diff --git a/src/Lean/Elab/Tactic/Simp.lean b/src/Lean/Elab/Tactic/Simp.lean index 4034cd73aa..d34f2f9f1e 100644 --- a/src/Lean/Elab/Tactic/Simp.lean +++ b/src/Lean/Elab/Tactic/Simp.lean @@ -434,7 +434,7 @@ where if tactic.simp.trace.get (← getOptions) then traceSimpCall stx usedSimps -def dsimpLocation (ctx : Simp.Context) (loc : Location) : TacticM Unit := do +def dsimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (loc : Location) : TacticM Unit := do match loc with | Location.targets hyps simplifyTarget => withMainContext do @@ -446,7 +446,7 @@ def dsimpLocation (ctx : Simp.Context) (loc : Location) : TacticM Unit := do where go (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) : TacticM Unit := do let mvarId ← getMainGoal - let (result?, usedSimps) ← dsimpGoal mvarId ctx (simplifyTarget := simplifyTarget) (fvarIdsToSimp := fvarIdsToSimp) + let (result?, usedSimps) ← dsimpGoal mvarId ctx simprocs (simplifyTarget := simplifyTarget) (fvarIdsToSimp := fvarIdsToSimp) match result? with | none => replaceMainGoal [] | some mvarId => replaceMainGoal [mvarId] @@ -454,8 +454,8 @@ where mvarId.withContext <| traceSimpCall (← getRef) usedSimps @[builtin_tactic Lean.Parser.Tactic.dsimp] def evalDSimp : Tactic := fun stx => do - let { ctx, .. } ← withMainContext <| mkSimpContext stx (eraseLocal := false) (kind := .dsimp) - dsimpLocation ctx (expandOptLocation stx[5]) + let { ctx, simprocs, .. } ← withMainContext <| mkSimpContext stx (eraseLocal := false) (kind := .dsimp) + dsimpLocation ctx simprocs (expandOptLocation stx[5]) end Lean.Elab.Tactic diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index f24bc520a5..0ad49a6f76 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -400,24 +400,20 @@ def simpLet (e : Expr) : SimpM Result := do let h ← mkLambdaFVars #[x] h return { expr := e', proof? := some (← mkLetBodyCongr v' h) } +private def dsimpReduce : DSimproc := fun e => do + let mut eNew ← reduce e + if eNew.isFVar then + eNew ← reduceFVar (← getConfig) (← getSimpTheorems) eNew + if eNew != e then return .visit eNew else return .done e + @[export lean_dsimp] private partial def dsimpImpl (e : Expr) : SimpM Expr := do let cfg ← getConfig unless cfg.dsimp do return e - let pre (e : Expr) : SimpM TransformStep := do - if let Step.visit r ← rewritePre (rflOnly := true) e then - if r.expr != e then - return .visit r.expr - return .continue - let post (e : Expr) : SimpM TransformStep := do - if let Step.visit r ← rewritePost (rflOnly := true) e then - if r.expr != e then - return .visit r.expr - let mut eNew ← reduce e - if eNew.isFVar then - eNew ← reduceFVar cfg (← getSimpTheorems) eNew - if eNew != e then return .visit eNew else return .done e + let m ← getMethods + let pre := m.dpre + let post := m.dpost >> dsimpReduce transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post) def visitFn (e : Expr) : SimpM Result := do @@ -649,9 +645,9 @@ def simp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (disc | none => Simp.main e ctx usedSimps (methods := Simp.mkDefaultMethodsCore simprocs) | some d => Simp.main e ctx usedSimps (methods := Simp.mkMethods simprocs d) -def dsimp (e : Expr) (ctx : Simp.Context) +def dsimp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (usedSimps : UsedSimps := {}) : MetaM (Expr × UsedSimps) := do profileitM Exception "dsimp" (← getOptions) do - Simp.dsimpMain e ctx usedSimps (methods := Simp.mkDefaultMethodsCore {}) + Simp.dsimpMain e ctx usedSimps (methods := Simp.mkDefaultMethodsCore simprocs ) /-- See `simpTarget`. This method assumes `mvarId` is not assigned, and we are already using `mvarId`s local context. -/ def simpTargetCore (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (discharge? : Option Simp.Discharge := none) @@ -800,7 +796,7 @@ def simpTargetStar (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsAr else return (TacticResultCNM.modified mvarId', usedSimps') -def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[]) +def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[]) (usedSimps : UsedSimps := {}) : MetaM (Option MVarId × UsedSimps) := do mvarId.withContext do mvarId.checkNotAssigned `simp @@ -808,7 +804,7 @@ def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := t let mut usedSimps : UsedSimps := usedSimps for fvarId in fvarIdsToSimp do let type ← instantiateMVars (← fvarId.getType) - let (typeNew, usedSimps') ← dsimp type ctx + let (typeNew, usedSimps') ← dsimp type ctx simprocs usedSimps := usedSimps' if typeNew.isFalse then mvarIdNew.assign (← mkFalseElim (← mvarIdNew.getType) (mkFVar fvarId)) @@ -817,7 +813,7 @@ def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := t mvarIdNew ← mvarIdNew.replaceLocalDeclDefEq fvarId typeNew if simplifyTarget then let target ← mvarIdNew.getType - let (targetNew, usedSimps') ← dsimp target ctx usedSimps + let (targetNew, usedSimps') ← dsimp target ctx simprocs usedSimps usedSimps := usedSimps' if targetNew.isTrue then mvarIdNew.assign (mkConst ``True.intro) diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index 08baf1b86d..2ae9800a17 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -319,6 +319,26 @@ def rewritePost (rflOnly := false) : Simproc := fun e => do return .visit r return .continue +def drewritePre : DSimproc := fun e => do + for thms in (← getContext).simpTheorems do + if let some r ← rewrite? e thms.pre thms.erased (tag := "pre") (rflOnly := true) then + return .visit r.expr + return .continue + +def drewritePost : DSimproc := fun e => do + for thms in (← getContext).simpTheorems do + if let some r ← rewrite? e thms.post thms.erased (tag := "post") (rflOnly := true) then + return .visit r.expr + return .continue + +def dpreDefault (s : SimprocsArray) : DSimproc := + drewritePre >> + userPreDSimprocs s + +def dpostDefault (s : SimprocsArray) : DSimproc := + drewritePost >> + userPostDSimprocs s + /-- Discharge procedure for the ground/symbolic evaluator. -/ @@ -382,6 +402,8 @@ def mkSEvalMethods : CoreM Methods := do return { pre := preSEval #[s] post := postSEval #[s] + dpre := dpreDefault #[s] + dpost := dpostDefault #[s] discharge? := dischargeGround } @@ -525,6 +547,8 @@ abbrev Discharge := Expr → SimpM (Option Expr) def mkMethods (s : SimprocsArray) (discharge? : Discharge) : Methods := { pre := preDefault s post := postDefault s + dpre := dpreDefault s + dpost := dpostDefault s discharge? := discharge? } diff --git a/src/Lean/Meta/Tactic/Simp/Simproc.lean b/src/Lean/Meta/Tactic/Simp/Simproc.lean index 84d7c37acc..dc984308e4 100644 --- a/src/Lean/Meta/Tactic/Simp/Simproc.lean +++ b/src/Lean/Meta/Tactic/Simp/Simproc.lean @@ -200,6 +200,18 @@ def SimprocEntry.try (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM let s ← proc e s.toStep.addExtraArgs extraArgs +/-- Similar to `try`, but only consider `DSimproc` case. That is, if `s.proc` is a `Simproc`, treat it as a `.continue`. -/ +def SimprocEntry.tryD (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM DStep := do + let mut extraArgs := #[] + let mut e := e + for _ in [:numExtraArgs] do + extraArgs := extraArgs.push e.appArg! + e := e.appFn! + extraArgs := extraArgs.reverse + match s.proc with + | .inl _ => return .continue + | .inr proc => return (← proc e).addExtraArgs extraArgs + def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM Step := do let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig)) if candidates.isEmpty then @@ -237,6 +249,39 @@ def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Ex else return .continue +def dsimprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM DStep := do + let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig)) + if candidates.isEmpty then + let tag := if post then "post" else "pre" + trace[Debug.Meta.Tactic.simp] "no {tag}-simprocs found for {e}" + return .continue + else + let mut e := e + let mut found := false + for (simprocEntry, numExtraArgs) in candidates do + unless erased.contains simprocEntry.declName do + let s ← simprocEntry.tryD numExtraArgs e + match s with + | .visit eNew => + trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}" + recordSimpTheorem (.decl simprocEntry.declName post) + return .visit eNew + | .done eNew => + trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}" + recordSimpTheorem (.decl simprocEntry.declName post) + return .done eNew + | .continue (some eNew) => + trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}" + recordSimpTheorem (.decl simprocEntry.declName post) + e := eNew + found := true + | .continue none => + pure () + if found then + return .continue (some e) + else + return .continue + abbrev SimprocsArray := Array Simprocs def SimprocsArray.add (ss : SimprocsArray) (declName : Name) (post : Bool) : CoreM SimprocsArray := @@ -272,6 +317,22 @@ def simprocArrayCore (post : Bool) (ss : SimprocsArray) (e : Expr) : SimpM Step else return .continue +def dsimprocArrayCore (post : Bool) (ss : SimprocsArray) (e : Expr) : SimpM DStep := do + let mut found := false + let mut e := e + for s in ss do + match (← dsimprocCore (post := post) (if post then s.post else s.pre) s.erased e) with + | .visit eNew => return .visit eNew + | .done eNew => return .done eNew + | .continue none => pure () + | .continue (some eNew) => + e := eNew + found := true + if found then + return .continue (some e) + else + return .continue + register_builtin_option simprocs : Bool := { defValue := true group := "backward compatibility" @@ -286,6 +347,14 @@ def userPostSimprocs (s : SimprocsArray) : Simproc := fun e => do unless simprocs.get (← getOptions) do return .continue simprocArrayCore (post := true) s e +def userPreDSimprocs (s : SimprocsArray) : DSimproc := fun e => do + unless simprocs.get (← getOptions) do return .continue + dsimprocArrayCore (post := false) s e + +def userPostDSimprocs (s : SimprocsArray) : DSimproc := fun e => do + unless simprocs.get (← getOptions) do return .continue + dsimprocArrayCore (post := true) s e + def mkSimprocExt (name : Name := by exact decl_name%) (ref? : Option (IO.Ref Simprocs)) : IO SimprocExtension := registerScopedEnvExtension { name := name diff --git a/src/Lean/Meta/Tactic/Simp/Types.lean b/src/Lean/Meta/Tactic/Simp/Types.lean index 85b580e27b..659effcec3 100644 --- a/src/Lean/Meta/Tactic/Simp/Types.lean +++ b/src/Lean/Meta/Tactic/Simp/Types.lean @@ -185,6 +185,17 @@ def andThen (f g : Simproc) : Simproc := fun e => do instance : AndThen Simproc where andThen s₁ s₂ := andThen s₁ (s₂ ()) +@[always_inline] +def dandThen (f g : DSimproc) : DSimproc := fun e => do + match (← f e) with + | .done eNew => return .done eNew + | .continue none => g e + | .continue (some eNew) => g eNew + | .visit eNew => return .visit eNew + +instance : AndThen DSimproc where + andThen s₁ s₂ := dandThen s₁ (s₂ ()) + /-- `Simproc` .olean entry. -/ @@ -217,6 +228,8 @@ structure Simprocs where structure Methods where pre : Simproc := fun _ => return .continue post : Simproc := fun e => return .done { expr := e } + dpre : DSimproc := fun _ => return .continue + dpost : DSimproc := fun e => return .done e discharge? : Expr → SimpM (Option Expr) := fun _ => return none deriving Inhabited @@ -543,6 +556,13 @@ def Step.addExtraArgs (s : Step) (extraArgs : Array Expr) : MetaM Step := do | .continue none => return .continue none | .continue (some r) => return .continue (← r.addExtraArgs extraArgs) +def DStep.addExtraArgs (s : DStep) (extraArgs : Array Expr) : DStep := + match s with + | .visit eNew => .visit (mkAppN eNew extraArgs) + | .done eNew => .done (mkAppN eNew extraArgs) + | .continue none => .continue none + | .continue (some eNew) => .continue (mkAppN eNew extraArgs) + end Simp export Simp (SimpM Simprocs) diff --git a/tests/lean/run/dsimp_bv_simproc.lean b/tests/lean/run/dsimp_bv_simproc.lean new file mode 100644 index 0000000000..06be478fc4 --- /dev/null +++ b/tests/lean/run/dsimp_bv_simproc.lean @@ -0,0 +1,30 @@ +open BitVec + +variable (write : (n : Nat) → BitVec 64 → BitVec (n * 8) → Type → Type) + +theorem write_simplify_test_0 (a x y : BitVec 64) + (h : ((8 * 8) + 8 * 8) = 2 * ((8 * 8) / 8) * 8) : + write (2 * ((8 * 8) / 8)) a (BitVec.cast h (zeroExtend (8 * 8) x ++ (zeroExtend (8 * 8) y))) s + = + write 16 a (x ++ y) s := by + simp only [zeroExtend_eq, BitVec.cast_eq] + +/-- +warning: declaration uses 'sorry' +--- +info: write : (n : Nat) → BitVec 64 → BitVec (n * 8) → Type → Type +s aux : Type +a x y : BitVec 64 +h : 128 = 128 +⊢ write 16 a (x ++ y) s = aux +-/ +#guard_msgs in +example (a x y : BitVec 64) + (h : ((8 * 8) + 8 * 8) = 2 * ((8 * 8) / 8) * 8) : + write (2 * ((8 * 8) / 8)) a (BitVec.cast h (zeroExtend (8 * 8) x ++ (zeroExtend (8 * 8) y))) s + = + aux := by + simp + dsimp at h + trace_state + sorry diff --git a/tests/lean/run/dsimproc.lean b/tests/lean/run/dsimproc.lean new file mode 100644 index 0000000000..4116bcd9ce --- /dev/null +++ b/tests/lean/run/dsimproc.lean @@ -0,0 +1,64 @@ +import Lean + +def foo (x : Nat) := x + 1 + +open Lean Meta Simp +dsimproc reduceFoo (foo _) := fun e => do + let_expr foo a ← e | return .continue + let some n ← getNatValue? a | return .continue + return .done (toExpr (n+1)) + +example (h : 3 = x) : foo 2 = x := by + simp + guard_target =ₛ 3 = x + assumption + +example (h : 3 = x) : foo 2 = x := by + fail_if_success simp [-reduceFoo] + fail_if_success simp only + simp only [reduceFoo] + guard_target =ₛ 3 = x + assumption + +def bla (x : Nat) := 2*x + +dsimproc_decl reduceBla (bla _) := fun e => do + let_expr bla a ← e | return .continue + let some n ← getNatValue? a | return .continue + return .done (toExpr (2*n)) + +example (h : 6 = x) : bla 3 = x := by + fail_if_success simp + fail_if_success simp only + simp [bla] + guard_target =ₛ 6 = x + assumption + +example (h : 6 = x) : bla 3 = x := by + fail_if_success simp + fail_if_success simp only + simp [bla] + guard_target =ₛ 6 = x + assumption + +example (h : 6 = x) : bla 3 = x := by + simp only [bla] + guard_target =ₛ 2*3 = x + assumption + +example (h : 6 = x) : bla 3 = x := by + simp only [bla, Nat.reduceMul] + guard_target =ₛ 6 = x + assumption + +attribute [simp] reduceBla + +example (h : 6 = x) : bla 3 = x := by + simp + guard_target =ₛ 6 = x + assumption + +example (h : 5 = x) : 2 + 3 = x := by + dsimp + guard_target =ₛ 5 = x + assumption