chore: cleanup
This commit is contained in:
parent
da68f629f9
commit
82cd5c8eef
1 changed files with 50 additions and 55 deletions
|
|
@ -91,7 +91,7 @@ abbrev M := ReaderT MetavarContext (StateM State)
|
|||
|
||||
partial def normLevel (u : Level) : M Level := do
|
||||
if !u.hasMVar then
|
||||
pure u
|
||||
return u
|
||||
else match u with
|
||||
| Level.succ v _ => return u.updateSucc! (← normLevel v)
|
||||
| Level.max v w _ => return u.updateMax! (← normLevel v) (← normLevel w)
|
||||
|
|
@ -99,16 +99,16 @@ partial def normLevel (u : Level) : M Level := do
|
|||
| Level.mvar mvarId _ =>
|
||||
let mctx ← read
|
||||
if !mctx.isLevelAssignable mvarId then
|
||||
pure u
|
||||
return u
|
||||
else
|
||||
let s ← get
|
||||
match s.lmap.find? mvarId with
|
||||
match (← get).lmap.find? mvarId with
|
||||
| some u' => pure u'
|
||||
| none =>
|
||||
let u' := mkLevelParam $ Name.mkNum `_tc s.nextIdx
|
||||
let u' := mkLevelParam <| Name.mkNum `_tc s.nextIdx
|
||||
modify fun s => { s with nextIdx := s.nextIdx + 1, lmap := s.lmap.insert mvarId u' }
|
||||
pure u'
|
||||
| u => pure u
|
||||
return u'
|
||||
| u => return u
|
||||
|
||||
partial def normExpr (e : Expr) : M Expr := do
|
||||
if !e.hasMVar then
|
||||
|
|
@ -123,9 +123,8 @@ partial def normExpr (e : Expr) : M Expr := do
|
|||
| Expr.mdata _ b _ => return e.updateMData! (← normExpr b)
|
||||
| Expr.proj _ _ b _ => return e.updateProj! (← normExpr b)
|
||||
| Expr.mvar mvarId _ =>
|
||||
let mctx ← read
|
||||
if !mctx.isExprAssignable mvarId then
|
||||
pure e
|
||||
if !(← read).isExprAssignable mvarId then
|
||||
return e
|
||||
else
|
||||
let s ← get
|
||||
match s.emap.find? mvarId with
|
||||
|
|
@ -133,8 +132,8 @@ partial def normExpr (e : Expr) : M Expr := do
|
|||
| none => do
|
||||
let e' := mkFVar { name := Name.mkNum `_tc s.nextIdx }
|
||||
modify fun s => { s with nextIdx := s.nextIdx + 1, emap := s.emap.insert mvarId e' }
|
||||
pure e'
|
||||
| _ => pure e
|
||||
return e'
|
||||
| _ => return e
|
||||
|
||||
end MkTableKey
|
||||
|
||||
|
|
@ -205,26 +204,24 @@ def getInstances (type : Expr) : MetaM (Array Expr) := do
|
|||
trace[Meta.synthInstance.globalInstances] "{type}, {result}"
|
||||
let result := localInstances.foldl (init := result) fun (result : Array Expr) linst =>
|
||||
if linst.className == className then result.push linst.fvar else result
|
||||
pure result
|
||||
return result
|
||||
|
||||
def mkGeneratorNode? (key mvar : Expr) : MetaM (Option GeneratorNode) := do
|
||||
let mvarType ← inferType mvar
|
||||
let mvarType ← instantiateMVars mvarType
|
||||
let instances ← getInstances mvarType
|
||||
if instances.isEmpty then
|
||||
pure none
|
||||
return none
|
||||
else
|
||||
let mctx ← getMCtx
|
||||
pure $ some {
|
||||
mvar := mvar,
|
||||
key := key,
|
||||
mctx := mctx,
|
||||
instances := instances,
|
||||
return some {
|
||||
mvar, key, mctx, instances
|
||||
currInstanceIdx := instances.size
|
||||
}
|
||||
|
||||
/-- Create a new generator node for `mvar` and add `waiter` as its waiter.
|
||||
`key` must be `mkTableKey mctx mvarType`. -/
|
||||
/--
|
||||
Create a new generator node for `mvar` and add `waiter` as its waiter.
|
||||
`key` must be `mkTableKey mctx mvarType`. -/
|
||||
def newSubgoal (mctx : MetavarContext) (key : Expr) (mvar : Expr) (waiter : Waiter) : SynthM Unit :=
|
||||
withMCtx mctx do
|
||||
trace[Meta.synthInstance.newSubgoal] key
|
||||
|
|
@ -234,7 +231,7 @@ def newSubgoal (mctx : MetavarContext) (key : Expr) (mvar : Expr) (waiter : Wait
|
|||
let entry : TableEntry := { waiters := #[waiter] }
|
||||
modify fun s =>
|
||||
{ s with
|
||||
generatorStack := s.generatorStack.push node,
|
||||
generatorStack := s.generatorStack.push node
|
||||
tableEntries := s.tableEntries.insert key entry }
|
||||
|
||||
def findEntry? (key : Expr) : SynthM (Option TableEntry) := do
|
||||
|
|
@ -284,7 +281,7 @@ private partial def getSubgoalsAux (lctx : LocalContext) (localInsts : LocalInst
|
|||
if type.isForall then
|
||||
getSubgoalsAux lctx localInsts xs args args.size subgoals instVal type
|
||||
else
|
||||
pure ⟨subgoals, instVal, type⟩
|
||||
return ⟨subgoals, instVal, type⟩
|
||||
|
||||
/--
|
||||
`getSubgoals lctx localInsts xs inst` creates the subgoals for the instance `inst`.
|
||||
|
|
@ -308,10 +305,10 @@ def getSubgoals (lctx : LocalContext) (localInsts : LocalInstances) (xs : Array
|
|||
| Expr.const constName _ _ =>
|
||||
let env ← getEnv
|
||||
if hasInferTCGoalsRLAttribute env constName then
|
||||
pure result
|
||||
return result
|
||||
else
|
||||
pure { result with subgoals := result.subgoals.reverse }
|
||||
| _ => pure result
|
||||
return { result with subgoals := result.subgoals.reverse }
|
||||
| _ => return result
|
||||
|
||||
def tryResolveCore (mvar : Expr) (inst : Expr) : MetaM (Option (MetavarContext × List Expr)) := do
|
||||
let mvar ← instantiateMVars mvar
|
||||
|
|
@ -345,13 +342,13 @@ def tryResolveCore (mvar : Expr) (inst : Expr) : MetaM (Option (MetavarContext
|
|||
let instVal ← mkLambdaFVars xs instVal
|
||||
if (← isDefEq mvar instVal) then
|
||||
trace[Meta.synthInstance.tryResolve] "success"
|
||||
pure (some ((← getMCtx), subgoals))
|
||||
return some ((← getMCtx), subgoals)
|
||||
else
|
||||
trace[Meta.synthInstance.tryResolve] "failure assigning"
|
||||
pure none
|
||||
return none
|
||||
else
|
||||
trace[Meta.synthInstance.tryResolve] "failure"
|
||||
pure none
|
||||
return none
|
||||
|
||||
/--
|
||||
Try to synthesize metavariable `mvar` using the instance `inst`.
|
||||
|
|
@ -368,9 +365,9 @@ def tryAnswer (mctx : MetavarContext) (mvar : Expr) (answer : Answer) : SynthM (
|
|||
withMCtx mctx do
|
||||
let (_, _, val) ← openAbstractMVarsResult answer.result
|
||||
if (← isDefEq mvar val) then
|
||||
pure (some (← getMCtx))
|
||||
return some (← getMCtx)
|
||||
else
|
||||
pure none
|
||||
return none
|
||||
|
||||
/-- Move waiters that are waiting for the given answer to the resume stack. -/
|
||||
def wakeUp (answer : Answer) : Waiter → SynthM Unit
|
||||
|
|
@ -384,7 +381,6 @@ def wakeUp (answer : Answer) : Waiter → SynthM Unit
|
|||
else
|
||||
let (_, _, answerExpr) ← openAbstractMVarsResult answer.result
|
||||
trace[Meta.synthInstance] "skip answer containing metavariables {answerExpr}"
|
||||
pure ()
|
||||
| Waiter.consumerNode cNode =>
|
||||
modify fun s => { s with resumeStack := s.resumeStack.push (cNode, answer) }
|
||||
|
||||
|
|
@ -401,7 +397,7 @@ private def mkAnswer (cNode : ConsumerNode) : MetaM Answer :=
|
|||
trace[Meta.synthInstance.newAnswer] "val: {val}"
|
||||
let result ← abstractMVars val -- assignable metavariables become parameters
|
||||
let resultType ← inferType result.expr
|
||||
pure { result := result, resultType := resultType, size := cNode.size + 1 }
|
||||
return { result, resultType, size := cNode.size + 1 }
|
||||
|
||||
/--
|
||||
Create a new answer after `cNode` resolved all subgoals.
|
||||
|
|
@ -474,10 +470,10 @@ private def removeUnusedArguments? (mctx : MetavarContext) (mvar : Expr) : MetaM
|
|||
return some (mvarType', transformer)
|
||||
|
||||
/-- Process the next subgoal in the given consumer node. -/
|
||||
def consume (cNode : ConsumerNode) : SynthM Unit :=
|
||||
def consume (cNode : ConsumerNode) : SynthM Unit := do
|
||||
match cNode.subgoals with
|
||||
| [] => addAnswer cNode
|
||||
| mvar::_ => do
|
||||
| mvar::_ =>
|
||||
let waiter := Waiter.consumerNode cNode
|
||||
let key ← mkTableKeyFor cNode.mctx mvar
|
||||
let entry? ← findEntry? key
|
||||
|
|
@ -489,12 +485,12 @@ def consume (cNode : ConsumerNode) : SynthM Unit :=
|
|||
| some (mvarType', transformer) =>
|
||||
let key' := mkTableKey cNode.mctx mvarType'
|
||||
match (← findEntry? key') with
|
||||
| none => do
|
||||
| none =>
|
||||
let (mctx', mvar') ← withMCtx cNode.mctx do
|
||||
let mvar' ← mkFreshExprMVar mvarType'
|
||||
return (← getMCtx, mvar')
|
||||
newSubgoal mctx' key' mvar' (Waiter.consumerNode { cNode with mctx := mctx', subgoals := mvar'::cNode.subgoals })
|
||||
| some entry' => do
|
||||
| some entry' =>
|
||||
let answers' ← entry'.answers.mapM fun a => withMCtx cNode.mctx do
|
||||
let trAnswr := Expr.betaRev transformer #[← instantiateMVars a.result.expr]
|
||||
let trAnswrType ← inferType trAnswr
|
||||
|
|
@ -508,8 +504,8 @@ def consume (cNode : ConsumerNode) : SynthM Unit :=
|
|||
resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
|
||||
tableEntries := s.tableEntries.insert key { entry with waiters := entry.waiters.push waiter } }
|
||||
|
||||
def getTop : SynthM GeneratorNode := do
|
||||
pure (← get).generatorStack.back
|
||||
def getTop : SynthM GeneratorNode :=
|
||||
return (← get).generatorStack.back
|
||||
|
||||
@[inline] def modifyTop (f : GeneratorNode → GeneratorNode) : SynthM Unit :=
|
||||
modify fun s => { s with generatorStack := s.generatorStack.modify (s.generatorStack.size - 1) f }
|
||||
|
|
@ -519,7 +515,7 @@ def generate : SynthM Unit := do
|
|||
let gNode ← getTop
|
||||
if gNode.currInstanceIdx == 0 then
|
||||
modify fun s => { s with generatorStack := s.generatorStack.pop }
|
||||
else do
|
||||
else
|
||||
let key := gNode.key
|
||||
let idx := gNode.currInstanceIdx - 1
|
||||
let inst := gNode.instances.get! idx
|
||||
|
|
@ -528,14 +524,13 @@ def generate : SynthM Unit := do
|
|||
trace[Meta.synthInstance.generate] "instance {inst}"
|
||||
modifyTop fun gNode => { gNode with currInstanceIdx := idx }
|
||||
match (← tryResolve mctx mvar inst) with
|
||||
| none => pure ()
|
||||
| some (mctx, subgoals) => consume { key := key, mvar := mvar, subgoals := subgoals, mctx := mctx, size := 0 }
|
||||
| none => return ()
|
||||
| some (mctx, subgoals) => consume { key, mvar, subgoals, mctx, size := 0 }
|
||||
|
||||
def getNextToResume : SynthM (ConsumerNode × Answer) := do
|
||||
let s ← get
|
||||
let r := s.resumeStack.back
|
||||
let r := (← get).resumeStack.back
|
||||
modify fun s => { s with resumeStack := s.resumeStack.pop }
|
||||
pure r
|
||||
return r
|
||||
|
||||
/--
|
||||
Given `(cNode, answer)` on the top of the resume stack, continue execution by using `answer` to solve the
|
||||
|
|
@ -546,37 +541,37 @@ def resume : SynthM Unit := do
|
|||
| [] => panic! "resume found no remaining subgoals"
|
||||
| mvar::rest =>
|
||||
match (← tryAnswer cNode.mctx mvar answer) with
|
||||
| none => pure ()
|
||||
| none => return ()
|
||||
| some mctx =>
|
||||
withMCtx mctx <| traceM `Meta.synthInstance.resume do
|
||||
let goal ← inferType cNode.mvar
|
||||
let subgoal ← inferType mvar
|
||||
pure m!"size: {cNode.size + answer.size}, {goal} <== {subgoal}"
|
||||
consume { key := cNode.key, mvar := cNode.mvar, subgoals := rest, mctx := mctx, size := cNode.size + answer.size }
|
||||
return m!"size: {cNode.size + answer.size}, {goal} <== {subgoal}"
|
||||
consume { key := cNode.key, mvar := cNode.mvar, subgoals := rest, mctx, size := cNode.size + answer.size }
|
||||
|
||||
def step : SynthM Bool := do
|
||||
checkMaxHeartbeats
|
||||
let s ← get
|
||||
if !s.resumeStack.isEmpty then
|
||||
resume
|
||||
pure true
|
||||
return true
|
||||
else if !s.generatorStack.isEmpty then
|
||||
generate
|
||||
pure true
|
||||
return true
|
||||
else
|
||||
pure false
|
||||
return false
|
||||
|
||||
def getResult : SynthM (Option AbstractMVarsResult) := do
|
||||
pure (← get).result?
|
||||
def getResult : SynthM (Option AbstractMVarsResult) :=
|
||||
return (← get).result?
|
||||
|
||||
partial def synth : SynthM (Option AbstractMVarsResult) := do
|
||||
if (← step) then
|
||||
match (← getResult) with
|
||||
| none => synth
|
||||
| some result => pure result
|
||||
| some result => return result
|
||||
else
|
||||
trace[Meta.synthInstance] "failed"
|
||||
pure none
|
||||
return none
|
||||
|
||||
def main (type : Expr) (maxResultSize : Nat) : MetaM (Option AbstractMVarsResult) :=
|
||||
withCurrHeartbeats <| traceCtx `Meta.synthInstance do
|
||||
|
|
@ -720,7 +715,7 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (
|
|||
if type.hasMVar || resultHasUnivMVars then
|
||||
pure result?
|
||||
else do
|
||||
modify fun s => { s with cache := { s.cache with synthInstance := s.cache.synthInstance.insert type result? } }
|
||||
modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert type result? }
|
||||
pure result?
|
||||
|
||||
/--
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue