refactor: cleanup MatchEqs and simplify SplitIf
This commit is contained in:
parent
aa177dacc3
commit
45d3b85d5a
3 changed files with 61 additions and 109 deletions
|
|
@ -105,17 +105,23 @@ where
|
|||
|
||||
proveLoop (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do
|
||||
let mvarId ← modifyTargetEqLHS mvarId whnfCore
|
||||
trace[Meta.debug] "proveLoop\n{MessageData.ofGoal mvarId}"
|
||||
(applyRefl mvarId)
|
||||
<|>
|
||||
(contradiction mvarId)
|
||||
<|>
|
||||
(commitIfNoEx do
|
||||
let s::ss ← splitIfGoal mvarId | failed
|
||||
if ss.isEmpty && s.mvarId == mvarId then failed
|
||||
(s::ss).forM fun s => proveLoop s.mvarId (depth + 1))
|
||||
(do let mvarId' ← simpIfTarget mvarId (useDecide := true)
|
||||
trace[Meta.debug] "simpIfTarget\n{MessageData.ofGoal mvarId'}"
|
||||
if mvarId' == mvarId then failed
|
||||
proveLoop mvarId' (depth+1))
|
||||
<|>
|
||||
(do
|
||||
trace[Meta.debug] "TODO\n{← ppGoal mvarId}"
|
||||
(do if let some (s₁, s₂) ← splitIfTarget? mvarId then
|
||||
proveLoop s₁.mvarId (depth+1)
|
||||
proveLoop s₂.mvarId (depth+1)
|
||||
else
|
||||
failed)
|
||||
<|>
|
||||
(do trace[Meta.debug] "TODO\n{← ppGoal mvarId}"
|
||||
-- TODO
|
||||
admit mvarId)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,12 +28,26 @@ builtin_initialize ext : LazyInitExtension MetaM Simp.Context ←
|
|||
config.decide := false
|
||||
}
|
||||
|
||||
/--
|
||||
Default `Simp.Context` for `simpIf` methods. It contains all congruence lemmas, but
|
||||
just the rewriting rules for reducing `if` expressions. -/
|
||||
def getSimpContext : MetaM Simp.Context :=
|
||||
ext.get
|
||||
|
||||
def discharge? : Simp.Discharge := fun prop => do
|
||||
/--
|
||||
Default `discharge?` function for `simpIf` methods.
|
||||
It only uses hypotheses from the local context. It is effective
|
||||
after a case-split. -/
|
||||
def discharge? (useDecide := false) : Simp.Discharge := fun prop => do
|
||||
let prop ← instantiateMVars prop
|
||||
trace[Meta.splitIf] "discharge? {prop}, {prop.notNot?}"
|
||||
trace[Meta.Tactic.splitIf] "discharge? {prop}, {prop.notNot?}"
|
||||
if useDecide then
|
||||
let prop ← instantiateMVars prop
|
||||
if !prop.hasFVar && !prop.hasMVar then
|
||||
let d ← mkDecide prop
|
||||
let r ← withDefault <| whnf d
|
||||
if r.isConstOf ``true then
|
||||
return some <| mkApp3 (mkConst ``of_decide_eq_true) prop d.appArg! (← mkEqRefl (mkConst ``true))
|
||||
(← getLCtx).findDeclRevM? fun localDecl => do
|
||||
if localDecl.isAuxDecl then
|
||||
return none
|
||||
|
|
@ -47,115 +61,47 @@ def discharge? : Simp.Discharge := fun prop => do
|
|||
else
|
||||
return none
|
||||
|
||||
/-- Return the condition of an `if` expression to case split. -/
|
||||
partial def findIfToSplit? (e : Expr) : Option Expr :=
|
||||
if let some iteApp := e.find? fun e => !e.hasLooseBVars && (e.isAppOfArity ``ite 5 || e.isAppOfArity ``dite 5) then
|
||||
let cond := iteApp.getArg! 1 5
|
||||
-- Try to find a nested `if` in `cond`
|
||||
findIfToSplit? cond |>.getD cond
|
||||
else
|
||||
none
|
||||
|
||||
def simpIfTarget (mvarId : MVarId) : MetaM MVarId := do
|
||||
trace[Meta.splitIf] "before simpIfTarget\n{MessageData.ofGoal mvarId}"
|
||||
if let some mvarId ← simpTarget mvarId (← getSimpContext) discharge? then
|
||||
trace[Meta.splitIf] "after simpIfTarget\n{MessageData.ofGoal mvarId}"
|
||||
return mvarId
|
||||
def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := do
|
||||
if let some cond := findIfToSplit? e then
|
||||
let hName ← match hName? with
|
||||
| none => mkFreshUserName `h
|
||||
| some hName => pure hName
|
||||
trace[Meta.Tactic.splitIf] "splitting on {cond}"
|
||||
return some (← byCases mvarId cond hName)
|
||||
else
|
||||
unreachable!
|
||||
|
||||
def simpIfLocalDecl (mvarId : MVarId) (fvarId : FVarId) : MetaM (FVarId × MVarId) := do
|
||||
if let some result ← simpLocalDecl mvarId fvarId (← getSimpContext) discharge? then
|
||||
return result
|
||||
else
|
||||
unreachable!
|
||||
|
||||
open Std
|
||||
|
||||
structure TargetSubgoal where
|
||||
mvarId : MVarId
|
||||
condFVarIds : PArray FVarId := {}
|
||||
|
||||
structure State where
|
||||
hNames : List Name
|
||||
|
||||
abbrev M := StateRefT State MetaM
|
||||
|
||||
private def getNextName : M Name := do
|
||||
match (← get).hNames with
|
||||
| [] => mkFreshUserName `h
|
||||
| n::ns =>
|
||||
modify fun s => { s with hNames := ns }
|
||||
return n
|
||||
|
||||
private partial def splitIfTargetCore (mvarId : MVarId) (condFVarIds : PArray FVarId) : M (List TargetSubgoal) := do
|
||||
if let some cond := findIfToSplit? (← getMVarType mvarId) then
|
||||
trace[Meta.splitIf] "splitting on {cond}"
|
||||
let (s₁, s₂) ← byCases mvarId cond (← getNextName)
|
||||
let (progress₁, ss₁) ← recurse s₁
|
||||
let (progress₂, ss₂) ← recurse s₂
|
||||
if progress₁ || progress₂ then
|
||||
return ss₁ ++ ss₂
|
||||
else
|
||||
return [{ mvarId, condFVarIds }]
|
||||
else
|
||||
return [{ mvarId, condFVarIds }]
|
||||
where
|
||||
recurse (s : ByCasesSubgoal) : M (Bool × List TargetSubgoal) := do
|
||||
let mvarId ← simpIfTarget s.mvarId
|
||||
if mvarId == s.mvarId then
|
||||
return (false, [{ mvarId, condFVarIds }])
|
||||
else
|
||||
return (true, (← splitIfTargetCore mvarId (condFVarIds.push s.fvarId)))
|
||||
|
||||
structure LocalDeclSubgoal where
|
||||
mvarId : MVarId
|
||||
fvarId : FVarId
|
||||
condFVarIds : PArray FVarId := {}
|
||||
|
||||
private partial def splitIfLocalDeclCore (mvarId : MVarId) (fvarId : FVarId) (condFVarIds : PArray FVarId) : M (List LocalDeclSubgoal) :=
|
||||
withMVarContext mvarId do
|
||||
if let some cond := findIfToSplit? (← getLocalDecl fvarId).type then
|
||||
let (s₁, s₂) ← byCases mvarId cond (← getNextName)
|
||||
let (progress₁, ss₁) ← recurse s₁
|
||||
let (progress₂, ss₂) ← recurse s₂
|
||||
if progress₁ || progress₂ then
|
||||
return ss₁ ++ ss₂
|
||||
else
|
||||
return [{ mvarId, fvarId, condFVarIds }]
|
||||
else
|
||||
return [{ mvarId, fvarId, condFVarIds }]
|
||||
where
|
||||
recurse (s : ByCasesSubgoal) : M (Bool × List LocalDeclSubgoal) := do
|
||||
let (fvarId', mvarId) ← simpIfLocalDecl s.mvarId fvarId
|
||||
if mvarId == s.mvarId then
|
||||
return (false, [{ mvarId, fvarId, condFVarIds }])
|
||||
else
|
||||
return (true, (← splitIfLocalDeclCore mvarId fvarId' (condFVarIds.push s.fvarId)))
|
||||
|
||||
structure Subgoal where
|
||||
mvarId : MVarId
|
||||
fvarIds : PArray FVarId := {}
|
||||
condFVarIds : PArray FVarId := {}
|
||||
|
||||
def splitIfGoalCore (mvarId : MVarId) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[]) : M (List Subgoal) := do
|
||||
let mut ss ← goTarget
|
||||
for fvarId in fvarIdsToSimp do
|
||||
ss ← goLocalDecl ss fvarId
|
||||
return ss
|
||||
where
|
||||
goTarget : M (List Subgoal) := do
|
||||
let mvarId ← simpIfTarget mvarId
|
||||
let ss ← splitIfTargetCore mvarId {}
|
||||
ss.mapM fun s => { s with : Subgoal }
|
||||
goLocalDecl (ss : List Subgoal) (fvarId : FVarId) : M (List Subgoal) := do
|
||||
let sss ← ss.mapM fun s => do
|
||||
let (fvarId, mvarId) ← simpIfLocalDecl s.mvarId fvarId
|
||||
let ss' ← splitIfLocalDeclCore mvarId fvarId s.condFVarIds
|
||||
ss'.mapM fun s' => { mvarId := s'.mvarId, fvarIds := s.fvarIds.push s'.fvarId, condFVarIds := s'.condFVarIds : Subgoal }
|
||||
return sss.join
|
||||
return none
|
||||
|
||||
end SplitIf
|
||||
|
||||
def splitIfGoal (mvarId : MVarId) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[]) (hNames : List Name := []) : MetaM (List SplitIf.Subgoal) := do
|
||||
SplitIf.splitIfGoalCore mvarId simplifyTarget fvarIdsToSimp |>.run' { hNames }
|
||||
open SplitIf
|
||||
|
||||
def simpIfTarget (mvarId : MVarId) (useDecide := false) : MetaM MVarId := do
|
||||
let mut ctx ← getSimpContext
|
||||
if let some mvarId' ← simpTarget mvarId ctx (discharge? useDecide) then
|
||||
return mvarId'
|
||||
else
|
||||
unreachable!
|
||||
|
||||
def splitIfTarget? (mvarId : MVarId) (hName? : Option Name := none) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := commitWhenSome? do
|
||||
if let some (s₁, s₂) ← splitIfAt? mvarId (← getMVarType mvarId) hName? then
|
||||
let mvarId₁ ← simpIfTarget s₁.mvarId
|
||||
let mvarId₂ ← simpIfTarget s₂.mvarId
|
||||
if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then
|
||||
return none
|
||||
else
|
||||
return some ({ s₁ with mvarId := mvarId₁ }, { s₂ with mvarId := mvarId₂ })
|
||||
else
|
||||
return none
|
||||
|
||||
builtin_initialize registerTraceClass `Meta.Tactic.splitIf
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ def h (x y : Nat) : Nat :=
|
|||
| 10000, _ => 0
|
||||
| 10001, _ => 5
|
||||
| _, 20000 => 4
|
||||
-- | x+1, _ => 3
|
||||
-- | Nat.zero, y+1 => 44
|
||||
| x+1, _ => 3
|
||||
| Nat.zero, y+1 => 44
|
||||
| _, _ => 1
|
||||
|
||||
-- theorem ex1 : h 10000 1 = 0 :=
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue