feat: weight lazy discriminator tree results early matches (#3818)
The matches returned by the lazy discriminator tree are partially constrained by a priority, but ties are broken by the order in which keys are traversed and the order of declarations. This PR changes the match key traversal to use an explicit stack rather than recursion and implicitly changes the order in which results are returned to favor left-matches first e.g., given the term `f a b` with constants `f a b`, and a tree with patterns `f a x -> 1` `f x b -> 2` that have the same priority, this will return `#[1, 2]` since the early matches for the key `a` are returned before the match for `x` which has a star. This appears to address the [lower quality results mentioned on zulip](https://leanprover.zulipchat.com/#narrow/stream/428973-nightly-testing/topic/Mathlib.20status.20updates/near/429955747).
This commit is contained in:
parent
c0027d3987
commit
eacb1790b3
2 changed files with 92 additions and 51 deletions
|
|
@ -25,7 +25,6 @@ elaborated additional parts of the tree.
|
|||
-/
|
||||
namespace Lean.Meta.LazyDiscrTree
|
||||
|
||||
|
||||
/--
|
||||
Discrimination tree key.
|
||||
-/
|
||||
|
|
@ -580,41 +579,76 @@ partial def appendResults (mr : MatchResult α) (a : Array α) : Array α :=
|
|||
|
||||
end MatchResult
|
||||
|
||||
private partial def getMatchLoop (todo : Array Expr) (score : Nat) (c : TrieIndex)
|
||||
(result : MatchResult α) : MatchM α (MatchResult α) := do
|
||||
let (vs, star, cs) ← evalNode c
|
||||
if todo.isEmpty then
|
||||
return result.push score vs
|
||||
else if star == 0 && cs.isEmpty then
|
||||
return result
|
||||
else
|
||||
let e := todo.back
|
||||
let todo := todo.pop
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let visitStar (result : MatchResult α) : MatchM α (MatchResult α) :=
|
||||
if star != 0 then
|
||||
getMatchLoop todo (score + 1) star result
|
||||
else
|
||||
return result
|
||||
let visitNonStar (k : Key) (args : Array Expr) (result : MatchResult α) :=
|
||||
match cs.find? k with
|
||||
| none => return result
|
||||
| some c => getMatchLoop (todo ++ args) (score + 1) c result
|
||||
let result ← visitStar result
|
||||
let (k, args) ← MatchClone.getMatchKeyArgs e (root := false) (←read)
|
||||
match k with
|
||||
| .star => return result
|
||||
/-
|
||||
Note: dep-arrow vs arrow
|
||||
Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are
|
||||
`(Key.arrow, #[a, b])`.
|
||||
A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`).
|
||||
Thus, we also visit the `Key.other` child.
|
||||
-/
|
||||
| .arrow => visitNonStar .other #[] (← visitNonStar k args result)
|
||||
| _ => visitNonStar k args result
|
||||
/-
|
||||
A partial match captures the intermediate state of a match
|
||||
execution.
|
||||
|
||||
N.B. The discriminator tree in Lean has non-determinism due to
|
||||
star and function arrows, so matching loop maintains a stack of
|
||||
partial match results.
|
||||
-/
|
||||
structure PartialMatch where
|
||||
-- Remaining terms to match
|
||||
todo : Array Expr
|
||||
-- Number of non-star matches so far.
|
||||
score : Nat
|
||||
-- Trie to match next
|
||||
c : TrieIndex
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
Evaluate all partial matches and add resulting matches to `MatchResult`.
|
||||
|
||||
The partial matches are stored in an array that is used as a stack. When adding
|
||||
multiple partial matches to explore next, to ensure the order of results matches
|
||||
user expectations, this code must add paths we want to prioritize and return
|
||||
results earlier are added last.
|
||||
-/
|
||||
private partial def getMatchLoop (cases : Array PartialMatch) (result : MatchResult α) : MatchM α (MatchResult α) := do
|
||||
if cases.isEmpty then
|
||||
pure result
|
||||
else do
|
||||
let ca := cases.back
|
||||
let cases := cases.pop
|
||||
let (vs, star, cs) ← evalNode ca.c
|
||||
if ca.todo.isEmpty then
|
||||
let result := result.push ca.score vs
|
||||
getMatchLoop cases result
|
||||
else if star == 0 && cs.isEmpty then
|
||||
getMatchLoop cases result
|
||||
else
|
||||
let e := ca.todo.back
|
||||
let todo := ca.todo.pop
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let pushStar (cases : Array PartialMatch) :=
|
||||
if star = 0 then
|
||||
cases
|
||||
else
|
||||
cases.push { todo, score := ca.score, c := star }
|
||||
let pushNonStar (k : Key) (args : Array Expr) (cases : Array PartialMatch) :=
|
||||
match cs.find? k with
|
||||
| none => cases
|
||||
| some c => cases.push { todo := todo ++ args, score := ca.score + 1, c }
|
||||
let cases := pushStar cases
|
||||
let (k, args) ← MatchClone.getMatchKeyArgs e (root := false) (← read)
|
||||
let cases :=
|
||||
match k with
|
||||
| .star => cases
|
||||
/-
|
||||
Note: dep-arrow vs arrow
|
||||
Recall that dependent arrows are `(Key.other, #[])`, and non-dependent arrows are
|
||||
`(Key.arrow, #[a, b])`.
|
||||
A non-dependent arrow may be an instance of a dependent arrow (stored at `DiscrTree`).
|
||||
Thus, we also visit the `Key.other` child.
|
||||
-/
|
||||
| .arrow =>
|
||||
cases |> pushNonStar .other #[]
|
||||
|> pushNonStar k args
|
||||
| _ =>
|
||||
cases |> pushNonStar k args
|
||||
getMatchLoop cases result
|
||||
|
||||
private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (MatchResult α) :=
|
||||
match root.find? .star with
|
||||
|
|
@ -624,11 +658,14 @@ private def getStarResult (root : Lean.HashMap Key TrieIndex) : MatchM α (Match
|
|||
let (vs, _) ← evalNode idx
|
||||
pure <| ({} : MatchResult α).push (score := 1) vs
|
||||
|
||||
private def getMatchRoot (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
|
||||
(result : MatchResult α) : MatchM α (MatchResult α) :=
|
||||
/-
|
||||
Add partial match to cases if discriminator tree root map has potential matches.
|
||||
-/
|
||||
private def pushRootCase (r : Lean.HashMap Key TrieIndex) (k : Key) (args : Array Expr)
|
||||
(cases : Array PartialMatch) : Array PartialMatch :=
|
||||
match r.find? k with
|
||||
| none => pure result
|
||||
| some c => getMatchLoop args (score := 1) c result
|
||||
| none => cases
|
||||
| some c => cases.push { todo := args, score := 1, c }
|
||||
|
||||
/--
|
||||
Find values that match `e` in `root`.
|
||||
|
|
@ -637,13 +674,17 @@ private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
|
|||
MatchM α (MatchResult α) := do
|
||||
let result ← getStarResult root
|
||||
let (k, args) ← MatchClone.getMatchKeyArgs e (root := true) (← read)
|
||||
match k with
|
||||
| .star => return result
|
||||
/- See note about "dep-arrow vs arrow" at `getMatchLoop` -/
|
||||
| .arrow =>
|
||||
getMatchRoot root k args (← getMatchRoot root .other #[] result)
|
||||
| _ =>
|
||||
getMatchRoot root k args result
|
||||
let cases :=
|
||||
match k with
|
||||
| .star =>
|
||||
#[]
|
||||
/- See note about "dep-arrow vs arrow" at `getMatchLoop` -/
|
||||
| .arrow =>
|
||||
#[] |> pushRootCase root .other #[]
|
||||
|> pushRootCase root k args
|
||||
| _ =>
|
||||
#[] |> pushRootCase root k args
|
||||
getMatchLoop cases result
|
||||
|
||||
/--
|
||||
Find values that match `e` in `d`.
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ noncomputable section
|
|||
#guard_msgs in
|
||||
example (x : Nat) : x ≠ x.succ := Nat.ne_of_lt (by apply?)
|
||||
|
||||
/-- info: Try this: exact Nat.lt_of_sub_eq_succ rfl -/
|
||||
/-- info: Try this: exact Nat.zero_lt_succ 1 -/
|
||||
#guard_msgs in
|
||||
example : 0 ≠ 1 + 1 := Nat.ne_of_lt (by apply?)
|
||||
|
||||
|
|
@ -83,11 +83,11 @@ example (n m k : Nat) : n * m - n * k = n * (m - k) := by
|
|||
#guard_msgs in
|
||||
example {α : Type} (x y : α) : x = y ↔ y = x := by apply?
|
||||
|
||||
/-- info: Try this: exact Nat.lt_add_right b ha -/
|
||||
/-- info: Try this: exact Nat.add_pos_left ha b -/
|
||||
#guard_msgs in
|
||||
example (a b : Nat) (ha : 0 < a) (_hb : 0 < b) : 0 < a + b := by apply?
|
||||
|
||||
/-- info: Try this: exact Nat.lt_add_right b ha -/
|
||||
/-- info: Try this: exact Nat.add_pos_left ha b -/
|
||||
#guard_msgs in
|
||||
-- Verify that if maxHeartbeats is 0 we don't stop immediately.
|
||||
set_option maxHeartbeats 0 in
|
||||
|
|
@ -95,7 +95,7 @@ example (a b : Nat) (ha : 0 < a) (_hb : 0 < b) : 0 < a + b := by apply?
|
|||
|
||||
section synonym
|
||||
|
||||
/-- info: Try this: exact Nat.lt_add_right b ha -/
|
||||
/-- info: Try this: exact Nat.add_pos_left ha b -/
|
||||
#guard_msgs in
|
||||
example (a b : Nat) (ha : a > 0) (_hb : 0 < b) : 0 < a + b := by apply?
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue