{-|
Module: Types
Description: Type inference and checking
Author: gatlin@niltag.net

This is still a work in progress.

= No really, what's this type system?

Eventually? I want

* Linear
* call-by-push-value with
* arbitrary rank type and kind polymorphism and
* graded (co)effects.

We can take this one at a time.

= Linear

/Linear logic/, seemingly like everything else in academia, is at once a dense
subject of study with profound implications, and also a very simple idea.

As it applies to programming the basic idea is to think of types as defining
/resources/, and the values of those types as matter which can't simply be
/duplicated/ or /discarded/ (ignored) at will!
The arrogance!
So if you bind a value to a name in some scope you must use it __exactly once__
within that scope.

Here is what pure linear file I/O would look like in pseudo-code:

> file_handle := open_file ("foo.txt");
> (file_handle, line_1) := read_line (file_handle);
> (file_handle, line_2) := read_line (file_handle);
> close (file_handle);

Notice that @file_handle@ must be re-bound after each use, because each binding
/must be used exactly once/.
There are many systems for relaxing and extending this system to other uses, but
that is the basic premise.

The appeal of linear logic in programming circles, it seems, stemmed from the
promise of being able to prevent memory leaks statically at compile-time --
if you can't use a pointer twice without explicitly duplicating it, it's hard to
use it a million times on accident in a loop!

On the other hand, not being able to write @let square x = x * x@ without
resorting to a helper utility is a pretty draconian restriction.
One of psilo's goals is to deliver on the promises of linear logic in an
intuitive way such that linearity /aids/ development.

= Call-By-Push-Value

Most languages are "strict", or /call-by-value/: all procedure arguments must
be fully evaluated before they may be passed along.
There are usually ways to explicitly circumvent this for when a computation
needs to be delayed.

Some, like Haskell, are "lazy", or /call-by-name\/need/: a term passed to a
function is not evaluated fully until absolutely necessary.

It's why in Haskell, this interactive session terminates:

> > fix f = let x = f x in x
> > bottom = fix id
> > 1 + (const 2 bottom)
> 3

@bottom@ is not actually evaluated by @const@ and so the fact that it does not
terminate is no matter.

== Psilo is both and neither!

/Call-by-push-value/ subsumes both of these and instead encodes evaluation order
in the type system.

A __positive__ type is "data": it is defined by the shape of its values.
You could conceivably come up with a way to print it out on paper.
It's static, at rest.

A __negative__ type is a "computation": it a process executing, a function
during application, dynamic, "in motion."

Crucially, the system provides ways to convert a term between these two
polarities: for instance, a function might be "suspended" to be "resumed" later,
and the resulting /thunk/ could be passed as an argument.
Likewise, a static piece of data may be promoted into a do-nothing function that
simply returns that value.

And it so happens it combines well with linearity.

= All together, say it with me: Linear Call-By-Push-Value!

It turns out that linear types and CBPV very naturally and elegantly pair with
one another: positive terms are precisely those which may be duplicated or
ignored without a second-thought; negative terms are those which must be treated
linearly.

Example: a function @λ x. f x@ is negative, while its argument @x@ is positive
and the function term @f@ is also negative, as is the entire expression @f x@.

= Coeffects!

It so-happens that you can regard the linear logic I described above as a
specialization of a more general idea called /coeffects/.
I call them coeffects because

1. Believe it or not that's the term in the literature, and
2. The only other appropriate name I can honestly think of is "context" which is
the most overloaded term in programming.

You can think of the linear logic described above as a coeffect where each bound
term has a number attached to it defining how many times it may be used in a
scope.

A more exotic example might be defining an /enum/ of security roles that attach
to every term the moment they are created, and rules for then performing
security auditing /statically at compile time/.

A coeffect system is a set of rules for defining contextual data like that,
along with the rules for the compiler to then enforce certain constraints.

Part of my current work is figuring out how this integrates with linear CBPV and
thus how it should be represented syntactically.

= Arbitrary-rank type and kind polymorphism.

I'm out of time for today!

* kinds: the "types of types". Basic types have kind @*@, and a function from
type @A : *@ to @B : *@ is itself a third function type, @A -> B : * -> * -> *@.

* arbitrary rank types: I can write types like, eg
@∀A B. (∀R. (A -> R) -> R) -> B@. Future drafts will spell this out more.

* "kind polymorphism?" Yeah what, did you think @*@ is the only kind?

= That sounds like a tall order.

At the moment, besides the rules for @shift@ and @reset@, this module
competently implements a basic version of System F with let-polymorphism.
Once the delimited control operators are settled, the next step will probably be
adding surface syntax for the type language and implementing arbitrary-rank
polymorphism, because in general my inference algorithm will need to be assisted
/just a little/ by the programmer sometimes with type annotations, and I don't
have a type annotation syntax yet.

After that, finally, I'll likely be ready to tackle (co)effects.

Papers and posts I am cribbing from:

* https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/putting.pdf
* https://github.com/sdiehl/write-you-a-haskell/blob/master/chapter7/poly_constraints/src/Infer.hs
* https://cs-people.bu.edu/gaboardi/publication/GaboardiEtAlIicfp16.pdf
* others, stay tuned.

-}

