feat: cache failures at isDefEq

We can compile Lean with these changes, but 3 tests are still broken.
This cache is used to address a performance issue reported at
  https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/deterministic.20timeout.20with.20structures/near/288180087
This commit is contained in:
Leonardo de Moura 2022-07-03 21:49:45 -07:00
parent 76245b39d1
commit a1413b8fa1
20 changed files with 320 additions and 270 deletions

View file

@ -343,7 +343,7 @@ private def anyNamedArgDependsOnCurrent : M Bool := do
for i in [1:xs.size] do
let xDecl ← getLocalDecl xs[i]!.fvarId!
if s.namedArgs.any fun arg => arg.name == xDecl.userName then
if (← getMCtx).localDeclDependsOn xDecl curr.fvarId! then
if (← MetavarContext.localDeclDependsOn xDecl curr.fvarId!) then
return true
return false

View file

@ -72,8 +72,7 @@ private def elabOptLevel (stx : Syntax) : TermElabM Level :=
if userName.isAnonymous || (← read).inPattern then
mkNewHole ()
else
let mctx ← getMCtx
match mctx.findUserName? userName with
match (← getMCtx).findUserName? userName with
| none => mkNewHole ()
| some mvarId =>
let mvar := mkMVar mvarId
@ -81,16 +80,16 @@ private def elabOptLevel (stx : Syntax) : TermElabM Level :=
let lctx ← getLCtx
if mvarDecl.lctx.isSubPrefixOf lctx then
return mvar
else match mctx.getExprAssignment? mvarId with
else match (← getExprMVarAssignment? mvarId) with
| some val =>
let val ← instantiateMVars val
if mctx.isWellFormed lctx val then
if (← MetavarContext.isWellFormed lctx val) then
return val
else
withLCtx mvarDecl.lctx mvarDecl.localInstances do
throwError "synthetic hole has already been defined and assigned to value incompatible with the current context{indentExpr val}"
| none =>
if mctx.isDelayedAssigned mvarId then
if (← getMCtx).isDelayedAssigned mvarId then
-- We can try to improve this case if needed.
throwError "synthetic hole has already beend defined and delayed assigned with an incompatible local context"
else if lctx.isSubPrefixOf mvarDecl.lctx then

View file

@ -982,9 +982,8 @@ where
for fvarId in s.fvarSet.toList do
unless containsFVar discrs fvarId || containsFVar indices fvarId do
let localDecl ← getLocalDecl fvarId
let mctx ← getMCtx
for indexFVarId in indicesFVar do
if mctx.localDeclDependsOn localDecl indexFVarId then
if (← MetavarContext.localDeclDependsOn localDecl indexFVarId) then
toAdd := toAdd.push fvarId
let indicesFVar ← sortFVarIds (indicesFVar ++ toAdd)
return indicesFVar.map mkFVar ++ indicesNonFVar

View file

@ -326,8 +326,8 @@ we would have a `LetRecToLift` containing:
Note that `g` is not a free variable at `(let g : B := ?m₂; body)`. We recover the fact that
`f` depends on `g` because it contains `m₂`
-/
private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift)
: UsedFVarsMap := Id.run do
private def mkInitialUsedFVarsMap [Monad m] [MonadMCtx m] (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift)
: m UsedFVarsMap := do
let mut sectionVarSet := {}
for var in sectionVars do
sectionVarSet := sectionVarSet.insert var.fvarId!
@ -342,11 +342,11 @@ private def mkInitialUsedFVarsMap (mctx : MetavarContext) (sectionVars : Array E
for the associated let-rec because we need this information to compute the fixpoint later. -/
let mvarIds := (toLift.val.collectMVars {}).result
for mvarId in mvarIds do
match letRecsToLift.findSome? fun (toLift : LetRecToLift) => if toLift.mvarId == mctx.getDelayedRoot mvarId then some toLift.fvarId else none with
match (← letRecsToLift.findSomeM? fun (toLift : LetRecToLift) => return if toLift.mvarId == (← getDelayedMVarRoot mvarId) then some toLift.fvarId else none) with
| some fvarId => set := set.insert fvarId
| none => pure ()
usedFVarMap := usedFVarMap.insert toLift.fvarId set
pure usedFVarMap
return usedFVarMap
/-
The let-recs may invoke each other. Example:
@ -423,10 +423,10 @@ end FixPoint
abbrev FreeVarMap := FVarIdMap (Array FVarId)
private def mkFreeVarMap
(mctx : MetavarContext) (sectionVars : Array Expr) (mainFVarIds : Array FVarId)
(recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : FreeVarMap := Id.run do
let usedFVarsMap := mkInitialUsedFVarsMap mctx sectionVars mainFVarIds letRecsToLift
private def mkFreeVarMap [Monad m] [MonadMCtx m]
(sectionVars : Array Expr) (mainFVarIds : Array FVarId)
(recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : m FreeVarMap := do
let usedFVarsMap ← mkInitialUsedFVarsMap sectionVars mainFVarIds letRecsToLift
let letRecFVarIds := letRecsToLift.map fun toLift => toLift.fvarId
let usedFVarsMap := FixPoint.run letRecFVarIds usedFVarsMap
let mut freeVarMap := {}
@ -536,7 +536,7 @@ private def mkLetRecClosureFor (toLift : LetRecToLift) (freeVars : Array FVarId)
private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVarId) (recFVarIds : Array FVarId) (letRecsToLift : Array LetRecToLift) : TermElabM (List LetRecClosure) := do
-- Compute the set of free variables (excluding `recFVarIds`) for each let-rec.
let mut letRecsToLift := letRecsToLift
let mut freeVarMap := mkFreeVarMap (← getMCtx) sectionVars mainFVarIds recFVarIds letRecsToLift
let mut freeVarMap ← mkFreeVarMap sectionVars mainFVarIds recFVarIds letRecsToLift
let mut result := #[]
for i in [:letRecsToLift.size] do
if letRecsToLift[i]!.val.hasExprMVar then
@ -546,7 +546,7 @@ private def mkLetRecClosures (sectionVars : Array Expr) (mainFVarIds : Array FVa
let valNew ← instantiateMVars letRecsToLift[i]!.val
letRecsToLift := letRecsToLift.modify i fun t => { t with val := valNew }
-- We have to recompute the `freeVarMap` in this case. This overhead should not be an issue in practice.
freeVarMap := mkFreeVarMap (← getMCtx) sectionVars mainFVarIds recFVarIds letRecsToLift
freeVarMap ← mkFreeVarMap sectionVars mainFVarIds recFVarIds letRecsToLift
let toLift := letRecsToLift[i]!
result := result.push (← mkLetRecClosureFor toLift (freeVarMap.find? toLift.fvarId).get!)
return result.toList

View file

@ -694,7 +694,7 @@ partial def visit (e : Expr) : M Unit := do
if e' != e then
visit e'
else
match (← getDelayedAssignment? mvarId) with
match (← getDelayedMVarAssignment? mvarId) with
| some d => visit d.val
| none => failure
| _ => return ()
@ -1531,7 +1531,7 @@ where
if auto.isFVar then
let localDecl ← getLocalDecl auto.fvarId!
for x in xs do
if (← getMCtx).localDeclDependsOn localDecl x.fvarId! then
if (← MetavarContext.localDeclDependsOn localDecl x.fvarId!) then
throwError "invalid auto implicit argument '{auto}', it depends on explicitly provided argument '{x}'"
return autos ++ xs
| auto :: todo =>
@ -1562,7 +1562,7 @@ builtin_initialize registerTraceClass `Elab.letrec
is delayed assigned to one. -/
def isLetRecAuxMVar (mvarId : MVarId) : TermElabM Bool := do
trace[Elab.letrec] "mvarId: {mkMVar mvarId} letrecMVars: {(← get).letRecsToLift.map (mkMVar $ ·.mvarId)}"
let mvarId := (getMCtx).getDelayedRoot mvarId
let mvarId ← getDelayedMVarRoot mvarId
trace[Elab.letrec] "mvarId root: {mkMVar mvarId}"
return (← get).letRecsToLift.any (·.mvarId == mvarId)

View file

@ -192,9 +192,11 @@ abbrev InferTypeCache := PersistentExprStructMap Expr
abbrev FunInfoCache := PersistentHashMap InfoCacheKey FunInfo
abbrev WhnfCache := PersistentExprStructMap Expr
/- A set of pairs. TODO: consider more efficient representations (e.g., a proper set) and caching policies (e.g., imperfect cache).
We should also investigate the impact on memory consumption. -/
abbrev DefEqCache := PersistentHashMap (Expr × Expr) Unit
/--
A mapping `(s, t) ↦ isDefEq s t`.
TODO: consider more efficient representations (e.g., a proper set) and caching policies (e.g., imperfect cache).
We should also investigate the impact on memory consumption. -/
abbrev DefEqCache := PersistentHashMap (Expr × Expr) Bool
/--
Cache datastructures for type inference, type class resolution, whnf, and definitional equality.
@ -237,12 +239,12 @@ structure PostponedEntry where
`MetaM` monad state.
-/
structure State where
mctx : MetavarContext := {}
cache : Cache := {}
mctx : MetavarContext := {}
cache : Cache := {}
/-- When `trackZeta == true`, then any let-decl free variable that is zeta expansion performed by `MetaM` is stored in `zetaFVarIds`. -/
zetaFVarIds : FVarIdSet := {}
zetaFVarIds : FVarIdSet := {}
/-- Array of postponed universe level constraints -/
postponed : PersistentArray PostponedEntry := {}
postponed : PersistentArray PostponedEntry := {}
deriving Inhabited
/--
@ -305,7 +307,8 @@ protected def saveState : MetaM SavedState :=
/-- Restore backtrackable parts of the state. -/
def SavedState.restore (b : SavedState) : MetaM Unit := do
Core.restore b.core
modify fun s => { s with mctx := b.meta.mctx, zetaFVarIds := b.meta.zetaFVarIds, postponed := b.meta.postponed }
let usedAssignment := (← getMCtx).usedAssignment
modify fun s => { s with mctx := { b.meta.mctx with usedAssignment }, zetaFVarIds := b.meta.zetaFVarIds, postponed := b.meta.postponed }
instance : MonadBacktrack SavedState MetaM where
saveState := Meta.saveState
@ -488,8 +491,11 @@ def shouldReduceAll : MetaM Bool :=
def shouldReduceReducibleOnly : MetaM Bool :=
return (← getTransparency) == TransparencyMode.reducible
def findMVarDecl? (mvarId : MVarId) : MetaM (Option MetavarDecl) :=
return (← getMCtx).findDecl? mvarId
def getMVarDecl (mvarId : MVarId) : MetaM MetavarDecl := do
match (← getMCtx).findDecl? mvarId with
match (← findMVarDecl? mvarId) with
| some d => pure d
| none => throwError "unknown metavariable '?{mvarId.name}'"
@ -527,22 +533,12 @@ def setMVarUserName (mvarId : MVarId) (newUserName : Name) : MetaM Unit :=
def isExprMVarAssigned (mvarId : MVarId) : MetaM Bool :=
return (← getMCtx).isExprAssigned mvarId
def getExprMVarAssignment? (mvarId : MVarId) : MetaM (Option Expr) :=
return (← getMCtx).getExprAssignment? mvarId
/-- Return true if `e` contains `mvarId` directly or indirectly -/
def occursCheck (mvarId : MVarId) (e : Expr) : MetaM Bool :=
return (← getMCtx).occursCheck mvarId e
def assignExprMVar (mvarId : MVarId) (val : Expr) : MetaM Unit :=
modifyMCtx fun mctx => mctx.assignExpr mvarId val
def isDelayedAssigned (mvarId : MVarId) : MetaM Bool :=
return (← getMCtx).isDelayedAssigned mvarId
def getDelayedAssignment? (mvarId : MVarId) : MetaM (Option DelayedMetavarAssignment) :=
return (← getMCtx).getDelayedAssignment? mvarId
def hasAssignableMVar (e : Expr) : MetaM Bool :=
return (← getMCtx).hasAssignableMVar e
@ -595,6 +591,10 @@ def abstractRange (e : Expr) (n : Nat) (xs : Array Expr) : MetaM Expr :=
def abstract (e : Expr) (xs : Array Expr) : MetaM Expr :=
abstractRange e xs.size xs
def collectForwardDeps (toRevert : Array Expr) (preserveOrder : Bool) : MetaM (Array Expr) := do
liftMkBindingM <| MetavarContext.collectForwardDeps toRevert preserveOrder
/-- Takes an array `xs` of free variables or metavariables and a term `e` that may contain those variables, and abstracts and binds them as universal quantifiers.
- if `usedOnly = true` then only variables that the expression body depends on will appear.
@ -1144,7 +1144,13 @@ private def withNewMCtxDepthImp (x : MetaM α) : MetaM α := do
try
x
finally
modify fun s => { s with mctx := saved.mctx, postponed := saved.postponed }
-- TODO: document why we need to restore defEqCache
modify fun s => { s with
mctx := saved.mctx
cache.defEqDefault := saved.cache.defEqDefault
cache.defEqAll := saved.cache.defEqAll
postponed := saved.postponed
}
/--
Save cache and `MetavarContext`, bump the `MetavarContext` depth, execute `x`,
@ -1283,19 +1289,19 @@ def instantiateLambda (e : Expr) (ps : Array Expr) : MetaM Expr :=
/-- Return true iff `e` depends on the free variable `fvarId` -/
def dependsOn (e : Expr) (fvarId : FVarId) : MetaM Bool :=
return (← getMCtx).exprDependsOn e fvarId
MetavarContext.exprDependsOn e fvarId
/-- Return true iff `e` depends on the free variable `fvarId` -/
def localDeclDependsOn (localDecl : LocalDecl) (fvarId : FVarId) : MetaM Bool :=
return (← getMCtx).localDeclDependsOn localDecl fvarId
MetavarContext.localDeclDependsOn localDecl fvarId
/-- Return true iff `e` depends on a free variable `x` s.t. `pf x`, or an unassigned metavariable `?m` s.t. `pm ?m` is true. -/
def dependsOnPred (e : Expr) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : MetaM Bool :=
return (← getMCtx).findExprDependsOn e pf pm
MetavarContext.findExprDependsOn e pf pm
/-- Return true iff the local declaration `localDecl` depends on a free variable `x` s.t. `pf x`, an unassigned metavariable `?m` s.t. `pm ?m` is true. -/
def localDeclDependsOnPred (localDecl : LocalDecl) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : MetaM Bool := do
return (← getMCtx).findLocalDeclDependsOn localDecl pf pm
MetavarContext.findLocalDeclDependsOn localDecl pf pm
/-- Pretty-print the given expression. -/
def ppExpr (e : Expr) : MetaM Format := do
@ -1379,7 +1385,7 @@ def isListLevelDefEqAux : List Level → List Level → MetaM Bool
| u::us, v::vs => isLevelDefEqAux u v <&&> isListLevelDefEqAux us vs
| _, _ => return false
private def getNumPostponed : MetaM Nat := do
def getNumPostponed : MetaM Nat := do
return (← getPostponed).size
def getResetPostponed : MetaM (PersistentArray PostponedEntry) := do

View file

@ -24,7 +24,7 @@ partial def collectMVars (e : Expr) : StateRefT CollectMVars.State MetaM Unit :=
let s := e.collectMVars s
set s
for mvarId in s.result[resultSavedSize:] do
match (← getDelayedAssignment? mvarId) with
match (← getDelayedMVarAssignment? mvarId) with
| none => pure ()
| some d => collectMVars d.val

View file

@ -19,8 +19,7 @@ private partial def decAux? : Level → ReaderT DecLevelContext MetaM (Option Le
| Level.zero _ => return none
| Level.param _ _ => return none
| Level.mvar mvarId _ => do
let mctx ← getMCtx
match mctx.getLevelAssignment? mvarId with
match (← getLevelMVarAssignment? mvarId) with
| some u => decAux? u
| none =>
if (← isReadOnlyLevelMVar mvarId) || !(← read).canAssignMVars then

View file

@ -16,7 +16,7 @@ import Lean.Meta.UnificationHint
namespace Lean.Meta
/--
Return true `b` is of the form `mk a.1 ... a.n`, and `a` is not a constructor application.
Return true if `b` is of the form `mk a.1 ... a.n`, and `a` is not a constructor application.
If `a` and `b` are constructor applications, the method returns `false` to force `isDefEq` to use `isDefEqArgs`.
For example, suppose we are trying to solve the constraint
@ -644,14 +644,13 @@ mutual
partial def checkMVar (mvar : Expr) : CheckAssignmentM Expr := do
let mvarId := mvar.mvarId!
let ctx ← read
let mctx ← getMCtx
if mvarId == ctx.mvarId then
traceM `Meta.isDefEq.assign.occursCheck <| addAssignmentInfo "occurs check failed"
throwCheckAssignmentFailure
else match mctx.getExprAssignment? mvarId with
else match (← getExprMVarAssignment? mvarId) with
| some v => check v
| none =>
match mctx.findDecl? mvarId with
match (← findMVarDecl? mvarId) with
| none => throwUnknownMVar mvarId
| some mvarDecl =>
if ctx.hasCtxLocals then
@ -660,7 +659,7 @@ mutual
/- The local context of `mvar` - free variables being abstracted is a subprefix of the metavariable being assigned.
We "substract" variables being abstracted because we use `elimMVarDeps` -/
pure mvar
else if mvarDecl.depth != mctx.depth || mvarDecl.kind.isSyntheticOpaque then
else if mvarDecl.depth != (← getMCtx).depth || mvarDecl.kind.isSyntheticOpaque then
traceM `Meta.isDefEq.assign.readOnlyMVarWithBiggerLCtx <| addAssignmentInfo (mkMVar mvarId)
throwCheckAssignmentFailure
else
@ -678,17 +677,17 @@ mutual
Notat that if a variable is `ctx.fvars`, but it depends on variable at `toErase`,
we must also erase it.
-/
let toErase := mvarDecl.lctx.foldl (init := #[]) fun toErase localDecl =>
let toErase ← mvarDecl.lctx.foldlM (init := #[]) fun toErase localDecl => do
if ctx.mvarDecl.lctx.contains localDecl.fvarId then
toErase
return toErase
else if ctx.fvars.any fun fvar => fvar.fvarId! == localDecl.fvarId then
if mctx.findLocalDeclDependsOn localDecl fun fvarId => toErase.contains fvarId then
if (← MetavarContext.findLocalDeclDependsOn localDecl fun fvarId => toErase.contains fvarId) then
-- localDecl depends on a variable that will be erased. So, we must add it to `toErase` too
toErase.push localDecl.fvarId
return toErase.push localDecl.fvarId
else
toErase
return toErase
else
toErase.push localDecl.fvarId
return toErase.push localDecl.fvarId
let lctx := toErase.foldl (init := mvarDecl.lctx) fun lctx toEraseFVar =>
lctx.erase toEraseFVar
/- Compute new set of local instances. -/
@ -808,7 +807,7 @@ partial def check
if fvars.any fun x => x.fvarId! == fvarId then true
else false -- We could throw an exception here, but we would have to use ExceptM. So, we let CheckAssignment.check do it
| Expr.mvar mvarId' _ =>
match mctx.getExprAssignment? mvarId' with
match mctx.getExprAssignmentCore? mvarId' with
| some _ => false -- use CheckAssignment.check to instantiate
| none =>
if mvarId' == mvarId then false -- occurs check failed, use CheckAssignment.check to throw exception
@ -1620,18 +1619,28 @@ private def skipDefEqCache : MetaM Bool := do
private def mkCacheKey (t : Expr) (s : Expr) : Expr × Expr :=
if Expr.quickLt t s then (t, s) else (s, t)
private def isCached (key : Expr × Expr) : MetaM Bool := do
match (← getConfig).transparency with
| TransparencyMode.default => return (← get).cache.defEqDefault.contains key
| TransparencyMode.all => return (← get).cache.defEqAll.contains key
| _ => return false
private def getCachedResult (key : Expr × Expr) : MetaM LBool := do
let cache ← match (← getConfig).transparency with
| TransparencyMode.default => pure (← get).cache.defEqDefault
| TransparencyMode.all => pure (← get).cache.defEqAll
| _ => return .undef
match cache.find? key with
| some val => return val.toLBool
| none => return .undef
private def cacheResult (key : Expr × Expr) : MetaM Unit := do
private def cacheResult (key : Expr × Expr) (result : Bool) : MetaM Unit := do
match (← getConfig).transparency with
| TransparencyMode.default => modify fun s => { s with cache.defEqDefault := s.cache.defEqDefault.insert key () }
| TransparencyMode.all => modify fun s => { s with cache.defEqAll := s.cache.defEqAll.insert key () }
| TransparencyMode.default => modify fun s => { s with cache.defEqDefault := s.cache.defEqDefault.insert key result }
| TransparencyMode.all => modify fun s => { s with cache.defEqAll := s.cache.defEqAll.insert key result }
| _ => pure ()
private abbrev withResetUsedAssignment (k : MetaM α) : MetaM α := do
let usedAssignment := (← getMCtx).usedAssignment
modifyMCtx fun mctx => { mctx with usedAssignment := false }
try
k
finally
modifyMCtx fun mctx => { mctx with usedAssignment }
@[export lean_is_expr_def_eq]
partial def isExprDefEqAuxImpl (t : Expr) (s : Expr) : MetaM Bool := withIncRecDepth do
@ -1658,18 +1667,24 @@ partial def isExprDefEqAuxImpl (t : Expr) (s : Expr) : MetaM Bool := withIncRecD
-/
let t ← instantiateMVars t
let s ← instantiateMVars s
if t.hasMVar || s.hasMVar then
-- It is not safe to use DefEq cache if terms contain metavariables
isExprDefEqExpensive t s
else
let k := mkCacheKey t s
if (← isCached k) then
return true
else if (← isExprDefEqExpensive t s) then
cacheResult k
return true
else
return false
let numPostponed ← getNumPostponed
let k := mkCacheKey t s
match (← getCachedResult k) with
| .true =>
trace[Meta.isDefEq.cache] "cache hit 'true' for {t} =?= {s}"
return true
| .false =>
trace[Meta.isDefEq.cache] "cache hit 'false' for {t} =?= {s}"
return false
| .undef =>
withResetUsedAssignment do
let result ← isExprDefEqExpensive t s
if numPostponed == (← getNumPostponed) && !(← getMCtx).usedAssignment then
-- It is only safe to cache the result if the mvars assignments have not been accessed/used,
-- and universe level variables have not been postponed.
trace[Meta.isDefEq.cache] "cache {result} for {t} =?= {s}"
cacheResult k result
return result
builtin_initialize
registerTraceClass `Meta.isDefEq

View file

@ -60,7 +60,7 @@ def getFVarSetToGeneralize (targets : Array Expr) (forbidden : FVarIdSet) (ignor
for localDecl in (← getLCtx) do
unless forbidden.contains localDecl.fvarId do
unless localDecl.isAuxDecl || localDecl.binderInfo.isInstImplicit || (ignoreLetDecls && localDecl.isLet) do
if (← getMCtx).findLocalDeclDependsOn localDecl (s.contains ·) then
if (← MetavarContext.findLocalDeclDependsOn localDecl (s.contains ·)) then
r := r.insert localDecl.fvarId
s := s.insert localDecl.fvarId
return r

View file

@ -69,13 +69,13 @@ def moveToHiddeProp (fvarId : FVarId) : M Unit := do
Recall that hiddenInaccessibleProps are visible, only their names are hidden -/
def hasVisibleDep (localDecl : LocalDecl) : M Bool := do
let s ← get
return (← getMCtx).findLocalDeclDependsOn localDecl (!s.hiddenInaccessible.contains ·)
MetavarContext.findLocalDeclDependsOn localDecl (!s.hiddenInaccessible.contains ·)
/- Return true if the given local declaration has a "nonvisible dependency", that is, it contains
a free variable that is `hiddenInaccessible` or `hiddenInaccessibleProp` -/
def hasInaccessibleNameDep (localDecl : LocalDecl) : M Bool := do
let s ← get
return (← getMCtx).findLocalDeclDependsOn localDecl fun fvarId =>
MetavarContext.findLocalDeclDependsOn localDecl fun fvarId =>
s.hiddenInaccessible.contains fvarId || s.hiddenInaccessibleProp.contains fvarId
/- If `e` is visible, then any inaccessible in `e` marked as hidden should be unmarked. -/

View file

@ -163,20 +163,18 @@ We say the major premise has independent indices IF
-/
private def hasIndepIndices (ctx : Context) : MetaM Bool := do
if ctx.majorTypeIndices.isEmpty then
pure true
return true
else if ctx.majorTypeIndices.any fun idx => !idx.isFVar then
/- One of the indices is not a free variable. -/
pure false
return false
else if ctx.majorTypeIndices.size.any fun i => i.any fun j => ctx.majorTypeIndices[i]! == ctx.majorTypeIndices[j]! then
/- An index ocurrs more than once -/
pure false
return false
else
let lctx ← getLCtx
let mctx ← getMCtx
return lctx.all fun decl =>
decl.fvarId == ctx.majorDecl.fvarId || -- decl is the major
ctx.majorTypeIndices.any (fun index => decl.fvarId == index.fvarId!) || -- decl is one of the indices
mctx.findLocalDeclDependsOn decl (fun fvarId => ctx.majorTypeIndices.all fun idx => idx.fvarId! != fvarId) -- or does not depend on any index
(← getLCtx).allM fun decl =>
pure (decl.fvarId == ctx.majorDecl.fvarId) <||> -- decl is the major
pure (ctx.majorTypeIndices.any (fun index => decl.fvarId == index.fvarId!)) <||> -- decl is one of the indices
MetavarContext.findLocalDeclDependsOn decl (fun fvarId => ctx.majorTypeIndices.all fun idx => idx.fvarId! != fvarId) -- or does not depend on any index
private def elimAuxIndices (s₁ : GeneralizeIndicesSubgoal) (s₂ : Array CasesSubgoal) : MetaM (Array CasesSubgoal) :=
let indicesFVarIds := s₁.indicesFVarIds

View file

@ -17,10 +17,10 @@ def clear (mvarId : MVarId) (fvarId : FVarId) : MetaM MVarId :=
let mctx ← getMCtx
lctx.forM fun localDecl => do
unless localDecl.fvarId == fvarId do
if mctx.localDeclDependsOn localDecl fvarId then
if (← MetavarContext.localDeclDependsOn localDecl fvarId) then
throwTacticEx `clear mvarId m!"variable '{localDecl.toExpr}' depends on '{mkFVar fvarId}'"
let mvarDecl ← getMVarDecl mvarId
if mctx.exprDependsOn mvarDecl.type fvarId then
if (← MetavarContext.exprDependsOn mvarDecl.type fvarId) then
throwTacticEx `clear mvarId m!"target depends on '{mkFVar fvarId}'"
let lctx := lctx.erase fvarId
let localInsts ← getLocalInstances

View file

@ -143,17 +143,18 @@ def induction (mvarId : MVarId) (majorFVarId : FVarId) (recursorName : Name) (gi
let arg := majorTypeArgs[i]!
if i != idxPos && arg == idx then
throwTacticEx `induction mvarId m!"'{idx}' is an index in major premise, but it occurs more than once{indentExpr majorType}"
if i < idxPos && mctx.exprDependsOn arg idx.fvarId! then
throwTacticEx `induction mvarId m!"'{idx}' is an index in major premise, but it occurs in previous arguments{indentExpr majorType}"
if i < idxPos then
if (← MetavarContext.exprDependsOn arg idx.fvarId!) then
throwTacticEx `induction mvarId m!"'{idx}' is an index in major premise, but it occurs in previous arguments{indentExpr majorType}"
-- If arg is also and index and a variable occurring after `idx`, we need to make sure it doesn't depend on `idx`.
-- Note that if `arg` is not a variable, we will fail anyway when we visit it.
if i > idxPos && recursorInfo.indicesPos.contains i && arg.isFVar then
let idxDecl ← getLocalDecl idx.fvarId!
if mctx.localDeclDependsOn idxDecl arg.fvarId! then
if (← MetavarContext.localDeclDependsOn idxDecl arg.fvarId!) then
throwTacticEx `induction mvarId m!"'{idx}' is an index in major premise, but it depends on index occurring at position #{i+1}"
pure idx
let target ← getMVarType mvarId
if !recursorInfo.depElim && mctx.exprDependsOn target majorFVarId then
if (← pure !recursorInfo.depElim <&&> MetavarContext.exprDependsOn target majorFVarId) then
throwTacticEx `induction mvarId m!"recursor '{recursorName}' does not support dependent elimination, but conclusion depends on major premise"
-- Revert indices and major premise preserving variable order
let (reverted, mvarId) ← revert mvarId ((indices.map Expr.fvarId!).push majorFVarId) true

View file

@ -16,30 +16,28 @@ def revert (mvarId : MVarId) (fvarIds : Array FVarId) (preserveOrder : Bool := f
if (← getLocalDecl fvarId) |>.isAuxDecl then
throwError "failed to revert {mkFVar fvarId}, it is an auxiliary declaration created to represent recursive definitions"
let fvars := fvarIds.map mkFVar
match MetavarContext.MkBinding.collectForwardDeps (← getMCtx) (← getLCtx) fvars preserveOrder with
| Except.error _ => throwError "failed to revert variables {fvars}"
| Except.ok toRevert =>
/- We should clear any `auxDecl` in `toRevert` -/
let mut mvarId := mvarId
let mut toRevertNew := #[]
for x in toRevert do
if (← getLocalDecl x.fvarId!) |>.isAuxDecl then
mvarId ← clear mvarId x.fvarId!
else
toRevertNew := toRevertNew.push x
let tag ← getMVarTag mvarId
-- TODO: the following code can be optimized because `MetavarContext.revert` will compute `collectDeps` again.
-- We should factor out the relevat part
let toRevert ← collectForwardDeps fvars preserveOrder
/- We should clear any `auxDecl` in `toRevert` -/
let mut mvarId := mvarId
let mut toRevertNew := #[]
for x in toRevert do
if (← getLocalDecl x.fvarId!) |>.isAuxDecl then
mvarId ← clear mvarId x.fvarId!
else
toRevertNew := toRevertNew.push x
let tag ← getMVarTag mvarId
-- TODO: the following code can be optimized because `MetavarContext.revert` will compute `collectDeps` again.
-- We should factor out the relevat part
-- Set metavariable kind to natural to make sure `revert` will assign it.
setMVarKind mvarId MetavarKind.natural
let (e, toRevert) ←
try
liftMkBindingM <| MetavarContext.revert toRevertNew mvarId preserveOrder
finally
setMVarKind mvarId MetavarKind.syntheticOpaque
let mvar := e.getAppFn
setMVarTag mvar.mvarId! tag
return (toRevert.map Expr.fvarId!, mvar.mvarId!)
-- Set metavariable kind to natural to make sure `revert` will assign it.
setMVarKind mvarId MetavarKind.natural
let (e, toRevert) ←
try
liftMkBindingM <| MetavarContext.revert toRevertNew mvarId preserveOrder
finally
setMVarKind mvarId MetavarKind.syntheticOpaque
let mvar := e.getAppFn
setMVarTag mvar.mvarId! tag
return (toRevert.map Expr.fvarId!, mvar.mvarId!)
end Lean.Meta

View file

@ -29,8 +29,7 @@ def substCore (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst :
| Expr.fvar aFVarId _ => do
let aFVarIdOriginal := aFVarId
trace[Meta.Tactic.subst] "substituting {a} (id: {aFVarId.name}) with {b}"
let mctx ← getMCtx
if mctx.exprDependsOn b aFVarId then
if (← MetavarContext.exprDependsOn b aFVarId) then
throwTacticEx `subst mvarId m!"'{a}' occurs at{indentExpr b}"
let (vars, mvarId) ← revert mvarId #[aFVarId, hFVarId] true
trace[Meta.Tactic.subst] "after revert {MessageData.ofGoal mvarId}"
@ -46,8 +45,9 @@ def substCore (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst :
pure false
else
let mvarType ← getMVarType mvarId
let mctx ← getMCtx
pure (!mctx.exprDependsOn mvarType aFVarId && !mctx.exprDependsOn mvarType hFVarId)
if (← MetavarContext.exprDependsOn mvarType aFVarId) then pure false
else if (← MetavarContext.exprDependsOn mvarType hFVarId) then pure false
else pure true
if skip then
if clearH then
let mvarId ← clear mvarId hFVarId
@ -64,8 +64,7 @@ def substCore (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst :
| none => unreachable!
| some (_, lhs, rhs) => do
let b ← instantiateMVars <| if symm then lhs else rhs
let mctx ← getMCtx
let depElim := mctx.exprDependsOn mvarDecl.type hFVarId
let depElim ← MetavarContext.exprDependsOn mvarDecl.type hFVarId
let cont (motive : Expr) (newType : Expr) : MetaM (FVarSubst × MVarId) := do
let major ← if symm then pure h else mkEqSymm h
let newMVar ← mkFreshExprSyntheticOpaqueMVar newType tag
@ -192,12 +191,13 @@ where
| some (_, lhs, rhs) =>
let lhs ← instantiateMVars lhs
let rhs ← instantiateMVars rhs
if rhs.isFVar && rhs.fvarId! == h && !mctx.exprDependsOn lhs h then
return some (localDecl.fvarId, true)
else if lhs.isFVar && lhs.fvarId! == h && !mctx.exprDependsOn rhs h then
return some (localDecl.fvarId, false)
else
return none
if rhs.isFVar && rhs.fvarId! == h then
if !(← MetavarContext.exprDependsOn lhs h) then
return some (localDecl.fvarId, true)
if lhs.isFVar && lhs.fvarId! == h then
if !(← MetavarContext.exprDependsOn rhs h) then
return some (localDecl.fvarId, false)
return none
| _ => return none
| throwTacticEx `subst mvarId m!"did not find equation for eliminating '{mkFVar h}'"
return (← substCore mvarId fvarId (symm := symm) (tryToSkip := true)).2

View file

@ -450,7 +450,7 @@ def reduceProj? (e : Expr) : MetaM (Option Expr) := do
-/
private def whnfDelayedAssigned? (f' : Expr) (e : Expr) : MetaM (Option Expr) := do
if f'.isMVar then
match (← getDelayedAssignment? f'.mvarId!) with
match (← getDelayedMVarAssignment? f'.mvarId!) with
| none => return none
| some { fvars := fvars, val := val, .. } =>
let args := e.getAppArgs

View file

@ -271,14 +271,15 @@ structure DelayedMetavarAssignment where
open Std (HashMap PersistentHashMap)
structure MetavarContext where
depth : Nat := 0
mvarCounter : Nat := 0 -- Counter for setting the field `index` at `MetavarDecl`
lDepth : PersistentHashMap MVarId Nat := {}
decls : PersistentHashMap MVarId MetavarDecl := {}
userNames : PersistentHashMap Name MVarId := {}
lAssignment : PersistentHashMap MVarId Level := {}
eAssignment : PersistentHashMap MVarId Expr := {}
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}
depth : Nat := 0
mvarCounter : Nat := 0 -- Counter for setting the field `index` at `MetavarDecl`
lDepth : PersistentHashMap MVarId Nat := {}
decls : PersistentHashMap MVarId MetavarDecl := {}
userNames : PersistentHashMap Name MVarId := {}
lAssignment : PersistentHashMap MVarId Level := {}
eAssignment : PersistentHashMap MVarId Expr := {}
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}
usedAssignment : Bool := false
class MonadMCtx (m : Type → Type) where
getMCtx : m MetavarContext
@ -290,6 +291,48 @@ instance (m n) [MonadLift m n] [MonadMCtx m] : MonadMCtx n where
getMCtx := liftM (getMCtx : m _)
modifyMCtx := fun f => liftM (modifyMCtx f : m _)
def markUsedAssignment [MonadMCtx m] : m Unit :=
modifyMCtx fun mctx => { mctx with usedAssignment := true }
abbrev setMCtx [MonadMCtx m] (mctx : MetavarContext) : m Unit :=
modifyMCtx fun _ => mctx
abbrev getLevelMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Level) := do
let result? := (← getMCtx).lAssignment.find? mvarId
if result?.isSome then
markUsedAssignment
return result?
def MetavarContext.getExprAssignmentCore? (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
m.eAssignment.find? mvarId
def getExprMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Expr) := do
let result? := (← getMCtx).getExprAssignmentCore? mvarId
if result?.isSome then
markUsedAssignment
return result?
def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) := do
let result? := (← getMCtx).dAssignment.find? mvarId
if result?.isSome then
markUsedAssignment
return result?
/- Given a sequence of delayed assignments
```
mvarId₁ := mvarId₂ ...;
...
mvarIdₙ := mvarId_root ... -- where `mvarId_root` is not delayed assigned
```
in `mctx`, `getDelayedRoot mctx mvarId₁` return `mvarId_root`.
If `mvarId₁` is not delayed assigned then return `mvarId₁` -/
partial def getDelayedMVarRoot [Monad m] [MonadMCtx m] (mvarId : MVarId) : m MVarId := do
match (← getDelayedMVarAssignment? mvarId) with
| some d => match d.val.getAppFn with
| Expr.mvar mvarId _ => getDelayedMVarRoot mvarId
| _ => return mvarId
| none => return mvarId
namespace MetavarContext
instance : Inhabited MetavarContext := ⟨{}⟩
@ -387,22 +430,13 @@ def isAnonymousMVar (mctx : MetavarContext) (mvarId : MVarId) : Bool :=
| some mvarDecl => mvarDecl.userName.isAnonymous
def assignLevel (m : MetavarContext) (mvarId : MVarId) (val : Level) : MetavarContext :=
{ m with lAssignment := m.lAssignment.insert mvarId val }
{ m with lAssignment := m.lAssignment.insert mvarId val, usedAssignment := true }
def assignExpr (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext :=
{ m with eAssignment := m.eAssignment.insert mvarId val }
{ m with eAssignment := m.eAssignment.insert mvarId val, usedAssignment := true }
def assignDelayed (m : MetavarContext) (mvarId : MVarId) (fvars : Array Expr) (val : Expr) : MetavarContext :=
{ m with dAssignment := m.dAssignment.insert mvarId { fvars, val } }
def getLevelAssignment? (m : MetavarContext) (mvarId : MVarId) : Option Level :=
m.lAssignment.find? mvarId
def getExprAssignment? (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
m.eAssignment.find? mvarId
def getDelayedAssignment? (m : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
m.dAssignment.find? mvarId
{ m with dAssignment := m.dAssignment.insert mvarId { fvars, val }, usedAssignment := true }
def isLevelAssigned (m : MetavarContext) (mvarId : MVarId) : Bool :=
m.lAssignment.contains mvarId
@ -416,21 +450,6 @@ def isDelayedAssigned (m : MetavarContext) (mvarId : MVarId) : Bool :=
def eraseDelayed (m : MetavarContext) (mvarId : MVarId) : MetavarContext :=
{ m with dAssignment := m.dAssignment.erase mvarId }
/- Given a sequence of delayed assignments
```
mvarId₁ := mvarId₂ ...;
...
mvarIdₙ := mvarId_root ... -- where `mvarId_root` is not delayed assigned
```
in `mctx`, `getDelayedRoot mctx mvarId₁` return `mvarId_root`.
If `mvarId₁` is not delayed assigned then return `mvarId₁` -/
partial def getDelayedRoot (m : MetavarContext) : MVarId → MVarId
| mvarId => match getDelayedAssignment? m mvarId with
| some d => match d.val.getAppFn with
| Expr.mvar mvarId _ => getDelayedRoot m mvarId
| _ => mvarId
| none => mvarId
def isLevelAssignable (mctx : MetavarContext) (mvarId : MVarId) : Bool :=
match mctx.lDepth.find? mvarId with
| some d => d == mctx.depth
@ -515,7 +534,7 @@ partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
| lvl@(Level.max lvl₁ lvl₂ _) => return Level.updateMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂)
| lvl@(Level.imax lvl₁ lvl₂ _) => return Level.updateIMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂)
| lvl@(Level.mvar mvarId _) => do
match getLevelAssignment? (← getMCtx) mvarId with
match (← getLevelMVarAssignment? mvarId) with
| some newLvl =>
if !newLvl.hasMVar then pure newLvl
else do
@ -551,8 +570,7 @@ partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLi
instArgs f
match f with
| Expr.mvar mvarId _ =>
let mctx ← getMCtx
match mctx.getDelayedAssignment? mvarId with
match (← getDelayedMVarAssignment? mvarId) with
| none => instApp
| some { fvars, val, .. } =>
/-
@ -589,8 +607,7 @@ partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLi
pure result
| _ => instApp
| e@(Expr.mvar mvarId _) => checkCache { val := e : ExprStructEq } fun _ => do
let mctx ← getMCtx
match mctx.getExprAssignment? mvarId with
match (← getExprMVarAssignment? mvarId) with
| some newE => do
let newE' ← instantiateExprMVars newE
modifyMCtx fun mctx => mctx.assignExpr mvarId newE'
@ -629,18 +646,26 @@ def instantiateMVarDeclMVars (mctx : MetavarContext) (mvarId : MVarId) : Metavar
namespace DependsOn
private abbrev M := StateM ExprSet
structure State where
visited : ExprSet := {}
mctx : MetavarContext
private abbrev M := StateM State
instance : MonadMCtx M where
getMCtx := return (← get).mctx
modifyMCtx f := modify fun s => { s with mctx := f s.mctx }
private def shouldVisit (e : Expr) : M Bool := do
if !e.hasMVar && !e.hasFVar then
return false
else if (← get).contains e then
else if (← get).visited.contains e then
return false
else
modify fun s => s.insert e
modify fun s => { s with visited := s.visited.insert e }
return true
@[specialize] private partial def dep (mctx : MetavarContext) (pf : FVarId → Bool) (pm : MVarId → Bool) (e : Expr) : M Bool :=
@[specialize] private partial def dep (pf : FVarId → Bool) (pm : MVarId → Bool) (e : Expr) : M Bool :=
let rec
visit (e : Expr) : M Bool := do
if !(← shouldVisit e) then
@ -656,10 +681,10 @@ private def shouldVisit (e : Expr) : M Bool := do
| Expr.lam _ d b _ => visit d <||> visit b
| Expr.letE _ t v b _ => visit t <||> visit v <||> visit b
| Expr.mdata _ b _ => visit b
| e@(Expr.app ..) =>
| e@(Expr.app ..) => do
let f := e.getAppFn
if f.isMVar then
let (e', _) := instantiateMVars mctx e
let e' ← modifyGet fun ⟨visited, mctx⟩ => let (e, mctx) := instantiateMVars mctx e; (e, ⟨visited, mctx⟩)
if e'.getAppFn != f then
visitMain e'
else if pm f.mvarId! then
@ -668,21 +693,21 @@ private def shouldVisit (e : Expr) : M Bool := do
visitApp e
else
visitApp e
| Expr.mvar mvarId _ =>
match mctx.getExprAssignment? mvarId with
| Expr.mvar mvarId _ => do
match (← getExprMVarAssignment? mvarId) with
| some a => visit a
| none =>
if pm mvarId then
return true
else
let lctx := (mctx.getDecl mvarId).lctx
let lctx := (← getMCtx).getDecl mvarId |>.lctx
return lctx.any fun decl => pf decl.fvarId
| Expr.fvar fvarId _ => return pf fvarId
| _ => pure false
visit e
@[inline] partial def main (mctx : MetavarContext) (pf : FVarId → Bool) (pm : MVarId → Bool) (e : Expr) : M Bool :=
if !e.hasFVar && !e.hasMVar then pure false else dep mctx pf pm e
@[inline] partial def main (pf : FVarId → Bool) (pm : MVarId → Bool) (e : Expr) : M Bool :=
if !e.hasFVar && !e.hasMVar then pure false else dep pf pm e
end DependsOn
@ -692,40 +717,45 @@ end DependsOn
1- If `?m := t`, then we visit `t` looking for `x`
2- If `?m` is unassigned, then we consider the worst case and check whether `x` is in the local context of `?m`.
This case is a "may dependency". That is, we may assign a term `t` to `?m` s.t. `t` contains `x`. -/
@[inline] def findExprDependsOn (mctx : MetavarContext) (e : Expr) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : Bool :=
DependsOn.main mctx pf pm e |>.run' {}
@[inline] def findExprDependsOn [Monad m] [MonadMCtx m] (e : Expr) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : m Bool := do
let (result, { mctx, .. }) := DependsOn.main pf pm e |>.run { mctx := (← getMCtx) }
setMCtx mctx
return result
/--
Similar to `findExprDependsOn`, but checks the expressions in the given local declaration
depends on a free variable `x` s.t. `pf x` is `true` or an unassigned metavariable `?m` s.t. `pm ?m` is true. -/
@[inline] def findLocalDeclDependsOn (mctx : MetavarContext) (localDecl : LocalDecl) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : Bool :=
@[inline] def findLocalDeclDependsOn [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (pf : FVarId → Bool := fun _ => false) (pm : MVarId → Bool := fun _ => false) : m Bool := do
match localDecl with
| LocalDecl.cdecl (type := t) .. => findExprDependsOn mctx t pf pm
| LocalDecl.ldecl (type := t) (value := v) .. => (DependsOn.main mctx pf pm t <||> DependsOn.main mctx pf pm v).run' {}
| LocalDecl.cdecl (type := t) .. => findExprDependsOn t pf pm
| LocalDecl.ldecl (type := t) (value := v) .. =>
let (result, { mctx, .. }) := (DependsOn.main pf pm t <||> DependsOn.main pf pm v).run { mctx := (← getMCtx) }
setMCtx mctx
return result
def exprDependsOn (mctx : MetavarContext) (e : Expr) (fvarId : FVarId) : Bool :=
findExprDependsOn mctx e (fvarId == ·)
def exprDependsOn [Monad m] [MonadMCtx m] (e : Expr) (fvarId : FVarId) : m Bool :=
findExprDependsOn e (fvarId == ·)
def localDeclDependsOn (mctx : MetavarContext) (localDecl : LocalDecl) (fvarId : FVarId) : Bool :=
findLocalDeclDependsOn mctx localDecl (fvarId == ·)
def localDeclDependsOn [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (fvarId : FVarId) : m Bool :=
findLocalDeclDependsOn localDecl (fvarId == ·)
/-- Similar to `exprDependsOn`, but `x` can be a free variable or an unassigned metavariable. -/
def exprDependsOn' (mctx : MetavarContext) (e : Expr) (x : Expr) : Bool :=
def exprDependsOn' [Monad m] [MonadMCtx m] (e : Expr) (x : Expr) : m Bool :=
if x.isFVar then
findExprDependsOn mctx e (x.fvarId! == ·)
findExprDependsOn e (x.fvarId! == ·)
else if x.isMVar then
findExprDependsOn mctx e (pm := (x.mvarId! == ·))
findExprDependsOn e (pm := (x.mvarId! == ·))
else
false
return false
/-- Similar to `localDeclDependsOn`, but `x` can be a free variable or an unassigned metavariable. -/
def localDeclDependsOn' (mctx : MetavarContext) (localDecl : LocalDecl) (x : Expr) : Bool :=
def localDeclDependsOn' [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (x : Expr) : m Bool :=
if x.isFVar then
findLocalDeclDependsOn mctx localDecl (x.fvarId! == ·)
findLocalDeclDependsOn localDecl (x.fvarId! == ·)
else if x.isMVar then
findLocalDeclDependsOn mctx localDecl (pm := (x.mvarId! == ·))
findLocalDeclDependsOn localDecl (pm := (x.mvarId! == ·))
else
false
return false
namespace MkBinding
@ -762,6 +792,10 @@ structure Context where
abbrev MCore := EStateM Exception State
abbrev M := ReaderT Context MCore
instance : MonadMCtx M where
getMCtx := return (← get).mctx
modifyMCtx f := modify fun s => { s with mctx := f s.mctx }
private def mkFreshBinderName (n : Name := `x) : M Name := do
let fresh ← modifyGet fun s => (s.nextMacroScope, { s with nextMacroScope := s.nextMacroScope + 1 })
return addMacroScope (← read).mainModule n fresh
@ -816,20 +850,20 @@ private def getLocalDeclWithSmallestIdx (lctx : LocalContext) (xs : Array Expr)
Note that https://github.com/leanprover/lean/issues/1258 is not an issue in Lean4 because
we have changed how we compile recursive definitions.
-/
def collectForwardDeps (mctx : MetavarContext) (lctx : LocalContext) (toRevert : Array Expr) (preserveOrder : Bool) : Except Exception (Array Expr) := do
def collectForwardDeps (lctx : LocalContext) (toRevert : Array Expr) : M (Array Expr) := do
if toRevert.size == 0 then
pure toRevert
else
if preserveOrder then
if (← preserveOrder) then
-- Make sure toRevert[j] does not depend on toRevert[i] for j > i
toRevert.size.forM fun i => do
let fvar := toRevert[i]!
i.forM fun j => do
let prevFVar := toRevert[j]!
let prevDecl := lctx.getFVar! prevFVar
if localDeclDependsOn mctx prevDecl fvar.fvarId! then
throw (Exception.revertFailure mctx lctx toRevert prevDecl.userName.toString)
let newToRevert := if preserveOrder then toRevert else Array.mkEmpty toRevert.size
if (← localDeclDependsOn prevDecl fvar.fvarId!) then
throw (Exception.revertFailure (← getMCtx) lctx toRevert prevDecl.userName.toString)
let newToRevert := if (← preserveOrder) then toRevert else Array.mkEmpty toRevert.size
let firstDeclToVisit := getLocalDeclWithSmallestIdx lctx toRevert
let initSize := newToRevert.size
lctx.foldlM (init := newToRevert) (start := firstDeclToVisit.index) fun (newToRevert : Array Expr) decl => do
@ -837,7 +871,7 @@ def collectForwardDeps (mctx : MetavarContext) (lctx : LocalContext) (toRevert :
return newToRevert
else if toRevert.any fun x => decl.fvarId == x.fvarId! then
return newToRevert.push decl.toExpr
else if findLocalDeclDependsOn mctx decl (newToRevert.any fun x => x.fvarId! == ·) then
else if (← findLocalDeclDependsOn decl (newToRevert.any fun x => x.fvarId! == ·)) then
return newToRevert.push decl.toExpr
else
return newToRevert
@ -848,9 +882,6 @@ def reduceLocalContext (lctx : LocalContext) (toRevert : Array Expr) : LocalCont
toRevert.foldr (init := lctx) fun x lctx =>
if x.isFVar then lctx.erase x.fvarId! else lctx
@[inline] private def getMCtx : M MetavarContext :=
return (← get).mctx
/-- Return free variables in `xs` that are in the local context `lctx` -/
private def getInScope (lctx : LocalContext) (xs : Array Expr) : Array Expr :=
xs.foldl (init := #[]) fun scope x =>
@ -930,8 +961,7 @@ mutual
pure (e.abstractRange i xs)
private partial def elimMVar (xs : Array Expr) (mvarId : MVarId) (args : Array Expr) : M (Expr × Array Expr) := do
let mctx ← getMCtx
let mvarDecl := mctx.getDecl mvarId
let mvarDecl := (← getMCtx).getDecl mvarId
let mvarLCtx := mvarDecl.lctx
let toRevert := getInScope mvarLCtx xs
if toRevert.size == 0 then
@ -949,7 +979,7 @@ mutual
A potential disadvantage is that `isDefEq` will not eagerly use `synthPending` for natural metavariables.
That being said, we should try this approach as soon as we have an extensive test suite.
-/
let newMVarKind := if !mctx.isExprAssignable mvarId then MetavarKind.syntheticOpaque else mvarDecl.kind
let newMVarKind := if !(← getMCtx).isExprAssignable mvarId then MetavarKind.syntheticOpaque else mvarDecl.kind
/- If `mvarId` is the lhs of a delayed assignment `?m #[x_1, ... x_n] := val`,
then `nestedFVars` is `#[x_1, ..., x_n]`.
In this case, we produce a new `syntheticOpaque` metavariable `?n` and a delayed assignment
@ -962,39 +992,36 @@ mutual
-/
let rec cont (nestedFVars : Array Expr) : M (Expr × Array Expr) := do
let args ← args.mapM (visit xs)
let preserve ← preserveOrder
-- Note that `toRevert` only contains free variables since it is the result of `getInScope`
match collectForwardDeps mctx mvarLCtx toRevert preserve with
| Except.error ex => throw ex
| Except.ok toRevert =>
let newMVarLCtx := reduceLocalContext mvarLCtx toRevert
let newLocalInsts := mvarDecl.localInstances.filter fun inst => toRevert.all fun x => inst.fvar != x
-- Remark: we must reset the before processing `mkAuxMVarType` because `toRevert` may not be equal to `xs`
let newMVarType ← withFreshCache do mkAuxMVarType mvarLCtx toRevert newMVarKind mvarDecl.type
let newMVarId := { name := (← get).ngen.curr }
let newMVar := mkMVar newMVarId
let result := mkMVarApp mvarLCtx newMVar toRevert newMVarKind
let numScopeArgs := mvarDecl.numScopeArgs + result.getAppNumArgs
modify fun s => { s with
mctx := s.mctx.addExprMVarDecl newMVarId Name.anonymous newMVarLCtx newLocalInsts newMVarType newMVarKind numScopeArgs,
ngen := s.ngen.next
}
match newMVarKind with
| MetavarKind.syntheticOpaque =>
modify fun s => { s with mctx := assignDelayed s.mctx newMVarId (toRevert ++ nestedFVars) (mkAppN (mkMVar mvarId) nestedFVars) }
| _ =>
modify fun s => { s with mctx := assignExpr s.mctx mvarId result }
return (mkAppN result args, toRevert)
let toRevert ← collectForwardDeps mvarLCtx toRevert
let newMVarLCtx := reduceLocalContext mvarLCtx toRevert
let newLocalInsts := mvarDecl.localInstances.filter fun inst => toRevert.all fun x => inst.fvar != x
-- Remark: we must reset the before processing `mkAuxMVarType` because `toRevert` may not be equal to `xs`
let newMVarType ← withFreshCache do mkAuxMVarType mvarLCtx toRevert newMVarKind mvarDecl.type
let newMVarId := { name := (← get).ngen.curr }
let newMVar := mkMVar newMVarId
let result := mkMVarApp mvarLCtx newMVar toRevert newMVarKind
let numScopeArgs := mvarDecl.numScopeArgs + result.getAppNumArgs
modify fun s => { s with
mctx := s.mctx.addExprMVarDecl newMVarId Name.anonymous newMVarLCtx newLocalInsts newMVarType newMVarKind numScopeArgs,
ngen := s.ngen.next
}
match newMVarKind with
| MetavarKind.syntheticOpaque =>
modify fun s => { s with mctx := assignDelayed s.mctx newMVarId (toRevert ++ nestedFVars) (mkAppN (mkMVar mvarId) nestedFVars) }
| _ =>
modify fun s => { s with mctx := assignExpr s.mctx mvarId result }
return (mkAppN result args, toRevert)
if !mvarDecl.kind.isSyntheticOpaque then
cont #[]
else match mctx.getDelayedAssignment? mvarId with
else match (← getDelayedMVarAssignment? mvarId) with
| none => cont #[]
| some { fvars, .. } => cont fvars
private partial def elimApp (xs : Array Expr) (f : Expr) (args : Array Expr) : M Expr := do
match f with
| Expr.mvar mvarId _ =>
match (← getMCtx).getExprAssignment? mvarId with
match (← getExprMVarAssignment? mvarId) with
| some newF =>
if newF.isLambda then
let args ← args.mapM (visit xs)
@ -1096,28 +1123,32 @@ def mkBinding (isLambda : Bool) (xs : Array Expr) (e : Expr) (usedOnly : Bool :=
@[inline] def abstractRange (e : Expr) (n : Nat) (xs : Array Expr) : MkBindingM Expr := fun ctx =>
MkBinding.abstractRange xs n e { preserveOrder := false, mainModule := ctx.mainModule }
@[inline] def collectForwardDeps (toRevert : Array Expr) (preserveOrder : Bool) : MkBindingM (Array Expr) := fun ctx =>
MkBinding.collectForwardDeps ctx.lctx toRevert { preserveOrder, mainModule := ctx.mainModule }
/--
`isWellFormed mctx lctx e` return true if
- All locals in `e` are declared in `lctx`
- All metavariables `?m` in `e` have a local context which is a subprefix of `lctx` or are assigned, and the assignment is well-formed. -/
partial def isWellFormed (mctx : MetavarContext) (lctx : LocalContext) : Expr → Bool
| Expr.mdata _ e _ => isWellFormed mctx lctx e
| Expr.proj _ _ e _ => isWellFormed mctx lctx e
| e@(Expr.app f a _) => (!e.hasExprMVar && !e.hasFVar) || (isWellFormed mctx lctx f && isWellFormed mctx lctx a)
| e@(Expr.lam _ d b _) => (!e.hasExprMVar && !e.hasFVar) || (isWellFormed mctx lctx d && isWellFormed mctx lctx b)
| e@(Expr.forallE _ d b _) => (!e.hasExprMVar && !e.hasFVar) || (isWellFormed mctx lctx d && isWellFormed mctx lctx b)
| e@(Expr.letE _ t v b _) => (!e.hasExprMVar && !e.hasFVar) || (isWellFormed mctx lctx t && isWellFormed mctx lctx v && isWellFormed mctx lctx b)
| Expr.const .. => true
| Expr.bvar .. => true
| Expr.sort .. => true
| Expr.lit .. => true
| Expr.mvar mvarId _ =>
let mvarDecl := mctx.getDecl mvarId;
if mvarDecl.lctx.isSubPrefixOf lctx then true
else match mctx.getExprAssignment? mvarId with
| none => false
| some v => isWellFormed mctx lctx v
| Expr.fvar fvarId _ => lctx.contains fvarId
partial def isWellFormed [Monad m] [MonadMCtx m] (lctx : LocalContext) : Expr → m Bool
| Expr.mdata _ e _ => isWellFormed lctx e
| Expr.proj _ _ e _ => isWellFormed lctx e
| e@(Expr.app f a _) => pure (!e.hasExprMVar && !e.hasFVar) <||> (isWellFormed lctx f <&&> isWellFormed lctx a)
| e@(Expr.lam _ d b _) => pure (!e.hasExprMVar && !e.hasFVar) <||> (isWellFormed lctx d <&&> isWellFormed lctx b)
| e@(Expr.forallE _ d b _) => pure (!e.hasExprMVar && !e.hasFVar) <||> (isWellFormed lctx d <&&> isWellFormed lctx b)
| e@(Expr.letE _ t v b _) => pure (!e.hasExprMVar && !e.hasFVar) <||> (isWellFormed lctx t <&&> isWellFormed lctx v <&&> isWellFormed lctx b)
| Expr.const .. => return true
| Expr.bvar .. => return true
| Expr.sort .. => return true
| Expr.lit .. => return true
| Expr.mvar mvarId _ => do
let mvarDecl := (← getMCtx).getDecl mvarId;
if mvarDecl.lctx.isSubPrefixOf lctx then
return true
else match (← getExprMVarAssignment? mvarId) with
| none => return false
| some v => isWellFormed lctx v
| Expr.fvar fvarId _ => return lctx.contains fvarId
namespace LevelMVarToParam
@ -1134,6 +1165,10 @@ structure State where
abbrev M := ReaderT Context <| StateM State
instance : MonadMCtx M where
getMCtx := return (← get).mctx
modifyMCtx f := modify fun s => { s with mctx := f s.mctx }
instance : MonadCache ExprStructEq Expr M where
findCached? e := return (← get).cache.find? e
cache e v := modify fun s => { s with cache := s.cache.insert e v }
@ -1158,7 +1193,7 @@ partial def visitLevel (u : Level) : M Level := do
| Level.param .. => return u
| Level.mvar mvarId _ =>
let s ← get
match s.mctx.getLevelAssignment? mvarId with
match (← getLevelMVarAssignment? mvarId) with
| some v => visitLevel v
| none =>
if (← read).except mvarId then
@ -1189,7 +1224,7 @@ where
visitApp (f : Expr) (args : Array Expr) : M Expr := do
match f with
| Expr.mvar mvarId .. =>
match (← get).mctx.getExprAssignment? mvarId with
match (← getExprMVarAssignment? mvarId) with
| some v => return (← visitApp v args).headBeta
| none => return mkAppN f (← args.mapM main)
| _ => return mkAppN (← main f) (← args.mapM main)

View file

@ -177,7 +177,7 @@ private def deriveInductiveInstance (indVal : InductiveVal) (params : Array Expr
let argFVars ← argVars.mapM (LocalDecl.fvarId <$> getFVarLocalDecl ·)
for arg in argVars do
let argTp ← inferType arg
if (← getMCtx).findExprDependsOn argTp (pf := fun fv => argFVars.contains fv) then
if (← MetavarContext.findExprDependsOn argTp (pf := fun fv => argFVars.contains fv)) then
throwError "cross-argument dependencies are not supported ({arg} : {argTp})"
if (← acc.encArgTypes.getMatch argTp).isEmpty then

View file

@ -10,26 +10,26 @@ namespace Lean
/--
Return true if `e` does **not** contain `mvarId` directly or indirectly
This function considers assigments and delayed assignments. -/
partial def MetavarContext.occursCheck (mctx : MetavarContext) (mvarId : MVarId) (e : Expr) : Bool :=
partial def occursCheck [Monad m] [MonadMCtx m] (mvarId : MVarId) (e : Expr) : m Bool := do
if !e.hasExprMVar then
true
return true
else
match visit e |>.run {} with
| EStateM.Result.ok .. => true
| EStateM.Result.error .. => false
match (← visit e |>.run |>.run {}) with
| (.ok .., _) => return true
| (.error .., _) => return false
where
visitMVar (mvarId' : MVarId) : EStateM Unit ExprSet Unit := do
visitMVar (mvarId' : MVarId) : ExceptT Unit (StateT ExprSet m) Unit := do
if mvarId == mvarId' then
throw () -- found
else
match mctx.getExprAssignment? mvarId' with
match (← getExprMVarAssignment? mvarId') with
| some v => visit v
| none =>
match mctx.getDelayedAssignment? mvarId' with
match (← getDelayedMVarAssignment? mvarId') with
| some d => visit d.val
| none => return ()
visit (e : Expr) : EStateM Unit ExprSet Unit := do
visit (e : Expr) : ExceptT Unit (StateT ExprSet m) Unit := do
if !e.hasExprMVar then
return ()
else if (← get).contains e then