Wednesday, December 25, 2013

Equational reasoning

You may have heard that Haskell is "great for equational reasoning", but perhaps you didn't know precisely what that meant. This post will walk through an intermediate example of equational reasoning to show how you can interpret Haskell code by hand by applying substitutions as if your code were a mathematical expression. This process of interpreting code by substituting equals-for-equals is known as equational reasoning.


Equality

In Haskell, when you see a single equals sign it means that the left-hand and right-hand side are substitutable for each other both ways. For example, if you see:

x = 5

That means that wherever you see an x in your program, you can substitute it with a 5, and, vice versa, anywhere you see a 5 in your program you can substitute it with an x. Substitutions like these preserve the behavior of your program.

In fact, you can "run" your program just by repeatedly applying these substitutions until your program is just one giant main. To illustrate this, we will begin with the following program, which should print the number 1 three times:

import Control.Monad

main = replicateM_ 3 (print 1)

replicateM_ is a function that repeats an action a specified number of times. Its type when specialized to IO is:

replicateM_ :: Int -> IO () -> IO ()

The first argument is an Int specifying how many times to repeat the action and the second argument is the action we wish to repeat. In the above example, we specify that we wish to repeat the print 1 action three times.

But what if you don't believe me? What if you wanted to prove to yourself that it repeated the action three times? How would you prove that?


Use the source!

You can locate the source to Haskell functions using one of three tricks:

  • Use Hoogle, which can also search for functions by type signature

  • Use Hayoo!, which is like hoogle, but searches a larger package database and is more strict about matches

  • Use Google and search for "hackage <function>". This also works well for searching for packages.

Using either of those three tricks we would locate replicateM_ here and then we can click the Source link to the right of it to view its definition, which I reproduce here:

replicateM_ n x = sequence_ (replicate n x)

The equals sign means that any time we see something of the form replicateM_ n x, we can substitute it with sequence_ (replicate n x), for any choice of n or x. For example, if we choose the following values for n and x:

n = 3

x = print 1

... then we obtain the following more specific equation:

replicateM_ 3 (print 1) = sequence_ (replicate 3 (print 1))

We will use this equation to replace our program's replicateM_ command with an equal program built from sequence and replicate:

main = sequence_ (replicate 3 (print 1))

The equation tells us that this substitution is safe and preserves the original behavior of our program.

Now, in order to simplify this further we must consult the definition of both replicate and sequence_. When in doubt which one to pick, always pick the outermost function because Haskell is lazy and evaluates everything from outside to in.

In this case, our outermost function is sequence_, defined here:

-- | Combine a list of actions into a single action
sequence_ :: [IO ()] -> IO ()  -- I've simplified the type
sequence_ ms = foldr (>>) (return ()) ms

We will substitute this into our main, noting that ms is replicate 3 (print 1) for the purpose of this substitution:

main = foldr (>>) (return ()) (replicate 3 (print 1))

Now foldr is our outermost function, so we'll consult the definition of foldr:

-- | 'foldr k z' replaces all `(:)`s with `k`
--   and replaces `[]` with `z`
foldr :: (a -> b -> b) -> b -> [a] -> b
foldr k z  []    = z
foldr k z (x:xs) = k x (foldr k z xs)

Here we see two equations. Both equations work both ways, but we don't really know which equation to apply unless we know whether or not the third argument to foldr is an empty list or not. We must evaluate the call to replicate to see whether it will pass us an empty list or a non-empty list, so we consult the definition of replicate:

-- | Create a list containing `x` repeated `n` times
replicate :: Int -> a -> [a]
replicate n x = take n (repeat x)

-- Apply the substitution, using these values:
-- n = 3
-- x = print 1

main = foldr (>>) (return ()) (take 3 (repeat (print 1)))

Boy, this rabbit hole keeps getting deeper and deeper! However, we must persevere. Let's now consult the definition of take:

-- | Take the first `n` elements from a list
take :: Int -> [a] -> [a]
take n _      | n <= 0 =  []
take _ []              =  []
take n (x:xs)          =  x : take (n-1) xs

Here we see three equations. The first one has a predicate, saying that the equality is only valid if n is less than or equal to 0. In our case n is 3, so we skip that equation. However, we cannot distinguish which of the latter two equations to use unless we know whether repeat (print 1) produces an empty list or not, so we must consult the definition of repeat:

-- | `repeat x` creates an infinite list of `x`s
repeat :: a -> [a]
repeat x = x:repeat x  -- Simplified from the real definition

-- Apply the substitution, using these values:
-- x = print 1

