Unification Opgave

Afp0405

-- DoaitseSwierstra - 30 Sep 2003

  1. Define a sufficiently rich StateMonad instance, and construct a working type inferencer.
  2. Change the definitions so you also count the number of unifications needed when inferencing the code.

-- NOTE THE FOLLOWING

-- fun is nowadays called fmap:
-- class Functor f where
--   fmap :: (a -> b) -> (f a -> f b)

-- Note that in the Haskell libraries `bind` is called >>= 
-- and result is called return
-- the standard Haskell module Monad that we import already contains a fail function:
-- class Monad m where
--  (>>=) :: m a -> (a -> m b) -> m b
--  (>>) :: m a -> m b -> m b
--  return :: a -> m a
--  fail :: String -> m a
--  m >> k = m >>= \_ -> k
--  fail s = error s -- note that you will have to override this function, since a call
--                      to error aborts the execution


module Infer where
import Monad



class Monad m => StateMonad m s where
  update :: (s -> s) -> m s

--data type InferAll wich has all properties of a Error And a State
data InferAll s a = InferST (s -> (a,s))  | InferFail String | InferOK a

--make an instance of the Monad for InferAll
instance Monad (InferAll s) where
 return x       = InferST  (\s -> (x,s))
 m >>= f           = InferST (\s -> let InferST m' = m
                          (x,s1)= m' s
                               InferST f' = f x
                          (y,s2) = f' s1
                        in (y,s2))
 InferOK x >>= f   = f x
 InferFail msg  >>= f    = InferFail msg
 fail         = InferFail
 
--make an instance of StateMonad for InferAll      
instance StateMonad (InferAll s) s where
 update f = InferST (\s -> (s, f s))
  
 
incr :: StateMonad m Int => m Int
incr = update (1+)

data Type v = TVar v -- Type variable
            | TInt -- Integer type
            | Fun (Type v) (Type v)
              deriving Show

data Term = Var Name      -- variable
          | Ap Term Term  -- application
          | Lam Name Term -- lambda abstraction
          | Num Int       -- numeric literal

type Name = String

type Subst m v = v -> m v

instance Functor Type where
  fmap f (TVar v)  = TVar (f v)
  fmap f TInt      = TInt
  fmap f (Fun d r) = Fun (fmap f d) (fmap f r)

instance Monad Type where
  return v      = TVar v
  TVar v  >>= f = f v
  TInt    >>= f = TInt
  Fun d r >>= f = Fun (d >>= f) (r >>= f)

apply s t = t >>= s

(@@) :: (Functor m, Monad m) => (a -> m b) -> (c -> m a) -> (c -> m b)
f @@ g = join . fmap f . g


(>>>) :: (Eq v, Monad m) => v -> m v -> Subst m v
(v >>> t) w = if v==w then t else return w

--To count the number of unification while inferencing the not I will use a statemonad to do so 
unify :: (Show v, Monad m, Eq v) => Type v -> Type v -> m (Subst Type v)
unify TInt TInt = return return

unify (TVar v) (TVar w) = return (if v==w then return
                                 else v >>> TVar w)

unify (TVar v) t = varBind v t
unify t (TVar v) = varBind v t

unify (Fun d r) (Fun e s) = 
                            unify d e >>= \s1 ->
                            unify (apply s1 r)
                                  (apply s1 s) 
                                      >>= \s2 ->
                            return (s2 @@ s1)

unify t1 t2 = fail ("Cannot unify " ++ show t1 
                                    ++
                    " with " ++ show t2)


varBind v t = if (v `elem` vars t)
              then fail "Occurs check fails"
              else return (v >>> t)
  where vars (TVar v) = [v]
        vars TInt = []
        vars (Fun d r) = vars d ++ vars r


data Env t = Ass [(Name,t)]


emptyEnv :: Env t
emptyEnv = Ass []

extend :: Name -> t -> Env t -> Env t
extend v t (Ass as) = Ass ((v,t):as)


locate v (Ass as) = foldr find err as
   where find (w,t) alt = if w==v then return t else alt
         err = fail ("Unbound variable: " ++ v)

instance Functor Env where
 fmap f (Ass as) = Ass [ (n, f t) | (n,t) <- as ]

-- Note that I have used the do-notation in infer 
infer :: StateMonad m Int => Env (Type Int) -> Term -> m (Int -> Type Int,Type Int)

infer a (Var v)   = locate v a >>= \t ->  return (return,t)
infer a (Num n)   = return (return, TInt)
infer a (Lam v e) = do{ b <- incr 
                      ; (s, t) <- infer (extend v (TVar b) a) e 
                      ; return (s, s b `Fun` t)
                      }
infer a (Ap l r)  = do{ (s, lt) <- infer a l 
                      ; (t, rt) <- infer (fmap (apply s) a) r 
                      ; b <- incr 
                      ; u <- unify (apply t lt) (rt `Fun` TVar b) 
                      ; return (u @@ t @@ s, u b)
                      }

<\verbatim>