Thursday, September 9, 2021

Optics are monoids

lens-trick

This post documents my favorite lens trick of all time. Also, this trick works for any optics package based on van Laarhoven lenses, like lens-family-core or microlens.

This post assumes some familiarity with lenses, so if you are new to lenses then you might want to first read:

The title is slightly misleading and the precise statement is that Folds are Monoids, and all of the following optics are subtypes of Folds:

  • Getter
  • Lens
  • Traversal
  • Prism
  • Iso

… but I couldn’t fit that all of that in the title.

That means that if you combine any of the above optic types using <>, you will get a new optic that can be used as a Fold that combines their targets. For example:

>>> toListOf _1 (True, False)          -- _1 is a Lens
[True]
>>> toListOf _2 (True, False)          -- _2 is a Lens
[False]
>>> toListOf (_1 <> _2) (True, False)  -- (_1 <> 2) is a Fold
[True,False]

Also, mempty is the “empty” Fold that targets nothing:

>>> toListOf mempty (True, False)
[]

There’s more to this trick, though, and we can build upon this idea to create optics that traverse complex data structures in a single pass.

Realistic example

To illustrate the trick, I’ll use a realistic example inspired by one of my interpreter side projects. I’ll keep things simple by reducing the original example to the following syntax tree for a toy lambda calculus implementation:

data Syntax
    = Variable String
    | Lambda String Syntax
    | Apply Syntax Syntax
    | Let String Syntax Syntax

example :: Syntax
example = Lambda "x" (Lambda "y" (Apply (Variable "x") (Variable "y")))

Now suppose we’d like to write a function that verifies that our syntax tree has no empty variable names. Without optics, we could write something like this:

wellFormed :: Syntax -> Bool
wellFormed (Variable name) =
    not (null name)
wellFormed (Lambda name body) =
    not (null name) && wellFormed body
wellFormed (Apply function argument) =
    wellFormed function && wellFormed argument
wellFormed (Let name assignment body) =
    not (null name) && wellFormed assignment && wellFormed body

… which works as expected on a few smoke tests:

>>> wellFormed example
True
>>> wellFormed (Variable "")
False

This implementation is simple enough for now. However, real interpreters tend to add a whole bunch of other constructors to the syntax tree. For example, each new keyword or datatype we add to the language will add another constructor to the syntax tree and each new constructor increases the boilerplate code for functions like wellFormed.

Thankfully, the lens and generic-lens packages provide useful utilities that simplify recursive functions like wellFormed. All we have to do is derive Plated and Generic for our Syntax type, like this:

{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric      #-}

module Example where

import Data.Generics.Product (the)
import Data.Generics.Sum (_As)
import GHC.Generics (Generic)

import Control.Lens
import Data.Data (Data)

data Syntax
    = Variable String
    | Lambda String Syntax
    | Apply Syntax Syntax
    | Let String Syntax Syntax
    deriving (Data, Generic, Show)

instance Plated Syntax

example :: Syntax
example = Lambda "x" (Lambda "y" (Apply (Variable "x") (Variable "y")))

Once we derive Plated we can use the cosmos lens to zoom in on all sub-expressions:

>>> toListOf cosmos example
[ Lambda "x" (Lambda "y" (Apply (Variable "x") (Variable "y")))
, Lambda "y" (Apply (Variable "x") (Variable "y"))
, Apply (Variable "x") (Variable "y")
, Variable "x"
, Variable "y"
]

… and when we derive Generic we can further narrow down the results using _As and the from the generic-lens package:

>>> :set -XTypeApplications
>>> :set -XDataKinds
>>> toListOf (cosmos . _As @"Variable") example  -- Get all Variable names
["x","y"]

>>> toListOf (cosmos . _As @"Lambda" . the @1) example  -- Get all Lambda names
["x","y"]

>>> toListOf (cosmos . _As @"Let" . the @1) example  -- Get all Let names
[]

So we can combine these tricks to implement our wellFormed function using optics to handle the automatic tree traversal:

wellFormed :: Syntax -> Bool
wellFormed syntax =
        hasn't (cosmos . _As @"Variable" . only "") syntax
    &&  hasn't (cosmos . _As @"Lambda" . the @1 . only "") syntax
    &&  hasn't (cosmos . _As @"Let" . the @1 . only "") syntax

… but we’re not done here!

The cosmos Traversal factored away some of the repetition because we no longer need to recursively descend into subexpressions any longer. We also no longer need to explicitly handle constructors that have no variable names, like Apply.

Our wellFormed function is still repetitive, though, because three times in a row we write:

hasn't (cosmos .. only "") syntax

… and we’d like to factor out this repetition.

Our first instinct might be to factor out the repetition with a helper function, like this:

wellFormed :: Syntax -> Bool
wellFormed syntax =
        noEmptyVariables (_As @"Variable")
    &&  noEmptyVariables (_As @"Lambda" . the @1) 
    &&  noEmptyVariables (_As @"Let" . the @1)
  where
    noEmptyVariables fold =
        hasn't (cosmos . fold . only "") syntax

… and that does work, but there is actually a better way. We can instead use the fact that Traversals are also Folds and Folds are Monoids to write this:

wellFormed :: Syntax -> Bool
wellFormed syntax =
    hasn't
        ( cosmos
        . (   _As @"Variable"
          <>  _As @"Lambda" . the @1
          <>  _As @"Let" . the @1
          )
        . only ""
        )
        syntax

In other words, we can take the following three Traversals that each focus on a different source of variable names:

  • _As @"Variable" - Focus in on variable names
  • _As @"Lambda" . the @1 - Focus in on Lambda-bound variables
  • _As @"Let" . the @1 - Focus in on Let-bound variables

… and when we combine them using <> we get a Fold that focuses on all possible sources of variable names. We can then chain this composite Fold in between cosmos and only to find all of the empty variable names.

In fact, we’re not limited to using <>. Any utility that works on Monoids will work, like mconcat, so we can refactor our code even further like this:

wellFormed :: Syntax -> Bool
wellFormed = hasn't (cosmos . names . only "")

-- | Get all variable names within the current constructor
names :: Monoid m => Getting m Syntax String
names =
    mconcat
        [ _As @"Variable"
        , _As @"Lambda" . the @1
        , _As @"Let" . the @1
        ]   

Now we have factored out a useful and reusable names Fold1 that we can combine with cosmos to get all names within our syntax tree:

>>> toListOf (cosmos . names) example
["x","y","x","y"]

This means that any new contributor to our interpreter can register a new source of variable names by adding a Traversal to that list and all downstream functions that use names will automatically update to do the right thing.

Why this trick?

I briefly touched on some other cool tricks in the course of this post, including:

  • The use of Plated for simplifying recursive Traversals
  • The use of generic-lens for boilerplate-free optics

… so why do I consider “optics as monoids” to be the coolest trick of them all? After all, Plated and generic-lens did most of the heavy lifting in the above example.

The reason why I particularly love the Monoid instance for lenses is because (as far as I know) nobody ever designed lenses to do this; this is purely an emergent property of independent design choices spread out over time.

This shouldn’t surprise us too much, because Haskell’s mathematically inspired type classes sort of follow the Unix philosophy to Do one thing and do it well. If each piece does one small thing correctly, then if you combine multiple pieces then the composite result is correct by construction.

However, you don’t need to take my word for it. I’ll walk through in detail how this works, first at the type level and then at the term level.

Type level

The first piece of the puzzle is the following Semigroup and Monoid instances for functions in base:

instance Semigroup b => Semigroup (a -> b) where
    (f <> g) x = f x <> g x

instance Monoid b => Monoid (a -> b) where
    mempty x = mempty

These instances says that a function is a Monoid if the function’s result is also a Monoid. We combine two functions by combining their outputs (when given the same input) and the empty function ignores its input and produces an empty output.

The second piece of the puzzle is the Const type (the constant Functor), which has a Semigroup and Monoid instance, too:

newtype Const a b = Const { getConst :: a }

instance Monoid a => Monoid (Const a b) where
    mempty = Const mempty

instance Semigroup a => Semigroup (Const a b) where
    Const x <> Const y = Const (x <> y)

These instances are so simple that we could have just derived them (and indeed, that is what the base package does):

newtype Const a b = Const { getConst :: a }
    deriving newtype (Semigroup, Monoid)

In other words, Const a b is a Monoid if a is a Monoid. Combining two Consts combines their stored value, and the empty Const stores an empty value.

The final piece of the puzzle is that a Fold from the lens package is just a higher-order function over Consts:

-- This not the real type, but it's equivalent
type Fold a b = forall m . Monoid m => (b -> Const m b) -> (a -> Const m a)

… and that type is a valid Monoid, because:

  • (b -> Const m b) -> (a -> Const m a) is a Monoid if (a -> Const m a) is a Monoid

    … according to the Monoid instance for functions

  • a -> Const m a is a Monoid if Const m a is a Monoid

    … also according to the Monoid instance for functions

  • Const m a is a Monoid if m is a Monoid

    … according to the Monoid instance for Const

  • m is a Monoid

    … according to the Monoid m => constraint in the type of Fold

Therefore, all Folds are Monoids.

Term level

However, knowing that a Fold type-checks as a Monoid is not enough. We want to build an intuition for what happens when we use Monoid operations on Folds.

The most compact way we can state our intuition is by the following two laws:

toListOf (l <> r) a = toListOf l a <> toListOf r a

toListOf mempty a = mempty

In other words, if you combine two Folds then you combine their “targets”, and the empty Fold has no targets.

These laws are also known as “monoid morphism laws”. In other words, toListOf is a function that transforms one type of Monoid (Folds) into another type of Monoid (lists).

We can verify those two laws using equational reasoning, but in order to do so we need to use the following simplified definition for toListOf:

{-# LANGUAGE RankNTypes #-}

import Data.Functor.Const (Const(..))

toListOf :: Fold a b -> a -> [b]
toListOf fold a = getConst (fold (\b -> Const [b]) a)

The real toListOf function from the lens package has a different, but equivalent, implementation. The real implementation is more efficient, but takes more steps when proving things using equational reasoning, so I prefer to use this simpler implementation.

Now let’s prove the two monoid morphism laws. The proof for the first law is:

toListOf (l <> r) a

-- Substitute `toListOf` with its definition
= getConst ((l <> r) (\b -> Const [b]) a)

-- If `l` and `r` are functions, then according to the `Semigroup` instance for
-- functions:
--
--     (f <> g) x = f x <> g x
--
-- … where in this case:
--
--     f = l
--     g = r
--     x = \b -> Const [b]
= getConst ((l (\b -> Const [b]) <> r (\b -> Const [b])) a)

-- Use the `Semigroup` instance for functions again, except this time:
--
--     f = l (\b -> Const [b])
--     g = r (\b -> Const [b])
--     x = a
= getConst (l (\b -> Const [b]) a <> r (\b -> Const [b]) a)

-- Now use the `Semigroup` instance for `Const`, which (essentially) says:
--
--     getConst (x <> y) = getConst x <> getConst y
--
-- … where:
--
--     x = l (\b -> Const [b]) a
--     y = r (\b -> Const [b]) a
= getConst (l (\b -> Const [b]) a) <> getConst (r (\b -> Const [b]) a)

-- Now apply the definition of `toListOf`, but in reverse:
= toListOf l a <> toListOf r a

… and the proof for the second law is:

toList mempty a

-- Substitute `toListOf` with its definition
= getConst (mempty (\b -> Const [b]) a)

-- If `mempty` is a function, then according to the `Monoid` instance for
-- functions:
--
--     mempty x = mempty
--
-- … where in this case:
--
--     x = \b -> Const [b]
= getConst (mempty a)

-- Use the `Monoid` instance for functions again, except this time:
--
--     x = a
= getConst mempty

-- Now use the `Monoid` instance for `Const`, which says:
--
--    mempty = Const mempty
= getConst (Const mempty)

-- getConst (Const x) = x
= mempty

Conclusion

Hopefully that gives you a taste for how slick and elegant Haskell’s lens package is. If you like this post, you might also like these other posts:

Also, I know that I skimmed through the subjects of Plated and generic-lens, which are interesting topics in their own right. I hope to cover those in more detail in future posts.

I don’t know if this trick can be made to work for optics (an alternative to lens that uses an abstract interface to improve error messages). I know that it does not work at the time of this writing, but perhaps a Monoid instance could be added for the Optic type? I also have no idea if this trick or a related trick works for profunctor-optics (a different formulation of lenses based on profunctors).

I haven’t benchmarked whether combining Folds is faster than doing separate passes over the same data structure. I believe it’s more lazy, though, to process the data structure in a single pass using a composite Fold.

Appendix

Here is the complete code example if you want to test this out locally:

{-# LANGUAGE DataKinds          #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric      #-}
{-# LANGUAGE TypeApplications   #-}

module Example where

import Data.Generics.Product (the)
import Data.Generics.Sum (_As)
import GHC.Generics (Generic)

import Control.Lens
import Data.Data (Data)

data Syntax
    = Variable String
    | Lambda String Syntax
    | Apply Syntax Syntax
    | Let String Syntax Syntax
    deriving (Data, Generic, Show)

instance Plated Syntax

example :: Syntax
example = Lambda "x" (Lambda "y" (Apply (Variable "x") (Variable "y")))

wellFormed :: Syntax -> Bool
wellFormed = hasn't (cosmos . names . only "")

names :: Monoid m => Getting m Syntax String
names =
    mconcat
        [ _As @"Variable"
        , _As @"Lambda" . the @1
        , _As @"Let" . the @1
        ]

  1. The type of names is morally Fold Syntax String, which according to the lens documentation is the essentially same type, but only the Getting version of the type will type-check here.↩︎

3 comments:

  1. One downside of the transformation for `wellFormed` is you lose exhaustivness checking. Is there some way to make the compiler complain if you forget to add a new constructor to the list of monoids?

    ReplyDelete
    Replies
    1. Not for Folds, but you can do exhaustive pattern matching for Prisms. See my `total` package: https://hackage.haskell.org/package/total

      Delete
  2. > I don’t know if this trick can be made to work for optics (an alternative to lens that uses an abstract interface to improve error messages). I know that it does not work at the time of this writing, but perhaps a Monoid instance could be added for the Optic type?

    There's been some discussion:
    https://github.com/well-typed/optics/pull/300
    https://github.com/well-typed/optics/pull/332#issuecomment-669461724
    Leading to the following documentation:
    https://hackage.haskell.org/package/optics-core-0.4/docs/Optics-Fold.html#g:monoids
    https://hackage.haskell.org/package/optics-core-0.4/docs/Optics-Traversal.html#g:9

    In particular, no Semigroup instance was added, because two equally valid `(<>)`s are possible, `summing` being the one which arises in `lens`. And because the type `a -> a -> a` is too restrictive to work will in practice with `optics`'s overloading.

    ReplyDelete