From 833c587fa30cd9dfcc0faa090b2cac5e8a16495f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 2 Dec 2019 19:00:43 -0800 Subject: [PATCH] feat: add `generate`, `newSubgoal`, `tryResolve`, and simpler table TODO: `resume` --- src/Init/Lean/Meta/AbstractMVars.lean | 22 ++- src/Init/Lean/Meta/Exception.lean | 3 + src/Init/Lean/Meta/Instances.lean | 15 +- src/Init/Lean/Meta/SynthInstance.lean | 242 +++++++++++++++++++++++--- 4 files changed, 249 insertions(+), 33 deletions(-) diff --git a/src/Init/Lean/Meta/AbstractMVars.lean b/src/Init/Lean/Meta/AbstractMVars.lean index ca6ac27e36..4f1131d026 100644 --- a/src/Init/Lean/Meta/AbstractMVars.lean +++ b/src/Init/Lean/Meta/AbstractMVars.lean @@ -14,15 +14,21 @@ structure AbstractMVarsResult := (numMVars : Nat) (expr : Expr) +def AbstractMVarsResult.beq (r₁ r₂ : AbstractMVarsResult) : Bool := +r₁.paramNames == r₂.paramNames && r₁.numMVars == r₂.numMVars && r₁.expr == r₂.expr + +instance AbstractMVarsResult.hasBeq : HasBeq AbstractMVarsResult := ⟨AbstractMVarsResult.beq⟩ + namespace AbstractMVars structure State := -(ngen : NameGenerator) -(lctx : LocalContext) -(paramNames : Array Name := #[]) -(fvars : Array Expr := #[]) -(lmap : HashMap Name Level := {}) -(emap : HashMap Name Expr := {}) +(ngen : NameGenerator) +(lctx : LocalContext) +(nextParamIdx : Nat := 0) +(paramNames : Array Name := #[]) +(fvars : Array Expr := #[]) +(lmap : HashMap Name Level := {}) +(emap : HashMap Name Expr := {}) abbrev M := ReaderT MetavarContext (StateM State) @@ -55,9 +61,9 @@ private partial def abstractLevelMVars : Level → M Level match s.lmap.find mvarId with | some u => pure u | none => do - paramId ← mkFreshId; + let paramId := mkNameNum `_abstMVar s.nextParamIdx; let u := mkLevelParam paramId; - modify $ fun s => { lmap := s.lmap.insert mvarId u, paramNames := s.paramNames.push paramId, .. s }; + modify $ fun s => { nextParamIdx := s.nextParamIdx + 1, lmap := s.lmap.insert mvarId u, paramNames := s.paramNames.push paramId, .. s }; pure u partial def abstractExprMVars : Expr → M Expr diff --git a/src/Init/Lean/Meta/Exception.lean b/src/Init/Lean/Meta/Exception.lean index 9ea0d94a5b..2b7a92d3f0 100644 --- a/src/Init/Lean/Meta/Exception.lean +++ b/src/Init/Lean/Meta/Exception.lean @@ -33,6 +33,7 @@ inductive Exception | isExprDefEqStuck (t s : Expr) (ctx : ExceptionContext) | letTypeMismatch (fvarId : FVarId) (ctx : ExceptionContext) | appTypeMismatch (f a : Expr) (ctx : ExceptionContext) +| notInstance (e : Expr) (ctx : ExceptionContext) | bug (b : Bug) (ctx : ExceptionContext) | other (msg : String) @@ -56,6 +57,7 @@ def toStr : Exception → String | isExprDefEqStuck _ _ _ => "isDefEq is stuck" | letTypeMismatch _ _ => "type mismatch at let-expression" | appTypeMismatch _ _ _ => "application type mismatch" +| notInstance _ _ => "type class instance expected" | bug _ _ => "bug" | other s => s @@ -80,6 +82,7 @@ def toMessageData : Exception → MessageData | isExprDefEqStuck t s ctx => mkCtx ctx $ `isExprDefEqStuck ++ " " ++ t ++ " =?= " ++ s | letTypeMismatch fvarId ctx => mkCtx ctx $ `letTypeMismatch ++ " " ++ mkFVar fvarId | appTypeMismatch f a ctx => mkCtx ctx $ `appTypeMismatch ++ " " ++ mkApp f a +| notInstance i ctx => mkCtx ctx $ `notInstance ++ " " ++ i | bug _ _ => "internal bug" -- TODO improve | other s => s diff --git a/src/Init/Lean/Meta/Instances.lean b/src/Init/Lean/Meta/Instances.lean index bc2fa1befb..84b90878c6 100644 --- a/src/Init/Lean/Meta/Instances.lean +++ b/src/Init/Lean/Meta/Instances.lean @@ -42,9 +42,6 @@ match env.find constName with (keys, env) ← IO.runMeta (mkInstanceKey c) env; pure $ instanceExtension.addEntry env { keys := keys, val := c } -def getInstances (env : Environment) : Instances := -instanceExtension.getState env - @[init] def registerInstanceAttr : IO Unit := registerAttribute { name := `instance, @@ -57,4 +54,16 @@ registerAttribute { } end Meta + +def Environment.getGlobalInstances (env : Environment) : Meta.Instances := +Meta.instanceExtension.getState env + +namespace Meta + +def getGlobalInstances : MetaM Instances := +do env ← getEnv; + pure env.getGlobalInstances + +end Meta + end Lean diff --git a/src/Init/Lean/Meta/SynthInstance.lean b/src/Init/Lean/Meta/SynthInstance.lean index 09bd907077..70d65617f1 100644 --- a/src/Init/Lean/Meta/SynthInstance.lean +++ b/src/Init/Lean/Meta/SynthInstance.lean @@ -14,39 +14,93 @@ namespace Lean namespace Meta namespace SynthInstance -structure Context extends Meta.Context := -(globalInstances : DiscrTree Expr := {}) - structure GeneratorNode := +(mvar : Expr) +(key : Expr) (mctx : MetavarContext) (instances : Array Expr) (currInstanceIdx : Nat) +instance GeneratorNode.inhabited : Inhabited GeneratorNode := ⟨⟨arbitrary _, arbitrary _, arbitrary _, arbitrary _, 0⟩⟩ + structure ConsumerNode := +(mvar : Expr) +(key : Expr) (mctx : MetavarContext) (subgoals : List Expr) -(answer : MVarId) inductive Waiter | consumerNode : ConsumerNode → Waiter | root : Waiter -/- -We represent the tabled/cached entries using +def Waiter.isRoot : Waiter → Bool +| Waiter.consumerNode _ => false +| Waiter.root => true -1- An imperfect discrimination tree that stores the type class instances (i.e., types) - an unique index. +namespace MkTableKey -2- A persistent array which represents a map from unique indices to `TableEntry`. --/ +structure State := +(nextLevelIdx : Nat := 0) +(nextExprIdx : Nat := 0) +(lmap : HashMap MVarId Level := {}) +(emap : HashMap MVarId Expr := {}) -structure Key := -(key : AbstractMVarsResult) -(idx : Nat) +abbrev M := ReaderT MetavarContext (StateM State) + +partial def normLevel : Level → M Level +| u => if !u.hasMVar then pure u else + match u with + | Level.succ v _ => do v ← normLevel v; pure $ u.updateSucc! v + | Level.max v w _ => do v ← normLevel v; w ← normLevel w; pure $ u.updateMax! v w + | Level.imax v w _ => do v ← normLevel v; w ← normLevel w; pure $ u.updateIMax! v w + | Level.mvar mvarId _ => do + + mctx ← read; + if !mctx.isLevelAssignable mvarId then pure u + else do + s ← get; + match s.lmap.find mvarId with + | some u' => pure u' + | none => do + let u' := mkLevelParam $ mkNameNum `_synthKey s.nextLevelIdx; + modify $ fun s => { nextLevelIdx := s.nextLevelIdx + 1, lmap := s.lmap.insert mvarId u', .. s }; + pure u' + | u => pure u + +partial def normExpr : Expr → M Expr +| e => if !e.hasMVar then pure e else + match e with + | Expr.const _ us _ => do us ← us.mapM normLevel; pure $ e.updateConst! us + | Expr.sort u _ => do u ← normLevel u; pure $ e.updateSort! u + | Expr.app f a _ => do f ← normExpr f; a ← normExpr a; pure $ e.updateApp! f a + | Expr.letE _ t v b _ => do t ← normExpr t; v ← normExpr v; b ← normExpr b; pure $ e.updateLet! t v b + | Expr.forallE _ d b _ => do d ← normExpr d; b ← normExpr b; pure $ e.updateForallE! d b + | Expr.lam _ d b _ => do d ← normExpr d; b ← normExpr b; pure $ e.updateLambdaE! d b + | Expr.mdata _ b _ => do b ← normExpr b; pure $ e.updateMData! b + | Expr.proj _ _ b _ => do b ← normExpr b; pure $ e.updateProj! b + | Expr.mvar mvarId _ => do + mctx ← read; + if !mctx.isExprAssignable mvarId then pure e + else do + s ← get; + match s.emap.find mvarId with + | some e' => pure e' + | none => do + let e' := mkFVar $ mkNameNum `_synthKey s.nextExprIdx; + modify $ fun s => { nextExprIdx := s.nextExprIdx + 1, emap := s.emap.insert mvarId e', .. s }; + pure e' + | _ => pure e + +end MkTableKey + +def mkTableKey (mctx : MetavarContext) (e : Expr) : Expr := +(MkTableKey.normExpr e mctx).run' {} + +abbrev Answer := AbstractMVarsResult structure TableEntry := (waiters : Array Waiter) -(answers : Array AbstractMVarsResult) +(answers : Array Answer := #[]) /- Remark: the SynthInstance.State is not really an extension of `Meta.State`. @@ -57,13 +111,15 @@ structure TableEntry := -/ structure State extends Meta.State := (mainMVarId : MVarId) -(generatorStack : Array GeneratorNode := #[]) -(resumeStack : Array (ConsumerNode × Expr) := #[]) -(tableKeys : DiscrTree Key := {}) -(tableEntries : PersistentArray TableEntry := {}) +(nextKeyIdx : Nat := 0) +(generatorStack : Array GeneratorNode := #[]) +(resumeStack : Array (ConsumerNode × Answer) := #[]) +(tableEntries : PersistentHashMap Expr TableEntry := {}) abbrev SynthM := ReaderT Context (EStateM Exception State) +instance SynthM.inhabited {α} : Inhabited (SynthM α) := ⟨throw $ Exception.other ""⟩ + @[inline] private def getTraceState : SynthM TraceState := do s ← get; pure s.traceState @@ -81,11 +137,154 @@ whenM (MonadTracerAdapter.isTracingEnabledFor cls) $ do s ← get; MonadTracerAdapter.addTrace cls (MessageData.context s.env mctx ctx.lctx (msg ())) -@[inline] def runMetaM {α} (x : MetaM α) : SynthM α := -fun ctx => adaptState (fun (s : State) => (s.toState, s)) (fun s' s => { toState := s', .. s }) (x ctx.toContext) +@[inline] def liftMeta {α} (x : MetaM α) : SynthM α := +adaptState (fun (s : State) => (s.toState, s)) (fun s' s => { toState := s', .. s }) x + +instance meta2Synth {α} : HasCoe (MetaM α) (SynthM α) := ⟨liftMeta⟩ + +/-- Return globals and locals instances that may unify with `type` -/ +def getInstances (type : Expr) : MetaM (Array Expr) := +forallTelescopeReducing type $ fun _ type => do + className? ← isClass type; + match className? with + | none => throwEx $ Exception.notInstance type + | some className => do + globalInstances ← getGlobalInstances; + result ← globalInstances.getUnify type; + localInstances ← getLocalInstances; + let result := localInstances.foldl + (fun (result : Array Expr) linst => if linst.className == className then result.push linst.fvar else result) + result; + pure result + +/-- Create a new generator node for `mvar` and add `waiter` as its waiter. + `key` must be `mkTableKey mctx mvarType`. -/ +def newSubgoal (key : Expr) (mvar : Expr) (waiter : Waiter) : SynthM Unit := +do mvarType ← inferType mvar; + instances ← getInstances mvarType; + if instances.isEmpty then pure () + else do + mctx ← getMCtx; + let node : GeneratorNode := { + mvar := mvar, + key := key, + mctx := mctx, + instances := instances, + currInstanceIdx := instances.size + }; + let entry : TableEntry := { waiters := #[waiter] }; + modify $ fun s => + { generatorStack := s.generatorStack.push node, + tableEntries := s.tableEntries.insert key entry, + .. s } + +def wakeUp (answer : Answer) : Waiter → SynthM Unit +| Waiter.root => modify $ fun s => s -- TODO +| Waiter.consumerNode cNode => modify $ fun s => { resumeStack := s.resumeStack.push (cNode, answer), .. s } + +def findEntry (key : Expr) : SynthM (Option TableEntry) := +do s ← get; + pure $ s.tableEntries.find key + +def getEntry (key : Expr) : SynthM TableEntry := +do entry? ← findEntry key; + match entry? with + | none => panic! "invalid key at synthInstance" + | some entry => pure entry + +def newAnswer (key : Expr) (answer : Answer) : SynthM Unit := +do entry ← getEntry key; + if entry.answers.contains answer then pure () + else condM (pure (entry.waiters.any Waiter.isRoot) <||> hasAssignableMVar answer.expr) (pure ()) $ + let newEntry := { answers := entry.answers.push answer, .. entry }; + modify $ fun s => { tableEntries := s.tableEntries.insert key newEntry, .. s }; + entry.waiters.forM (wakeUp answer) + +def mkAnswer (cNode : ConsumerNode) : MetaM Answer := +withMCtx cNode.mctx $ do + val ← instantiateMVars cNode.mvar; + abstractMVars val + +def mkTableKeyFor (mctx : MetavarContext) (mvar : Expr) : MetaM Expr := +withMCtx mctx $ do + mvarType ← inferType mvar; + pure $ mkTableKey mctx mvarType + +def consume (cNode : ConsumerNode) : SynthM Unit := +match cNode.subgoals with +| [] => do + answer ← mkAnswer cNode; + newAnswer cNode.key answer +| mvar::_ => do + let waiter := Waiter.consumerNode cNode; + let key := mkTableKey cNode.mctx mvar; + entry? ← findEntry key; + match entry? with + | none => newSubgoal key mvar waiter + | some entry => modify $ fun s => + { resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack, + tableEntries := s.tableEntries.insert key { waiters := entry.waiters.push waiter, .. entry }, + .. s } + +private partial def mkInstanceTelescopeAux + (xs : Array Expr) : Array Expr → Nat → List Expr → Expr → Expr → MetaM (List Expr × Expr × Expr) +| mvars, j, subgoals, instVal, Expr.forallE n d b c => do + let d := d.instantiateRevRange j mvars.size mvars; + type ← mkForall xs d; + mvar ← mkFreshExprMVar type; + let arg := mkAppN mvar xs; + let instVal := mkApp instVal arg; + let subgoals := if c.binderInfo.isInstImplicit then mvar::subgoals else subgoals; + let mvars := mvars.push mvar; + mkInstanceTelescopeAux mvars j subgoals instVal b +| mvars, j, subgoals, instVal, type => do + let type := type.instantiateRevRange j mvars.size mvars; + type ← whnf type; + if type.isForall then + mkInstanceTelescopeAux mvars mvars.size subgoals instVal type + else + pure (subgoals, instVal, type) + +def mkInstanceTelescope (xs : Array Expr) (inst : Expr) : MetaM (List Expr × Expr × Expr) := +do instType ← inferType inst; + mkInstanceTelescopeAux xs #[] 0 [] inst instType + +def tryResolve (mctx : MetavarContext) (mvar : Expr) (inst : Expr) : MetaM (Option (MetavarContext × List Expr)) := +withMCtx mctx $ do + mvarType ← inferType mvar; + forallTelescopeReducing mvarType $ fun xs mvarTypeBody => do + (subgoals, instVal, instTypeBody) ← mkInstanceTelescope xs inst; + condM (isDefEq mvarTypeBody instTypeBody) + (do instVal ← mkLambda xs instVal; + condM (isDefEq mvar instVal) + (do mctx ← getMCtx; pure (some (mctx, subgoals))) + (pure none)) + (pure none) + +def getTop : SynthM GeneratorNode := +do s ← get; + pure s.generatorStack.back + +@[inline] def modifyTop (f : GeneratorNode → GeneratorNode) : SynthM Unit := +modify $ fun s => { generatorStack := s.generatorStack.modify (s.generatorStack.size - 1) f, .. s } + +def generate : SynthM Unit := +do gNode ← getTop; + if gNode.currInstanceIdx == 0 then + modify $ fun s => { generatorStack := s.generatorStack.pop, .. s } + else do + let idx := gNode.currInstanceIdx - 1; + modifyTop $ fun gNode => { currInstanceIdx := idx, .. gNode }; + let inst := gNode.instances.get! idx; + result? ← tryResolve gNode.mctx gNode.mvar inst; + match result? with + | none => pure () + | some (mctx, subgoals) => consume { key := gNode.key, mvar := gNode.mvar, subgoals := subgoals, mctx := mctx } def main (type : Expr) : MetaM (Option Expr) := -pure none -- TODO +do Meta.trace `Meta.synthInstance $ fun _ => type; + mvar ← mkFreshExprMVar type; + pure none -- TODO end SynthInstance @@ -159,7 +358,6 @@ usingTransparency TransparencyMode.reducible $ do | none => do result ← withNewMCtxDepth $ do { (normType, replacements) ← preprocessOutParam type; - trace `Meta.synthInstance $ fun _ => normType; result? ← SynthInstance.main normType; match result? with | none => pure none