feat: weight lazy discriminator tree results early matches (#3818)

The matches returned by the lazy discriminator tree are partially
constrained by a priority, but ties are broken by the order in which
keys are traversed and the order of declarations.

This PR changes the match key traversal to use an explicit stack rather
than recursion and implicitly changes the order in which results are
returned to favor left-matches first e.g., given the term `f a b` with
constants `f a b`, and a tree with patterns `f a x -> 1` `f x b -> 2`
that have the same priority, this will return `#[1, 2]` since the early
matches for the key `a` are returned before the match for `x` which has
a star.

This appears to address the [lower quality results mentioned on
zulip](https://leanprover.zulipchat.com/#narrow/stream/428973-nightly-testing/topic/Mathlib.20status.20updates/near/429955747).
This commit is contained in:
Joe Hendrix 2024-04-02 00:19:30 -07:00 committed by GitHub
parent c0027d3987
commit eacb1790b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 92 additions and 51 deletions

View file

@ -25,7 +25,6 @@ elaborated additional parts of the tree.
-/
namespace Lean.Meta.LazyDiscrTree
/--
Discrimination tree key.
-/
@ -580,41 +579,76 @@ partial def appendResults (mr : MatchResult α) (a : Array α) : Array α :=
end MatchResult
private partial def getMatchLoop (todo : Array Expr) (score : Nat) (c : TrieIndex)
(result : MatchResult α) : MatchM α (MatchResult α) := do
let (vs, star, cs) ← evalNode c
if todo.isEmpty then
return result.push score vs
else if star == 0 && cs.isEmpty then
return result
else
let e := todo.back
let todo := todo.pop
/- We must always visit `Key.star` edges since they are wildcards.
Thus, `todo` is not used linearly when there is `Key.star` edge
and there is an edge for `k` and `k != Key.star`. -/
let visitStar (result : MatchResult α) : MatchM α (MatchResult α) :=
if star != 0 then
getMatchLoop todo (score + 1) star result
else
return result
let visitNonStar (k : Key) (args : Array Expr) (result : MatchResult α) :=
match cs.find? k with
| none => return result
| some c => getMatchLoop (todo ++ args) (score + 1) c result
let result ← visitStar result
let (k, args) ← MatchClone.getMatchKeyArgs e (root := false) (←read)
match k with
| .star => return result
/-
Note: dep-arrow vs arrow
Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are
`(Key.arrow, #[a, b])`.
A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`).
Thus, we also visit the `Key.other` child.
-/
| .arrow => visitNonStar .other #[] (← visitNonStar k args result)
| _ => visitNonStar k args result
/-
A partial match captures the intermediate state of a match
execution.
N.B. The discriminator tree in Lean has non-determinism due to
star and function arrows, so matching loop maintains a stack of
partial match results.
-/
structure PartialMatch where
-- Remaining terms to match
todo : Array Expr
-- Number of non-star matches so far.
score : Nat
-- Trie to match next
c : TrieIndex
deriving Inhabited
/--
Evaluate all partial matches and add resulting matches to `MatchResult`.
The partial matches are stored in an array that is used as a stack. When adding
multiple partial matches to explore next, to ensure the order of results matches
user expectations, this code must add paths we want to prioritize and return
results earlier are added last.
-/
private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchResult α) : MatchM α (MatchResult α) := do
if cases.isEmpty then
pure result
else do
let ca := cases.back
let cases := cases.pop
let (vs, star, cs) ← evalNode ca.c
if ca.todo.isEmpty then
let result := result.push ca.score vs
getMatchLoop cases result
else if star == 0 && cs.isEmpty then
getMatchLoop cases result
else
let e := ca.todo.back
let todo := ca.todo.pop
/- We must always visit `Key.star` edges since they are wildcards.
Thus, `todo` is not used linearly when there is `Key.star` edge
and there is an edge for `k` and `k != Key.star`. -/
let pushStar (cases : Array PartialMatch) :=
if star = 0 then
cases
else
cases.push { todo, score := ca.score, c := star }
let pushNonStar (k : Key) (args : Array Expr) (cases : Array PartialMatch) :=
match cs.find? k with
| none => cases
| some c => cases.push { todo := todo ++ args, score := ca.score + 1, c }
let cases := pushStar cases
let (k, args) ← MatchClone.getMatchKeyArgs e (root := false) (← read)
let cases :=
match k with
| .star => cases
/-
Note: dep-arrow vs arrow
Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are
`(Key.arrow, #[a, b])`.
A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`).
Thus, we also visit the `Key.other` child.
-/
| .arrow =>
cases |> pushNonStar .other #[]
|> pushNonStar k args
| _ =>
cases |> pushNonStar k args
getMatchLoop cases result
private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
match root.find? .star with
@ -624,11 +658,14 @@ private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (Match
let (vs, _) ← evalNode idx
pure <| ({} : MatchResult α).push (score := 1) vs
private def getMatchRoot (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
(result : MatchResult α) : MatchM α (MatchResult α) :=
/-
Add partial match to cases if discriminator tree root map has potential matches.
-/
private def pushRootCase (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
(cases : Array PartialMatch) : Array PartialMatch :=
match r.find? k with
| none => pure result
| some c => getMatchLoop args (score := 1) c result
| none => cases
| some c => cases.push { todo := args, score := 1, c }
/--
Find values that match `e` in `root`.
@ -637,13 +674,17 @@ private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
MatchM α (MatchResult α) := do
let result ← getStarResult root
let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (← read)
match k with
| .star => return result
/- See note about "dep-arrow vs arrow" at `getMatchLoop` -/
| .arrow =>
getMatchRoot root k args (← getMatchRoot root .other #[] result)
| _ =>
getMatchRoot root k args result
let cases :=
match k with
| .star =>
#[]
/- See note about "dep-arrow vs arrow" at `getMatchLoop` -/
| .arrow =>
#[] |> pushRootCase root .other #[]
|> pushRootCase root k args
| _ =>
#[] |> pushRootCase root k args
getMatchLoop cases result
/--
Find values that match `e` in `d`.

View file

@ -21,7 +21,7 @@ noncomputable section
#guard_msgs in
example (x : Nat) : x ≠ x.succ := Nat.ne_of_lt (by apply?)
/-- info: Try this: exact Nat.lt_of_sub_eq_succ rfl -/
/-- info: Try this: exact Nat.zero_lt_succ 1 -/
#guard_msgs in
example : 0 ≠ 1 + 1 := Nat.ne_of_lt (by apply?)
@ -83,11 +83,11 @@ example (n m k : Nat) : n * m - n * k = n * (m - k) := by
#guard_msgs in
example {α : Type} (x y : α) : x = y ↔ y = x := by apply?
/-- info: Try this: exact Nat.lt_add_right b ha -/
/-- info: Try this: exact Nat.add_pos_left ha b -/
#guard_msgs in
example (a b : Nat) (ha : 0 < a) (_hb : 0 < b) : 0 < a + b := by apply?
/-- info: Try this: exact Nat.lt_add_right b ha -/
/-- info: Try this: exact Nat.add_pos_left ha b -/
#guard_msgs in
-- Verify that if maxHeartbeats is 0 we don't stop immediately.
set_option maxHeartbeats 0 in
@ -95,7 +95,7 @@ example (a b : Nat) (ha : 0 < a) (_hb : 0 < b) : 0 < a + b := by apply?
section synonym
/-- info: Try this: exact Nat.lt_add_right b ha -/
/-- info: Try this: exact Nat.add_pos_left ha b -/
#guard_msgs in
example (a b : Nat) (ha : a > 0) (_hb : 0 < b) : 0 < a + b := by apply?