The context behind this post is that my partner asked me how to implement type inference for plain data structures (e.g. JSON or YAML) which was awfully convenient because this is something I’ve done a couple of times already and there is a pretty elegant trick for this I wanted to share.
Now, normally type inference and unification are a bit tricky to implement in a programming language with functions, but they’re actually fairly simple to implement if all you have to work with is plain data. To illustrate this, I’ll implement and walk through a simple type inference algorithm for JSON-like expressions.
For this post I’ll use the Value
type from Haskell’s
aeson
package, which represents a JSON value1:
data Value
= Object (KeyMap Value) -- { "key₀": value₀, "key₁": value₁, … }
| Array (Vector Value) -- [ element₀, element₁, … ]
| String Text -- e.g. "example string"
| Number Scientific -- e.g. 42.0
| Bool Bool -- true or false
| Null -- null
I’ll also introduce a Type
datatype to represent the
type of a JSON value, which is partially inspired by TypeScript:
import Data.Aeson.KeyMap (KeyMap)
data Type
= ObjectType (KeyMap Type) -- { "key₀": type₀, "key₁": type₁, … }
| ArrayType Type -- type[]
| StringType -- string
| NumberType -- number
| BoolType -- boolean
| Optional Type -- null | type
| Never -- never, the subtype of all other types
| Any -- any, the supertype of all other types
deriving (Show)
… and the goal is that we want to implement an infer
function that has this type:
import Data.Aeson (Value(..))
infer :: Value -> Type
I want to walk through a few test cases before diving into the
implementation, otherwise it might not be clear what the
Type
constructors are supposed to represent:
>>> -- I'll use the usual `x : T` syntax to denote "`x` has type `T`"
>>> -- I'll also use TypeScript notation for the types
>>> -- "example string" : string
>>> infer (String "example string")
StringType
>>> -- true : boolean
>>> infer (Bool True)
BoolType
>>> -- false : boolean
>>> infer (Bool False)
BoolType
>>> -- 42 : number
>>> infer (Number 42)
NumberType
>>> -- [ 2, 3, 5 ] : number[]
>>> infer (Array [Number 2, Number 3, Number 5])
ArrayType NumberType
>>> -- [ 2, "hello" ] : any[]
>>> -- To keep things simple, we'll differ from TypeScript and not infer
>>> -- a type like (number | string)[]. That's an exercise for the reader.
>>> infer (Array [Number 2, String "hello"])
ArrayType Any
>>> -- [] : never[]
>>> infer (Array [])
ArrayType Never
>>> -- { "key₀": true, "key₁": 42 } : { "key₀": bool, "key₁": number }
>>> infer (Object [("key₀", Bool True), ("key₁", Number 42)])
ObjectType [("key₀", BoolType), ("key₁", NumberType)]
>>> -- [{ "key₀": true }, { "key₁": 42 }] : { "key₀": null | bool, "key₁": null | bool }[]
>>> infer (Array [Object [("key₀", Bool True)], Object [("key₁", Number 42)]])
ArrayType (ObjectType (fromList [("key₀",Optional BoolType),("key₀",Optional NumberType)]))
>>> -- null : null | never
>>> infer Null
Optional Never
>>> -- [ null, true ] : (null | boolean)[]
>>> infer (Array [Null, Bool True])
ArrayType (Optional Bool)
Some of those test cases correspond almost 1-to-1 with the
implementation of infer
, which we can begin to
implement:
infer :: Value -> Type
String _) = StringType
infer (Bool _) = BoolType
infer (Number _) = NumberType
infer (Null = Optional Never
infer …
The main two non-trivial cases are the implementation of
infer
for Object
s and Array
s.
We’ll start with Object
s since that’s the easier case to
infer. To infer the type of an object we infer the type of each field
and then collect those field types into the final object type:
Object fields) = ObjectType (fmap infer fields) infer (
The last tricky bit to implement is the case for Array
s.
We might start with something like this:
Array elements) = ArrayType ??? infer (
… but what goes in the result? This is NOT correct:
Array elements) = ArrayType (fmap infer elements) infer (
… because there can only be a single element type for the whole array. We can infer the type of each element, but if those element types don’t match then we need some way to unify those element types into a single element type representing the entire array. In other words, we need a function with this type:
unify :: Vector Type -> Type
… because if we had such function then we could write:
Array elements) = ArrayType (unify (fmap infer elements)) infer (
The trick to doing this is that we need to implement a
Monoid
instance and Semigroup
instance for
Type
, which is the same as saying that we need to define
two functions:
-- The default type `unify` returns if our list is empty
mempty :: Type
-- Unify two types into one
(<>) :: Type -> Type -> Type
… because if we implement those two functions then our
unify
function becomes … fold
!
import Data.Foldable (fold)
import Data.Vector (Vector)
unify :: Vector Type -> Type
= fold unify
The documentation for fold
explains how it works:
Given a structure with elements whose type is a Monoid, combine them via the monoid’s (
<>
) operator.
Laws
There are a few rules we need to be aware of when implementing
mempty
and (<>)
which will help ensure
that our implementation of unification is well-behaved.
First, mempty
and (<>)
must obey the
“Monoid
laws”, which require that:
-- Left identity
mempty <> x = x
-- Right identity
<> mempty = x
x
-- Associativity
<> (y <> z) = (x <> y) <> z x
Second, mempty
and (<>)
must
additionally obey the following unification laws:
mempty
is a subtype ofx
, for allx
x <> y
is a supertype of bothx
andy
Unification
mempty
is easy to implement since according to the
unification laws mempty
must be the universal subtype,
which is the Never
type:
instance Monoid Type where
mempty = Never
(<>)
is the more interesting function to
implement, and we’ll start with the easy cases:
instance Semigroup Type where
StringType <> StringType = StringType
NumberType <> NumberType = NumberType
BoolType <> BoolType = BoolType
…
If we unify any scalar type with itself, we get back the same type. That’s pretty self-explanatory.
The next two cases are also pretty simple:
Never <> other = other
<> Never = other other
If we unify the Never
type with any other
type, then we get the other type because Never
is a subtype
of every other type.
The next case is slightly more interesting:
ArrayType left <> ArrayType right = ArrayType (left <> right)
If we unify two array types, then we unify their element types. But
what about Optional
types?
Optional left <> Optional right = Optional (left <> right)
Optional left <> right = Optional (left <> right)
<> Optional right = Optional (left <> right) left
If we unify two Optional
types, then we unify their
element types, but we also handle the case where only one or the other
type is Optional
, too.
The last complex data type is objects, which has the most interesting implementation:
ObjectType left <> ObjectType right =
ObjectType (KeyMap.alignWith adapt left right)
where
This a) = Optional a
adapt (That b) = Optional b
adapt (These a b) = a <> b adapt (
You can read that as saying “to unify two objects, unify the types of
their respective fields, and if either object has an extra field not
present in the other object then wrap the field’s type in
Optional
”.
Finally, we have the case of last resort:
<> _ = Any _
If we try to unify two types that could not unify via the previous
rules, then fall back to Any
(the supertype of all other
types).
This gives us our final program (which I’ll included in its entirety here):
import Data.Aeson (Value(..))
import Data.Aeson.KeyMap (KeyMap)
import Data.Foldable (fold)
import Data.These (These(..))
import Data.Vector (Vector)
import qualified Data.Aeson.KeyMap as KeyMap
data Type
= ObjectType (KeyMap Type) -- { "key₀": type₀, "key₁": type₁, … }
| ArrayType Type -- type[]
| StringType -- string
| NumberType -- number
| BoolType -- boolean
| Optional Type -- null | type
| Never -- never, the subtype of all other types
| Any -- any, the supertype of all other types
deriving (Show)
infer :: Value -> Type
String _) = StringType
infer (Bool _) = BoolType
infer (Number _) = NumberType
infer (Null = Optional Never
infer Object fields) = ObjectType (fmap infer fields)
infer (Array elements) = ArrayType (unify (fmap infer elements))
infer (
unify :: Vector Type -> Type
= fold
unify
instance Monoid Type where
mempty = Never
instance Semigroup Type where
StringType <> StringType = StringType
NumberType <> NumberType = NumberType
BoolType <> BoolType = BoolType
Never <> other = other
<> Never = other
other
ArrayType left <> ArrayType right = ArrayType (left <> right)
Optional left <> Optional right = Optional (left <> right)
Optional left <> right = Optional (left <> right)
<> Optional right = Optional (left <> right)
left
ObjectType left <> ObjectType right =
ObjectType (KeyMap.alignWith adapt left right)
where
This a) = Optional a
adapt (That b) = Optional b
adapt (These a b) = a <> b
adapt (
<> _ = Any _
Pretty simple! That’s the complete implementation of type inference and unification.
Unification laws
I mentioned that our implementation should satisfy the
Monoid
laws and unification laws, so I’ll include some
quick proof sketches (albeit not full formal proofs), starting with the
unification laws.
Let’s start with the first unification law:
mempty
is the subtype ofx
, for allx
This is true because we define mempty = Never
and
Never
is the subtype of all other types.
Next, let’s show that the implementation of (<>)
satisfies the other unification law:
x <> y
is a super type of bothx
andy
The first case is:
StringType <> StringType = StringType
This satisfies the unificaiton law because if we replace both
x
and y
with StringType
we
get:
StringType <> StringType
is a supertype of bothStringType
andStringType
… and since StringType <> StringType = StringType
that simplifies down to:
StringType
is a supertype of bothStringType
andStringType
… and every type is a supertype of itself, so this satisfies the unification law.
We’d prove the unification law for the next two cases in the exact
same way (just replacing StringType
with
NumberType
or BoolType
):
NumberType <> NumberType = NumberType
BoolType <> BoolType = BoolType
What about the next case:
Never <> other = other
Well, if we take our unification law and replace x
with
Never
and replace y
with other
we
get:
Never <> other
is a supertype ofNever
andother
… and since Never <> other = other
that simplifies
to:
other
is a supertype ofNever
andother
… which is true because:
other
is a supertype ofNever
(becauseNever
is the universal subtype)other
is a supertype ofother
(because every type is a supertype of itself)
We’d prove the next case in the exact same way (just swapping
Never
and other
):
<> Never = other other
For the next case:
ArrayType left <> ArrayType right = ArrayType (left <> right)
The unification law becomes:
ArrayType (left <> right)
is a supertype of bothArrayType left
andArrayType right
… which is true because ArrayType
is covariant
and by induction left <> right
is a supertype of both
left
and right
.
We’d prove the first case for Optional
in the exact same
way (just replace Array
with Optional
):
Optional left <> Optional right = Optional (left <> right)
The next case for Optional
is more interesting:
Optional left <> right = Optional (left <> right)
Here the unification law would be:
Optional (left <> right)
is a supertype ofOptional left
andright
… which is true because:
Optional (left <> right)
is a supertype ofOptional left
This is true because
Optional
is covariant andleft <> right
is a supertype ofleft
Optional (left <> right)
is a supertype ofright
This is true because:
Optional (left <> right)
is a supertype ofOptional right
Optional right
is a supertype ofright
- Therefore, by transitivity,
Optional (left <> right)
is a supertype ofright
We’d prove the next case in the same, just switching
left
and right
:
<> Optional right = Optional (left <> right) left
The case for objects is the most interesting case:
ObjectType left <> ObjectType right =
ObjectType (KeyMap.alignWith adapt left right)
where
This a) = Optional a
adapt (That b) = Optional b
adapt (These a b) = a <> b adapt (
I won’t prove this case as formally, but the basic idea is that this
is true because a record type (A
) is a supertype of another
record type (B
) if and only if:
- for each field
k
they share in common,A.k
is a supertype ofB.k
- for each field
k
present only inA
,A.k
is a supertype ofOptional Never
- there are no fields present only in
B
… and given that definition of record subtyping then the above implementation satisfies the unification law.
Monoid laws
The first two Monoid
laws are trivial to prove:
mempty <> x = x
<> mempty = x x
… because we defined:
mempty = Never
… and if we replace mempty
with Never
in
those laws:
Never <> x = x
<> Never = x x
… that is literally what our code defines (except replacing
x
with other
):
Never <> other = other
<> Never = other other
The last law, associativity, is pretty tedious to prove in full:
<> y) <> z = x <> (y <> z) (x
… but I’ll do a few cases to show how the basic gist of how the proof works.
First, the associativity law is easy to prove for the case where any
of x
, y
, or z
is
Never
. For example, if x = Never
, then we
get:
Never <> y) <> z = Never <> (y <> z)
(
-- Never <> other = other
<> z = y <> z y
… which is true. The other two cases for y = Never
and
z = Never
are equally simple to prove.
Associativity is also easy to prove when any of x
,
y
, or z
is Any
. For example, if
x = Any
, then we get:
Any <> y) <> z = Any <> (y <> z)
(
-- Any <> other = other
Any <> z = Any
-- Any <> other = other
Any = Any
… which is true. The other two cases for y = Any
and
Z = Any
are equally simple to prove.
Now we can prove associativity if any of x
,
y
or z
is StringType
. The reason
why is that these are the only relevant cases in the implementation of
unification for StringType
:
StringType <> StringType = StringType
StringType <> Never = StringType
Never <> StringType = StringType
StringType <> _ = Any
<> StringType = Any _
… but we already proved associativity for all cases involving a
Never
, so we don’t need to consider the second case, which
simplifies things down to:
StringType <> StringType = StringType
StringType <> _ = Any
<> StringType = Any _
That means, that there are only seven cases we need to consider to
prove the associativity laws if at least one of x
,
y
, and z
is StringType
(using
_
below to denote “any type other than
StringType
):
-- true: both sides evaluate to StringType
StringType <> StringType) <> StringType = StringType <> (StringType <> StringType)
(
-- all other cases below are also true: they all evaluate to `Any`
StringType <> StringType) <> _ = StringType <> (StringType <> _ )
(StringType <> _ ) <> StringType = StringType <> (_ <> StringType)
(StringType <> _ ) <> _ = StringType <> (_ <> _ )
(<> StringType) <> StringType = _ <> (StringType <> StringType)
(_ <> StringType) <> _ = _ <> (StringType <> _ )
(_ <> _ ) <> StringType = _ <> (_ <> StringType) (_
We can similarly prove associativity for all cases involving at least
one NumberType
or BoolType
.
The proof for ArrayType
is almost the same as the proof
for
StringType
/NumberType
/BoolType
.
The only relevant cases are:
ArrayType left <> ArrayType right = ArrayType (left <> right)
ArrayType left <> Never = ArrayType
Never <> ArrayType right = ArrayType
ArrayType left <> _ = Any
<> ArrayType right = Any _
Just like before, we can ignore the case where either argument is
Never
because we already proved associativity for that.
That just leaves:
ArrayType left <> ArrayType right = ArrayType (left <> right)
ArrayType left <> _ = Any
<> ArrayType right = Any _
Just like before, there are only seven cases we have to prove (using
_
below to denote “any type other than
ArrayType
):
ArrayType x <> (ArrayType y <> ArrayType z) = (ArrayType x <> ArrayType y) <> ArrayType z
-- … simplifies to:
ArrayType (x <> (y <> z)) = ArrayType ((x <> y) <> z)
-- … which is true because unification of the element types is associative
-- all other cases below are also true: they all evaluate to `Any`
ArrayType x <> ArrayType y) <> _ = ArrayType x <> (ArrayType y <> _ )
(ArrayType x <> _ ) <> ArrayType z = ArrayType x <> (_ <> ArrayType z)
(ArrayType x <> _ ) <> _ = ArrayType x <> (_ <> _ )
(<> ArrayType y) <> ArrayType z = _ <> (ArrayType y <> ArrayType z)
(_ <> ArrayType y) <> _ = _ <> (ArrayType y <> _ )
(_ <> _ ) <> ArrayType z = _ <> (_ <> ArrayType z) (_
The proofs for the Optional
and Object
cases are longer and more laborious so I’ll omit them. They’re an
exercise for the reader because I am LAZY.
I’ve inlined all the type synonyms and removed strictness annotations, for clarity↩︎
No comments:
Post a Comment