diff --git a/src/Lean/Meta/Tactic/Split.lean b/src/Lean/Meta/Tactic/Split.lean index 7cfb476c0d..dbbb928387 100644 --- a/src/Lean/Meta/Tactic/Split.lean +++ b/src/Lean/Meta/Tactic/Split.lean @@ -9,7 +9,43 @@ import Lean.Meta.Tactic.Generalize namespace Lean.Meta namespace Split -private def genMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Array FVarId × MVarId) := do +private def getSimpMatchContext : MetaM Simp.Context := + return { + simpLemmas := {} + congrLemmas := (← getCongrLemmas) + config.zeta := false + config.beta := false + config.eta := false + config.iota := false + config.proj := false + config.decide := false + } + +private def simpMatchPre (matchDeclName : Name) (matchEqDeclName : Name) (e : Expr) : SimpM Simp.Step := do + if e.isAppOf matchDeclName then + -- First try to reduce matcher + match (← reduceRecMatcher? e) with + | some e' => return Simp.Step.done { expr := e' } + | none => + -- Try lemma + match (← Simp.tryLemma? e { proof := mkConst matchEqDeclName, name? := matchEqDeclName } SplitIf.discharge?) with + | none => return Simp.Step.visit { expr := e } + | some r => return Simp.Step.done r + else + return Simp.Step.visit { expr := e } + +private def simpMatch (matchDeclName : Name) (matchEqDeclName : Name) (e : Expr) : MetaM Simp.Result := do + Simp.main e (← getSimpMatchContext) (methods := { pre := simpMatchPre matchDeclName matchEqDeclName }) + +private def simpMatchTarget (mvarId : MVarId) (matchDeclName : Name) (matchEqDeclName : Name) : MetaM MVarId := do + withMVarContext mvarId do + let target ← instantiateMVars (← getMVarType mvarId) + let r ← simpMatch matchDeclName matchEqDeclName target + match r.proof? with + | some proof => replaceTargetEq mvarId r.expr proof + | none => replaceTargetDefEq mvarId r.expr + +private def generalizeMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Array FVarId × MVarId) := do if discrs.all (·.isFVar) then return (discrs.map (·.fvarId!), mvarId) else @@ -19,35 +55,30 @@ private def genMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Arra def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do let some app ← matchMatcherApp? e | throwError "match application expected" - let (discrFVarIds, mvarId) ← genMatchDiscrs mvarId app.discrs + let (discrFVarIds, mvarId) ← generalizeMatchDiscrs mvarId app.discrs trace[Meta.debug] "split [1]:\n{MessageData.ofGoal mvarId}" let (reverted, mvarId) ← revert mvarId discrFVarIds - trace[Meta.debug] "split [2]:\n{MessageData.ofGoal mvarId}" let (discrFVarIds, mvarId) ← introNP mvarId discrFVarIds.size let numExtra := reverted.size - discrFVarIds.size - trace[Meta.debug] "split [3]:\n{MessageData.ofGoal mvarId}" let discrs := discrFVarIds.map mkFVar let matchEqns ← Match.getEquationsFor app.matcherName withMVarContext mvarId do let motive ← mkLambdaFVars discrs (← getMVarType mvarId) - trace[Meta.debug] "split [4]: {motive}" -- Fix universe let mut us := app.matcherLevels if let some uElimPos := app.uElimPos? then -- Set universe elimination level to zero (Prop). us := us.set! uElimPos levelZero - trace[Meta.debug] "us: {us}" let splitter := mkAppN (mkConst matchEqns.splitterName us.toList) app.params let splitter := mkAppN (mkApp splitter motive) discrs - trace[Meta.debug] "splitter: {splitter}" check splitter -- TODO let mvarIds ← apply mvarId splitter let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do - trace[Meta.debug] "split [5]:\n{MessageData.ofGoal mvarId}" let numParams := matchEqns.splitterAltNumParams[i] - -- TODO: use equation lemmas to reduce `match`-expressions let (_, mvarId) ← introN mvarId numParams let (_, mvarId) ← introNP mvarId numExtra + trace[Meta.debug] "before simpMatch:\n{MessageData.ofGoal mvarId}" + let mvarId ← simpMatchTarget mvarId app.matcherName matchEqns.eqnNames[i] return (i+1, mvarId::mvarIds) return mvarIds.reverse diff --git a/tests/lean/run/matchEqs.lean b/tests/lean/run/matchEqs.lean index dfdd55c5e3..2f63824204 100644 --- a/tests/lean/run/matchEqs.lean +++ b/tests/lean/run/matchEqs.lean @@ -22,14 +22,7 @@ test% f.match_1 theorem ex (x : List Nat) : f x > 0 := by simp [f] - induction x using f.match_1.splitter - next => simp [f.match_1.eq_1] - next x => simp [f.match_1.eq_2] - next x h1 h2 => - rw [f.match_1.eq_3] - . decide - . exact h1 - . exact h2 + split <;> decide test% Std.RBNode.balance1.match_1 #check @Std.RBNode.balance1.match_1.splitter diff --git a/tests/lean/run/split1.lean b/tests/lean/run/split1.lean index 0159f857ad..d311257345 100644 --- a/tests/lean/run/split1.lean +++ b/tests/lean/run/split1.lean @@ -7,13 +7,9 @@ def f (xs : List Nat) : Nat := theorem ex1 (xs : List Nat) (hr : xs.reverse = xs) (ys : Nat) : ys > 0 → f xs > 0 := by simp [f] split - next => intro hys; simp - next => intro hys; simp; apply Nat.zero_lt_succ - next zs n₁ n₂ => - intro hys - rw [f.match_1.eq_3] - anyGoals assumption - decide + next => intro hys; decide + next => intro hys; apply Nat.zero_lt_succ + next zs n₁ n₂ => intro hys; decide def g (xs : List Nat) : Nat := match xs with @@ -23,7 +19,5 @@ def g (xs : List Nat) : Nat := theorem ex2 (xs : List Nat) : g xs > 0 := by simp [g] split - . simp; apply Nat.zero_lt_succ - . rw [g.match_1.eq_2] - . decide - . assumption + next a b c d e => apply Nat.zero_lt_succ + next h => decide