diff --git a/src/Lean/Elab/BuiltinNotation.lean b/src/Lean/Elab/BuiltinNotation.lean index 07df420ec0..c986d185a1 100644 --- a/src/Lean/Elab/BuiltinNotation.lean +++ b/src/Lean/Elab/BuiltinNotation.lean @@ -209,6 +209,27 @@ private def elabCDot (stx : Syntax) (expectedType? : Option Expr) : TermElabM Ex | some stx' => withMacroExpansion stx stx' (elabTerm stx' expectedType?) | none => elabTerm stx expectedType? +/-- + Helper method for elaborating terms such as `(.+.)` where a constant name is expected. + This method is usually used to implement tactics that function names as arguments (e.g., `simp`). +-/ +def elabCDotFunctionAlias? (stx : Syntax) : TermElabM (Option Expr) := do + let some stx ← liftMacroM <| expandCDotArg? stx | pure none + let stx ← liftMacroM <| expandMacros stx + match stx with + | `(fun $binders* => $f:ident $args*) => + if binders == args then + try Term.resolveId? f catch _ => return none + else + return none + | _ => return none +where + expandCDotArg? (stx : Syntax) : MacroM (Option Syntax) := + match stx with + | `(($e)) => Term.expandCDot? e + | _ => Term.expandCDot? stx + + @[builtinTermElab paren] def elabParen : TermElab := fun stx expectedType? => do match stx with | `(()) => return Lean.mkConst `Unit.unit diff --git a/src/Lean/Elab/Tactic/Simp.lean b/src/Lean/Elab/Tactic/Simp.lean index 614a624b43..6e789f14f2 100644 --- a/src/Lean/Elab/Tactic/Simp.lean +++ b/src/Lean/Elab/Tactic/Simp.lean @@ -8,6 +8,7 @@ import Lean.Elab.Tactic.Basic import Lean.Elab.Tactic.ElabTerm import Lean.Elab.Tactic.Location import Lean.Meta.Tactic.Replace +import Lean.Elab.BuiltinNotation namespace Lean.Elab.Tactic open Meta @@ -37,6 +38,24 @@ def elabSimpConfig (optConfig : Syntax) (ctx : Bool) : TermElabM Meta.Simp.Confi else evalSimpConfig (← instantiateMVars c) +private def elabSimpLemmaTerm (stx : Syntax) : TacticM Expr := do + withRef stx <| Term.withoutErrToSorry do + let e ← Term.elabTerm stx none + Term.synthesizeSyntheticMVarsUsingDefault + let e ← instantiateMVars e + return e.eta + +private def addLemma (lemmas : Meta.SimpLemmas) (e : Expr) (post : Bool): MetaM Meta.SimpLemmas := do + if e.isConst then + let declName := e.constName! + let info ← getConstInfo declName + if (← isProp info.type) then + lemmas.addConst declName post + else + lemmas.addDeclToUnfold declName + else + lemmas.add e post + /-- Elaborate extra simp lemmas provided to `simp`. `stx` is of the `simpLemma,*` If `eraseLocal == true`, then we consider local declarations when resolving names for erased lemmas (`- id`), @@ -70,19 +89,10 @@ private def elabSimpLemmas (stx : Syntax) (ctx : Simp.Context) (eraseLocal : Boo else arg[0][0].getKind == ``Parser.Tactic.simpPost match (← resolveSimpIdLemma? arg[1]) with - | some e => - if e.isConst then - let declName := e.constName! - let info ← getConstInfo declName - if (← isProp info.type) then - lemmas ← lemmas.addConst declName post - else - lemmas := lemmas.addDeclToUnfold declName - else - lemmas ← lemmas.add e post + | some e => lemmas ← addLemma lemmas e post | _ => - let arg ← elabTerm arg[1] none (mayPostpone := false) - lemmas ← lemmas.add arg post + let e ← elabSimpLemmaTerm arg[1] + lemmas ← addLemma lemmas e post return { ctx with simpLemmas := lemmas } where resolveSimpIdLemma? (simpArgTerm : Syntax) : TacticM (Option Expr) := do @@ -92,7 +102,7 @@ where catch _ => return none else - return none + Term.elabCDotFunctionAlias? simpArgTerm -- If `ctx == false`, the argument is assumed to have type `Meta.Simp.Config`, and `Meta.Simp.ConfigCtx` otherwise. -/ private def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (ctx := false) : TacticM Simp.Context := do diff --git a/tests/lean/cdotAtSimpArg.lean b/tests/lean/cdotAtSimpArg.lean new file mode 100644 index 0000000000..8d8c3f36c5 --- /dev/null +++ b/tests/lean/cdotAtSimpArg.lean @@ -0,0 +1,16 @@ +example : ¬ true = false := by + simp [(¬ ·)] + +example (h : y = 0) : x + y = x := by + simp [(.+.)] -- Expands `HAdd.hAdd + traceState + simp [Add.add] + simp [h, Nat.add] + done + +example (h : y = 0) : x + y = x := by + simp [.+.] + traceState + simp [Add.add] + simp [h, Nat.add] + done diff --git a/tests/lean/cdotAtSimpArg.lean.expected.out b/tests/lean/cdotAtSimpArg.lean.expected.out new file mode 100644 index 0000000000..37cb79eba0 --- /dev/null +++ b/tests/lean/cdotAtSimpArg.lean.expected.out @@ -0,0 +1,6 @@ +y x : Nat +h : y = 0 +⊢ Add.add x y = x +y x : Nat +h : y = 0 +⊢ Add.add x y = x