fix: avoid unnecessary branching in match compilation (#10763)
This PR improves match compilation: Branch on variables in the order suggested by the first remaining alternative, and do not branch when the first remaining alternative does not require it. This fixes https://github.com/leanprover/lean4/issues/10749. With `set_option backwards.match.rowMajor false` the old behavior can be turned on. (For now this is an experiment to get familiar with the code and the whole problem domain. It is likely overly naive.)
This commit is contained in:
parent
275f9077b6
commit
c7f57d6a0b
6 changed files with 110 additions and 39 deletions
|
|
@ -17,6 +17,15 @@ public section
|
|||
|
||||
namespace Lean.Meta.Match
|
||||
|
||||
register_builtin_option backwards.match.rowMajor : Bool := {
|
||||
defValue := true
|
||||
group := "bootstrap"
|
||||
descr := "If true (the default), match compilation will split the discrimnants based \
|
||||
on position of the first constructor pattern in the first alternative. If false, \
|
||||
it splits them from left to right, which can lead to unnecessary code bloat."
|
||||
}
|
||||
|
||||
|
||||
private def mkIncorrectNumberOfPatternsMsg [ToMessageData α]
|
||||
(discrepancyKind : String) (expected actual : Nat) (pats : List α) :=
|
||||
let patternsMsg := MessageData.joinSep (pats.map toMessageData) ", "
|
||||
|
|
@ -747,6 +756,46 @@ private def checkNextPatternTypes (p : Problem) : MetaM Unit := do
|
|||
unless (← isDefEq xType eType) do
|
||||
throwError "Type mismatch in pattern: Pattern{indentExpr e}\n{← mkHasTypeButIsExpectedMsg eType xType}"
|
||||
|
||||
private def List.moveToFront [Inhabited α] (as : List α) (i : Nat) : List α :=
|
||||
let rec loop : (as : List α) → (i : Nat) → α × List α
|
||||
| [], _ => unreachable!
|
||||
| a::as, 0 => (a, as)
|
||||
| a::as, i+1 =>
|
||||
let (b, bs) := loop as i
|
||||
(b, a::bs)
|
||||
let (b, bs) := loop as i
|
||||
b :: bs
|
||||
|
||||
/-- Move variable `#i` to the beginning of the to-do list `p.vars`. -/
|
||||
private def moveToFront (p : Problem) (i : Nat) : Problem :=
|
||||
if i == 0 then
|
||||
p
|
||||
else if i < p.vars.length then
|
||||
{ p with
|
||||
vars := List.moveToFront p.vars i
|
||||
alts := p.alts.map fun alt => { alt with patterns := List.moveToFront alt.patterns i }
|
||||
}
|
||||
else
|
||||
p
|
||||
|
||||
def Pattern.isRefutable : Pattern → Bool
|
||||
| .var _ => false
|
||||
| .inaccessible _ => false
|
||||
| .as _ p _ => p.isRefutable
|
||||
| .arrayLit .. => true
|
||||
| .ctor .. => true
|
||||
| .val .. => true
|
||||
|
||||
/--
|
||||
Returns the index of the first pattern in the first alternative that is refutable
|
||||
(i.e. not a variable or inaccessible pattern). We want to handle these first
|
||||
so that the generated code branches in the order suggested by the user's code.
|
||||
-/
|
||||
private def firstRefutablePattern (p : Problem) : Option Nat :=
|
||||
match p.alts with
|
||||
| alt:: _ => alt.patterns.findIdx? (·.isRefutable)
|
||||
| _ => none
|
||||
|
||||
def isExFalsoTransition (p : Problem) : MetaM Bool := do
|
||||
if p.alts.isEmpty then
|
||||
withGoalOf p do
|
||||
|
|
@ -778,6 +827,21 @@ private partial def process (p : Problem) : StateRefT State MetaM Unit := do
|
|||
process p
|
||||
return
|
||||
|
||||
if backwards.match.rowMajor.get (← getOptions) then
|
||||
match firstRefutablePattern p with
|
||||
| some i =>
|
||||
if i > 0 then
|
||||
traceStep ("move var to front")
|
||||
process (moveToFront p i)
|
||||
return
|
||||
| none =>
|
||||
if 1 < p.alts.length then
|
||||
traceStep ("drop all but first alt")
|
||||
-- all patterns are irrefutable, we can drop all other alts
|
||||
let p := { p with alts := p.alts.take 1 }
|
||||
process p
|
||||
return
|
||||
|
||||
if (← isNatValueTransition p) then
|
||||
traceStep ("nat value to constructor")
|
||||
process (← expandNatValuePattern p)
|
||||
|
|
|
|||
|
|
@ -237,11 +237,11 @@ partial def unifyEqs? (numEqs : Nat) (mvarId : MVarId) (subst : FVarSubst) (case
|
|||
return none
|
||||
|
||||
private def unifyCasesEqs (numEqs : Nat) (subgoals : Array CasesSubgoal) : MetaM (Array CasesSubgoal) :=
|
||||
subgoals.foldlM (init := #[]) fun subgoals s => do
|
||||
subgoals.filterMapM fun s => do
|
||||
match (← unifyEqs? numEqs s.mvarId s.subst s.ctorName) with
|
||||
| none => pure subgoals
|
||||
| none => pure none
|
||||
| some (mvarId, subst) =>
|
||||
return subgoals.push { s with
|
||||
return some { s with
|
||||
mvarId := mvarId,
|
||||
subst := subst,
|
||||
fields := s.fields.map (subst.apply ·)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ def test (a : List Nat) : Nat :=
|
|||
/--
|
||||
info: def test.match_1.{u_1} : (motive : List Nat → Sort u_1) →
|
||||
(a : List Nat) → ((x : List Nat) → motive x) → (Unit → motive []) → motive a :=
|
||||
fun motive a h_1 h_2 => List.casesOn a (h_1 []) fun head tail => h_1 (head :: tail)
|
||||
fun motive a h_1 h_2 => h_1 a
|
||||
-/
|
||||
#guard_msgs in #print test.match_1
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ info: def test2.match_1.{u_1} : (motive : List Nat → List Nat → Sort u_1)
|
|||
(tail : List Nat) → (head_1 : Nat) → (tail_1 : List Nat) → motive (head :: tail) (head_1 :: tail_1)) →
|
||||
motive a b :=
|
||||
fun motive a b h_1 h_2 h_3 =>
|
||||
List.casesOn a (List.casesOn b (h_1 []) fun head tail => h_1 (head :: tail)) fun head tail =>
|
||||
List.casesOn a (h_1 b) fun head tail =>
|
||||
List.casesOn b (h_2 (head :: tail)) fun head_1 tail_1 => h_3 head tail head_1 tail_1
|
||||
-/
|
||||
#guard_msgs in #print test2.match_1
|
||||
|
|
@ -51,8 +51,7 @@ info: def test3.match_1.{u_1} : (motive : List Nat → Bool → Sort u_1) →
|
|||
((x : List Nat) → motive x true) →
|
||||
((x : Bool) → motive [] x) → ((x : List Nat) → (x_1 : Bool) → motive x x_1) → motive a b :=
|
||||
fun motive a b h_1 h_2 h_3 =>
|
||||
List.casesOn a (Bool.casesOn b (h_2 false) (h_1 [])) fun head tail =>
|
||||
Bool.casesOn b (h_3 (head :: tail) false) (h_1 (head :: tail))
|
||||
Bool.casesOn b (List.casesOn a (h_2 false) fun head tail => h_3 (head :: tail) false) (h_1 a)
|
||||
-/
|
||||
#guard_msgs in #print test3.match_1
|
||||
|
||||
|
|
@ -79,29 +78,33 @@ info: def test4.match_1.{u_1} : (motive : Bool → Bool → Bool → Bool → Bo
|
|||
((x x_5 x_6 x_7 : Bool) → motive true x x_5 x_6 x_7) →
|
||||
((x x_5 x_6 x_7 x_8 : Bool) → motive x x_5 x_6 x_7 x_8) → motive x x_1 x_2 x_3 x_4 :=
|
||||
fun motive x x_1 x_2 x_3 x_4 h_1 h_2 h_3 h_4 h_5 h_6 =>
|
||||
Bool.casesOn x
|
||||
(Bool.casesOn x_1
|
||||
Bool.casesOn x_4
|
||||
(Bool.casesOn x_3
|
||||
(Bool.casesOn x_2
|
||||
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_6 false false false false false) (h_1 false false false false))
|
||||
(Bool.casesOn x_4 (h_2 false false false false) (h_1 false false false true)))
|
||||
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 false false false false) (h_1 false false true false))
|
||||
(Bool.casesOn x_4 (h_2 false false true false) (h_1 false false true true))))
|
||||
(Bool.casesOn x_2
|
||||
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_4 false false false false) (h_1 false true false false))
|
||||
(Bool.casesOn x_4 (h_2 false true false false) (h_1 false true false true)))
|
||||
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 false true false false) (h_1 false true true false))
|
||||
(Bool.casesOn x_4 (h_2 false true true false) (h_1 false true true true)))))
|
||||
(Bool.casesOn x_1
|
||||
(Bool.casesOn x_2
|
||||
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_5 false false false false) (h_1 true false false false))
|
||||
(Bool.casesOn x_4 (h_2 true false false false) (h_1 true false false true)))
|
||||
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 true false false false) (h_1 true false true false))
|
||||
(Bool.casesOn x_4 (h_2 true false true false) (h_1 true false true true))))
|
||||
(Bool.casesOn x_2
|
||||
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_4 true false false false) (h_1 true true false false))
|
||||
(Bool.casesOn x_4 (h_2 true true false false) (h_1 true true false true)))
|
||||
(Bool.casesOn x_3 (Bool.casesOn x_4 (h_3 true true false false) (h_1 true true true false))
|
||||
(Bool.casesOn x_4 (h_2 true true true false) (h_1 true true true true)))))
|
||||
(Bool.casesOn x_1 (Bool.casesOn x (h_6 false false false false false) (h_5 false false false false))
|
||||
(h_4 x false false false))
|
||||
(h_3 x x_1 false false))
|
||||
(h_2 x x_1 x_2 false))
|
||||
(h_1 x x_1 x_2 x_3)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print test4.match_1
|
||||
|
||||
-- Just testing the backwards compatibility option
|
||||
|
||||
set_option match.ignoreUnusedAlts true in
|
||||
set_option backwards.match.rowMajor false in
|
||||
def testOld (a : List Nat) : Nat :=
|
||||
match a with
|
||||
| _ => 3
|
||||
| [] => 4
|
||||
|
||||
-- Has unnecessary `casesOn`
|
||||
|
||||
/--
|
||||
info: def testOld.match_1.{u_1} : (motive : List Nat → Sort u_1) →
|
||||
(a : List Nat) → ((x : List Nat) → motive x) → (Unit → motive []) → motive a :=
|
||||
fun motive a h_1 h_2 => List.casesOn a (h_1 []) fun head tail => h_1 (head :: tail)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print testOld.match_1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
/--
|
||||
error: Dependent match elimination failed: Could not solve constraints
|
||||
true ≋ false
|
||||
error: Dependent elimination failed: Type mismatch when solving this alternative: it has type
|
||||
motive false
|
||||
but is expected to have type
|
||||
motive b✝
|
||||
-/
|
||||
#guard_msgs in
|
||||
def test1 b := match b with
|
||||
|
|
@ -8,8 +10,10 @@ def test1 b := match b with
|
|||
| true => 2
|
||||
|
||||
/--
|
||||
error: Dependent match elimination failed: Could not solve constraints
|
||||
true ≋ false
|
||||
error: Dependent elimination failed: Type mismatch when solving this alternative: it has type
|
||||
motive false ⋯
|
||||
but is expected to have type
|
||||
motive x✝¹ x✝
|
||||
-/
|
||||
#guard_msgs in
|
||||
def test2 : (b : Bool) → (h : b = false) → Nat
|
||||
|
|
|
|||
|
|
@ -137,9 +137,9 @@ partial def natToBin' : (n : Nat) → List Bool
|
|||
/--
|
||||
error: Tactic `cases` failed with a nested error:
|
||||
Dependent elimination failed: Failed to solve equation
|
||||
Nat.zero = n✝.add n✝
|
||||
n✝¹.succ = n✝.add n✝
|
||||
at case `Parity.even` after processing
|
||||
Nat.zero, _
|
||||
(Nat.succ _), _
|
||||
the dependent pattern matcher can solve the following kinds of equations
|
||||
- <var> = <term> and <term> = <var>
|
||||
- <term> = <term> where the terms are definitionally equal
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ else
|
|||
/--
|
||||
error: Tactic `cases` failed with a nested error:
|
||||
Dependent elimination failed: Failed to solve equation
|
||||
Nat.zero = n✝.add n✝
|
||||
n✝¹.succ = n✝.add n✝
|
||||
at case `Parity.even` after processing
|
||||
Nat.zero, _
|
||||
(Nat.succ _), _
|
||||
the dependent pattern matcher can solve the following kinds of equations
|
||||
- <var> = <term> and <term> = <var>
|
||||
- <term> = <term> where the terms are definitionally equal
|
||||
|
|
@ -56,9 +56,9 @@ def parity (n : MyNat) : Parity n := sorry
|
|||
/--
|
||||
error: Tactic `cases` failed with a nested error:
|
||||
Dependent elimination failed: Failed to solve equation
|
||||
zero = n✝.add n✝
|
||||
a✝.succ = n✝.add n✝
|
||||
at case `Parity.even` after processing
|
||||
zero, _
|
||||
(succ _), _
|
||||
the dependent pattern matcher can solve the following kinds of equations
|
||||
- <var> = <term> and <term> = <var>
|
||||
- <term> = <term> where the terms are definitionally equal
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue