Friday, September 12, 2014

Morte: an intermediate language for super-optimizing functional programs

The Haskell language provides the following guarantee (with caveats): if two programs are equal according to equational reasoning then they will behave the same. On the other hand, Haskell does not guarantee that equal programs will generate identical performance. Consequently, Haskell library writers must employ rewrite rules to ensure that their abstractions do not interfere with performance.

Now suppose there were a hypothetical language with a stronger guarantee: if two programs are equal then they generate identical executables. Such a language would be immune to abstraction: no matter how many layers of indirection you might add the binary size and runtime performance would be unaffected.

Here I will introduce such an intermediate language named Morte that obeys this stronger guarantee. I have not yet implemented a back-end code generator for Morte, but I wanted to pause to share what I have completed so far because Morte uses several tricks from computer science that I believe deserve more attention.

Morte is nothing more than a bare-bones implementation of the calculus of constructions, which is a specific type of lambda calculus. The only novelty is how I intend to use this lambda calculus: as a super-optimizer.

Normalization

The typed lambda calculus possesses a useful property: every term in the lambda calculus has a unique normal form if you beta-reduce everything. If you're new to lambda calculus, normalizing an expression equates to indiscriminately inlining every function call.

What if we built a programming language whose intermediate language was lambda calculus? What if optimization was just normalization of lambda terms (i.e. indiscriminate inlining)? If so, then we would could abstract freely, knowing that while compile times might increase, our final executable would never change.

Recursion

Normally you would not want to inline everything because infinitely recursive functions would become infinitely large expressions. Fortunately, we can often translate recursive code to non-recursive code!

I'll demonstrate this trick first in Haskell and then in Morte. Let's begin from the following recursive List type along with a recursive map function over lists:

import Prelude hiding (map, foldr)

data List a = Cons a (List a) | Nil

example :: List Int
example = Cons 1 (Cons 2 (Cons 3 Nil))

map :: (a -> b) -> List a -> List b
map f  Nil       = Nil
map f (Cons a l) = Cons (f a) (map f l)

-- Argument order intentionally switched
foldr :: List a -> (a -> x -> x) -> x -> x
foldr  Nil       c n = n
foldr (Cons a l) c n = c a (foldr l c n)

result :: Int
result = foldr (map (+1) example) (+) 0

-- result = 9

Now imagine that we disable all recursion in Haskell: no more recursive types and no more recursive functions. Now we must reject the above program because:

  • the List data type definition recursively refers to itself

  • the map and foldr functions recursively refer to themselves

Can we still encode lists in a non-recursive dialect of Haskell?

Yes, we can!

-- This is a valid Haskell program

{-# LANGUAGE RankNTypes #-}

import Prelude hiding (map, foldr)

type List a = forall x . (a -> x -> x) -> x -> x

example :: List Int
example = \cons nil -> cons 1 (cons 2 (cons 3 nil))

map :: (a -> b) -> List a -> List b
map f l = \cons nil -> l (\a x -> cons (f a) x) nil

foldr :: List a -> (a -> x -> x) -> x -> x
foldr l = l

result :: Int
result = foldr (map (+ 1) example) (+) 0

-- result = 9

Carefully note that:

  • List is no longer defined recursively in terms of itself

  • map and foldr are no longer defined recursively in terms of themselves

Yet, we somehow managed to build a list, map a function over the list, and fold the list, all without ever using recursion! We do this by encoding the list as a fold, which is why foldr became the identity function.

This trick works for more than just lists. You can take any recursive data type and mechanically transform the type into a fold and transform functions on the type into functions on folds. If you want to learn more about this trick, the specific name for it is "Boehm-Berarducci encoding". If you are curious, this in turn is equivalent to an even more general concept from category theory known as "F-algebras", which let you encode inductive things in a non-inductive way.

Non-recursive code greatly simplifies equational reasoning. For example, we can easily prove that we can optimize map id l to l:

map id l

-- Inline: map f l = \cons nil -> l (\a x -> cons (f a) x) nil
= \cons nil -> l (\a x -> cons (id a) x) nil

-- Inline: id x = x
= \cons nil -> l (\a x -> cons a x) nil

-- Eta-reduce
= \cons nil -> l cons nil

-- Eta-reduce
= l

Note that we did not need to use induction to prove this optimization because map is no longer recursive. The optimization became downright trivial, so trivial that we can automate it!

Morte optimizes programs using this same simple scheme:

  • Beta-reduce everything (equivalent to inlining)
  • Eta-reduce everything

To illustrate this, I will desugar our high-level Haskell code to the calculus of constructions. This desugaring process is currently manual (and tedious), but I plan to automate this, too, by providing a front-end high-level language similar to Haskell that compiles to Morte:

-- mapid.mt

(    \(List : * -> *)
->   \(  map
     :   forall (a : *)
     ->  forall (b : *)
     -> (a -> b) -> List a -> List b
     )
->   \(id : forall (a : *) -> a -> a)

    ->   \(a : *) -> map a a (id a)
)

-- List
(\(a : *) -> forall (x : *) -> (a -> x -> x) -> x -> x)

-- map
(   \(a : *)
->  \(b : *)
->  \(f : a -> b)
->  \(l : forall (x : *) -> (a -> x -> x) -> x -> x)
->  \(x : *)
->  \(Cons : b -> x -> x)
->  \(Nil: x)
->  l x (\(va : a) -> \(vx : x) -> Cons (f va) vx) Nil
)

-- id
(\(a : *) -> \(va : a) -> va)

This line of code is the "business end" of the program:

\(a : *) -> map a a (id a)

The extra 'a' business is because in any polymorphic lambda calculus you explicitly accept polymorphic types as arguments and specialize functions by applying them to types. Higher-level functional languages like Haskell or ML use type inference to automatically infer and supply type arguments when possible.

We can compile this program using the morte executable, which accepts a Morte program on stdin, outputs the program's type stderr, and outputs the optimized program on stdout:

$ morte < id.mt
∀(a : *) → (∀(x : *) → (a → x → x) → x → x) → ∀(x : *) → (a 
→ x → x) → x → x

λ(a : *) → λ(l : ∀(x : *) → (a → x → x) → x → x) → l

The first line is the type, which is a desugared form of:

forall a . List a -> List a

The second line is the program, which is the identity function on lists. Morte optimized away the map completely, the same way we did by hand.

Morte also optimized away the rest of the code, too. Dead-code elimination is just an emergent property of Morte's simple optimization scheme.

Equality

We could double-check our answer by asking Morte to optimize the identity function on lists:

-- idlist.mt

(    \(List : * -> *)
->   \(id   : forall (a : *) -> a -> a)

    ->   \(a : *) -> id (List a)
)

-- List
(\(a : *) -> forall (x : *) -> (a -> x -> x) -> x -> x)

-- id
(\(a : *) -> \(va : a) -> va)

Sure enough, Morte outputs an alpha-equivalent result (meaning the same up to variable renaming):

$ ~/.cabal/bin/morte < idlist.mt
∀(a : *) → (∀(x : *) → (a → x → x) → x → x) → ∀(x : *) → (a 
→ x → x) → x → x

λ(a : *) → λ(va : ∀(x : *) → (a → x → x) → x → x) → va

We can even use the morte library to mechanically check if two Morte expressions are alpha-, beta-, and eta- equivalent. We can parse our two Morte files into Morte's Expr type and then use the Eq instance for Expr to test for equivalence:

$ ghci
Prelude> import qualified Data.Text.Lazy.IO as Text
Prelude Text> txt1 <- Text.readFile "mapid.mt"
Prelude Text> txt2 <- Text.readFile "idlist.mt"
Prelude Text> import Morte.Parser (exprFromText)
Prelude Text Morte.Parser> let e1 = exprFromText txt1
Prelude Text Morte.Parser> let e2 = exprFromText txt2
Prelude Text Morte.Parser> import Control.Applicative (liftA2)
Prelude Text Morte.Parser Control.Applicative> liftA2 (==) e1 e2
Right True
$ -- `Right` means both expressions parsed successfully
$ -- `True` means they are alpha-, beta-, and eta-equivalent

We can use this to mechanically verify that two Morte programs optimize to the same result.

Compile-time computation

Morte can compute as much (or as little) at compile as you want. The more information you encode directly within lambda calculus, the more compile-time computation Morte will perform for you. For example, if we translate our Haskell List code entirely to lambda calculus, then Morte will statically compute the result at compile time.

-- nine.mt

(   \(Nat : *)
->  \(zero : Nat)
->  \(one : Nat)
->  \((+) : Nat -> Nat -> Nat)
->  \((*) : Nat -> Nat -> Nat)
->  \(List : * -> *)
->  \(Cons : forall (a : *) -> a -> List a -> List a)
->  \(Nil  : forall (a : *)                -> List a)
->  \(  map
    :   forall (a : *) -> forall (b : *)
    ->  (a -> b) -> List a -> List b
    )
->  \(  foldr
    :   forall (a : *)
    ->  List a
    ->  forall (r : *)
    ->  (a -> r -> r) -> r -> r
    )
->  (    \(two   : Nat)
    ->   \(three : Nat)
    ->   (    \(example : List Nat)

         ->   foldr Nat (map Nat Nat ((+) one) example) Nat (+) zero
         )

         -- example
         (Cons Nat one (Cons Nat two (Cons Nat three (Nil Nat))))
    )

    -- two
    ((+) one one)

    -- three
    ((+) one ((+) one one))
)

-- Nat
(   forall (a : *)
->  (a -> a)
->  a
->  a
)

-- zero
(   \(a : *)
->  \(Succ : a -> a)
->  \(Zero : a)
->  Zero
)

-- one
(   \(a : *)
->  \(Succ : a -> a)
->  \(Zero : a)
->  Succ Zero
)

-- (+)
(   \(m : forall (a : *) -> (a -> a) -> a -> a)
->  \(n : forall (a : *) -> (a -> a) -> a -> a)
->  \(a : *)
->  \(Succ : a -> a)
->  \(Zero : a)
->  m a Succ (n a Succ Zero)
)

-- (*)
(   \(m : forall (a : *) -> (a -> a) -> a -> a)
->  \(n : forall (a : *) -> (a -> a) -> a -> a)
->  \(a : *)
->  \(Succ : a -> a)
->  \(Zero : a)
->  m a (n a Succ) Zero
)

-- List
(   \(a : *)
->  forall (x : *)
->  (a -> x -> x)  -- Cons
->  x              -- Nil
->  x
)

-- Cons
(   \(a : *)
->  \(va  : a)
->  \(vas : forall (x : *) -> (a -> x -> x) -> x -> x)
->  \(x : *)
->  \(Cons : a -> x -> x)
->  \(Nil  : x)
->  Cons va (vas x Cons Nil)
)

-- Nil
(   \(a : *)
->  \(x : *)
->  \(Cons : a -> x -> x)
->  \(Nil  : x)
->  Nil
)

-- map
(   \(a : *)
->  \(b : *)
->  \(f : a -> b)
->  \(l : forall (x : *) -> (a -> x -> x) -> x -> x)
->  \(x : *)
->  \(Cons : b -> x -> x)
->  \(Nil: x)
->  l x (\(va : a) -> \(vx : x) -> Cons (f va) vx) Nil
)

-- foldr
(   \(a : *)
->  \(vas : forall (x : *) -> (a -> x -> x) -> x -> x)
->  vas
)

The relevant line is:

foldr Nat (map Nat Nat ((+) one) example) Nat (+) zero

If you remove the type-applications to Nat, this parallels our original Haskell example. We can then evaluate this expression at compile time:

$ morte < nine.mt
∀(a : *) → (a → a) → a → a

λ(a : *) → λ(Succ : a → a) → λ(Zero : a) → Succ (Succ (Succ 
(Succ (Succ (Succ (Succ (Succ (Succ Zero))))))))

Morte reduces our program to a church-encoded nine.

Run-time computation

Morte does not force you to compute everything using lambda calculus at compile time. Suppose that we wanted to use machine arithmetic at run-time instead. We can do this by parametrizing our program on:

  • the Int type,
  • operations on Ints, and
  • any integer literals we use

We accept these "foreign imports" as ordinary arguments to our program:

-- foreign.mt

-- Foreign imports
    \(Int : *)                      -- Foreign type
->  \((+) : Int -> Int -> Int)      -- Foreign function
->  \((*) : Int -> Int -> Int)      -- Foreign function
->  \(lit@0 : Int)  -- Literal "1"  -- Foreign data
->  \(lit@1 : Int)  -- Literal "2"  -- Foreign data
->  \(lit@2 : Int)  -- Literal "3"  -- Foreign data
->  \(lit@3 : Int)  -- Literal "1"  -- Foreign data
->  \(lit@4 : Int)  -- Literal "0"  -- Foreign data

-- The rest is compile-time lambda calculus
->  (   \(List : * -> *)
    ->  \(Cons : forall (a : *) -> a -> List a -> List a)
    ->  \(Nil  : forall (a : *)                -> List a)
    ->  \(  map
        :   forall (a : *)
        ->  forall (b : *)
        ->  (a -> b) -> List a -> List b
        )
    ->  \(  foldr
        :   forall (a : *)
        ->  List a
        ->  forall (r : *)
        ->  (a -> r -> r) -> r -> r
        )
        ->   (    \(example : List Int)

             ->   foldr Int (map Int Int ((+) lit@3) example) Int (+) lit@4
             )

             -- example
             (Cons Int lit@0 (Cons Int lit@1 (Cons Int lit@2 (Nil Int))))
    )

    -- List
    (   \(a : *)
    ->  forall (x : *)
    ->  (a -> x -> x)  -- Cons
    ->  x              -- Nil
    ->  x
    )

    -- Cons
    (   \(a : *)
    ->  \(va  : a)
    ->  \(vas : forall (x : *) -> (a -> x -> x) -> x -> x)
    ->  \(x : *)
    ->  \(Cons : a -> x -> x)
    ->  \(Nil  : x)
    ->  Cons va (vas x Cons Nil)
    )

    -- Nil
    (   \(a : *)
    ->  \(x : *)
    ->  \(Cons : a -> x -> x)
    ->  \(Nil  : x)
    ->  Nil
    )

    -- map
    (   \(a : *)
    ->  \(b : *)
    ->  \(f : a -> b)
    ->  \(l : forall (x : *) -> (a -> x -> x) -> x -> x)
    ->  \(x : *)
    ->  \(Cons : b -> x -> x)
    ->  \(Nil: x)
    ->  l x (\(va : a) -> \(vx : x) -> Cons (f va) vx) Nil
    )

    -- foldr
    (   \(a : *)
    ->  \(vas : forall (x : *) -> (a -> x -> x) -> x -> x)
    ->  vas
    )

We can use Morte to optimize the above program and Morte will reduce the program to nothing but foreign types, operations, and values:

$ morte < foreign.mt
∀(Int : *) → (Int → Int → Int) → (Int → Int → Int) → Int →
Int → Int → Int → Int → Int

λ(Int : *) → λ((+) : Int → Int → Int) → λ((*) : Int → Int → 
Int) → λ(lit : Int) → λ(lit@1 : Int) → λ(lit@2 : Int) → 
λ(lit@3 : Int) → λ(lit@4 : Int) → (+) ((+) lit@3 lit) ((+) 
((+) lit@3 lit@1) ((+) ((+) lit@3 lit@2) lit@4))

If you study that closely, Morte adds lit@3 (the "1" literal) to each literal of the list and then adds them up. We can then pass this foreign syntax tree to our machine arithmetic backend to transform those foreign operations to efficient operations.

Morte lets you choose how much information you want to encode within lambda calculus. The more information you encode in lambda calculus the more Morte can optimize your program, but the slower your compile times will get, so it's a tradeoff.

Corecursion

Corecursion is the dual of recursion. Where recursion works on finite data types, corecursion works on potentially infinite data types. An example would be the following infinite Stream in Haskell:

data Stream a = Cons a (Stream a)

numbers :: Stream Int
numbers = go 0
  where
    go n = Cons n (go (n + 1))

-- numbers = Cons 0 (Cons 1 (Cons 2 (...

map :: (a -> b) -> Stream a -> Stream b
map f (Cons a l) = Cons (f a) (map f l)

example :: Stream Int
example = map (+ 1) numbers

-- example = Cons 1 (Cons 2 (Cons 3 (...

Again, pretend that we disable any function from referencing itself so that the above code becomes invalid. This time we cannot reuse the same trick from previous sections because we cannot encode numbers as a fold without referencing itself. Try this if you don't believe me.

However, we can still encode corecursive things in a non-corecursive way. This time, we encode our Stream type as an unfold instead of a fold:

-- This is also valid Haskell code

{-# LANGUAGE ExistentialQuantification #-}

data Stream a = forall s . MkStream
    { seed :: s
    , step :: s -> (a, s)
    }

numbers :: Stream Int
numbers = MkStream 0 (\n -> (n, n + 1))

map :: (a -> b) -> Stream a -> Stream b
map f (MkStream s0 k) = MkStream s0 k'
  where
    k' s = (f a, s')
      where (a, s') = k s 

In other words, we store an initial seed of some type s and a step function of type s -> (a, s) that emits one element of our Stream. The type of our seed s can be anything and in our numbers example, the type of the internal state is Int. Another stream could use a completely different internal state of type (), like this:

-- ones = Cons 1 ones

ones :: Stream Int
ones = MkStream () (\_ -> (1, ()))

The general name for this trick is an "F-coalgebra" encoding of a corecursive type.

Once we encode our infinite stream non-recursively, we can safely optimize the stream by inlining and eta reduction:

map id l

-- l = MkStream s0 k
= map id (MkStream s0 k)

-- Inline definition of `map`
= MkStream s0 k'
  where
    k' = \s -> (id a, s')
      where
        (a, s') = k s

-- Inline definition of `id`
= MkStream s0 k'
  where
    k' = \s -> (a, s')
      where
        (a, s') = k s

-- Inline: (a, s') = k s
= MkStream s0 k'
  where
    k' = \s -> k s

-- Eta reduce
= MkStream s0 k'
  where
    k' = k

-- Inline: k' = k
= MkStream s0 k

-- l = MkStream s0 k
= l

Now let's encode Stream and map in Morte and compile the following four expressions:

map id

id

map f . map g

map (f . g)

Save the following Morte file to stream.mt and then uncomment the expression you want to test:

(   \(id : forall (a : *) -> a -> a)
->  \(  (.)
    :   forall (a : *)
    ->  forall (b : *)
    ->  forall (c : *)
    ->  (b -> c)
    ->  (a -> b)
    ->  (a -> c)
    )
->  \(Pair : * -> * -> *)
->  \(P : forall (a : *) -> forall (b : *) -> a -> b -> Pair a b)
->  \(  first
    :   forall (a : *)
    ->  forall (b : *)
    ->  forall (c : *)
    ->  (a -> b)
    ->  Pair a c
    ->  Pair b c
    )

->  (   \(Stream : * -> *)
    ->  \(  map
        :   forall (a : *)
        ->  forall (b : *)
        ->  (a -> b)
        ->  Stream a
        ->  Stream b
        )

        -- example@1 = example@2
    ->  (   \(example@1 : forall (a : *) -> Stream a -> Stream a)
        ->  \(example@2 : forall (a : *) -> Stream a -> Stream a)

        -- example@3 = example@4
        ->  \(  example@3
            :   forall (a : *)
            ->  forall (b : *)
            ->  forall (c : *)
            ->  (b -> c)
            ->  (a -> b)
            ->  Stream a
            ->  Stream c
            )

        ->  \(  example@4
            :   forall (a : *)
            ->  forall (b : *)
            ->  forall (c : *)
            ->  (b -> c)
            ->  (a -> b)
            ->  Stream a
            ->  Stream c
            )

        -- Uncomment the example you want to test
        ->  example@1
--      ->  example@2
--      ->  example@3
--      ->  example@4
        )

        -- example@1
        (\(a : *) -> map a a (id a))

        -- example@2
        (\(a : *) -> id (Stream a))

        -- example@3
        (   \(a : *)
        ->  \(b : *)
        ->  \(c : *)
        ->  \(f : b -> c)
        ->  \(g : a -> b)
        ->  map a c ((.) a b c f g)
        )

        --  example@4
        (   \(a : *)
        ->  \(b : *)
        ->  \(c : *)
        ->  \(f : b -> c)
        ->  \(g : a -> b)
        ->  (.) (Stream a) (Stream b) (Stream c) (map b c f) (map a b g)
        )
    )

    -- Stream
    (   \(a : *)
    ->  forall (x : *)
    ->  (forall (s : *) -> s -> (s -> Pair a s) -> x)
    ->  x
    )

    -- map
    (   \(a : *)
    ->  \(b : *)
    ->  \(f : a -> b)
    ->  \(  st
        :   forall (x : *)
        -> (forall (s : *) -> s -> (s -> Pair a s) -> x)
        -> x
        )
    ->  \(x : *)
    ->  \(S : forall (s : *) -> s -> (s -> Pair b s) -> x)
    ->  st
        x
        (   \(s : *)
        ->  \(seed : s)
        ->  \(step : s -> Pair a s)
        ->  S
            s
            seed
            (\(seed@1 : s) -> first a b s f (step seed@1))
        )
    )
)

-- id
(\(a : *) -> \(va : a) -> va)

-- (.)
(   \(a : *)
->  \(b : *)
->  \(c : *)
->  \(f : b -> c)
->  \(g : a -> b)
->  \(va : a)
->  f (g va)
)

-- Pair
(\(a : *) -> \(b : *) -> forall (x : *) -> (a -> b -> x) -> x)

-- P
(   \(a : *)
->  \(b : *)
->  \(va : a)
->  \(vb : b)
->  \(x : *)
->  \(P : a -> b -> x)
->  P va vb
)

-- first
(   \(a : *)
->  \(b : *)
->  \(c : *)
->  \(f : a -> b)
->  \(p : forall (x : *) -> (a -> c -> x) -> x)
->  \(x : *)
->  \(Pair : b -> c -> x)
->  p x (\(va : a) -> \(vc : c) -> Pair (f va) vc)
)

Both example@1 and example@2 will generate alpha-equivalent code:

$ morte < example1.mt
∀(a : *) → (∀(x : *) → (∀(s : *) → s → (s → ∀(x : *) → (a → 
s → x) → x) → x) → x) → ∀(x : *) → (∀(s : *) → s → (s → ∀(x
 : *) → (a → s → x) → x) → x) → x

λ(a : *) → λ(st : ∀(x : *) → (∀(s : *) → s → (s → ∀(x : *) →
 (a → s → x) → x) → x) → x) → st

$ morte < example2.mt
∀(a : *) → (∀(x : *) → (∀(s : *) → s → (s → ∀(x : *) → (a → 
s → x) → x) → x) → x) → ∀(x : *) → (∀(s : *) → s → (s → ∀(x
 : *) → (a → s → x) → x) → x) → x

λ(a : *) → λ(va : ∀(x : *) → (∀(s : *) → s → (s → ∀(x : *) →
 (a → s → x) → x) → x) → x) → va

Similarly, example@3 and example@4 will generate alpha-equivalent code:

$ morte < example3.mt
∀(a : *) → ∀(b : *) → ∀(c : *) → (b → c) → (a → b) → (∀(x : 
*) → (∀(s : *) → s → (s → ∀(x : *) → (a → s → x) → x) → x) →
 x) → ∀(x : *) → (∀(s : *) → s → (s → ∀(x : *) → (c → s → x)
 → x) → x) → x

λ(a : *) → λ(b : *) → λ(c : *) → λ(f : b → c) → λ(g : a → b)
 → λ(st : ∀(x : *) → (∀(s : *) → s → (s → ∀(x : *) → (a → s 
→ x) → x) → x) → x) → λ(x : *) → λ(S : ∀(s : *) → s → (s → ∀
(x : *) → (c → s → x) → x) → x) → st x (λ(s : *) → λ(seed : 
s) → λ(step : s → ∀(x : *) → (a → s → x) → x) → S s seed (λ(
seed@1 : s) → λ(x : *) → λ(Pair : c → s → x) → step seed@1 x
 (λ(va : a) → Pair (f (g va)))))

$ morte < example4.mt
∀(a : *) → ∀(b : *) → ∀(c : *) → (b → c) → (a → b) → (∀(x : 
*) → (∀(s : *) → s → (s → ∀(x : *) → (a → s → x) → x) → x) →
 x) → ∀(x : *) → (∀(s : *) → s → (s → ∀(x : *) → (c → s → x)
 → x) → x) → x

λ(a : *) → λ(b : *) → λ(c : *) → λ(f : b → c) → λ(g : a → b)
 → λ(va : ∀(x : *) → (∀(s : *) → s → (s → ∀(x : *) → (a → s 
→ x) → x) → x) → x) → λ(x : *) → λ(S : ∀(s : *) → s → (s → ∀
(x : *) → (c → s → x) → x) → x) → va x (λ(s : *) → λ(seed : 
s) → λ(step : s → ∀(x : *) → (a → s → x) → x) → S s seed (λ(
seed@1 : s) → λ(x : *) → λ(Pair : c → s → x) → step seed@1 x
 (λ(va : a) → Pair (f (g va))))

We inadvertently proved stream fusion for free, but we're still not done, yet! Everything we learn about recursive and corecursive sequences can be applied to model recursive and corecursive effects!

Effects

I will conclude this post by showing how to model both recursive and corecursive programs that have side effects. The recursive program will echo ninety-nine lines from stdin to stdout. The equivalent Haskell program is in the comment header:

-- recursive.mt

-- The Haskell code we will translate to Morte:
--
--     import Prelude hiding (
--         (+), (*), IO, putStrLn, getLine, (>>=), (>>), return )
-- 
--     -- Simple prelude
--
--     data Nat = Succ Nat | Zero
--
--     zero :: Nat
--     zero = Zero
--
--     one :: Nat
--     one = Succ Zero
--
--     (+) :: Nat -> Nat -> Nat
--     Zero   + n = n
--     Succ m + n = m + Succ n
--
--     (*) :: Nat -> Nat -> Nat
--     Zero   * n = Zero
--     Succ m * n = n + (m * n)
--
--     foldNat :: Nat -> (a -> a) -> a -> a
--     foldNat  Zero    f x = x
--     foldNat (Succ m) f x = f (foldNat m f x)
--
--     data IO r
--         = PutStrLn String (IO r)
--         | GetLine (String -> IO r)
--         | Return r
--
--     putStrLn :: String -> IO U
--     putStrLn str = PutStrLn str (Return Unit)
--
--     getLine :: IO String
--     getLine = GetLine Return
--
--     return :: a -> IO a
--     return = Return
--
--     (>>=) :: IO a -> (a -> IO b) -> IO b
--     PutStrLn str io >>= f = PutStrLn str (io >>= f)
--     GetLine k       >>= f = GetLine (\str -> k str >>= f)
--     Return r        >>= f = f r
--
--     -- Derived functions
--
--     (>>) :: IO U -> IO U -> IO U
--     m >> n = m >>= \_ -> n
--
--     two :: Nat
--     two = one + one
--
--     three :: Nat
--     three = one + one + one
--
--     four :: Nat
--     four = one + one + one + one
--
--     five :: Nat
--     five = one + one + one + one + one
--
--     six :: Nat
--     six = one + one + one + one + one + one
--
--     seven :: Nat
--     seven = one + one + one + one + one + one + one
--
--     eight :: Nat
--     eight = one + one + one + one + one + one + one + one
--
--     nine :: Nat
--     nine = one + one + one + one + one + one + one + one + one
--
--     ten :: Nat
--     ten = one + one + one + one + one + one + one + one + one + one
--
--     replicateM_ :: Nat -> IO U -> IO U
--     replicateM_ n io = foldNat n (io >>) (return Unit)
--
--     ninetynine :: Nat
--     ninetynine = nine * ten + nine
--
--     main_ :: IO U
--     main_ = replicateM_ ninetynine (getLine >>= putStrLn)

-- "Free" variables
(   \(String : *   )
->  \(U : *)
->  \(Unit : U)

    -- Simple prelude
->  (   \(Nat : *)
    ->  \(zero : Nat)
    ->  \(one : Nat)
    ->  \((+) : Nat -> Nat -> Nat)
    ->  \((*) : Nat -> Nat -> Nat)
    ->  \(foldNat : Nat -> forall (a : *) -> (a -> a) -> a -> a)
    ->  \(IO : * -> *)
    ->  \(return : forall (a : *) -> a -> IO a)
    ->  \((>>=)
        :   forall (a : *)
        ->  forall (b : *)
        ->  IO a
        ->  (a -> IO b)
        ->  IO b
        )
    ->  \(putStrLn : String -> IO U)
    ->  \(getLine : IO String)

        -- Derived functions
    ->  (   \((>>) : IO U -> IO U -> IO U)
        ->  \(two   : Nat)
        ->  \(three : Nat)
        ->  \(four  : Nat)
        ->  \(five  : Nat)
        ->  \(six   : Nat)
        ->  \(seven : Nat)
        ->  \(eight : Nat)
        ->  \(nine  : Nat)
        ->  \(ten   : Nat)
        ->  (   \(replicateM_ : Nat -> IO U -> IO U)
            ->  \(ninetynine : Nat)

            ->  replicateM_ ninetynine ((>>=) String U getLine putStrLn)
            )

            -- replicateM_
            (   \(n : Nat)
            ->  \(io : IO U)
            ->  foldNat n (IO U) ((>>) io) (return U Unit)
            )

            -- ninetynine
            ((+) ((*) nine ten) nine)
        )

        -- (>>)
        (   \(m : IO U)
        ->  \(n : IO U)
        ->  (>>=) U U m (\(_ : U) -> n)
        )

        -- two
        ((+) one one)

        -- three
        ((+) one ((+) one one))

        -- four
        ((+) one ((+) one ((+) one one)))

        -- five
        ((+) one ((+) one ((+) one ((+) one one))))

        -- six
        ((+) one ((+) one ((+) one ((+) one ((+) one one)))))

        -- seven
        ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one one))))))

        -- eight
        ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one one)))))))
        -- nine
        ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one one))))))))

        -- ten
        ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one ((+) one one)))))))))
    )

    -- Nat
    (   forall (a : *)
    ->  (a -> a)
    ->  a
    ->  a
    )

    -- zero
    (   \(a : *)
    ->  \(Succ : a -> a)
    ->  \(Zero : a)
    ->  Zero
    )

    -- one
    (   \(a : *)
    ->  \(Succ : a -> a)
    ->  \(Zero : a)
    ->  Succ Zero
    )

    -- (+)
    (   \(m : forall (a : *) -> (a -> a) -> a -> a)
    ->  \(n : forall (a : *) -> (a -> a) -> a -> a)
    ->  \(a : *)
    ->  \(Succ : a -> a)
    ->  \(Zero : a)
    ->  m a Succ (n a Succ Zero)
    )

    -- (*)
    (   \(m : forall (a : *) -> (a -> a) -> a -> a)
    ->  \(n : forall (a : *) -> (a -> a) -> a -> a)
    ->  \(a : *)
    ->  \(Succ : a -> a)
    ->  \(Zero : a)
    ->  m a (n a Succ) Zero
    )

    -- foldNat
    (   \(n : forall (a : *) -> (a -> a) -> a -> a)
    ->  n
    )

    -- IO
    (   \(r : *)
    ->  forall (x : *)
    ->  (String -> x -> x)
    ->  ((String -> x) -> x)
    ->  (r -> x)
    ->  x
    )

    -- return
    (   \(a : *)
    ->  \(va : a)
    ->  \(x : *)
    ->  \(PutStrLn : String -> x -> x)
    ->  \(GetLine : (String -> x) -> x)
    ->  \(Return : a -> x)
    ->  Return va
    )

    -- (>>=)
    (   \(a : *)
    ->  \(b : *)
    ->  \(m : forall (x : *)
        ->  (String -> x -> x)
        ->  ((String -> x) -> x)
        ->  (a -> x)
        ->  x
        )
    ->  \(f : a
        ->  forall (x : *)
        -> (String -> x -> x)
        -> ((String -> x) -> x)
        -> (b -> x)
        -> x
        )
    ->  \(x : *)
    ->  \(PutStrLn : String -> x -> x)
    ->  \(GetLine : (String -> x) -> x)
    ->  \(Return : b -> x)
    ->  m x PutStrLn GetLine (\(va : a) -> f va x PutStrLn GetLine Return)
    )

    -- putStrLn
    (   \(str : String)
    ->  \(x : *)
    ->  \(PutStrLn : String -> x -> x  )
    ->  \(GetLine  : (String -> x) -> x)
    ->  \(Return   : U -> x)
    ->  PutStrLn str (Return Unit)
    )

    -- getLine
    (   \(x : *)
    ->  \(PutStrLn : String -> x -> x  )
    ->  \(GetLine  : (String -> x) -> x)
    ->  \(Return   : String -> x)
    -> GetLine Return
    )
)