main = foldr (>>) (return ()) (take 3 (print 1:repeat (print 1)))

The buck stops here! Although repeat is infinitely recursive, we don't have to fully evaluate it. We can just evaluate it once and lazily defer the recursive call, since all we need to know for now is that the list has at least one value. This now provides us with enough information to select the third equation for take that requires a non-empty list as its argument:

take n (x:xs)          =  x : take (n-1) xs

-- Apply the substitution, using these values:
-- n  = 3
-- x  = print 1
-- xs = repeat (print 1)

main = foldr (>>) (return ()) (print 1:take 2 (repeat (print 1)))

Now we know for sure that foldr's third argument is a non-empty list, so we can select the second equation for foldr:

foldr k z (x:xs) = k x (foldr k z xs)

-- Apply the substitution, using these values:
-- k  = (>>)
-- z  = return ()
-- x  = print 1
-- xs = take 2 (repeat (print 1))

main =
    (>>) (print 1)
         (foldr (>>) (return ()) (take 2 (repeat (print 1))))

-- Note: "(>>) x y" is the same thing as "x >> y"
main = print 1 >> foldr (>>)
                        (return ())
                        (take 2 (repeat (print 1)))

-- Note: "x >> y" is the same thing as "do { x; y }
main = do
    print 1
    foldr (>>) (return ()) (take 2 (repeat (print 1)))

Now our Haskell runtime knows enough information to realize that it needs to print a single 1. The language is smart and will execute the first print statement before further evaluating the call to foldr.

We can repeat this process two more times, cycling through evaluating repeat, take, and foldr, which emit two additional print commands:

-- Evaluate `repeat`
main = do
    print 1
    foldr (>>) (return ()) (take 2 (print 1:repeat (print 1)))

-- Evaluate `take`
main = do
    print 1
    foldr (>>) (return ()) (print 1:take 1 (repeat (print 1)))

-- Evaluate `foldr`
main = do
    print 1
    print 1
    foldr (>>) (return ()) (take 1 (repeat (print 1)))

-- Evaluate `repeat`
main = do
    print 1
    print 1
    foldr (>>) (return ()) (take 1 (print 1:repeat (print 1)))

-- Evaluate `take`
main = do
    print 1
    print 1
    foldr (>>) (return ()) (print 1:take 0 (repeat (print 1)))

-- Evaluate `foldr`
main = do
    print 1
    print 1
    print 1
    foldr (>>) (return ()) (take 0 (repeat (print 1)))

Now something different will happen. Let's revisit our definition of take:

take n _      | n <= 0 =  []
take _ []              =  []
take n (x:xs)          =  x : take (n-1) xs

This time the first equation matches, because n is now equal to 0. However, we also know the third equation will match because repeat will emit a new element. Whenever more than one equation matches Haskell takes the first one by default, so we use the first equation to substitute the call to take with an empty list:

main = do
    print 1
    print 1
    print 1
    foldr (>>) (return ()) []

This then triggers the first equation for foldr:

foldr k z [] = z

-- Therefore:
main = do
    print 1
    print 1
    print 1
    return ()

There it is! We've fully expanded out our use of replicateM_ to prove that it prints the number 1 three times. For reasons I won't go into, we can also remove the final return () and finish with:

main = do
    print 1
    print 1
    print 1

Equational reasoning - Part 1

Notice how Haskell has a very straightforward way to interpret all code: substitution. If you can substitute equals-for-equals then you can interpret a Haskell program on pen and paper. Substitution is the engine of application state.

Equally cool: we never needed to understand the language runtime to know what our code did. The Haskell language has a very limited role: ensure that substitution is valid and otherwise get out of our way.

Unlike imperative languages, there are no extra language statements such as for or break that we need to understand in order to interpret our code. Our printing "loop" that repeated three times was just a bunch of ordinary Haskell code. This is a common theme in Haskell: "language features" are usually ordinary libraries.

Also, had we tried the same pen-and-paper approach to interpreting an imperative language we would have had to keep track of temporary values somewhere in the margins while evaluating our program. In Haskell, all the "state" of our program resides within the expression we are evaluating, and in our case the "state" was the integer argument to take that we threaded through our subsitutions. Haskell code requires less context to understand than imperative code because Haskell expressions are self-contained.

We didn't need to be told to keep track of state. We just kept mindlessly applying substitution and perhaps realized after the fact that what we were doing was equivalent to a state machine. Indeed, state is implemented within the language just like everything else.


Proof reduction

