Wednesday, June 8, 2011

Logic programming in Scala, part 3: unification and state

In this post I want to build on the backtracking logic monad we covered last time by adding unification, yielding an embedded DSL for Prolog-style logic programming.

Prolog

Here is a small Prolog example, the rough equivalent of List.contains in Scala:

  member(X, [X|T]). 
  member(X, [_|T]) :- member(X, T). 

Member doesn’t return a boolean; instead it succeeds or fails (in the same way as the logic monad). The goal member(1, [1,2,3]) succeeds; the goal member(4, [1,2,3]) fails. (What happens for member(1, [1,1,3])?)

A Prolog predicate is defined by one or more clauses (each ending in a period), made up of a head (the predicate and arguments before the :-) and zero or more subgoals (goals after the :-, separated by commas; if there are no subgoals the :- is omitted). To solve a goal, we unify it (match it) with each clause head, then solve each subgoal in the clause. If a subgoal fails we backtrack and try the next matching head; if there is no matching head the goal fails. A goal may succeed more than once.

For member we have two clauses: the first says that member succeeds if X is the head of the list ([X|T] is the same as x::t in Scala); the second says that member succeeds if X is a member of the tail of the list, regardless of the head. There is no clause where the list is empty (written []); a goal with an empty list fails since there is no matching clause head.

Prolog unification is more expressive than pattern matching as found in Scala, OCaml, etc. Both sides of a unification may contain variables; unification attempts to instantiate them so that the two sides are equal. Variables are instantiated by terms, which themselves may contain variables; unification finds the most general instantiation which makes the sides equal.

As a small example of this expressivity, we can run member “backwards”: the goal member(X, [1,2,3]) succeeds once for each element of the list, with X bound to the element.

There is much more on Prolog and logic programming in Frank Pfenning’s course notes, which I recommend highly.

Unification

For each type we want to use in unification we’ll define a corresponding type of terms, which have the same structure as the underlying type but can also contain variables. These aren’t Scala variables (which of course can’t be stored in a data structure) but “existential variables”, or evars. Evars are just tags; computations will carry an environment mapping evars to terms, which may be updated after a successful unification.

import scala.collection.immutable.{Map,HashMap} 
 
class Evar[A](val name: String) 
object Evar { def apply[A](name: String) = new Evar[A](name) } 
 
trait Term[A] { 
  // invariant: on call to unify, this and t have e substituted 
  def unify(e: Env, t: Term[A]): Option[Env] 
 
  def occurs[B](v: Evar[B]): Boolean 
  def subst(e: Env): Term[A] 
  def ground: A 
} 

The important property of an evar is that it is distinct from every other evar; the name attached to it is just a label. An evar is indexed by a phantom type indicating the underlying type of terms which may be bound to it.

A term is indexed by its underlying type. So Int becomes Term[Int], String becomes Term[String], and so on; an evar of type Evar[A] may only be bound to a term of type Term[A]. (Prolog is dynamically typed, but this statically-typed treatment of evars and terms fits better with Scala.)

The unify method unifies a term with another term of the same type, taking an environment and returning an updated environment (or None if the unification fails). Occurs checks if an evar occurs in a term (as we will see this is used to prevent circular bindings). Subst substitutes the variables in a term with their bindings in an environment, and ground returns the underlying Scala value represented by the term (provided the term contains no evars).

class Env(m: Map[Evar[Any],Term[Any]]) { 
  def apply[A](v: Evar[A]) = 
    m(v.asInstanceOf[Evar[Any]]).asInstanceOf[Term[A]] 
  def get[A](v: Evar[A]): Option[Term[A]] = 
    m.get(v.asInstanceOf[Evar[Any]]).asInstanceOf[Option[Term[A]]] 
  def updated[A](v: Evar[A], t: Term[A]): Env = { 
    val v2 = v.asInstanceOf[Evar[Any]] 
    val t2 = t.asInstanceOf[Term[Any]] 
    val e2 = Env(Map(v2 -> t2)) 
    val m2 = m.mapValues(_.subst(e2)) 
    Env(m2.updated(v2, t2)) 
  } 
} 
object Env { 
  def apply(m: Map[Evar[Any],Term[Any]]) = new Env(m) 
  def empty = new Env(HashMap()) 
} 

