Type inference for plain data using Monoids
The context behind this post is that my partner asked me how to
implement type inference for plain data structures (e.g. JSON or YAML)
which was awfully convenient because this is something I’ve done a
couple of times already and there is a pretty elegant trick for this I
wanted to share.
Now, normally type inference
and unification
are a bit tricky to implement in a programming language with functions,
but they’re actually fairly simple to implement if all you have to work
with is plain data. To illustrate this, I’ll implement and walk through
a simple type inference algorithm for JSON-like expressions.
For this post I’ll use the Value type from Haskell’s
aeson package, which represents a JSON value:
data Value
= Object (KeyMap Value) -- { "key₀": value₀, "key₁": value₁, … }
| Array (Vector Value) -- [ element₀, element₁, … ]
| String Text -- e.g. "example string"
| Number Scientific -- e.g. 42.0
| Bool Bool -- true or false
| Null -- null
I’ll also introduce a Type datatype to represent the
type of a JSON value, which is partially inspired by TypeScript:
import Data.Aeson.KeyMap (KeyMap)
data Type
= ObjectType (KeyMap Type) -- { "key₀": type₀, "key₁": type₁, … }
| ArrayType Type -- type[]
| StringType -- string
| NumberType -- number
| BoolType -- boolean
| Optional Type -- null | type
| Never -- never, the subtype of all other types
| Any -- any, the supertype of all other types
deriving (Show)
… and the goal is that we want to implement an infer
function that has this type:
import Data.Aeson (Value(..))
infer :: Value -> Type
I want to walk through a few test cases before diving into the
implementation, otherwise it might not be clear what the
Type constructors are supposed to represent:
>>> -- I'll use the usual `x : T` syntax to denote "`x` has type `T`"
>>> -- I'll also use TypeScript notation for the types
>>> -- "example string" : string
>>> infer (String "example string")
StringType
>>> -- true : boolean
>>> infer (Bool True)
BoolType
>>> -- false : boolean
>>> infer (Bool False)
BoolType
>>> -- 42 : number
>>> infer (Number 42)
NumberType
>>> -- [ 2, 3, 5 ] : number[]
>>> infer (Array [Number 2, Number 3, Number 5])
ArrayType NumberType
>>> -- [ 2, "hello" ] : any[]
>>> -- To keep things simple, we'll differ from TypeScript and not infer
>>> -- a type like (number | string)[]. That's an exercise for the reader.
>>> infer (Array [Number 2, String "hello"])
ArrayType Any
>>> -- [] : never[]
>>> infer (Array [])
ArrayType Never
>>> -- { "key₀": true, "key₁": 42 } : { "key₀": bool, "key₁": number }
>>> infer (Object [("key₀", Bool True), ("key₁", Number 42)])
ObjectType [("key₀", BoolType), ("key₁", NumberType)]
>>> -- [{ "key₀": true }, { "key₁": 42 }] : { "key₀": null | bool, "key₁": null | bool }[]
>>> infer (Array [Object [("key₀", Bool True)], Object [("key₁", Number 42)]])
ArrayType (ObjectType (fromList [("key₀",Optional BoolType),("key₀",Optional NumberType)]))
>>> -- null : null | never
>>> infer Null
Optional Never
>>> -- [ null, true ] : (null | boolean)[]
>>> infer (Array [Null, Bool True])
ArrayType (Optional Bool)
Some of those test cases correspond almost 1-to-1 with the
implementation of infer, which we can begin to
implement:
infer :: Value -> Type
infer (String _) = StringType
infer (Bool _) = BoolType
infer (Number _) = NumberType
infer Null = Optional Never
…
The main two non-trivial cases are the implementation of
infer for Objects and Arrays.
We’ll start with Objects since that’s the easier case to
infer. To infer the type of an object we infer the type of each field
and then collect those field types into the final object type:
infer (Object fields) = ObjectType (fmap infer fields)
The last tricky bit to implement is the case for Arrays.
We might start with something like this:
infer (Array elements) = ArrayType ???
… but what goes in the result? This is NOT
correct:
infer (Array elements) = ArrayType (fmap infer elements)
… because there can only be a single element type for the whole
array. We can infer the type of each element, but if those element types
don’t match then we need some way to unify those element types into a
single element type representing the entire array. In other words, we
need a function with this type:
unify :: Vector Type -> Type
… because if we had such function then we could write:
infer (Array elements) = ArrayType (unify (fmap infer elements))
The trick to doing this is that we need to implement a
Monoid instance and Semigroup instance for
Type, which is the same as saying that we need to define
two functions:
-- The default type `unify` returns if our list is empty
mempty :: Type
-- Unify two types into one
(<>) :: Type -> Type -> Type
… because if we implement those two functions then our
unify function becomes … fold!
import Data.Foldable (fold)
import Data.Vector (Vector)
unify :: Vector Type -> Type
unify = fold
The documentation for fold explains how it works:
Given a structure with elements whose type is a Monoid, combine them via the monoid’s (<>) operator.
Laws
There are a few rules we need to be aware of when implementing
mempty and (<>) which will help ensure
that our implementation of unification is well-behaved.
First, mempty and (<>) must obey the
“Monoid laws”, which require that:
-- Left identity
mempty <> x = x
-- Right identity
x <> mempty = x
-- Associativity
x <> (y <> z) = (x <> y) <> z
Second, mempty and (<>) must
additionally obey the following unification laws:
mempty is a subtype of x, for all
x
x <> y is a supertype of both x and
y
Unification
mempty is easy to implement since according to the
unification laws mempty must be the universal subtype,
which is the Never type:
instance Monoid Type where
mempty = Never
(<>) is the more interesting function to
implement, and we’ll start with the easy cases:
instance Semigroup Type where
StringType <> StringType = StringType
NumberType <> NumberType = NumberType
BoolType <> BoolType = BoolType
…
If we unify any scalar type with itself, we get back the same type.
That’s pretty self-explanatory.
The next two cases are also pretty simple:
Never <> other = other
other <> Never = other
If we unify the Never type with any other
type, then we get the other type because Never is a subtype
of every other type.
The next case is slightly more interesting:
ArrayType left <> ArrayType right = ArrayType (left <> right)
If we unify two array types, then we unify their element types. But
what about Optional types?
Optional left <> Optional right = Optional (left <> right)
Optional left <> right = Optional (left <> right)
left <> Optional right = Optional (left <> right)
If we unify two Optional types, then we unify their
element types, but we also handle the case where only one or the other
type is Optional, too.
The last complex data type is objects, which has the most interesting
implementation:
ObjectType left <> ObjectType right =
ObjectType (KeyMap.alignWith adapt left right)
where
adapt (This (Optional a)) = Optional a
adapt (That (Optional b)) = Optional b
adapt (This a) = Optional a
adapt (That b) = Optional b
adapt (These a b) = a <> b
You can read that as saying “to unify two objects, unify the types of
their respective fields, and if either object has an extra field not
present in the other object then wrap the field’s type in
Optional”.
Finally, we have the case of last resort:
If we try to unify two types that could not unify via the previous
rules, then fall back to Any (the supertype of all other
types).
This gives us our final program (which I’ll included in its entirety
here):
import Data.Aeson (Value(..))
import Data.Aeson.KeyMap (KeyMap)
import Data.Foldable (fold)
import Data.These (These(..))
import Data.Vector (Vector)
import qualified Data.Aeson.KeyMap as KeyMap
data Type
= ObjectType (KeyMap Type) -- { "key₀": type₀, "key₁": type₁, … }
| ArrayType Type -- type[]
| StringType -- string
| NumberType -- number
| BoolType -- boolean
| Optional Type -- null | type
| Never -- never, the subtype of all other types
| Any -- any, the supertype of all other types
deriving (Show)
infer :: Value -> Type
infer (String _) = StringType
infer (Bool _) = BoolType
infer (Number _) = NumberType
infer Null = Optional Never
infer (Object fields) = ObjectType (fmap infer fields)
infer (Array elements) = ArrayType (unify (fmap infer elements))
unify :: Vector Type -> Type
unify = fold
instance Monoid Type where
mempty = Never
instance Semigroup Type where
StringType <> StringType = StringType
NumberType <> NumberType = NumberType
BoolType <> BoolType = BoolType
Never <> other = other
other <> Never = other
ArrayType left <> ArrayType right = ArrayType (left <> right)
Optional left <> Optional right = Optional (left <> right)
Optional left <> right = Optional (left <> right)
left <> Optional right = Optional (left <> right)
ObjectType left <> ObjectType right =
ObjectType (KeyMap.alignWith adapt left right)
where
adapt (This (Optional a)) = Optional a
adapt (That (Optional b)) = Optional b
adapt (This a) = Optional a
adapt (That b) = Optional b
adapt (These a b) = a <> b
_ <> _ = Any
Pretty simple! That’s the complete implementation of type inference and unification.
Unification laws
I mentioned that our implementation should satisfy the
Monoid laws and unification laws, so I’ll include some
quick proof sketches (albeit not full formal proofs), starting with the
unification laws.
Let’s start with the first unification law:
mempty is the subtype of x, for all
x
This is true because we define mempty = Never and
Never is the subtype of all other types.
Next, let’s show that the implementation of (<>)
satisfies the other unification law:
x <> y is a super type of both x and
y
The first case is:
StringType <> StringType = StringType
This satisfies the unificaiton law because if we replace both
x and y with StringType we
get:
StringType <> StringType is a supertype of both
StringType and StringType
… and since StringType <> StringType = StringType
that simplifies down to:
StringType is a supertype of both
StringType and StringType
… and every type is a supertype of itself, so this satisfies the
unification law.
We’d prove the unification law for the next two cases in the exact
same way (just replacing StringType with
NumberType or BoolType):
NumberType <> NumberType = NumberType
BoolType <> BoolType = BoolType
What about the next case:
Well, if we take our unification law and replace x with
Never and replace y with other we
get:
Never <> other is a supertype of
Never and other
… and since Never <> other = other that simplifies
to:
other is a supertype of Never and
other
… which is true because:
other is a supertype of Never (because
Never is the universal subtype)
other is a supertype of other (because
every type is a supertype of itself)
We’d prove the next case in the exact same way (just swapping
Never and other):
For the next case:
ArrayType left <> ArrayType right = ArrayType (left <> right)
The unification law becomes:
ArrayType (left <> right) is a supertype of both
ArrayType left and ArrayType right
… which is true because ArrayType is covariant
and by induction left <> right is a supertype of both
left and right.
We’d prove the first case for Optional in the exact same
way (just replace Array with Optional):
Optional left <> Optional right = Optional (left <> right)
The next case for Optional is more interesting:
Optional left <> right = Optional (left <> right)
Here the unification law would be:
Optional (left <> right) is a supertype of
Optional left and right
… which is true because:
Optional (left <> right) is a supertype of
Optional left
This is true because Optional is covariant and
left <> right is a supertype of
left
Optional (left <> right) is a supertype of
right
This is true because:
Optional (left <> right) is a supertype of
Optional right
Optional right is a supertype of
right
- Therefore, by transitivity,
Optional (left <> right) is a supertype of
right
We’d prove the next case in the same, just switching
left and right:
left <> Optional right = Optional (left <> right)
The case for objects is the most interesting case:
ObjectType left <> ObjectType right =
ObjectType (KeyMap.alignWith adapt left right)
where
adapt (This (Optional a)) = Optional a
adapt (That (Optional b)) = Optional b
adapt (This a) = Optional a
adapt (That b) = Optional b
adapt (These a b) = a <> b
I won’t prove this case as formally, but the basic idea is that this
is true because a record type (A) is a supertype of another
record type (B) if and only if:
- for each field
k they share in common, A.k
is a supertype of B.k
- for each field
k present only in A,
A.k is a supertype of Optional Never
- there are no fields present only in
B
… and given that definition of record subtyping then the above
implementation satisfies the unification law.
Monoid laws
The first two Monoid laws are trivial to prove:
mempty <> x = x
x <> mempty = x
… because we defined:
… and if we replace mempty with Never in
those laws:
Never <> x = x
x <> Never = x
… that is literally what our code defines (except replacing
x with other):
Never <> other = other
other <> Never = other
The last law, associativity, is pretty tedious to prove in full:
(x <> y) <> z = x <> (y <> z)
… but I’ll do a few cases to show how the basic gist of how the proof
works.
First, the associativity law is easy to prove for the case where any
of x, y, or z is
Never. For example, if x = Never, then we
get:
(Never <> y) <> z = Never <> (y <> z)
-- Never <> other = other
y <> z = y <> z
… which is true. The other two cases for y = Never and
z = Never are equally simple to prove.
Associativity is also easy to prove when any of x,
y, or z is Any. For example, if
x = Any, then we get:
(Any <> y) <> z = Any <> (y <> z)
-- Any <> other = other
Any <> z = Any
-- Any <> other = other
Any = Any
… which is true. The other two cases for y = Any and
Z = Any are equally simple to prove.
Now we can prove associativity if any of x,
y or z is StringType. The reason
why is that these are the only relevant cases in the implementation of
unification for StringType:
StringType <> StringType = StringType
StringType <> Never = StringType
Never <> StringType = StringType
StringType <> _ = Any
_ <> StringType = Any
… but we already proved associativity for all cases involving a
Never, so we don’t need to consider the second case, which
simplifies things down to:
StringType <> StringType = StringType
StringType <> _ = Any
_ <> StringType = Any
That means, that there are only seven cases we need to consider to
prove the associativity laws if at least one of x,
y, and z is StringType (using
_ below to denote “any type other than
StringType):
-- true: both sides evaluate to StringType
(StringType <> StringType) <> StringType = StringType <> (StringType <> StringType)
-- all other cases below are also true: they all evaluate to `Any`
(StringType <> StringType) <> _ = StringType <> (StringType <> _ )
(StringType <> _ ) <> StringType = StringType <> (_ <> StringType)
(StringType <> _ ) <> _ = StringType <> (_ <> _ )
(_ <> StringType) <> StringType = _ <> (StringType <> StringType)
(_ <> StringType) <> _ = _ <> (StringType <> _ )
(_ <> _ ) <> StringType = _ <> (_ <> StringType)
We can similarly prove associativity for all cases involving at least
one NumberType or BoolType.
The proof for ArrayType is almost the same as the proof
for
StringType/NumberType/BoolType.
The only relevant cases are:
ArrayType left <> ArrayType right = ArrayType (left <> right)
ArrayType left <> Never = ArrayType
Never <> ArrayType right = ArrayType
ArrayType left <> _ = Any
_ <> ArrayType right = Any
Just like before, we can ignore the case where either argument is
Never because we already proved associativity for that.
That just leaves:
ArrayType left <> ArrayType right = ArrayType (left <> right)
ArrayType left <> _ = Any
_ <> ArrayType right = Any
Just like before, there are only seven cases we have to prove (using
_ below to denote “any type other than
ArrayType):
ArrayType x <> (ArrayType y <> ArrayType z) = (ArrayType x <> ArrayType y) <> ArrayType z
-- … simplifies to:
ArrayType (x <> (y <> z)) = ArrayType ((x <> y) <> z)
-- … which is true because unification of the element types is associative
-- all other cases below are also true: they all evaluate to `Any`
(ArrayType x <> ArrayType y) <> _ = ArrayType x <> (ArrayType y <> _ )
(ArrayType x <> _ ) <> ArrayType z = ArrayType x <> (_ <> ArrayType z)
(ArrayType x <> _ ) <> _ = ArrayType x <> (_ <> _ )
(_ <> ArrayType y) <> ArrayType z = _ <> (ArrayType y <> ArrayType z)
(_ <> ArrayType y) <> _ = _ <> (ArrayType y <> _ )
(_ <> _ ) <> ArrayType z = _ <> (_ <> ArrayType z)
The proofs for the Optional and Object
cases are longer and more laborious so I’ll omit them. They’re an
exercise for the reader because I am LAZY.