Proving the behavior of our code was really tedious, and we're really interested in proving more generally reusable properties rather than deducing the behavior of specific, disposable programs. However, we spent a lot of effort to prove our last equation, so we want to pick our battles wisely and only spend time proving equations that we can reuse heavily.

So I will propose that we should only really bother to prove the following four equations in order to cover most common uses of replicateM_:

-- replicateM_ "distributes" over addition
replicateM_  0      x = return ()
replicateM_ (m + n) x = replicateM_ m x >> replicateM_ n x

-- replicateM_ "distributes" over multiplication
replicateM_  1      = id
replicateM_ (m * n) = replicateM_ m . replicateM_ n

The last two equations are written in a "point-free style" to emphasize that replicateM_ distributes in a different way over multiplication. If you expand those two equations out to a "point-ful style" you get:

replicateM_ (m * n) x = replicateM_ m (replicateM_ n x)
replicateM   1      x = x

If we could prove these equations, then we could much more easily deduce how replicateM_ behaves for our example program, because we could transform our code more rapidly like this:

replicateM_ 3 (print 1)

-- 3 = 1 + 1 + 1
= replicateM_ (1 + 1 + 1) (print 1)

-- replicateM_ distributes over addition
= do replicateM_ 1 (print 1)
     replicateM_ 1 (print 1)
     replicateM_ 1 (print 1)

-- replicateM_ 1 x = x
= do print 1
     print 1
     print 1

These four master equations are still very tedious to prove all in one go, but we can break this complex task into smaller bite-sized tasks. As a bonus, this divide-and-conquer approach will also produce several other useful and highly reusable equations along the way.

Let's begin by revisiting the definition of replicateM_:

replicateM_ n x = sequence_ (replicate n x)

Like many Haskell functions, replicateM_ is built from two smaller composable pieces: sequence_ and replicate. So perhaps we can also build our proofs from smaller and composable proofs about the individual behaviors of sequence_ and replicate.

Indeed, replicate possesses a set of properties that are remarkably similar to replicateM_:

-- replicate distributes over addition
replicate  0      x = []
replicate (m + n) x = replicate m x ++ replicate n x

-- replicate distributes over multiplication
replicate  1      = return
replicate (m * n) = replicate m <=< replicate n

Again, the last two equations can be expanded to a "point-ful" form:

-- `return x` is `[x]` for lists
replicate 1 x = [x]

-- `(<=<)` is like a point-free `concatMap` for lists
replicate (m * n) x = concatMap (replicate m) (replicate n x)

These four replicate equations are easier to prove than the corresponding replicateM_ equations. If we can somehow prove these simpler equations, then all we have to do is then prove that sequence_ lifts all of the replicate proofs into the equivalent replicateM_ proofs:

-- sequence_ lifts list concatenation to sequencing
sequence_  []        = return ()
sequence_ (xs ++ ys) = sequence_ xs >> sequence_ ys

-- (sequence_ .) lifts list functions to ordinary functions
sequence_ .  return   = id
sequence_ . (f <=< g) = (sequence_ . f) . (sequence_ . g)

-- Point-ful version of last two equations:
sequence_ [x] = x
sequence_ (concatMap f (g x)) = sequence_ (f (sequence_ (g x)))

It might not be obvious at first, but the above four equations for sequence_ suffice to transform the replicate proofs into the analogous replicateM_ proofs. For example, this is how we would prove the first replicateM_ equation in terms of the first replicate equation and the first sequence equation:

replicateM_ 0 x

-- Definition of `replicateM_`
= sequence_ (replicate 0 x)

-- replicate 0 x = []
= sequence_ []

-- sequence_ [] = return ()
= return ()

That was simple. It is only slightly more tricky to prove the second equation:

replicateM_ (m + n) x

-- Definition of `replicateM_`
= sequence_ (replicate (m + n) x)

-- replicate (m + n) x = replicate m x ++ replicate n x
= sequence_ (replicate m x ++ replicate n x)

-- sequence (xs ++ ys) = sequence xs >> sequence ys
= sequence_ (replicate m x) >> sequence_ (replicate n x)

-- Definition of `replicateM_`, in reverse!
= replicateM_ m x >> replicateM_ n x

Notice how the last step of the proof involves using the original replicateM_ equation, but instead substituting from right-to-left!

--  +-- We can substitute this way --+
--  |                                |
--  ^                                v

replicateM_ n x = sequence_ (replicate n x)

--  ^                                v
--  |                                |
--  +--------- Or this way! ---------+

This is a nice example of the utility of bidirectional substitution. We can replace the body of a function with the equivalent function call.

