feat: add generate, newSubgoal, tryResolve, and simpler table
TODO: `resume`
This commit is contained in:
parent
3eabda1c4d
commit
833c587fa3
4 changed files with 249 additions and 33 deletions
|
|
@ -14,15 +14,21 @@ structure AbstractMVarsResult :=
|
|||
(numMVars : Nat)
|
||||
(expr : Expr)
|
||||
|
||||
def AbstractMVarsResult.beq (r₁ r₂ : AbstractMVarsResult) : Bool :=
|
||||
r₁.paramNames == r₂.paramNames && r₁.numMVars == r₂.numMVars && r₁.expr == r₂.expr
|
||||
|
||||
instance AbstractMVarsResult.hasBeq : HasBeq AbstractMVarsResult := ⟨AbstractMVarsResult.beq⟩
|
||||
|
||||
namespace AbstractMVars
|
||||
|
||||
structure State :=
|
||||
(ngen : NameGenerator)
|
||||
(lctx : LocalContext)
|
||||
(paramNames : Array Name := #[])
|
||||
(fvars : Array Expr := #[])
|
||||
(lmap : HashMap Name Level := {})
|
||||
(emap : HashMap Name Expr := {})
|
||||
(ngen : NameGenerator)
|
||||
(lctx : LocalContext)
|
||||
(nextParamIdx : Nat := 0)
|
||||
(paramNames : Array Name := #[])
|
||||
(fvars : Array Expr := #[])
|
||||
(lmap : HashMap Name Level := {})
|
||||
(emap : HashMap Name Expr := {})
|
||||
|
||||
abbrev M := ReaderT MetavarContext (StateM State)
|
||||
|
||||
|
|
@ -55,9 +61,9 @@ private partial def abstractLevelMVars : Level → M Level
|
|||
match s.lmap.find mvarId with
|
||||
| some u => pure u
|
||||
| none => do
|
||||
paramId ← mkFreshId;
|
||||
let paramId := mkNameNum `_abstMVar s.nextParamIdx;
|
||||
let u := mkLevelParam paramId;
|
||||
modify $ fun s => { lmap := s.lmap.insert mvarId u, paramNames := s.paramNames.push paramId, .. s };
|
||||
modify $ fun s => { nextParamIdx := s.nextParamIdx + 1, lmap := s.lmap.insert mvarId u, paramNames := s.paramNames.push paramId, .. s };
|
||||
pure u
|
||||
|
||||
partial def abstractExprMVars : Expr → M Expr
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ inductive Exception
|
|||
| isExprDefEqStuck (t s : Expr) (ctx : ExceptionContext)
|
||||
| letTypeMismatch (fvarId : FVarId) (ctx : ExceptionContext)
|
||||
| appTypeMismatch (f a : Expr) (ctx : ExceptionContext)
|
||||
| notInstance (e : Expr) (ctx : ExceptionContext)
|
||||
| bug (b : Bug) (ctx : ExceptionContext)
|
||||
| other (msg : String)
|
||||
|
||||
|
|
@ -56,6 +57,7 @@ def toStr : Exception → String
|
|||
| isExprDefEqStuck _ _ _ => "isDefEq is stuck"
|
||||
| letTypeMismatch _ _ => "type mismatch at let-expression"
|
||||
| appTypeMismatch _ _ _ => "application type mismatch"
|
||||
| notInstance _ _ => "type class instance expected"
|
||||
| bug _ _ => "bug"
|
||||
| other s => s
|
||||
|
||||
|
|
@ -80,6 +82,7 @@ def toMessageData : Exception → MessageData
|
|||
| isExprDefEqStuck t s ctx => mkCtx ctx $ `isExprDefEqStuck ++ " " ++ t ++ " =?= " ++ s
|
||||
| letTypeMismatch fvarId ctx => mkCtx ctx $ `letTypeMismatch ++ " " ++ mkFVar fvarId
|
||||
| appTypeMismatch f a ctx => mkCtx ctx $ `appTypeMismatch ++ " " ++ mkApp f a
|
||||
| notInstance i ctx => mkCtx ctx $ `notInstance ++ " " ++ i
|
||||
| bug _ _ => "internal bug" -- TODO improve
|
||||
| other s => s
|
||||
|
||||
|
|
|
|||
|
|
@ -42,9 +42,6 @@ match env.find constName with
|
|||
(keys, env) ← IO.runMeta (mkInstanceKey c) env;
|
||||
pure $ instanceExtension.addEntry env { keys := keys, val := c }
|
||||
|
||||
def getInstances (env : Environment) : Instances :=
|
||||
instanceExtension.getState env
|
||||
|
||||
@[init] def registerInstanceAttr : IO Unit :=
|
||||
registerAttribute {
|
||||
name := `instance,
|
||||
|
|
@ -57,4 +54,16 @@ registerAttribute {
|
|||
}
|
||||
|
||||
end Meta
|
||||
|
||||
def Environment.getGlobalInstances (env : Environment) : Meta.Instances :=
|
||||
Meta.instanceExtension.getState env
|
||||
|
||||
namespace Meta
|
||||
|
||||
def getGlobalInstances : MetaM Instances :=
|
||||
do env ← getEnv;
|
||||
pure env.getGlobalInstances
|
||||
|
||||
end Meta
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -14,39 +14,93 @@ namespace Lean
|
|||
namespace Meta
|
||||
namespace SynthInstance
|
||||
|
||||
structure Context extends Meta.Context :=
|
||||
(globalInstances : DiscrTree Expr := {})
|
||||
|
||||
structure GeneratorNode :=
|
||||
(mvar : Expr)
|
||||
(key : Expr)
|
||||
(mctx : MetavarContext)
|
||||
(instances : Array Expr)
|
||||
(currInstanceIdx : Nat)
|
||||
|
||||
instance GeneratorNode.inhabited : Inhabited GeneratorNode := ⟨⟨arbitrary _, arbitrary _, arbitrary _, arbitrary _, 0⟩⟩
|
||||
|
||||
structure ConsumerNode :=
|
||||
(mvar : Expr)
|
||||
(key : Expr)
|
||||
(mctx : MetavarContext)
|
||||
(subgoals : List Expr)
|
||||
(answer : MVarId)
|
||||
|
||||
inductive Waiter
|
||||
| consumerNode : ConsumerNode → Waiter
|
||||
| root : Waiter
|
||||
|
||||
/-
|
||||
We represent the tabled/cached entries using
|
||||
def Waiter.isRoot : Waiter → Bool
|
||||
| Waiter.consumerNode _ => false
|
||||
| Waiter.root => true
|
||||
|
||||
1- An imperfect discrimination tree that stores the type class instances (i.e., types)
|
||||
an unique index.
|
||||
namespace MkTableKey
|
||||
|
||||
2- A persistent array which represents a map from unique indices to `TableEntry`.
|
||||
-/
|
||||
structure State :=
|
||||
(nextLevelIdx : Nat := 0)
|
||||
(nextExprIdx : Nat := 0)
|
||||
(lmap : HashMap MVarId Level := {})
|
||||
(emap : HashMap MVarId Expr := {})
|
||||
|
||||
structure Key :=
|
||||
(key : AbstractMVarsResult)
|
||||
(idx : Nat)
|
||||
abbrev M := ReaderT MetavarContext (StateM State)
|
||||
|
||||
partial def normLevel : Level → M Level
|
||||
| u => if !u.hasMVar then pure u else
|
||||
match u with
|
||||
| Level.succ v _ => do v ← normLevel v; pure $ u.updateSucc! v
|
||||
| Level.max v w _ => do v ← normLevel v; w ← normLevel w; pure $ u.updateMax! v w
|
||||
| Level.imax v w _ => do v ← normLevel v; w ← normLevel w; pure $ u.updateIMax! v w
|
||||
| Level.mvar mvarId _ => do
|
||||
|
||||
mctx ← read;
|
||||
if !mctx.isLevelAssignable mvarId then pure u
|
||||
else do
|
||||
s ← get;
|
||||
match s.lmap.find mvarId with
|
||||
| some u' => pure u'
|
||||
| none => do
|
||||
let u' := mkLevelParam $ mkNameNum `_synthKey s.nextLevelIdx;
|
||||
modify $ fun s => { nextLevelIdx := s.nextLevelIdx + 1, lmap := s.lmap.insert mvarId u', .. s };
|
||||
pure u'
|
||||
| u => pure u
|
||||
|
||||
partial def normExpr : Expr → M Expr
|
||||
| e => if !e.hasMVar then pure e else
|
||||
match e with
|
||||
| Expr.const _ us _ => do us ← us.mapM normLevel; pure $ e.updateConst! us
|
||||
| Expr.sort u _ => do u ← normLevel u; pure $ e.updateSort! u
|
||||
| Expr.app f a _ => do f ← normExpr f; a ← normExpr a; pure $ e.updateApp! f a
|
||||
| Expr.letE _ t v b _ => do t ← normExpr t; v ← normExpr v; b ← normExpr b; pure $ e.updateLet! t v b
|
||||
| Expr.forallE _ d b _ => do d ← normExpr d; b ← normExpr b; pure $ e.updateForallE! d b
|
||||
| Expr.lam _ d b _ => do d ← normExpr d; b ← normExpr b; pure $ e.updateLambdaE! d b
|
||||
| Expr.mdata _ b _ => do b ← normExpr b; pure $ e.updateMData! b
|
||||
| Expr.proj _ _ b _ => do b ← normExpr b; pure $ e.updateProj! b
|
||||
| Expr.mvar mvarId _ => do
|
||||
mctx ← read;
|
||||
if !mctx.isExprAssignable mvarId then pure e
|
||||
else do
|
||||
s ← get;
|
||||
match s.emap.find mvarId with
|
||||
| some e' => pure e'
|
||||
| none => do
|
||||
let e' := mkFVar $ mkNameNum `_synthKey s.nextExprIdx;
|
||||
modify $ fun s => { nextExprIdx := s.nextExprIdx + 1, emap := s.emap.insert mvarId e', .. s };
|
||||
pure e'
|
||||
| _ => pure e
|
||||
|
||||
end MkTableKey
|
||||
|
||||
def mkTableKey (mctx : MetavarContext) (e : Expr) : Expr :=
|
||||
(MkTableKey.normExpr e mctx).run' {}
|
||||
|
||||
abbrev Answer := AbstractMVarsResult
|
||||
|
||||
structure TableEntry :=
|
||||
(waiters : Array Waiter)
|
||||
(answers : Array AbstractMVarsResult)
|
||||
(answers : Array Answer := #[])
|
||||
|
||||
/-
|
||||
Remark: the SynthInstance.State is not really an extension of `Meta.State`.
|
||||
|
|
@ -57,13 +111,15 @@ structure TableEntry :=
|
|||
-/
|
||||
structure State extends Meta.State :=
|
||||
(mainMVarId : MVarId)
|
||||
(generatorStack : Array GeneratorNode := #[])
|
||||
(resumeStack : Array (ConsumerNode × Expr) := #[])
|
||||
(tableKeys : DiscrTree Key := {})
|
||||
(tableEntries : PersistentArray TableEntry := {})
|
||||
(nextKeyIdx : Nat := 0)
|
||||
(generatorStack : Array GeneratorNode := #[])
|
||||
(resumeStack : Array (ConsumerNode × Answer) := #[])
|
||||
(tableEntries : PersistentHashMap Expr TableEntry := {})
|
||||
|
||||
abbrev SynthM := ReaderT Context (EStateM Exception State)
|
||||
|
||||
instance SynthM.inhabited {α} : Inhabited (SynthM α) := ⟨throw $ Exception.other ""⟩
|
||||
|
||||
@[inline] private def getTraceState : SynthM TraceState :=
|
||||
do s ← get; pure s.traceState
|
||||
|
||||
|
|
@ -81,11 +137,154 @@ whenM (MonadTracerAdapter.isTracingEnabledFor cls) $ do
|
|||
s ← get;
|
||||
MonadTracerAdapter.addTrace cls (MessageData.context s.env mctx ctx.lctx (msg ()))
|
||||
|
||||
@[inline] def runMetaM {α} (x : MetaM α) : SynthM α :=
|
||||
fun ctx => adaptState (fun (s : State) => (s.toState, s)) (fun s' s => { toState := s', .. s }) (x ctx.toContext)
|
||||
@[inline] def liftMeta {α} (x : MetaM α) : SynthM α :=
|
||||
adaptState (fun (s : State) => (s.toState, s)) (fun s' s => { toState := s', .. s }) x
|
||||
|
||||
instance meta2Synth {α} : HasCoe (MetaM α) (SynthM α) := ⟨liftMeta⟩
|
||||
|
||||
/-- Return globals and locals instances that may unify with `type` -/
|
||||
def getInstances (type : Expr) : MetaM (Array Expr) :=
|
||||
forallTelescopeReducing type $ fun _ type => do
|
||||
className? ← isClass type;
|
||||
match className? with
|
||||
| none => throwEx $ Exception.notInstance type
|
||||
| some className => do
|
||||
globalInstances ← getGlobalInstances;
|
||||
result ← globalInstances.getUnify type;
|
||||
localInstances ← getLocalInstances;
|
||||
let result := localInstances.foldl
|
||||
(fun (result : Array Expr) linst => if linst.className == className then result.push linst.fvar else result)
|
||||
result;
|
||||
pure result
|
||||
|
||||
/-- Create a new generator node for `mvar` and add `waiter` as its waiter.
|
||||
`key` must be `mkTableKey mctx mvarType`. -/
|
||||
def newSubgoal (key : Expr) (mvar : Expr) (waiter : Waiter) : SynthM Unit :=
|
||||
do mvarType ← inferType mvar;
|
||||
instances ← getInstances mvarType;
|
||||
if instances.isEmpty then pure ()
|
||||
else do
|
||||
mctx ← getMCtx;
|
||||
let node : GeneratorNode := {
|
||||
mvar := mvar,
|
||||
key := key,
|
||||
mctx := mctx,
|
||||
instances := instances,
|
||||
currInstanceIdx := instances.size
|
||||
};
|
||||
let entry : TableEntry := { waiters := #[waiter] };
|
||||
modify $ fun s =>
|
||||
{ generatorStack := s.generatorStack.push node,
|
||||
tableEntries := s.tableEntries.insert key entry,
|
||||
.. s }
|
||||
|
||||
def wakeUp (answer : Answer) : Waiter → SynthM Unit
|
||||
| Waiter.root => modify $ fun s => s -- TODO
|
||||
| Waiter.consumerNode cNode => modify $ fun s => { resumeStack := s.resumeStack.push (cNode, answer), .. s }
|
||||
|
||||
def findEntry (key : Expr) : SynthM (Option TableEntry) :=
|
||||
do s ← get;
|
||||
pure $ s.tableEntries.find key
|
||||
|
||||
def getEntry (key : Expr) : SynthM TableEntry :=
|
||||
do entry? ← findEntry key;
|
||||
match entry? with
|
||||
| none => panic! "invalid key at synthInstance"
|
||||
| some entry => pure entry
|
||||
|
||||
def newAnswer (key : Expr) (answer : Answer) : SynthM Unit :=
|
||||
do entry ← getEntry key;
|
||||
if entry.answers.contains answer then pure ()
|
||||
else condM (pure (entry.waiters.any Waiter.isRoot) <||> hasAssignableMVar answer.expr) (pure ()) $
|
||||
let newEntry := { answers := entry.answers.push answer, .. entry };
|
||||
modify $ fun s => { tableEntries := s.tableEntries.insert key newEntry, .. s };
|
||||
entry.waiters.forM (wakeUp answer)
|
||||
|
||||
def mkAnswer (cNode : ConsumerNode) : MetaM Answer :=
|
||||
withMCtx cNode.mctx $ do
|
||||
val ← instantiateMVars cNode.mvar;
|
||||
abstractMVars val
|
||||
|
||||
def mkTableKeyFor (mctx : MetavarContext) (mvar : Expr) : MetaM Expr :=
|
||||
withMCtx mctx $ do
|
||||
mvarType ← inferType mvar;
|
||||
pure $ mkTableKey mctx mvarType
|
||||
|
||||
def consume (cNode : ConsumerNode) : SynthM Unit :=
|
||||
match cNode.subgoals with
|
||||
| [] => do
|
||||
answer ← mkAnswer cNode;
|
||||
newAnswer cNode.key answer
|
||||
| mvar::_ => do
|
||||
let waiter := Waiter.consumerNode cNode;
|
||||
let key := mkTableKey cNode.mctx mvar;
|
||||
entry? ← findEntry key;
|
||||
match entry? with
|
||||
| none => newSubgoal key mvar waiter
|
||||
| some entry => modify $ fun s =>
|
||||
{ resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
|
||||
tableEntries := s.tableEntries.insert key { waiters := entry.waiters.push waiter, .. entry },
|
||||
.. s }
|
||||
|
||||
private partial def mkInstanceTelescopeAux
|
||||
(xs : Array Expr) : Array Expr → Nat → List Expr → Expr → Expr → MetaM (List Expr × Expr × Expr)
|
||||
| mvars, j, subgoals, instVal, Expr.forallE n d b c => do
|
||||
let d := d.instantiateRevRange j mvars.size mvars;
|
||||
type ← mkForall xs d;
|
||||
mvar ← mkFreshExprMVar type;
|
||||
let arg := mkAppN mvar xs;
|
||||
let instVal := mkApp instVal arg;
|
||||
let subgoals := if c.binderInfo.isInstImplicit then mvar::subgoals else subgoals;
|
||||
let mvars := mvars.push mvar;
|
||||
mkInstanceTelescopeAux mvars j subgoals instVal b
|
||||
| mvars, j, subgoals, instVal, type => do
|
||||
let type := type.instantiateRevRange j mvars.size mvars;
|
||||
type ← whnf type;
|
||||
if type.isForall then
|
||||
mkInstanceTelescopeAux mvars mvars.size subgoals instVal type
|
||||
else
|
||||
pure (subgoals, instVal, type)
|
||||
|
||||
def mkInstanceTelescope (xs : Array Expr) (inst : Expr) : MetaM (List Expr × Expr × Expr) :=
|
||||
do instType ← inferType inst;
|
||||
mkInstanceTelescopeAux xs #[] 0 [] inst instType
|
||||
|
||||
def tryResolve (mctx : MetavarContext) (mvar : Expr) (inst : Expr) : MetaM (Option (MetavarContext × List Expr)) :=
|
||||
withMCtx mctx $ do
|
||||
mvarType ← inferType mvar;
|
||||
forallTelescopeReducing mvarType $ fun xs mvarTypeBody => do
|
||||
(subgoals, instVal, instTypeBody) ← mkInstanceTelescope xs inst;
|
||||
condM (isDefEq mvarTypeBody instTypeBody)
|
||||
(do instVal ← mkLambda xs instVal;
|
||||
condM (isDefEq mvar instVal)
|
||||
(do mctx ← getMCtx; pure (some (mctx, subgoals)))
|
||||
(pure none))
|
||||
(pure none)
|
||||
|
||||
def getTop : SynthM GeneratorNode :=
|
||||
do s ← get;
|
||||
pure s.generatorStack.back
|
||||
|
||||
@[inline] def modifyTop (f : GeneratorNode → GeneratorNode) : SynthM Unit :=
|
||||
modify $ fun s => { generatorStack := s.generatorStack.modify (s.generatorStack.size - 1) f, .. s }
|
||||
|
||||
def generate : SynthM Unit :=
|
||||
do gNode ← getTop;
|
||||
if gNode.currInstanceIdx == 0 then
|
||||
modify $ fun s => { generatorStack := s.generatorStack.pop, .. s }
|
||||
else do
|
||||
let idx := gNode.currInstanceIdx - 1;
|
||||
modifyTop $ fun gNode => { currInstanceIdx := idx, .. gNode };
|
||||
let inst := gNode.instances.get! idx;
|
||||
result? ← tryResolve gNode.mctx gNode.mvar inst;
|
||||
match result? with
|
||||
| none => pure ()
|
||||
| some (mctx, subgoals) => consume { key := gNode.key, mvar := gNode.mvar, subgoals := subgoals, mctx := mctx }
|
||||
|
||||
def main (type : Expr) : MetaM (Option Expr) :=
|
||||
pure none -- TODO
|
||||
do Meta.trace `Meta.synthInstance $ fun _ => type;
|
||||
mvar ← mkFreshExprMVar type;
|
||||
pure none -- TODO
|
||||
|
||||
end SynthInstance
|
||||
|
||||
|
|
@ -159,7 +358,6 @@ usingTransparency TransparencyMode.reducible $ do
|
|||
| none => do
|
||||
result ← withNewMCtxDepth $ do {
|
||||
(normType, replacements) ← preprocessOutParam type;
|
||||
trace `Meta.synthInstance $ fun _ => normType;
|
||||
result? ← SynthInstance.main normType;
|
||||
match result? with
|
||||
| none => pure none
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue