refactor: cleanup MatchEqs and simplify SplitIf

This commit is contained in:
Leonardo de Moura 2021-08-18 18:22:07 -07:00
parent aa177dacc3
commit 45d3b85d5a
3 changed files with 61 additions and 109 deletions

View file

@ -105,17 +105,23 @@ where
proveLoop (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do
let mvarId ← modifyTargetEqLHS mvarId whnfCore
trace[Meta.debug] "proveLoop\n{MessageData.ofGoal mvarId}"
(applyRefl mvarId)
<|>
(contradiction mvarId)
<|>
(commitIfNoEx do
let s::ss ← splitIfGoal mvarId | failed
if ss.isEmpty && s.mvarId == mvarId then failed
(s::ss).forM fun s => proveLoop s.mvarId (depth + 1))
(do let mvarId' ← simpIfTarget mvarId (useDecide := true)
trace[Meta.debug] "simpIfTarget\n{MessageData.ofGoal mvarId'}"
if mvarId' == mvarId then failed
proveLoop mvarId' (depth+1))
<|>
(do
trace[Meta.debug] "TODO\n{← ppGoal mvarId}"
(do if let some (s₁, s₂) ← splitIfTarget? mvarId then
proveLoop s₁.mvarId (depth+1)
proveLoop s₂.mvarId (depth+1)
else
failed)
<|>
(do trace[Meta.debug] "TODO\n{← ppGoal mvarId}"
-- TODO
admit mvarId)

View file

@ -28,12 +28,26 @@ builtin_initialize ext : LazyInitExtension MetaM Simp.Context ←
config.decide := false
}
/--
Default `Simp.Context` for `simpIf` methods. It contains all congruence lemmas, but
just the rewriting rules for reducing `if` expressions. -/
def getSimpContext : MetaM Simp.Context :=
ext.get
def discharge? : Simp.Discharge := fun prop => do
/--
Default `discharge?` function for `simpIf` methods.
It only uses hypotheses from the local context. It is effective
after a case-split. -/
def discharge? (useDecide := false) : Simp.Discharge := fun prop => do
let prop ← instantiateMVars prop
trace[Meta.splitIf] "discharge? {prop}, {prop.notNot?}"
trace[Meta.Tactic.splitIf] "discharge? {prop}, {prop.notNot?}"
if useDecide then
let prop ← instantiateMVars prop
if !prop.hasFVar && !prop.hasMVar then
let d ← mkDecide prop
let r ← withDefault <| whnf d
if r.isConstOf ``true then
return some <| mkApp3 (mkConst ``of_decide_eq_true) prop d.appArg! (← mkEqRefl (mkConst ``true))
(← getLCtx).findDeclRevM? fun localDecl => do
if localDecl.isAuxDecl then
return none
@ -47,115 +61,47 @@ def discharge? : Simp.Discharge := fun prop => do
else
return none
/-- Return the condition of an `if` expression to case split. -/
partial def findIfToSplit? (e : Expr) : Option Expr :=
if let some iteApp := e.find? fun e => !e.hasLooseBVars && (e.isAppOfArity ``ite 5 || e.isAppOfArity ``dite 5) then
let cond := iteApp.getArg! 1 5
-- Try to find a nested `if` in `cond`
findIfToSplit? cond |>.getD cond
else
none
def simpIfTarget (mvarId : MVarId) : MetaM MVarId := do
trace[Meta.splitIf] "before simpIfTarget\n{MessageData.ofGoal mvarId}"
if let some mvarId ← simpTarget mvarId (← getSimpContext) discharge? then
trace[Meta.splitIf] "after simpIfTarget\n{MessageData.ofGoal mvarId}"
return mvarId
def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := do
if let some cond := findIfToSplit? e then
let hName ← match hName? with
| none => mkFreshUserName `h
| some hName => pure hName
trace[Meta.Tactic.splitIf] "splitting on {cond}"
return some (← byCases mvarId cond hName)
else
unreachable!
def simpIfLocalDecl (mvarId : MVarId) (fvarId : FVarId) : MetaM (FVarId × MVarId) := do
if let some result ← simpLocalDecl mvarId fvarId (← getSimpContext) discharge? then
return result
else
unreachable!
open Std
structure TargetSubgoal where
mvarId : MVarId
condFVarIds : PArray FVarId := {}
structure State where
hNames : List Name
abbrev M := StateRefT State MetaM
private def getNextName : M Name := do
match (← get).hNames with
| [] => mkFreshUserName `h
| n::ns =>
modify fun s => { s with hNames := ns }
return n
private partial def splitIfTargetCore (mvarId : MVarId) (condFVarIds : PArray FVarId) : M (List TargetSubgoal) := do
if let some cond := findIfToSplit? (← getMVarType mvarId) then
trace[Meta.splitIf] "splitting on {cond}"
let (s₁, s₂) ← byCases mvarId cond (← getNextName)
let (progress₁, ss₁) ← recurse s₁
let (progress₂, ss₂) ← recurse s₂
if progress₁ || progress₂ then
return ss₁ ++ ss₂
else
return [{ mvarId, condFVarIds }]
else
return [{ mvarId, condFVarIds }]
where
recurse (s : ByCasesSubgoal) : M (Bool × List TargetSubgoal) := do
let mvarId ← simpIfTarget s.mvarId
if mvarId == s.mvarId then
return (false, [{ mvarId, condFVarIds }])
else
return (true, (← splitIfTargetCore mvarId (condFVarIds.push s.fvarId)))
structure LocalDeclSubgoal where
mvarId : MVarId
fvarId : FVarId
condFVarIds : PArray FVarId := {}
private partial def splitIfLocalDeclCore (mvarId : MVarId) (fvarId : FVarId) (condFVarIds : PArray FVarId) : M (List LocalDeclSubgoal) :=
withMVarContext mvarId do
if let some cond := findIfToSplit? (← getLocalDecl fvarId).type then
let (s₁, s₂) ← byCases mvarId cond (← getNextName)
let (progress₁, ss₁) ← recurse s₁
let (progress₂, ss₂) ← recurse s₂
if progress₁ || progress₂ then
return ss₁ ++ ss₂
else
return [{ mvarId, fvarId, condFVarIds }]
else
return [{ mvarId, fvarId, condFVarIds }]
where
recurse (s : ByCasesSubgoal) : M (Bool × List LocalDeclSubgoal) := do
let (fvarId', mvarId) ← simpIfLocalDecl s.mvarId fvarId
if mvarId == s.mvarId then
return (false, [{ mvarId, fvarId, condFVarIds }])
else
return (true, (← splitIfLocalDeclCore mvarId fvarId' (condFVarIds.push s.fvarId)))
structure Subgoal where
mvarId : MVarId
fvarIds : PArray FVarId := {}
condFVarIds : PArray FVarId := {}
def splitIfGoalCore (mvarId : MVarId) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[]) : M (List Subgoal) := do
let mut ss ← goTarget
for fvarId in fvarIdsToSimp do
ss ← goLocalDecl ss fvarId
return ss
where
goTarget : M (List Subgoal) := do
let mvarId ← simpIfTarget mvarId
let ss ← splitIfTargetCore mvarId {}
ss.mapM fun s => { s with : Subgoal }
goLocalDecl (ss : List Subgoal) (fvarId : FVarId) : M (List Subgoal) := do
let sss ← ss.mapM fun s => do
let (fvarId, mvarId) ← simpIfLocalDecl s.mvarId fvarId
let ss' ← splitIfLocalDeclCore mvarId fvarId s.condFVarIds
ss'.mapM fun s' => { mvarId := s'.mvarId, fvarIds := s.fvarIds.push s'.fvarId, condFVarIds := s'.condFVarIds : Subgoal }
return sss.join
return none
end SplitIf
def splitIfGoal (mvarId : MVarId) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[]) (hNames : List Name := []) : MetaM (List SplitIf.Subgoal) := do
SplitIf.splitIfGoalCore mvarId simplifyTarget fvarIdsToSimp |>.run' { hNames }
open SplitIf
def simpIfTarget (mvarId : MVarId) (useDecide := false) : MetaM MVarId := do
let mut ctx ← getSimpContext
if let some mvarId' ← simpTarget mvarId ctx (discharge? useDecide) then
return mvarId'
else
unreachable!
def splitIfTarget? (mvarId : MVarId) (hName? : Option Name := none) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := commitWhenSome? do
if let some (s₁, s₂) ← splitIfAt? mvarId (← getMVarType mvarId) hName? then
let mvarId₁ ← simpIfTarget s₁.mvarId
let mvarId₂ ← simpIfTarget s₂.mvarId
if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then
return none
else
return some ({ s₁ with mvarId := mvarId₁ }, { s₂ with mvarId := mvarId₂ })
else
return none
builtin_initialize registerTraceClass `Meta.Tactic.splitIf
end Lean.Meta

View file

@ -23,8 +23,8 @@ def h (x y : Nat) : Nat :=
| 10000, _ => 0
| 10001, _ => 5
| _, 20000 => 4
-- | x+1, _ => 3
-- | Nat.zero, y+1 => 44
| x+1, _ => 3
| Nat.zero, y+1 => 44
| _, _ => 1
-- theorem ex1 : h 10000 1 = 0 :=