The third replicateM_ equation is also simple to prove in terms of the third replicate and sequence_ equations. I will use the point-ful forms of all these equations for simplicity:

replicateM_ 1 x

-- Definition of `replicateM_`
= sequence_ (replicate 1 x)

-- replicate 1 x = [x]
= sequence_ [x]

-- sequence_ [x] = x
= x

The fourth property is surprisingly short, too:

replicateM_ (m * n) x

-- Definition of `replicateM_`
= sequence_ (replicate (m * n) x)

-- replicate (m * n) x = concatMap (replicate m) (replicate n x)
= sequence_ (concatMap (replicate m) (replicate n x))

-- sequence_ (concatMap f (g x)) = sequence_ (f (sequence_ (g x)))
= sequence_ (replicate m (sequence_ (replicate n x)))

-- Definition of `replicateM_`, in reverse
= replicateM_ m (replicateM_ n x)

Equational reasoning - Part 2

We reduced our proofs of the replicateM_ properties to smaller proofs for replicate and sequence properties. The overhead of this proof reduction is tiny and we can gain the benefit of reusing proofs for replicate and sequence.

As programmers we try to reuse code when we program and the way we promote code reuse is to divide programs into smaller composable pieces that are more reusable. Likewise, we try to reuse proofs when we equationally reason about code and the way we encourage proof reuse is to divide larger proofs into smaller proofs using proof reduction. In the above example we reduced the four equations for replicateM_ into four equations for replicate and four equations for sequence_. These smaller equations are equally useful in their own right and they can be reused by other people as sub-proofs for their own proofs.

However, proof reuse also faces the same challenges as code reuse. When we break up code into smaller pieces sometimes we take things too far and create components we like to think are reusable but really aren't. Similarly, when we reduce proofs sometimes we pick sub-proofs that are worthless and only add more overhead to the entire proof process. How can we sift out the gold from the garbage?

I find that the most reusable proofs are category laws or functor laws of some sort. In fact, every single proof from the previous section was a functor law in disguise. To learn more about functor laws and how they arise everywhere you can read another post of mine about the functor design pattern.


Proof techniques

This section will walk through the complete proofs for the replicate equations to provide several worked examples and to also illustrate several useful proof tricks.

I deliberately write these proofs to be reasonably detailed and to skip as few steps as possible. In practice, though, proofs become much easier the more you equationally reason about code because you get better at taking larger steps.

Let's revisit the equations we wish to prove for replicate, in point-ful form:

replicate  0      x = []
replicate (m + n) x = replicate m x ++ replicate n x

replicate  1      x = [x]
replicate (m * n) x = concatMap (replicate m) (replicate n x)

... where concatMap is defined as:

concatMap :: (a -> [b]) -> [a] -> [b]
concatMap f = foldr ((++) . f) []

Now we must use the same equational reasoning skills we developed in the first section to prove all four of these equations.

The first equation is simple:

replicate 0 x

-- Definition of `replicate`
= take 0 (repeat x)

-- Definition of `take`
= []

-- Proof complete

The second equation is trickier:

replicate (m + n) x

-- Definition of `take`
= take (m + n) (repeat x)

We can't proceed further unless we know whether or not m + n is greater than 0. For simplicity we'll assume that m and n are non-negative.

We then do something analogous to "case analysis" on m, pretending it is like a Peano number. That means that m is either 0 or positive (i.e. greater than 0). We'll first prove our equation for the case where m equals 0:

-- Assume: m = 0
= take (0 + n) (repeat x)

-- 0 + n = n
= take n (repeat x)

-- Definition of `(++)`, in reverse
= [] ++ take n (repeat x)

-- Definition of `take`, in reverse
= take 0 (repeat x) ++ take n (repeat x)

-- m = 0
= take m (repeat x) ++ take n (repeat x)

-- Definition of `replicate`, in reverse
= replicate m x ++ replicate n x

-- Proof complete for m = 0

Then there is the second case, where m is positive, meaning that we can represent m as 1 plus some other non-negative number m':

-- Assume: m = 1 + m'
= take (1 + m' + n) (repeat x)

-- Definition of `repeat`
= take (1 + m' + n) (x:repeat x)

-- Definition of `take`
= x:take (m' + n) (repeat x)

-- Definition of `replicate` in reverse
= x:replicate (m' + n) x

Now we can use induction to reuse the original premise since m' is strictly smaller than m. Since we are assuming that m is non-negative this logical recursion is well-founded and guaranteed to eventually bottom out at the base case where m equals 0:

-- Induction: reuse the premise
= x:(replicate m' x ++ replicate n x)

-- Definition of `(++)`, in reverse
= (x:replicate m' x) ++ replicate n x

-- Definition of `replicate`
= (x:take m' (repeat x)) ++ replicate n x

-- Definition of `take`, in reverse
= take (1 + m') (repeat x) ++ replicate n x

-- Definition of `replicate`, in reverse
= replicate (1 + m') x ++ replicate n x

-- m = 1 + m', in reverse
= replicate m x ++ replicate n x

-- Proof complete for m = 1 + m'

This completes the proof for both cases so the proof is "total", meaning that we covered all possibilities. Actually, that's a lie because really rigorous Haskell proofs must account for the possibility of non-termination (a.k.a. "bottom"). However, I usually consider proofs that don't account for non-termination to be good enough for most practical purposes.

The third replicate law is very straightforward to prove:

replicate 1 x

-- Definition of `replicate`
= take 1 (repeat x)

-- Definition of `repeat`
= take 1 (x:repeat x)

-- Definition of `take`
= x:take 0 (repeat x)

-- Definition of `take`
= x:[]

-- [x] is syntactic sugar for `x:[]`
= [x]

The fourth equation for replicate also requires us to split our proof into two branches. Either n is zero or greater than zero. First we consider the case where n is zero:

replicate (m * n) x

-- Assume: n = 0
= replicate 0 x

-- replicate 0 x = []
= []

-- Definition of `foldr`, in reverse
= foldr ((++) . replicate m) [] []

-- Definition of `concatMap`, in reverse
= concatMap (replicate m) []

-- replicate 0 x = [], in reverse
= concatMap (replicate m) (replicate 0 x)

-- n = 0, in reverse
= concatMap (replicate m) (replicate n x)

-- Proof complete for n = 0

Then we consider the case where n is greater than zero:

replicate (m * n) x

-- Assume: n = 1 + n'
= replicate (m * (1 + n')) x

-- m * (1 + n') = m + m * n'
= replicate (m + m * n') x

-- replicate distributes over addition
= replicate m x ++ replicate (m * n') x

-- Induction: reuse the premise
= replicate m x ++ concatMap (replicate m) (replicate n' x)

-- Definition of `concatMap`
= replicate m x ++ foldr ((++) . replicate m) [] (replicate n' x)

-- Definition of `foldr`, in reverse
= foldr ((++) . replicate m)) [] (x:replicate n' x)

-- Definition of `concatMap`, in reverse
= concatMap (replicate m) (x:replicate n' x)

-- Definition of `replicate`
= concatMap (replicate m) (x:take n' (repeat x))

-- Definition of `take`, in reverse
= concatMap (replicate m) (take (1 + n') (x:repeat x))

-- n = 1 + n', in reverse
= concatMap (replicate m) (take n (x:repeat x))

-- Definition of `repeat`, in reverse
= concatMap (replicate m) (take n (repeat x))

-- Definition of `replicate`, in reverse
= concatMap (replicate m) (replicate n x)

-- Proof complete for n = 1 + n'

Hopefully these proofs give an idea for the amount of effort involved to prove properties of moderate complexity. I omitted the final part of proving the sequence_ equations in the interest of space, but they make for a great exercise.


Equational Reasoning - Part 3

Reasoning about Haskell differs from reasoning about code in other languages. Traditionally, reasoning about code would require:

  • building a formal model of a program's algorithm,
  • reasoning about the behavior of the formal model using its axioms, and
  • proving that the program matches the model.

In a purely functional language like Haskell you formally reason about your code within the language itself. There is no need for a separate formal model because the code is already written as a bunch of formal equations. This is what people mean when they say that Haskell has strong ties to mathematics, because you can reason about Haskell code the same way you reason about mathematical equations.

This is why Haskell syntax for function definitions deviates from mainstream languages. All function definitions are just equalities, which is why Haskell is great for equational reasoning.

This post illustrated how equational reasoning in Haskell can scale to larger complexity through the use of proof reduction. A future post of mine will walk through a second tool for reducing proof complexity: type classes inspired by mathematics that come with associated type class laws.

3 comments:

  1. Great article!

    A small observation: the type of concatMap should be (a -> [b]) -> [a] -> [b] to match the given definition.

    ReplyDelete
  2. This is such a fun article. You write so clearly. And that was almost 8 years ago!

    Possible correction. Should the sequences in "-- sequence (xs ++ ys) = sequence xs >> sequence ys" be sequence_'s ?

    ReplyDelete