diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index 118c8b56b3..ac81efe8d4 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -17,6 +17,15 @@ public section namespace Lean.Meta.Match +register_builtin_option backwards.match.rowMajor : Bool := { + defValue := true + group := "bootstrap" + descr := "If true (the default), match compilation will split the discrimnants based \ + on position of the first constructor pattern in the first alternative. If false, \ + it splits them from left to right, which can lead to unnecessary code bloat." +} + + private def mkIncorrectNumberOfPatternsMsg [ToMessageData α] (discrepancyKind : String) (expected actual : Nat) (pats : List α) := let patternsMsg := MessageData.joinSep (pats.map toMessageData) ", " @@ -747,6 +756,46 @@ private def checkNextPatternTypes (p : Problem) : MetaM Unit := do unless (← isDefEq xType eType) do throwError "Type mismatch in pattern: Pattern{indentExpr e}\n{← mkHasTypeButIsExpectedMsg eType xType}" +private def List.moveToFront [Inhabited α] (as : List α) (i : Nat) : List α := + let rec loop : (as : List α) → (i : Nat) → α × List α + | [], _ => unreachable! + | a::as, 0 => (a, as) + | a::as, i+1 => + let (b, bs) := loop as i + (b, a::bs) + let (b, bs) := loop as i + b :: bs + +/-- Move variable `#i` to the beginning of the to-do list `p.vars`. -/ +private def moveToFront (p : Problem) (i : Nat) : Problem := + if i == 0 then + p + else if i < p.vars.length then + { p with + vars := List.moveToFront p.vars i + alts := p.alts.map fun alt => { alt with patterns := List.moveToFront alt.patterns i } + } + else + p + +def Pattern.isRefutable : Pattern → Bool + | .var _ => false + | .inaccessible _ => false + | .as _ p _ => p.isRefutable + | .arrayLit .. => true + | .ctor .. => true + | .val .. => true + +/-- +Returns the index of the first pattern in the first alternative that is refutable +(i.e. not a variable or inaccessible pattern). We want to handle these first +so that the generated code branches in the order suggested by the user's code. +-/ +private def firstRefutablePattern (p : Problem) : Option Nat := + match p.alts with + | alt:: _ => alt.patterns.findIdx? (·.isRefutable) + | _ => none + def isExFalsoTransition (p : Problem) : MetaM Bool := do if p.alts.isEmpty then withGoalOf p do @@ -778,6 +827,21 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do process p return + if backwards.match.rowMajor.get (← getOptions) then + match firstRefutablePattern p with + | some i => + if i > 0 then + traceStep ("move var to front") + process (moveToFront p i) + return + | none => + if 1 < p.alts.length then + traceStep ("drop all but first alt") + -- all patterns are irrefutable, we can drop all other alts + let p := { p with alts := p.alts.take 1 } + process p + return + if (← isNatValueTransition p) then traceStep ("nat value to constructor") process (← expandNatValuePattern p) diff --git a/src/Lean/Meta/Tactic/Cases.lean b/src/Lean/Meta/Tactic/Cases.lean index 2ca2f54608..011b004bf6 100644 --- a/src/Lean/Meta/Tactic/Cases.lean +++ b/src/Lean/Meta/Tactic/Cases.lean @@ -237,11 +237,11 @@ partial def unifyEqs? (numEqs : Nat) (mvarId : MVarId) (subst : FVarSubst) (case return none private def unifyCasesEqs (numEqs : Nat) (subgoals : Array CasesSubgoal) : MetaM (Array CasesSubgoal) := - subgoals.foldlM (init := #[]) fun subgoals s => do + subgoals.filterMapM fun s => do match (← unifyEqs? numEqs s.mvarId s.subst s.ctorName) with - | none => pure subgoals + | none => pure none | some (mvarId, subst) => - return subgoals.push { s with + return some { s with mvarId := mvarId, subst := subst, fields := s.fields.map (subst.apply ·) diff --git a/tests/lean/run/issue10749.lean b/tests/lean/run/issue10749.lean index 069f07ba75..b964a4d560 100644 --- a/tests/lean/run/issue10749.lean +++ b/tests/lean/run/issue10749.lean @@ -12,7 +12,7 @@ def test (a : List Nat) : Nat := /-- info: def test.match_1.{u_1} : (motive : List Nat → Sort u_1) → (a : List Nat) → ((x : List Nat) → motive x) → (Unit → motive []) → motive a := -fun motive a h_1 h_2 => List.casesOn a (h_1 []) fun head tail => h_1 (head :: tail) +fun motive a h_1 h_2 => h_1 a -/ #guard_msgs in #print test.match_1 @@ -31,7 +31,7 @@ info: def test2.match_1.{u_1} : (motive : List Nat → List Nat → Sort u_1) (tail : List Nat) → (head_1 : Nat) → (tail_1 : List Nat) → motive (head :: tail) (head_1 :: tail_1)) → motive a b := fun motive a b h_1 h_2 h_3 => - List.casesOn a (List.casesOn b (h_1 []) fun head tail => h_1 (head :: tail)) fun head tail => + List.casesOn a (h_1 b) fun head tail => List.casesOn b (h_2 (head :: tail)) fun head_1 tail_1 => h_3 head tail head_1 tail_1 -/ #guard_msgs in #print test2.match_1 @@ -51,8 +51,7 @@ info: def test3.match_1.{u_1} : (motive : List Nat → Bool → Sort u_1) → ((x : List Nat) → motive x true) → ((x : Bool) → motive [] x) → ((x : List Nat) → (x_1 : Bool) → motive x x_1) → motive a b := fun motive a b h_1 h_2 h_3 => - List.casesOn a (Bool.casesOn b (h_2 false) (h_1 [])) fun head tail => - Bool.casesOn b (h_3 (head :: tail) false) (h_1 (head :: tail)) + Bool.casesOn b (List.casesOn a (h_2 false) fun head tail => h_3 (head :: tail) false) (h_1 a) -/ #guard_msgs in #print test3.match_1 @@ -79,29 +78,33 @@ info: def test4.match_1.{u_1} : (motive : Bool → Bool → Bool → Bool → Bo ((x x_5 x_6 x_7 : Bool) → motive true x x_5 x_6 x_7) → ((x x_5 x_6 x_7 x_8 : Bool) → motive x x_5 x_6 x_7 x_8) → motive x x_1 x_2 x_3 x_4 := fun motive x x_1 x_2 x_3 x_4 h_1 h_2 h_3 h_4 h_5 h_6 => - Bool.casesOn x - (Bool.casesOn x_1 + Bool.casesOn x_4 + (Bool.casesOn x_3 (Bool.casesOn x_2 - (Bool.casesOn x_3 (Bool.casesOn x_4 (h_6 false false false false false) (h_1 false false false false)) - (Bool.casesOn x_4 (h_2 false false false false) (h_1 false false false true))) - (Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 false false false false) (h_1 false false true false)) - (Bool.casesOn x_4 (h_2 false false true false) (h_1 false false true true)))) - (Bool.casesOn x_2 - (Bool.casesOn x_3 (Bool.casesOn x_4 (h_4 false false false false) (h_1 false true false false)) - (Bool.casesOn x_4 (h_2 false true false false) (h_1 false true false true))) - (Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 false true false false) (h_1 false true true false)) - (Bool.casesOn x_4 (h_2 false true true false) (h_1 false true true true))))) - (Bool.casesOn x_1 - (Bool.casesOn x_2 - (Bool.casesOn x_3 (Bool.casesOn x_4 (h_5 false false false false) (h_1 true false false false)) - (Bool.casesOn x_4 (h_2 true false false false) (h_1 true false false true))) - (Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 true false false false) (h_1 true false true false)) - (Bool.casesOn x_4 (h_2 true false true false) (h_1 true false true true)))) - (Bool.casesOn x_2 - (Bool.casesOn x_3 (Bool.casesOn x_4 (h_4 true false false false) (h_1 true true false false)) - (Bool.casesOn x_4 (h_2 true true false false) (h_1 true true false true))) - (Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 true true false false) (h_1 true true true false)) - (Bool.casesOn x_4 (h_2 true true true false) (h_1 true true true true))))) + (Bool.casesOn x_1 (Bool.casesOn x (h_6 false false false false false) (h_5 false false false false)) + (h_4 x false false false)) + (h_3 x x_1 false false)) + (h_2 x x_1 x_2 false)) + (h_1 x x_1 x_2 x_3) -/ #guard_msgs in #print test4.match_1 + +-- Just testing the backwards compatibility option + +set_option match.ignoreUnusedAlts true in +set_option backwards.match.rowMajor false in +def testOld (a : List Nat) : Nat := + match a with + | _ => 3 + | [] => 4 + +-- Has unnecessary `casesOn` + +/-- +info: def testOld.match_1.{u_1} : (motive : List Nat → Sort u_1) → + (a : List Nat) → ((x : List Nat) → motive x) → (Unit → motive []) → motive a := +fun motive a h_1 h_2 => List.casesOn a (h_1 []) fun head tail => h_1 (head :: tail) +-/ +#guard_msgs in +#print testOld.match_1 diff --git a/tests/lean/run/issue10794.lean b/tests/lean/run/issue10794.lean index 91692f022e..361346c7bd 100644 --- a/tests/lean/run/issue10794.lean +++ b/tests/lean/run/issue10794.lean @@ -1,6 +1,8 @@ /-- -error: Dependent match elimination failed: Could not solve constraints - true ≋ false +error: Dependent elimination failed: Type mismatch when solving this alternative: it has type + motive false +but is expected to have type + motive b✝ -/ #guard_msgs in def test1 b := match b with @@ -8,8 +10,10 @@ def test1 b := match b with | true => 2 /-- -error: Dependent match elimination failed: Could not solve constraints - true ≋ false +error: Dependent elimination failed: Type mismatch when solving this alternative: it has type + motive false ⋯ +but is expected to have type + motive x✝¹ x✝ -/ #guard_msgs in def test2 : (b : Bool) → (h : b = false) → Nat diff --git a/tests/lean/run/match1.lean b/tests/lean/run/match1.lean index f14a4ae7c6..c3ae60a2b3 100644 --- a/tests/lean/run/match1.lean +++ b/tests/lean/run/match1.lean @@ -137,9 +137,9 @@ partial def natToBin' : (n : Nat) → List Bool /-- error: Tactic `cases` failed with a nested error: Dependent elimination failed: Failed to solve equation - Nat.zero = n✝.add n✝ + n✝¹.succ = n✝.add n✝ at case `Parity.even` after processing - Nat.zero, _ + (Nat.succ _), _ the dependent pattern matcher can solve the following kinds of equations - = and = - = where the terms are definitionally equal diff --git a/tests/lean/run/matchOverlapInaccesible.lean b/tests/lean/run/matchOverlapInaccesible.lean index af459f68a3..b2d3e6d7f7 100644 --- a/tests/lean/run/matchOverlapInaccesible.lean +++ b/tests/lean/run/matchOverlapInaccesible.lean @@ -16,9 +16,9 @@ else /-- error: Tactic `cases` failed with a nested error: Dependent elimination failed: Failed to solve equation - Nat.zero = n✝.add n✝ + n✝¹.succ = n✝.add n✝ at case `Parity.even` after processing - Nat.zero, _ + (Nat.succ _), _ the dependent pattern matcher can solve the following kinds of equations - = and = - = where the terms are definitionally equal @@ -56,9 +56,9 @@ def parity (n : MyNat) : Parity n := sorry /-- error: Tactic `cases` failed with a nested error: Dependent elimination failed: Failed to solve equation - zero = n✝.add n✝ + a✝.succ = n✝.add n✝ at case `Parity.even` after processing - zero, _ + (succ _), _ the dependent pattern matcher can solve the following kinds of equations - = and = - = where the terms are definitionally equal