feat: add SynthInstance.lean

It is the first step to integrate the new type class procedure intro
`MetaM`. It implements the main entry point where we preprocess the
instance, check cache, invoke main function, process replacements, and
cache result.
This commit is contained in:
Leonardo de Moura 2019-12-01 18:30:02 -08:00
parent 5682c6e33b
commit 7e34cb4474
3 changed files with 234 additions and 1 deletions

View file

@ -14,3 +14,4 @@ import Init.Lean.Meta.DiscrTree
import Init.Lean.Meta.Reduce
import Init.Lean.Meta.Instances
import Init.Lean.Meta.AbstractMVars
import Init.Lean.Meta.SynthInstance

View file

@ -0,0 +1,192 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Daniel Selsam, Leonardo de Moura
Type class instance synthesizer using tabled resolution.
-/
import Init.Lean.Meta.Basic
import Init.Lean.Meta.Instances
import Init.Lean.Meta.LevelDefEq
import Init.Lean.Meta.AbstractMVars
namespace Lean
namespace Meta
namespace SynthInstance
structure Context :=
(config : Config := {})
(lctx : LocalContext := {})
(localInstances : LocalInstances := #[])
(globalInstances : DiscrTree Expr := {})
structure GeneratorNode :=
(mctx : MetavarContext)
(instances : Array Expr)
(currInstanceIdx : Nat)
structure ConsumerNode :=
(mctx : MetavarContext)
(subgoals : List Expr)
(answer : MVarId)
inductive Waiter
| consumerNode : ConsumerNode → Waiter
| root : Waiter
/-
We represent the tabled/cached entries using
1- An imperfect discrimination tree that stores the type class instances (i.e., types)
an unique index.
2- A persistent array which represents a map from unique indices to `TableEntry`.
-/
structure Key :=
(key : AbstractMVarsResult)
(idx : Nat)
structure TableEntry :=
(waiters : Array Waiter)
(answers : Array AbstractMVarsResult)
structure State :=
(env : Environment)
(cache : Cache)
(ngen : NameGenerator)
(traceState : TraceState)
(mainMVarId : MVarId)
(generatorStack : Array GeneratorNode := #[])
(resumeStack : Array (ConsumerNode × Expr) := #[])
(tableKeys : DiscrTree Key := {})
(tableEntries : PersistentArray TableEntry := {})
abbrev M := ReaderT Context (EStateM Exception State)
@[inline] private def getTraceState : M TraceState :=
do s ← get; pure s.traceState
@[inline] private def getOptions : M Options :=
do ctx ← read; pure ctx.config.opts
instance tracer : SimpleMonadTracerAdapter M :=
{ getOptions := getOptions,
getTraceState := getTraceState,
modifyTraceState := fun f => modify $ fun s => { traceState := f s.traceState, .. s } }
@[inline] def trace (cls : Name) (mctx : MetavarContext) (msg : Unit → MessageData) : M Unit :=
whenM (MonadTracerAdapter.isTracingEnabledFor cls) $ do
ctx ← read;
s ← get;
MonadTracerAdapter.addTrace cls (MessageData.context s.env mctx ctx.lctx (msg ()))
@[inline] def updateState (s : State) (newS : Meta.State) : State :=
{ env := newS.env, cache := newS.cache, ngen := newS.ngen, traceState := newS.traceState, .. s }
@[inline] def runMetaM {α} (x : MetaM α) (mctx : MetavarContext) : M (α × MetavarContext) :=
fun ctx s =>
let r := (x { config := ctx.config, lctx := ctx.lctx, localInstances := ctx.localInstances }).run {
env := s.env,
mctx := mctx,
cache := s.cache,
ngen := s.ngen,
traceState := s.traceState
};
match r with
| EStateM.Result.error ex newS => EStateM.Result.error ex (updateState s newS)
| EStateM.Result.ok a newS => EStateM.Result.ok (a, newS.mctx) (updateState s newS)
def main (type : Expr) : MetaM (Option Expr) :=
pure none -- TODO
end SynthInstance
structure Replacements :=
(levelReplacements : Array (Level × Level) := #[])
(exprReplacements : Array (Expr × Expr) := #[])
private def preprocess (type : Expr) : MetaM Expr :=
forallTelescopeReducing type $ fun xs type => do
type ← whnf type;
mkForall xs type
private def preprocessLevels (us : List Level) : MetaM (List Level × Array (Level × Level)) :=
do (us, r) ← us.foldlM
(fun (r : List Level × Array (Level × Level)) (u : Level) => do
u ← instantiateLevelMVars u;
if u.hasMVar then do
u' ← mkFreshLevelMVar;
pure (u'::r.1, r.2.push (u, u'))
else
pure (u::r.1, r.2))
([], #[]);
pure (us.reverse, r)
private partial def preprocessArgs (ys : Array Expr) : Nat → Array Expr → Array (Expr × Expr) → MetaM (Array Expr × Array (Expr × Expr))
| i, args, r => do
if h : i < ys.size then do
let y := ys.get ⟨i, h⟩;
yType ← inferType y;
if isOutParam yType then do
if h : i < args.size then do
let arg := args.get ⟨i, h⟩;
arg' ← mkFreshExprMVar yType;
preprocessArgs (i+1) (args.set ⟨i, h⟩ arg') (r.push (arg, arg'))
else
throw $ Exception.other "type class resolution failed, insufficient number of arguments" -- TODO improve error message
else
preprocessArgs (i+1) args r
else
pure (args, r)
private def preprocessOutParam (type : Expr) : MetaM (Expr × Replacements) :=
forallTelescope type $ fun xs typeBody =>
match typeBody.getAppFn with
| c@(Expr.const constName us _) => do
env ← getEnv;
if !hasOutParams env constName then pure (type, {})
else do
let args := typeBody.getAppArgs;
cType ← inferType c;
forallTelescopeReducing cType $ fun ys _ => do
(us, levelReplacements) ← preprocessLevels us;
(args, exprReplacements) ← preprocessArgs ys 0 args #[];
type ← mkForall xs (mkAppN (mkConst constName us) args);
pure (type, { levelReplacements := levelReplacements, exprReplacements := exprReplacements })
| _ => pure (type, {})
private def resolveReplacements (r : Replacements) : MetaM Bool :=
r.levelReplacements.allM (fun ⟨u, u'⟩ => isLevelDefEqAux u u')
<&&>
r.exprReplacements.allM (fun ⟨e, e'⟩ => isExprDefEqAux e e')
def synthInstance (type : Expr) : MetaM (Option Expr) :=
usingTransparency TransparencyMode.reducible $ do
type ← preprocess type;
s ← get;
match s.cache.synthInstance.find type with
| some result => pure result
| none => do
result ← withNewMCtxDepth $ do {
(normType, replacements) ← preprocessOutParam type;
trace `Meta.synthInstance $ fun _ => normType;
result? ← SynthInstance.main normType;
match result? with
| none => pure none
| some result => do
condM (resolveReplacements replacements)
(do result ← instantiateMVars result;
condM (hasAssignableMVar result)
(pure none)
(pure (some result)))
(pure none)
};
if type.hasMVar then do
modify $ fun s => { cache := { synthInstance := s.cache.synthInstance.insert type result, .. s.cache }, .. s };
pure result
else
pure result
end Meta
end Lean

View file

@ -289,4 +289,44 @@ do print "----- tst13 -----";
print e;
pure ()
#eval run [`Init.Core] tst13
def mkDecEq (type : Expr) : MetaM Expr :=
do u ← getLevel type;
pure $ mkApp (mkConst `DecidableEq [u]) type
def mkStateM (σ : Expr) : MetaM Expr :=
do u ← getLevel σ;
(some u) ← pure u.dec | throw $ Exception.other "failed to create StateM application";
let r := mkApp (mkConst `StateM [u]) σ;
check r;
pure r
def mkMonadState (σ m : Expr) : MetaM Expr :=
do u ← getLevel σ;
(some u) ← pure u.dec | throw $ Exception.other "failed to create MonadState application";
mType ← inferType m;
v ← mkFreshLevelMVar;
w ← mkFreshLevelMVar;
let arrow := mkArrow (mkSort (mkLevelSucc v)) (mkSort (mkLevelSucc w));
mType ← whnf mType;
print arrow;
print mType;
condM (isDefEq arrow mType)
(do w ← instantiateLevelMVars w;
let r := mkAppB (mkConst `MonadState [u, w]) σ m;
print r;
check r;
pure r)
(throw $ Exception.other "failed to create MonadState application")
def tst14 : MetaM Unit :=
do print "----- tst14 -----";
decEqNat ← mkDecEq nat;
c ← synthInstance decEqNat;
stateM ← mkStateM nat;
print stateM;
monadState ← mkMonadState nat stateM;
print monadState;
c ← synthInstance monadState;
pure ()
#eval run [`Init.Control.State] tst14