Tuesday, April 21, 2020

Blazing fast Fibonacci numbers using Monoids

fibonacci

This post illustrates a nifty application of Haskell’s standard library to solve a numeric problem.

The Fibonacci series is a well-known sequence of numbers defined by the following rules:

f(0) = 0
f(1) = 1
f(n) = f(n - 1) + f(n - 2)

In fact, that’s not only a specification of the Fibonacci numbers: that’s also valid Haskell code (with a few gratuitous parentheses to resemble traditional mathematical notation).

However, that solution is inefficient and you can instead use one of two “closed form” solutions for the Fibonacci numbers.

The first solution says that you can compute the Nth fibonacci number using the following formula:

f(n) =^n - ψ^n) /- ψ)
  where
    φ = (1 + sqrt(5)) / 2

    ψ = (1 - sqrt(5)) / 2

… which is also valid Haskell code.

Unfortunately, the above solution has two issues when translated to a computer algorithm using IEEE 754 floating-point numbers:

  • These floating point numbers suffer from floating point imprecision:

    >>> f(10)
    54.99999999999999
  • These floating point numbers cannot handle values larger than ~1.8 × 10³⁰⁸ (the maximum double-precision floating point number)

    >>> f(10000)
    Infinity

I instead prefer the second closed form solution using matrix arithmetic, which you can find here:

I will present a minor variation on that solution which is essentially the same solution.

You can compute the Nth Fibonacci number by using the following matrix multiplication expression:

-- Okay, this is not valid Haskell 😌
             ┌   ┐ⁿ ┌ ┐
             │0 1│  │0│
f(n) = [1 0] │1 1│  │1│
             └   ┘  └ ┘

There are two reasons I prefer this matrix-based closed-form solution:

  • This solution doesn’t require floating point numbers

  • You can more easily generalize this solution to other arithmetic sequences

To expand upon the latter point, if you have an arithmetic sequence of the form:

f(0) = a₁
f(1) = a₂

f(m) = aₘ
f(n) = b₁ × f(n - 1) + b₂ × f(n - 2) ++ bₘ × f(n - m)

… then the closed-form matrix solution is:

                 ┌          ┐ⁿ ┌  ┐
0  1 0  … │  │a₁│
                 │…  0 1  0 │  │a₂│
00  1 │  │… │
f(n) = [1 00] │bₘ … b₂ b₁│  │aₘ│
                 └          ┘  └  ┘

For now, though, we’ll stick to Fibonacci numbers, which we can implement efficiently in Haskell in less than 30 lines of code.

First, we’ll define a quick and dirty 2×2 matrix type as a record of four fields:

data Matrix2x2 = Matrix
    { x00 :: Integer, x01 :: Integer
    , x10 :: Integer, x11 :: Integer
    }

Haskell does have linear algebra packages, but I wanted to keep this solution as dependency-free as possible.

Then we’ll define matrix multiplication for this type using Haskell’s Semigroup class, which you can think of as a generic interface for any operator that is associative:

instance Semigroup Matrix2x2 where
    Matrix l00 l01 l10 l11 <> Matrix r00 r01 r10 r11 =
        Matrix
            { x00 = l00 * r00 + l01 * r10, x01 = l00 * r01 + l01 * r11
            , x10 = l10 * r00 + l11 * r10, x11 = l10 * r01 + l11 * r11
            }

We’ll see why we implement this general interface in just a second.

The only rule for this Semigroup interface is that the operator we implement must obey the following associativity law:

(x <> y) <> z = x <> (y <> z)

… and matrix multiplication is indeed associative.

Next, we implement the Monoid interface, which is essentially the same as the Semigroup interface except with an additional mempty value. This value is the “identity” of the corresponding Semigroup operation, meaning that the value obeys the following “identity laws”:

x <> mempty = x

mempty <> x = x

Since our Semigroup operation is matrix multiplication, the corresponding identity value is … the identity matrix (and now you know how it got that name):

instance Monoid Matrix2x2 where
    mempty =
        Matrix
            { x00 = 1, x01 = 0
            , x10 = 0, x11 = 1
            }

Now, in order to translate this expression to Haskell:

             ┌   ┐ⁿ ┌ ┐
             │0 1│  │0│
f(n) = [1 0] │1 1│  │1│
             └   ┘  └ ┘

… we need a fast way to exponentiate our Matrix2x2 type. Fortunately, we can do so using the mtimesDefault utility from Haskell’s standard library, which works for any type that implements Monoid:

-- | Repeat a value @n@ times.
--
-- > mtimesDefault n a = a <> a <> ... <> a  -- using <> (n-1) times
--
-- Implemented using 'stimes' and 'mempty'.
mtimesDefault :: Monoid a => Integer -> a -> a

This is why I chose to implement the Semigroup and Monoid interface, because when we do so we can use the above utility for free. The mtimesDefault function works for any type that implements those two interfaces (like our Matrix2x2 type). This means that in order to exponentiate a matrix, I only need to write mtimesDefault n matrix, which will multiply our matrix by itself n times.

The documentation for this utility fails to note one important detail: mtimesDefault will compute the result in only O(log(n)) operations using the trick known as exponentiation by squaring.

This leads to the solution for our elegant and efficient fibonacci function, which is:

import qualified Data.Semigroup as Semigroup

f :: Integer -> Integer
f n = x01 (Semigroup.mtimesDefault n matrix)
  where
    matrix =
        Matrix
            { x00 = 0, x01 = 1
            , x10 = 1, x11 = 1
            }

Here I’ve added one last simplification, which skips the final vector multiplications by instead extracting the value in the top right corner of our 2×2 matrix. This simplification works for the fibonacci numbers, but does not necessarily work for the general solution of computing an arbitrary arithmetic sequence.

Let’s quickly eyeball that things work:

>>> map f [0..20]
[0,1,1,2,3,5,8,13,21,34,55,89,144,233,377,610,987,1597,2584,4181,6765]

… and now we can compute extraordinarily large Fibonacci numbers, even more quickly than the computer can display them:

>>> f(100000)
25974069347221724166155034021275915414880485386517696584724770703952534543511273
68626555677283671674475463758722307443211163839947387509103096569738218830449305
22876385313349213530267927895670105127657827163560807305053220024323311438398651

🌺 200+ lines later 🌺

03835085621908060270866604873585849001704200923929789193938125116798421788115209
25913043557232163566089560351438388393901895316627435560997001569978028923636234
9895374653428746875

… in fact, you can easily compute up to f(10^8) in a couple of seconds using this code (not shown, because the result takes far longer to print than to compute).

Appendix

Here is the complete example in case you want to test this out on your own:

module Fibonacci where

import qualified Data.Semigroup as Semigroup

data Matrix2x2 = Matrix
    { x00 :: Integer, x01 :: Integer
    , x10 :: Integer, x11 :: Integer
    }

instance Monoid Matrix2x2 where
    mempty =
        Matrix
            { x00 = 1, x01 = 0
            , x10 = 0, x11 = 1
            }

instance Semigroup Matrix2x2 where
    Matrix l00 l01 l10 l11 <> Matrix r00 r01 r10 r11 =
        Matrix
            { x00 = l00 * r00 + l01 * r10, x01 = l00 * r01 + l01 * r11
            , x10 = l10 * r00 + l11 * r10, x11 = l10 * r01 + l11 * r11
            }

f :: Integer -> Integer
f n = x01 (Semigroup.mtimesDefault n matrix)
  where
    matrix =
        Matrix
            { x00 = 0, x01 = 1
            , x10 = 1, x11 = 1
            }