{-# LANGUAGE DataKinds, PolyKinds, ExistentialQuantification, RankNTypes #-}
{-# LANGUAGE FlexibleInstances, GeneralizedNewtypeDeriving, StandaloneDeriving #-}

module Types
where

import Control.Monad
  ( foldM
  , forM
  , liftM2
  , mapAndUnzipM )
import Control.Monad.State
  (StateT(..)
  , evalStateT
  , get
  , gets
  , modify
  , put )

import Data.Functor.Identity (Identity(..))
import Control.Monad.Except (Except(..), ExceptT(..),runExcept,runExceptT,throwError)
import Control.Monad.Reader (ReaderT(..), runReaderT, ask, local)
import Control.Monad.Writer (WriterT, runWriterT, tell)
import Control.Comonad (Comonad(..), extract, extend)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import Data.List (intersperse, intercalate, foldl', foldl1, nub)
import Data.Maybe (fromJust)
import qualified Data.Graph as G

import Text.Show.Unicode (ushow)
import Syntax (Symbol, Cbpv(..), CbpvExp, deps)
import Control.Comonad.Cofree (Cofree(..))

-- * (basic) Type language

type TypeVar = Int

data Type
  = TVar TypeVar
  | TCon String
  | Type :-> Type
  | TForall [TypeVar] Type
  deriving (Type -> Type -> Bool
(Type -> Type -> Bool) -> (Type -> Type -> Bool) -> Eq Type
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Type -> Type -> Bool
$c/= :: Type -> Type -> Bool
== :: Type -> Type -> Bool
$c== :: Type -> Type -> Bool
Eq, Eq Type
Eq Type
-> (Type -> Type -> Ordering)
-> (Type -> Type -> Bool)
-> (Type -> Type -> Bool)
-> (Type -> Type -> Bool)
-> (Type -> Type -> Bool)
-> (Type -> Type -> Type)
-> (Type -> Type -> Type)
-> Ord Type
Type -> Type -> Bool
Type -> Type -> Ordering
Type -> Type -> Type
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Type -> Type -> Type
$cmin :: Type -> Type -> Type
max :: Type -> Type -> Type
$cmax :: Type -> Type -> Type
>= :: Type -> Type -> Bool
$c>= :: Type -> Type -> Bool
> :: Type -> Type -> Bool
$c> :: Type -> Type -> Bool
<= :: Type -> Type -> Bool
$c<= :: Type -> Type -> Bool
< :: Type -> Type -> Bool
$c< :: Type -> Type -> Bool
compare :: Type -> Type -> Ordering
$ccompare :: Type -> Type -> Ordering
$cp1Ord :: Eq Type
Ord)

instance Show Type where
  show :: Type -> String
show (TVar Int
n) = String
"_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
ushow Int
n
  show (TCon String
s) = String
s
  show (TCon String
"" :-> Type
t) = String
"{ " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
ushow Type
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" }"
  show (Type
t1 :-> Type
t2) = (Type -> String
forall a. Show a => a -> String
ushow Type
t1) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" -> " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Type -> String
forall a. Show a => a -> String
ushow Type
t2)
  show (TForall [] Type
t) = Type -> String
forall a. Show a => a -> String
ushow Type
t
  show (TForall [Int]
tvs Type
t) = String
"∀ " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
tvs' String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
". " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Show a => a -> String
ushow Type
t
    where tvs' :: String
tvs' = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ String -> [String] -> [String]
forall a. a -> [a] -> [a]
intersperse String
" " ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$ (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
n -> String
"_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
ushow Int
n) [Int]
tvs

-- these exist to remind us of constraints for now,
-- and in the future they may become type families of some sort
type Sigma = Type
type Rho = Type
type Tau = Type

ty_int, ty_float, ty_bool, ty_bottom :: Rho
ty_int :: Type
ty_int = String -> Type
TCon String
"int"
ty_float :: Type
ty_float = String -> Type
TCon String
"float"
ty_bool :: Type
ty_bool = String -> Type
TCon String
"boolean"
ty_bottom :: Type
ty_bottom = String -> Type
TCon String
""

ty_fun_sig :: [Type] -> Type -> Type
ty_fun_sig :: [Type] -> Type -> Type
ty_fun_sig [Type]
args Type
ret = (Type -> Type -> Type) -> [Type] -> Type
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 ((Type -> Type -> Type) -> Type -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Type -> Type
(:->)) ([Type] -> Type) -> [Type] -> Type
forall a b. (a -> b) -> a -> b
$ Type
retType -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:([Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
args)

-- * Typing environment ("frame")

data Frame = Frame { Frame -> Map String Type
types :: Map Symbol Type }
  deriving (Frame -> Frame -> Bool
(Frame -> Frame -> Bool) -> (Frame -> Frame -> Bool) -> Eq Frame
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Frame -> Frame -> Bool
$c/= :: Frame -> Frame -> Bool
== :: Frame -> Frame -> Bool
$c== :: Frame -> Frame -> Bool
Eq, Int -> Frame -> ShowS
[Frame] -> ShowS
Frame -> String
(Int -> Frame -> ShowS)
-> (Frame -> String) -> ([Frame] -> ShowS) -> Show Frame
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Frame] -> ShowS
$cshowList :: [Frame] -> ShowS
show :: Frame -> String
$cshow :: Frame -> String
showsPrec :: Int -> Frame -> ShowS
$cshowsPrec :: Int -> Frame -> ShowS
Show)

frame_empty :: Frame
frame_empty :: Frame
frame_empty = Map String Type -> Frame
Frame Map String Type
forall k a. Map k a
M.empty

frame_extend :: Frame -> (Symbol, Type) -> Frame
frame_extend :: Frame -> (String, Type) -> Frame
frame_extend Frame
frame (String
x, Type
s) = Frame
frame { types :: Map String Type
types = String -> Type -> Map String Type -> Map String Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert String
x Type
s (Frame -> Map String Type
types Frame
frame) }

frame_remove :: Frame -> Symbol -> Frame
frame_remove :: Frame -> String -> Frame
frame_remove (Frame Map String Type
frame) String
var = Map String Type -> Frame
Frame (String -> Map String Type -> Map String Type
forall k a. Ord k => k -> Map k a -> Map k a
M.delete String
var Map String Type
frame)

frame_extends :: Frame -> [(Symbol, Type)] -> Frame
frame_extends :: Frame -> [(String, Type)] -> Frame
frame_extends Frame
frame [(String, Type)]
xs = Frame
frame { types :: Map String Type
types = Map String Type -> Map String Type -> Map String Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union ([(String, Type)] -> Map String Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(String, Type)]
xs) (Frame -> Map String Type
types Frame
frame) }

frame_lookup :: Symbol -> Frame -> Maybe Type
frame_lookup :: String -> Frame -> Maybe Type
frame_lookup String
key (Frame Map String Type
frame) = String -> Map String Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
key Map String Type
frame

frame_merge :: Frame -> Frame -> Frame
frame_merge :: Frame -> Frame -> Frame
frame_merge (Frame Map String Type
a) (Frame Map String Type
b) = Map String Type -> Frame
Frame (Map String Type -> Map String Type -> Map String Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map String Type
a Map String Type
b)

frames_merge :: [Frame] -> Frame
frames_merge :: [Frame] -> Frame
frames_merge = (Frame -> Frame -> Frame) -> Frame -> [Frame] -> Frame
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Frame -> Frame -> Frame
frame_merge Frame
frame_empty

frame_singleton :: Symbol -> Type -> Frame
frame_singleton :: String -> Type -> Frame
frame_singleton String
x Type
y = Map String Type -> Frame
Frame (String -> Type -> Map String Type
forall k a. k -> a -> Map k a
M.singleton String
x Type
y)

frame_keys :: Frame -> [Symbol]
frame_keys :: Frame -> [String]
frame_keys (Frame Map String Type
frame) = Map String Type -> [String]
forall k a. Map k a -> [k]
M.keys Map String Type
frame

frames_fromList :: [(Symbol, Type)] -> Frame
frames_fromList :: [(String, Type)] -> Frame
frames_fromList [(String, Type)]
xs = Map String Type -> Frame
Frame ([(String, Type)] -> Map String Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(String, Type)]
xs)

frames_toList :: Frame -> [(Symbol, Type)]
frames_toList :: Frame -> [(String, Type)]
frames_toList (Frame Map String Type
frame) = Map String Type -> [(String, Type)]
forall k a. Map k a -> [(k, a)]
M.toList Map String Type
frame

instance Semigroup Frame where
  <> :: Frame -> Frame -> Frame
(<>) = Frame -> Frame -> Frame
frame_merge

instance Monoid Frame where
  mempty :: Frame
mempty = Frame
frame_empty
  mappend :: Frame -> Frame -> Frame
mappend = Frame -> Frame -> Frame
forall a. Semigroup a => a -> a -> a
(<>)

-- * Inference monad

type Infer = ReaderT Frame (StateT InferState (Except String))
data InferState = InferState { InferState -> Int
count :: Int }
default_infer_state :: InferState
default_infer_state = InferState :: Int -> InferState
InferState { count :: Int
count = Int
0 }

-- * Constraint solving monad

