diff --git a/src/Init/Lean/Meta/DiscrTree.lean b/src/Init/Lean/Meta/DiscrTree.lean index 1a816092db..a728a00625 100644 --- a/src/Init/Lean/Meta/DiscrTree.lean +++ b/src/Init/Lean/Meta/DiscrTree.lean @@ -106,6 +106,11 @@ def Key.format : Key → Format instance Key.hasFormat : HasFormat Key := ⟨Key.format⟩ +def Key.arity : Key → Nat +| Key.const _ a => a +| Key.fvar _ a => a +| _ => 0 + inductive Trie (α : Type) | node (vs : Array α) (children : Array (Key × Trie)) : Trie @@ -230,6 +235,12 @@ do e ← whnf e; (pure (Key.star, #[])) | _ => pure (Key.other, #[]) +private abbrev getMatchKeyArgs (e : Expr) : MetaM (Key × Array Expr) := +getKeyArgs e true + +private abbrev getUnifyKeyArgs (e : Expr) : MetaM (Key × Array Expr) := +getKeyArgs e false + private partial def getMatchAux {α} : Array Expr → Trie α → Array α → MetaM (Array α) | todo, Trie.node vs cs, result => if todo.isEmpty then pure $ result ++ vs @@ -238,7 +249,7 @@ private partial def getMatchAux {α} : Array Expr → Trie α → Array α → M let e := todo.back; let todo := todo.pop; let first := cs.get! 0; /- Recall that `Key.star` is the minimal key -/ - (k, args) ← getKeyArgs e true; + (k, args) ← getMatchKeyArgs e; /- We must always visit `Key.star` edges since they are wildcards. Thus, `todo` is not used linearly when there is `Key.star` edge and there is an edge for `k` and `k != Key.star`. -/ @@ -259,7 +270,7 @@ match d.root.find Key.star with def getMatch {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) := usingTransparency TransparencyMode.reducible $ do let result := getStarResult d; - (k, args) ← getKeyArgs e true; + (k, args) ← getMatchKeyArgs e; match k with | Key.star => pure result | _ => @@ -270,30 +281,16 @@ usingTransparency TransparencyMode.reducible $ do private partial def getUnifyAux {α} : Nat → Array Expr → Trie α → (Array α) → MetaM (Array α) | skip+1, todo, Trie.node vs cs, result => if cs.isEmpty then pure result - else - cs.foldlM - (fun result ⟨k, c⟩ => - match k with - | Key.const _ a => getUnifyAux (skip + a) todo c result - | Key.fvar _ a => getUnifyAux (skip + a) todo c result - | _ => getUnifyAux skip todo c result) - result + else cs.foldlM (fun result ⟨k, c⟩ => getUnifyAux (skip + k.arity) todo c result) result | 0, todo, Trie.node vs cs, result => if todo.isEmpty then pure (result ++ vs) else if cs.isEmpty then pure result else do let e := todo.back; let todo := todo.pop; - (k, args) ← getKeyArgs e true; + (k, args) ← getUnifyKeyArgs e; match k with - | Key.star => - cs.foldlM - (fun result ⟨k, c⟩ => - match k with - | Key.const _ a => getUnifyAux a todo c result - | Key.fvar _ a => getUnifyAux a todo c result - | _ => getUnifyAux 0 todo c result) - result + | Key.star => cs.foldlM (fun result ⟨k, c⟩ => getUnifyAux k.arity todo c result) result | _ => let first := cs.get! 0; let visitStarChild (result : Array α) : MetaM (Array α) := if first.1 == Key.star then getMatchAux todo first.2 result else pure result; @@ -303,15 +300,9 @@ private partial def getUnifyAux {α} : Nat → Array Expr → Trie α → (Array def getUnify {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) := usingTransparency TransparencyMode.reducible $ do - (k, args) ← getKeyArgs e true; + (k, args) ← getUnifyKeyArgs e; match k with - | Key.star => - d.root.foldlM - (fun result k c => match k with - | Key.const _ a => getUnifyAux a #[] c result - | Key.fvar _ a => getUnifyAux a #[] c result - | _ => getUnifyAux 0 #[] c result) - #[] + | Key.star => d.root.foldlM (fun result k c => getUnifyAux k.arity #[] c result) #[] | _ => let result := getStarResult d; match d.root.find k with