feat: use dsimprocs at dsimp

This commit is contained in:
Leonardo de Moura 2024-03-05 13:56:52 -08:00 committed by Leonardo de Moura
parent 63b068a77c
commit acdb0054d5
7 changed files with 225 additions and 22 deletions

View file

@ -434,7 +434,7 @@ where
if tactic.simp.trace.get (← getOptions) then
traceSimpCall stx usedSimps
def dsimpLocation (ctx : Simp.Context) (loc : Location) : TacticM Unit := do
def dsimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (loc : Location) : TacticM Unit := do
match loc with
| Location.targets hyps simplifyTarget =>
withMainContext do
@ -446,7 +446,7 @@ def dsimpLocation (ctx : Simp.Context) (loc : Location) : TacticM Unit := do
where
go (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) : TacticM Unit := do
let mvarId ← getMainGoal
let (result?, usedSimps) ← dsimpGoal mvarId ctx (simplifyTarget := simplifyTarget) (fvarIdsToSimp := fvarIdsToSimp)
let (result?, usedSimps) ← dsimpGoal mvarId ctx simprocs (simplifyTarget := simplifyTarget) (fvarIdsToSimp := fvarIdsToSimp)
match result? with
| none => replaceMainGoal []
| some mvarId => replaceMainGoal [mvarId]
@ -454,8 +454,8 @@ where
mvarId.withContext <| traceSimpCall (← getRef) usedSimps
@[builtin_tactic Lean.Parser.Tactic.dsimp] def evalDSimp : Tactic := fun stx => do
let { ctx, .. } ← withMainContext <| mkSimpContext stx (eraseLocal := false) (kind := .dsimp)
dsimpLocation ctx (expandOptLocation stx[5])
let { ctx, simprocs, .. } ← withMainContext <| mkSimpContext stx (eraseLocal := false) (kind := .dsimp)
dsimpLocation ctx simprocs (expandOptLocation stx[5])
end Lean.Elab.Tactic

View file

