Wednesday, February 10, 2010

Continuation Monad

Last month I worked on a TopCoder problem that led me to continue my exploration of monads. The problem is to sort a set of elements using comparisons whose cost is a function of the comparands. Solutions comprise a pair of methods, one of which initializes the program with the costs, and the other of which is invoked repeatedly to effect the dialog between the program and the less-than oracle.

Now choose your own adventure. Do you:

  1. Ask how a real computer scientist would solve it
  2. Watch me muddle through comparison-based sorting with variable costs
  3. Follow my continuation monad implementation in Java

Comparison-based sorting with variable costs

Sidebar: true story

I've seen a related problem in real life. Our app iDeal helped salespersons develop contract terms subject to Boolean constraints called guidelines, analogous to the Boolean trees examined in Charikar et al. We didn't actually have to buy our guideline results, but we did care about evaluation speed. There were basically two prices: low (computable on our local app server) and high (computable at a minicomputer reached by at least two SOAP hops and a batch job. Unlike the Charikar model, our cost function was not additive: it behaved like or, giving a result of "expensive" whenever at least one guideline couldn't be computed locally and "cheap" otherwise. Consequently, the real-life challenge was less about algorithm design and more about software engineering: given that guideline evaluation is demanded at many call sites, how can we avoid unnecessarily consulting the minicomputer?

I first thought of divide and conquer approaches to avoid redundant queries (don't ask a < c if it is already known that a < b and b < c). I thought about algorithms for the special case when C(a, b) = 1: mergesort, heapsort, and quicksort. All of these relate to binary trees, and I thought about properties a binary search tree constructed at optimal cost would have in the general case.

Shouldn't the root of the tree be the element for which the sum over pairs of itself with each of its descendants is minimal? Intuitively, we compare every element with the root to locate the element in the tree, and we want to minimize the cost of inserting each element by this recursive process. Not quite: this fails to account for tree shape (consider p < q < r with costs C(p, X) = 1 for all X in {p, q, r}; C(q, r) = 2). And thinking back to the special case of a constant cost function, it's clear that this root selection heuristic has no power. It's also often mentioned in discussions of e.g., quicksort, that it's important to have a balanced tree. But we're not looking just to balance and fill the tree; It's easy to ensure that every left child in an optimal tree is a leaf node by carefully choosing the cost function.

The shape of the tree is determined both by the sequence of queries the program makes and by the order of elements determined by the oracle. The best choice for the root is the median of the elements, so an exact solution is O(n^2) but we could think about approximate selection and average-case performance.

What about a bottom-up approach? Do our trees have optimal substructure? First: in an optimal tree, is every subtree optimal? Suppose not. Then some optimal tree has a subtree that can be rearranged for lower cost. But this doesn't affect the cost to construct the rest of the tree, so after the rearrangement the overall cost is lowered. This contradicts the supposition that our original tree was optimal with a suboptimal subtree. But second: can we efficiently combine optimal subsolutions to derive optimal solutions?

A straightforward mergesort would guarantee a minimal number of comparisons, but not necessarily a minimal-cost set of comparisons. Although a minimal tree implies minimal subtrees, not all combinations of minimal subtrees constitute a (rearrangement of a) minimal tree. To see this, imagine a cost function that does not have the triangle property. In this case, the cheapest way to combine two trees might involve using additional vertices not a part of either original tree.

Let g(t) = sum(C(n, n') for n in t for n' in ~t). Recursively merge the runs with the lowest g.

Pf. g-ordered mergesort yields optimal solutions.

Let A and B represent two C-optimal subsolutions whose g-costs are least among all subsolutions. Claim: the combination S of A and B is C-optimal. Suppose not. Then there is a (sub)solution S' containing all the elements of a and b such that C(S) > C(S'). S' is a proper superset of S, so we can obtain it by adding elements from the complement of S. These missing elements must exist in other optimal subsolutions having higher g-costs than that of S. Also, each of these elements must have edges to each of a and b in the optimal tree of S. Consider any subsolution G hosting one of these missing elements. It's g-cost is \sum_{g \in G}(\sum_{a \in A}(C(g, a)) + \sum_{b \in B}(C(g, b)) + \sum_{c' \in c}(C(g, c'))), where a, b, and c' are drawn from our input sets and the complement of their union respectively. To simplify the notation, I'm going to switch now to using implicit summations: the g-cost of subsolution G is C(g, a) + C(g, b) + C(g, c). The g-cost of S is c(a, c) + c(b, c) + c(s, c). The g-cost of a is c(a, b) + c(a, c). The g-cost of b is c(a, b) + c(b, c). A little algebra and the observation that the complement of the union of a and b includes all the g \in G, we can derive a contradiction from the assumption that there exist elements outside the min-g-cost subsolutions a and b that would improve the C-cost of the combination of a and b:


c(s a) + c(s b) + c(s c) > max(c(a c) + c(a b); c(b c) + c(a b))
> max(c(S) - c(s, c) - c(b, c) + c(a, b);
c(S) - c(s, c) - c(a, c) + c(a, b))
c(s a) + c(s b) + 2c(s c) - c(a b) - c(S) > max(- c(b, c); - c(a c))
- c(s c) - c(a c) - c(b c)
c(s a) + c(s b) + c(s c) - c(a b) - c(a c) - c(b c) > -max(c(b c); c(a c))
c(s a) + c(s b) + c(s c) - c(a b) - c(a c) > 0
c(s a) + c(s b) + c(s c) - c(a b) - c(b c) > 0
c(s a) + c(s b) + c(s c) > c(a b) + c(a c)
c(s a) + c(s b) + c(s c) > c(a b) + c(b c)
c(a c) includes c(a s) so c(a c) > c(a s)
c(s b) + c(s c) > c(a b)
c(s a) + c(s c) > c(a b)
c(s b) - c(s a) > 0
c(s a) - c(s b) > 0
⟶⟵

Sadly, although this heuristic won't mislead us, it may often fail to answer by telling us that more than two trees have the least g-cost.

A more dynamic heuristic could rely on probabilistic estimates of the information gain at each query. Assuming uniform random permutation, initially we should just do the cheapest comparison. As we progress, we'll develop equivalence classes: these elements could be anywhere; those elements could only be in the top 6. Comparisons involving elements in large classes or classes near the min or max yield relatively little information. We could combine estimates of value with the given cost information to get reasonable average-case performance.

I decided not to work harder to prove a tight lower bound on the solution to the problem or to find a better algorithm or heuristic. It's TopCoder after all, and speed is important. I decided instead to move on to implementation and Google for a better answer later.

So in a nutshell my solution is: initially regard every element as a subsolution. Order them by g-cost and merge the two lowest-ranked solutions. Recurse.

Continuation monad

The online interactions between the TopCoder oracle and my solution raises an interesting question. The first thing that occurred to me of course was to write g and merge, arrange to store state in instance variables, and update my instance state incrementally in the two callback functions whose existence is mandated by the oracle. On second thought, I wondered if my code could more closely resemble the nutshell description I gave above, with the state maintenance handled separately. It seemed like a good excuse to get some first-hand experience with continuation-passing style, a canonical form for programs that can be paused and restarted. The continuation monad brings cps into a more general framework, improving modularity.

So I broke out my old favorites Monads for Functional Programming and Comprehending Monads. I also found a couple of new sources: Lecture notes: Continuations and call/cc and The Continuation Monad in Clojure. Continuations are also discussed at some length in the lambda papers, although the idea originated a decade earlier.

First, the monad preliminaries. When I implemented the maybe monad in C#, I had the ambition to support composition of different monads and that led to some difficulty with the type system. Also, I noticed how cumbersome the syntax was and what effect that had on my enthusiasm for explicit and separate expression of an abstraction for monads. And that was with C# 3, which has overtaken Java in its evolution towards greater expressive power. So this time I decided to focus strictly on the continuation monad and the question of whether it could enable me to separate the oracle i/o from the specification of my sort algorithm.

Recall the monad laws:


@Test
// unit a ⋆ λb. n = n[a/b]
public void leftUnit() {
Func1<String, CMonad<Integer>> f =
new Func1<String, CMonad<Integer>>() {
public CMonad<Integer> f(String a) {
return CMonad.unit(a.length());
}
};

String a = "a";
CMonad<Integer> expected =
f.f(a);
CMonad<Integer> actual =
CMonad.map(
f,
CMonad.unit(a));

assertMEquals(expected, actual);
}

@Test
// m ⋆ λa. unit a = m
public void rightUnit() {
CMonad<Integer> m =
CMonad.unit(42);
Func1<Integer, CMonad<Integer>> id =
new Func1<Integer, CMonad<Integer>>() {
public CMonad<Integer> f(Integer x) {
return CMonad.unit(x);
}
};

assertMEquals(
m,
CMonad.map(id, m));
}

@Test
// m ⋆ (λa. n ⋆ λb. o) = (m ⋆ λa. n) ⋆ λb. o
public void associative() {
Func1<String, CMonad<Integer>> fa =
new Func1<String, CMonad<Integer>>() {
public CMonad<Integer> f(String x) {

return CMonad.unit(x.length());
}
};

Func1<Integer, CMonad<Integer>> fb =
new Func1<Integer, CMonad<Integer>>() {
public CMonad<Integer> f(Integer x) {
return CMonad.unit(x * 2);
}
};

final CMonad<String> m =
CMonad.unit("m");

final CMonad<Integer> n =
CMonad.unit(5);

final Func1<Integer, CMonad<Integer>> o =
new Func1<Integer, CMonad<Integer>>() {
public CMonad<Integer> f(Integer x) {
return CMonad.unit(x - 1);
}
};

CMonad<Integer> expected =
CMonad.map(
new Func1<String, CMonad<Integer>>() {
public CMonad<Integer> f(String x) {
return
CMonad.map(
o,
n);
}
},
m);

CMonad<Integer> actual =
CMonad.map(
o,
CMonad.map(
new Func1<String, CMonad<Integer>>() {
public CMonad<Integer> f(String x) {
return n;
}
},
m));

assertMEquals(expected, actual);
}

(assertMEquals is more or less what you'd expect):


private void assertMEquals(CMonad expected, CMonad actual) {
Func1 id =
new CostlySorting.Func1() {
public T f(T x) {
return x;
}
};
T exp = expected.f(id);
T act = actual.f(id);
assertEquals(exp, act);
}

My continuation monad satisfies those:


public abstract static class CMonad<X> {
private CMonad() {}
public abstract <R> R f(Func1<X, R> k);
public static <X> CMonad<X> unit(final X x) {
return
new CMonad<X>() {
public <R> R f(Func1<X, R> k) {
return k.f(x);
}
};
}

// aka bind aka ★
public static <X, R> CMonad<R> map(final Func1<X, CMonad<R>> f, final CMonad<X> mx) {
return new CMonad<R>() {
public <S> S f(Func1<R, S> k) {
return mx.f(f).f(k);
}
};
}
}

call/cc seemed like a useful abstraction, so I implemented that in terms of my monad:


public Integer divideByZeroEg(final int x, final int y) {
CMonad<Integer> eg =
CMonad.callcc(
new Func1<Func1<Integer, CMonad<Integer>>, CMonad<Integer>>() {
public CMonad<Integer> f(Func1<Integer, CMonad<Integer>> esc) {
return
CMonad.map(
new Func1<Integer, CMonad<Integer>>() {
public CMonad<Integer> f(Integer z) {
return CMonad.unit(x / z);
}
},
(y == 0 ? esc.f(42) : CMonad.unit(y)));
}
});
return eg.f(
new Func1<Integer, Integer>() {
public Integer f(Integer x) {
return x;
}
});
}

@Test
public void divideByZeroEg() {
assertEquals((Object)5, divideByZeroEg(20, 4));
assertEquals((Object)42, divideByZeroEg(20, 0));
}

public static <X, Y> CMonad<X> callcc(final Func1<Func1<X, CMonad<Y>>, CMonad<X>> g) {
// .\ k -> g(.\x -> .\k' -> kx)k
// g :: ((X -> MY) -> MX)
// g :: ((X -> (Y -> R)) -> (X -> R'))
// k :: X -> R''
// R'' = R
// R = X -> R'
return
map(
new Func1<X, CMonad<X>>() {
public CMonad<X> f(X x) {
return unit(x);
}
},
new CMonad<X>() {
public <R> R f(final Func1<X, R> k) {
CMonad<X> gresult =
g.f(
new Func1<X, CMonad<Y>>() {
public CMonad<Y> f(final X x) {
return new CMonad<Y>() {
public <S> S f(Func1<Y, S> kprime) {
return (S)k.f(x);
}
};
}
});
return gresult.f(k);
}
});
}

and then I implemented a while loop atop continuation monad and call/cc:


@Test
public void finiteWhile() {
final int [] countdown = new int [] {3};
CMonad.cpswhile(
new Func0<Boolean>() {
public Boolean f() {
return countdown[0] > 0;
}
},
new SideEffectful() {
public void f() {
countdown[0] -= 1;
}
});
assertEquals(0, countdown[0]);
}

private static <T> CMonad<T> recurse(CMonad<T> m) {
final WhateverClosure<CMonad<T>> loop =
new WhateverClosure<CMonad<T>>();
loop.whatever =
CMonad.map(
new Func1<T, CMonad<T>>() {
public CMonad<T> f(T x) {
System.out.println("recurse");
return loop.whatever;
}
},
m);
return loop.whatever;
}

public static void cpswhile(final Func0<Boolean> loopInvariant, final SideEffectful body) {
callcc(
new Func1<Func1<Object, CMonad<Object>>, CMonad<Object>>() {
public CMonad<Object> f(final Func1<Object, CMonad<Object>> esc) {
return recurse(
new CMonad<Object>() {
public <R> R f(Func1<Object, R> f) {
System.out.println("start.f");
if (!loopInvariant.f()) {
System.out.println("done");
return
esc.f(null).f(f);
}
System.out.println("body.f");
body.f();
return f.f(null);
}
});
}
}).f(new Func1<Object, Object>() {
public Object f(Object x) {
return x;
}
});
}

public interface SideEffectful {
public void f();
}

Why reimplement the while? Java's while keyword only composes in the predefined ways that Java's language designers cared about. I wouldn't be able to effect oracle I/O between loop iterations, except by resorting to contrived method invocations in the loop body or test. Also, because of the communications mechanism is method invocation, I couldn't actually use while without also using threads. By using composable continuation monads for control flow, I hoped to express my sort algorithm separately from managing the data flow between the program and the TC oracle.

Of course, Lisp and Haskell don't always translate easily to Java:


@Test
// (recipe for stack overflow)
public void infiniteWhile() {
// customize this value for your stack size/patience
int approximatelyInfinity = 1;

final Calendar finish = Calendar.getInstance();
finish.add(Calendar.SECOND, approximatelyInfinity);
//CMonad.cpsbounce(
CMonad.cpsbouncewhile(
new Func0<Boolean>() {
public Boolean f() {
return finish.after(Calendar.getInstance());
}
},
new SideEffectful() {
public void f() {
}
});
}

public static void cpsbouncewhile(final Func0<Boolean> loopInvariant, final SideEffectful body) {
CMonad<Object> mo =
callcc(
new Func1<Func1<Object, CMonad<Object>>, CMonad<Object>>() {
public CMonad<Object> f(final Func1<Object, CMonad<Object>> esc) {
final CMonad<Object> start =
new CMonad<Object>() {
public <R> R f(Func1<Object, R> f) {
System.out.println("start.f");
if (!loopInvariant.f()) {
System.out.println("done");
return
esc.f(null).f(f);
}
System.out.println("body.f");
body.f();
return f.f(null);
}
};

final WhateverClosure<CMonad<Object>> restartHolder = new WhateverClosure<CMonad<Object>>();
restartHolder.whatever =
map(
new Func1<Object, CMonad<Object>>() {
public CMonad<Object> f(Object x) {
System.out.println("restart.f");
return
esc.f(
Bounce.wrap(
restartHolder.whatever));
}
},
start);

return restartHolder.whatever;
}
});
for (; mo != null;
mo =
(CMonad<Object>)mo.f(new Func1<Object, Object>() {
public Object f(Object x) {
return x;
}
})) {
if (!(mo instanceof Bounce)) continue;
System.out.println("bouncey");
}
}

So then how'd I do? Well, the test code for while loops is fairly readable, but my goal was to write something like


// T is a priority queue of sorted runs
while (T.size() > 1) {
T.add(merge(comparator, T.remove(), T.remove());
}

and have that automatically translated to a continuation monad representing the progress of the loop as a function from Booleans (comparison results) to pairs of its own type and pairs of elements to compare next.

Dealing with the while per se was not bad, but translating the body requires quite a lot of fussiness. We need not only a restartable outer (while) loop, but also a restartable inner (merge) loop and we want these flattened so that the while's client code can be oblivious to the implementation choice to use functional composition in the while body. Looking at things a different way, we would like to encapsulate data flow so that a caller can repeatedly supply data to meet the demands of the callee until the callee finishes.

We could use call/cc whenever we need input to obtain a restart wherever we need input. These restarts are relatively easy to pass around (no need for joins), but they only support one standard operation (escape). Accordingly, each adjacent caller/callee pair have to be tweaked to expect/return restarts (analogously to monads in the above paragraph) and each caller has to branch at each invocation to either pass back a restart or move on with its own computation.

Alternatively, we could add a function to represent function application within the continuation monad framework. We could tweak while so that its body argument has a result type of continuation-monad, and then this new apply function could yield such a monad; while could be responsible for composing that result into its own result (using the join operation from MMA -> MA).

Although we're able to dispense with the repetitive branching at function invocations in the call/cc approach by using monads, we can't escape the need to reimplement Java all along the way. If the body of a while loop includes its own function delegation, we'll have make that monadic. If we need a conditional branch, we'll have to make that monadic. Even to implement merge sort we'd wind up with multiple nested anonymous class instance creation expressions. In between scads of braces, parens, brackets, and repetitions of type names like "Func1," we'd have the names of the functions we used to replace all the Java primitives. Lacking syntactic abstraction, the readability of the resulting DSL doesn't scale beyond the simple examples in this blog post.

No comments:

Post a Comment