refactor: remove workaround

We don't need to keep passing `discharge?` method around anymore.
This commit is contained in:
Leonardo de Moura 2024-01-29 17:24:05 -08:00 committed by Scott Morrison
parent 01750e2139
commit 01469bdbd6
5 changed files with 38 additions and 38 deletions

View file

@ -44,7 +44,7 @@ private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
def simpMatchWF? (mvarId : MVarId) : MetaM (Option MVarId) :=
mvarId.withContext do
let target ← instantiateMVars (← mvarId.getType)
let (targetNew, _) ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre })
let (targetNew, _) ← Simp.main target (← Split.getSimpMatchContext) (methods := { pre, discharge? := SplitIf.discharge? })
let mvarIdNew ← applySimpResultToTarget mvarId target targetNew
if mvarId != mvarIdNew then return some mvarIdNew else return none
where
@ -54,7 +54,7 @@ where
-- First try to reduce matcher
match (← reduceRecMatcher? e) with
| some e' => return Simp.Step.done { expr := e' }
| none => Simp.simpMatchCore app.matcherName SplitIf.discharge? e
| none => Simp.simpMatchCore app.matcherName e
/--
Given a goal of the form `|- f.{us} a_1 ... a_n b_1 ... b_m = ...`, return `(us, #[a_1, ..., a_n])`

View file

