feat: abstract metavars in grind preprocessor (#8299)

This PR implements a missing preprocessing step in `grind`: abstract
metavariables in the goal
This commit is contained in:
Leonardo de Moura 2025-05-12 07:53:54 -07:00 committed by GitHub
parent eda467e066
commit 3f75f08e1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 36 additions and 10 deletions

View file

@ -15,7 +15,8 @@ structure State where
mctx : MetavarContext
nextParamIdx : Nat := 0
paramNames : Array Name := #[]
fvars : Array Expr := #[]
fvars : Array Expr := #[]
mvars : Array Expr := #[]
lmap : Std.HashMap LMVarId Level := {}
emap : Std.HashMap MVarId Expr := {}
abstractLevels : Bool -- whether to abstract level mvars
@ -100,8 +101,9 @@ partial def abstractExprMVars (e : Expr) : M Expr := do
pure decl.userName
modify fun s => {
s with
emap := s.emap.insert mvarId fvar,
fvars := s.fvars.push fvar,
emap := s.emap.insert mvarId fvar
fvars := s.fvars.push fvar
mvars := s.mvars.push e
lctx := s.lctx.mkLocalDecl fvarId userName type }
return fvar
@ -111,7 +113,7 @@ end AbstractMVars
Abstract (current depth) metavariables occurring in `e`.
The result contains
- An array of universe level parameters that replaced universe metavariables occurring in `e`.
- The number of (expr) metavariables abstracted.
- The metavariables that have been abstracted.
- And an expression of the form `fun (m_1 : A_1) ... (m_k : A_k) => e'`, where
`k` equal to the number of (expr) metavariables abstracted, and `e'` is `e` after we
replace the metavariables.
@ -126,7 +128,10 @@ end AbstractMVars
If `levels := false`, then level metavariables are not abstracted.
Application: we use this method to cache the results of type class resolution. -/
Application: we use this method to cache the results of type class resolution.
Application: tactic `MVarId.abstractMVars`
-/
def abstractMVars (e : Expr) (levels : Bool := true): MetaM AbstractMVarsResult := do
let e ← instantiateMVars e
let (e, s) := AbstractMVars.abstractExprMVars e
@ -134,7 +139,7 @@ def abstractMVars (e : Expr) (levels : Bool := true): MetaM AbstractMVarsResult
setNGen s.ngen
setMCtx s.mctx
let e := s.lctx.mkLambda s.fvars e
pure { paramNames := s.paramNames, numMVars := s.fvars.size, expr := e }
pure { paramNames := s.paramNames, mvars := s.mvars, expr := e }
def openAbstractMVarsResult (a : AbstractMVarsResult) : MetaM (Array Expr × Array BinderInfo × Expr) := do
let us ← a.paramNames.mapM fun _ => mkFreshLevelMVar

View file

@ -317,10 +317,13 @@ structure SynthInstanceCacheKey where
/-- Resulting type for `abstractMVars` -/
structure AbstractMVarsResult where
paramNames : Array Name
numMVars : Nat
mvars : Array Expr
expr : Expr
deriving Inhabited, BEq
def AbstractMVarsResult.numMVars (r : AbstractMVarsResult) : Nat :=
r.mvars.size
abbrev SynthInstanceCache := PersistentHashMap SynthInstanceCacheKey (Option AbstractMVarsResult)
-- Key for `InferType` and `WHNF` caches

View file

@ -771,7 +771,7 @@ private def cacheResult (cacheKey : SynthInstanceCacheKey) (abstResult? : Option
if abstResult.numMVars == 0 && abstResult.paramNames.isEmpty then
-- See `applyCachedAbstractResult?` If new metavariables have **not** been introduced,
-- we don't need to perform extra checks again when reusing result.
modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert cacheKey (some { expr := result, paramNames := #[], numMVars := 0 }) }
modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert cacheKey (some { expr := result, paramNames := #[], mvars := #[] }) }
else
modify fun s => { s with cache.synthInstance := s.cache.synthInstance.insert cacheKey (some abstResult) }

View file

@ -94,8 +94,7 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
activateTheorem thm 0
private def initCore (mvarId : MVarId) (params : Params) : GrindM (List Goal) := do
-- TODO: abstract metavars
mvarId.ensureNoMVar
let mvarId ← mvarId.abstractMVars
let mvarId ← mvarId.clearAuxDecls
let mvarId ← mvarId.revertAll
let mvarId ← mvarId.unfoldReducible

View file

@ -21,6 +21,19 @@ def _root_.Lean.MVarId.ensureNoMVar (mvarId : MVarId) : MetaM Unit := do
if type.hasExprMVar then
throwTacticEx `grind mvarId "goal contains metavariables"
/-- Abstracts metavariables occurring in the target. -/
def _root_.Lean.MVarId.abstractMVars (mvarId : MVarId) : MetaM MVarId := do
mvarId.checkNotAssigned `grind
let type ← instantiateMVars (← mvarId.getType)
unless type.hasExprMVar do return mvarId
mvarId.withContext do
let r ← Meta.abstractMVars type (levels := false)
let typeNew ← lambdaTelescope r.expr fun xs body => mkForallFVars xs body
let tag ← mvarId.getTag
let mvarNew ← mkFreshExprSyntheticOpaqueMVar typeNew tag
mvarId.assign (mkAppN mvarNew r.mvars)
return mvarNew.mvarId!
def _root_.Lean.MVarId.transformTarget (mvarId : MVarId) (f : Expr → MetaM Expr) : MetaM MVarId := mvarId.withContext do
mvarId.checkNotAssigned `grind
let tag ← mvarId.getTag

View file

@ -0,0 +1,6 @@
set_option grind.warning false
example (xs : Array Nat)
(w : ∀ (j : Nat), 0 ≤ j → ∀ (x : j < xs.size / 2), xs[j] = xs[xs.size - 1 - j])
(i : Nat) (hi₁ : i < xs.reverse.size) (hi₂ : i < xs.size) (h : i < xs.size / 2) : xs.reverse[i] = xs[i] := by
rw [w] <;> grind