From 45d3b85d5ac7db16f3fc3c2fbcb9370a9dc6abf1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 18 Aug 2021 18:22:07 -0700 Subject: [PATCH] refactor: cleanup `MatchEqs` and simplify `SplitIf` --- src/Lean/Meta/Match/MatchEqs.lean | 18 ++-- src/Lean/Meta/Tactic/SplitIf.lean | 148 ++++++++++-------------------- tests/playground/matchEqs.lean | 4 +- 3 files changed, 61 insertions(+), 109 deletions(-) diff --git a/src/Lean/Meta/Match/MatchEqs.lean b/src/Lean/Meta/Match/MatchEqs.lean index 2a5f0b0e2b..d01254c54c 100644 --- a/src/Lean/Meta/Match/MatchEqs.lean +++ b/src/Lean/Meta/Match/MatchEqs.lean @@ -105,17 +105,23 @@ where proveLoop (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do let mvarId ← modifyTargetEqLHS mvarId whnfCore + trace[Meta.debug] "proveLoop\n{MessageData.ofGoal mvarId}" (applyRefl mvarId) <|> (contradiction mvarId) <|> - (commitIfNoEx do - let s::ss ← splitIfGoal mvarId | failed - if ss.isEmpty && s.mvarId == mvarId then failed - (s::ss).forM fun s => proveLoop s.mvarId (depth + 1)) + (do let mvarId' ← simpIfTarget mvarId (useDecide := true) + trace[Meta.debug] "simpIfTarget\n{MessageData.ofGoal mvarId'}" + if mvarId' == mvarId then failed + proveLoop mvarId' (depth+1)) <|> - (do - trace[Meta.debug] "TODO\n{← ppGoal mvarId}" + (do if let some (s₁, s₂) ← splitIfTarget? mvarId then + proveLoop s₁.mvarId (depth+1) + proveLoop s₂.mvarId (depth+1) + else + failed) + <|> + (do trace[Meta.debug] "TODO\n{← ppGoal mvarId}" -- TODO admit mvarId) diff --git a/src/Lean/Meta/Tactic/SplitIf.lean b/src/Lean/Meta/Tactic/SplitIf.lean index e423c8dfa5..43a4f40962 100644 --- a/src/Lean/Meta/Tactic/SplitIf.lean +++ b/src/Lean/Meta/Tactic/SplitIf.lean @@ -28,12 +28,26 @@ builtin_initialize ext : LazyInitExtension MetaM Simp.Context ← config.decide := false } +/-- + Default `Simp.Context` for `simpIf` methods. It contains all congruence lemmas, but + just the rewriting rules for reducing `if` expressions. -/ def getSimpContext : MetaM Simp.Context := ext.get -def discharge? : Simp.Discharge := fun prop => do +/-- + Default `discharge?` function for `simpIf` methods. + It only uses hypotheses from the local context. It is effective + after a case-split. -/ +def discharge? (useDecide := false) : Simp.Discharge := fun prop => do let prop ← instantiateMVars prop - trace[Meta.splitIf] "discharge? {prop}, {prop.notNot?}" + trace[Meta.Tactic.splitIf] "discharge? {prop}, {prop.notNot?}" + if useDecide then + let prop ← instantiateMVars prop + if !prop.hasFVar && !prop.hasMVar then + let d ← mkDecide prop + let r ← withDefault <| whnf d + if r.isConstOf ``true then + return some <| mkApp3 (mkConst ``of_decide_eq_true) prop d.appArg! (← mkEqRefl (mkConst ``true)) (← getLCtx).findDeclRevM? fun localDecl => do if localDecl.isAuxDecl then return none @@ -47,115 +61,47 @@ def discharge? : Simp.Discharge := fun prop => do else return none +/-- Return the condition of an `if` expression to case split. -/ partial def findIfToSplit? (e : Expr) : Option Expr := if let some iteApp := e.find? fun e => !e.hasLooseBVars && (e.isAppOfArity ``ite 5 || e.isAppOfArity ``dite 5) then let cond := iteApp.getArg! 1 5 + -- Try to find a nested `if` in `cond` findIfToSplit? cond |>.getD cond else none -def simpIfTarget (mvarId : MVarId) : MetaM MVarId := do - trace[Meta.splitIf] "before simpIfTarget\n{MessageData.ofGoal mvarId}" - if let some mvarId ← simpTarget mvarId (← getSimpContext) discharge? then - trace[Meta.splitIf] "after simpIfTarget\n{MessageData.ofGoal mvarId}" - return mvarId +def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := do + if let some cond := findIfToSplit? e then + let hName ← match hName? with + | none => mkFreshUserName `h + | some hName => pure hName + trace[Meta.Tactic.splitIf] "splitting on {cond}" + return some (← byCases mvarId cond hName) else - unreachable! - -def simpIfLocalDecl (mvarId : MVarId) (fvarId : FVarId) : MetaM (FVarId × MVarId) := do - if let some result ← simpLocalDecl mvarId fvarId (← getSimpContext) discharge? then - return result - else - unreachable! - -open Std - -structure TargetSubgoal where - mvarId : MVarId - condFVarIds : PArray FVarId := {} - -structure State where - hNames : List Name - -abbrev M := StateRefT State MetaM - -private def getNextName : M Name := do - match (← get).hNames with - | [] => mkFreshUserName `h - | n::ns => - modify fun s => { s with hNames := ns } - return n - -private partial def splitIfTargetCore (mvarId : MVarId) (condFVarIds : PArray FVarId) : M (List TargetSubgoal) := do - if let some cond := findIfToSplit? (← getMVarType mvarId) then - trace[Meta.splitIf] "splitting on {cond}" - let (s₁, s₂) ← byCases mvarId cond (← getNextName) - let (progress₁, ss₁) ← recurse s₁ - let (progress₂, ss₂) ← recurse s₂ - if progress₁ || progress₂ then - return ss₁ ++ ss₂ - else - return [{ mvarId, condFVarIds }] - else - return [{ mvarId, condFVarIds }] -where - recurse (s : ByCasesSubgoal) : M (Bool × List TargetSubgoal) := do - let mvarId ← simpIfTarget s.mvarId - if mvarId == s.mvarId then - return (false, [{ mvarId, condFVarIds }]) - else - return (true, (← splitIfTargetCore mvarId (condFVarIds.push s.fvarId))) - -structure LocalDeclSubgoal where - mvarId : MVarId - fvarId : FVarId - condFVarIds : PArray FVarId := {} - -private partial def splitIfLocalDeclCore (mvarId : MVarId) (fvarId : FVarId) (condFVarIds : PArray FVarId) : M (List LocalDeclSubgoal) := - withMVarContext mvarId do - if let some cond := findIfToSplit? (← getLocalDecl fvarId).type then - let (s₁, s₂) ← byCases mvarId cond (← getNextName) - let (progress₁, ss₁) ← recurse s₁ - let (progress₂, ss₂) ← recurse s₂ - if progress₁ || progress₂ then - return ss₁ ++ ss₂ - else - return [{ mvarId, fvarId, condFVarIds }] - else - return [{ mvarId, fvarId, condFVarIds }] -where - recurse (s : ByCasesSubgoal) : M (Bool × List LocalDeclSubgoal) := do - let (fvarId', mvarId) ← simpIfLocalDecl s.mvarId fvarId - if mvarId == s.mvarId then - return (false, [{ mvarId, fvarId, condFVarIds }]) - else - return (true, (← splitIfLocalDeclCore mvarId fvarId' (condFVarIds.push s.fvarId))) - -structure Subgoal where - mvarId : MVarId - fvarIds : PArray FVarId := {} - condFVarIds : PArray FVarId := {} - -def splitIfGoalCore (mvarId : MVarId) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[]) : M (List Subgoal) := do - let mut ss ← goTarget - for fvarId in fvarIdsToSimp do - ss ← goLocalDecl ss fvarId - return ss -where - goTarget : M (List Subgoal) := do - let mvarId ← simpIfTarget mvarId - let ss ← splitIfTargetCore mvarId {} - ss.mapM fun s => { s with : Subgoal } - goLocalDecl (ss : List Subgoal) (fvarId : FVarId) : M (List Subgoal) := do - let sss ← ss.mapM fun s => do - let (fvarId, mvarId) ← simpIfLocalDecl s.mvarId fvarId - let ss' ← splitIfLocalDeclCore mvarId fvarId s.condFVarIds - ss'.mapM fun s' => { mvarId := s'.mvarId, fvarIds := s.fvarIds.push s'.fvarId, condFVarIds := s'.condFVarIds : Subgoal } - return sss.join + return none end SplitIf -def splitIfGoal (mvarId : MVarId) (simplifyTarget : Bool := true) (fvarIdsToSimp : Array FVarId := #[]) (hNames : List Name := []) : MetaM (List SplitIf.Subgoal) := do - SplitIf.splitIfGoalCore mvarId simplifyTarget fvarIdsToSimp |>.run' { hNames } +open SplitIf + +def simpIfTarget (mvarId : MVarId) (useDecide := false) : MetaM MVarId := do + let mut ctx ← getSimpContext + if let some mvarId' ← simpTarget mvarId ctx (discharge? useDecide) then + return mvarId' + else + unreachable! + +def splitIfTarget? (mvarId : MVarId) (hName? : Option Name := none) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := commitWhenSome? do + if let some (s₁, s₂) ← splitIfAt? mvarId (← getMVarType mvarId) hName? then + let mvarId₁ ← simpIfTarget s₁.mvarId + let mvarId₂ ← simpIfTarget s₂.mvarId + if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then + return none + else + return some ({ s₁ with mvarId := mvarId₁ }, { s₂ with mvarId := mvarId₂ }) + else + return none + +builtin_initialize registerTraceClass `Meta.Tactic.splitIf end Lean.Meta diff --git a/tests/playground/matchEqs.lean b/tests/playground/matchEqs.lean index 6afa8ba4a2..c143e1e768 100644 --- a/tests/playground/matchEqs.lean +++ b/tests/playground/matchEqs.lean @@ -23,8 +23,8 @@ def h (x y : Nat) : Nat := | 10000, _ => 0 | 10001, _ => 5 | _, 20000 => 4 --- | x+1, _ => 3 --- | Nat.zero, y+1 => 44 + | x+1, _ => 3 + | Nat.zero, y+1 => 44 | _, _ => 1 -- theorem ex1 : h 10000 1 = 0 :=