diff --git a/src/Lean/Meta/Tactic/Grind/EMatch.lean b/src/Lean/Meta/Tactic/Grind/EMatch.lean index 5ca3b7aaa9..37a3ecc63d 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatch.lean @@ -8,14 +8,14 @@ import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Internalize namespace Lean.Meta.Grind - -/-- Returns maximum term generation that is considered during ematching -/ -private def getMaxGeneration : GoalM Nat := do - return 10000 -- TODO - -/-- Returns `true` if the maximum number of instances has been reached. -/ -private def checkMaxInstancesExceeded : GoalM Bool := do - return false -- TODO +/-- +Theorem instance found using E-matching. +Recall that we only internalize new instances after we complete a full round of E-matching. -/ +structure EMatchTheoremInstance where + proof : Expr + prop : Expr + generation : Nat + deriving Inhabited namespace EMatch /-! This module implements a simple E-matching procedure as a backtracking search. -/ @@ -51,13 +51,6 @@ structure Choice where assignment : Array Expr deriving Inhabited -/-- Theorem instances found so far. We only internalize them after we complete a full round of E-matching. -/ -structure TheoremInstance where - proof : Expr - prop : Expr - generation : Nat - deriving Inhabited - /-- Context for the E-matching monad. -/ structure Context where /-- `useMT` is `true` if we are using the mod-time optimization. It is always set to false for new `EMatchTheorem`s. -/ @@ -70,7 +63,7 @@ structure Context where structure State where /-- Choices that still have to be processed. -/ choiceStack : List Choice := [] - newInstances : PArray TheoremInstance := {} + newInstances : Array EMatchTheoremInstance := #[] deriving Inhabited abbrev M := ReaderT Context $ StateRefT State GoalM @@ -181,6 +174,8 @@ Missing parameters are synthesized using type inference and type class synthesis -/ private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do let thm := (← read).thm + unless (← markTheorenInstance thm.proof c.assignment) do + return () trace[grind.ematch.instance.assignment] "{← thm.origin.pp}: {assignmentToMessageData c.assignment}" let proof ← thm.getProofWithFreshMVarLevels let numParams := thm.numParams @@ -285,22 +280,26 @@ where def ematchTheorems (thms : PArray EMatchTheorem) : M Unit := do thms.forM ematchTheorem -def internalizeNewInstances : M Unit := do - -- TODO - return () - end EMatch open EMatch -/-- Performs one round of E-matching, and internalizes new instances. -/ -def ematch : GoalM Unit := do - let go (thms newThms : PArray EMatchTheorem) : EMatch.M Unit := do +/-- Performs one round of E-matching, and returns new instances. -/ +def ematch : GoalM (Array EMatchTheoremInstance) := do + let go (thms newThms : PArray EMatchTheorem) : EMatch.M (Array EMatchTheoremInstance) := do withReader (fun ctx => { ctx with useMT := true }) <| ematchTheorems thms withReader (fun ctx => { ctx with useMT := false }) <| ematchTheorems newThms - internalizeNewInstances - unless (← checkMaxInstancesExceeded) do - go (← get).thms (← get).newThms |>.run' - modify fun s => { s with thms := s.thms ++ s.newThms, newThms := {}, gmt := s.gmt + 1 } + return (← get).newInstances + if (← checkMaxInstancesExceeded) then + return #[] + else + let insts ← go (← get).thms (← get).newThms |>.run' + modify fun s => { s with + thms := s.thms ++ s.newThms + newThms := {} + gmt := s.gmt + 1 + numInstances := s.numInstances + insts.size + } + return insts end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean index 341e2b7e50..75d650ea2a 100644 --- a/src/Lean/Meta/Tactic/Grind/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -59,7 +59,10 @@ private partial def activateTheoremPatterns (fName : Name) (generation : Nat) : let thm := { thm with symbols } match symbols with | [] => - let thm := { thm with patterns := (← thm.patterns.mapM (internalizePattern · generation)) } + -- Recall that we use the proof as part of the key for a set of instances found so far. + -- We don't want to use structural equality when comparing keys. + let proof ← shareCommon thm.proof + let thm := { thm with proof, patterns := (← thm.patterns.mapM (internalizePattern · generation)) } trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}" modify fun s => { s with newThms := s.newThms.push thm } | _ => diff --git a/src/Lean/Meta/Tactic/Grind/Preprocessor.lean b/src/Lean/Meta/Tactic/Grind/Preprocessor.lean index c8ebf593a1..bb638daffe 100644 --- a/src/Lean/Meta/Tactic/Grind/Preprocessor.lean +++ b/src/Lean/Meta/Tactic/Grind/Preprocessor.lean @@ -144,7 +144,7 @@ def preprocess (mvarId : MVarId) : PreM State := do loop (← mkGoal mvarId) let goals := (← get).goals -- Testing `ematch` module here. We will rewrite this part later. - let goals ← goals.mapM fun goal => GoalM.run' goal ematch + let goals ← goals.mapM fun goal => GoalM.run' goal (discard <| ematch) if (← isTracingEnabledFor `grind.pre) then trace[grind.debug.pre] (← ppGoals goals) for goal in goals do diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 438f665bd4..754f2c1e57 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -101,6 +101,11 @@ def getMainDeclName : GrindM Name := @[inline] def getMethodsRef : GrindM MethodsRef := read +/-- +Returns maximum term generation that is considered during ematching. -/ +def getMaxGeneration : GrindM Nat := do + return 10000 -- TODO + /-- Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization. -/ @@ -193,31 +198,44 @@ structure NewEq where proof : Expr isHEq : Bool -abbrev ENodes := PHashMap USize ENode +/-- +Key for the `ENodeMap` and `ParentMap` map. +We use pointer addresses and rely on the fact all internalized expressions +have been hash-consed, i.e., we have applied `shareCommon`. +-/ +private structure ENodeKey where + expr : Expr -structure CongrKey (enodes : ENodes) where +instance : Hashable ENodeKey where + hash k := unsafe (ptrAddrUnsafe k.expr).toUInt64 + +instance : BEq ENodeKey where + beq k₁ k₂ := isSameExpr k₁.expr k₂.expr + +abbrev ENodeMap := PHashMap ENodeKey ENode + +/-- +Key for the congruence table. +We need access to the `enodes` to be able to retrieve the equivalence class roots. +-/ +structure CongrKey (enodes : ENodeMap) where e : Expr -private abbrev toENodeKey (e : Expr) : USize := - unsafe ptrAddrUnsafe e - -private def hashRoot (enodes : ENodes) (e : Expr) : UInt64 := - if let some node := enodes.find? (toENodeKey e) then - toENodeKey node.root |>.toUInt64 +private def hashRoot (enodes : ENodeMap) (e : Expr) : UInt64 := + if let some node := enodes.find? { expr := e } then + unsafe (ptrAddrUnsafe node.root).toUInt64 else 13 -private def hasSameRoot (enodes : ENodes) (a b : Expr) : Bool := Id.run do - let ka := toENodeKey a - let kb := toENodeKey b - if ka == kb then +private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do + if isSameExpr a b then return true else - let some n1 := enodes.find? ka | return false - let some n2 := enodes.find? kb | return false - toENodeKey n1.root == toENodeKey n2.root + let some n1 := enodes.find? { expr := a } | return false + let some n2 := enodes.find? { expr := b } | return false + isSameExpr n1.root n2.root -def congrHash (enodes : ENodes) (e : Expr) : UInt64 := +def congrHash (enodes : ENodeMap) (e : Expr) : UInt64 := if e.isAppOfArity ``Lean.Grind.nestedProof 2 then -- We only hash the proposition hashRoot enodes (e.getArg! 0) @@ -229,7 +247,7 @@ where | .app f a => go f (mixHash r (hashRoot enodes a)) | _ => mixHash r (hashRoot enodes e) -partial def isCongruent (enodes : ENodes) (a b : Expr) : Bool := +partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool := if a.isAppOfArity ``Lean.Grind.nestedProof 2 && b.isAppOfArity ``Lean.Grind.nestedProof 2 then hasSameRoot enodes (a.getArg! 0) (b.getArg! 0) else @@ -249,15 +267,43 @@ instance : Hashable (CongrKey enodes) where instance : BEq (CongrKey enodes) where beq k1 k2 := isCongruent enodes k1.e k2.e -abbrev CongrTable (enodes : ENodes) := PHashSet (CongrKey enodes) +abbrev CongrTable (enodes : ENodeMap) := PHashSet (CongrKey enodes) -- Remark: we cannot use pointer addresses here because we have to traverse the tree. abbrev ParentSet := RBTree Expr Expr.quickComp -abbrev ParentMap := PHashMap USize ParentSet +abbrev ParentMap := PHashMap ENodeKey ParentSet + +/-- +The E-matching module instantiates theorems using the `EMatchTheorem proof` and a (partial) assignment. +We want to avoid instantiating the same theorem with the same assignment more than once. +Therefore, we store the (pre-)instance information in set. +Recall that the proofs of activated theorems have been hash-consed. +The assignment contains internalized expressions, which have also been hash-consed. +-/ +structure PreInstance where + proof : Expr + assignment : Array Expr + +instance : Hashable PreInstance where + hash i := Id.run do + let mut r := unsafe (ptrAddrUnsafe i.proof >>> 3).toUInt64 + for v in i.assignment do + r := mixHash r (unsafe (ptrAddrUnsafe v >>> 3).toUInt64) + return r + +instance : BEq PreInstance where + beq i₁ i₂ := Id.run do + unless isSameExpr i₁.proof i₂.proof do return false + unless i₁.assignment.size == i₂.assignment.size do return false + for v₁ in i₁.assignment, v₂ in i₂.assignment do + unless isSameExpr v₁ v₂ do return false + return true + +abbrev PreInstanceSet := PHashSet PreInstance structure Goal where mvarId : MVarId - enodes : ENodes := {} + enodes : ENodeMap := {} parents : ParentMap := {} congrTable : CongrTable enodes := {} /-- @@ -285,6 +331,8 @@ structure Goal where thmMap : EMatchTheorems /-- Number of theorem instances generated so far -/ numInstances : Nat := 0 + /-- (pre-)instances found so far -/ + instances : PreInstanceSet := {} deriving Inhabited def Goal.admit (goal : Goal) : MetaM Unit := @@ -294,6 +342,21 @@ abbrev GoalM := StateRefT Goal GrindM abbrev Propagator := Expr → GoalM Unit +/-- +A helper function used to mark a theorem instance found by the E-matching module. +It returns `true` if it is a new instance and `false` otherwise. +-/ +def markTheorenInstance (proof : Expr) (assignment : Array Expr) : GoalM Bool := do + let k := { proof, assignment } + if (← get).instances.contains k then + return false + modify fun s => { s with instances := s.instances.insert k } + return true + +/-- Returns `true` if the maximum number of instances has been reached. -/ +def checkMaxInstancesExceeded : GoalM Bool := do + return false -- TODO + /-- Returns `true` if `e` is the internalized `True` expression. -/ def isTrueExpr (e : Expr) : GrindM Bool := return isSameExpr e (← getTrueExpr) @@ -307,11 +370,11 @@ Returns `some n` if `e` has already been "internalized" into the Otherwise, returns `none`s. -/ def getENode? (e : Expr) : GoalM (Option ENode) := - return (← get).enodes.find? (unsafe ptrAddrUnsafe e) + return (← get).enodes.find? { expr := e } /-- Returns node associated with `e`. It assumes `e` has already been internalized. -/ def getENode (e : Expr) : GoalM ENode := do - let some n := (← get).enodes.find? (unsafe ptrAddrUnsafe e) + let some n := (← get).enodes.find? { expr := e } | throwError "internal `grind` error, term has not been internalized{indentExpr e}" return n @@ -362,7 +425,7 @@ def getNext (e : Expr) : GoalM Expr := /-- Returns `true` if `e` has already been internalized. -/ def alreadyInternalized (e : Expr) : GoalM Bool := - return (← get).enodes.contains (unsafe ptrAddrUnsafe e) + return (← get).enodes.contains { expr := e } def getTarget? (e : Expr) : GoalM (Option Expr) := do let some n ← getENode? e | return none @@ -407,9 +470,8 @@ information in the root (aka canonical representative) of `child`. -/ def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do let some childRoot ← getRoot? child | return () - let key := toENodeKey childRoot - let parents := if let some parents := (← get).parents.find? key then parents else {} - modify fun s => { s with parents := s.parents.insert key (parents.insert parent) } + let parents := if let some parents := (← get).parents.find? { expr := childRoot } then parents else {} + modify fun s => { s with parents := s.parents.insert { expr := childRoot } (parents.insert parent) } /-- Returns the set of expressions `e` is a child of, or an expression in @@ -417,7 +479,7 @@ Returns the set of expressions `e` is a child of, or an expression in The information is only up to date if `e` is the root (aka canonical representative) of the equivalence class. -/ def getParents (e : Expr) : GoalM ParentSet := do - let some parents := (← get).parents.find? (toENodeKey e) | return {} + let some parents := (← get).parents.find? { expr := e } | return {} return parents /-- @@ -425,7 +487,7 @@ Similar to `getParents`, but also removes the entry `e ↦ parents` from the par -/ def getParentsAndReset (e : Expr) : GoalM ParentSet := do let parents ← getParents e - modify fun s => { s with parents := s.parents.erase (toENodeKey e) } + modify fun s => { s with parents := s.parents.erase { expr := e } } return parents /-- @@ -433,15 +495,14 @@ Copy `parents` to the parents of `root`. `root` must be the root of its equivalence class. -/ def copyParentsTo (parents : ParentSet) (root : Expr) : GoalM Unit := do - let key := toENodeKey root - let mut curr := if let some parents := (← get).parents.find? key then parents else {} + let mut curr := if let some parents := (← get).parents.find? { expr := root } then parents else {} for parent in parents do curr := curr.insert parent - modify fun s => { s with parents := s.parents.insert key curr } + modify fun s => { s with parents := s.parents.insert { expr := root } curr } def setENode (e : Expr) (n : ENode) : GoalM Unit := modify fun s => { s with - enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n + enodes := s.enodes.insert { expr := e } n congrTable := unsafe unsafeCast s.congrTable }