From b7acc3881013db8dded32a6d682add340a4069b6 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 25 Mar 2021 12:20:05 -0700 Subject: [PATCH] feat: add support for "arrow" at `DiscrTree` --- src/Lean/Meta/DiscrTree.lean | 76 ++++++++++++++++++------------- src/Lean/Meta/DiscrTreeTypes.lean | 2 + 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/src/Lean/Meta/DiscrTree.lean b/src/Lean/Meta/DiscrTree.lean index f9e9563969..86dd79215b 100644 --- a/src/Lean/Meta/DiscrTree.lean +++ b/src/Lean/Meta/DiscrTree.lean @@ -54,6 +54,7 @@ def Key.ctorIdx : Key → Nat | Key.lit _ => 2 | Key.fvar _ _ => 3 | Key.const _ _ => 4 + | Key.arrow => 5 def Key.lt : Key → Key → Bool | Key.lit v₁, Key.lit v₂ => v₁ < v₂ @@ -71,40 +72,42 @@ def Key.format : Key → Format | Key.lit (Literal.strVal v) => repr v | Key.const k _ => fmt k | Key.fvar k _ => fmt k + | Key.arrow => "→" instance : ToFormat Key := ⟨Key.format⟩ def Key.arity : Key → Nat | Key.const _ a => a | Key.fvar _ a => a + | Key.arrow => 2 | _ => 0 -instance {α} : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩ +instance : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩ -def empty {α} : DiscrTree α := { root := {} } +def empty : DiscrTree α := { root := {} } -partial def Trie.format {α} [ToFormat α] : Trie α → Format +partial def Trie.format [ToFormat α] : Trie α → Format | Trie.node vs cs => Format.group $ Format.paren $ "node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs) ++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ format c)) -instance {α} [ToFormat α] : ToFormat (Trie α) := ⟨Trie.format⟩ +instance [ToFormat α] : ToFormat (Trie α) := ⟨Trie.format⟩ -partial def format {α} [ToFormat α] (d : DiscrTree α) : Format := +partial def format [ToFormat α] (d : DiscrTree α) : Format := let (_, r) := d.root.foldl (fun (p : Bool × Format) k c => (false, p.2 ++ (if p.1 then Format.nil else Format.line) ++ Format.paren (fmt k ++ " => " ++ fmt c))) (true, Format.nil) Format.group r -instance {α} [ToFormat α] : ToFormat (DiscrTree α) := ⟨format⟩ +instance [ToFormat α] : ToFormat (DiscrTree α) := ⟨format⟩ /- The discrimination tree ignores implicit arguments and proofs. We use the following auxiliary id as a "mark". -/ private def tmpMVarId : MVarId := `_discr_tree_tmp private def tmpStar := mkMVar tmpMVarId -instance {α} : Inhabited (DiscrTree α) where +instance : Inhabited (DiscrTree α) where default := {} /-- @@ -137,13 +140,13 @@ instance {α} : Inhabited (DiscrTree α) where Remark: if users have problems with the solution above, we may provide a `noIndexing` annotation, and `ignoreArg` would return true for any term of the form `noIndexing t`. -/ -private def ignoreArg (a : Expr) (i : Nat) (infos : Array ParamInfo) : MetaM Bool := +private def ignoreArg (a : Expr) (i : Nat) (infos : Array ParamInfo) : MetaM Bool := do if h : i < infos.size then let info := infos.get ⟨i, h⟩ if info.instImplicit then - pure true + return true else if info.implicit then - not <$> isType a + return not (← isType a) else isProof a else @@ -155,13 +158,13 @@ private partial def pushArgsAux (infos : Array ParamInfo) : Nat → Expr → Arr pushArgsAux infos (i-1) f (todo.push tmpStar) else pushArgsAux infos (i-1) f (todo.push a) - | _, _, todo => pure todo + | _, _, todo => return todo private partial def whnfEta (e : Expr) : MetaM Expr := do let e ← whnf e match e.etaExpandedStrict? with | some e => whnfEta e - | none => pure e + | none => return e /-- Return true if `e` is one of the following @@ -254,12 +257,17 @@ private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) : MetaM (Key return (Key.other, todo) else return (Key.star, todo) + | Expr.forallE _ d b _ => + if b.hasLooseBVars then + return (Key.other, todo) + else + return (Key.arrow, todo.push d |>.push b) | _ => return (Key.other, todo) partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array Key) : MetaM (Array Key) := do if todo.isEmpty then - pure keys + return keys else let e := todo.back let todo := todo.pop @@ -274,7 +282,7 @@ def mkPath (e : Expr) : MetaM (Array Key) := do let keys : Array Key := Array.mkEmpty initCapacity mkPathAux (root := true) (todo.push e) keys -private partial def createNodes {α} (keys : Array Key) (v : α) (i : Nat) : Trie α := +private partial def createNodes (keys : Array Key) (v : α) (i : Nat) : Trie α := if h : i < keys.size then let k := keys.get ⟨i, h⟩ let c := createNodes keys v (i+1) @@ -282,10 +290,10 @@ private partial def createNodes {α} (keys : Array Key) (v : α) (i : Nat) : Tri else Trie.node #[v] #[] -private def insertVal {α} [BEq α] (vs : Array α) (v : α) : Array α := +private def insertVal [BEq α] (vs : Array α) (v : α) : Array α := if vs.contains v then vs else vs.push v -private partial def insertAux {α} [BEq α] (keys : Array Key) (v : α) : Nat → Trie α → Trie α +private partial def insertAux [BEq α] (keys : Array Key) (v : α) : Nat → Trie α → Trie α | i, Trie.node vs cs => if h : i < keys.size then let k := keys.get ⟨i, h⟩ @@ -298,7 +306,7 @@ private partial def insertAux {α} [BEq α] (keys : Array Key) (v : α) : Nat else Trie.node (insertVal vs v) cs -def insertCore {α} [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α := +def insertCore [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α := if keys.isEmpty then panic! "invalid key sequence" else let k := keys[0] @@ -310,23 +318,23 @@ def insertCore {α} [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) : Di let c := insertAux keys v 1 c { root := d.root.insert k c } -def insert {α} [BEq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) := do +def insert [BEq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) := do let keys ← mkPath e return d.insertCore keys v private def getKeyArgs (e : Expr) (isMatch : Bool) : MetaM (Key × Array Expr) := do let e ← whnfEta e match e.getAppFn with - | Expr.lit v _ => pure (Key.lit v, #[]) + | Expr.lit v _ => return (Key.lit v, #[]) | Expr.const c _ _ => let nargs := e.getAppNumArgs - pure (Key.const c nargs, e.getAppRevArgs) + return (Key.const c nargs, e.getAppRevArgs) | Expr.fvar fvarId _ => let nargs := e.getAppNumArgs - pure (Key.fvar fvarId nargs, e.getAppRevArgs) + return (Key.fvar fvarId nargs, e.getAppRevArgs) | Expr.mvar mvarId _ => if isMatch then - pure (Key.other, #[]) + return (Key.other, #[]) else do let ctx ← read if ctx.config.isDefEqStuckEx then @@ -344,12 +352,18 @@ private def getKeyArgs (e : Expr) (isMatch : Bool) : MetaM (Key × Array Expr) : a regular metavariable here, otherwise we return the empty set of candidates. This is incorrect because it is equivalent to saying that there is no solution even if the caller assigns `?m` and try again. -/ - pure (Key.star, #[]) + return (Key.star, #[]) else if (← isReadOnlyOrSyntheticOpaqueExprMVar mvarId) then - pure (Key.other, #[]) + return (Key.other, #[]) else - pure (Key.star, #[]) - | _ => pure (Key.other, #[]) + return (Key.star, #[]) + | Expr.forallE _ d b _ => + if b.hasLooseBVars then + return (Key.other, #[]) + else + return (Key.arrow, #[d, b]) + | _ => + return (Key.other, #[]) private abbrev getMatchKeyArgs (e : Expr) : MetaM (Key × Array Expr) := getKeyArgs e (isMatch := true) @@ -357,21 +371,21 @@ private abbrev getMatchKeyArgs (e : Expr) : MetaM (Key × Array Expr) := private abbrev getUnifyKeyArgs (e : Expr) : MetaM (Key × Array Expr) := getKeyArgs e (isMatch := false) -private def getStarResult {α} (d : DiscrTree α) : Array α := +private def getStarResult (d : DiscrTree α) : Array α := let result : Array α := Array.mkEmpty initCapacity match d.root.find? Key.star with | none => result | some (Trie.node vs _) => result ++ vs -partial def getMatch {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) := +partial def getMatch (d : DiscrTree α) (e : Expr) : MetaM (Array α) := withReducible do let result := getStarResult d let (k, args) ← getMatchKeyArgs e match k with - | Key.star => pure result + | Key.star => return result | _ => match d.root.find? k with - | none => pure result + | none => return result | some c => process args c result where process (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do @@ -403,7 +417,7 @@ where let result ← visitStarChild result process (todo ++ args) c.2 result -partial def getUnify {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) := +partial def getUnify (d : DiscrTree α) (e : Expr) : MetaM (Array α) := withReducible do let (k, args) ← getUnifyKeyArgs e match k with diff --git a/src/Lean/Meta/DiscrTreeTypes.lean b/src/Lean/Meta/DiscrTreeTypes.lean index ef0d149a1f..bacbeb219d 100644 --- a/src/Lean/Meta/DiscrTreeTypes.lean +++ b/src/Lean/Meta/DiscrTreeTypes.lean @@ -17,6 +17,7 @@ inductive Key where | lit : Literal → Key | star : Key | other : Key + | arrow : Key deriving Inhabited, BEq protected def Key.hash : Key → USize @@ -25,6 +26,7 @@ protected def Key.hash : Key → USize | Key.lit v => mixHash 1879 $ hash v | Key.star => 7883 | Key.other => 2411 + | Key.arrow => 17 instance : Hashable Key := ⟨Key.hash⟩