This program will compile to a completely unrolled read-write loop, as most recursive programs will:

$ morte < recursive.mt
∀(String : *) → ∀(U : *) → U → ∀(x : *) → (String → x → x) →
 ((String → x) → x) → (U → x) → x

λ(String : *) → λ(U : *) → λ(Unit : U) → λ(x : *) → λ(PutStr
Ln : String → x → x) → λ(GetLine : (String → x) → x) → λ(Ret
urn : U → x) → GetLine (λ(va : String) → PutStrLn va (GetLin
e (λ(va@1 : String) → PutStrLn va@1 (GetLine (λ(va@2 : Strin
g) → PutStrLn va@2 (GetLine (λ(va@3 : String) → PutStrLn ...
 <snip>
... GetLine (λ(va@92 : String) → PutStrLn va@92 (GetLine (λ(
va@93 : String) → PutStrLn va@93 (GetLine (λ(va@94 : String)
 → PutStrLn va@94 (GetLine (λ(va@95 : String) → PutStrLn va@
95 (GetLine (λ(va@96 : String) → PutStrLn va@96 (GetLine (λ(
va@97 : String) → PutStrLn va@97 (GetLine (λ(va@98 : String)
 → PutStrLn va@98 (Return Unit))))))))))))))))))))))))))))))
))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))
))))))))))))))))))))))))))))))))))))))))))))))))))))))))))))
))))))))))))))))))))))))))))))))))))))))))))))))

In contrast, if we encode the effects corecursively we can express a program that echoes indefinitely from stdin to stdout:

-- corecursive.mt

-- data IOF r s
--     = PutStrLn String s
--     | GetLine (String -> s)
--     | Return r
--
-- data IO r = forall s . MkIO s (s -> IOF r s)
--
-- main = MkIO
--     Nothing
--     (maybe (\str -> PutStrLn str Nothing) (GetLine Just))

(   \(String : *)
->  (   \(Maybe : * -> *)
    ->  \(Just : forall (a : *) -> a -> Maybe a)
    ->  \(Nothing : forall (a : *) -> Maybe a)
    ->  \(  maybe
        :   forall (a : *)
        ->  Maybe a
        ->  forall (x : *)
        ->  (a -> x)
        ->  x
        ->  x
        )
    ->  \(IOF : * -> * -> *)
    ->  \(  PutStrLn
        :   forall (r : *)
        ->  forall (s : *)
        ->  String
        ->  s
        ->  IOF r s
        )
    ->  \(  GetLine
        :   forall (r : *)
        ->  forall (s : *)
        ->  (String -> s)
        ->  IOF r s
        )
    ->  \(  Return
        :   forall (r : *)
        ->  forall (s : *)
        ->  r
        ->  IOF r s
        )
    ->  (   \(IO : * -> *)
        ->  \(  MkIO
            :   forall (r : *)
            ->  forall (s : *)
            ->  s
            ->  (s -> IOF r s)
            ->  IO r
            )
        ->  (   \(main : forall (r : *) -> IO r)
            ->  main
            )

            -- main
            (   \(r : *)
            ->  MkIO
                r
                (Maybe String)
                (Nothing String)
                (   \(m : Maybe String)
                ->  maybe
                        String
                        m
                        (IOF r (Maybe String))
                        (\(str : String) ->
                            PutStrLn
                                r
                                (Maybe String)
                                str
                                (Nothing String)
                        )
                        (GetLine r (Maybe String) (Just String))
                )
            )
        )

        -- IO
        (   \(r : *)
        ->  forall (x : *)
        ->  (forall (s : *) -> s -> (s -> IOF r s) -> x)
        ->  x
        )

        -- MkIO
        (   \(r : *)
        ->  \(s : *)
        ->  \(seed : s)
        ->  \(step : s -> IOF r s)
        ->  \(x : *)
        ->  \(k : forall (s : *) -> s -> (s -> IOF r s) -> x)
        ->  k s seed step
        )
    )

    -- Maybe
    (\(a : *) -> forall (x : *) -> (a -> x) -> x -> x)

    -- Just
    (   \(a : *)
    ->  \(va : a)
    ->  \(x : *)
    ->  \(Just : a -> x)
    ->  \(Nothing : x)
    ->  Just va
    )

    -- Nothing
    (   \(a : *)
    ->  \(x : *)
    ->  \(Just : a -> x)
    ->  \(Nothing : x)
    ->  Nothing
    )

    -- maybe
    (   \(a : *)
    ->  \(m : forall (x : *) ->  (a -> x) ->  x-> x)
    ->  m
    )

    -- IOF
    (   \(r : *)
    ->  \(s : *)
    ->  forall (x : *)
    ->  (String -> s -> x)
    ->  ((String -> s) -> x)
    ->  (r -> x)
    ->  x
    )

    -- PutStrLn
    (   \(r : *)
    ->  \(s : *)
    ->  \(str : String)
    ->  \(vs : s)
    ->  \(x : *)
    ->  \(PutStrLn : String -> s -> x)
    ->  \(GetLine : (String -> s) -> x)
    ->  \(Return : r -> x)
    ->  PutStrLn str vs
    )

    -- GetLine
    (   \(r : *)
    ->  \(s : *)
    ->  \(k : String -> s)
    ->  \(x : *)
    ->  \(PutStrLn : String -> s -> x)
    ->  \(GetLine : (String -> s) -> x)
    ->  \(Return : r -> x)
    ->  GetLine k
    )

    -- Return
    (   \(r : *)
    ->  \(s : *)
    ->  \(vr : r)
    ->  \(x : *)
    ->  \(PutStrLn : String -> s -> x)
    ->  \(GetLine : (String -> s) -> x)
    ->  \(Return : r -> x)
    ->  Return vr
    )

)