newtype Subst = Subst (Map TypeVar Rho)
  deriving ( Subst -> Subst -> Bool
(Subst -> Subst -> Bool) -> (Subst -> Subst -> Bool) -> Eq Subst
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Subst -> Subst -> Bool
$c/= :: Subst -> Subst -> Bool
== :: Subst -> Subst -> Bool
$c== :: Subst -> Subst -> Bool
Eq
           , Eq Subst
Eq Subst
-> (Subst -> Subst -> Ordering)
-> (Subst -> Subst -> Bool)
-> (Subst -> Subst -> Bool)
-> (Subst -> Subst -> Bool)
-> (Subst -> Subst -> Bool)
-> (Subst -> Subst -> Subst)
-> (Subst -> Subst -> Subst)
-> Ord Subst
Subst -> Subst -> Bool
Subst -> Subst -> Ordering
Subst -> Subst -> Subst
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Subst -> Subst -> Subst
$cmin :: Subst -> Subst -> Subst
max :: Subst -> Subst -> Subst
$cmax :: Subst -> Subst -> Subst
>= :: Subst -> Subst -> Bool
$c>= :: Subst -> Subst -> Bool
> :: Subst -> Subst -> Bool
$c> :: Subst -> Subst -> Bool
<= :: Subst -> Subst -> Bool
$c<= :: Subst -> Subst -> Bool
< :: Subst -> Subst -> Bool
$c< :: Subst -> Subst -> Bool
compare :: Subst -> Subst -> Ordering
$ccompare :: Subst -> Subst -> Ordering
$cp1Ord :: Eq Subst
Ord
           , Int -> Subst -> ShowS
[Subst] -> ShowS
Subst -> String
(Int -> Subst -> ShowS)
-> (Subst -> String) -> ([Subst] -> ShowS) -> Show Subst
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Subst] -> ShowS
$cshowList :: [Subst] -> ShowS
show :: Subst -> String
$cshow :: Subst -> String
showsPrec :: Int -> Subst -> ShowS
$cshowsPrec :: Int -> Subst -> ShowS
Show
           , b -> Subst -> Subst
NonEmpty Subst -> Subst
Subst -> Subst -> Subst
(Subst -> Subst -> Subst)
-> (NonEmpty Subst -> Subst)
-> (forall b. Integral b => b -> Subst -> Subst)
-> Semigroup Subst
forall b. Integral b => b -> Subst -> Subst
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
stimes :: b -> Subst -> Subst
$cstimes :: forall b. Integral b => b -> Subst -> Subst
sconcat :: NonEmpty Subst -> Subst
$csconcat :: NonEmpty Subst -> Subst
<> :: Subst -> Subst -> Subst
$c<> :: Subst -> Subst -> Subst
Semigroup
           , Semigroup Subst
Subst
Semigroup Subst
-> Subst
-> (Subst -> Subst -> Subst)
-> ([Subst] -> Subst)
-> Monoid Subst
[Subst] -> Subst
Subst -> Subst -> Subst
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
mconcat :: [Subst] -> Subst
$cmconcat :: [Subst] -> Subst
mappend :: Subst -> Subst -> Subst
$cmappend :: Subst -> Subst -> Subst
mempty :: Subst
$cmempty :: Subst
$cp1Monoid :: Semigroup Subst
Monoid)

type Constraint = (Type, Type)
type Unifier = (Subst, [Constraint])

type Solve a = ExceptT String Identity a

class Substitutable a where
  substitute :: Subst -> a -> a
  ftv :: a -> Set TypeVar

instance Substitutable Type where
  substitute :: Subst -> Type -> Type
substitute Subst
_ (TCon String
a) = String -> Type
TCon String
a
  substitute (Subst Map Int Type
s) t :: Type
t@(TVar Int
a) = Type -> Int -> Map Int Type -> Type
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Type
t Int
a Map Int Type
s
  substitute Subst
s (Type
t1 :-> Type
t2) = (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
s Type
t1) Type -> Type -> Type
:-> (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
s Type
t2)
  substitute (Subst Map Int Type
s) (TForall [Int]
tys Type
t) = [Int] -> Type -> Type
TForall [Int]
tys (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
s' Type
t
    where s' :: Subst
s' = Map Int Type -> Subst
Subst (Map Int Type -> Subst) -> Map Int Type -> Subst
forall a b. (a -> b) -> a -> b
$ (Int -> Map Int Type -> Map Int Type)
-> Map Int Type -> [Int] -> Map Int Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Int -> Map Int Type -> Map Int Type
forall k a. Ord k => k -> Map k a -> Map k a
M.delete Map Int Type
s [Int]
tys

  ftv :: Type -> Set Int
ftv TCon{} = Set Int
forall a. Set a
S.empty
  ftv (TVar Int
a) = Int -> Set Int
forall a. a -> Set a
S.singleton Int
a
  ftv (Type
t1 :-> Type
t2) = (Type -> Set Int
forall a. Substitutable a => a -> Set Int
ftv Type
t1) Set Int -> Set Int -> Set Int
forall a. Ord a => Set a -> Set a -> Set a
`S.union` (Type -> Set Int
forall a. Substitutable a => a -> Set Int
ftv Type
t2)
  ftv (TForall [Int]
tys Type
t) = (Type -> Set Int
forall a. Substitutable a => a -> Set Int
ftv Type
t) Set Int -> Set Int -> Set Int
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` ([Int] -> Set Int
forall a. Ord a => [a] -> Set a
S.fromList [Int]
tys)

instance Substitutable Constraint where
  substitute :: Subst -> Constraint -> Constraint
substitute Subst
s (Type
t1, Type
t2) = (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
s Type
t1, Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
s Type
t2)
  ftv :: Constraint -> Set Int
ftv (Type
t1, Type
t2) = (Type -> Set Int
forall a. Substitutable a => a -> Set Int
ftv Type
t1) Set Int -> Set Int -> Set Int
forall a. Ord a => Set a -> Set a -> Set a
`S.union` (Type -> Set Int
forall a. Substitutable a => a -> Set Int
ftv Type
t2)

instance Substitutable a => Substitutable [a] where
  substitute :: Subst -> [a] -> [a]
substitute = (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ((a -> a) -> [a] -> [a])
-> (Subst -> a -> a) -> Subst -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Subst -> a -> a
forall a. Substitutable a => Subst -> a -> a
substitute
  ftv :: [a] -> Set Int
ftv = (a -> Set Int -> Set Int) -> Set Int -> [a] -> Set Int
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Set Int -> Set Int -> Set Int
forall a. Ord a => Set a -> Set a -> Set a
S.union (Set Int -> Set Int -> Set Int)
-> (a -> Set Int) -> a -> Set Int -> Set Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Set Int
forall a. Substitutable a => a -> Set Int
ftv) Set Int
forall a. Set a
S.empty

instance Substitutable Frame where
  substitute :: Subst -> Frame -> Frame
substitute Subst
s (Frame Map String Type
frame) = Map String Type -> Frame
Frame (Map String Type -> Frame) -> Map String Type -> Frame
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> Map String Type -> Map String Type
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
s) Map String Type
frame
  ftv :: Frame -> Set Int
ftv (Frame Map String Type
frame) = [Type] -> Set Int
forall a. Substitutable a => a -> Set Int
ftv ([Type] -> Set Int) -> [Type] -> Set Int
forall a b. (a -> b) -> a -> b
$ Map String Type -> [Type]
forall k a. Map k a -> [a]
M.elems Map String Type
frame

-- * Inference

run_infer :: Frame -> Infer a -> Either String a
run_infer :: Frame -> Infer a -> Either String a
run_infer Frame
frame Infer a
m = Except String a -> Either String a
forall e a. Except e a -> Either e a
runExcept (Except String a -> Either String a)
-> Except String a -> Either String a
forall a b. (a -> b) -> a -> b
$ StateT InferState (Except String) a
-> InferState -> Except String a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Infer a -> Frame -> StateT InferState (Except String) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT Infer a
m Frame
frame) InferState
default_infer_state

typecheck :: Frame -> CbpvExp -> Either String Type
typecheck :: Frame -> CbpvExp -> Either String Type
typecheck Frame
frame CbpvExp
ex = case Frame
-> Infer (Type, [Constraint]) -> Either String (Type, [Constraint])
forall a. Frame -> Infer a -> Either String a
run_infer Frame
frame (CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
ex) of
  Left String
err -> String -> Either String Type
forall a b. a -> Either a b
Left String
err
  Right (Type
ty, [Constraint]
cs) -> case [Constraint] -> Either String Subst
run_solve [Constraint]
cs of
    Left String
err -> String -> Either String Type
forall a b. a -> Either a b
Left String
err
    Right Subst
subst -> Type -> Either String Type
forall a b. b -> Either a b
Right (Type -> Type
close_over (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
subst Type
ty)

close_over :: Type -> Sigma
close_over :: Type -> Type
close_over = Type -> Type
normalize (Type -> Type) -> (Type -> Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Frame -> Type -> Type
generalize Frame
frame_empty

normalize :: Sigma -> Sigma
normalize :: Type -> Type
normalize (TForall [Int]
tys Type
body) = [Int] -> Type -> Type
TForall (((Int, Int) -> Int) -> [(Int, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Int) -> Int
forall a b. (a, b) -> b
snd [(Int, Int)]
ord) (Type -> Type
normtype Type
body) where
  ord :: [(Int, Int)]
ord = [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Int]
forall a. Eq a => [a] -> [a]
nub ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Type -> [Int]
fv Type
body) [Int]
tys
  fv :: Type -> [Int]
fv (TVar Int
a) = [Int
a]
  fv (Type
a :-> Type
b) = Type -> [Int]
fv Type
a [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Type -> [Int]
fv Type
b
  fv (TCon String
_) = []
  normtype :: Type -> Type
normtype (Type
a :-> Type
b) = (Type -> Type
normtype Type
a) Type -> Type -> Type
:-> (Type -> Type
normtype Type
b)
  normtype (TCon String
a) = String -> Type
TCon String
a
  normtype (TVar Int
a) = case Int -> [(Int, Int)] -> Maybe Int
forall a b. Eq a => a -> [(a, b)] -> Maybe b
Prelude.lookup Int
a [(Int, Int)]
ord of
    Just Int
x -> Int -> Type
TVar Int
x
    Maybe Int
Nothing -> String -> Type
forall a. HasCallStack => String -> a
error String
"type variable not in signature"
normalize Type
t = Type
t

in_frame :: (Symbol, Type) -> Infer a -> Infer a
in_frame :: (String, Type) -> Infer a -> Infer a
in_frame (String
x, Type
sc) Infer a
m = do
  let scope :: Frame -> Frame
scope Frame
e = (Frame -> String -> Frame
frame_remove Frame
e String
x) Frame -> (String, Type) -> Frame
`frame_extend` (String
x, Type
sc)
  (Frame -> Frame) -> Infer a -> Infer a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Frame -> Frame
scope Infer a
m

in_frames :: [(Symbol, Type)] -> Infer a -> Infer a
in_frames :: [(String, Type)] -> Infer a -> Infer a
in_frames [(String, Type)]
bindings Infer a
m = do
  let scope :: Frame -> Frame
scope Frame
e = (Frame -> [(String, Type)] -> Frame
forall b. Frame -> [(String, b)] -> Frame
frame_removes Frame
e [(String, Type)]
bindings) Frame -> [(String, Type)] -> Frame
`frame_extends` [(String, Type)]
bindings
      frame_removes :: Frame -> [(String, b)] -> Frame
frame_removes (Frame Map String Type
frame) [(String, b)]
bindings = Map String Type -> Frame
Frame (Map String Type -> Frame) -> Map String Type -> Frame
forall a b. (a -> b) -> a -> b
$ Map String Type
frame Map String Type -> Map String b -> Map String Type
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` ([(String, b)] -> Map String b
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(String, b)]
bindings)
  (Frame -> Frame) -> Infer a -> Infer a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Frame -> Frame
scope Infer a
m

lookup_in_frame :: Symbol -> Infer (Rho,[Constraint])
lookup_in_frame :: String -> Infer (Type, [Constraint])
lookup_in_frame String
s = do
  Type
σ <- case String -> Maybe Type
prim_type String
s of
    Maybe Type
Nothing -> do
      Frame
frame <- ReaderT Frame (StateT InferState (Except String)) Frame
forall r (m :: * -> *). MonadReader r m => m r
ask
      case String -> Frame -> Maybe Type
frame_lookup String
s Frame
frame of
        Maybe Type
Nothing -> String -> ReaderT Frame (StateT InferState (Except String)) Type
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
"Unbound symbol: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s)
        Just Type
sig -> Type -> ReaderT Frame (StateT InferState (Except String)) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
sig
    Just Type
sig -> Type -> ReaderT Frame (StateT InferState (Except String)) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
sig
  Type
ρ <- Type -> ReaderT Frame (StateT InferState (Except String)) Type
instantiate Type
σ
  (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ρ, [])

fresh :: Infer Tau
fresh :: ReaderT Frame (StateT InferState (Except String)) Type
fresh = do
  InferState
s <- ReaderT Frame (StateT InferState (Except String)) InferState
forall s (m :: * -> *). MonadState s m => m s
get
  InferState -> ReaderT Frame (StateT InferState (Except String)) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put InferState
s { count :: Int
count = InferState -> Int
count InferState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 }
  Type -> ReaderT Frame (StateT InferState (Except String)) Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> ReaderT Frame (StateT InferState (Except String)) Type)
-> Type -> ReaderT Frame (StateT InferState (Except String)) Type
forall a b. (a -> b) -> a -> b
$ Int -> Type
TVar (InferState -> Int
count InferState
s)

instantiate :: Sigma -> Infer Rho
instantiate :: Type -> ReaderT Frame (StateT InferState (Except String)) Type
instantiate (TForall [Int]
tys Type
t) = do
  [Type]
tys' <- (Int -> ReaderT Frame (StateT InferState (Except String)) Type)
-> [Int]
-> ReaderT Frame (StateT InferState (Except String)) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ReaderT Frame (StateT InferState (Except String)) Type
-> Int -> ReaderT Frame (StateT InferState (Except String)) Type
forall a b. a -> b -> a
const ReaderT Frame (StateT InferState (Except String)) Type
fresh) [Int]
tys
  let s :: Subst
