feat: add generate, newSubgoal, tryResolve, and simpler table

TODO: `resume`
This commit is contained in:
Leonardo de Moura 2019-12-02 19:00:43 -08:00
parent 3eabda1c4d
commit 833c587fa3
4 changed files with 249 additions and 33 deletions

View file

@ -14,15 +14,21 @@ structure AbstractMVarsResult :=
(numMVars : Nat)
(expr : Expr)
def AbstractMVarsResult.beq (r₁ r₂ : AbstractMVarsResult) : Bool :=
r₁.paramNames == r₂.paramNames && r₁.numMVars == r₂.numMVars && r₁.expr == r₂.expr
instance AbstractMVarsResult.hasBeq : HasBeq AbstractMVarsResult := ⟨AbstractMVarsResult.beq⟩
namespace AbstractMVars
structure State :=
(ngen : NameGenerator)
(lctx : LocalContext)
(paramNames : Array Name := #[])
(fvars : Array Expr := #[])
(lmap : HashMap Name Level := {})
(emap : HashMap Name Expr := {})
(ngen : NameGenerator)
(lctx : LocalContext)
(nextParamIdx : Nat := 0)
(paramNames : Array Name := #[])
(fvars : Array Expr := #[])
(lmap : HashMap Name Level := {})
(emap : HashMap Name Expr := {})
abbrev M := ReaderT MetavarContext (StateM State)
@ -55,9 +61,9 @@ private partial def abstractLevelMVars : Level → M Level
match s.lmap.find mvarId with
| some u => pure u
| none => do
paramId ← mkFreshId;
let paramId := mkNameNum `_abstMVar s.nextParamIdx;
let u := mkLevelParam paramId;
modify $ fun s => { lmap := s.lmap.insert mvarId u, paramNames := s.paramNames.push paramId, .. s };
modify $ fun s => { nextParamIdx := s.nextParamIdx + 1, lmap := s.lmap.insert mvarId u, paramNames := s.paramNames.push paramId, .. s };
pure u
partial def abstractExprMVars : Expr → M Expr

View file

@ -33,6 +33,7 @@ inductive Exception
| isExprDefEqStuck (t s : Expr) (ctx : ExceptionContext)
| letTypeMismatch (fvarId : FVarId) (ctx : ExceptionContext)
| appTypeMismatch (f a : Expr) (ctx : ExceptionContext)
| notInstance (e : Expr) (ctx : ExceptionContext)
| bug (b : Bug) (ctx : ExceptionContext)
| other (msg : String)
@ -56,6 +57,7 @@ def toStr : Exception → String
| isExprDefEqStuck _ _ _ => "isDefEq is stuck"
| letTypeMismatch _ _ => "type mismatch at let-expression"
| appTypeMismatch _ _ _ => "application type mismatch"
| notInstance _ _ => "type class instance expected"
| bug _ _ => "bug"
| other s => s
@ -80,6 +82,7 @@ def toMessageData : Exception → MessageData
| isExprDefEqStuck t s ctx => mkCtx ctx $ `isExprDefEqStuck ++ " " ++ t ++ " =?= " ++ s
| letTypeMismatch fvarId ctx => mkCtx ctx $ `letTypeMismatch ++ " " ++ mkFVar fvarId
| appTypeMismatch f a ctx => mkCtx ctx $ `appTypeMismatch ++ " " ++ mkApp f a
| notInstance i ctx => mkCtx ctx $ `notInstance ++ " " ++ i
| bug _ _ => "internal bug" -- TODO improve
| other s => s

View file

@ -42,9 +42,6 @@ match env.find constName with
(keys, env) ← IO.runMeta (mkInstanceKey c) env;
pure $ instanceExtension.addEntry env { keys := keys, val := c }
def getInstances (env : Environment) : Instances :=
instanceExtension.getState env
@[init] def registerInstanceAttr : IO Unit :=
registerAttribute {
name := `instance,
@ -57,4 +54,16 @@ registerAttribute {
}
end Meta
def Environment.getGlobalInstances (env : Environment) : Meta.Instances :=
Meta.instanceExtension.getState env
namespace Meta
def getGlobalInstances : MetaM Instances :=
do env ← getEnv;
pure env.getGlobalInstances
end Meta
end Lean

View file

@ -14,39 +14,93 @@ namespace Lean
namespace Meta
namespace SynthInstance
structure Context extends Meta.Context :=
(globalInstances : DiscrTree Expr := {})
structure GeneratorNode :=
(mvar : Expr)
(key : Expr)
(mctx : MetavarContext)
(instances : Array Expr)
(currInstanceIdx : Nat)
instance GeneratorNode.inhabited : Inhabited GeneratorNode := ⟨⟨arbitrary _, arbitrary _, arbitrary _, arbitrary _, 0⟩⟩
structure ConsumerNode :=
(mvar : Expr)
(key : Expr)
(mctx : MetavarContext)
(subgoals : List Expr)
(answer : MVarId)
inductive Waiter
| consumerNode : ConsumerNode → Waiter
| root : Waiter
/-
We represent the tabled/cached entries using
def Waiter.isRoot : Waiter → Bool
| Waiter.consumerNode _ => false
| Waiter.root => true
1- An imperfect discrimination tree that stores the type class instances (i.e., types)
an unique index.
namespace MkTableKey
2- A persistent array which represents a map from unique indices to `TableEntry`.
-/
structure State :=
(nextLevelIdx : Nat := 0)
(nextExprIdx : Nat := 0)
(lmap : HashMap MVarId Level := {})
(emap : HashMap MVarId Expr := {})
structure Key :=
(key : AbstractMVarsResult)
(idx : Nat)
abbrev M := ReaderT MetavarContext (StateM State)
partial def normLevel : Level → M Level
| u => if !u.hasMVar then pure u else
match u with
| Level.succ v _ => do v ← normLevel v; pure $ u.updateSucc! v
| Level.max v w _ => do v ← normLevel v; w ← normLevel w; pure $ u.updateMax! v w
| Level.imax v w _ => do v ← normLevel v; w ← normLevel w; pure $ u.updateIMax! v w
| Level.mvar mvarId _ => do
mctx ← read;
if !mctx.isLevelAssignable mvarId then pure u
else do
s ← get;
match s.lmap.find mvarId with
| some u' => pure u'
| none => do
let u' := mkLevelParam $ mkNameNum `_synthKey s.nextLevelIdx;
modify $ fun s => { nextLevelIdx := s.nextLevelIdx + 1, lmap := s.lmap.insert mvarId u', .. s };
pure u'
| u => pure u
partial def normExpr : Expr → M Expr
| e => if !e.hasMVar then pure e else
match e with
| Expr.const _ us _ => do us ← us.mapM normLevel; pure $ e.updateConst! us
| Expr.sort u _ => do u ← normLevel u; pure $ e.updateSort! u
| Expr.app f a _ => do f ← normExpr f; a ← normExpr a; pure $ e.updateApp! f a
| Expr.letE _ t v b _ => do t ← normExpr t; v ← normExpr v; b ← normExpr b; pure $ e.updateLet! t v b
| Expr.forallE _ d b _ => do d ← normExpr d; b ← normExpr b; pure $ e.updateForallE! d b
| Expr.lam _ d b _ => do d ← normExpr d; b ← normExpr b; pure $ e.updateLambdaE! d b
| Expr.mdata _ b _ => do b ← normExpr b; pure $ e.updateMData! b
| Expr.proj _ _ b _ => do b ← normExpr b; pure $ e.updateProj! b
| Expr.mvar mvarId _ => do
mctx ← read;
if !mctx.isExprAssignable mvarId then pure e
else do
s ← get;
match s.emap.find mvarId with
| some e' => pure e'
| none => do
let e' := mkFVar $ mkNameNum `_synthKey s.nextExprIdx;
modify $ fun s => { nextExprIdx := s.nextExprIdx + 1, emap := s.emap.insert mvarId e', .. s };
pure e'
| _ => pure e
end MkTableKey
def mkTableKey (mctx : MetavarContext) (e : Expr) : Expr :=
(MkTableKey.normExpr e mctx).run' {}
abbrev Answer := AbstractMVarsResult
structure TableEntry :=
(waiters : Array Waiter)
(answers : Array AbstractMVarsResult)
(answers : Array Answer := #[])
/-
Remark: the SynthInstance.State is not really an extension of `Meta.State`.
@ -57,13 +111,15 @@ structure TableEntry :=
-/
structure State extends Meta.State :=
(mainMVarId : MVarId)
(generatorStack : Array GeneratorNode := #[])
(resumeStack : Array (ConsumerNode × Expr) := #[])
(tableKeys : DiscrTree Key := {})
(tableEntries : PersistentArray TableEntry := {})
(nextKeyIdx : Nat := 0)
(generatorStack : Array GeneratorNode := #[])
(resumeStack : Array (ConsumerNode × Answer) := #[])
(tableEntries : PersistentHashMap Expr TableEntry := {})
abbrev SynthM := ReaderT Context (EStateM Exception State)
instance SynthM.inhabited {α} : Inhabited (SynthM α) := ⟨throw $ Exception.other ""⟩
@[inline] private def getTraceState : SynthM TraceState :=
do s ← get; pure s.traceState
@ -81,11 +137,154 @@ whenM (MonadTracerAdapter.isTracingEnabledFor cls) $ do
s ← get;
MonadTracerAdapter.addTrace cls (MessageData.context s.env mctx ctx.lctx (msg ()))
@[inline] def runMetaM {α} (x : MetaM α) : SynthM α :=
fun ctx => adaptState (fun (s : State) => (s.toState, s)) (fun s' s => { toState := s', .. s }) (x ctx.toContext)
@[inline] def liftMeta {α} (x : MetaM α) : SynthM α :=
adaptState (fun (s : State) => (s.toState, s)) (fun s' s => { toState := s', .. s }) x
instance meta2Synth {α} : HasCoe (MetaM α) (SynthM α) := ⟨liftMeta⟩
/-- Return globals and locals instances that may unify with `type` -/
def getInstances (type : Expr) : MetaM (Array Expr) :=
forallTelescopeReducing type $ fun _ type => do
className? ← isClass type;
match className? with
| none => throwEx $ Exception.notInstance type
| some className => do
globalInstances ← getGlobalInstances;
result ← globalInstances.getUnify type;
localInstances ← getLocalInstances;
let result := localInstances.foldl
(fun (result : Array Expr) linst => if linst.className == className then result.push linst.fvar else result)
result;
pure result
/-- Create a new generator node for `mvar` and add `waiter` as its waiter.
`key` must be `mkTableKey mctx mvarType`. -/
def newSubgoal (key : Expr) (mvar : Expr) (waiter : Waiter) : SynthM Unit :=
do mvarType ← inferType mvar;
instances ← getInstances mvarType;
if instances.isEmpty then pure ()
else do
mctx ← getMCtx;
let node : GeneratorNode := {
mvar := mvar,
key := key,
mctx := mctx,
instances := instances,
currInstanceIdx := instances.size
};
let entry : TableEntry := { waiters := #[waiter] };
modify $ fun s =>
{ generatorStack := s.generatorStack.push node,
tableEntries := s.tableEntries.insert key entry,
.. s }
def wakeUp (answer : Answer) : Waiter → SynthM Unit
| Waiter.root => modify $ fun s => s -- TODO
| Waiter.consumerNode cNode => modify $ fun s => { resumeStack := s.resumeStack.push (cNode, answer), .. s }
def findEntry (key : Expr) : SynthM (Option TableEntry) :=
do s ← get;
pure $ s.tableEntries.find key
def getEntry (key : Expr) : SynthM TableEntry :=
do entry? ← findEntry key;
match entry? with
| none => panic! "invalid key at synthInstance"
| some entry => pure entry
def newAnswer (key : Expr) (answer : Answer) : SynthM Unit :=
do entry ← getEntry key;
if entry.answers.contains answer then pure ()
else condM (pure (entry.waiters.any Waiter.isRoot) <||> hasAssignableMVar answer.expr) (pure ()) $
let newEntry := { answers := entry.answers.push answer, .. entry };
modify $ fun s => { tableEntries := s.tableEntries.insert key newEntry, .. s };
entry.waiters.forM (wakeUp answer)
def mkAnswer (cNode : ConsumerNode) : MetaM Answer :=
withMCtx cNode.mctx $ do
val ← instantiateMVars cNode.mvar;
abstractMVars val
def mkTableKeyFor (mctx : MetavarContext) (mvar : Expr) : MetaM Expr :=
withMCtx mctx $ do
mvarType ← inferType mvar;
pure $ mkTableKey mctx mvarType
def consume (cNode : ConsumerNode) : SynthM Unit :=
match cNode.subgoals with
| [] => do
answer ← mkAnswer cNode;
newAnswer cNode.key answer
| mvar::_ => do
let waiter := Waiter.consumerNode cNode;
let key := mkTableKey cNode.mctx mvar;
entry? ← findEntry key;
match entry? with
| none => newSubgoal key mvar waiter
| some entry => modify $ fun s =>
{ resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
tableEntries := s.tableEntries.insert key { waiters := entry.waiters.push waiter, .. entry },
.. s }
private partial def mkInstanceTelescopeAux
(xs : Array Expr) : Array Expr → Nat → List Expr → Expr → Expr → MetaM (List Expr × Expr × Expr)
| mvars, j, subgoals, instVal, Expr.forallE n d b c => do
let d := d.instantiateRevRange j mvars.size mvars;
type ← mkForall xs d;
mvar ← mkFreshExprMVar type;
let arg := mkAppN mvar xs;
let instVal := mkApp instVal arg;
let subgoals := if c.binderInfo.isInstImplicit then mvar::subgoals else subgoals;
let mvars := mvars.push mvar;
mkInstanceTelescopeAux mvars j subgoals instVal b
| mvars, j, subgoals, instVal, type => do
let type := type.instantiateRevRange j mvars.size mvars;
type ← whnf type;
if type.isForall then
mkInstanceTelescopeAux mvars mvars.size subgoals instVal type
else
pure (subgoals, instVal, type)
def mkInstanceTelescope (xs : Array Expr) (inst : Expr) : MetaM (List Expr × Expr × Expr) :=
do instType ← inferType inst;
mkInstanceTelescopeAux xs #[] 0 [] inst instType
def tryResolve (mctx : MetavarContext) (mvar : Expr) (inst : Expr) : MetaM (Option (MetavarContext × List Expr)) :=
withMCtx mctx $ do
mvarType ← inferType mvar;
forallTelescopeReducing mvarType $ fun xs mvarTypeBody => do
(subgoals, instVal, instTypeBody) ← mkInstanceTelescope xs inst;
condM (isDefEq mvarTypeBody instTypeBody)
(do instVal ← mkLambda xs instVal;
condM (isDefEq mvar instVal)
(do mctx ← getMCtx; pure (some (mctx, subgoals)))
(pure none))
(pure none)
def getTop : SynthM GeneratorNode :=
do s ← get;
pure s.generatorStack.back
@[inline] def modifyTop (f : GeneratorNode → GeneratorNode) : SynthM Unit :=
modify $ fun s => { generatorStack := s.generatorStack.modify (s.generatorStack.size - 1) f, .. s }
def generate : SynthM Unit :=
do gNode ← getTop;
if gNode.currInstanceIdx == 0 then
modify $ fun s => { generatorStack := s.generatorStack.pop, .. s }
else do
let idx := gNode.currInstanceIdx - 1;
modifyTop $ fun gNode => { currInstanceIdx := idx, .. gNode };
let inst := gNode.instances.get! idx;
result? ← tryResolve gNode.mctx gNode.mvar inst;
match result? with
| none => pure ()
| some (mctx, subgoals) => consume { key := gNode.key, mvar := gNode.mvar, subgoals := subgoals, mctx := mctx }
def main (type : Expr) : MetaM (Option Expr) :=
pure none -- TODO
do Meta.trace `Meta.synthInstance $ fun _ => type;
mvar ← mkFreshExprMVar type;
pure none -- TODO
end SynthInstance
@ -159,7 +358,6 @@ usingTransparency TransparencyMode.reducible $ do
| none => do
result ← withNewMCtxDepth $ do {
(normType, replacements) ← preprocessOutParam type;
trace `Meta.synthInstance $ fun _ => normType;
result? ← SynthInstance.main normType;
match result? with
| none => pure none