An environment is just a map from evars to terms. Because we need to store evars and terms of different types in the same environment, we cast them to and from Any; this is safe because of the phantom type on Evar. For simplicity we maintain the invariant that the term bound to each evar is already substituted by the rest of the environment.

case class VarTerm[A](v: Evar[A]) extends Term[A] { 
  def unify(e: Env, t: Term[A]) = 
    t match { 
      case VarTerm(v2) if (v2 == v) => Some(e) 
      case _ => 
        if (t.occurs(v)) None 
        else Some(e.updated(v, t)) 
    } 
 
  def occurs[B](v2: Evar[B]) = v2 == v 
 
  def subst(e: Env) = 
    e.get(v) match { 
      case Some(t) => t 
      case None => this 
    } 
 
  def ground = 
    throw new IllegalArgumentException("not ground") 
 
  override def toString = { v.name  } 
} 

The VarTerm class represents terms consisting of an evar. To unify a VarTerm with another VarTerm containing the same evar, we just return the environment unchanged (since there is no new information). Otherwise we check that the evar doesn’t appear in the term (since a unification x =:= List(x) would create a circular term) then return the updated environment.

To substitute a VarTerm we return the term bound to the evar in the environment if one exists, otherwise the unsubstituted VarTerm. A VarTerm is never ground (we assume ground is called only on terms which are already substituted by the environment).

case class LitTerm[A](a: A) extends Term[A] { 
  def unify(e: Env, t: Term[A]) = 
    t match { 
      case LitTerm(a2) => if (a == a2) Some(e) else None 
      case _: VarTerm[_] => t.unify(e, this) 
      case _ => None 
    } 
 
  def occurs[B](v: Evar[B]) = false 
  def subst(e: Env) = this 
  def ground = a 
 
  override def toString = { a.toString } 
} 

LitTerm represents terms of literal Scala values. A LitTerm unifies with another LitTerm containing an equal value, but that adds nothing to the environment. Then we have two cases which we need for every term type—to unify with a VarTerm call unify back on it; otherwise fail.

case class NilTerm[A]() extends Term[List[A]] { 
  def unify(e: Env, t: Term[List[A]]) = 
    t match { 
      case NilTerm() => Some(e) 
      case _: VarTerm[_] => t.unify(e, this) 
      case _ => None 
    } 
 
  def occurs[B](v: Evar[B]) = false 
  def subst(e: Env) = this 
  def ground = Nil 
 
  override def toString = { Nil.toString } 
} 
 
case class ConsTerm[A](hd: Term[A], tl: Term[List[A]]) 
  extends Term[List[A]] 
{ 
  def unify(e: Env, t: Term[List[A]]) = 
    t match { 
      case ConsTerm(hd2, tl2) => 
        for { 
          e1 <- hd.unify(e, hd2) 
          e2 <- tl.subst(e1).unify(e1, tl2.subst(e1)) 
        } yield e2 
      case _: VarTerm[_] => t.unify(e, this) 
      case _ => None 
    } 
 
  def occurs[C](v: Evar[C]) = hd.occurs(v) || tl.occurs(v) 
  def subst(e: Env) = ConsTerm(hd.subst(e), tl.subst(e)) 
  def ground = hd.ground :: tl.ground 
 
  override def toString = { hd.toString + " :: " + tl.toString } 
} 

NilTerm and ConsTerm represent the Nil and :: constructors for lists. Nil is sort of like a literal, so the methods for NilTerm are similar to those for LitTerm. For ConsTerm we unify by unifying the heads and tails, calling subst on the tails since unifying the heads may have added bindings to the environment. (Here it’s convenient to use a for-comprehension on the Option[Env] type since either unification may fail.) Similarly we implement occurs, subst, and ground by calling them on the head and tail.

object Term { 
  implicit def var2Term[A](v: Evar[A]): Term[A] = VarTerm(v) 
  //implicit def lit2term[A](a: A): Term[A] = LitTerm(a) 
  implicit def int2Term(a: Int): Term[Int] = LitTerm(a) 
  implicit def list2Term[A](l: List[Term[A]]): Term[List[A]] = 
    l match { 
      case Nil => NilTerm[A] 
      case hd :: tl => ConsTerm(hd, list2Term(tl)) 
    } 
} 