s = Map Int Type -> Subst
Subst (Map Int Type -> Subst) -> Map Int Type -> Subst
forall a b. (a -> b) -> a -> b
$ [(Int, Type)] -> Map Int Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Int, Type)] -> Map Int Type) -> [(Int, Type)] -> Map Int Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [Type] -> [(Int, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
tys [Type]
tys'
  Type -> ReaderT Frame (StateT InferState (Except String)) Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> ReaderT Frame (StateT InferState (Except String)) Type)
-> Type -> ReaderT Frame (StateT InferState (Except String)) Type
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
s Type
t
instantiate Type
ty = Type -> ReaderT Frame (StateT InferState (Except String)) Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ty

generalize :: Frame -> Rho -> Sigma
generalize :: Frame -> Type -> Type
generalize Frame
frame Type
t = [Int] -> Type -> Type
TForall [Int]
tys Type
t
  where tys :: [Int]
tys = Set Int -> [Int]
forall a. Set a -> [a]
S.toList (Set Int -> [Int]) -> Set Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Type -> Set Int
forall a. Substitutable a => a -> Set Int
ftv Type
t Set Int -> Set Int -> Set Int
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Frame -> Set Int
forall a. Substitutable a => a -> Set Int
ftv Frame
frame

-- | Deep skolemization.
-- Returns the skolem variables and the skolemized ρ-type.
skolemize :: Sigma -> Infer ([TypeVar], Rho)
skolemize :: Type -> Infer ([Int], Type)
skolemize (TForall [Int]
tys Type
ρ1) = do -- rule PRPOLY
  [Type]
sks1 <- (Int -> ReaderT Frame (StateT InferState (Except String)) Type)
-> [Int]
-> ReaderT Frame (StateT InferState (Except String)) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ReaderT Frame (StateT InferState (Except String)) Type
-> Int -> ReaderT Frame (StateT InferState (Except String)) Type
forall a b. a -> b -> a
const ReaderT Frame (StateT InferState (Except String)) Type
fresh) [Int]
tys
  ([Int]
sks2, Type
ρ2) <- Type -> Infer ([Int], Type)
skolemize Type
ρ1
  ([Int], Type) -> Infer ([Int], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Set Int -> [Int]
forall a. Set a -> [a]
S.toList ([Type] -> Set Int
forall a. Substitutable a => a -> Set Int
ftv [Type]
sks1)) [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int]
sks2, Type
ρ2)
skolemize (Type
σ1 :-> Type
σ2) = do -- rule PRFUN
  ([Int]
sks, Type
σ2') <- Type -> Infer ([Int], Type)
skolemize Type
σ2
  ([Int], Type) -> Infer ([Int], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
sks, Type
σ1 Type -> Type -> Type
:-> Type
σ2')
skolemize Type
τ = ([Int], Type) -> Infer ([Int], Type)
forall (m :: * -> *) a. Monad m => a -> m a
return ([], Type
τ) -- rule PRMONO

{- | Type inference rules for 'CbpvExp'ressions.
All functions have at least one (terminal) argument of type ⊥.
-}
infer :: CbpvExp -> Infer (Rho, [Constraint])
infer :: CbpvExp -> Infer (Type, [Constraint])
infer (() :< Cbpv CbpvExp
cbpvexpr) = case Cbpv CbpvExp
cbpvexpr of
  IntA Integer
_ -> (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty_int, [])
  FloatA Double
_ -> (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty_float, [])
  BoolA Bool
_ -> (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty_bool, [])
  SymA String
sym -> case String
sym of
    String
"_" -> ReaderT Frame (StateT InferState (Except String)) Type
fresh ReaderT Frame (StateT InferState (Except String)) Type
-> (Type -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint])
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Type
ty -> (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ty, [])
    String
_   -> String -> Infer (Type, [Constraint])
lookup_in_frame String
sym
  OpA String
op [CbpvExp]
erands -> CbpvExp -> Infer (Type, [Constraint])
infer (() () -> Cbpv CbpvExp -> CbpvExp
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (CbpvExp -> [CbpvExp] -> Cbpv CbpvExp
forall a. a -> [a] -> Cbpv a
AppA (() () -> Cbpv CbpvExp -> CbpvExp
forall (f :: * -> *) a. a -> f (Cofree f a) -> Cofree f a
:< (String -> Cbpv CbpvExp
forall a. String -> Cbpv a
SymA String
op)) [CbpvExp]
erands))
  SuspendA CbpvExp
exp -> CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
exp
  IfA CbpvExp
c CbpvExp
t CbpvExp
e -> do
    ([Type]
ρs, [[Constraint]]
cs) <- (CbpvExp -> Infer (Type, [Constraint]))
-> [CbpvExp]
-> ReaderT
     Frame (StateT InferState (Except String)) ([Type], [[Constraint]])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM CbpvExp -> Infer (Type, [Constraint])
infer [CbpvExp
c, CbpvExp
t, CbpvExp
e]
    (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ρs [Type] -> Int -> Type
forall a. [a] -> Int -> a
!! Int
1, ([[Constraint]] -> [Constraint]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Constraint]]
cs) [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [ ([Type]
ρs [Type] -> Int -> Type
forall a. [a] -> Int -> a
!! Int
0, Type
ty_bool), ([Type]
ρs [Type] -> Int -> Type
forall a. [a] -> Int -> a
!! Int
1, [Type]
ρs [Type] -> Int -> Type
forall a. [a] -> Int -> a
!! Int
2) ])
  ResumeA CbpvExp
exp -> CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
exp
  ResetA CbpvExp
exp -> do
    CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
exp
  ShiftA String
karg CbpvExp
exp -> do
    Frame
frame <- ReaderT Frame (StateT InferState (Except String)) Frame
forall r (m :: * -> *). MonadReader r m => m r
ask
    Type
tkarg <- ReaderT Frame (StateT InferState (Except String)) Type
fresh
    Type
tkret <- ReaderT Frame (StateT InferState (Except String)) Type
fresh
    let tk :: Type
tk = Frame -> Type -> Type
generalize Frame
frame (Type
tkarg Type -> Type -> Type
:-> Type
tkret)
    (Type
rho, [Constraint]
c) <- (String, Type)
-> Infer (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a. (String, Type) -> Infer a -> Infer a
in_frame (String
karg, Type
tk) (Infer (Type, [Constraint]) -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a b. (a -> b) -> a -> b
$ CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
exp
    [Constraint]
-> (Subst -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint])
forall a. [Constraint] -> (Subst -> Infer a) -> Infer a
solve_inference [Constraint]
c ((Subst -> Infer (Type, [Constraint]))
 -> Infer (Type, [Constraint]))
-> (Subst -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint])
forall a b. (a -> b) -> a -> b
$ \Subst
sub -> do
      let TForall [Int]
_ Type
tkarg' = Frame -> Type -> Type
generalize (Subst -> Frame -> Frame
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub Frame
frame) (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub Type
tkarg)
      (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
tkarg', [])
  AppA CbpvExp
op [CbpvExp]
erands -> do
    Type
rho <- ReaderT Frame (StateT InferState (Except String)) Type
fresh
    Frame
frame <- ReaderT Frame (StateT InferState (Except String)) Frame
forall r (m :: * -> *). MonadReader r m => m r
ask
    (Type
opt, [Constraint]
opc) <- CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
op
    ([Type]
erands_ts, [[Constraint]]
erands_cs) <- (CbpvExp -> Infer (Type, [Constraint]))
-> [CbpvExp]
-> ReaderT
     Frame (StateT InferState (Except String)) ([Type], [[Constraint]])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM CbpvExp -> Infer (Type, [Constraint])
infer [CbpvExp]
erands
    [Constraint]
-> (Subst -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint])
forall a. [Constraint] -> (Subst -> Infer a) -> Infer a
solve_inference ([[Constraint]] -> [Constraint]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Constraint]]
erands_cs) ((Subst -> Infer (Type, [Constraint]))
 -> Infer (Type, [Constraint]))
