diff --git a/src/Lean/Elab/Tactic/NormCast.lean b/src/Lean/Elab/Tactic/NormCast.lean index da3e09658a..952e6eafe6 100644 --- a/src/Lean/Elab/Tactic/NormCast.lean +++ b/src/Lean/Elab/Tactic/NormCast.lean @@ -266,7 +266,7 @@ def evalConvNormCast : Tactic := @[builtin_tactic pushCast] def evalPushCast : Tactic := fun stx => do - let { ctx, simprocs, dischargeWrapper } ← withMainContext do + let { ctx, simprocs, dischargeWrapper, .. } ← withMainContext do mkSimpContext (simpTheorems := pushCastExt.getTheorems) stx (eraseLocal := false) let ctx := ctx.setFailIfUnchanged false dischargeWrapper.with fun discharge? => diff --git a/src/Lean/Elab/Tactic/Simp.lean b/src/Lean/Elab/Tactic/Simp.lean index 452ff81396..54483fbb5c 100644 --- a/src/Lean/Elab/Tactic/Simp.lean +++ b/src/Lean/Elab/Tactic/Simp.lean @@ -91,56 +91,6 @@ def elabSimpConfig (optConfig : Syntax) (kind : SimpKind) : TacticM Meta.Simp.Co | .simpAll => return (← elabSimpConfigCtxCore optConfig).toConfig | .dsimp => return { (← elabDSimpConfigCore optConfig) with } -private def addDeclToUnfoldOrTheorem (config : Meta.ConfigWithKey) (thms : SimpTheorems) (id : Origin) (e : Expr) (post : Bool) (inv : Bool) (kind : SimpKind) : MetaM SimpTheorems := do - if e.isConst then - let declName := e.constName! - let info ← getConstVal declName - if (← isProp info.type) then - thms.addConst declName (post := post) (inv := inv) - else - if inv then - throwError "invalid '←' modifier, '{declName}' is a declaration name to be unfolded" - if kind == .dsimp then - return thms.addDeclToUnfoldCore declName - else - thms.addDeclToUnfold declName - else if e.isFVar then - let fvarId := e.fvarId! - let decl ← fvarId.getDecl - if (← isProp decl.type) then - thms.add id #[] e (post := post) (inv := inv) (config := config) - else if !decl.isLet then - throwError "invalid argument, variable is not a proposition or let-declaration" - else if inv then - throwError "invalid '←' modifier, '{e}' is a let-declaration name to be unfolded" - else - return thms.addLetDeclToUnfold fvarId - else - thms.add id #[] e (post := post) (inv := inv) (config := config) - -private def addSimpTheorem (config : Meta.ConfigWithKey) (thms : SimpTheorems) (id : Origin) (stx : Syntax) (post : Bool) (inv : Bool) : TermElabM SimpTheorems := do - let thm? ← Term.withoutModifyingElabMetaStateWithInfo <| withRef stx do - let e ← Term.elabTerm stx none - Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true) - let e ← instantiateMVars e - if e.hasSyntheticSorry then - return none - let e := e.eta - if e.hasMVar then - let r ← abstractMVars e - return some (r.paramNames, r.expr) - else - return some (#[], e) - if let some (levelParams, proof) := thm? then - thms.add id levelParams proof (post := post) (inv := inv) (config := config) - else - return thms - -structure ElabSimpArgsResult where - ctx : Simp.Context - simprocs : Simp.SimprocsArray - starArg : Bool := false - inductive ResolveSimpIdResult where | none | expr (e : Expr) @@ -154,104 +104,8 @@ inductive ResolveSimpIdResult where -/ | ext (ext₁? : Option SimpExtension) (ext₂? : Option Simp.SimprocExtension) (h : ext₁?.isSome || ext₂?.isSome) -/-- - Elaborate extra simp theorems provided to `simp`. `stx` is of the form `"[" simpTheorem,* "]"` - If `eraseLocal == true`, then we consider local declarations when resolving names for erased theorems (`- id`), - this option only makes sense for `simp_all` or `*` is used. - When `recover := true`, try to recover from errors as much as possible so that users keep seeing - the current goal. --/ -def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (eraseLocal : Bool) (kind : SimpKind) : TacticM ElabSimpArgsResult := do - if stx.isNone then - return { ctx, simprocs } - else - /- - syntax simpPre := "↓" - syntax simpPost := "↑" - syntax simpLemma := (simpPre <|> simpPost)? "← "? term - - syntax simpErase := "-" ident - -/ - let go := withMainContext do - let zetaDeltaSet ← toZetaDeltaSet stx ctx - withTrackingZetaDeltaSet zetaDeltaSet do - let mut thmsArray := ctx.simpTheorems - let mut thms := thmsArray[0]! - let mut simprocs := simprocs - let mut starArg := false - for arg in stx[1].getSepArgs do - try -- like withLogging, but compatible with do-notation - if arg.getKind == ``Lean.Parser.Tactic.simpErase then - let fvar? ← if eraseLocal || starArg then Term.isLocalIdent? arg[1] else pure none - if let some fvar := fvar? then - -- We use `eraseCore` because the simp theorem for the hypothesis was not added yet - thms := thms.eraseCore (.fvar fvar.fvarId!) - else - let id := arg[1] - if let .ok declName ← observing (realizeGlobalConstNoOverloadWithInfo id) then - if (← Simp.isSimproc declName) then - simprocs := simprocs.erase declName - else if ctx.config.autoUnfold then - thms := thms.eraseCore (.decl declName) - else - thms ← withRef id <| thms.erase (.decl declName) - else - -- If `id` could not be resolved, we should check whether it is a builtin simproc. - -- before returning error. - let name := id.getId.eraseMacroScopes - if (← Simp.isBuiltinSimproc name) then - simprocs := simprocs.erase name - else - throwUnknownConstantAt id name - else if arg.getKind == ``Lean.Parser.Tactic.simpLemma then - let post := - if arg[0].isNone then - true - else - arg[0][0].getKind == ``Parser.Tactic.simpPost - let inv := !arg[1].isNone - let term := arg[2] - match (← resolveSimpIdTheorem? term) with - | .expr e => - let name ← mkFreshId - thms ← addDeclToUnfoldOrTheorem ctx.indexConfig thms (.stx name arg) e post inv kind - | .simproc declName => - simprocs ← simprocs.add declName post - | .ext (some ext₁) (some ext₂) _ => - thmsArray := thmsArray.push (← ext₁.getTheorems) - simprocs := simprocs.push (← ext₂.getSimprocs) - | .ext (some ext₁) none _ => - thmsArray := thmsArray.push (← ext₁.getTheorems) - | .ext none (some ext₂) _ => - simprocs := simprocs.push (← ext₂.getSimprocs) - | .none => - let name ← mkFreshId - thms ← addSimpTheorem ctx.indexConfig thms (.stx name arg) term post inv - else if arg.getKind == ``Lean.Parser.Tactic.simpStar then - starArg := true - else - throwUnsupportedSyntax - catch ex => - if (← read).recover then - logException ex - else - throw ex - let ctx := ctx.setZetaDeltaSet zetaDeltaSet (← getZetaDeltaFVarIds) - return { ctx := ctx.setSimpTheorems (thmsArray.set! 0 thms), simprocs, starArg } - -- If recovery is disabled, then we want simp argument elaboration failures to be exceptions. - -- This affects `addSimpTheorem`. - if (← read).recover then - go - else - Term.withoutErrToSorry go -where - isSimproc? (e : Expr) : MetaM (Option Name) := do - let .const declName _ := e | return none - unless (← Simp.isSimproc declName) do return none - return some declName - - resolveSimpIdTheorem? (simpArgTerm : Term) : TacticM ResolveSimpIdResult := do - let resolveExt (n : Name) : TacticM ResolveSimpIdResult := do +private def resolveSimpIdTheorem? (simpArgTerm : Term) : TermElabM ResolveSimpIdResult := do + let resolveExt (n : Name) : TermElabM ResolveSimpIdResult := do let ext₁? ← getSimpExtension? n let ext₂? ← Simp.getSimprocExtension? n if h : ext₁?.isSome || ext₂?.isSome then @@ -279,7 +133,234 @@ where return .expr e else return .none +where + isSimproc? (e : Expr) : MetaM (Option Name) := do + let .const declName _ := e | return none + unless (← Simp.isSimproc declName) do return none + return some declName + +/-- +The result of elaborating a single `simp` argument +-/ +inductive ElabSimpArgResult where + | addEntries (entries : Array SimpEntry) + | addSimproc («simproc» : Name) (post : Bool) + | addLetToUnfold (fvarId : FVarId) + | ext (ext₁? : Option SimpExtension) (ext₂? : Option Simp.SimprocExtension) (h : ext₁?.isSome || ext₂?.isSome) + | erase (toErase : Origin) + | eraseSimproc (toErase : Name) + | star + | none -- used for example when elaboration fails + +private def elabDeclToUnfoldOrTheorem (config : Meta.ConfigWithKey) (id : Origin) + (e : Expr) (post : Bool) (inv : Bool) (kind : SimpKind) : MetaM ElabSimpArgResult := do + if e.isConst then + let declName := e.constName! + let info ← getConstVal declName + if (← isProp info.type) then + let thms ← mkSimpTheoremFromConst declName (post := post) (inv := inv) + return .addEntries <| thms.map (SimpEntry.thm ·) + else + if inv then + throwError "invalid '←' modifier, '{declName}' is a declaration name to be unfolded" + if kind == .dsimp then + return .addEntries #[.toUnfold declName] + else + .addEntries <$> mkSimpEntryOfDeclToUnfold declName + else if e.isFVar then + let fvarId := e.fvarId! + let decl ← fvarId.getDecl + if (← isProp decl.type) then + let thms ← mkSimpTheoremFromExpr id #[] e (post := post) (inv := inv) (config := config) + return .addEntries <| thms.map (SimpEntry.thm ·) + else if !decl.isLet then + throwError "invalid argument, variable is not a proposition or let-declaration" + else if inv then + throwError "invalid '←' modifier, '{e}' is a let-declaration name to be unfolded" + else + return .addLetToUnfold fvarId + else + let thms ← mkSimpTheoremFromExpr id #[] e (post := post) (inv := inv) (config := config) + return .addEntries <| thms.map (SimpEntry.thm ·) + +private def elabSimpTheorem (config : Meta.ConfigWithKey) (id : Origin) (stx : Syntax) + (post : Bool) (inv : Bool) : TermElabM ElabSimpArgResult := do + let thm? ← Term.withoutModifyingElabMetaStateWithInfo <| withRef stx do + let e ← Term.elabTerm stx .none + Term.synthesizeSyntheticMVars (postpone := .no) (ignoreStuckTC := true) + let e ← instantiateMVars e + if e.hasSyntheticSorry then + return .none + let e := e.eta + if e.hasMVar then + let r ← abstractMVars e + return some (r.paramNames, r.expr) + else + return some (#[], e) + if let some (levelParams, proof) := thm? then + let thms ← mkSimpTheoremFromExpr id levelParams proof (post := post) (inv := inv) (config := config) + return .addEntries <| thms.map (SimpEntry.thm ·) + else + return .none + +private def elabSimpArg (indexConfig : Meta.ConfigWithKey) (eraseLocal : Bool) (kind : SimpKind) + (arg : Syntax) : TacticM ElabSimpArgResult := withRef arg do + try + /- + syntax simpPre := "↓" + syntax simpPost := "↑" + syntax simpLemma := (simpPre <|> simpPost)? "← "? term + + syntax simpErase := "-" ident + -/ + if arg.getKind == ``Lean.Parser.Tactic.simpErase then + let fvar? ← if eraseLocal then Term.isLocalIdent? arg[1] else pure none + if let some fvar := fvar? then + -- We use `eraseCore` because the simp theorem for the hypothesis was not added yet + return .erase (.fvar fvar.fvarId!) + else + let id := arg[1] + if let .ok declName ← observing (realizeGlobalConstNoOverloadWithInfo id) then + if (← Simp.isSimproc declName) then + return .eraseSimproc declName + else + return .erase (.decl declName) + else + -- If `id` could not be resolved, we should check whether it is a builtin simproc. + -- before returning error. + let name := id.getId.eraseMacroScopes + if (← Simp.isBuiltinSimproc name) then + return .eraseSimproc name + else + throwUnknownConstantAt id name + else if arg.getKind == ``Lean.Parser.Tactic.simpLemma then + let post := + if arg[0].isNone then + true + else + arg[0][0].getKind == ``Parser.Tactic.simpPost + let inv := !arg[1].isNone + let term := arg[2] + match (← resolveSimpIdTheorem? term) with + | .expr e => + let name ← mkFreshId + elabDeclToUnfoldOrTheorem indexConfig (.stx name arg) e post inv kind + | .simproc declName => + return .addSimproc declName post + | .ext ext₁? ext₂? h => + return .ext ext₁? ext₂? h + | .none => + let name ← mkFreshId + elabSimpTheorem indexConfig (.stx name arg) term post inv + else if arg.getKind == ``Lean.Parser.Tactic.simpStar then + return .star + else + throwUnsupportedSyntax + catch ex => + if (← read).recover then + logException ex + return .none + else + throw ex + +/-- +The result of elaborating a full array of simp arguments and applying them to the simp context. +-/ +structure ElabSimpArgsResult where + ctx : Simp.Context + simprocs : Simp.SimprocsArray + /-- The elaborated simp arguments with syntax -/ + simpArgs : Array (Syntax × ElabSimpArgResult) + +/-- Implements the effect of the `*` attribute. -/ +private def applyStarArg (ctx : Simp.Context) : MetaM Simp.Context := do + let mut simpTheorems := ctx.simpTheorems + /- + When using `zetaDelta := false`, we do not expand let-declarations when using `[*]`. + Users must explicitly include it in the list. + -/ + let hs ← getPropHyps + for h in hs do + unless simpTheorems.isErased (.fvar h) do + simpTheorems ← simpTheorems.addTheorem (.fvar h) (← h.getDecl).toExpr (config := ctx.indexConfig) + return ctx.setSimpTheorems simpTheorems + +/-- + Elaborate extra simp theorems provided to `simp`. `stx` is of the form `"[" simpTheorem,* "]"` + If `eraseLocal == true`, then we consider local declarations when resolving names for erased theorems (`- id`), + this option only makes sense for `simp_all` or `*` is used. + When `recover := true`, try to recover from errors as much as possible so that users keep seeing + the current goal. +-/ +def elabSimpArgs (stx : Syntax) (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (eraseLocal : Bool) + (kind : SimpKind) (ignoreStarArg := false) : TacticM ElabSimpArgsResult := do + if stx.isNone then + return { ctx, simprocs, simpArgs := #[] } + else + /- + syntax simpPre := "↓" + syntax simpPost := "↑" + syntax simpLemma := (simpPre <|> simpPost)? "← "? term + + syntax simpErase := "-" ident + -/ + let go := withMainContext do + let zetaDeltaSet ← toZetaDeltaSet stx ctx + withTrackingZetaDeltaSet zetaDeltaSet do + let mut starArg := false -- only after * we can erase local declarations + let mut args : Array (Syntax × ElabSimpArgResult) := #[] + for argStx in stx[1].getSepArgs do + let arg ← elabSimpArg ctx.indexConfig (eraseLocal || starArg) kind argStx + starArg := !ignoreStarArg && (starArg || arg matches .star) + args := args.push (argStx, arg) + + let mut thmsArray := ctx.simpTheorems + let mut thms := thmsArray[0]! + let mut simprocs := simprocs + for (ref, arg) in args do + match arg with + | .addEntries entries => + for entry in entries do + thms := thms.uneraseSimpEntry entry + thms := thms.addSimpEntry entry + | .addLetToUnfold fvarId => + thms := thms.addLetDeclToUnfold fvarId + | .addSimproc declName post => + simprocs ← simprocs.add declName post + | .erase origin => + -- `thms.erase` checks if the erasure is effective. + -- We do not want this check for local hypotheses (they are added later based on `starArg`) + if origin matches .fvar _ then + thms := thms.eraseCore origin + -- Nor for decls to unfold when we do auto unfolding + else if ctx.config.autoUnfold then + thms := thms.eraseCore origin + else + thms ← withRef ref <| thms.erase origin + | .eraseSimproc name => + simprocs := simprocs.erase name + | .ext simpExt? simprocExt? _ => + if let some simpExt := simpExt? then + thmsArray := thmsArray.push (← simpExt.getTheorems) + if let some simprocExt := simprocExt? then + simprocs := simprocs.push (← simprocExt.getSimprocs) + | .star => pure () + | .none => pure () + + let mut ctx := ctx.setZetaDeltaSet zetaDeltaSet (← getZetaDeltaFVarIds) + ctx := ctx.setSimpTheorems (thmsArray.set! 0 thms) + if !ignoreStarArg && starArg then + ctx ← applyStarArg ctx + + return { ctx, simprocs, simpArgs := args} + -- If recovery is disabled, then we want simp argument elaboration failures to be exceptions. + -- This affects `addSimpTheorem`. + if (← read).recover then + go + else + Term.withoutErrToSorry go +where /-- If `zetaDelta := false`, create a `FVarId` set with all local let declarations in the `simp` argument list. -/ toZetaDeltaSet (stx : Syntax) (ctx : Simp.Context) : TacticM FVarIdSet := do if ctx.config.zetaDelta then return {} @@ -319,6 +400,8 @@ structure MkSimpContextResult where ctx : Simp.Context simprocs : Simp.SimprocsArray dischargeWrapper : Simp.DischargeWrapper + /-- The elaborated simp arguments with syntax -/ + simpArgs : Array (Syntax × ElabSimpArgResult) := #[] /-- Create the `Simp.Context` for the `simp`, `dsimp`, and `simp_all` tactics. @@ -351,23 +434,8 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp) (config := (← elabSimpConfig stx[1] (kind := kind))) (simpTheorems := #[simpTheorems]) congrTheorems - let r ← elabSimpArgs stx[4] (eraseLocal := eraseLocal) (kind := kind) (simprocs := #[simprocs]) ctx - if !r.starArg || ignoreStarArg then - return { r with dischargeWrapper } - else - let ctx := r.ctx - let simprocs := r.simprocs - let mut simpTheorems := ctx.simpTheorems - /- - When using `zetaDelta := false`, we do not expand let-declarations when using `[*]`. - Users must explicitly include it in the list. - -/ - let hs ← getPropHyps - for h in hs do - unless simpTheorems.isErased (.fvar h) do - simpTheorems ← simpTheorems.addTheorem (.fvar h) (← h.getDecl).toExpr (config := ctx.indexConfig) - let ctx := ctx.setSimpTheorems simpTheorems - return { ctx, simprocs, dischargeWrapper } + let r ← elabSimpArgs stx[4] (eraseLocal := eraseLocal) (kind := kind) (simprocs := #[simprocs]) (ignoreStarArg := ignoreStarArg) ctx + return { r with dischargeWrapper } register_builtin_option tactic.simp.trace : Bool := { defValue := false @@ -477,7 +545,7 @@ def withSimpDiagnostics (x : TacticM Simp.Diagnostics) : TacticM Unit := do (location)? -/ @[builtin_tactic Lean.Parser.Tactic.simp] def evalSimp : Tactic := fun stx => withMainContext do withSimpDiagnostics do - let { ctx, simprocs, dischargeWrapper } ← mkSimpContext stx (eraseLocal := false) + let { ctx, simprocs, dischargeWrapper, .. } ← mkSimpContext stx (eraseLocal := false) let stats ← dischargeWrapper.with fun discharge? => simpLocation ctx simprocs discharge? (expandOptLocation stx[5]) if tactic.simp.trace.get (← getOptions) then diff --git a/src/Lean/Elab/Tactic/SimpTrace.lean b/src/Lean/Elab/Tactic/SimpTrace.lean index 661207239f..28913ea31e 100644 --- a/src/Lean/Elab/Tactic/SimpTrace.lean +++ b/src/Lean/Elab/Tactic/SimpTrace.lean @@ -30,7 +30,7 @@ def mkSimpCallStx (stx : Syntax) (usedSimps : UsedSimps) : MetaM (TSyntax `tacti `(tactic| simp!%$tk $cfg:optConfig $(discharger)? $[only%$o]? $[[$args,*]]? $(loc)?) else `(tactic| simp%$tk $cfg:optConfig $[$discharger]? $[only%$o]? $[[$args,*]]? $(loc)?) - let { ctx, simprocs, dischargeWrapper } ← mkSimpContext stx (eraseLocal := false) + let { ctx, simprocs, dischargeWrapper, ..} ← mkSimpContext stx (eraseLocal := false) let ctx := if bang.isSome then ctx.setAutoUnfold else ctx let stats ← dischargeWrapper.with fun discharge? => simpLocation ctx (simprocs := simprocs) discharge? <| diff --git a/src/Lean/Elab/Tactic/Simpa.lean b/src/Lean/Elab/Tactic/Simpa.lean index 9b5e0f4cfd..f54d8215a9 100644 --- a/src/Lean/Elab/Tactic/Simpa.lean +++ b/src/Lean/Elab/Tactic/Simpa.lean @@ -34,7 +34,7 @@ deriving instance Repr for UseImplicitLambdaResult | `(tactic| simpa%$tk $[?%$squeeze]? $[!%$unfold]? $cfg:optConfig $(disch)? $[only%$only]? $[[$args,*]]? $[using $usingArg]?) => Elab.Tactic.focus do withSimpDiagnostics do let stx ← `(tactic| simp $cfg:optConfig $(disch)? $[only%$only]? $[[$args,*]]?) - let { ctx, simprocs, dischargeWrapper } ← + let { ctx, simprocs, dischargeWrapper, .. } ← withMainContext <| mkSimpContext stx (eraseLocal := false) let ctx := if unfold.isSome then ctx.setAutoUnfold else ctx -- TODO: have `simpa` fail if it doesn't use `simp`. diff --git a/src/Lean/Meta/Tactic/Simp/Attr.lean b/src/Lean/Meta/Tactic/Simp/Attr.lean index 7ac1fe261a..c0f1e1419b 100644 --- a/src/Lean/Meta/Tactic/Simp/Attr.lean +++ b/src/Lean/Meta/Tactic/Simp/Attr.lean @@ -33,13 +33,13 @@ def mkSimpAttr (attrName : Name) (attrDescr : String) (ext : SimpExtension) else if info.kind matches .defn then if inv then throwError "invalid '←' modifier, '{declName}' is a declaration name to be unfolded" - if (← SimpTheorems.ignoreEquations declName) then + if (← Simp.ignoreEquations declName) then ext.add (SimpEntry.toUnfold declName) attrKind else if let some eqns ← getEqnsFor? declName then for eqn in eqns do addSimpTheorem ext eqn post (inv := false) attrKind prio ext.add (SimpEntry.toUnfoldThms declName eqns) attrKind - if (← SimpTheorems.unfoldEvenWithEqns declName) then + if (← Simp.unfoldEvenWithEqns declName) then ext.add (SimpEntry.toUnfold declName) attrKind else ext.add (SimpEntry.toUnfold declName) attrKind diff --git a/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean b/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean index cdcd409642..7f9afa4b29 100644 --- a/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean +++ b/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean @@ -12,6 +12,27 @@ import Lean.Meta.Eqns import Lean.Meta.Tactic.AuxLemma import Lean.DefEqAttrib import Lean.DocString + +/-! +This module contains types to manages simp theorems and sets theirof. + +Overview of types in this module: + +* `Origin`: Identifies where a simp theorem comes from (global declaration or local expression). + Includes the direction of the theorem for global declarations. +* `SimpTheorem`: Represents a single simp theorem, including its origin and proof. +* `SimpEntry`: The effect of a simp attribute; either a `SimpTheorem` or information about a + definition to unfold. This is stored in oleans. +* `SimpTheorems`: Main data structure to store the simp set for a given `simp` invocation, including + discrimination trees, sets of erased theorem, declarations to unfold. +* `SimpExtension`: Environment extension to store the default simp set, or user-defined simp sets. + Each simp extension maintains its own `SimpTheorems` within a module. +* `SimpTheoremsArray`: Array of `SimpTheorems`, to avoid the need for merging `SimpTheorems` when + more than one simp extension is enabled. + +-/ + + namespace Lean.Meta register_builtin_option backward.dsimp.useDefEqAttr : Bool := { @@ -219,23 +240,6 @@ def ppSimpTheorem [Monad m] [MonadEnv m] [MonadError m] (s : SimpTheorem) : m Me instance : BEq SimpTheorem where beq e₁ e₂ := e₁.proof == e₂.proof -abbrev SimpTheoremTree := DiscrTree SimpTheorem - -/-- -The theorems in a simp set. --/ -structure SimpTheorems where - pre : SimpTheoremTree := DiscrTree.empty - post : SimpTheoremTree := DiscrTree.empty - lemmaNames : PHashSet Origin := {} - /-- - Constants (and let-declaration `FVarId`) to unfold. - When `zetaDelta := false`, the simplifier will expand a let-declaration if it is in this set. - -/ - toUnfold : PHashSet Name := {} - erased : PHashSet Origin := {} - toUnfoldThms : PHashMap Name (Array Name) := {} - deriving Inhabited /-- Configuration for `MetaM` used to process global simp theorems @@ -250,81 +254,7 @@ def simpGlobalConfig : ConfigWithKey := @[inline] def withSimpGlobalConfig : MetaM α → MetaM α := withConfigWithKey simpGlobalConfig -partial def SimpTheorems.eraseCore (d : SimpTheorems) (thmId : Origin) : SimpTheorems := - let d := { d with erased := d.erased.insert thmId, lemmaNames := d.lemmaNames.erase thmId } - if let .decl declName .. := thmId then - let d := { d with toUnfold := d.toUnfold.erase declName } - if let some thms := d.toUnfoldThms.find? declName then - let dummy := true - thms.foldl (init := d) (eraseCore · <| .decl · dummy (inv := false)) - else - d - else - d -private def eraseIfExists (d : SimpTheorems) (thmId : Origin) : SimpTheorems := - if d.lemmaNames.contains thmId then - d.eraseCore thmId - else - d - -/-- -If `e` is a backwards theorem `← thm`, we must ensure the forward theorem is erased -from `d`. See issue #4290 --/ -private def eraseFwdIfBwd (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems := - if let some converseOrigin := e.origin.converse then - eraseIfExists d converseOrigin - else - d - -def addSimpTheoremEntry (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems := - let d := eraseFwdIfBwd d e - if e.post then - { d with post := d.post.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames } - else - { d with pre := d.pre.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames } -where - updateLemmaNames (s : PHashSet Origin) : PHashSet Origin := - s.insert e.origin - -def SimpTheorems.addDeclToUnfoldCore (d : SimpTheorems) (declName : Name) : SimpTheorems := - { d with toUnfold := d.toUnfold.insert declName } - -def SimpTheorems.addLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : SimpTheorems := - -- A small hack that relies on the fact that constants and `FVarId` names should be disjoint. - { d with toUnfold := d.toUnfold.insert fvarId.name } - -/-- Return `true` if `declName` is tagged to be unfolded using `unfoldDefinition?` (i.e., without using equational theorems). -/ -def SimpTheorems.isDeclToUnfold (d : SimpTheorems) (declName : Name) : Bool := - d.toUnfold.contains declName - -def SimpTheorems.isLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : Bool := - d.toUnfold.contains fvarId.name -- See comment at `addLetDeclToUnfold` - -def SimpTheorems.isLemma (d : SimpTheorems) (thmId : Origin) : Bool := - d.lemmaNames.contains thmId - -/-- Register the equational theorems for the given definition. -/ -def SimpTheorems.registerDeclToUnfoldThms (d : SimpTheorems) (declName : Name) (eqThms : Array Name) : SimpTheorems := - { d with toUnfoldThms := d.toUnfoldThms.insert declName eqThms } - -def SimpTheorems.erase [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] - (d : SimpTheorems) (thmId : Origin) : m SimpTheorems := do - if d.isLemma thmId || - match thmId with - | .decl declName .. => d.isDeclToUnfold declName || d.toUnfoldThms.contains declName - | _ => false - then - return d.eraseCore thmId - - -- `attribute [-simp] foo` should also undo `attribute [simp ←] foo`. - if let some thmId' := thmId.converse then - if d.isLemma thmId' then - return d.eraseCore thmId' - - logWarning m!"'{thmId.key}' does not have [simp] attribute" - return d private partial def isPerm : Expr → Expr → MetaM Bool | .app f₁ a₁, .app f₂ a₂ => isPerm f₁ f₂ <&&> isPerm a₁ a₂ @@ -432,7 +362,12 @@ private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array | none => throwError "unexpected kind of 'simp' theorem{indentExpr type}" return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) } -private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool) (prio : Nat) : MetaM (Array SimpTheorem) := do +/-- +Creates a `SimpTheorem` from a global theorem. +Because some theorems lead to multiple `SimpTheorems` (in particular conjunctions), returns an array. +-/ +def mkSimpTheoremFromConst (declName : Name) (post := true) (inv := false) + (prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) := do let cinfo ← getConstVal declName let us := cinfo.levelParams.map mkLevelParam let origin := .decl declName post inv @@ -449,52 +384,6 @@ private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool) else return #[← withoutExporting do mkSimpTheoremCore origin (mkConst declName us) #[] (mkConst declName) post prio (noIndexAtArgs := false)] -inductive SimpEntry where - | thm : SimpTheorem → SimpEntry - | toUnfold : Name → SimpEntry - | toUnfoldThms : Name → Array Name → SimpEntry - deriving Inhabited - -/-- -The environment extension that contains a simp set, returned by `Lean.Meta.registerSimpAttr`. - -Use the simp set's attribute or `Lean.Meta.addSimpTheorem` to add theorems to the simp set. Use -`Lean.Meta.SimpExtension.getTheorems` to get the contents. --/ -abbrev SimpExtension := SimpleScopedEnvExtension SimpEntry SimpTheorems - -def SimpExtension.getTheorems (ext : SimpExtension) : CoreM SimpTheorems := - return ext.getState (← getEnv) - -def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv : Bool) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do - let simpThms ← withExporting (isExporting := !isPrivateName declName) do mkSimpTheoremsFromConst declName post inv prio - for simpThm in simpThms do - ext.add (SimpEntry.thm simpThm) attrKind - -def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension := - registerSimpleScopedEnvExtension { - name := name - initial := {} - addEntry := fun d e => - match e with - | .thm e => addSimpTheoremEntry d e - | .toUnfold n => d.addDeclToUnfoldCore n - | .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms - } - -abbrev SimpExtensionMap := Std.HashMap Name SimpExtension - -builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {} - -def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) := - return (← simpExtensionMapRef.get)[attrName]? - -/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/ -def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do - let s := { s with erased := s.erased.erase (.decl declName post inv) } - let simpThms ← mkSimpTheoremsFromConst declName post inv prio - return simpThms.foldl addSimpTheoremEntry s - def SimpTheorem.getValue (simpThm : SimpTheorem) : MetaM Expr := do if simpThm.proof.isConst && simpThm.levelParams.isEmpty then let info ← getConstVal simpThm.proof.constName! @@ -512,10 +401,27 @@ private def preprocessProof (val : Expr) (inv : Bool) : MetaM (Array Expr) := do let ps ← preprocess val type inv (isGlobal := false) return ps.toArray.map fun (val, _) => val -/-- Auxiliary method for creating simp theorems from a proof term `val`. -/ -private def mkSimpTheorems (id : Origin) (levelParams : Array Name) (proof : Expr) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) := - withReducible do - (← preprocessProof proof inv).mapM fun val => mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true) +def mkSimpTheoremFromExpr (id : Origin) (levelParams : Array Name) (proof : Expr) (inv := false) + (post := true) (prio : Nat := eval_prio default) (config : ConfigWithKey := simpGlobalConfig) : + MetaM (Array SimpTheorem) := do + if proof.isConst then + -- Recall that we use `simpGlobalConfig` for processing global declarations. + mkSimpTheoremFromConst proof.constName! post inv prio + else + withConfigWithKey config do + withReducible do + (← preprocessProof proof inv).mapM fun val => + mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true) + +/-- +A simp theorem or information about a declaration to unfold by simp. +This is stored in the oleans to implement the `simp` attribute and user-defined simp sets. +-/ +inductive SimpEntry where + | thm : SimpTheorem → SimpEntry + | toUnfold : Name → SimpEntry + | toUnfoldThms : Name → Array Name → SimpEntry + deriving Inhabited /-- Reducible functions and projection functions should always be put in `toUnfold`, instead @@ -524,7 +430,7 @@ of trying to use equational theorems. The simplifiers has special support for structure and class projections, and gets confused when they suddenly rewrite, so ignore equations for them -/ -def SimpTheorems.ignoreEquations (declName : Name) : CoreM Bool := do +def Simp.ignoreEquations (declName : Name) : CoreM Bool := do return (← isProjectionFn declName) || (← isReducible declName) /-- @@ -540,20 +446,24 @@ behavior unless `unfoldPartialApp := true`. Moreover, users will have to use `f.eq_def` if they want to force the definition to be unfolded. -/ -def SimpTheorems.unfoldEvenWithEqns (declName : Name) : CoreM Bool := do +def Simp.unfoldEvenWithEqns (declName : Name) : CoreM Bool := do if hasSmartUnfoldingDecl (← getEnv) declName then return true unless (← isRecursiveDefinition declName) do return true return false -def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM SimpTheorems := do +/-- +Given the name of a declaration to unfold, return the `SimpEntry` (or entries) that +implement this unfolding, using either the equational theorems, or `SimpEntry.toUnfold`, or both. +-/ +def mkSimpEntryOfDeclToUnfold (declName : Name) : MetaM (Array SimpEntry) := do + let mut entries : Array SimpEntry := #[] -- NOTE: the latter condition is only to preserve previous behavior where simp accepts even things -- that neither theorems nor unfoldable. This should likely be tightened up in the future. if !(← getConstInfo declName).isDefinition && getOriginalConstKind? (← getEnv) declName == some .defn then throwError "invalid 'simp', definition with exposed body expected: {.ofConstName declName}" - if (← ignoreEquations declName) then - return d.addDeclToUnfoldCore declName + if (← Simp.ignoreEquations declName) then + entries := entries.push (.toUnfold declName) else if let some eqns ← getEqnsFor? declName then - let mut d := d for h : i in [:eqns.size] do let eqn := eqns[i] /- @@ -571,24 +481,187 @@ def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM Si if i + 1 = eqns.size then 0 else 1 else 100 - i - d ← SimpTheorems.addConst d eqn (prio := prio) - if (← unfoldEvenWithEqns declName) then - d := d.addDeclToUnfoldCore declName - return d + let thms ← mkSimpTheoremFromConst eqn (prio := prio) + entries := entries ++ thms.map (.thm ·) + if (← Simp.unfoldEvenWithEqns declName) then + entries := entries.push (.toUnfold declName) else - return d.addDeclToUnfoldCore declName + entries := entries.push (.toUnfold declName) + return entries + + +abbrev SimpTheoremTree := DiscrTree SimpTheorem + +/-- +The theorems in a simp set. +-/ +structure SimpTheorems where + pre : SimpTheoremTree := DiscrTree.empty + post : SimpTheoremTree := DiscrTree.empty + lemmaNames : PHashSet Origin := {} + /-- + Constants (and let-declaration `FVarId`) to unfold. + When `zetaDelta := false`, the simplifier will expand a let-declaration if it is in this set. + -/ + toUnfold : PHashSet Name := {} + erased : PHashSet Origin := {} + toUnfoldThms : PHashMap Name (Array Name) := {} + deriving Inhabited + +partial def SimpTheorems.eraseCore (d : SimpTheorems) (thmId : Origin) : SimpTheorems := + let d := { d with erased := d.erased.insert thmId, lemmaNames := d.lemmaNames.erase thmId } + if let .decl declName .. := thmId then + let d := { d with toUnfold := d.toUnfold.erase declName } + if let some thms := d.toUnfoldThms.find? declName then + let dummy := true + thms.foldl (init := d) (eraseCore · <| .decl · dummy (inv := false)) + else + d + else + d + +private def eraseIfExists (d : SimpTheorems) (thmId : Origin) : SimpTheorems := + if d.lemmaNames.contains thmId then + d.eraseCore thmId + else + d + +/-- +If `e` is a backwards theorem `← thm`, we must ensure the forward theorem is erased +from `d`. See issue #4290 +-/ +private def eraseFwdIfBwd (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems := + if let some converseOrigin := e.origin.converse then + eraseIfExists d converseOrigin + else + d + +def SimpTheorems.unerase (d : SimpTheorems) (thmId : Origin) : SimpTheorems := + { d with erased := d.erased.erase thmId } + +def SimpTheorems.addSimpTheorem (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems := + -- Erase the converse, if it exists + let d := eraseFwdIfBwd d e + if e.post then + { d with post := d.post.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames } + else + { d with pre := d.pre.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames } +where + updateLemmaNames (s : PHashSet Origin) : PHashSet Origin := + s.insert e.origin + +@[deprecated SimpTheorems.addSimpTheorem (since := "2025-06-17")] +def addSimpTheoremEntry := SimpTheorems.addSimpTheorem + +def SimpTheorems.addDeclToUnfoldCore (d : SimpTheorems) (declName : Name) : SimpTheorems := + { d with toUnfold := d.toUnfold.insert declName } + +def SimpTheorems.addLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : SimpTheorems := + -- A small hack that relies on the fact that constants and `FVarId` names should be disjoint. + { d with toUnfold := d.toUnfold.insert fvarId.name } + +/-- Return `true` if `declName` is tagged to be unfolded using `unfoldDefinition?` (i.e., without using equational theorems). -/ +def SimpTheorems.isDeclToUnfold (d : SimpTheorems) (declName : Name) : Bool := + d.toUnfold.contains declName + +def SimpTheorems.isLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : Bool := + d.toUnfold.contains fvarId.name -- See comment at `addLetDeclToUnfold` + +def SimpTheorems.isLemma (d : SimpTheorems) (thmId : Origin) : Bool := + d.lemmaNames.contains thmId + +/-- Register the equational theorems for the given definition. -/ +def SimpTheorems.registerDeclToUnfoldThms (d : SimpTheorems) (declName : Name) (eqThms : Array Name) : SimpTheorems := + { d with toUnfoldThms := d.toUnfoldThms.insert declName eqThms } + +def SimpTheorems.erase [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] + (d : SimpTheorems) (thmId : Origin) : m SimpTheorems := do + if d.isLemma thmId || + match thmId with + | .decl declName .. => d.isDeclToUnfold declName || d.toUnfoldThms.contains declName + | _ => false + then + return d.eraseCore thmId + + -- `attribute [-simp] foo` should also undo `attribute [simp ←] foo`. + if let some thmId' := thmId.converse then + if d.isLemma thmId' then + return d.eraseCore thmId' + + logWarning m!"'{thmId.key}' does not have [simp] attribute" + return d + +def SimpTheorems.addSimpEntry (d : SimpTheorems) (e : SimpEntry) : SimpTheorems := + match e with + | .thm e => d.addSimpTheorem e + | .toUnfold n => d.addDeclToUnfoldCore n + | .toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms + +/-- +`simp [foo]` should undo a previous `attribute @[-simp] foo`. +(Note that `attribute @[simp] foo` does not undo a `attribute @[simp] foo`, see #5852) +-/ +def SimpTheorems.uneraseSimpEntry (d : SimpTheorems) (e : SimpEntry) : SimpTheorems := + match e with + | .thm e => d.unerase e.origin + | _ => d + +/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/ +def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do + let simpThms ← mkSimpTheoremFromConst declName post inv prio + return simpThms.foldl SimpTheorems.addSimpTheorem s + +/-- +The environment extension that contains a simp set, returned by `Lean.Meta.registerSimpAttr`. + +Use the simp set's attribute or `Lean.Meta.addSimpTheorem` to add theorems to the simp set. Use +`Lean.Meta.SimpExtension.getTheorems` to get the contents. +-/ +abbrev SimpExtension := SimpleScopedEnvExtension SimpEntry SimpTheorems + +def SimpExtension.getTheorems (ext : SimpExtension) : CoreM SimpTheorems := + return ext.getState (← getEnv) + +/-- +Adds a simp theorem to a simp extension +-/ +def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv : Bool) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do + let simpThms ← withExporting (isExporting := !isPrivateName declName) do mkSimpTheoremFromConst declName post inv prio + for simpThm in simpThms do + ext.add (SimpEntry.thm simpThm) attrKind + + +def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension := + registerSimpleScopedEnvExtension { + name := name + initial := {} + addEntry := fun d e => d.addSimpEntry e + } + +abbrev SimpExtensionMap := Std.HashMap Name SimpExtension + +builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {} + +def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) := + return (← simpExtensionMapRef.get)[attrName]? + +def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM SimpTheorems := do + let entries ← mkSimpEntryOfDeclToUnfold declName + return entries.foldl (init := d) fun d e => d.addSimpEntry e /-- Auxiliary method for adding a local simp theorem to a `SimpTheorems` datastructure. -/ def SimpTheorems.add (s : SimpTheorems) (id : Origin) (levelParams : Array Name) (proof : Expr) (inv := false) (post := true) (prio : Nat := eval_prio default) (config : ConfigWithKey := simpGlobalConfig) : MetaM SimpTheorems := do - if proof.isConst then - -- Recall that we use `simpGlobalConfig` for processing global declarations. - s.addConst proof.constName! post inv prio - else - let simpThms ← withConfigWithKey config <| mkSimpTheorems id levelParams proof post inv prio - return simpThms.foldl addSimpTheoremEntry s + let simpThms ← mkSimpTheoremFromExpr id levelParams proof inv post prio config + return simpThms.foldl SimpTheorems.addSimpTheorem s +/-- +A `SimpTheoremsArray` is a collection of `SimpTheorems`. The first entry is the default simp set +and possible extensions as simp args (`simp [thm]`), further entries are custom simp sets added +a s simp arguments (`simp [my_simp_set]`). The array is scanned linear during rewriting. +This avoids the need for efficiently merging the `SimpTheorems` data structure. +-/ abbrev SimpTheoremsArray := Array SimpTheorems def SimpTheoremsArray.addTheorem (thmsArray : SimpTheoremsArray) (id : Origin) (h : Expr) (config : ConfigWithKey := simpGlobalConfig) : MetaM SimpTheoremsArray := diff --git a/tests/lean/run/8815.lean b/tests/lean/run/8815.lean new file mode 100644 index 0000000000..22de009ed6 --- /dev/null +++ b/tests/lean/run/8815.lean @@ -0,0 +1,62 @@ +/-! +Assortion of tests to make sure the #8815 simp arg elaboration refactoring did not change +behavior. +-/ + +set_option linter.unusedVariables false + +example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp [*, -hQ] + +/-- error: simp made no progress -/ +#guard_msgs in example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp [*, -hP] + +/-- error: unknown constant 'hQ' -/ +#guard_msgs in example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp [-hQ, *] + +#guard_msgs in example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp_all [-hQ] + +/-- +error: unknown constant 'hQ' +--- +error: simp made no progress +-/ +#guard_msgs in example (P Q : Prop) (hQ : Q) (hP : P) : P := by simp [-hQ] + + +theorem a_thm : True := trivial + +def f : Nat → Nat +| 0 => 1 +| n + 1 => f n + 1 + +example : f 0 > 0 := by simp [f] +example : f 0 > 0 := by simp! + + +-- NB: simp! disables all warnings, not just for declarations to unfold +-- Mild bug, but not a regresion. + +/-- +error: unsolved goals +⊢ 0 < f 0 +-/ +#guard_msgs in example : f 0 > 0 := by simp! [-f, -a_thm] + +/-- +warning: 'f' does not have [simp] attribute +--- +warning: 'a_thm' does not have [simp] attribute +--- +error: unsolved goals +⊢ 0 < f 0 +-/ +#guard_msgs in example : f 0 > 0 := by + simp [-f, -a_thm] + + +/-- +error: invalid 'simp', proposition expected + Type 32 +-/ +#guard_msgs in +example : True := by simp [Sort 32] -- mostly about error location, once guard_msgs shows that diff --git a/tests/lean/run/eqnsPrio.lean b/tests/lean/run/eqnsPrio.lean index e08e0999f9..23e2a6efbc 100644 --- a/tests/lean/run/eqnsPrio.lean +++ b/tests/lean/run/eqnsPrio.lean @@ -22,13 +22,13 @@ termination_by _ n => n #check foo.eq_3 -- In order to reliably check if simp is not attempting to rewrite with a certain lemma --- we can look at te diagnostics. But simply dumping all diangostics is too noisy for a test, +-- we can look at the diagnostics. But simply dumping all diangostics is too noisy for a test, -- so here we try to get our hands at the `Simp.Stats` and look there. open Lean Meta Elab Tactic in elab "simp_foo_with_check" : tactic => withOptions (fun o => diagnostics.set o true) do withMainContext do let stx ← `(tactic|simp [foo]) - let { ctx, simprocs, dischargeWrapper } ← mkSimpContext stx (eraseLocal := false) + let { ctx, simprocs, dischargeWrapper, .. } ← mkSimpContext stx (eraseLocal := false) let stats ← dischargeWrapper.with fun discharge? => do simpLocation ctx simprocs discharge? (expandOptLocation stx.raw[5]) unless stats.diag.triedThmCounter.toList.any (fun (o, _n) => o.key = ``foo.eq_2) do