fix: avoid unnecessary branching in match compilation (#10763)

This PR improves match compilation: Branch on variables in the order
suggested by the first remaining alternative, and do not branch when the
first remaining alternative does not require it. This fixes
https://github.com/leanprover/lean4/issues/10749. With `set_option
backwards.match.rowMajor false` the old behavior can be turned on.

(For now this is an experiment to get familiar with the code and the
whole
problem domain. It is likely overly naive.)
This commit is contained in:
Joachim Breitner 2025-10-30 16:05:13 -04:00 committed by GitHub
parent 275f9077b6
commit c7f57d6a0b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 110 additions and 39 deletions

View file

@ -17,6 +17,15 @@ public section
namespace Lean.Meta.Match
register_builtin_option backwards.match.rowMajor : Bool := {
defValue := true
group := "bootstrap"
descr := "If true (the default), match compilation will split the discrimnants based \
on position of the first constructor pattern in the first alternative. If false, \
it splits them from left to right, which can lead to unnecessary code bloat."
}
private def mkIncorrectNumberOfPatternsMsg [ToMessageData α]
(discrepancyKind : String) (expected actual : Nat) (pats : List α) :=
let patternsMsg := MessageData.joinSep (pats.map toMessageData) ", "
@ -747,6 +756,46 @@ private def checkNextPatternTypes (p : Problem) : MetaM Unit := do
unless (← isDefEq xType eType) do
throwError "Type mismatch in pattern: Pattern{indentExpr e}\n{← mkHasTypeButIsExpectedMsg eType xType}"
private def List.moveToFront [Inhabited α] (as : List α) (i : Nat) : List α :=
let rec loop : (as : List α) → (i : Nat) → α × List α
| [], _ => unreachable!
| a::as, 0 => (a, as)
| a::as, i+1 =>
let (b, bs) := loop as i
(b, a::bs)
let (b, bs) := loop as i
b :: bs
/-- Move variable `#i` to the beginning of the to-do list `p.vars`. -/
private def moveToFront (p : Problem) (i : Nat) : Problem :=
if i == 0 then
p
else if i < p.vars.length then
{ p with
vars := List.moveToFront p.vars i
alts := p.alts.map fun alt => { alt with patterns := List.moveToFront alt.patterns i }
}
else
p
def Pattern.isRefutable : Pattern → Bool
| .var _ => false
| .inaccessible _ => false
| .as _ p _ => p.isRefutable
| .arrayLit .. => true
| .ctor .. => true
| .val .. => true
/--
Returns the index of the first pattern in the first alternative that is refutable
(i.e. not a variable or inaccessible pattern). We want to handle these first
so that the generated code branches in the order suggested by the user's code.
-/
private def firstRefutablePattern (p : Problem) : Option Nat :=
match p.alts with
| alt:: _ => alt.patterns.findIdx? (·.isRefutable)
| _ => none
def isExFalsoTransition (p : Problem) : MetaM Bool := do
if p.alts.isEmpty then
withGoalOf p do
@ -778,6 +827,21 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
process p
return
if backwards.match.rowMajor.get (← getOptions) then
match firstRefutablePattern p with
| some i =>
if i > 0 then
traceStep ("move var to front")
process (moveToFront p i)
return
| none =>
if 1 < p.alts.length then
traceStep ("drop all but first alt")
-- all patterns are irrefutable, we can drop all other alts
let p := { p with alts := p.alts.take 1 }
process p
return
if (← isNatValueTransition p) then
traceStep ("nat value to constructor")
process (← expandNatValuePattern p)

View file

@ -237,11 +237,11 @@ partial def unifyEqs? (numEqs : Nat) (mvarId : MVarId) (subst : FVarSubst) (case
return none
private def unifyCasesEqs (numEqs : Nat) (subgoals : Array CasesSubgoal) : MetaM (Array CasesSubgoal) :=
subgoals.foldlM (init := #[]) fun subgoals s => do
subgoals.filterMapM fun s => do
match (← unifyEqs? numEqs s.mvarId s.subst s.ctorName) with
| none => pure subgoals
| none => pure none
| some (mvarId, subst) =>
return subgoals.push { s with
return some { s with
mvarId := mvarId,
subst := subst,
fields := s.fields.map (subst.apply ·)

View file

@ -12,7 +12,7 @@ def test (a : List Nat) : Nat :=
/--
info: def test.match_1.{u_1} : (motive : List Nat → Sort u_1) →
(a : List Nat) → ((x : List Nat) → motive x) → (Unit → motive []) → motive a :=
fun motive a h_1 h_2 => List.casesOn a (h_1 []) fun head tail => h_1 (head :: tail)
fun motive a h_1 h_2 => h_1 a
-/
#guard_msgs in #print test.match_1
@ -31,7 +31,7 @@ info: def test2.match_1.{u_1} : (motive : List Nat → List Nat → Sort u_1)
(tail : List Nat) → (head_1 : Nat) → (tail_1 : List Nat) → motive (head :: tail) (head_1 :: tail_1)) →
motive a b :=
fun motive a b h_1 h_2 h_3 =>
List.casesOn a (List.casesOn b (h_1 []) fun head tail => h_1 (head :: tail)) fun head tail =>
List.casesOn a (h_1 b) fun head tail =>
List.casesOn b (h_2 (head :: tail)) fun head_1 tail_1 => h_3 head tail head_1 tail_1
-/
#guard_msgs in #print test2.match_1
@ -51,8 +51,7 @@ info: def test3.match_1.{u_1} : (motive : List Nat → Bool → Sort u_1) →
((x : List Nat) → motive x true) →
((x : Bool) → motive [] x) → ((x : List Nat) → (x_1 : Bool) → motive x x_1) → motive a b :=
fun motive a b h_1 h_2 h_3 =>
List.casesOn a (Bool.casesOn b (h_2 false) (h_1 [])) fun head tail =>
Bool.casesOn b (h_3 (head :: tail) false) (h_1 (head :: tail))
Bool.casesOn b (List.casesOn a (h_2 false) fun head tail => h_3 (head :: tail) false) (h_1 a)
-/
#guard_msgs in #print test3.match_1
@ -79,29 +78,33 @@ info: def test4.match_1.{u_1} : (motive : Bool → Bool → Bool → Bool → Bo
((x x_5 x_6 x_7 : Bool) → motive true x x_5 x_6 x_7) →
((x x_5 x_6 x_7 x_8 : Bool) → motive x x_5 x_6 x_7 x_8) → motive x x_1 x_2 x_3 x_4 :=
fun motive x x_1 x_2 x_3 x_4 h_1 h_2 h_3 h_4 h_5 h_6 =>
Bool.casesOn x
(Bool.casesOn x_1
Bool.casesOn x_4
(Bool.casesOn x_3
(Bool.casesOn x_2
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_6 false false false false false) (h_1 false false false false))
(Bool.casesOn x_4 (h_2 false false false false) (h_1 false false false true)))
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 false false false false) (h_1 false false true false))
(Bool.casesOn x_4 (h_2 false false true false) (h_1 false false true true))))
(Bool.casesOn x_2
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_4 false false false false) (h_1 false true false false))
(Bool.casesOn x_4 (h_2 false true false false) (h_1 false true false true)))
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 false true false false) (h_1 false true true false))
(Bool.casesOn x_4 (h_2 false true true false) (h_1 false true true true)))))
(Bool.casesOn x_1
(Bool.casesOn x_2
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_5 false false false false) (h_1 true false false false))
(Bool.casesOn x_4 (h_2 true false false false) (h_1 true false false true)))
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 true false false false) (h_1 true false true false))
(Bool.casesOn x_4 (h_2 true false true false) (h_1 true false true true))))
(Bool.casesOn x_2
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_4 true false false false) (h_1 true true false false))
(Bool.casesOn x_4 (h_2 true true false false) (h_1 true true false true)))
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 true true false false) (h_1 true true true false))
(Bool.casesOn x_4 (h_2 true true true false) (h_1 true true true true)))))
(Bool.casesOn x_1 (Bool.casesOn x (h_6 false false false false false) (h_5 false false false false))
(h_4 x false false false))
(h_3 x x_1 false false))
(h_2 x x_1 x_2 false))
(h_1 x x_1 x_2 x_3)
-/
#guard_msgs in
#print test4.match_1
-- Just testing the backwards compatibility option
set_option match.ignoreUnusedAlts true in
set_option backwards.match.rowMajor false in
def testOld (a : List Nat) : Nat :=
match a with
| _ => 3
| [] => 4
-- Has unnecessary `casesOn`
/--
info: def testOld.match_1.{u_1} : (motive : List Nat → Sort u_1) →
(a : List Nat) → ((x : List Nat) → motive x) → (Unit → motive []) → motive a :=
fun motive a h_1 h_2 => List.casesOn a (h_1 []) fun head tail => h_1 (head :: tail)
-/
#guard_msgs in
#print testOld.match_1