-> (Subst -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint])
forall a b. (a -> b) -> a -> b
$ \Subst
sub -> do
      [Type]
erands_ts' <- [Type]
-> (Type -> ReaderT Frame (StateT InferState (Except String)) Type)
-> ReaderT Frame (StateT InferState (Except String)) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
erands_ts ((Type -> ReaderT Frame (StateT InferState (Except String)) Type)
 -> ReaderT Frame (StateT InferState (Except String)) [Type])
-> (Type -> ReaderT Frame (StateT InferState (Except String)) Type)
-> ReaderT Frame (StateT InferState (Except String)) [Type]
forall a b. (a -> b) -> a -> b
$ \Type
ty -> Type -> ReaderT Frame (StateT InferState (Except String)) Type
instantiate (Type -> ReaderT Frame (StateT InferState (Except String)) Type)
-> Type -> ReaderT Frame (StateT InferState (Except String)) Type
forall a b. (a -> b) -> a -> b
$
        Frame -> Type -> Type
generalize (Subst -> Frame -> Frame
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub Frame
frame) (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub Type
ty)
      (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
rho, [Constraint]
opc [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [ (Type
opt, [Type]
erands_ts' [Type] -> Type -> Type
`ty_fun_sig` Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub Type
rho)])
  FunA [String]
args CbpvExp
body -> do
    [Type]
tvs <- (String -> ReaderT Frame (StateT InferState (Except String)) Type)
-> [String]
-> ReaderT Frame (StateT InferState (Except String)) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ReaderT Frame (StateT InferState (Except String)) Type
-> String -> ReaderT Frame (StateT InferState (Except String)) Type
forall a b. a -> b -> a
const ReaderT Frame (StateT InferState (Except String)) Type
fresh) [String]
args
    (Type
t, [Constraint]
c) <- [(String, Type)]
-> Infer (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a. [(String, Type)] -> Infer a -> Infer a
in_frames ([String] -> [Type] -> [(String, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
args [Type]
tvs) (CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
body)
    (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
tvs [Type] -> Type -> Type
`ty_fun_sig` Type
t, [Constraint]
c)
  LetrecA [(String, CbpvExp)]
bindings CbpvExp
body -> do
    let bindings' :: [(String, CbpvExp)]
bindings' = [(String, CbpvExp)] -> [(String, CbpvExp)]
forall a. [a] -> [a]
reverse ([(String, CbpvExp)] -> [(String, CbpvExp)])
-> [(String, CbpvExp)] -> [(String, CbpvExp)]
forall a b. (a -> b) -> a -> b
$ Graph (String, CbpvExp) String -> [(String, CbpvExp)]
forall node key. Graph node key -> [node]
topo (Graph (String, CbpvExp) String -> [(String, CbpvExp)])
-> Graph (String, CbpvExp) String -> [(String, CbpvExp)]
forall a b. (a -> b) -> a -> b
$ Map String CbpvExp -> Graph (String, CbpvExp) String
make_dep_graph (Map String CbpvExp -> Graph (String, CbpvExp) String)
-> Map String CbpvExp -> Graph (String, CbpvExp) String
forall a b. (a -> b) -> a -> b
$ [(String, CbpvExp)] -> Map String CbpvExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(String, CbpvExp)]
bindings
    [(String, Type)]
btvs <- [(String, CbpvExp)]
-> ((String, CbpvExp)
    -> ReaderT
         Frame (StateT InferState (Except String)) (String, Type))
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(String, CbpvExp)]
bindings' (((String, CbpvExp)
  -> ReaderT
       Frame (StateT InferState (Except String)) (String, Type))
 -> ReaderT
      Frame (StateT InferState (Except String)) [(String, Type)])
-> ((String, CbpvExp)
    -> ReaderT
         Frame (StateT InferState (Except String)) (String, Type))
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall a b. (a -> b) -> a -> b
$ \(String
name, CbpvExp
_) -> ReaderT Frame (StateT InferState (Except String)) Type
fresh ReaderT Frame (StateT InferState (Except String)) Type
-> (Type
    -> ReaderT
         Frame (StateT InferState (Except String)) (String, Type))
-> ReaderT Frame (StateT InferState (Except String)) (String, Type)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Type
tv -> (String, Type)
-> ReaderT Frame (StateT InferState (Except String)) (String, Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (String
name, Type
tv)
    let pass :: [(String, Type)]
-> (String, CbpvExp)
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
pass [(String, Type)]
fr (String
name, CbpvExp
defn) = [(String, Type)]
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall a. [(String, Type)] -> Infer a -> Infer a
in_frames [(String, Type)]
fr (ReaderT Frame (StateT InferState (Except String)) [(String, Type)]
 -> ReaderT
      Frame (StateT InferState (Except String)) [(String, Type)])
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall a b. (a -> b) -> a -> b
$ do
          (Type
ρ, [Constraint]
c) <- CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
defn
          [Constraint]
-> (Subst
    -> ReaderT
         Frame (StateT InferState (Except String)) [(String, Type)])
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall a. [Constraint] -> (Subst -> Infer a) -> Infer a
solve_inference [Constraint]
c ((Subst
  -> ReaderT
       Frame (StateT InferState (Except String)) [(String, Type)])
 -> ReaderT
      Frame (StateT InferState (Except String)) [(String, Type)])
-> (Subst
    -> ReaderT
         Frame (StateT InferState (Except String)) [(String, Type)])
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall a b. (a -> b) -> a -> b
$ \Subst
sub -> do
            let σ :: Type
σ = Frame -> Type -> Type
generalize
                    (Subst -> Frame -> Frame
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub (Frame -> Frame) -> Frame -> Frame
forall a b. (a -> b) -> a -> b
$ [(String, Type)] -> Frame
frames_fromList [(String, Type)]
fr) (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub Type
ρ)
            [(String, Type)]
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(String, Type)]
 -> ReaderT
      Frame (StateT InferState (Except String)) [(String, Type)])
-> [(String, Type)]
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall a b. (a -> b) -> a -> b
$ [(String, Type)]
fr [(String, Type)] -> [(String, Type)] -> [(String, Type)]
forall a. [a] -> [a] -> [a]
++ [(String
name, Type
σ)]
    [(String, Type)]
fr <- ([(String, Type)]
 -> (String, CbpvExp)
 -> ReaderT
      Frame (StateT InferState (Except String)) [(String, Type)])
-> [(String, Type)]
-> [(String, CbpvExp)]
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM [(String, Type)]
-> (String, CbpvExp)
-> ReaderT
     Frame (StateT InferState (Except String)) [(String, Type)]
pass [(String, Type)]
btvs [(String, CbpvExp)]
bindings'
    [(String, Type)]
