refactor: simp arg elaboration (#8815)

This PR refactors the way simp arguments are elaborated: Instead of
changing the `SimpTheorems` structure as we go, this elaborates each
argument to a more declarative description of what it does, and then
apply those. This enables more interesting checks of simp arguments that
need to happen in the context of the eventually constructed simp context
(the checks in #8688), or after simp has run (unused argument linter
#8901).

The new data structure describing an elaborated simp argument isn’t the
most elegant, but follows from the code.

While I am at it, move handling of `[*]` into `elabSimpArgs`. Downstream
adaption branches exist (but may not be fully up to date because of the
permission changes).

While I am at it, I cleaned up `SimpTheorems.lean` file a bit (sorting
declarations, mild renaming) and added documentation.
This commit is contained in:
Joachim Breitner 2025-06-21 19:55:53 +02:00 committed by GitHub
parent 85992757e7
commit 4d697874b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 535 additions and 332 deletions

View file

@ -266,7 +266,7 @@ def evalConvNormCast : Tactic :=
@[builtin_tactic pushCast]
def evalPushCast : Tactic := fun stx => do
let { ctx, simprocs, dischargeWrapper } ← withMainContext do
let { ctx, simprocs, dischargeWrapper, .. } ← withMainContext do
mkSimpContext (simpTheorems := pushCastExt.getTheorems) stx (eraseLocal := false)
let ctx := ctx.setFailIfUnchanged false
dischargeWrapper.with fun discharge? =>

View file

@ -91,56 +91,6 @@ def elabSimpConfig (optConfig : Syntax) (kind : SimpKind) : TacticM Meta.Simp.Co
| .simpAll => return (← elabSimpConfigCtxCore optConfig).toConfig
| .dsimp => return { (← elabDSimpConfigCore optConfig) with }
private def addDeclToUnfoldOrTheorem (config : Meta.ConfigWithKey) (thms : SimpTheorems) (id : Origin) (e : Expr) (post : Bool) (inv : Bool) (kind : SimpKind) : MetaM SimpTheorems := do
if e.isConst then
let declName := e.constName!
let info ← getConstVal declName
if (← isProp info.type) then
thms.addConst declName (post := post) (inv := inv)
else
if inv then
throwError "invalid '←' modifier, '{declName}' is a declaration name to be unfolded"
if kind == .dsimp then
return thms.addDeclToUnfoldCore declName
else
thms.addDeclToUnfold declName
else if e.isFVar then
let fvarId := e.fvarId!
let decl ← fvarId.getDecl
if (← isProp decl.type) then
thms.add id #[] e (post := post) (inv := inv) (config := config)
else if !decl.isLet then
throwError "invalid argument, variable is not a proposition or let-declaration"
else if inv then
throwError "invalid '←' modifier, '{e}' is a let-declaration name to be unfolded"
else
return thms.addLetDeclToUnfold fvarId
else
thms.add id #[] e (post := post) (inv := inv) (config := config)
private def addSimpTheorem (config : Meta.ConfigWithKey) (thms : SimpTheorems) (id : Origin) (stx : Syntax) (post : Bool) (inv : Bool) : TermElabM SimpTheorems := do
let thm? ← Term.withoutModifyingElabMetaStateWithInfo <| withRef stx do
let e ← Term.elabTerm stx none
Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true)
let e ← instantiateMVars e
if e.hasSyntheticSorry then
return none
let e := e.eta
if e.hasMVar then
let r ← abstractMVars e
return some (r.paramNames, r.expr)
else
return some (#[], e)
if let some (levelParams, proof) := thm? then
thms.add id levelParams proof (post := post) (inv := inv) (config := config)
else
return thms
structure ElabSimpArgsResult where
ctx : Simp.Context
simprocs : Simp.SimprocsArray
starArg : Bool := false
inductive ResolveSimpIdResult where
| none
| expr (e : Expr)
@ -154,104 +104,8 @@ inductive ResolveSimpIdResult where
-/
| ext (ext₁? : Option SimpExtension) (ext₂? : Option Simp.SimprocExtension) (h : ext₁?.isSome || ext₂?.isSome)
/--
Elaborate extra simp theorems provided to `simp`. `stx` is of the form `"[" simpTheorem,* "]"`
If `eraseLocal == true`, then we consider local declarations when resolving names for erased theorems (`- id`),
this option only makes sense for `simp_all` or `*` is used.
When `recover := true`, try to recover from errors as much as possible so that users keep seeing
the current goal.
-/
def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (eraseLocal : Bool) (kind : SimpKind) : TacticM ElabSimpArgsResult := do
if stx.isNone then
return { ctx, simprocs }
else
/-
syntax simpPre := "↓"
syntax simpPost := "↑"
syntax simpLemma := (simpPre <|> simpPost)? "← "? term
syntax simpErase := "-" ident
-/
let go := withMainContext do
let zetaDeltaSet ← toZetaDeltaSet stx ctx
withTrackingZetaDeltaSet zetaDeltaSet do
let mut thmsArray := ctx.simpTheorems
let mut thms := thmsArray[0]!
let mut simprocs := simprocs
let mut starArg := false
for arg in stx[1].getSepArgs do
try -- like withLogging, but compatible with do-notation
if arg.getKind == ``Lean.Parser.Tactic.simpErase then
let fvar? ← if eraseLocal || starArg then Term.isLocalIdent? arg[1] else pure none
if let some fvar := fvar? then
-- We use `eraseCore` because the simp theorem for the hypothesis was not added yet
thms := thms.eraseCore (.fvar fvar.fvarId!)
else
let id := arg[1]
if let .ok declName ← observing (realizeGlobalConstNoOverloadWithInfo id) then
if (← Simp.isSimproc declName) then
simprocs := simprocs.erase declName
else if ctx.config.autoUnfold then
thms := thms.eraseCore (.decl declName)
else
thms ← withRef id <| thms.erase (.decl declName)
else
-- If `id` could not be resolved, we should check whether it is a builtin simproc.
-- before returning error.
let name := id.getId.eraseMacroScopes
if (← Simp.isBuiltinSimproc name) then
simprocs := simprocs.erase name
else
throwUnknownConstantAt id name
else if arg.getKind == ``Lean.Parser.Tactic.simpLemma then
let post :=
if arg[0].isNone then
true
else
arg[0][0].getKind == ``Parser.Tactic.simpPost
let inv := !arg[1].isNone
let term := arg[2]
match (← resolveSimpIdTheorem? term) with
| .expr e =>
let name ← mkFreshId
thms ← addDeclToUnfoldOrTheorem ctx.indexConfig thms (.stx name arg) e post inv kind
| .simproc declName =>
simprocs ← simprocs.add declName post
| .ext (some ext₁) (some ext₂) _ =>
thmsArray := thmsArray.push (← ext₁.getTheorems)
simprocs := simprocs.push (← ext₂.getSimprocs)
| .ext (some ext₁) none _ =>
thmsArray := thmsArray.push (← ext₁.getTheorems)
| .ext none (some ext₂) _ =>
simprocs := simprocs.push (← ext₂.getSimprocs)
| .none =>
let name ← mkFreshId
thms ← addSimpTheorem ctx.indexConfig thms (.stx name arg) term post inv
else if arg.getKind == ``Lean.Parser.Tactic.simpStar then
starArg := true
else
throwUnsupportedSyntax
catch ex =>
if (← read).recover then
logException ex
else
throw ex
let ctx := ctx.setZetaDeltaSet zetaDeltaSet (← getZetaDeltaFVarIds)
return { ctx := ctx.setSimpTheorems (thmsArray.set! 0 thms), simprocs, starArg }
-- If recovery is disabled, then we want simp argument elaboration failures to be exceptions.
-- This affects `addSimpTheorem`.
if (← read).recover then
go
else
Term.withoutErrToSorry go
where
isSimproc? (e : Expr) : MetaM (Option Name) := do
let .const declName _ := e | return none
unless (← Simp.isSimproc declName) do return none
return some declName
resolveSimpIdTheorem? (simpArgTerm : Term) : TacticM ResolveSimpIdResult := do
let resolveExt (n : Name) : TacticM ResolveSimpIdResult := do
private def resolveSimpIdTheorem? (simpArgTerm : Term) : TermElabM ResolveSimpIdResult := do
let resolveExt (n : Name) : TermElabM ResolveSimpIdResult := do
let ext₁? ← getSimpExtension? n
let ext₂? ← Simp.getSimprocExtension? n
if h : ext₁?.isSome || ext₂?.isSome then
@ -279,7 +133,234 @@ where
return .expr e
else
return .none
where
isSimproc? (e : Expr) : MetaM (Option Name) := do
let .const declName _ := e | return none
unless (← Simp.isSimproc declName) do return none
return some declName
/--
The result of elaborating a single `simp` argument
-/
inductive ElabSimpArgResult where
| addEntries (entries : Array SimpEntry)
| addSimproc («simproc» : Name) (post : Bool)
| addLetToUnfold (fvarId : FVarId)
| ext (ext₁? : Option SimpExtension) (ext₂? : Option Simp.SimprocExtension) (h : ext₁?.isSome || ext₂?.isSome)
| erase (toErase : Origin)
| eraseSimproc (toErase : Name)
| star
| none -- used for example when elaboration fails
private def elabDeclToUnfoldOrTheorem (config : Meta.ConfigWithKey) (id : Origin)
(e : Expr) (post : Bool) (inv : Bool) (kind : SimpKind) : MetaM ElabSimpArgResult := do
if e.isConst then
let declName := e.constName!
let info ← getConstVal declName
if (← isProp info.type) then
let thms ← mkSimpTheoremFromConst declName (post := post) (inv := inv)
return .addEntries <| thms.map (SimpEntry.thm ·)
else
if inv then
throwError "invalid '←' modifier, '{declName}' is a declaration name to be unfolded"
if kind == .dsimp then
return .addEntries #[.toUnfold declName]
else
.addEntries <$> mkSimpEntryOfDeclToUnfold declName
else if e.isFVar then
let fvarId := e.fvarId!
let decl ← fvarId.getDecl
if (← isProp decl.type) then
let thms ← mkSimpTheoremFromExpr id #[] e (post := post) (inv := inv) (config := config)
return .addEntries <| thms.map (SimpEntry.thm ·)
else if !decl.isLet then
throwError "invalid argument, variable is not a proposition or let-declaration"
else if inv then
throwError "invalid '←' modifier, '{e}' is a let-declaration name to be unfolded"
else
return .addLetToUnfold fvarId
else
let thms ← mkSimpTheoremFromExpr id #[] e (post := post) (inv := inv) (config := config)
return .addEntries <| thms.map (SimpEntry.thm ·)
private def elabSimpTheorem (config : Meta.ConfigWithKey) (id : Origin) (stx : Syntax)
(post : Bool) (inv : Bool) : TermElabM ElabSimpArgResult := do
let thm? ← Term.withoutModifyingElabMetaStateWithInfo <| withRef stx do
let e ← Term.elabTerm stx .none
Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true)
let e ← instantiateMVars e
if e.hasSyntheticSorry then
return .none
let e := e.eta
if e.hasMVar then
let r ← abstractMVars e
return some (r.paramNames, r.expr)
else
return some (#[], e)
if let some (levelParams, proof) := thm? then
let thms ← mkSimpTheoremFromExpr id levelParams proof (post := post) (inv := inv) (config := config)
return .addEntries <| thms.map (SimpEntry.thm ·)
else
return .none
private def elabSimpArg (indexConfig : Meta.ConfigWithKey) (eraseLocal : Bool) (kind : SimpKind)
(arg : Syntax) : TacticM ElabSimpArgResult := withRef arg do
try
/-
syntax simpPre := "↓"
syntax simpPost := "↑"
syntax simpLemma := (simpPre <|> simpPost)? "← "? term
syntax simpErase := "-" ident
-/
if arg.getKind == ``Lean.Parser.Tactic.simpErase then
let fvar? ← if eraseLocal then Term.isLocalIdent? arg[1] else pure none
if let some fvar := fvar? then
-- We use `eraseCore` because the simp theorem for the hypothesis was not added yet
return .erase (.fvar fvar.fvarId!)
else
let id := arg[1]
if let .ok declName ← observing (realizeGlobalConstNoOverloadWithInfo id) then
if (← Simp.isSimproc declName) then
return .eraseSimproc declName
else
return .erase (.decl declName)
else
-- If `id` could not be resolved, we should check whether it is a builtin simproc.
-- before returning error.
let name := id.getId.eraseMacroScopes
if (← Simp.isBuiltinSimproc name) then
return .eraseSimproc name
else
throwUnknownConstantAt id name
else if arg.getKind == ``Lean.Parser.Tactic.simpLemma then
let post :=
if arg[0].isNone then
true
else
arg[0][0].getKind == ``Parser.Tactic.simpPost
let inv := !arg[1].isNone
let term := arg[2]
match (← resolveSimpIdTheorem? term) with
| .expr e =>
let name ← mkFreshId
elabDeclToUnfoldOrTheorem indexConfig (.stx name arg) e post inv kind
| .simproc declName =>
return .addSimproc declName post
| .ext ext₁? ext₂? h =>
return .ext ext₁? ext₂? h
| .none =>
let name ← mkFreshId
elabSimpTheorem indexConfig (.stx name arg) term post inv
else if arg.getKind == ``Lean.Parser.Tactic.simpStar then
return .star
else
throwUnsupportedSyntax
catch ex =>
if (← read).recover then
logException ex
return .none
else
throw ex
/--
The result of elaborating a full array of simp arguments and applying them to the simp context.
-/
structure ElabSimpArgsResult where
ctx : Simp.Context
simprocs : Simp.SimprocsArray
/-- The elaborated simp arguments with syntax -/
simpArgs : Array (Syntax × ElabSimpArgResult)
/-- Implements the effect of the `*` attribute. -/
private def applyStarArg (ctx : Simp.Context) : MetaM Simp.Context := do
let mut simpTheorems := ctx.simpTheorems
/-
When using `zetaDelta := false`, we do not expand let-declarations when using `[*]`.
Users must explicitly include it in the list.
-/
let hs ← getPropHyps
for h in hs do
unless simpTheorems.isErased (.fvar h) do
simpTheorems ← simpTheorems.addTheorem (.fvar h) (← h.getDecl).toExpr (config := ctx.indexConfig)
return ctx.setSimpTheorems simpTheorems
/--
Elaborate extra simp theorems provided to `simp`. `stx` is of the form `"[" simpTheorem,* "]"`
If `eraseLocal == true`, then we consider local declarations when resolving names for erased theorems (`- id`),
this option only makes sense for `simp_all` or `*` is used.
When `recover := true`, try to recover from errors as much as possible so that users keep seeing
the current goal.
-/
def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (eraseLocal : Bool)
(kind : SimpKind) (ignoreStarArg := false) : TacticM ElabSimpArgsResult := do
if stx.isNone then
return { ctx, simprocs, simpArgs := #[] }
else
/-
syntax simpPre := "↓"
syntax simpPost := "↑"
syntax simpLemma := (simpPre <|> simpPost)? "← "? term
syntax simpErase := "-" ident
-/
let go := withMainContext do
let zetaDeltaSet ← toZetaDeltaSet stx ctx
withTrackingZetaDeltaSet zetaDeltaSet do
let mut starArg := false -- only after * we can erase local declarations
let mut args : Array (Syntax × ElabSimpArgResult) := #[]
for argStx in stx[1].getSepArgs do
let arg ← elabSimpArg ctx.indexConfig (eraseLocal || starArg) kind argStx
starArg := !ignoreStarArg && (starArg || arg matches .star)
args := args.push (argStx, arg)
let mut thmsArray := ctx.simpTheorems
let mut thms := thmsArray[0]!
let mut simprocs := simprocs
for (ref, arg) in args do
match arg with
| .addEntries entries =>
for entry in entries do
thms := thms.uneraseSimpEntry entry
thms := thms.addSimpEntry entry
| .addLetToUnfold fvarId =>
thms := thms.addLetDeclToUnfold fvarId
| .addSimproc declName post =>
simprocs ← simprocs.add declName post
| .erase origin =>
-- `thms.erase` checks if the erasure is effective.
-- We do not want this check for local hypotheses (they are added later based on `starArg`)
if origin matches .fvar _ then
thms := thms.eraseCore origin
-- Nor for decls to unfold when we do auto unfolding
else if ctx.config.autoUnfold then
thms := thms.eraseCore origin
else
thms ← withRef ref <| thms.erase origin
| .eraseSimproc name =>
simprocs := simprocs.erase name
| .ext simpExt? simprocExt? _ =>
if let some simpExt := simpExt? then
thmsArray := thmsArray.push (← simpExt.getTheorems)
if let some simprocExt := simprocExt? then
simprocs := simprocs.push (← simprocExt.getSimprocs)
| .star => pure ()
| .none => pure ()
let mut ctx := ctx.setZetaDeltaSet zetaDeltaSet (← getZetaDeltaFVarIds)
ctx := ctx.setSimpTheorems (thmsArray.set! 0 thms)
if !ignoreStarArg && starArg then
ctx ← applyStarArg ctx
return { ctx, simprocs, simpArgs := args}
-- If recovery is disabled, then we want simp argument elaboration failures to be exceptions.
-- This affects `addSimpTheorem`.
if (← read).recover then
go
else
Term.withoutErrToSorry go
where
/-- If `zetaDelta := false`, create a `FVarId` set with all local let declarations in the `simp` argument list. -/
toZetaDeltaSet (stx : Syntax) (ctx : Simp.Context) : TacticM FVarIdSet := do
if ctx.config.zetaDelta then return {}
@ -319,6 +400,8 @@ structure MkSimpContextResult where
ctx : Simp.Context
simprocs : Simp.SimprocsArray
dischargeWrapper : Simp.DischargeWrapper
/-- The elaborated simp arguments with syntax -/
simpArgs : Array (Syntax × ElabSimpArgResult) := #[]
/--
Create the `Simp.Context` for the `simp`, `dsimp`, and `simp_all` tactics.
@ -351,23 +434,8 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp)
(config := (← elabSimpConfig stx[1] (kind := kind)))
(simpTheorems := #[simpTheorems])
congrTheorems
let r ← elabSimpArgs stx[4] (eraseLocal := eraseLocal) (kind := kind) (simprocs := #[simprocs]) ctx
if !r.starArg || ignoreStarArg then
return { r with dischargeWrapper }
else
let ctx := r.ctx
let simprocs := r.simprocs
let mut simpTheorems := ctx.simpTheorems
/-
When using `zetaDelta := false`, we do not expand let-declarations when using `[*]`.
Users must explicitly include it in the list.
-/
let hs ← getPropHyps
for h in hs do
unless simpTheorems.isErased (.fvar h) do
simpTheorems ← simpTheorems.addTheorem (.fvar h) (← h.getDecl).toExpr (config := ctx.indexConfig)
let ctx := ctx.setSimpTheorems simpTheorems
return { ctx, simprocs, dischargeWrapper }
let r ← elabSimpArgs stx[4] (eraseLocal := eraseLocal) (kind := kind) (simprocs := #[simprocs]) (ignoreStarArg := ignoreStarArg) ctx
return { r with dischargeWrapper }
register_builtin_option tactic.simp.trace : Bool := {
defValue := false
@ -477,7 +545,7 @@ def withSimpDiagnostics (x : TacticM Simp.Diagnostics) : TacticM Unit := do
(location)?
-/
@[builtin_tactic Lean.Parser.Tactic.simp] def evalSimp : Tactic := fun stx => withMainContext do withSimpDiagnostics do
let { ctx, simprocs, dischargeWrapper } ← mkSimpContext stx (eraseLocal := false)
let { ctx, simprocs, dischargeWrapper, .. } ← mkSimpContext stx (eraseLocal := false)
let stats ← dischargeWrapper.with fun discharge? =>
simpLocation ctx simprocs discharge? (expandOptLocation stx[5])
if tactic.simp.trace.get (← getOptions) then

View file

@ -30,7 +30,7 @@ def mkSimpCallStx (stx : Syntax) (usedSimps : UsedSimps) : MetaM (TSyntax `tacti
`(tactic| simp!%$tk $cfg:optConfig $(discharger)? $[only%$o]? $[[$args,*]]? $(loc)?)
else
`(tactic| simp%$tk $cfg:optConfig $[$discharger]? $[only%$o]? $[[$args,*]]? $(loc)?)
let { ctx, simprocs, dischargeWrapper } ← mkSimpContext stx (eraseLocal := false)
let { ctx, simprocs, dischargeWrapper, ..} ← mkSimpContext stx (eraseLocal := false)
let ctx := if bang.isSome then ctx.setAutoUnfold else ctx
let stats ← dischargeWrapper.with fun discharge? =>
simpLocation ctx (simprocs := simprocs) discharge? <|

View file

@ -34,7 +34,7 @@ deriving instance Repr for UseImplicitLambdaResult
| `(tactic| simpa%$tk $[?%$squeeze]? $[!%$unfold]? $cfg:optConfig $(disch)? $[only%$only]?
$[[$args,*]]? $[using $usingArg]?) => Elab.Tactic.focus do withSimpDiagnostics do
let stx ← `(tactic| simp $cfg:optConfig $(disch)? $[only%$only]? $[[$args,*]]?)
let { ctx, simprocs, dischargeWrapper } ←
let { ctx, simprocs, dischargeWrapper, .. } ←
withMainContext <| mkSimpContext stx (eraseLocal := false)
let ctx := if unfold.isSome then ctx.setAutoUnfold else ctx
-- TODO: have `simpa` fail if it doesn't use `simp`.

View file

@ -33,13 +33,13 @@ def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension)
else if info.kind matches .defn then
if inv then
throwError "invalid '←' modifier, '{declName}' is a declaration name to be unfolded"
if (← SimpTheorems.ignoreEquations declName) then
if (← Simp.ignoreEquations declName) then
ext.add (SimpEntry.toUnfold declName) attrKind
else if let some eqns ← getEqnsFor? declName then
for eqn in eqns do
addSimpTheorem ext eqn post (inv := false) attrKind prio
ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind
if (← SimpTheorems.unfoldEvenWithEqns declName) then
if (← Simp.unfoldEvenWithEqns declName) then
ext.add (SimpEntry.toUnfold declName) attrKind
else
ext.add (SimpEntry.toUnfold declName) attrKind

View file

@ -12,6 +12,27 @@ import Lean.Meta.Eqns
import Lean.Meta.Tactic.AuxLemma
import Lean.DefEqAttrib
import Lean.DocString
/-!
This module contains types to manages simp theorems and sets theirof.
Overview of types in this module:
* `Origin`: Identifies where a simp theorem comes from (global declaration or local expression).
Includes the direction of the theorem for global declarations.
* `SimpTheorem`: Represents a single simp theorem, including its origin and proof.
* `SimpEntry`: The effect of a simp attribute; either a `SimpTheorem` or information about a
definition to unfold. This is stored in oleans.
* `SimpTheorems`: Main data structure to store the simp set for a given `simp` invocation, including
discrimination trees, sets of erased theorem, declarations to unfold.
* `SimpExtension`: Environment extension to store the default simp set, or user-defined simp sets.
Each simp extension maintains its own `SimpTheorems` within a module.
* `SimpTheoremsArray`: Array of `SimpTheorems`, to avoid the need for merging `SimpTheorems` when
more than one simp extension is enabled.
-/
namespace Lean.Meta
register_builtin_option backward.dsimp.useDefEqAttr : Bool := {
@ -219,23 +240,6 @@ def ppSimpTheorem [Monad m] [MonadEnv m] [MonadError m] (s : SimpTheorem) : m Me
instance : BEq SimpTheorem where
beq e₁ e₂ := e₁.proof == e₂.proof
abbrev SimpTheoremTree := DiscrTree SimpTheorem
/--
The theorems in a simp set.
-/
structure SimpTheorems where
pre : SimpTheoremTree := DiscrTree.empty
post : SimpTheoremTree := DiscrTree.empty
lemmaNames : PHashSet Origin := {}
/--
Constants (and let-declaration `FVarId`) to unfold.
When `zetaDelta := false`, the simplifier will expand a let-declaration if it is in this set.
-/
toUnfold : PHashSet Name := {}
erased : PHashSet Origin := {}
toUnfoldThms : PHashMap Name (Array Name) := {}
deriving Inhabited
/--
Configuration for `MetaM` used to process global simp theorems
@ -250,81 +254,7 @@ def simpGlobalConfig : ConfigWithKey :=
@[inline] def withSimpGlobalConfig : MetaM α → MetaM α :=
withConfigWithKey simpGlobalConfig
partial def SimpTheorems.eraseCore (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
let d := { d with erased := d.erased.insert thmId, lemmaNames := d.lemmaNames.erase thmId }
if let .decl declName .. := thmId then
let d := { d with toUnfold := d.toUnfold.erase declName }
if let some thms := d.toUnfoldThms.find? declName then
let dummy := true
thms.foldl (init := d) (eraseCore · <| .decl · dummy (inv := false))
else
d
else
d
private def eraseIfExists (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
if d.lemmaNames.contains thmId then
d.eraseCore thmId
else
d
/--
If `e` is a backwards theorem `← thm`, we must ensure the forward theorem is erased
from `d`. See issue #4290
-/
private def eraseFwdIfBwd (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
if let some converseOrigin := e.origin.converse then
eraseIfExists d converseOrigin
else
d
def addSimpTheoremEntry (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
let d := eraseFwdIfBwd d e
if e.post then
{ d with post := d.post.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
else
{ d with pre := d.pre.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
where
updateLemmaNames (s : PHashSet Origin) : PHashSet Origin :=
s.insert e.origin
def SimpTheorems.addDeclToUnfoldCore (d : SimpTheorems) (declName : Name) : SimpTheorems :=
{ d with toUnfold := d.toUnfold.insert declName }
def SimpTheorems.addLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : SimpTheorems :=
-- A small hack that relies on the fact that constants and `FVarId` names should be disjoint.
{ d with toUnfold := d.toUnfold.insert fvarId.name }
/-- Return `true` if `declName` is tagged to be unfolded using `unfoldDefinition?` (i.e., without using equational theorems). -/
def SimpTheorems.isDeclToUnfold (d : SimpTheorems) (declName : Name) : Bool :=
d.toUnfold.contains declName
def SimpTheorems.isLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : Bool :=
d.toUnfold.contains fvarId.name -- See comment at `addLetDeclToUnfold`
def SimpTheorems.isLemma (d : SimpTheorems) (thmId : Origin) : Bool :=
d.lemmaNames.contains thmId
/-- Register the equational theorems for the given definition. -/
def SimpTheorems.registerDeclToUnfoldThms (d : SimpTheorems) (declName : Name) (eqThms : Array Name) : SimpTheorems :=
{ d with toUnfoldThms := d.toUnfoldThms.insert declName eqThms }
def SimpTheorems.erase [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m]
(d : SimpTheorems) (thmId : Origin) : m SimpTheorems := do
if d.isLemma thmId ||
match thmId with
| .decl declName .. => d.isDeclToUnfold declName || d.toUnfoldThms.contains declName
| _ => false
then
return d.eraseCore thmId
-- `attribute [-simp] foo` should also undo `attribute [simp ←] foo`.
if let some thmId' := thmId.converse then
if d.isLemma thmId' then
return d.eraseCore thmId'
logWarning m!"'{thmId.key}' does not have [simp] attribute"
return d
private partial def isPerm : Expr → Expr → MetaM Bool
| .app f₁ a₁, .app f₂ a₂ => isPerm f₁ f₂ <&&> isPerm a₁ a₂
@ -432,7 +362,12 @@ private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array
| none => throwError "unexpected kind of 'simp' theorem{indentExpr type}"
return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) }
private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool) (prio : Nat) : MetaM (Array SimpTheorem) := do
/--
Creates a `SimpTheorem` from a global theorem.
Because some theorems lead to multiple `SimpTheorems` (in particular conjunctions), returns an array.
-/
def mkSimpTheoremFromConst (declName : Name) (post := true) (inv := false)
(prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) := do
let cinfo ← getConstVal declName
let us := cinfo.levelParams.map mkLevelParam
let origin := .decl declName post inv
@ -449,52 +384,6 @@ private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool)
else
return #[← withoutExporting do mkSimpTheoremCore origin (mkConst declName us) #[] (mkConst declName) post prio (noIndexAtArgs := false)]
inductive SimpEntry where
| thm : SimpTheorem → SimpEntry
| toUnfold : Name → SimpEntry
| toUnfoldThms : Name → Array Name → SimpEntry
deriving Inhabited
/--
The environment extension that contains a simp set, returned by `Lean.Meta.registerSimpAttr`.
Use the simp set's attribute or `Lean.Meta.addSimpTheorem` to add theorems to the simp set. Use
`Lean.Meta.SimpExtension.getTheorems` to get the contents.
-/
abbrev SimpExtension := SimpleScopedEnvExtension SimpEntry SimpTheorems
def SimpExtension.getTheorems (ext : SimpExtension) : CoreM SimpTheorems :=
return ext.getState (← getEnv)
def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv : Bool) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do
let simpThms ← withExporting (isExporting := !isPrivateName declName) do mkSimpTheoremsFromConst declName post inv prio
for simpThm in simpThms do
ext.add (SimpEntry.thm simpThm) attrKind
def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
registerSimpleScopedEnvExtension {
name := name
initial := {}
addEntry := fun d e =>
match e with
| .thm e => addSimpTheoremEntry d e
| .toUnfold n => d.addDeclToUnfoldCore n
| .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
}
abbrev SimpExtensionMap := Std.HashMap Name SimpExtension
builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {}
def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
return (← simpExtensionMapRef.get)[attrName]?
/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
let s := { s with erased := s.erased.erase (.decl declName post inv) }
let simpThms ← mkSimpTheoremsFromConst declName post inv prio
return simpThms.foldl addSimpTheoremEntry s
def SimpTheorem.getValue (simpThm : SimpTheorem) : MetaM Expr := do
if simpThm.proof.isConst && simpThm.levelParams.isEmpty then
let info ← getConstVal simpThm.proof.constName!
@ -512,10 +401,27 @@ private def preprocessProof (val : Expr) (inv : Bool) : MetaM (Array Expr) := do
let ps ← preprocess val type inv (isGlobal := false)
return ps.toArray.map fun (val, _) => val
/-- Auxiliary method for creating simp theorems from a proof term `val`. -/
private def mkSimpTheorems (id : Origin) (levelParams : Array Name) (proof : Expr) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) :=
withReducible do
(← preprocessProof proof inv).mapM fun val => mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true)
def mkSimpTheoremFromExpr (id : Origin) (levelParams : Array Name) (proof : Expr) (inv := false)
(post := true) (prio : Nat := eval_prio default) (config : ConfigWithKey := simpGlobalConfig) :
MetaM (Array SimpTheorem) := do
if proof.isConst then
-- Recall that we use `simpGlobalConfig` for processing global declarations.
mkSimpTheoremFromConst proof.constName! post inv prio
else
withConfigWithKey config do
withReducible do
(← preprocessProof proof inv).mapM fun val =>
mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true)
/--
A simp theorem or information about a declaration to unfold by simp.
This is stored in the oleans to implement the `simp` attribute and user-defined simp sets.
-/
inductive SimpEntry where
| thm : SimpTheorem → SimpEntry
| toUnfold : Name → SimpEntry
| toUnfoldThms : Name → Array Name → SimpEntry
deriving Inhabited
/--
Reducible functions and projection functions should always be put in `toUnfold`, instead
@ -524,7 +430,7 @@ of trying to use equational theorems.
The simplifiers has special support for structure and class projections, and gets
confused when they suddenly rewrite, so ignore equations for them
-/
def SimpTheorems.ignoreEquations (declName : Name) : CoreM Bool := do
def Simp.ignoreEquations (declName : Name) : CoreM Bool := do
return (← isProjectionFn declName) || (← isReducible declName)
/--
@ -540,20 +446,24 @@ behavior unless `unfoldPartialApp := true`.
Moreover, users will have to use `f.eq_def` if they want to force the definition to be
unfolded.
-/
def SimpTheorems.unfoldEvenWithEqns (declName : Name) : CoreM Bool := do
def Simp.unfoldEvenWithEqns (declName : Name) : CoreM Bool := do
if hasSmartUnfoldingDecl (← getEnv) declName then return true
unless (← isRecursiveDefinition declName) do return true
return false
def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM SimpTheorems := do
/--
Given the name of a declaration to unfold, return the `SimpEntry` (or entries) that
implement this unfolding, using either the equational theorems, or `SimpEntry.toUnfold`, or both.
-/
def mkSimpEntryOfDeclToUnfold (declName : Name) : MetaM (Array SimpEntry) := do
let mut entries : Array SimpEntry := #[]
-- NOTE: the latter condition is only to preserve previous behavior where simp accepts even things
-- that neither theorems nor unfoldable. This should likely be tightened up in the future.
if !(← getConstInfo declName).isDefinition && getOriginalConstKind? (← getEnv) declName == some .defn then
throwError "invalid 'simp', definition with exposed body expected: {.ofConstName declName}"
if (← ignoreEquations declName) then
return d.addDeclToUnfoldCore declName
if (← Simp.ignoreEquations declName) then
entries := entries.push (.toUnfold declName)
else if let some eqns ← getEqnsFor? declName then
let mut d := d
for h : i in [:eqns.size] do
let eqn := eqns[i]
/-
@ -571,24 +481,187 @@ def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM Si
if i + 1 = eqns.size then 0 else 1
else
100 - i
d ← SimpTheorems.addConst d eqn (prio := prio)
if (← unfoldEvenWithEqns declName) then
d := d.addDeclToUnfoldCore declName
return d
let thms ← mkSimpTheoremFromConst eqn (prio := prio)
entries := entries ++ thms.map (.thm ·)
if (← Simp.unfoldEvenWithEqns declName) then
entries := entries.push (.toUnfold declName)
else
return d.addDeclToUnfoldCore declName
entries := entries.push (.toUnfold declName)
return entries
abbrev SimpTheoremTree := DiscrTree SimpTheorem
/--
The theorems in a simp set.
-/
structure SimpTheorems where
pre : SimpTheoremTree := DiscrTree.empty
post : SimpTheoremTree := DiscrTree.empty
lemmaNames : PHashSet Origin := {}
/--
Constants (and let-declaration `FVarId`) to unfold.
When `zetaDelta := false`, the simplifier will expand a let-declaration if it is in this set.
-/
toUnfold : PHashSet Name := {}
erased : PHashSet Origin := {}
toUnfoldThms : PHashMap Name (Array Name) := {}
deriving Inhabited
partial def SimpTheorems.eraseCore (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
let d := { d with erased := d.erased.insert thmId, lemmaNames := d.lemmaNames.erase thmId }
if let .decl declName .. := thmId then
let d := { d with toUnfold := d.toUnfold.erase declName }
if let some thms := d.toUnfoldThms.find? declName then
let dummy := true
thms.foldl (init := d) (eraseCore · <| .decl · dummy (inv := false))
else
d
else
d
private def eraseIfExists (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
if d.lemmaNames.contains thmId then
d.eraseCore thmId
else
d
/--
If `e` is a backwards theorem `← thm`, we must ensure the forward theorem is erased
from `d`. See issue #4290
-/
private def eraseFwdIfBwd (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
if let some converseOrigin := e.origin.converse then
eraseIfExists d converseOrigin
else
d
def SimpTheorems.unerase (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
{ d with erased := d.erased.erase thmId }
def SimpTheorems.addSimpTheorem (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
-- Erase the converse, if it exists
let d := eraseFwdIfBwd d e
if e.post then
{ d with post := d.post.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
else
{ d with pre := d.pre.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
where
updateLemmaNames (s : PHashSet Origin) : PHashSet Origin :=
s.insert e.origin
@[deprecated SimpTheorems.addSimpTheorem (since := "2025-06-17")]
def addSimpTheoremEntry := SimpTheorems.addSimpTheorem
def SimpTheorems.addDeclToUnfoldCore (d : SimpTheorems) (declName : Name) : SimpTheorems :=
{ d with toUnfold := d.toUnfold.insert declName }
def SimpTheorems.addLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : SimpTheorems :=
-- A small hack that relies on the fact that constants and `FVarId` names should be disjoint.
{ d with toUnfold := d.toUnfold.insert fvarId.name }
/-- Return `true` if `declName` is tagged to be unfolded using `unfoldDefinition?` (i.e., without using equational theorems). -/
def SimpTheorems.isDeclToUnfold (d : SimpTheorems) (declName : Name) : Bool :=
d.toUnfold.contains declName
def SimpTheorems.isLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : Bool :=
d.toUnfold.contains fvarId.name -- See comment at `addLetDeclToUnfold`
def SimpTheorems.isLemma (d : SimpTheorems) (thmId : Origin) : Bool :=
d.lemmaNames.contains thmId
/-- Register the equational theorems for the given definition. -/
def SimpTheorems.registerDeclToUnfoldThms (d : SimpTheorems) (declName : Name) (eqThms : Array Name) : SimpTheorems :=
{ d with toUnfoldThms := d.toUnfoldThms.insert declName eqThms }
def SimpTheorems.erase [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m]
(d : SimpTheorems) (thmId : Origin) : m SimpTheorems := do
if d.isLemma thmId ||
match thmId with
| .decl declName .. => d.isDeclToUnfold declName || d.toUnfoldThms.contains declName
| _ => false
then
return d.eraseCore thmId
-- `attribute [-simp] foo` should also undo `attribute [simp ←] foo`.
if let some thmId' := thmId.converse then
if d.isLemma thmId' then
return d.eraseCore thmId'
logWarning m!"'{thmId.key}' does not have [simp] attribute"
return d
def SimpTheorems.addSimpEntry (d : SimpTheorems) (e : SimpEntry) : SimpTheorems :=
match e with
| .thm e => d.addSimpTheorem e
| .toUnfold n => d.addDeclToUnfoldCore n
| .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
/--
`simp [foo]` should undo a previous `attribute @[-simp] foo`.
(Note that `attribute @[simp] foo` does not undo a `attribute @[simp] foo`, see #5852)
-/
def SimpTheorems.uneraseSimpEntry (d : SimpTheorems) (e : SimpEntry) : SimpTheorems :=
match e with
| .thm e => d.unerase e.origin
| _ => d
/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
let simpThms ← mkSimpTheoremFromConst declName post inv prio
return simpThms.foldl SimpTheorems.addSimpTheorem s
/--
The environment extension that contains a simp set, returned by `Lean.Meta.registerSimpAttr`.
Use the simp set's attribute or `Lean.Meta.addSimpTheorem` to add theorems to the simp set. Use
`Lean.Meta.SimpExtension.getTheorems` to get the contents.
-/
abbrev SimpExtension := SimpleScopedEnvExtension SimpEntry SimpTheorems
def SimpExtension.getTheorems (ext : SimpExtension) : CoreM SimpTheorems :=
return ext.getState (← getEnv)
/--
Adds a simp theorem to a simp extension
-/
def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv : Bool) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do
let simpThms ← withExporting (isExporting := !isPrivateName declName) do mkSimpTheoremFromConst declName post inv prio
for simpThm in simpThms do
ext.add (SimpEntry.thm simpThm) attrKind
def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
registerSimpleScopedEnvExtension {
name := name
initial := {}
addEntry := fun d e => d.addSimpEntry e
}
abbrev SimpExtensionMap := Std.HashMap Name SimpExtension
builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {}
def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
return (← simpExtensionMapRef.get)[attrName]?
def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM SimpTheorems := do
let entries ← mkSimpEntryOfDeclToUnfold declName
return entries.foldl (init := d) fun d e => d.addSimpEntry e
/-- Auxiliary method for adding a local simp theorem to a `SimpTheorems` datastructure. -/
def SimpTheorems.add (s : SimpTheorems) (id : Origin) (levelParams : Array Name) (proof : Expr)
(inv := false) (post := true) (prio : Nat := eval_prio default)
(config : ConfigWithKey := simpGlobalConfig) : MetaM SimpTheorems := do
if proof.isConst then
-- Recall that we use `simpGlobalConfig` for processing global declarations.
s.addConst proof.constName! post inv prio
else
let simpThms ← withConfigWithKey config <| mkSimpTheorems id levelParams proof post inv prio
return simpThms.foldl addSimpTheoremEntry s
let simpThms ← mkSimpTheoremFromExpr id levelParams proof inv post prio config
return simpThms.foldl SimpTheorems.addSimpTheorem s
/--
A `SimpTheoremsArray` is a collection of `SimpTheorems`. The first entry is the default simp set
and possible extensions as simp args (`simp [thm]`), further entries are custom simp sets added
a s simp arguments (`simp [my_simp_set]`). The array is scanned linear during rewriting.
This avoids the need for efficiently merging the `SimpTheorems` data structure.
-/
abbrev SimpTheoremsArray := Array SimpTheorems
def SimpTheoremsArray.addTheorem (thmsArray : SimpTheoremsArray) (id : Origin) (h : Expr) (config : ConfigWithKey := simpGlobalConfig) : MetaM SimpTheoremsArray :=

62
tests/lean/run/8815.lean Normal file
View file

@ -0,0 +1,62 @@
/-!
Assortion of tests to make sure the #8815 simp arg elaboration refactoring did not change
behavior.
-/
set_option linter.unusedVariables false
example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp [*, -hQ]
/-- error: simp made no progress -/
#guard_msgs in example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp [*, -hP]
/-- error: unknown constant 'hQ' -/
#guard_msgs in example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp [-hQ, *]
#guard_msgs in example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp_all [-hQ]
/--
error: unknown constant 'hQ'
---
error: simp made no progress
-/
#guard_msgs in example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp [-hQ]
theorem a_thm : True := trivial
def f : Nat → Nat
| 0 => 1
| n + 1 => f n + 1
example : f 0 > 0 := by simp [f]
example : f 0 > 0 := by simp!
-- NB: simp! disables all warnings, not just for declarations to unfold
-- Mild bug, but not a regresion.
/--
error: unsolved goals
⊢ 0 < f 0
-/
#guard_msgs in example : f 0 > 0 := by simp! [-f, -a_thm]
/--
warning: 'f' does not have [simp] attribute
---
warning: 'a_thm' does not have [simp] attribute
---
error: unsolved goals
⊢ 0 < f 0
-/
#guard_msgs in example : f 0 > 0 := by
simp [-f, -a_thm]
/--
error: invalid 'simp', proposition expected
Type 32
-/
#guard_msgs in
example : True := by simp [Sort 32] -- mostly about error location, once guard_msgs shows that

View file

@ -22,13 +22,13 @@ termination_by _ n => n
#check foo.eq_3
-- In order to reliably check if simp is not attempting to rewrite with a certain lemma
-- we can look at te diagnostics. But simply dumping all diangostics is too noisy for a test,
-- we can look at the diagnostics. But simply dumping all diangostics is too noisy for a test,
-- so here we try to get our hands at the `Simp.Stats` and look there.
open Lean Meta Elab Tactic in
elab "simp_foo_with_check" : tactic =>
withOptions (fun o => diagnostics.set o true) do withMainContext do
let stx ← `(tactic|simp [foo])
let { ctx, simprocs, dischargeWrapper } ← mkSimpContext stx (eraseLocal := false)
let { ctx, simprocs, dischargeWrapper, .. } ← mkSimpContext stx (eraseLocal := false)
let stats ← dischargeWrapper.with fun discharge? => do
simpLocation ctx simprocs discharge? (expandOptLocation stx.raw[5])
unless stats.diag.triedThmCounter.toList.any (fun (o, _n) => o.key = ``foo.eq_2) do