Finally we have some implicit conversions to make it a little easier to build Term values. The lit2term conversion turned out to be a bad idea; in particular you don’t want a LitTerm[List[A]] since it doesn’t unify with a ConsTerm[A] or NilTerm[A].

State

In order to combine unification with backtracking, we need to keep track of the environment along each branch of the tree of choices. We don’t want the environments from different branches to interfere, so it’s convenient to use a purely functional environment representation; we pass the current environment down the tree as computation proceeds. However, we can hide this state passing in the monad interface:

trait LogicState { L => 
  type T[S,A] 
  // as before 
  def split[S,A](s: S, t: T[S,A]): Option[(S,A,T[S,A])] 
 
  def get[S]: T[S,S] 
  def set[S](s: S): T[S, Unit] 
 
  case class Syntax[S,A](t: T[S,A]) { 
    // as before 
    def &[B](t2: => T[S,B]): T[S,B] = L.bind(t, { _: A => t2 }) 
  } 
} 

LogicState is mostly the same as Logic, except that the type of choices has an extra parameter for the type of the state. The get and set functions get and set the current state. To split we need an initial state to get things started, and each result includes an updated state. Finally we add the syntax & to sequence two computations, ignoring the value of the first. We’ll use this to sequence goals, since we care only about the updated environment.

The simplest implementation of LogicState builds on Logic:

trait LogicStateT extends LogicState { 
  val Logic: Logic 
 
  type T[S,A] = S => Logic.T[(S, A)] 

We embed state-passing in a Logic.T as a function from an initial state to a choice of alternatives, where each alternative includes an updated state along with its value.

  def fail[S,A] = { s: S => Logic.fail } 
  def unit[S,A](a: A) = { s: S => Logic.unit((s, a)) } 
 
  def or[S,A](t1: T[S,A], t2: => T[S,A]) = 
    { s: S => Logic.or(t1(s), t2(s)) } 
 
  def bind[S,A,B](t: T[S,A], f: A => T[S,B]) = { 
    val f2: ((S,A)) => Logic.T[(S,B)] = { case (s, a) => f(a)(s) } 
    { s: S => Logic.bind(t(s), f2) } 
  } 
 
  def apply[S,A,B](t: T[S,A], f: A => B) = { 
    val f2: ((S,A)) => ((S,B)) = { case (s, a) => (s, f(a)) } 
    { s: S => Logic.apply(t(s), f2) } 
  } 
 
  def filter[S,A](t: T[S,A], p: A => Boolean) = { 
    val p2: ((S,A)) => Boolean = { case (_, a) => p(a) } 
    { s: S => Logic.filter(t(s), p2) } 
  } 

All of these operations pass the state through unchanged. Note that or passes the same state to both alternatives—different branches of the tree cannot interfere with one another’s state.

  def split[S,A](s: S, t: T[S,A]) = { 
    Logic.split(t(s)) match { 
      case None => None 
      case Some(((s, a), t)) => Some((s, a, { _ => t })) 
    } 
  } 
 
  def get[S] = { s: S => Logic.unit((s,s)) } 
  def set[S](s: S) = { _: S => Logic.unit((s,())) } 
} 

In split we pass the given state to the underlying Logic.T, and for each alternative we unpack the pair of state and value. The choice of remaining alternatives t encapsulates the current state, so when we return it we ignore the input state. In get and set we return and replace the current state.

Another approach is to pass state explicitly through LogicSFK:

object LogicStateSFK extends LogicState { 
  type FK[R] = () => R 
  type SK[S,A,R] = (S, A, FK[R]) => R 
 
  trait T[S,A] { def apply[R](s: S, sk: SK[S,A,R], fk: FK[R]): R } 

This is not really any different from LogicStateT applied to LogicSFK—we have just uncurried the state argument. We can take the same path as last time and defunctionalize this into a tail-recursive implementation (see the full code) although LogicStateT applied to LogicSFKDefuncTailrec inherits tail-recursiveness from the underlying Logic monad.

Scrolog

Finally we can put the pieces together into a Prolog-like embedded DSL:

trait Scrolog { 
  val LogicState: LogicState 
  import LogicState._ 
 
  type G = T[Env,Unit] 

From our point of view, a goal is a stateful choice among alternatives, where we don’t care about the value returned, only the environment.

  class TermSyntax[A](t: Term[A]) { 
    def =:=(t2: Term[A]): G = 
      for { 
        env <- get 
        env2 <- { 
          t.subst(env).unify(env, t2.subst(env)) match { 
            case None => fail[Env,Unit] 
            case Some(e) => set(e) 
          } 
        } 
      } yield env2 
  } 
 
  implicit def termSyntax[A](t: Term[A]) = new TermSyntax(t) 
  implicit def syntax[A](t: G) = LogicState.syntax(t) 

We connect term unification to the stateful logic monad with a wrapper class defining a =:= operator. To unify terms in the monad, we get the current environment, substitute it into the two terms (to satisfy the invariant above), then call unify; if it fails we fail the computation, else we set the new state.

  def run[A](t: G, n: Int, tm: Term[A]): List[Term[A]] = 
    LogicState.run(Env.empty, t, n) 
      .map({ case (e, _) => tm.subst(e) }) 
} 

The run function solves a goal, taking as arguments the goal, the maximum number of solutions to find, and a term to be evaluated in the environment of each solution.

Examples

First we need to set up Scrolog:

val Scrolog = 
  new Scrolog { val LogicState = 
    new LogicStateT { val Logic = LogicSFKDefuncTailrec } 
  } 
import Scrolog._ 

Here is a translation of the member predicate:

  def member[A](x: Term[A], l: Term[List[A]]): G = { 
    val hd = Evar[A]("hd"); val tl = Evar[List[A]]("tl") 
    ConsTerm(x, tl) =:= l | 
    (ConsTerm(hd, tl) =:= l & member(x, tl)) 
  } 

We implement predicates by functions, and goals by function calls. To implement matching the clause head, we explicitly unify the input arguments against each clause head, and combine the clauses with |. Subgoals are sequenced with &. Finally, we must create local evars explicitly, since they are fresh for each call (just as local variables are in Scala).

Finally we can run the goal above:

scala> val x = Evar[Int]("x") 
scala> run(member(x, List[Term[Int]](1, 2, 3)), 3, x) 
res6: List[Term[Int]] = List(1, 2, 3) 

As another example, we can implement addition over unary natural numbers. In Prolog this would be

  sum(z, N, N). 
  sum(s(M), N, s(P)) :- sum(M, N, P). 

In Prolog we can just invent symbols like s and z; in Scala we need first to define a type of natural numbers, then terms over that type:

  sealed trait Nat 
  case object Z extends Nat 
  case class S(n: Nat) extends Nat 
 
  case object ZTerm extends Term[Nat] { 
    // like NilTerm 
 
  case class STerm(n: Term[Nat]) extends Term[Nat] { 
    // like ConsTerm 

Then we can define sum, again separating the clauses by | and explicitly unifying the clause heads:

  def sum(m: Term[Nat], n: Term[Nat], p: Term[Nat]): G = { 
    val m2 = Evar[Nat]("m"); val p2 = Evar[Nat]("p") 
    (m =:= Z & n =:= p) | 
    (m =:= STerm(m2) & p =:= STerm(p2) & sum(m2, n, p2)) 
  } 

We can use sum to do addition:

scala> val x = Evar[Nat]("x"); val y = Evar[Nat]("y") 
scala> run(sum(S(Z), S(S(Z)), x), 1, x) 
res8: List[Term[Nat]] = List(S(S(S(Z)))) 

or subtraction:

scala> run(sum(x, S(S(Z)), S(S(S(Z)))), 1, x) 
res10: List[Term[Nat]] = List(S(Z)) 
 
scala> run(sum(S(Z), x, S(S(S(Z)))), 1, x) 
res11: List[Term[Nat]] = List(S(S(Z))) 

or even to find all the pairs of naturals which sum to 3:

scala> run(sum(x, y, S(S(S(Z)))), 10, List[Term[Nat]](x, y)) 
res14: List[Term[List[Nat]]] = 
  List(Z :: S(S(S(Z))) :: List(), 
       S(Z) :: S(S(Z)) :: List(), 
       S(S(Z)) :: S(Z) :: List(), 
       S(S(S(Z))) :: Z :: List()) 

although the printing of Term[List] could be better.

This is only a small taste of the expressivity of Prolog-style logic programming. Again let me recommend Frank Pfenning’s course notes, which explore the semantics of Prolog in a “definitional interpreters” style, by gradually refining an interpreter to expose more of the machinery of the language.

See the full code.