diff --git a/src/Init/Conv.lean b/src/Init/Conv.lean index b4eb6a3ee3..919a592e3d 100644 --- a/src/Init/Conv.lean +++ b/src/Init/Conv.lean @@ -30,7 +30,7 @@ syntax (name := change) "change " term : conv syntax (name := delta) "delta " ident : conv syntax (name := pattern) "pattern " term : conv syntax (name := rewrite) "rewrite " (config)? rwRuleSeq : conv -syntax (name := simp) "simp " ("(" &"config" " := " term ")")? (&"only ")? ("[" (simpStar <|> simpErase <|> simpLemma),* "]")? : conv +syntax (name := simp) "simp " (config)? (discharger)? (&"only ")? ("[" (simpStar <|> simpErase <|> simpLemma),* "]")? : conv syntax (name := simpMatch) "simpMatch " : conv /-- Execute the given tactic block without converting `conv` goal into a regular goal -/ diff --git a/src/Init/Notation.lean b/src/Init/Notation.lean index d860dd2368..bd41e6ae31 100644 --- a/src/Init/Notation.lean +++ b/src/Init/Notation.lean @@ -319,7 +319,7 @@ macro "sorry" : tactic => `(exact sorry) macro "inferInstance" : tactic => `(exact inferInstance) /-- Optional configuration option for tactics -/ -syntax config := ("(" &"config" " := " term ")") +syntax config := atomic("(" &"config") " := " term ")" syntax locationWildcard := "*" syntax locationHyp := (colGt ident)+ ("⊢" <|> "|-")? -- TODO: delete @@ -352,13 +352,15 @@ syntax (name := injection) "injection " term (" with " (colGt (ident <|> "_"))+) syntax (name := injections) "injections" : tactic +syntax discharger := atomic("(" (&"discharger" <|> &"disch")) " := " tacticSeq ")" + syntax simpPre := "↓" syntax simpPost := "↑" syntax simpLemma := (simpPre <|> simpPost)? ("←" <|> "<-")? term syntax simpErase := "-" ident syntax simpStar := "*" -syntax (name := simp) "simp " (config)? (&"only ")? ("[" (simpStar <|> simpErase <|> simpLemma),* "]")? (location)? : tactic -syntax (name := simpAll) "simp_all " (config)? (&"only ")? ("[" (simpErase <|> simpLemma),* "]")? : tactic +syntax (name := simp) "simp " (config)? (discharger)? (&"only ")? ("[" (simpStar <|> simpErase <|> simpLemma),* "]")? (location)? : tactic +syntax (name := simpAll) "simp_all " (config)? (discharger)? (&"only ")? ("[" (simpErase <|> simpLemma),* "]")? : tactic /-- Delta expand the given definition. diff --git a/src/Lean/Elab/Tactic/Conv/Simp.lean b/src/Lean/Elab/Tactic/Conv/Simp.lean index 5797efd983..76cdb6fee6 100644 --- a/src/Lean/Elab/Tactic/Conv/Simp.lean +++ b/src/Lean/Elab/Tactic/Conv/Simp.lean @@ -17,8 +17,10 @@ def applySimpResult (result : Simp.Result) : TacticM Unit := do updateLhs result.expr (← result.getProof) @[builtinTactic Lean.Parser.Tactic.Conv.simp] def evalSimp : Tactic := fun stx => withMainContext do - let { ctx, .. } ← mkSimpContext stx (eraseLocal := false) - applySimpResult (← simp (← getLhs) ctx) + let { ctx, dischargeWrapper, .. } ← mkSimpContext stx (eraseLocal := false) + let lhs ← getLhs + let result ← dischargeWrapper.with fun d? => simp lhs ctx (discharge? := d?) + applySimpResult result @[builtinTactic Lean.Parser.Tactic.Conv.simpMatch] def evalSimpMatch : Tactic := fun stx => withMainContext do applySimpResult (← Split.simpMatch (← getLhs)) diff --git a/src/Lean/Elab/Tactic/Simp.lean b/src/Lean/Elab/Tactic/Simp.lean index 12330a63fb..280f5eb017 100644 --- a/src/Lean/Elab/Tactic/Simp.lean +++ b/src/Lean/Elab/Tactic/Simp.lean @@ -17,6 +17,59 @@ open Meta declare_config_elab elabSimpConfigCore Meta.Simp.Config declare_config_elab elabSimpConfigCtxCore Meta.Simp.ConfigCtx +/-- + Implement a `simp` discharge function using the given tactic syntax code. + Recall that `simp` dischargers are in `SimpM` which does not have access to `Term.State`. + We need access to `Term.State` to store messages and update the info tree. + Thus, we create an `IO.ref` to track these changes at `Term.State` when we execute `tacticCode`. + We must set this reference with the current `Term.State` before we execute `simp` using the + generated `Simp.Discharge`. -/ +def tacticToDischarge (tacticCode : Syntax) : TacticM (IO.Ref Term.State × Simp.Discharge) := do + let tacticCode ← `(tactic| try ($tacticCode:tacticSeq)) + let ref ← IO.mkRef (← getThe Term.State) + let ctx ← readThe Term.Context + let disch : Simp.Discharge := fun e => do + let mvar ← mkFreshExprSyntheticOpaqueMVar e `simp.discharger + let s ← ref.get + let runTac? : TermElabM (Option Expr) := + try + /- We must only save messages and info tree changes. Recall that `simp` uses temporary metavariables (`withNewMCtxDepth`). + So, we must not save references to them at `Term.State`. -/ + withoutModifyingStateWithInfoAndMessages do + Term.withSynthesize (mayPostpone := false) <| Term.runTactic mvar.mvarId! tacticCode + let result ← instantiateMVars mvar + if result.hasExprMVar then + return none + else + return some result + catch _ => + return none + let (result?, s) ← liftM (m := MetaM) <| Term.TermElabM.run runTac? ctx s + ref.set s + return result? + return (ref, disch) + +inductive Simp.DischargeWrapper where + | default + | custom (ref : IO.Ref Term.State) (discharge : Simp.Discharge) + +def Simp.DischargeWrapper.with (w : Simp.DischargeWrapper) (x : Option Simp.Discharge → MetaM α) : TacticM α := do + match w with + | default => x none + | custom ref d => + ref.set (← getThe Term.State) + try + x d + finally + set (← ref.get) + +private def mkDischargeWrapper (optDischargeSyntax : Syntax) : TacticM Simp.DischargeWrapper := do + if optDischargeSyntax.isNone then + return Simp.DischargeWrapper.default + else + let (ref, d) ← tacticToDischarge optDischargeSyntax[0][3] + return Simp.DischargeWrapper.custom ref d + /- `optConfig` is of the form `("(" "config" ":=" term ")")?` If `ctx == false`, the argument is assumed to have type `Meta.Simp.Config`, and `Meta.Simp.ConfigCtx` otherwise. -/ @@ -121,11 +174,17 @@ private def getPropHyps : MetaM (Array FVarId) := do return result structure MkSimpContextResult where - ctx : Simp.Context - fvarIdToLemmaId : FVarIdToLemmaId + ctx : Simp.Context + dischargeWrapper : Simp.DischargeWrapper + fvarIdToLemmaId : FVarIdToLemmaId --- If `ctx == false`, the argument is assumed to have type `Meta.Simp.Config`, and `Meta.Simp.ConfigCtx` otherwise. -/ +/-- + If `ctx == false`, the config argument is assumed to have type `Meta.Simp.Config`, and `Meta.Simp.ConfigCtx` otherwise. + If `ctx == false`, the `discharge` option must be none -/ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (ctx := false) (ignoreStarArg : Bool := false) : TacticM MkSimpContextResult := do + if ctx && !stx[2].isNone then + throwError "'simp_all' tactic does not support 'discharger' option" + let dischargeWrapper ← mkDischargeWrapper stx[2] let simpOnly := !stx[3].isNone let simpLemmas ← if simpOnly then @@ -138,7 +197,7 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (ctx := false) (ignoreStarA simpLemmas, congrLemmas } if !r.starArg || ignoreStarArg then - return { r with fvarIdToLemmaId := {} } + return { r with fvarIdToLemmaId := {}, dischargeWrapper } else let ctx := r.ctx let erased := ctx.simpLemmas.erased @@ -154,27 +213,30 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (ctx := false) (ignoreStarA fvarIdToLemmaId := fvarIdToLemmaId.insert fvarId id let simpLemmas ← ctx.simpLemmas.add #[] proof (name? := id) ctx := { ctx with simpLemmas } - return { ctx, fvarIdToLemmaId } + return { ctx, fvarIdToLemmaId, dischargeWrapper } /- "simp " (config)? (discharger)? ("only ")? ("[" simpLemma,* "]")? (location)? -/ @[builtinTactic Lean.Parser.Tactic.simp] def evalSimp : Tactic := fun stx => do - let { ctx, fvarIdToLemmaId } ← withMainContext <| mkSimpContext stx (eraseLocal := false) + let { ctx, fvarIdToLemmaId, dischargeWrapper } ← withMainContext <| mkSimpContext stx (eraseLocal := false) -- trace[Meta.debug] "Lemmas {← toMessageData ctx.simpLemmas.post}" let loc := expandOptLocation stx[5] match loc with | Location.targets hUserNames simplifyTarget => withMainContext do let fvarIds ← hUserNames.mapM fun hUserName => return (← getLocalDeclFromUserName hUserName).fvarId - go ctx fvarIds simplifyTarget fvarIdToLemmaId + go ctx dischargeWrapper fvarIds simplifyTarget fvarIdToLemmaId | Location.wildcard => withMainContext do - go ctx (← getNondepPropHyps (← getMainGoal)) (simplifyTarget := true) fvarIdToLemmaId + go ctx dischargeWrapper (← getNondepPropHyps (← getMainGoal)) (simplifyTarget := true) fvarIdToLemmaId where - go (ctx : Simp.Context) (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) (fvarIdToLemmaId : FVarIdToLemmaId) : TacticM Unit := do - liftMetaTactic1 fun mvarId => - return (← simpGoal mvarId ctx (simplifyTarget := simplifyTarget) (fvarIdsToSimp := fvarIdsToSimp) (fvarIdToLemmaId := fvarIdToLemmaId)).map (·.2) + go (ctx : Simp.Context) (dischargeWrapper : Simp.DischargeWrapper) (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) (fvarIdToLemmaId : FVarIdToLemmaId) : TacticM Unit := do + let mvarId ← getMainGoal + let result? ← dischargeWrapper.with fun discharge? => return (← simpGoal mvarId ctx (simplifyTarget := simplifyTarget) (discharge? := discharge?) (fvarIdsToSimp := fvarIdsToSimp) (fvarIdToLemmaId := fvarIdToLemmaId)).map (·.2) + match result? with + | none => replaceMainGoal [] + | some mvarId => replaceMainGoal [mvarId] @[builtinTactic Lean.Parser.Tactic.simpAll] def evalSimpAll : Tactic := fun stx => do let { ctx, .. } ← mkSimpContext stx (eraseLocal := true) (ctx := true) (ignoreStarArg := true) diff --git a/tests/lean/simpDisch.lean b/tests/lean/simpDisch.lean new file mode 100644 index 0000000000..9fce860d85 --- /dev/null +++ b/tests/lean/simpDisch.lean @@ -0,0 +1,24 @@ +constant f : Nat → Nat +@[simp] axiom fEq (x : Nat) (h : x ≠ 0) : f x = x + +example (x : Nat) (h : x ≠ 0) : f x = x + 0 := by + simp (discharger := traceState; exact (fun h' => h') h) + +example (x y : Nat) (h1 : x ≠ 0) (h2 : y ≠ 0) (h3 : x = y) : f x = f y + 0 := by + simp (discharger := traceState; assumption) + assumption + +example (x y : Nat) (h1 : x ≠ 0) (h2 : y ≠ 0) (h3 : x = y) : f x = f y + 0 := by + simp (discharger := assumption) + assumption + +example (x y : Nat) (h1 : x ≠ 0) (h2 : y ≠ 0) (h3 : x = y) : f x = f y + 0 := by + simp (disch := assumption) + assumption + +example (x y : Nat) (h1 : x ≠ 0) (h2 : y ≠ 0) (h3 : x = y) : f x = f y + 0 := by + conv => lhs; simp (disch := assumption) + traceState + conv => rhs; simp (disch := assumption) + traceState + assumption diff --git a/tests/lean/simpDisch.lean.expected.out b/tests/lean/simpDisch.lean.expected.out new file mode 100644 index 0000000000..e7b4bfc6e4 --- /dev/null +++ b/tests/lean/simpDisch.lean.expected.out @@ -0,0 +1,26 @@ +case simp.discharger +x : Nat +h : x ≠ 0 +⊢ x ≠ 0 +case simp.discharger +x y : Nat +h1 : x ≠ 0 +h2 : y ≠ 0 +h3 : x = y +⊢ x ≠ 0 +case simp.discharger +x y : Nat +h1 : x ≠ 0 +h2 : y ≠ 0 +h3 : x = y +⊢ y ≠ 0 +x y : Nat +h1 : x ≠ 0 +h2 : y ≠ 0 +h3 : x = y +⊢ x = f y + 0 +x y : Nat +h1 : x ≠ 0 +h2 : y ≠ 0 +h3 : x = y +⊢ x = y