diff --git a/src/Lean/Elab/Tactic/Grind/Basic.lean b/src/Lean/Elab/Tactic/Grind/Basic.lean index 5fe65c1418..228ff2bfa6 100644 --- a/src/Lean/Elab/Tactic/Grind/Basic.lean +++ b/src/Lean/Elab/Tactic/Grind/Basic.lean @@ -25,10 +25,16 @@ structure Context extends Tactic.Context where open Meta.Grind (Goal) -/-- Cache key for `Sym.simp` variant invocations: variant name + ordered extra theorem names. -/ +/-- An extra theorem passed to `simp` in `sym =>` mode. -/ +inductive ExtraTheorem where + | const (declName : Name) + | fvar (fvarId : FVarId) + deriving BEq, Hashable + +/-- Cache key for `Sym.simp` variant invocations. -/ structure SimpCacheKey where variant : Name - extras : List Name + extras : Array ExtraTheorem deriving BEq, Hashable structure Cache where diff --git a/src/Lean/Elab/Tactic/Grind/Sym.lean b/src/Lean/Elab/Tactic/Grind/Sym.lean index 57473898cf..62c83e8b52 100644 --- a/src/Lean/Elab/Tactic/Grind/Sym.lean +++ b/src/Lean/Elab/Tactic/Grind/Sym.lean @@ -153,39 +153,54 @@ def elabOptSimproc (stx? : Option Syntax) : GrindTacticM Simproc := do let some stx := stx? | return trivialSimproc elabSymSimproc stx -def addExtraTheorems (post : Simproc) (extraNames : Array Name) : GrindTacticM Simproc := do - if extraNames.isEmpty then return post +def resolveExtraTheorems (ids? : Option (Array (TSyntax `ident))) : GrindTacticM (Array ExtraTheorem × Array Theorem) := do + let some ids := ids? | return (#[], #[]) + let mut extras := #[] + let mut thms := #[] + let lctx ← getLCtx + for id in ids do + if let some decl := lctx.findFromUserName? id.getId then + extras := extras.push <| .fvar decl.fvarId + thms := thms.push (← mkTheoremFromExpr decl.toExpr) + else + let declName ← realizeGlobalConstNoOverload id + extras := extras.push <| .const declName + thms := thms.push (← mkTheoremFromDecl declName) + return (extras, thms) + +def addExtraTheorems (post : Simproc) (extraThms : Array Theorem) : GrindTacticM Simproc := do + if extraThms.isEmpty then return post let mut thms : Theorems := {} - for name in extraNames do - thms := thms.insert (← mkTheoremFromDecl name) + for thm in extraThms do + thms := thms.insert thm return post >> thms.rewrite -def mkDefaultMethods (extraNames : Array Name) : GrindTacticM Sym.Simp.Methods := do +def mkDefaultMethods (extraThms : Array Theorem) : GrindTacticM Sym.Simp.Methods := do let thms ← getSymSimpTheorems let pre := simpControl >> simpArrowTelescope - let post ← addExtraTheorems (evalGround >> thms.rewrite) extraNames + let post ← addExtraTheorems (evalGround >> thms.rewrite) extraThms return { pre, post } -def elabVariant (variantName : Name) (extraNames : Array Name) : GrindTacticM (Sym.Simp.Methods × Sym.Simp.Config) := do +def elabVariant (variantName : Name) (extraThms : Array Theorem) : GrindTacticM (Sym.Simp.Methods × Sym.Simp.Config) := do if variantName.isAnonymous then - return (← mkDefaultMethods extraNames, {}) + return (← mkDefaultMethods extraThms, {}) let some v := getSymSimpVariant? (← getEnv) variantName | throwError "unknown Sym.simp variant `{variantName}`" let pre ← elabOptSimproc v.pre? - let post ← addExtraTheorems (← elabOptSimproc v.post?) extraNames + let post ← addExtraTheorems (← elabOptSimproc v.post?) extraThms return ({ pre, post}, v.config) -@[builtin_grind_tactic Parser.Tactic.Grind.symSimp] def evalSymSimp : GrindTactic := fun stx => do +@[builtin_grind_tactic Parser.Tactic.Grind.symSimp] def evalSymSimp : GrindTactic := fun stx => withMainContext do ensureSym let `(grind| simp $[$variantId?]? $[[ $[$extraIds],* ]]?) := stx | throwUnsupportedSyntax -- Resolve variant let variantName := variantId?.map (·.getId) |>.getD .anonymous - -- Compose extra theorems into post - let extraNames ← (extraIds.getD #[]).mapM fun id => realizeGlobalConstNoOverload id + -- Resolve extra theorems (local hypotheses first, then global constants) + let (extras, thms) ← resolveExtraTheorems extraIds -- Cache lookup/creation - let cacheKey : SimpCacheKey := { variant := variantName, extras := extraNames.toList } + let cacheKey : SimpCacheKey := { variant := variantName, extras } let simpState := (← get).cache.simpState[cacheKey]?.getD {} - let (methods, config) ← elabVariant variantName extraNames + let (methods, config) ← elabVariant variantName thms let goal ← getMainGoal let (simpResult, simpState) ← liftGrindM <| goal.withContext do Sym.Simp.SimpM.run (Sym.Simp.simp (← goal.mvarId.getType)) methods config simpState diff --git a/tests/elab/sym_simp_adapt1.lean b/tests/elab/sym_simp_adapt1.lean index 14774ffbdb..26af049e24 100644 --- a/tests/elab/sym_simp_adapt1.lean +++ b/tests/elab/sym_simp_adapt1.lean @@ -70,3 +70,30 @@ example (x : Nat) : ¬ p x := by example (x : Nat) : p x = q x := by sym => simp simple [iff_thm] + +-- Tests for local hypothesis support in `simp [h]` + +-- Local hypothesis `h : p x` rewrites `p x` to `True` +example (x : Nat) (h : p x) : p x = True := by + sym => simp simple [h] + +-- Local hypothesis `h : ¬ p x` rewrites `p x` to `False` +example (x : Nat) (h : ¬ p x) : p x = False := by + sym => simp simple [h] + +-- Local hypothesis `h : p x ↔ q x` rewrites `p x` to `q x` +example (x : Nat) (h : p x ↔ q x) : p x = q x := by + sym => simp simple [h] + +-- Local hypothesis `h : p x = q x` (already an equality) +example (x : Nat) (h : p x = q x) : p x = q x := by + sym => simp simple [h] + +-- Local hypothesis with intro +example (x : Nat) : p x → p x = True := by + sym => + intro h + simp simple [h] + +example (h : ∀ x, p x = q x) : p a = q a ∧ p b = q b := by + sym => simp simple [h, and_true]