-> Infer (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a. [(String, Type)] -> Infer a -> Infer a
in_frames [(String, Type)]
fr (Infer (Type, [Constraint]) -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a b. (a -> b) -> a -> b
$ CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
body
  LetA String
var CbpvExp
exp CbpvExp
body -> do
    Frame
frame <- ReaderT Frame (StateT InferState (Except String)) Frame
forall r (m :: * -> *). MonadReader r m => m r
ask
    (Type
ρ1, [Constraint]
c1) <- CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
exp
    [Constraint]
-> (Subst -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint])
forall a. [Constraint] -> (Subst -> Infer a) -> Infer a
solve_inference [Constraint]
c1 ((Subst -> Infer (Type, [Constraint]))
 -> Infer (Type, [Constraint]))
-> (Subst -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint])
forall a b. (a -> b) -> a -> b
$ \Subst
sub -> do
      let σ :: Type
σ = Frame -> Type -> Type
generalize (Subst -> Frame -> Frame
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub Frame
frame) (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub Type
ρ1)
      (Type
ρ2, [Constraint]
c2) <- (String, Type)
-> Infer (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a. (String, Type) -> Infer a -> Infer a
in_frame (String
var, Type
σ) (Infer (Type, [Constraint]) -> Infer (Type, [Constraint]))
-> Infer (Type, [Constraint]) -> Infer (Type, [Constraint])
forall a b. (a -> b) -> a -> b
$ (Frame -> Frame)
-> Infer (Type, [Constraint]) -> Infer (Type, [Constraint])
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Subst -> Frame -> Frame
forall a. Substitutable a => Subst -> a -> a
substitute Subst
sub) (CbpvExp -> Infer (Type, [Constraint])
infer CbpvExp
body)
      --throwError $ var ++ " = " ++ (show ρ2) ++ ", subbed fr = " ++ (show $ substitute sub frame)
      (Type, [Constraint]) -> Infer (Type, [Constraint])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
ρ2, [Constraint]
c1 [Constraint] -> [Constraint] -> [Constraint]
forall a. [a] -> [a] -> [a]
++ [Constraint]
c2)
  Cbpv CbpvExp
_ -> String -> Infer (Type, [Constraint])
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
"Typechecking for that code Coming Soon™")

-- * Constraint solver

subst_empty :: Subst
subst_empty :: Subst
subst_empty = Subst
forall a. Monoid a => a
mempty

-- * Compose substitutions
compose :: Subst -> Subst -> Subst
(Subst Map Int Type
s1) compose :: Subst -> Subst -> Subst
`compose` (Subst Map Int Type
s2) = Map Int Type -> Subst
Subst (Map Int Type -> Subst) -> Map Int Type -> Subst
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> Map Int Type -> Map Int Type
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Subst -> Type -> Type
forall a. Substitutable a => Subst -> a -> a
substitute (Map Int Type -> Subst
Subst Map Int Type
s1)) Map Int Type
s2 Map Int Type -> Map Int Type -> Map Int Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map Int Type
s1

-- * Run the constraint solver

run_solve :: [Constraint] -> Either String Subst
run_solve :: [Constraint] -> Either String Subst
run_solve [Constraint]
cs = Identity (Either String Subst) -> Either String Subst
forall a. Identity a -> a
runIdentity (Identity (Either String Subst) -> Either String Subst)
-> Identity (Either String Subst) -> Either String Subst
forall a b. (a -> b) -> a -> b
$ ExceptT String Identity Subst -> Identity (Either String Subst)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT String Identity Subst -> Identity (Either String Subst))
-> ExceptT String Identity Subst -> Identity (Either String Subst)
forall a b. (a -> b) -> a -> b
$ Unifier -> ExceptT String Identity Subst
solver Unifier
st
  where st :: Unifier
st = (Subst
subst_empty, [Constraint]
cs)

unify_many :: [Type] -> [Type] -> Solve Subst
unify_many :: [Type] -> [Type] -> ExceptT String Identity Subst
unify_many [] [] = Subst -> ExceptT String Identity Subst
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
subst_empty
unify_many (Type
t1 : [Type]
ts1) (Type
t2 : [Type]
ts2) = do
  Subst
su1 <- Type -> Type -> ExceptT String Identity Subst
unifies Type
t1 Type
t2
  Subst
su2 <- [Type] -> [Type] -> ExceptT String Identity Subst
unify_many (Subst -> [Type] -> [Type]
forall a. Substitutable a => Subst -> a -> a
substitute Subst
su1 [Type]
ts1) (Subst -> [Type] -> [Type]
forall a. Substitutable a => Subst -> a -> a
substitute Subst
su1 [Type]
ts2)
  Subst -> ExceptT String Identity Subst
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst
su2 Subst -> Subst -> Subst
`compose` Subst
su1)
unify_many [Type]
t1 [Type]
t2 = String -> ExceptT String Identity Subst
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String Identity Subst)
-> String -> ExceptT String Identity Subst
forall a b. (a -> b) -> a -> b
$ String
"Unification mismatch: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ([Type] -> String
forall a. Show a => a -> String
show [Type]
t1) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" vs " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ([Type] -> String
forall a. Show a => a -> String
show [Type]
t2)

unifies :: Type -> Type -> Solve Subst
unifies :: Type -> Type -> ExceptT String Identity Subst
unifies Type
t1 Type
t2 | Type
t1 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t2 = Subst -> ExceptT String Identity Subst
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
subst_empty
unifies (TVar Int
v) Type
t = Int
v Int -> Type -> ExceptT String Identity Subst
`bind` Type
t
unifies Type
t (TVar Int
v) = Int
v Int -> Type -> ExceptT String Identity Subst
`bind` Type
t
unifies (Type
t1 :-> Type
t2) (Type
t3 :-> Type
t4) = [Type] -> [Type] -> ExceptT String Identity Subst
unify_many [Type
t1, Type
t2] [Type
t3, Type
t4]
unifies Type
t1 Type
t2 = String -> ExceptT String Identity Subst
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String Identity Subst)
-> String -> ExceptT String Identity Subst
forall a b. (a -> b) -> a -> b
$ String
"Unification fail: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Type -> String
forall a. Show a => a -> String
show Type
t1) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" vs " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Type -> String
forall a. Show a => a -> String
show Type
t2)

-- * Unification solver
solver :: Unifier -> Solve Subst
solver :: Unifier -> ExceptT String Identity Subst
solver (Subst
su, [Constraint]
cs) = case [Constraint]
cs of
  [] -> Subst -> ExceptT String Identity Subst
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
su
  ((Type
t1, Type
t2):[Constraint]
cs0) -> do
    Subst
su1 <- Type -> Type -> ExceptT String Identity Subst
unifies Type
t1 Type
t2
    Unifier -> ExceptT String Identity Subst
solver (Subst
su1 Subst -> Subst -> Subst
`compose` Subst
su, Subst -> [Constraint] -> [Constraint]
forall a. Substitutable a => Subst -> a -> a
substitute Subst
su1 [Constraint]
cs0)

bind :: TypeVar -> Type -> Solve Subst
bind :: Int -> Type -> ExceptT String Identity Subst
bind Int
a Type
t | Type
t Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Type
TVar Int
a = Subst -> ExceptT String Identity Subst
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
subst_empty
         | Int -> Type -> Bool
forall a. Substitutable a => Int -> a -> Bool
occurs_check Int
a Type
t = String -> ExceptT String Identity Subst
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String Identity Subst)
-> String -> ExceptT String Identity Subst
forall a b. (a -> b) -> a -> b
$ String
"Infinite type: _" String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int -> String
forall a. Show a => a -> String
show Int
a) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" ~ " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Type -> String
forall a. Show a => a -> String
show Type
t)
         | Bool