@ -400,24 +400,20 @@ def simpLet (e : Expr) : SimpM Result := do
let h ← mkLambdaFVars #[x] h
return { expr := e', proof? := some (← mkLetBodyCongr v' h) }
private def dsimpReduce : DSimproc := fun e => do
let mut eNew ← reduce e
if eNew.isFVar then
eNew ← reduceFVar (← getConfig) (← getSimpTheorems) eNew
if eNew != e then return .visit eNew else return .done e
@[export lean_dsimp]
private partial def dsimpImpl (e : Expr) : SimpM Expr := do
let cfg ← getConfig
unless cfg.dsimp do
return e
let pre (e : Expr) : SimpM TransformStep := do
if let Step.visit r ← rewritePre (rflOnly := true) e then
if r.expr != e then
return .visit r.expr
return .continue
let post (e : Expr) : SimpM TransformStep := do
if let Step.visit r ← rewritePost (rflOnly := true) e then
if r.expr != e then
return .visit r.expr
let mut eNew ← reduce e
if eNew.isFVar then
eNew ← reduceFVar cfg (← getSimpTheorems) eNew
if eNew != e then return .visit eNew else return .done e
let m ← getMethods
let pre := m.dpre
let post := m.dpost >> dsimpReduce
transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post)
def visitFn (e : Expr) : SimpM Result := do
@ -649,9 +645,9 @@ def simp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (disc
| none => Simp.main e ctx usedSimps (methods := Simp.mkDefaultMethodsCore simprocs)
| some d => Simp.main e ctx usedSimps (methods := Simp.mkMethods simprocs d)
def dsimp (e : Expr) (ctx : Simp.Context)
def dsimp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[])
(usedSimps : UsedSimps := {}) : MetaM (Expr × UsedSimps) := do profileitM Exception "dsimp" (← getOptions) do
Simp.dsimpMain e ctx usedSimps (methods := Simp.mkDefaultMethodsCore {})
Simp.dsimpMain e ctx usedSimps (methods := Simp.mkDefaultMethodsCore simprocs )
/-- See `simpTarget`. This method assumes `mvarId` is not assigned, and we are already using `mvarId`s local context. -/
def simpTargetCore (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (discharge? : Option Simp.Discharge := none)
@ -800,7 +796,7 @@ def simpTargetStar (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsAr
else
return (TacticResultCNM.modified mvarId', usedSimps')
def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[])
def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[])
(usedSimps : UsedSimps := {}) : MetaM (Option MVarId × UsedSimps) := do
mvarId.withContext do
mvarId.checkNotAssigned `simp
@ -808,7 +804,7 @@ def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := t
let mut usedSimps : UsedSimps := usedSimps
for fvarId in fvarIdsToSimp do
let type ← instantiateMVars (← fvarId.getType)
let (typeNew, usedSimps') ← dsimp type ctx
let (typeNew, usedSimps') ← dsimp type ctx simprocs
usedSimps := usedSimps'
if typeNew.isFalse then
mvarIdNew.assign (← mkFalseElim (← mvarIdNew.getType) (mkFVar fvarId))
@ -817,7 +813,7 @@ def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := t
mvarIdNew ← mvarIdNew.replaceLocalDeclDefEq fvarId typeNew
if simplifyTarget then
let target ← mvarIdNew.getType
let (targetNew, usedSimps') ← dsimp target ctx usedSimps
let (targetNew, usedSimps') ← dsimp target ctx simprocs usedSimps
usedSimps := usedSimps'
if targetNew.isTrue then
mvarIdNew.assign (mkConst ``True.intro)

View file

@ -319,6 +319,26 @@ def rewritePost (rflOnly := false) : Simproc := fun e => do
return .visit r
return .continue
def drewritePre : DSimproc := fun e => do
for thms in (← getContext).simpTheorems do
if let some r ← rewrite? e thms.pre thms.erased (tag := "pre") (rflOnly := true) then
return .visit r.expr
return .continue
def drewritePost : DSimproc := fun e => do
for thms in (← getContext).simpTheorems do
if let some r ← rewrite? e thms.post thms.erased (tag := "post") (rflOnly := true) then
return .visit r.expr
return .continue
def dpreDefault (s : SimprocsArray) : DSimproc :=
drewritePre >>
userPreDSimprocs s
def dpostDefault (s : SimprocsArray) : DSimproc :=
drewritePost >>
userPostDSimprocs s
/--
Discharge procedure for the ground/symbolic evaluator.
-/
@ -382,6 +402,8 @@ def mkSEvalMethods : CoreM Methods := do
return {
pre := preSEval #[s]
post := postSEval #[s]
dpre := dpreDefault #[s]
dpost := dpostDefault #[s]
discharge? := dischargeGround
}
@ -525,6 +547,8 @@ abbrev Discharge := Expr → SimpM (Option Expr)
def mkMethods (s : SimprocsArray) (discharge? : Discharge) : Methods := {
pre := preDefault s
post := postDefault s
dpre := dpreDefault s
dpost := dpostDefault s
discharge? := discharge?
}

View file

@ -200,6 +200,18 @@ def SimprocEntry.try (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM
let s ← proc e
s.toStep.addExtraArgs extraArgs
/-- Similar to `try`, but only consider `DSimproc` case. That is, if `s.proc` is a `Simproc`, treat it as a `.continue`. -/
def SimprocEntry.tryD (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM DStep := do
let mut extraArgs := #[]
let mut e := e
for _ in [:numExtraArgs] do
extraArgs := extraArgs.push e.appArg!
e := e.appFn!
extraArgs := extraArgs.reverse
match s.proc with
| .inl _ => return .continue
| .inr proc => return (← proc e).addExtraArgs extraArgs
def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM Step := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
if candidates.isEmpty then
@ -237,6 +249,39 @@ def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Ex
else
return .continue
def dsimprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM DStep := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
if candidates.isEmpty then
let tag := if post then "post" else "pre"
trace[Debug.Meta.Tactic.simp] "no {tag}-simprocs found for {e}"
return .continue
else
let mut e := e
let mut found := false
for (simprocEntry, numExtraArgs) in candidates do
unless erased.contains simprocEntry.declName do
let s ← simprocEntry.tryD numExtraArgs e
match s with
| .visit eNew =>
trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}"
recordSimpTheorem (.decl simprocEntry.declName post)
return .visit eNew
| .done eNew =>
trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}"
recordSimpTheorem (.decl simprocEntry.declName post)
return .done eNew
| .continue (some eNew) =>
trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}"
recordSimpTheorem (.decl simprocEntry.declName post)
e := eNew
found := true
| .continue none =>
pure ()
if found then
return .continue (some e)
else
return .continue
abbrev SimprocsArray := Array Simprocs
def SimprocsArray.add (ss : SimprocsArray) (declName : Name) (post : Bool) : CoreM SimprocsArray :=
@ -272,6 +317,22 @@ def simprocArrayCore (post : Bool) (ss : SimprocsArray) (e : Expr) : SimpM Step
else
return .continue
def dsimprocArrayCore (post : Bool) (ss : SimprocsArray) (e : Expr) : SimpM DStep := do
let mut found := false
let mut e := e
for s in ss do
match (← dsimprocCore (post := post) (if post then s.post else s.pre) s.erased e) with
| .visit eNew => return .visit eNew
| .done eNew => return .done eNew
| .continue none => pure ()
| .continue (some eNew) =>
e := eNew
found := true
if found then
return .continue (some e)
else
return .continue
register_builtin_option simprocs : Bool := {
defValue := true
group := "backward compatibility"
@ -286,6 +347,14 @@ def userPostSimprocs (s : SimprocsArray) : Simproc := fun e => do
unless simprocs.get (← getOptions) do return .continue
simprocArrayCore (post := true) s e
def userPreDSimprocs (s : SimprocsArray) : DSimproc := fun e => do
unless simprocs.get (← getOptions) do return .continue
dsimprocArrayCore (post := false) s e
def userPostDSimprocs (s : SimprocsArray) : DSimproc := fun e => do
unless simprocs.get (← getOptions) do return .continue
dsimprocArrayCore (post := true) s e
def mkSimprocExt (name : Name := by exact decl_name%) (ref? : Option (IO.Ref Simprocs)) : IO SimprocExtension :=
registerScopedEnvExtension {
name := name

View file

@ -185,6 +185,17 @@ def andThen (f g : Simproc) : Simproc := fun e => do
instance : AndThen Simproc where
andThen s₁ s₂ := andThen s₁ (s₂ ())
@[always_inline]
def dandThen (f g : DSimproc) : DSimproc := fun e => do
match (← f e) with
| .done eNew => return .done eNew
| .continue none => g e
| .continue (some eNew) => g eNew
| .visit eNew => return .visit eNew
instance : AndThen DSimproc where
andThen s₁ s₂ := dandThen s₁ (s₂ ())
/--
`Simproc` .olean entry.
-/
@ -217,6 +228,8 @@ structure Simprocs where
structure Methods where
pre : Simproc := fun _ => return .continue
post : Simproc := fun e => return .done { expr := e }
dpre : DSimproc := fun _ => return .continue
dpost : DSimproc := fun e => return .done e
discharge? : Expr → SimpM (Option Expr) := fun _ => return none
deriving Inhabited
@ -543,6 +556,13 @@ def Step.addExtraArgs (s : Step) (extraArgs : Array Expr) : MetaM Step := do
| .continue none => return .continue none
| .continue (some r) => return .continue (← r.addExtraArgs extraArgs)
def DStep.addExtraArgs (s : DStep) (extraArgs : Array Expr) : DStep :=
match s with
| .visit eNew => .visit (mkAppN eNew extraArgs)
| .done eNew => .done (mkAppN eNew extraArgs)
| .continue none => .continue none
| .continue (some eNew) => .continue (mkAppN eNew extraArgs)
end Simp
export Simp (SimpM Simprocs)

View file

@ -0,0 +1,30 @@
open BitVec
variable (write : (n : Nat) → BitVec 64 → BitVec (n * 8) → Type → Type)
theorem write_simplify_test_0 (a x y : BitVec 64)
(h : ((8 * 8) + 8 * 8) = 2 * ((8 * 8) / 8) * 8) :
write (2 * ((8 * 8) / 8)) a (BitVec.cast h (zeroExtend (8 * 8) x ++ (zeroExtend (8 * 8) y))) s
=
write 16 a (x ++ y) s := by
simp only [zeroExtend_eq, BitVec.cast_eq]
/--
warning: declaration uses 'sorry'
---
info: write : (n : Nat) → BitVec 64 → BitVec (n * 8) → Type → Type
s aux : Type
a x y : BitVec 64
h : 128 = 128
⊢ write 16 a (x ++ y) s = aux
-/
#guard_msgs in
example (a x y : BitVec 64)
(h : ((8 * 8) + 8 * 8) = 2 * ((8 * 8) / 8) * 8) :
write (2 * ((8 * 8) / 8)) a (BitVec.cast h (zeroExtend (8 * 8) x ++ (zeroExtend (8 * 8) y))) s
=
aux := by
simp
dsimp at h
trace_state
sorry

View file

@ -0,0 +1,64 @@
import Lean
def foo (x : Nat) := x + 1
open Lean Meta Simp
dsimproc reduceFoo (foo _) := fun e => do
let_expr foo a ← e | return .continue
let some n ← getNatValue? a | return .continue
return .done (toExpr (n+1))
example (h : 3 = x) : foo 2 = x := by
simp
guard_target =ₛ 3 = x
assumption
example (h : 3 = x) : foo 2 = x := by
fail_if_success simp [-reduceFoo]
fail_if_success simp only
simp only [reduceFoo]
guard_target =ₛ 3 = x
assumption
def bla (x : Nat) := 2*x
dsimproc_decl reduceBla (bla _) := fun e => do
let_expr bla a ← e | return .continue
let some n ← getNatValue? a | return .continue
return .done (toExpr (2*n))
example (h : 6 = x) : bla 3 = x := by
fail_if_success simp
fail_if_success simp only
simp [bla]
guard_target =ₛ 6 = x
assumption
example (h : 6 = x) : bla 3 = x := by
fail_if_success simp
fail_if_success simp only
simp [bla]
guard_target =ₛ 6 = x
assumption
example (h : 6 = x) : bla 3 = x := by
simp only [bla]
guard_target =ₛ 2*3 = x
assumption
example (h : 6 = x) : bla 3 = x := by
simp only [bla, Nat.reduceMul]
guard_target =ₛ 6 = x
assumption
attribute [simp] reduceBla
example (h : 6 = x) : bla 3 = x := by
simp
guard_target =ₛ 6 = x
assumption
example (h : 5 = x) : 2 + 3 = x := by
dsimp
guard_target =ₛ 5 = x
assumption