View file

@ -1,6 +1,8 @@
/--
error: Dependent match elimination failed: Could not solve constraints
true ≋ false
error: Dependent elimination failed: Type mismatch when solving this alternative: it has type
motive false
but is expected to have type
motive b✝
-/
#guard_msgs in
def test1 b := match b with
@ -8,8 +10,10 @@ def test1 b := match b with
| true => 2
/--
error: Dependent match elimination failed: Could not solve constraints
true ≋ false
error: Dependent elimination failed: Type mismatch when solving this alternative: it has type
motive false ⋯
but is expected to have type
motive x✝¹ x✝
-/
#guard_msgs in
def test2 : (b : Bool) → (h : b = false) → Nat

View file

@ -137,9 +137,9 @@ partial def natToBin' : (n : Nat) → List Bool
/--
error: Tactic `cases` failed with a nested error:
Dependent elimination failed: Failed to solve equation
Nat.zero = n✝.add n✝
n✝¹.succ = n✝.add n✝
at case `Parity.even` after processing
Nat.zero, _
(Nat.succ _), _
the dependent pattern matcher can solve the following kinds of equations
- <var> = <term> and <term> = <var>
- <term> = <term> where the terms are definitionally equal

View file

@ -16,9 +16,9 @@ else
/--
error: Tactic `cases` failed with a nested error:
Dependent elimination failed: Failed to solve equation
Nat.zero = n✝.add n✝
n✝¹.succ = n✝.add n✝
at case `Parity.even` after processing
Nat.zero, _
(Nat.succ _), _
the dependent pattern matcher can solve the following kinds of equations
- <var> = <term> and <term> = <var>
- <term> = <term> where the terms are definitionally equal
@ -56,9 +56,9 @@ def parity (n : MyNat) : Parity n := sorry
/--
error: Tactic `cases` failed with a nested error:
Dependent elimination failed: Failed to solve equation
zero = n✝.add n✝
a✝.succ = n✝.add n✝
at case `Parity.even` after processing
zero, _
(succ _), _
the dependent pattern matcher can solve the following kinds of equations
- <var> = <term> and <term> = <var>
- <term> = <term> where the terms are definitionally equal