2010-08-22

Polymorphic recursion with rank-2 polymorphism in OCaml 3.12

Using polymorphic recursion in OCaml just got easy and direct! With the new syntax, Okasaki's binary random-access lists translate to OCaml 3.12 practically verbatim, without the need for cumbersome encodings. You should refer to the book (Purely Functional Data Structures, p. 147) for comparison, but here's the implementation in its entirety:

```type 'a seq = Nil | Zero of ('a * 'a) seq | One of 'a * ('a * 'a) seq

let nil = Nil

let is_empty = function Nil -> true | _ -> false

let rec cons : 'a . 'a -> 'a seq -> 'a seq =
fun x l -> match l with
| Nil         -> One (x, Nil)
| Zero    ps  -> One (x, ps)
| One (y, ps) -> Zero (cons (x, y) ps)

let rec uncons : 'a . 'a seq -> 'a * 'a seq = function
| Nil          -> failwith "uncons"
| One (x, Nil) -> x, Nil
| One (x, ps ) -> x, Zero ps
| Zero    ps   ->
let (x, y), qs = uncons ps in
x, One (y, qs)

let head l = let x, _  = uncons l in x
and tail l = let _, xs = uncons l in xs

let rec lookup : 'a . int -> 'a seq -> 'a =
fun n l -> match l with
| Nil                    -> failwith "lookup"
| One (x, _ ) when n = 0 -> x
| One (_, ps)            -> lookup (n - 1) (Zero ps)
| Zero ps                ->
let (x, y) = lookup (n / 2) ps in
if n mod 2 = 0 then x else y

let update n e =
let rec go : 'a . ('a -> 'a) -> int -> 'a seq -> 'a seq =
fun f n l -> match l with
| Nil                    -> failwith "update"
| One (x, ps) when n = 0 -> One (f x, ps)
| One (x, ps)            -> cons x (go f (n - 1) (Zero ps))
| Zero ps                ->
let g (x, y) = if n mod 2 = 0 then (f x, y) else (x, f y) in
Zero (go g (n / 2) ps)
in go (fun x -> e) n
```

The implementation given in the book is rather bare-bones, but it can be extended with some thought and by paying close attention to the techniques Okasaki uses. To begin with, a `length` function is a very simple O(log n) mapping from constructors to integers:

```let rec length : 'a . 'a seq -> int
= fun l -> match l with
| Nil         -> 0
| Zero ps     ->     2 * length ps
| One (_, ps) -> 1 + 2 * length ps
```

It is also rather easy to write a `map` for binary random-access lists:

```let rec map : 'a 'b . ('a -> 'b) -> 'a seq -> 'b seq =
fun f l -> match l with
| Nil          -> Nil
| One (x,  ps) -> One (f x, map (fun (x, y) -> (f x, f y)) ps)
| Zero ps      -> Zero     (map (fun (x, y) -> (f x, f y)) ps)
```

Note two things: first, that both parameters need to be generalized as both the argument and the return type vary from invocation to invocation, as shown by the `Zero` case. Second, that there is no need to use `cons`, as `map` preserves the shape of the list. With this as a warm-up, writing `fold_right` is analogous:

```let rec fold_right : 'a 'b . ('a -> 'b -> 'b) -> 'a seq -> 'b -> 'b =
fun f l e -> match l with
| Nil          -> e
| One (x, ps)  -> f x (fold_right (fun (x, y) z -> f x (f y z)) ps e)
| Zero ps      ->      fold_right (fun (x, y) z -> f x (f y z)) ps e
```

Given a right fold, any catamorphism is a one-liner:

```let append l = fold_right cons l

let of_list l = List.fold_right cons l nil
and to_list l = fold_right (fun x l -> x :: l) l []
```

Now, armed with `fold_right`, filling up a list library is easy; but taking advantage of the logarithmic nature of the representation requires thought. For instance, building a random-access list of size n can be done in logarithmic time with maximal sharing:

```let repeat n x =
let rec go : 'a . int -> 'a -> 'a seq = fun n x ->
if n = 0 then Nil else
if n = 1 then One (x, Nil) else
let t = go (n / 2) (x, x) in
if n mod 2 = 0 then Zero t else One (x, t)
in
if n < 0 then failwith "repeat" else go n x
```

By analogy with binary adding, there is also a fast O(log n) merge:

```let rec merge : 'a . 'a seq -> 'a seq -> 'a seq =
fun l r -> match l, r with
| Nil        ,         ps
|         ps , Nil         -> ps
| Zero    ps , Zero    qs  -> Zero (merge ps qs)
| Zero    ps , One (x, qs)
| One (x, ps), Zero    qs  -> One (x, merge ps qs)
| One (x, ps), One (y, qs) -> Zero (cons (x, y) (merge ps qs))
```

It walks both lists, "adding" the "bits" at the head. The only complication is the case where both lists are heded by two `One`s, which requires rippling the carry with a call to `cons`. An alternative is to explicitly keep track of the carry, but that doubles the number of branches. This `merge` operation does not preserve the order of the elements on both lists. It can be used, however, as the basis for a very fast nondeterminism monad:

```let return x = One (x, Nil)
let join mm  = fold_right merge mm Nil
let bind f m = join (map f m)

let mzero = Nil
let mplus = merge
```

The rank-2 extension to the typing algorithm does not extend to signatures, as in Haskell. It only has effect on the typing of the function body, by keeping the type parameter fresh during unification, as the signature of the module shows:

```type 'a seq = Nil | Zero of ('a * 'a) seq | One of 'a * ('a * 'a) seq
val nil : 'a seq
val is_empty : 'a seq -> bool
val cons : 'a -> 'a seq -> 'a seq
val uncons : 'a seq -> 'a * 'a seq
val head : 'a seq -> 'a
val tail : 'a seq -> 'a seq
val lookup : int -> 'a seq -> 'a
val update : int -> 'a -> 'a seq -> 'a seq
val length : 'a seq -> int
val map : ('a -> 'b) -> 'a seq -> 'b seq
val fold_right : ('a -> 'b -> 'b) -> 'a seq -> 'b -> 'b
val append : 'a seq -> 'a seq -> 'a seq
val of_list : 'a list -> 'a seq
val to_list : 'a seq -> 'a list
val repeat : int -> 'a -> 'a seq
val merge : 'a seq -> 'a seq -> 'a seq
val return : 'a -> 'a seq
val join : 'a seq seq -> 'a seq
val bind : ('a -> 'b seq) -> 'a seq -> 'b seq
val mzero : 'a seq
val mplus : 'a seq -> 'a seq -> 'a seq
```

To use rank-2 types in interfaces it is still necessary to encode them via records or objects.

1 comment:

kirillkh said...

It's nice to know Okasaki's magic works with OCaml now. Thanks for this!