otherwise = Subst -> ExceptT String Identity Subst
forall (m :: * -> *) a. Monad m => a -> m a
return (Map Int Type -> Subst
Subst (Map Int Type -> Subst) -> Map Int Type -> Subst
forall a b. (a -> b) -> a -> b
$ Int -> Type -> Map Int Type
forall k a. k -> a -> Map k a
M.singleton Int
a Type
t)

occurs_check :: Substitutable a => TypeVar -> a -> Bool
occurs_check :: Int -> a -> Bool
occurs_check Int
a a
t = Int
a Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` a -> Set Int
forall a. Substitutable a => a -> Set Int
ftv a
t

solve_inference :: [Constraint] -> (Subst -> Infer a) -> Infer a
solve_inference :: [Constraint] -> (Subst -> Infer a) -> Infer a
solve_inference [Constraint]
c Subst -> Infer a
ks = case [Constraint] -> Either String Subst
run_solve [Constraint]
c of
  Left String
err -> String -> Infer a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
err
  Right Subst
sub -> Subst -> Infer a
ks Subst
sub

-- | Lookup table for primitive operators.
prim_type :: Symbol -> Maybe Sigma
prim_type :: String -> Maybe Type
prim_type String
"add-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_int)
prim_type String
"mul-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_int)
prim_type String
"sub-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_int)
prim_type String
"div-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_int)
prim_type String
"add-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_float)
prim_type String
"mul-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_float)
prim_type String
"sub-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_float)
prim_type String
"div-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_float)
prim_type String
"eq-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"lte-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"gte-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"lt-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"gt-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_int Type -> Type -> Type
:-> (Type
ty_int Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"eq-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"lte-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"gte-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"lt-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"gt-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_float Type -> Type -> Type
:-> (Type
ty_float Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"=?" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Int -> Type
TVar Int
0 Type -> Type -> Type
:-> (Int -> Type
TVar Int
0 Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"not" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_bool Type -> Type -> Type
:-> (Type
ty_bool)
prim_type String
"mod-int" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Int -> Type
TVar Int
0 Type -> Type -> Type
:-> (Int -> Type
TVar Int
0 Type -> Type -> Type
:-> Int -> Type
TVar Int
0)
prim_type String
"mod-float" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Int -> Type
TVar Int
0 Type -> Type -> Type
:-> (Int -> Type
TVar Int
0 Type -> Type -> Type
:-> Int -> Type
TVar Int
0)
prim_type String
"&&" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_bool Type -> Type -> Type
:-> (Type
ty_bool Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
"||" = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ [Int] -> Type -> Type
TForall [Int
0] (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type
ty_bool Type -> Type -> Type
:-> (Type
ty_bool Type -> Type -> Type
:-> Type
ty_bool)
prim_type String
_ = Maybe Type
forall a. Maybe a
Nothing

-- * Graph implementation

data Graph node key = Graph
  { Graph node key -> Graph
_graph :: G.Graph
  , Graph node key -> Int -> (node, key, [key])
_vertices :: G.Vertex -> (node, key, [key])
  }

graph_fromList :: Ord key => [(node, key, [key])] -> Graph node key
graph_fromList :: [(node, key, [key])] -> Graph node key
graph_fromList = (Graph -> (Int -> (node, key, [key])) -> Graph node key)
-> (Graph, Int -> (node, key, [key])) -> Graph node key
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Graph -> (Int -> (node, key, [key])) -> Graph node key
forall node key.
Graph -> (Int -> (node, key, [key])) -> Graph node key
Graph ((Graph, Int -> (node, key, [key])) -> Graph node key)
-> ([(node, key, [key])] -> (Graph, Int -> (node, key, [key])))
-> [(node, key, [key])]
-> Graph node key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(node, key, [key])] -> (Graph, Int -> (node, key, [key]))
forall key node.
Ord key =>
[(node, key, [key])] -> (Graph, Int -> (node, key, [key]))
G.graphFromEdges'

vertex_labels :: Functor f => Graph b t -> (f G.Vertex) -> f b
vertex_labels :: Graph b t -> f Int -> f b
vertex_labels Graph b t
g = (Int -> b) -> f Int -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Graph b t -> Int -> b
forall b t. Graph b t -> Int -> b
vertex_label Graph b t
g)

vertex_label :: Graph b t -> G.Vertex -> b
vertex_label :: Graph b t -> Int -> b
vertex_label Graph b t
g = (\(b
vi, t
_, [t]
_) -> b
vi) ((b, t, [t]) -> b) -> (Int -> (b, t, [t])) -> Int -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Graph b t -> Int -> (b, t, [t])
forall node key. Graph node key -> Int -> (node, key, [key])
_vertices Graph b t
g)

topo :: Graph node key -> [node]
topo :: Graph node key -> [node]
topo Graph node key
g = Graph node key -> [Int] -> [node]
forall (f :: * -> *) b t. Functor f => Graph b t -> f Int -> f b
vertex_labels Graph node key
g ([Int] -> [node]) -> [Int] -> [node]
forall a b. (a -> b) -> a -> b
$ Graph -> [Int]
G.topSort (Graph node key -> Graph
forall node key. Graph node key -> Graph
_graph Graph node key
g)

make_dep_graph
  :: Map Symbol CbpvExp
  -> Graph (Symbol, CbpvExp) Symbol
make_dep_graph :: Map String CbpvExp -> Graph (String, CbpvExp) String
make_dep_graph Map String CbpvExp
defns = [((String, CbpvExp), String, [String])]
-> Graph (String, CbpvExp) String
forall key node. Ord key => [(node, key, [key])] -> Graph node key
graph_fromList ([((String, CbpvExp), String, [String])]
 -> Graph (String, CbpvExp) String)
-> [((String, CbpvExp), String, [String])]
-> Graph (String, CbpvExp) String
forall a b. (a -> b) -> a -> b
$ Map String ((String, CbpvExp), String, [String])
-> [((String, CbpvExp), String, [String])]
forall k a. Map k a -> [a]
M.elems (Map String ((String, CbpvExp), String, [String])
 -> [((String, CbpvExp), String, [String])])
-> Map String ((String, CbpvExp), String, [String])
-> [((String, CbpvExp), String, [String])]
forall a b. (a -> b) -> a -> b
$ (String -> CbpvExp -> ((String, CbpvExp), String, [String]))
-> Map String CbpvExp
-> Map String ((String, CbpvExp), String, [String])
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey String -> CbpvExp -> ((String, CbpvExp), String, [String])
forall b. b -> CbpvExp -> ((b, CbpvExp), b, [String])
dep_list Map String CbpvExp
defns where
  dep_list :: b -> CbpvExp -> ((b, CbpvExp), b, [String])
dep_list b
sym CbpvExp
expr = ((b
sym, CbpvExp
expr), b
sym, Cofree Cbpv [String] -> [String]
forall (w :: * -> *) a. Comonad w => w a -> a
extract ((CbpvExp -> [String]) -> CbpvExp -> Cofree Cbpv [String]
forall (w :: * -> *) a b. Comonad w => (w a -> b) -> w a -> w b
extend (Map String CbpvExp -> CbpvExp -> [String]
deps Map String CbpvExp
defns) CbpvExp
expr))