Saturday, February 6, 2021

Folds are constructor substitution

Folds are constructor substitution

I notice that functional programming beginners and experts understand the word “fold” to mean subtly different things, so I’d like to explain what experienced functional programmers usually mean when they use the term “fold”. This post assumes a passing familiarity with Haskell.

Overview

A “fold” is a function that replaces all constructors of a datatype with corresponding expressions. “fold”s are not limited to lists, linear sequences, or even containers; you can fold any inductively defined datatype.

To explain the more general notion of a “fold”, we’ll consider three representative data structures:

  • lists
  • Maybe values
  • binary trees

… and show how we can automatically derive the “one true fold” for each data structure by following the same general principle.

Lists

Many beginners understand the word “fold” to be a way to reduce some collection of values (e.g. a list) to a single value. For example, in Haskell you can add up the elements of a list like this:

sum :: [Int] -> Int
sum xs = foldr (+) 0 xs

… where sum reduces a sequence of Ints to a single Int by starting from an initial accumulator value of 0 and then “folding” each element of the list into the accumulator using (+).

Haskell’s standard library provides at least two fold functions named foldl and foldr, but only foldr is the “canonical” fold for a list. By “canonical” I mean that foldr is the only fold that works by substituting list constructors.

We can more easily see this if we define our own linked list type with explicitly named constructors:

data List a = Cons a (List a) | Nil

… where instead of writing a list as [ 1, 2, 3 ] we instead will write such a list as:

example :: List Int
example = Cons 1 (Cons 2 (Cons 3 Nil))

This is a very unergonomic representation for a list, but bear with me!

We can implement the “canonical” fold for the above List type as a function that takes two arguments:

  • The first argument (named cons) replaces all occurrences of the Cons constructor
  • The second argument (named nil) replaces all occurrences of the Nil constructor

The implementation of the canonical fold looks like this:

fold :: (a -> list -> list) -> list -> List a -> list
fold cons nil (Cons x xs) = cons x (fold cons nil xs)
fold cons nil  Nil        = nil

You might not necessarily follow how that implementation works, so a more direct way to appreciate how fold works is to see how the function behaves on some sample inputs:

-- The general case, step-by-step
fold cons nil (Cons x (Cons y (Cons z Nil)))
    = cons x (fold cons nil (Cons y (Cons z Nil)))
    = cons x (cons y (fold cons nil (Cons z Nil)))
    = cons x (cons y (cons z (fold cons nil Nil)))
    = cons x (cons y (cons z nil))

-- Add up the elements of the list, but skipping more steps this time
fold (+) 0 (Cons 1 (Cons 2 (Cons 3 Nil)))
  = (+) 1 ((+) 2 ((+) 3 0))
  = 1 + (2 + (3 + 0))
  = 6

-- Calculate the list length
fold (\_ n -> n + 1) 0 (Cons True (Cons False (Cons True Nil)))
  = (\_ n -> n + 1) True ((\_ n -> n + 1) False ((\_ n -> n + 1) True 0))
  = (\_ n -> n + 1) True ((\_ n -> n + 1) False 1)
  = (\_ n -> n + 1) True 2
  = 3

Notice that if we format the type of fold a bit we can see that the type of each argument to fold (sort of) matches the type of the corresponding constructor they replace:

fold
    :: (a -> list -> list)  -- Cons :: a -> List a -> List a
    -> list                 -- Nil  :: List a
    -> List a
    -> list

In the above type, list is actually a type variable and we could have used any name for that type variable instead of list, such as b. In fact, if we were to replace list with b, we would get essentially the same type as foldr for Haskell lists:

-- Our `fold` type, replacing `list` with `b`
fold
    :: (a -> b -> b)
    -> b
    -> List a
    -> b

-- Now compare that type to the `foldr` type from the Prelude:
foldr
    :: (a -> b -> b)
    -> b
    -> [a]
    -> b

We commonly use folds to reduce a List to a single scalar value, but folds are actually much more general-purpose than that and they can be used to transform one data structure into another data structure. For example, we can use the same fold function to convert our clumsy List type into the standard Haskell list type, like this:

fold (:) [] (Cons 1 (Cons 2 (Cons 3 Nil)))
    = (:) 1 ((:) 2 ((:) 3 []))
    = 1 : (2 : (3 : []))
    = [ 1, 2, 3 ]

Maybe

Folds are not limited to recursive data types. For example, here is the canonical fold for Haskell’s Maybe type, which is not recursive:

data Maybe a = Nothing | Just a

fold :: maybe -> (a -> maybe) -> Maybe a -> maybe
fold nothing just  Nothing  = nothing
fold nothing just (Just x ) = just x

In fact, this function already exists in Haskell’s standard library by the name of maybe:

maybe :: b -> (a -> b) -> Maybe a -> b
maybe n _ Nothing  = n
maybe _ f (Just x) = f x

Once you think of folds in terms of constructor substitution you can quickly spot these canonical folds for other types.

Binary trees

What about more complex data structures, like the following binary Tree type?

data Tree a = Node a (Tree a) (Tree a) | Leaf

This sort of fold is still straightforward to write, by applying the same principle of constructor substitution:

fold :: (a -> tree -> tree -> tree) -> tree -> Tree a -> tree
fold node leaf (Node x l r) = node x (fold node leaf l) (fold node leaf r)
fold node leaf  Leaf        = leaf

We only need to keep recursively descending over the Tree, replacing constructors as we go.

We can use this fold to reduce the Tree to a single value, like this:

-- Add up all the nodes in the tree
fold (\x l r -> x + l + r) 0 (Node 1 (Node 2 Leaf Leaf) (Node 3 Leaf Leaf))
    = (\x l r -> x + l + r) 1
        ((\x l r -> x + l + r) 2 0 0)
        ((\x l r -> x + l + r) 3 0 0)
    = (\x l r -> x + l + r) 1
        (2 + 0 + 0)
        (3 + 0 + 0)
    = (\x l r -> x + l + r) 1
        2
        3
    = 1 + 2 + 3
    = 6

… or we can use the same fold function to transform the Tree into another data structure, like a list:

-- List `Tree` elements in pre-order
fold (\x l r -> x : l ++ r) [] (Node 1 (Node 2 Leaf Leaf) (Node 3 Leaf Leaf))
    = (\x l r -> x : l ++ r) 1
        ((\x l r -> x : l ++ r) 2 [] [])
        ((\x l r -> x : l ++ r) 3 [] [])
    = (\x l r -> x : l ++ r) 1
        (2 : [] ++ [])
        (3 : [] ++ [])
    = (\x l r -> x : l ++ r) 1
        [2]
        [3]
    = (\x l r -> x : l ++ r) 1
        [2]
        [3]
    = 1 : [2] ++ [3]
    = [1, 2, 3]

… even use the fold to reverse the tree:

fold (\x l r -> Node x r l) Leaf (Node 1 (Node 2 Leaf Leaf) (Node 3 Leaf Leaf))
    = (\x l r -> Node x r l) 1
        ((\x l r -> Node x r l) 2 Leaf Leaf)
        ((\x l r -> Node x r l) 3 Leaf Leaf)
    = (\x l r -> Node x r l) 1
        (Node 2 Leaf Leaf)
        (Node 3 Leaf Leaf)
    = Node 1 (Node 3 Leaf Leaf) (Node 2 Leaf Leaf)

Generality

At this point you might be wondering: “what can’t a fold do?”. The answer is: you can do essentially anything with a fold, although it might not necessarily be the most efficient solution to your problem. You can think of a fold as the most general-purpose interface for consuming a data structure because the fold interface is a “lossless” way to process a data structure.

To see why a fold is a “lossless” interface, let’s revisit the fold function for Trees and this time we will pass in the Node and Leaf constructors as the inputs to the fold. In other words, we will replace all occurrences of Node with Node and replace all occurrences of Leaf with Leaf:

fold Node Leaf (Node 1 (Node 2 Leaf Leaf) (Node 3 Leaf Leaf))
    = Node 1 (Node 2 Leaf Leaf) (Node 3 Leaf Leaf)

This gives us back the original data structure, demonstrating how we always have the option for a fold to recover the original pristine input. This is what I mean when I say that a fold is a lossless interface.