This compiles to a state machine that we can unfold one step at a time:

$ morte < corecursive.mt
∀(String : *) → ∀(r : *) → ∀(x : *) → (∀(s : *) → s → (s → ∀
(x : *) → (String → s → x) → ((String → s) → x) → (r → x) → 
x) → x) → x

λ(String : *) → λ(r : *) → λ(x : *) → λ(k : ∀(s : *) → s → (
s → ∀(x : *) → (String → s → x) → ((String → s) → x) → (r → 
x) → x) → x) → k (∀(x : *) → (String → x) → x → x) (λ(x : *)
 → λ(Just : String → x) → λ(Nothing : x) → Nothing) (λ(m : ∀
(x : *) → (String → x) → x → x) → m (∀(x : *) → (String → (∀
(x : *) → (String → x) → x → x) → x) → ((String → ∀(x : *) →
 (String → x) → x → x) → x) → (r → x) → x) (λ(str : String) 
→ λ(x : *) → λ(PutStrLn : String → (∀(x : *) → (String → x) 
→ x → x) → x) → λ(GetLine : (String → ∀(x : *) → (String → x
) → x → x) → x) → λ(Return : r → x) → PutStrLn str (λ(x : *)
 → λ(Just : String → x) → λ(Nothing : x) → Nothing)) (λ(x : 
*) → λ(PutStrLn : String → (∀(x : *) → (String → x) → x → x)
 → x) → λ(GetLine : (String → ∀(x : *) → (String → x) → x → 
x) → x) → λ(Return : r → x) → GetLine (λ(va : String) → λ(x 
: *) → λ(Just : String → x) → λ(Nothing : x) → Just va))

I don't expect you to understand that output other than to know that we can translate the output to any backend that provides functions, and primitive read/write operations.

Conclusion

If you would like to use Morte, you can find the library on both Github and Hackage. I also provide a Morte tutorial that you can use to learn more about the library.

Morte is dependently typed in theory, but in practice I have not exercised this feature so I don't understand the implications of this. If this turns out to be a mistake then I will downgrade Morte to System Fw, which has higher-kinds and polymorphism, but no dependent types.

Additionally, Morte might be usable to transmit code in a secure and typed way in distributed environment or to share code between diverse functional language by providing a common intermediate language. However, both of those scenarios require additional work, such as establishing a shared set of foreign primitives and creating Morte encoders/decoders for each target language.

Also, there are additional optimizations which Morte might implement in the future. For example, Morte could use free theorems (equalities you deduce from the types) to simplify some code fragments even further, but Morte currently does not do this.

My next goals are:

  • Add a back-end to compile Morte to LLVM
  • Add a front-end to desugar a medium-level Haskell-like language to Morte

Once those steps are complete then Morte will be a usable intermediate language for writing super-optimizable programs.

Also, if you're wondering, the name Morte is a tribute to a talking skull from the game Planescape: Torment, since the Morte library is a "bare-bones" calculus of constructions.

Literature

If this topic interests you more, you may find the following links helpful, in roughly increasing order of difficulty:

7 comments:

  1. Very interesting post! (You had won me over by the time you said "Morte" :)

    I have a question regarding your statement that if two programs are “equal” they produce the same executable. What do you mean by equal in this context? alpha-beta-eta-equivalence? And if so, do you think that is the right concept here?

    The reason why I am asking here is the following. In classical treatments of lambda calculus, programs are just single lamba terms. There, we know that simply typed lambda calculus is strongly normalizing, and therefore (alpha)-beta-eta-equivalence is decidable by computing the normal form. However, what we are interested here is the question whether the two *programs* are the same, i.e., compute the same output on all inputs. More precisely, the question would be, given lambda terms A and B decide whether (A X) beta-eta-= (B X) for all correctly typed X. Even though simply typed lambda calculus is not Turing complete, I still cannot imagine that this problem is decidable for all A and B.

    What I am getting at is this: We have two notions of equality. i) ordinary beta-eta-equivalence and ii) beta-eta-equivalence over all X in the above sense. ii) seems to be too strong a notion of equality to be tackled with the approach you outline. On the other, hand i) seems too weak to be interesting for programs that act on many different inputs. So, is there a notion of equality that lies between i) and ii) that you are shooting for here?

    ReplyDelete
    Replies
    1. Yeah, in this case I mean alpha-beta-eta-equivalence. I think the right concept is also free theorem equivalence, but that's harder to implement, and I'm trying to see if there is an easy subset of theorems that can be implemented that is reasonably complete.

      I think if you also have free theorems you can implement the kind of equivalence you just described, but that is still a very big "if".

      Delete
  2. In the Haskell version of the IO example, the definition of `main = replicateM_ ninetynine main_` is missing.

    ReplyDelete
  3. There are examples of the list operations on folds on Oleg's site.

    ReplyDelete
  4. There are examples of the list operations on folds on Oleg's site.

    ReplyDelete
  5. I really like the ideas you're following with Morte, Annah, Dhall, etc. but stumbling on this post I'm wondering if "super-optimize" is the right phrase to use, and whether "super-compile" would be more accurate.

    As far as I'm aware, super-optimisation involves brute-force search for a program which is provably equivalent to the input, starting with small, fast programs and gradually trying longer, slower programs until a provably-equivalent one can be found. Crucially, we don't generate code by modifying the given program into some other form; instead, it's only used as a *specification* against which we check our generated code. It seems like state of the art superoptimisers can handle programs containing around 10 machine code instructions, without branches.

    On the other hand, super-compilation involves modifying the given code as much as possible, acting like constant-folding on steroids. Rather than special-case optimisations like constant folding, function inlining, loop unrolling, etc. a supercompiler will try to evaluate as many terms of the program as possible at compile time. Known functions applied to known arguments will be run to completion, abstractions will be collapsed and any remaining computation will be specialised to perform only those steps which couldn't be reduced without knowing the input (e.g. a pattern-match). That seems to be what Morte does; relying on CoC being total to ensure that performing such calls will terminate.

    Some super-compilers also extend the reduction rules of the language, to try and reduce redundancy between branches. For example, during supercompilation our code may end up containing something like 'if (x > 3) then (if (x > 0) then "hello" else "goodbye") else "world"' which we can't beta-reduce without knowing "x"; however, we can still propagate the "positive" information that "x > 3" into the "then" branch, which tells us enough about "x" that we can evaluate the inner "if" to get 'if (x > 3) then "hello" else "world"'. We can also propagate the "negative" information (in this case "not (x > 3)") into "else" branches.

    ReplyDelete