@ -403,12 +403,12 @@ private partial def dsimpImpl (e : Expr) : SimpM Expr := do
unless cfg.dsimp do
return e
let pre (e : Expr) : SimpM TransformStep := do
if let Step.visit r ← rewritePre (discharge? := fun _ => pure none) (rflOnly := true) e then
if let Step.visit r ← rewritePre (rflOnly := true) e then
if r.expr != e then
return .visit r.expr
return .continue
let post (e : Expr) : SimpM TransformStep := do
if let Step.visit r ← rewritePost (discharge? := fun _ => pure none) (rflOnly := true) e then
if let Step.visit r ← rewritePost (rflOnly := true) e then
if r.expr != e then
return .visit r.expr
let mut eNew ← reduce e
@ -504,7 +504,7 @@ def trySimpCongrTheorem? (c : SimpCongrTheorem) (e : Expr) : SimpM (Option Resul
unless modified do
trace[Meta.Tactic.simp.congr] "{c.theoremName} not modified"
return none
unless (← synthesizeArgs (.decl c.theoremName) xs bis discharge?) do
unless (← synthesizeArgs (.decl c.theoremName) xs bis) do
trace[Meta.Tactic.simp.congr] "{c.theoremName} synthesizeArgs failed"
return none
let eNew ← instantiateMVars rhs

View file

@ -14,7 +14,7 @@ import Lean.Meta.Tactic.Simp.Simproc
namespace Lean.Meta.Simp
def synthesizeArgs (thmId : Origin) (xs : Array Expr) (bis : Array BinderInfo) (discharge? : Expr → SimpM (Option Expr)) : SimpM Bool := do
def synthesizeArgs (thmId : Origin) (xs : Array Expr) (bis : Array BinderInfo) : SimpM Bool := do
for x in xs, bi in bis do
let type ← inferType x
-- Note that the binderInfo may be misleading here:
@ -59,10 +59,10 @@ where
trace[Meta.Tactic.simp.discharge] "{← ppOrigin thmId}, failed to synthesize instance{indentExpr type}"
return false
private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := do
private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) : SimpM (Option Result) := do
let rec go (e : Expr) : SimpM (Option Result) := do
if (← isDefEq lhs e) then
unless (← synthesizeArgs thm.origin xs bis discharge?) do
unless (← synthesizeArgs thm.origin xs bis) do
return none
let proof? ← if thm.rfl then
pure none
@ -109,36 +109,36 @@ private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInf
return none
r.addExtraArgs extraArgs
def tryTheoremWithExtraArgs? (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) :=
def tryTheoremWithExtraArgs? (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) : SimpM (Option Result) :=
withNewMCtxDepth do
let val ← thm.getValue
let type ← inferType val
let (xs, bis, type) ← forallMetaTelescopeReducing type
let type ← whnf (← instantiateMVars type)
let lhs := type.appFn!.appArg!
tryTheoremCore lhs xs bis val type e thm numExtraArgs discharge?
tryTheoremCore lhs xs bis val type e thm numExtraArgs
def tryTheorem? (e : Expr) (thm : SimpTheorem) (discharge? : Expr → SimpM (Option Expr)) : SimpM (Option Result) := do
def tryTheorem? (e : Expr) (thm : SimpTheorem) : SimpM (Option Result) := do
withNewMCtxDepth do
let val ← thm.getValue
let type ← inferType val
let (xs, bis, type) ← forallMetaTelescopeReducing type
let type ← whnf (← instantiateMVars type)
let lhs := type.appFn!.appArg!
match (← tryTheoremCore lhs xs bis val type e thm 0 discharge?) with
match (← tryTheoremCore lhs xs bis val type e thm 0) with
| some result => return some result
| none =>
let lhsNumArgs := lhs.getAppNumArgs
let eNumArgs := e.getAppNumArgs
if eNumArgs > lhsNumArgs then
tryTheoremCore lhs xs bis val type e thm (eNumArgs - lhsNumArgs) discharge?
tryTheoremCore lhs xs bis val type e thm (eNumArgs - lhsNumArgs)
else
return none
/--
Remark: the parameter tag is used for creating trace messages. It is irrelevant otherwise.
-/
def rewrite? (e : Expr) (s : SimpTheoremTree) (erased : PHashSet Origin) (discharge? : Expr → SimpM (Option Expr)) (tag : String) (rflOnly : Bool) : SimpM (Option Result) := do
def rewrite? (e : Expr) (s : SimpTheoremTree) (erased : PHashSet Origin) (tag : String) (rflOnly : Bool) : SimpM (Option Result) := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
@ -147,7 +147,7 @@ def rewrite? (e : Expr) (s : SimpTheoremTree) (erased : PHashSet Origin) (discha
let candidates := candidates.insertionSort fun e₁ e₂ => e₁.1.priority > e₂.1.priority
for (thm, numExtraArgs) in candidates do
unless inErasedSet thm || (rflOnly && !thm.rfl) do
if let some result ← tryTheoremWithExtraArgs? e thm numExtraArgs discharge? then
if let some result ← tryTheoremWithExtraArgs? e thm numExtraArgs then
trace[Debug.Meta.Tactic.simp] "rewrite result {e} => {result.expr}"
return some result
return none
@ -239,15 +239,15 @@ def simpMatchDiscrs? (info : MatcherInfo) (e : Expr) : SimpM (Option Result) :=
r ← mkCongrFun r arg
return some r
def simpMatchCore (matcherName : Name) (discharge? : Expr → SimpM (Option Expr)) (e : Expr) : SimpM Step := do
def simpMatchCore (matcherName : Name) (e : Expr) : SimpM Step := do
for matchEq in (← Match.getEquationsFor matcherName).eqnNames do
-- Try lemma
match (← withReducible <| Simp.tryTheorem? e { origin := .decl matchEq, proof := mkConst matchEq, rfl := (← isRflTheorem matchEq) } discharge?) with
match (← withReducible <| Simp.tryTheorem? e { origin := .decl matchEq, proof := mkConst matchEq, rfl := (← isRflTheorem matchEq) }) with
| none => pure ()
| some r => return .visit r
return .continue
def simpMatch (discharge? : Expr → SimpM (Option Expr)) : Simproc := fun e => do
def simpMatch : Simproc := fun e => do
unless (← getConfig).iota do
return .continue
if let some e ← reduceRecMatcher? e then
@ -258,24 +258,24 @@ def simpMatch (discharge? : Expr → SimpM (Option Expr)) : Simproc := fun e =>
| return .continue
if let some r ← simpMatchDiscrs? info e then
return .visit r
simpMatchCore declName discharge? e
simpMatchCore declName e
def rewritePre (discharge? : Expr → SimpM (Option Expr)) (rflOnly := false) : Simproc := fun e => do
def rewritePre (rflOnly := false) : Simproc := fun e => do
for thms in (← getContext).simpTheorems do
if let some r ← rewrite? e thms.pre thms.erased discharge? (tag := "pre") (rflOnly := rflOnly) then
if let some r ← rewrite? e thms.pre thms.erased (tag := "pre") (rflOnly := rflOnly) then
return .visit r
return .continue
def rewritePost (discharge? : Expr → SimpM (Option Expr)) (rflOnly := false) : Simproc := fun e => do
def rewritePost (rflOnly := false) : Simproc := fun e => do
for thms in (← getContext).simpTheorems do
if let some r ← rewrite? e thms.post thms.erased discharge? (tag := "post") (rflOnly := rflOnly) then
if let some r ← rewrite? e thms.post thms.erased (tag := "post") (rflOnly := rflOnly) then
return .visit r
return .continue
/--
Try to unfold ground term when `Context.unfoldGround := true`.
-/
def simpGround (discharge? : Expr → SimpM (Option Expr)) : Simproc := fun e => do
def simpGround : Simproc := fun e => do
-- Ground term unfolding is disabled.
unless (← getContext).unfoldGround do return .continue
-- `e` is not a ground term.
@ -292,7 +292,7 @@ def simpGround (discharge? : Expr → SimpM (Option Expr)) : Simproc := fun e =>
-- `declName` has equation theorems associated with it.
for eqn in eqns do
-- TODO: cache SimpTheorem to avoid calls to `isRflTheorem`
if let some result ← Simp.tryTheorem? e { origin := .decl eqn, proof := mkConst eqn, rfl := (← isRflTheorem eqn) } discharge? then
if let some result ← Simp.tryTheorem? e { origin := .decl eqn, proof := mkConst eqn, rfl := (← isRflTheorem eqn) } then
trace[Meta.Tactic.simp.ground] "unfolded, {e} => {result.expr}"
return .visit result
return .continue
@ -308,16 +308,16 @@ def simpGround (discharge? : Expr → SimpM (Option Expr)) : Simproc := fun e =>
trace[Meta.Tactic.simp.ground] "delta, {e} => {eNew}"
return .visit { expr := eNew }
partial def preDefault (s : SimprocsArray) (discharge? : Expr → SimpM (Option Expr)) : Simproc :=
rewritePre discharge? >>
simpMatch discharge? >>
partial def preDefault (s : SimprocsArray) : Simproc :=
rewritePre >>
simpMatch >>
userPreSimprocs s >>
simpUsingDecide
def postDefault (s : SimprocsArray) (discharge? : Expr → SimpM (Option Expr)) : Simproc :=
rewritePost discharge? >>
def postDefault (s : SimprocsArray) : Simproc :=
rewritePost >>
userPostSimprocs s >>
simpGround discharge? >>
simpGround >>
simpArith >>
simpCtorEq >>
simpUsingDecide
@ -411,8 +411,8 @@ def dischargeDefault? (e : Expr) : SimpM (Option Expr) := do
abbrev Discharge := Expr → SimpM (Option Expr)
def mkMethods (s : SimprocsArray) (discharge? : Discharge) : Methods := {
pre := preDefault s discharge?
post := postDefault s discharge?
pre := preDefault s
post := postDefault s
discharge? := discharge?
}

View file

@ -19,7 +19,7 @@ def getSimpMatchContext : MetaM Simp.Context :=
}
def simpMatch (e : Expr) : MetaM Simp.Result := do
(·.1) <$> Simp.main e (← getSimpMatchContext) (methods := { pre })
(·.1) <$> Simp.main e (← getSimpMatchContext) (methods := { pre, discharge? := SplitIf.discharge? })
where
pre (e : Expr) : SimpM Simp.Step := do
unless (← isMatcherApp e) do
@ -28,7 +28,7 @@ where
-- First try to reduce matcher
match (← reduceRecMatcher? e) with
| some e' => return Simp.Step.done { expr := e' }
| none => Simp.simpMatchCore matcherDeclName SplitIf.discharge? e
| none => Simp.simpMatchCore matcherDeclName e
def simpMatchTarget (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target ← instantiateMVars (← mvarId.getType)
@ -36,7 +36,7 @@ def simpMatchTarget (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
applySimpResultToTarget mvarId target r
private def simpMatchCore (matchDeclName : Name) (matchEqDeclName : Name) (e : Expr) : MetaM Simp.Result := do
(·.1) <$> Simp.main e (← getSimpMatchContext) (methods := { pre })
(·.1) <$> Simp.main e (← getSimpMatchContext) (methods := { pre, discharge? := SplitIf.discharge? })
where
pre (e : Expr) : SimpM Simp.Step := do
if e.isAppOf matchDeclName then
@ -50,7 +50,7 @@ where
proof := mkConst matchEqDeclName
rfl := (← isRflTheorem matchEqDeclName)
}
match (← withReducible <| Simp.tryTheorem? e simpTheorem SplitIf.discharge?) with
match (← withReducible <| Simp.tryTheorem? e simpTheorem) with
| none => return .continue
| some r => return .done r
else

View file

@ -22,7 +22,7 @@ def unfold (e : Expr) (declName : Name) : MetaM Simp.Result := do
return { expr := (← deltaExpand e (· == declName)) }
where
pre (unfoldThm : Name) (e : Expr) : SimpM Simp.Step := do
match (← withReducible <| Simp.tryTheorem? e { origin := .decl unfoldThm, proof := mkConst unfoldThm, rfl := (← isRflTheorem unfoldThm) } (fun _ => return none)) with
match (← withReducible <| Simp.tryTheorem? e { origin := .decl unfoldThm, proof := mkConst unfoldThm, rfl := (← isRflTheorem unfoldThm) }) with
| none => pure ()
| some r => match (← reduceMatcher? r.expr) with
| .reduced e' => return .done { r with expr := e' }