refactor: use mkAuxLemma in mkAuxTheorem (#7762)

cc @Kha

---------

Co-authored-by: Sebastian Ullrich <sebasti@nullri.ch>
This commit is contained in:
Leonardo de Moura 2025-03-31 15:50:30 -07:00 committed by GitHub
parent d6303a8e7f
commit bb07a732e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 62 additions and 78 deletions

View file

@ -87,11 +87,11 @@ def applyAttributesOf (preDefs : Array PreDefinition) (applicationTime : Attribu
for preDef in preDefs do
applyAttributesAt preDef.declName preDef.modifiers.attrs applicationTime
def abstractNestedProofs (preDef : PreDefinition) : MetaM PreDefinition := withRef preDef.ref do
def abstractNestedProofs (preDef : PreDefinition) (cache := true) : MetaM PreDefinition := withRef preDef.ref do
if preDef.kind.isTheorem || preDef.kind.isExample then
pure preDef
else do
let value ← Meta.abstractNestedProofs preDef.declName preDef.value
let value ← Meta.abstractNestedProofs (cache := cache) preDef.declName preDef.value
pure { preDef with value := value }
/-- Auxiliary method for (temporarily) adding pre definition as an axiom -/
@ -121,9 +121,9 @@ private def reportTheoremDiag (d : TheoremVal) : TermElabM Unit := do
-- let info
logInfo <| MessageData.trace { cls := `theorem } m!"{d.name}" (#[sizeMsg] ++ constOccsMsg)
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List Name) (applyAttrAfterCompilation := true) : TermElabM Unit :=
private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List Name) (applyAttrAfterCompilation := true) (cacheProofs := true) : TermElabM Unit :=
withRef preDef.ref do
let preDef ← abstractNestedProofs preDef
let preDef ← abstractNestedProofs (cache := cacheProofs) preDef
let mkDefDecl : TermElabM Declaration :=
return Declaration.defnDecl {
name := preDef.declName, levelParams := preDef.levelParams, type := preDef.type, value := preDef.value
@ -168,8 +168,8 @@ private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List N
def addAndCompileNonRec (preDef : PreDefinition) (all : List Name := [preDef.declName]) : TermElabM Unit := do
addNonRecAux preDef (compile := true) (all := all)
def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all : List Name := [preDef.declName]) : TermElabM Unit := do
addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation) (all := all)
def addNonRec (preDef : PreDefinition) (applyAttrAfterCompilation := true) (all : List Name := [preDef.declName]) (cacheProofs := true) : TermElabM Unit := do
addNonRecAux preDef (compile := false) (applyAttrAfterCompilation := applyAttrAfterCompilation) (all := all) (cacheProofs := cacheProofs)
/--
Eliminate recursive application annotations containing syntax. These annotations are used by the well-founded recursion module

View file

@ -27,7 +27,7 @@ where
go (fvars.push x) (vals.map fun val => val.bindingBody!.instantiate1 x)
def addPreDefsFromUnary (preDefs : Array PreDefinition) (preDefsNonrec : Array PreDefinition)
(unaryPreDefNonRec : PreDefinition) : TermElabM Unit := do
(unaryPreDefNonRec : PreDefinition) (cacheProofs := true) : TermElabM Unit := do
/-
We must remove `implemented_by` attributes from the auxiliary application because
this attribute is only relevant for code that is compiled. Moreover, the `[implemented_by <decl>]`
@ -41,21 +41,21 @@ def addPreDefsFromUnary (preDefs : Array PreDefinition) (preDefsNonrec : Array P
-- we recognize that below and then do not set @[irreducible]
withOptions (allowUnsafeReducibility.set · true) do
if unaryPreDefNonRec.declName = preDefs[0]!.declName then
addNonRec preDefNonRec (applyAttrAfterCompilation := false)
addNonRec preDefNonRec (applyAttrAfterCompilation := false) (cacheProofs := cacheProofs)
else
withEnableInfoTree false do
addNonRec preDefNonRec (applyAttrAfterCompilation := false)
preDefsNonrec.forM (addNonRec · (applyAttrAfterCompilation := false) (all := declNames))
addNonRec preDefNonRec (applyAttrAfterCompilation := false) (cacheProofs := cacheProofs)
preDefsNonrec.forM (addNonRec · (applyAttrAfterCompilation := false) (all := declNames) (cacheProofs := cacheProofs))
/--
Cleans the right-hand-sides of the predefinitions, to prepare for inclusion in the EqnInfos:
* Remove RecAppSyntax markers
* Abstracts nested proofs (and for that, add the `_unsafe_rec` definitions)
-/
def cleanPreDefs (preDefs : Array PreDefinition) : TermElabM (Array PreDefinition) := do
def cleanPreDefs (preDefs : Array PreDefinition) (cacheProofs := true) : TermElabM (Array PreDefinition) := do
addAndCompilePartialRec preDefs
let preDefs ← preDefs.mapM (eraseRecAppSyntax ·)
let preDefs ← preDefs.mapM (abstractNestedProofs ·)
let preDefs ← preDefs.mapM (abstractNestedProofs (cache := cacheProofs) ·)
return preDefs
/--

View file

@ -66,8 +66,8 @@ def wfRecursion (preDefs : Array PreDefinition) (termMeasure?s : Array (Option T
trace[Elab.definition.wf] ">> {preDefNonRec.declName} :=\n{preDefNonRec.value}"
let preDefsNonrec ← preDefsFromUnaryNonRec fixedParamPerms argsPacker preDefs preDefNonRec
Mutual.addPreDefsFromUnary preDefs preDefsNonrec preDefNonRec
let preDefs ← Mutual.cleanPreDefs preDefs
Mutual.addPreDefsFromUnary (cacheProofs := false) preDefs preDefsNonrec preDefNonRec
let preDefs ← Mutual.cleanPreDefs (cacheProofs := false) preDefs
registerEqnsInfo preDefs preDefNonRec.declName fixedParamPerms argsPacker
for preDef in preDefs, wfPreprocessProof in wfPreprocessProofs do
unless preDef.kind.isTheorem do

View file

@ -20,8 +20,6 @@ def elabAsAuxLemma : Lean.Elab.Tactic.Tactic
unless mvars.isEmpty do
throwError "Cannot abstract term into auxiliary lemma because there are open goals."
let e ← instantiateMVars (mkMVar mvarId)
let env ← getEnv
-- TODO: this likely should share name creation code with `mkAuxLemma`
let e ← mkAuxTheorem (← mkFreshUserName <| env.asyncPrefix?.getD env.mainModule ++ `_auxLemma) (← mvarId.getType) e
let e ← mkAuxTheorem (prefix? := (← Term.getDeclName?)) (← mvarId.getType) e
mvarId.assign e
| _ => throwError "Invalid as_aux_lemma syntax"

View file

@ -141,11 +141,11 @@ def grind
let result ← Grind.main mvar'.mvarId! params mainDeclName fallback
if result.hasFailures then
throwError "`grind` failed\n{← result.toMessageData}"
let auxName ← Term.mkAuxName `grind
-- `grind` proofs are often big
let e ← if (← isProp type) then
mkAuxTheorem auxName type (← instantiateMVarsProfiling mvar') (zetaDelta := true)
mkAuxTheorem (prefix? := mainDeclName) type (← instantiateMVarsProfiling mvar') (zetaDelta := true)
else
let auxName ← Term.mkAuxName `grind
mkAuxDefinition auxName type (← instantiateMVarsProfiling mvar') (zetaDelta := true)
mvarId.assign e
return result.trace

View file

@ -672,7 +672,7 @@ open Lean Elab Tactic Parser.Tactic
/-- The `omega` tactic, for resolving integer and natural linear arithmetic problems. -/
def omegaTactic (cfg : OmegaConfig) : TacticM Unit := do
let auxName ← Term.mkAuxName `omega
let declName? ← Term.getDeclName?
liftMetaFinishingTactic fun g => do
let some g ← g.falseOrByContra | return ()
g.withContext do
@ -682,7 +682,7 @@ def omegaTactic (cfg : OmegaConfig) : TacticM Unit := do
trace[omega] "analyzing {hyps.length} hypotheses:\n{← hyps.mapM inferType}"
omega hyps g'.mvarId! cfg
-- Omega proofs are typically rather large, so hide them in a separate definition
let e ← mkAuxTheorem auxName type (← instantiateMVarsProfiling g') (zetaDelta := true)
let e ← mkAuxTheorem (prefix? := declName?) type (← instantiateMVarsProfiling g') (zetaDelta := true)
g.assign e

View file

@ -31,6 +31,7 @@ def isNonTrivialProof (e : Expr) : MetaM Bool := do
pure $ !f.isAtomic || args.any fun arg => !arg.isAtomic
structure Context where
cache : Bool
baseName : Name
structure State where
@ -74,21 +75,19 @@ where
let type ← zetaReduce type
/- Ensure proofs nested in type are also abstracted -/
let type ← visit type
let lemmaName ← mkAuxName (ctx.baseName ++ `proof) (← get).nextIdx
modify fun s => { s with nextIdx := s.nextIdx + 1 }
/- We turn on zetaDelta-expansion to make sure we don't need to perform an expensive `check` step to
identify which let-decls can be abstracted. If we design a more efficient test, we can avoid the eager zetaDelta expansion step.
It a benchmark created by @selsam, The extra `check` step was a bottleneck. -/
mkAuxTheorem lemmaName type e (zetaDelta := true)
mkAuxTheorem (prefix? := ctx.baseName) (cache := ctx.cache) type e (zetaDelta := true)
end AbstractNestedProofs
/-- Replace proofs nested in `e` with new lemmas. The new lemmas have names of the form `mainDeclName.proof_<idx>` -/
def abstractNestedProofs (mainDeclName : Name) (e : Expr) : MetaM Expr := do
def abstractNestedProofs (mainDeclName : Name) (e : Expr) (cache := true) : MetaM Expr := do
if (← isProof e) then
-- `e` is a proof itself. So, we don't abstract nested proofs
return e
else
AbstractNestedProofs.visit e |>.run { baseName := mainDeclName } |>.run |>.run' { nextIdx := 1 }
AbstractNestedProofs.visit e |>.run { cache, baseName := mainDeclName } |>.run |>.run' { nextIdx := 1 }
end Lean.Meta

View file

@ -10,6 +10,7 @@ import Lean.AddDecl
import Lean.Util.FoldConsts
import Lean.Meta.Basic
import Lean.Meta.Check
import Lean.Meta.Tactic.AuxLemma
/-!
@ -391,36 +392,9 @@ def mkAuxDefinitionFor (name : Name) (value : Expr) (zetaDelta : Bool := false)
/--
Create an auxiliary theorem with the given name, type and value. It is similar to `mkAuxDefinition`.
-/
def mkAuxTheorem (name : Name) (type : Expr) (value : Expr) (zetaDelta : Bool := false) : MetaM Expr := do
def mkAuxTheorem (type : Expr) (value : Expr) (zetaDelta : Bool := false) (prefix? : Option Name) (cache := true) : MetaM Expr := do
let result ← Closure.mkValueTypeClosure type value zetaDelta
let env ← getEnv
let decl :=
if env.hasUnsafe result.type || env.hasUnsafe result.value then
-- `result` contains unsafe code, thus we cannot use a theorem.
Declaration.defnDecl {
name
levelParams := result.levelParams.toList
type := result.type
value := result.value
hints := ReducibilityHints.opaque
safety := DefinitionSafety.unsafe
}
else
Declaration.thmDecl {
name
levelParams := result.levelParams.toList
type := result.type
value := result.value
}
addDecl decl
let name ← mkAuxLemma (prefix? := prefix?) (cache := cache) result.levelParams.toList result.type result.value
return mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
/--
Similar to `mkAuxTheorem`, but infers the type of `value`.
-/
def mkAuxTheoremFor (name : Name) (value : Expr) (zetaDelta : Bool := false) : MetaM Expr := do
let type ← inferType value
let type := type.headBeta
mkAuxTheorem name type value zetaDelta
end Lean.Meta

View file

@ -28,19 +28,32 @@ builtin_initialize auxLemmasExt : EnvExtension AuxLemmas ←
This method is useful for tactics (e.g., `simp`) that may perform preprocessing steps to lemmas provided by
users. For example, `simp` preprocessor may convert a lemma into multiple ones.
-/
def mkAuxLemma (levelParams : List Name) (type : Expr) (value : Expr) : MetaM Name := do
def mkAuxLemma (levelParams : List Name) (type : Expr) (value : Expr) (prefix? : Option Name := none) (cache := true) : MetaM Name := do
let env ← getEnv
let s := auxLemmasExt.getState env
let mkNewAuxLemma := do
let auxName := Name.mkNum (env.asyncPrefix?.getD env.mainModule ++ `_auxLemma) s.idx
addDecl <| Declaration.thmDecl {
name := auxName
levelParams, type, value
}
let auxName := prefix?.getD (env.asyncPrefix?.getD (mkPrivateName env .anonymous)) ++ `_proof |>.appendIndexAfter s.idx
let decl :=
if env.hasUnsafe type || env.hasUnsafe value then
-- `result` contains unsafe code, thus we cannot use a theorem.
Declaration.defnDecl {
name := auxName
hints := ReducibilityHints.opaque
safety := DefinitionSafety.unsafe
levelParams, type, value
}
else
Declaration.thmDecl {
name := auxName
levelParams, type, value
}
addDecl decl
modifyEnv fun env => auxLemmasExt.modifyState env fun ⟨idx, lemmas⟩ => ⟨idx + 1, lemmas.insert type (auxName, levelParams)⟩
return auxName
match s.lemmas.find? type with
| some (name, levelParams') => if levelParams == levelParams' then return name else mkNewAuxLemma
| none => mkNewAuxLemma
if cache then
if let some (name, levelParams') := s.lemmas.find? type then
if levelParams == levelParams' then
return name
mkNewAuxLemma
end Lean.Meta

View file

@ -199,7 +199,7 @@ Abtracts nested proofs in `e`. This is a preprocessing step performed before int
-/
def abstractNestedProofs (e : Expr) : GrindM Expr := do
let nextIdx := (← get).nextThmIdx
let (e, s') ← AbstractNestedProofs.visit e |>.run { baseName := (← getMainDeclName) } |>.run |>.run { nextIdx }
let (e, s') ← AbstractNestedProofs.visit e |>.run { cache := true, baseName := (← getMainDeclName) } |>.run |>.run { nextIdx }
modify fun s => { s with nextThmIdx := s'.nextIdx }
return e

View file

@ -16,7 +16,7 @@ info: foo.eq_def (n : Nat) :
if n = 0 then 0
else
let x := n - 1;
let_fun this := foo.proof_4;
let_fun this := foo._proof_4;
foo x
-/
#guard_msgs in

View file

@ -11,9 +11,9 @@ set_option pp.explicit true
/--
info: def foo : Foo :=
{ obj := fun x => @Function.const Type (@Eq Unit Unit.unit Unit.unit) Nat foo.proof_1,
{ obj := fun x => @Function.const Type (@Eq Unit Unit.unit Unit.unit) Nat foo._proof_1,
map :=
@id (@Function.const Type (@Eq Unit Unit.unit Unit.unit) Nat foo.proof_1) (@OfNat.ofNat Nat 0 (instOfNatNat 0)) }
@id (@Function.const Type (@Eq Unit Unit.unit Unit.unit) Nat foo._proof_1) (@OfNat.ofNat Nat 0 (instOfNatNat 0)) }
-/
#guard_msgs in
#print foo

View file

@ -7,12 +7,12 @@ structure Foo (n : Nat) (h : n > 1 := by omega) : Type
set_option pp.proofs true
/-- info: Foo 3 (Decidable.byContradiction fun a => _check.omega_1 a) : Type -/
/-- info: Foo 3 (Decidable.byContradiction fun a => _check._proof_1 a) : Type -/
#guard_msgs in
#check Foo 3
variable (x : Foo 2)
/-- info: x : Foo 2 (Decidable.byContradiction fun a => aux_2 a) -/
/-- info: x : Foo 2 (Decidable.byContradiction fun a => _proof_1 a) -/
#guard_msgs in
#check x

View file

@ -60,12 +60,12 @@ theorem thm1' : ∀ x < 100, x * x ≤ 10000 := by decide +kernel
/--
info: theorem thm1 : ∀ (x : Nat), x < 100 → x * x ≤ 10000 :=
thm1._auxLemma.1
thm1._proof_1
-/
#guard_msgs in #print thm1
/--
info: theorem thm1' : ∀ (x : Nat), x < 100 → x * x ≤ 10000 :=
thm1'._auxLemma.1
thm1'._proof_1
-/
#guard_msgs in #print thm1'

View file

@ -260,7 +260,7 @@ theorem ex1 (p : Prop) (a1 a2 a3 : Nat) : (p ↔ a2 ≤ a1) → ¬p → a2 + 3
grind
/--
info: theorem ex1.grind_1 : ∀ {a4 : Nat} (p : Prop) (a1 a2 a3 : Nat),
info: theorem ex1._proof_1 : ∀ {a4 : Nat} (p : Prop) (a1 a2 a3 : Nat),
(p ↔ a2 ≤ a1) → ¬p → a2 + 3 ≤ a3 → (p ↔ a4 ≤ a3 + 2) → a1 ≤ a4 :=
fun {a4} p a1 a2 a3 =>
intro_with_eq (p ↔ a2 ≤ a1) (p = (a2 ≤ a1)) (¬p → a2 + 3 ≤ a3 → (p ↔ a4 ≤ a3 + 2) → a1 ≤ a4) (iff_eq p (a2 ≤ a1))
@ -277,7 +277,7 @@ fun {a4} p a1 a2 a3 =>
-/
#guard_msgs (info) in
open Lean Grind in
#print ex1.grind_1
#print ex1._proof_1
/-! Propagate `cnstr = False` tests -/

View file

@ -17,7 +17,7 @@ info: f.induct (motive : (n : Nat) → n % 2 = 1 → (m : Nat) → (n + m) % 2 =
(case3 :
∀ (n' : Nat) (hn : (n' + 3) % 2 = 1) (m' : Nat) (hm : (n' + 3 + (m' + 1)) % 2 = 1),
(n' + 3 + m'.succ) % 2 = 1 →
motive n' (f.proof_1 n') m' (f.proof_2 n' m') → motive n'.succ.succ.succ hn m'.succ hm)
motive n' (f._proof_1 n') m' (f._proof_2 n' m') → motive n'.succ.succ.succ hn m'.succ hm)
(n : Nat) (hn : n % 2 = 1) (m : Nat) (hm : (n + m) % 2 = 1) : motive n hn m hm
-/
#guard_msgs in

View file

@ -23,7 +23,7 @@ def g (i j k : Nat) (a : Array Nat) (h₁ : i < k) (h₂ : k < j) (h₃ : j < a.
set_option pp.all true in
#print g
#check g.proof_1
#check g._proof_1
theorem ex1 {p q r s : Prop} : p ∧ q ∧ r ∧ s → r ∧ s ∧ q ∧ p :=
fun ⟨hp, hq, hr, hs⟩ => ⟨hr, hs, hq, hp⟩

View file

@ -311,7 +311,7 @@ namespace hashmap
else Result.ok (ntable, slots)
partial_fixpoint
set_option pp.proofs true in
#print HashMap.move_elements_loop.proof_2
#print HashMap.move_elements_loop._proof_13
def HashMap.move_elements
{T : Type} (ntable : HashMap T) (slots : alloc.vec.Vec (AList T)) :

View file

@ -1,5 +1,5 @@
@[irreducible] def f : Nat → Nat :=
f.proof_1.fix fun n a =>
f._proof_1.fix fun n a =>
if h : n = 0 then 1
else
let y := 42;