You may have heard that Haskell is "great for equational reasoning", but perhaps you didn't know precisely what that meant. This post will walk through an intermediate example of equational reasoning to show how you can interpret Haskell code by hand by applying substitutions as if your code were a mathematical expression. This process of interpreting code by substituting equals-for-equals is known as equational reasoning.
Equality
In Haskell, when you see a single equals sign it means that the left-hand and right-hand side are substitutable for each other both ways. For example, if you see:
x = 5
That means that wherever you see an x
in your program, you can substitute it with a 5
, and, vice versa, anywhere you see a 5
in your program you can substitute it with an x
. Substitutions like these preserve the behavior of your program.
In fact, you can "run" your program just by repeatedly applying these substitutions until your program is just one giant main
. To illustrate this, we will begin with the following program, which should print the number 1
three times:
import Control.Monad
main = replicateM_ 3 (print 1)
replicateM_
is a function that repeats an action a specified number of times. Its type when specialized to IO
is:
replicateM_ :: Int -> IO () -> IO ()
The first argument is an Int
specifying how many times to repeat the action and the second argument is the action we wish to repeat. In the above example, we specify that we wish to repeat the print 1
action three times.
But what if you don't believe me? What if you wanted to prove to yourself that it repeated the action three times? How would you prove that?
Use the source!
You can locate the source to Haskell functions using one of three tricks:
Use Hoogle, which can also search for functions by type signature
Use Hayoo!, which is like hoogle, but searches a larger package database and is more strict about matches
Use Google and search for "hackage <function>". This also works well for searching for packages.
Using either of those three tricks we would locate replicateM_
here and then we can click the Source link to the right of it to view its definition, which I reproduce here:
replicateM_ n x = sequence_ (replicate n x)
The equals sign means that any time we see something of the form replicateM_ n x
, we can substitute it with sequence_ (replicate n x)
, for any choice of n
or x
. For example, if we choose the following values for n
and x
:
n = 3
x = print 1
... then we obtain the following more specific equation:
replicateM_ 3 (print 1) = sequence_ (replicate 3 (print 1))
We will use this equation to replace our program's replicateM_
command with an equal program built from sequence
and replicate
:
main = sequence_ (replicate 3 (print 1))
The equation tells us that this substitution is safe and preserves the original behavior of our program.
Now, in order to simplify this further we must consult the definition of both replicate
and sequence_
. When in doubt which one to pick, always pick the outermost function because Haskell is lazy and evaluates everything from outside to in.
In this case, our outermost function is sequence_
, defined here:
-- | Combine a list of actions into a single action
sequence_ :: [IO ()] -> IO () -- I've simplified the type
sequence_ ms = foldr (>>) (return ()) ms
We will substitute this into our main
, noting that ms
is replicate 3 (print 1)
for the purpose of this substitution:
main = foldr (>>) (return ()) (replicate 3 (print 1))
Now foldr
is our outermost function, so we'll consult the definition of foldr
:
-- | 'foldr k z' replaces all `(:)`s with `k`
-- and replaces `[]` with `z`
foldr :: (a -> b -> b) -> b -> [a] -> b
foldr k z [] = z
foldr k z (x:xs) = k x (foldr k z xs)
Here we see two equations. Both equations work both ways, but we don't really know which equation to apply unless we know whether or not the third argument to foldr
is an empty list or not. We must evaluate the call to replicate
to see whether it will pass us an empty list or a non-empty list, so we consult the definition of replicate
:
-- | Create a list containing `x` repeated `n` times
replicate :: Int -> a -> [a]
replicate n x = take n (repeat x)
-- Apply the substitution, using these values:
-- n = 3
-- x = print 1
main = foldr (>>) (return ()) (take 3 (repeat (print 1)))
Boy, this rabbit hole keeps getting deeper and deeper! However, we must persevere. Let's now consult the definition of take
:
-- | Take the first `n` elements from a list
take :: Int -> [a] -> [a]
take n _ | n <= 0 = []
take _ [] = []
take n (x:xs) = x : take (n-1) xs
Here we see three equations. The first one has a predicate, saying that the equality is only valid if n
is less than or equal to 0. In our case n
is 3, so we skip that equation. However, we cannot distinguish which of the latter two equations to use unless we know whether repeat (print 1)
produces an empty list or not, so we must consult the definition of repeat
:
-- | `repeat x` creates an infinite list of `x`s
repeat :: a -> [a]
repeat x = x:repeat x -- Simplified from the real definition
-- Apply the substitution, using these values:
-- x = print 1
main = foldr (>>) (return ()) (take 3 (print 1:repeat (print 1)))
The buck stops here! Although repeat
is infinitely recursive, we don't have to fully evaluate it. We can just evaluate it once and lazily defer the recursive call, since all we need to know for now is that the list has at least one value. This now provides us with enough information to select the third equation for take
that requires a non-empty list as its argument:
take n (x:xs) = x : take (n-1) xs
-- Apply the substitution, using these values:
-- n = 3
-- x = print 1
-- xs = repeat (print 1)
main = foldr (>>) (return ()) (print 1:take 2 (repeat (print 1)))
Now we know for sure that foldr
's third argument is a non-empty list, so we can select the second equation for foldr
:
foldr k z (x:xs) = k x (foldr k z xs)
-- Apply the substitution, using these values:
-- k = (>>)
-- z = return ()
-- x = print 1
-- xs = take 2 (repeat (print 1))
main =
(>>) (print 1)
(foldr (>>) (return ()) (take 2 (repeat (print 1))))
-- Note: "(>>) x y" is the same thing as "x >> y"
main = print 1 >> foldr (>>)
(return ())
(take 2 (repeat (print 1)))
-- Note: "x >> y" is the same thing as "do { x; y }
main = do
print 1
foldr (>>) (return ()) (take 2 (repeat (print 1)))
Now our Haskell runtime knows enough information to realize that it needs to print
a single 1. The language is smart and will execute the first print
statement before further evaluating the call to foldr
.
We can repeat this process two more times, cycling through evaluating repeat
, take
, and foldr
, which emit two additional print
commands:
-- Evaluate `repeat`
main = do
print 1
foldr (>>) (return ()) (take 2 (print 1:repeat (print 1)))
-- Evaluate `take`
main = do
print 1
foldr (>>) (return ()) (print 1:take 1 (repeat (print 1)))
-- Evaluate `foldr`
main = do
print 1
print 1
foldr (>>) (return ()) (take 1 (repeat (print 1)))
-- Evaluate `repeat`
main = do
print 1
print 1
foldr (>>) (return ()) (take 1 (print 1:repeat (print 1)))
-- Evaluate `take`
main = do
print 1
print 1
foldr (>>) (return ()) (print 1:take 0 (repeat (print 1)))
-- Evaluate `foldr`
main = do
print 1
print 1
print 1
foldr (>>) (return ()) (take 0 (repeat (print 1)))
Now something different will happen. Let's revisit our definition of take
:
take n _ | n <= 0 = []
take _ [] = []
take n (x:xs) = x : take (n-1) xs
This time the first equation matches, because n
is now equal to 0
. However, we also know the third equation will match because repeat
will emit a new element. Whenever more than one equation matches Haskell takes the first one by default, so we use the first equation to substitute the call to take
with an empty list:
main = do
print 1
print 1
print 1
foldr (>>) (return ()) []
This then triggers the first equation for foldr
:
foldr k z [] = z
-- Therefore:
main = do
print 1
print 1
print 1
return ()
There it is! We've fully expanded out our use of replicateM_
to prove that it print
s the number 1 three times. For reasons I won't go into, we can also remove the final return ()
and finish with:
main = do
print 1
print 1
print 1
Equational reasoning - Part 1
Notice how Haskell has a very straightforward way to interpret all code: substitution. If you can substitute equals-for-equals then you can interpret a Haskell program on pen and paper. Substitution is the engine of application state.
Equally cool: we never needed to understand the language runtime to know what our code did. The Haskell language has a very limited role: ensure that substitution is valid and otherwise get out of our way.
Unlike imperative languages, there are no extra language statements such as for
or break
that we need to understand in order to interpret our code. Our printing "loop" that repeated three times was just a bunch of ordinary Haskell code. This is a common theme in Haskell: "language features" are usually ordinary libraries.
Also, had we tried the same pen-and-paper approach to interpreting an imperative language we would have had to keep track of temporary values somewhere in the margins while evaluating our program. In Haskell, all the "state" of our program resides within the expression we are evaluating, and in our case the "state" was the integer argument to take
that we threaded through our subsitutions. Haskell code requires less context to understand than imperative code because Haskell expressions are self-contained.
We didn't need to be told to keep track of state. We just kept mindlessly applying substitution and perhaps realized after the fact that what we were doing was equivalent to a state machine. Indeed, state is implemented within the language just like everything else.
Proof reduction
Proving the behavior of our code was really tedious, and we're really interested in proving more generally reusable properties rather than deducing the behavior of specific, disposable programs. However, we spent a lot of effort to prove our last equation, so we want to pick our battles wisely and only spend time proving equations that we can reuse heavily.
So I will propose that we should only really bother to prove the following four equations in order to cover most common uses of replicateM_
:
-- replicateM_ "distributes" over addition
replicateM_ 0 x = return ()
replicateM_ (m + n) x = replicateM_ m x >> replicateM_ n x
-- replicateM_ "distributes" over multiplication
replicateM_ 1 = id
replicateM_ (m * n) = replicateM_ m . replicateM_ n
The last two equations are written in a "point-free style" to emphasize that replicateM_
distributes in a different way over multiplication. If you expand those two equations out to a "point-ful style" you get:
replicateM_ (m * n) x = replicateM_ m (replicateM_ n x)
replicateM 1 x = x
If we could prove these equations, then we could much more easily deduce how replicateM_
behaves for our example program, because we could transform our code more rapidly like this:
replicateM_ 3 (print 1)
-- 3 = 1 + 1 + 1
= replicateM_ (1 + 1 + 1) (print 1)
-- replicateM_ distributes over addition
= do replicateM_ 1 (print 1)
replicateM_ 1 (print 1)
replicateM_ 1 (print 1)
-- replicateM_ 1 x = x
= do print 1
print 1
print 1
These four master equations are still very tedious to prove all in one go, but we can break this complex task into smaller bite-sized tasks. As a bonus, this divide-and-conquer approach will also produce several other useful and highly reusable equations along the way.
Let's begin by revisiting the definition of replicateM_
:
replicateM_ n x = sequence_ (replicate n x)
Like many Haskell functions, replicateM_
is built from two smaller composable pieces: sequence_
and replicate
. So perhaps we can also build our proofs from smaller and composable proofs about the individual behaviors of sequence_
and replicate
.
Indeed, replicate
possesses a set of properties that are remarkably similar to replicateM_
:
-- replicate distributes over addition
replicate 0 x = []
replicate (m + n) x = replicate m x ++ replicate n x
-- replicate distributes over multiplication
replicate 1 = return
replicate (m * n) = replicate m <=< replicate n
Again, the last two equations can be expanded to a "point-ful" form:
-- `return x` is `[x]` for lists
replicate 1 x = [x]
-- `(<=<)` is like a point-free `concatMap` for lists
replicate (m * n) x = concatMap (replicate m) (replicate n x)
These four replicate
equations are easier to prove than the corresponding replicateM_
equations. If we can somehow prove these simpler equations, then all we have to do is then prove that sequence_
lifts all of the replicate
proofs into the equivalent replicateM_
proofs:
-- sequence_ lifts list concatenation to sequencing
sequence_ [] = return ()
sequence_ (xs ++ ys) = sequence_ xs >> sequence_ ys
-- (sequence_ .) lifts list functions to ordinary functions
sequence_ . return = id
sequence_ . (f <=< g) = (sequence_ . f) . (sequence_ . g)
-- Point-ful version of last two equations:
sequence_ [x] = x
sequence_ (concatMap f (g x)) = sequence_ (f (sequence_ (g x)))
It might not be obvious at first, but the above four equations for sequence_
suffice to transform the replicate
proofs into the analogous replicateM_
proofs. For example, this is how we would prove the first replicateM_
equation in terms of the first replicate
equation and the first sequence
equation:
replicateM_ 0 x
-- Definition of `replicateM_`
= sequence_ (replicate 0 x)
-- replicate 0 x = []
= sequence_ []
-- sequence_ [] = return ()
= return ()
That was simple. It is only slightly more tricky to prove the second equation:
replicateM_ (m + n) x
-- Definition of `replicateM_`
= sequence_ (replicate (m + n) x)
-- replicate (m + n) x = replicate m x ++ replicate n x
= sequence_ (replicate m x ++ replicate n x)
-- sequence (xs ++ ys) = sequence xs >> sequence ys
= sequence_ (replicate m x) >> sequence_ (replicate n x)
-- Definition of `replicateM_`, in reverse!
= replicateM_ m x >> replicateM_ n x
Notice how the last step of the proof involves using the original replicateM_
equation, but instead substituting from right-to-left!
-- +-- We can substitute this way --+
-- | |
-- ^ v
replicateM_ n x = sequence_ (replicate n x)
-- ^ v
-- | |
-- +--------- Or this way! ---------+
This is a nice example of the utility of bidirectional substitution. We can replace the body of a function with the equivalent function call.
The third replicateM_
equation is also simple to prove in terms of the third replicate
and sequence_
equations. I will use the point-ful forms of all these equations for simplicity:
replicateM_ 1 x
-- Definition of `replicateM_`
= sequence_ (replicate 1 x)
-- replicate 1 x = [x]
= sequence_ [x]
-- sequence_ [x] = x
= x
The fourth property is surprisingly short, too:
replicateM_ (m * n) x
-- Definition of `replicateM_`
= sequence_ (replicate (m * n) x)
-- replicate (m * n) x = concatMap (replicate m) (replicate n x)
= sequence_ (concatMap (replicate m) (replicate n x))
-- sequence_ (concatMap f (g x)) = sequence_ (f (sequence_ (g x)))
= sequence_ (replicate m (sequence_ (replicate n x)))
-- Definition of `replicateM_`, in reverse
= replicateM_ m (replicateM_ n x)
Equational reasoning - Part 2
We reduced our proofs of the replicateM_
properties to smaller proofs for replicate
and sequence
properties. The overhead of this proof reduction is tiny and we can gain the benefit of reusing proofs for replicate
and sequence
.
As programmers we try to reuse code when we program and the way we promote code reuse is to divide programs into smaller composable pieces that are more reusable. Likewise, we try to reuse proofs when we equationally reason about code and the way we encourage proof reuse is to divide larger proofs into smaller proofs using proof reduction. In the above example we reduced the four equations for replicateM_
into four equations for replicate
and four equations for sequence_
. These smaller equations are equally useful in their own right and they can be reused by other people as sub-proofs for their own proofs.
However, proof reuse also faces the same challenges as code reuse. When we break up code into smaller pieces sometimes we take things too far and create components we like to think are reusable but really aren't. Similarly, when we reduce proofs sometimes we pick sub-proofs that are worthless and only add more overhead to the entire proof process. How can we sift out the gold from the garbage?
I find that the most reusable proofs are category laws or functor laws of some sort. In fact, every single proof from the previous section was a functor law in disguise. To learn more about functor laws and how they arise everywhere you can read another post of mine about the functor design pattern.
Proof techniques
This section will walk through the complete proofs for the replicate
equations to provide several worked examples and to also illustrate several useful proof tricks.
I deliberately write these proofs to be reasonably detailed and to skip as few steps as possible. In practice, though, proofs become much easier the more you equationally reason about code because you get better at taking larger steps.
Let's revisit the equations we wish to prove for replicate
, in point-ful form:
replicate 0 x = []
replicate (m + n) x = replicate m x ++ replicate n x
replicate 1 x = [x]
replicate (m * n) x = concatMap (replicate m) (replicate n x)
... where concatMap
is defined as:
concatMap :: (a -> [b]) -> [a] -> [b]
concatMap f = foldr ((++) . f) []
Now we must use the same equational reasoning skills we developed in the first section to prove all four of these equations.
The first equation is simple:
replicate 0 x
-- Definition of `replicate`
= take 0 (repeat x)
-- Definition of `take`
= []
-- Proof complete
The second equation is trickier:
replicate (m + n) x
-- Definition of `take`
= take (m + n) (repeat x)
We can't proceed further unless we know whether or not m + n
is greater than 0
. For simplicity we'll assume that m
and n
are non-negative.
We then do something analogous to "case analysis" on m
, pretending it is like a Peano number. That means that m
is either 0
or positive (i.e. greater than 0). We'll first prove our equation for the case where m
equals 0
:
-- Assume: m = 0
= take (0 + n) (repeat x)
-- 0 + n = n
= take n (repeat x)
-- Definition of `(++)`, in reverse
= [] ++ take n (repeat x)
-- Definition of `take`, in reverse
= take 0 (repeat x) ++ take n (repeat x)
-- m = 0
= take m (repeat x) ++ take n (repeat x)
-- Definition of `replicate`, in reverse
= replicate m x ++ replicate n x
-- Proof complete for m = 0
Then there is the second case, where m
is positive, meaning that we can represent m
as 1
plus some other non-negative number m'
:
-- Assume: m = 1 + m'
= take (1 + m' + n) (repeat x)
-- Definition of `repeat`
= take (1 + m' + n) (x:repeat x)
-- Definition of `take`
= x:take (m' + n) (repeat x)
-- Definition of `replicate` in reverse
= x:replicate (m' + n) x
Now we can use induction to reuse the original premise since m'
is strictly smaller than m
. Since we are assuming that m
is non-negative this logical recursion is well-founded and guaranteed to eventually bottom out at the base case where m
equals 0
:
-- Induction: reuse the premise
= x:(replicate m' x ++ replicate n x)
-- Definition of `(++)`, in reverse
= (x:replicate m' x) ++ replicate n x
-- Definition of `replicate`
= (x:take m' (repeat x)) ++ replicate n x
-- Definition of `take`, in reverse
= take (1 + m') (repeat x) ++ replicate n x
-- Definition of `replicate`, in reverse
= replicate (1 + m') x ++ replicate n x
-- m = 1 + m', in reverse
= replicate m x ++ replicate n x
-- Proof complete for m = 1 + m'
This completes the proof for both cases so the proof is "total", meaning that we covered all possibilities. Actually, that's a lie because really rigorous Haskell proofs must account for the possibility of non-termination (a.k.a. "bottom"). However, I usually consider proofs that don't account for non-termination to be good enough for most practical purposes.
The third replicate
law is very straightforward to prove:
replicate 1 x
-- Definition of `replicate`
= take 1 (repeat x)
-- Definition of `repeat`
= take 1 (x:repeat x)
-- Definition of `take`
= x:take 0 (repeat x)
-- Definition of `take`
= x:[]
-- [x] is syntactic sugar for `x:[]`
= [x]
The fourth equation for replicate
also requires us to split our proof into two branches. Either n
is zero or greater than zero. First we consider the case where n
is zero:
replicate (m * n) x
-- Assume: n = 0
= replicate 0 x
-- replicate 0 x = []
= []
-- Definition of `foldr`, in reverse
= foldr ((++) . replicate m) [] []
-- Definition of `concatMap`, in reverse
= concatMap (replicate m) []
-- replicate 0 x = [], in reverse
= concatMap (replicate m) (replicate 0 x)
-- n = 0, in reverse
= concatMap (replicate m) (replicate n x)
-- Proof complete for n = 0
Then we consider the case where n
is greater than zero:
replicate (m * n) x
-- Assume: n = 1 + n'
= replicate (m * (1 + n')) x
-- m * (1 + n') = m + m * n'
= replicate (m + m * n') x
-- replicate distributes over addition
= replicate m x ++ replicate (m * n') x
-- Induction: reuse the premise
= replicate m x ++ concatMap (replicate m) (replicate n' x)
-- Definition of `concatMap`
= replicate m x ++ foldr ((++) . replicate m) [] (replicate n' x)
-- Definition of `foldr`, in reverse
= foldr ((++) . replicate m)) [] (x:replicate n' x)
-- Definition of `concatMap`, in reverse
= concatMap (replicate m) (x:replicate n' x)
-- Definition of `replicate`
= concatMap (replicate m) (x:take n' (repeat x))
-- Definition of `take`, in reverse
= concatMap (replicate m) (take (1 + n') (x:repeat x))
-- n = 1 + n', in reverse
= concatMap (replicate m) (take n (x:repeat x))
-- Definition of `repeat`, in reverse
= concatMap (replicate m) (take n (repeat x))
-- Definition of `replicate`, in reverse
= concatMap (replicate m) (replicate n x)
-- Proof complete for n = 1 + n'
Hopefully these proofs give an idea for the amount of effort involved to prove properties of moderate complexity. I omitted the final part of proving the sequence_
equations in the interest of space, but they make for a great exercise.
Equational Reasoning - Part 3
Reasoning about Haskell differs from reasoning about code in other languages. Traditionally, reasoning about code would require:
- building a formal model of a program's algorithm,
- reasoning about the behavior of the formal model using its axioms, and
- proving that the program matches the model.
In a purely functional language like Haskell you formally reason about your code within the language itself. There is no need for a separate formal model because the code is already written as a bunch of formal equations. This is what people mean when they say that Haskell has strong ties to mathematics, because you can reason about Haskell code the same way you reason about mathematical equations.
This is why Haskell syntax for function definitions deviates from mainstream languages. All function definitions are just equalities, which is why Haskell is great for equational reasoning.
This post illustrated how equational reasoning in Haskell can scale to larger complexity through the use of proof reduction. A future post of mine will walk through a second tool for reducing proof complexity: type classes inspired by mathematics that come with associated type class laws.
Great article!
ReplyDeleteA small observation: the type of concatMap should be (a -> [b]) -> [a] -> [b] to match the given definition.
Thanks! I fixed it.
DeleteThis is such a fun article. You write so clearly. And that was almost 8 years ago!
ReplyDeletePossible correction. Should the sequences in "-- sequence (xs ++ ys) = sequence xs >> sequence ys" be sequence_'s ?