feat: use dsimprocs at dsimp
This commit is contained in:
parent
63b068a77c
commit
acdb0054d5
7 changed files with 225 additions and 22 deletions
|
|
@ -434,7 +434,7 @@ where
|
|||
if tactic.simp.trace.get (← getOptions) then
|
||||
traceSimpCall stx usedSimps
|
||||
|
||||
def dsimpLocation (ctx : Simp.Context) (loc : Location) : TacticM Unit := do
|
||||
def dsimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (loc : Location) : TacticM Unit := do
|
||||
match loc with
|
||||
| Location.targets hyps simplifyTarget =>
|
||||
withMainContext do
|
||||
|
|
@ -446,7 +446,7 @@ def dsimpLocation (ctx : Simp.Context) (loc : Location) : TacticM Unit := do
|
|||
where
|
||||
go (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) : TacticM Unit := do
|
||||
let mvarId ← getMainGoal
|
||||
let (result?, usedSimps) ← dsimpGoal mvarId ctx (simplifyTarget := simplifyTarget) (fvarIdsToSimp := fvarIdsToSimp)
|
||||
let (result?, usedSimps) ← dsimpGoal mvarId ctx simprocs (simplifyTarget := simplifyTarget) (fvarIdsToSimp := fvarIdsToSimp)
|
||||
match result? with
|
||||
| none => replaceMainGoal []
|
||||
| some mvarId => replaceMainGoal [mvarId]
|
||||
|
|
@ -454,8 +454,8 @@ where
|
|||
mvarId.withContext <| traceSimpCall (← getRef) usedSimps
|
||||
|
||||
@[builtin_tactic Lean.Parser.Tactic.dsimp] def evalDSimp : Tactic := fun stx => do
|
||||
let { ctx, .. } ← withMainContext <| mkSimpContext stx (eraseLocal := false) (kind := .dsimp)
|
||||
dsimpLocation ctx (expandOptLocation stx[5])
|
||||
let { ctx, simprocs, .. } ← withMainContext <| mkSimpContext stx (eraseLocal := false) (kind := .dsimp)
|
||||
dsimpLocation ctx simprocs (expandOptLocation stx[5])
|
||||
|
||||
end Lean.Elab.Tactic
|
||||
|
||||
|
|
|
|||
|
|
@ -400,24 +400,20 @@ def simpLet (e : Expr) : SimpM Result := do
|
|||
let h ← mkLambdaFVars #[x] h
|
||||
return { expr := e', proof? := some (← mkLetBodyCongr v' h) }
|
||||
|
||||
private def dsimpReduce : DSimproc := fun e => do
|
||||
let mut eNew ← reduce e
|
||||
if eNew.isFVar then
|
||||
eNew ← reduceFVar (← getConfig) (← getSimpTheorems) eNew
|
||||
if eNew != e then return .visit eNew else return .done e
|
||||
|
||||
@[export lean_dsimp]
|
||||
private partial def dsimpImpl (e : Expr) : SimpM Expr := do
|
||||
let cfg ← getConfig
|
||||
unless cfg.dsimp do
|
||||
return e
|
||||
let pre (e : Expr) : SimpM TransformStep := do
|
||||
if let Step.visit r ← rewritePre (rflOnly := true) e then
|
||||
if r.expr != e then
|
||||
return .visit r.expr
|
||||
return .continue
|
||||
let post (e : Expr) : SimpM TransformStep := do
|
||||
if let Step.visit r ← rewritePost (rflOnly := true) e then
|
||||
if r.expr != e then
|
||||
return .visit r.expr
|
||||
let mut eNew ← reduce e
|
||||
if eNew.isFVar then
|
||||
eNew ← reduceFVar cfg (← getSimpTheorems) eNew
|
||||
if eNew != e then return .visit eNew else return .done e
|
||||
let m ← getMethods
|
||||
let pre := m.dpre
|
||||
let post := m.dpost >> dsimpReduce
|
||||
transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post)
|
||||
|
||||
def visitFn (e : Expr) : SimpM Result := do
|
||||
|
|
@ -649,9 +645,9 @@ def simp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (disc
|
|||
| none => Simp.main e ctx usedSimps (methods := Simp.mkDefaultMethodsCore simprocs)
|
||||
| some d => Simp.main e ctx usedSimps (methods := Simp.mkMethods simprocs d)
|
||||
|
||||
def dsimp (e : Expr) (ctx : Simp.Context)
|
||||
def dsimp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[])
|
||||
(usedSimps : UsedSimps := {}) : MetaM (Expr × UsedSimps) := do profileitM Exception "dsimp" (← getOptions) do
|
||||
Simp.dsimpMain e ctx usedSimps (methods := Simp.mkDefaultMethodsCore {})
|
||||
Simp.dsimpMain e ctx usedSimps (methods := Simp.mkDefaultMethodsCore simprocs )
|
||||
|
||||
/-- See `simpTarget`. This method assumes `mvarId` is not assigned, and we are already using `mvarId`s local context. -/
|
||||
def simpTargetCore (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (discharge? : Option Simp.Discharge := none)
|
||||
|
|
@ -800,7 +796,7 @@ def simpTargetStar (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsAr
|
|||
else
|
||||
return (TacticResultCNM.modified mvarId', usedSimps')
|
||||
|
||||
def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[])
|
||||
def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[])
|
||||
(usedSimps : UsedSimps := {}) : MetaM (Option MVarId × UsedSimps) := do
|
||||
mvarId.withContext do
|
||||
mvarId.checkNotAssigned `simp
|
||||
|
|
@ -808,7 +804,7 @@ def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := t
|
|||
let mut usedSimps : UsedSimps := usedSimps
|
||||
for fvarId in fvarIdsToSimp do
|
||||
let type ← instantiateMVars (← fvarId.getType)
|
||||
let (typeNew, usedSimps') ← dsimp type ctx
|
||||
let (typeNew, usedSimps') ← dsimp type ctx simprocs
|
||||
usedSimps := usedSimps'
|
||||
if typeNew.isFalse then
|
||||
mvarIdNew.assign (← mkFalseElim (← mvarIdNew.getType) (mkFVar fvarId))
|
||||
|
|
@ -817,7 +813,7 @@ def dsimpGoal (mvarId : MVarId) (ctx : Simp.Context) (simplifyTarget : Bool := t
|
|||
mvarIdNew ← mvarIdNew.replaceLocalDeclDefEq fvarId typeNew
|
||||
if simplifyTarget then
|
||||
let target ← mvarIdNew.getType
|
||||
let (targetNew, usedSimps') ← dsimp target ctx usedSimps
|
||||
let (targetNew, usedSimps') ← dsimp target ctx simprocs usedSimps
|
||||
usedSimps := usedSimps'
|
||||
if targetNew.isTrue then
|
||||
mvarIdNew.assign (mkConst ``True.intro)
|
||||
|
|
|
|||
|
|
@ -319,6 +319,26 @@ def rewritePost (rflOnly := false) : Simproc := fun e => do
|
|||
return .visit r
|
||||
return .continue
|
||||
|
||||
def drewritePre : DSimproc := fun e => do
|
||||
for thms in (← getContext).simpTheorems do
|
||||
if let some r ← rewrite? e thms.pre thms.erased (tag := "pre") (rflOnly := true) then
|
||||
return .visit r.expr
|
||||
return .continue
|
||||
|
||||
def drewritePost : DSimproc := fun e => do
|
||||
for thms in (← getContext).simpTheorems do
|
||||
if let some r ← rewrite? e thms.post thms.erased (tag := "post") (rflOnly := true) then
|
||||
return .visit r.expr
|
||||
return .continue
|
||||
|
||||
def dpreDefault (s : SimprocsArray) : DSimproc :=
|
||||
drewritePre >>
|
||||
userPreDSimprocs s
|
||||
|
||||
def dpostDefault (s : SimprocsArray) : DSimproc :=
|
||||
drewritePost >>
|
||||
userPostDSimprocs s
|
||||
|
||||
/--
|
||||
Discharge procedure for the ground/symbolic evaluator.
|
||||
-/
|
||||
|
|
@ -382,6 +402,8 @@ def mkSEvalMethods : CoreM Methods := do
|
|||
return {
|
||||
pre := preSEval #[s]
|
||||
post := postSEval #[s]
|
||||
dpre := dpreDefault #[s]
|
||||
dpost := dpostDefault #[s]
|
||||
discharge? := dischargeGround
|
||||
}
|
||||
|
||||
|
|
@ -525,6 +547,8 @@ abbrev Discharge := Expr → SimpM (Option Expr)
|
|||
def mkMethods (s : SimprocsArray) (discharge? : Discharge) : Methods := {
|
||||
pre := preDefault s
|
||||
post := postDefault s
|
||||
dpre := dpreDefault s
|
||||
dpost := dpostDefault s
|
||||
discharge? := discharge?
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -200,6 +200,18 @@ def SimprocEntry.try (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM
|
|||
let s ← proc e
|
||||
s.toStep.addExtraArgs extraArgs
|
||||
|
||||
/-- Similar to `try`, but only consider `DSimproc` case. That is, if `s.proc` is a `Simproc`, treat it as a `.continue`. -/
|
||||
def SimprocEntry.tryD (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM DStep := do
|
||||
let mut extraArgs := #[]
|
||||
let mut e := e
|
||||
for _ in [:numExtraArgs] do
|
||||
extraArgs := extraArgs.push e.appArg!
|
||||
e := e.appFn!
|
||||
extraArgs := extraArgs.reverse
|
||||
match s.proc with
|
||||
| .inl _ => return .continue
|
||||
| .inr proc => return (← proc e).addExtraArgs extraArgs
|
||||
|
||||
def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM Step := do
|
||||
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
|
||||
if candidates.isEmpty then
|
||||
|
|
@ -237,6 +249,39 @@ def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Ex
|
|||
else
|
||||
return .continue
|
||||
|
||||
def dsimprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM DStep := do
|
||||
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
|
||||
if candidates.isEmpty then
|
||||
let tag := if post then "post" else "pre"
|
||||
trace[Debug.Meta.Tactic.simp] "no {tag}-simprocs found for {e}"
|
||||
return .continue
|
||||
else
|
||||
let mut e := e
|
||||
let mut found := false
|
||||
for (simprocEntry, numExtraArgs) in candidates do
|
||||
unless erased.contains simprocEntry.declName do
|
||||
let s ← simprocEntry.tryD numExtraArgs e
|
||||
match s with
|
||||
| .visit eNew =>
|
||||
trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}"
|
||||
recordSimpTheorem (.decl simprocEntry.declName post)
|
||||
return .visit eNew
|
||||
| .done eNew =>
|
||||
trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}"
|
||||
recordSimpTheorem (.decl simprocEntry.declName post)
|
||||
return .done eNew
|
||||
| .continue (some eNew) =>
|
||||
trace[Debug.Meta.Tactic.simp] "simproc result {e} => {eNew}"
|
||||
recordSimpTheorem (.decl simprocEntry.declName post)
|
||||
e := eNew
|
||||
found := true
|
||||
| .continue none =>
|
||||
pure ()
|
||||
if found then
|
||||
return .continue (some e)
|
||||
else
|
||||
return .continue
|
||||
|
||||
abbrev SimprocsArray := Array Simprocs
|
||||
|
||||
def SimprocsArray.add (ss : SimprocsArray) (declName : Name) (post : Bool) : CoreM SimprocsArray :=
|
||||
|
|
@ -272,6 +317,22 @@ def simprocArrayCore (post : Bool) (ss : SimprocsArray) (e : Expr) : SimpM Step
|
|||
else
|
||||
return .continue
|
||||
|
||||
def dsimprocArrayCore (post : Bool) (ss : SimprocsArray) (e : Expr) : SimpM DStep := do
|
||||
let mut found := false
|
||||
let mut e := e
|
||||
for s in ss do
|
||||
match (← dsimprocCore (post := post) (if post then s.post else s.pre) s.erased e) with
|
||||
| .visit eNew => return .visit eNew
|
||||
| .done eNew => return .done eNew
|
||||
| .continue none => pure ()
|
||||
| .continue (some eNew) =>
|
||||
e := eNew
|
||||
found := true
|
||||
if found then
|
||||
return .continue (some e)
|
||||
else
|
||||
return .continue
|
||||
|
||||
register_builtin_option simprocs : Bool := {
|
||||
defValue := true
|
||||
group := "backward compatibility"
|
||||
|
|
@ -286,6 +347,14 @@ def userPostSimprocs (s : SimprocsArray) : Simproc := fun e => do
|
|||
unless simprocs.get (← getOptions) do return .continue
|
||||
simprocArrayCore (post := true) s e
|
||||
|
||||
def userPreDSimprocs (s : SimprocsArray) : DSimproc := fun e => do
|
||||
unless simprocs.get (← getOptions) do return .continue
|
||||
dsimprocArrayCore (post := false) s e
|
||||
|
||||
def userPostDSimprocs (s : SimprocsArray) : DSimproc := fun e => do
|
||||
unless simprocs.get (← getOptions) do return .continue
|
||||
dsimprocArrayCore (post := true) s e
|
||||
|
||||
def mkSimprocExt (name : Name := by exact decl_name%) (ref? : Option (IO.Ref Simprocs)) : IO SimprocExtension :=
|
||||
registerScopedEnvExtension {
|
||||
name := name
|
||||
|
|
|
|||
|
|
@ -185,6 +185,17 @@ def andThen (f g : Simproc) : Simproc := fun e => do
|
|||
instance : AndThen Simproc where
|
||||
andThen s₁ s₂ := andThen s₁ (s₂ ())
|
||||
|
||||
@[always_inline]
|
||||
def dandThen (f g : DSimproc) : DSimproc := fun e => do
|
||||
match (← f e) with
|
||||
| .done eNew => return .done eNew
|
||||
| .continue none => g e
|
||||
| .continue (some eNew) => g eNew
|
||||
| .visit eNew => return .visit eNew
|
||||
|
||||
instance : AndThen DSimproc where
|
||||
andThen s₁ s₂ := dandThen s₁ (s₂ ())
|
||||
|
||||
/--
|
||||
`Simproc` .olean entry.
|
||||
-/
|
||||
|
|
@ -217,6 +228,8 @@ structure Simprocs where
|
|||
structure Methods where
|
||||
pre : Simproc := fun _ => return .continue
|
||||
post : Simproc := fun e => return .done { expr := e }
|
||||
dpre : DSimproc := fun _ => return .continue
|
||||
dpost : DSimproc := fun e => return .done e
|
||||
discharge? : Expr → SimpM (Option Expr) := fun _ => return none
|
||||
deriving Inhabited
|
||||
|
||||
|
|
@ -543,6 +556,13 @@ def Step.addExtraArgs (s : Step) (extraArgs : Array Expr) : MetaM Step := do
|
|||
| .continue none => return .continue none
|
||||
| .continue (some r) => return .continue (← r.addExtraArgs extraArgs)
|
||||
|
||||
def DStep.addExtraArgs (s : DStep) (extraArgs : Array Expr) : DStep :=
|
||||
match s with
|
||||
| .visit eNew => .visit (mkAppN eNew extraArgs)
|
||||
| .done eNew => .done (mkAppN eNew extraArgs)
|
||||
| .continue none => .continue none
|
||||
| .continue (some eNew) => .continue (mkAppN eNew extraArgs)
|
||||
|
||||
end Simp
|
||||
|
||||
export Simp (SimpM Simprocs)
|
||||
|
|
|
|||
30
tests/lean/run/dsimp_bv_simproc.lean
Normal file
30
tests/lean/run/dsimp_bv_simproc.lean
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
open BitVec
|
||||
|
||||
variable (write : (n : Nat) → BitVec 64 → BitVec (n * 8) → Type → Type)
|
||||
|
||||
theorem write_simplify_test_0 (a x y : BitVec 64)
|
||||
(h : ((8 * 8) + 8 * 8) = 2 * ((8 * 8) / 8) * 8) :
|
||||
write (2 * ((8 * 8) / 8)) a (BitVec.cast h (zeroExtend (8 * 8) x ++ (zeroExtend (8 * 8) y))) s
|
||||
=
|
||||
write 16 a (x ++ y) s := by
|
||||
simp only [zeroExtend_eq, BitVec.cast_eq]
|
||||
|
||||
/--
|
||||
warning: declaration uses 'sorry'
|
||||
---
|
||||
info: write : (n : Nat) → BitVec 64 → BitVec (n * 8) → Type → Type
|
||||
s aux : Type
|
||||
a x y : BitVec 64
|
||||
h : 128 = 128
|
||||
⊢ write 16 a (x ++ y) s = aux
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (a x y : BitVec 64)
|
||||
(h : ((8 * 8) + 8 * 8) = 2 * ((8 * 8) / 8) * 8) :
|
||||
write (2 * ((8 * 8) / 8)) a (BitVec.cast h (zeroExtend (8 * 8) x ++ (zeroExtend (8 * 8) y))) s
|
||||
=
|
||||
aux := by
|
||||
simp
|
||||
dsimp at h
|
||||
trace_state
|
||||
sorry
|
||||
64
tests/lean/run/dsimproc.lean
Normal file
64
tests/lean/run/dsimproc.lean
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import Lean
|
||||
|
||||
def foo (x : Nat) := x + 1
|
||||
|
||||
open Lean Meta Simp
|
||||
dsimproc reduceFoo (foo _) := fun e => do
|
||||
let_expr foo a ← e | return .continue
|
||||
let some n ← getNatValue? a | return .continue
|
||||
return .done (toExpr (n+1))
|
||||
|
||||
example (h : 3 = x) : foo 2 = x := by
|
||||
simp
|
||||
guard_target =ₛ 3 = x
|
||||
assumption
|
||||
|
||||
example (h : 3 = x) : foo 2 = x := by
|
||||
fail_if_success simp [-reduceFoo]
|
||||
fail_if_success simp only
|
||||
simp only [reduceFoo]
|
||||
guard_target =ₛ 3 = x
|
||||
assumption
|
||||
|
||||
def bla (x : Nat) := 2*x
|
||||
|
||||
dsimproc_decl reduceBla (bla _) := fun e => do
|
||||
let_expr bla a ← e | return .continue
|
||||
let some n ← getNatValue? a | return .continue
|
||||
return .done (toExpr (2*n))
|
||||
|
||||
example (h : 6 = x) : bla 3 = x := by
|
||||
fail_if_success simp
|
||||
fail_if_success simp only
|
||||
simp [bla]
|
||||
guard_target =ₛ 6 = x
|
||||
assumption
|
||||
|
||||
example (h : 6 = x) : bla 3 = x := by
|
||||
fail_if_success simp
|
||||
fail_if_success simp only
|
||||
simp [bla]
|
||||
guard_target =ₛ 6 = x
|
||||
assumption
|
||||
|
||||
example (h : 6 = x) : bla 3 = x := by
|
||||
simp only [bla]
|
||||
guard_target =ₛ 2*3 = x
|
||||
assumption
|
||||
|
||||
example (h : 6 = x) : bla 3 = x := by
|
||||
simp only [bla, Nat.reduceMul]
|
||||
guard_target =ₛ 6 = x
|
||||
assumption
|
||||
|
||||
attribute [simp] reduceBla
|
||||
|
||||
example (h : 6 = x) : bla 3 = x := by
|
||||
simp
|
||||
guard_target =ₛ 6 = x
|
||||
assumption
|
||||
|
||||
example (h : 5 = x) : 2 + 3 = x := by
|
||||
dsimp
|
||||
guard_target =ₛ 5 = x
|
||||
assumption
|
||||
Loading…
Add table
Reference in a new issue