fix: E-matching module for grind (#6488)
This PR fixes and refactors the E-matching module for the (WIP) `grind` tactic. Next step: top-level search procedure for `grind`.
This commit is contained in:
parent
8899c7ed8c
commit
5ba476116f
4 changed files with 124 additions and 61 deletions
|
|
@ -8,14 +8,14 @@ import Lean.Meta.Tactic.Grind.Types
|
|||
import Lean.Meta.Tactic.Grind.Internalize
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-- Returns maximum term generation that is considered during ematching -/
|
||||
private def getMaxGeneration : GoalM Nat := do
|
||||
return 10000 -- TODO
|
||||
|
||||
/-- Returns `true` if the maximum number of instances has been reached. -/
|
||||
private def checkMaxInstancesExceeded : GoalM Bool := do
|
||||
return false -- TODO
|
||||
/--
|
||||
Theorem instance found using E-matching.
|
||||
Recall that we only internalize new instances after we complete a full round of E-matching. -/
|
||||
structure EMatchTheoremInstance where
|
||||
proof : Expr
|
||||
prop : Expr
|
||||
generation : Nat
|
||||
deriving Inhabited
|
||||
|
||||
namespace EMatch
|
||||
/-! This module implements a simple E-matching procedure as a backtracking search. -/
|
||||
|
|
@ -51,13 +51,6 @@ structure Choice where
|
|||
assignment : Array Expr
|
||||
deriving Inhabited
|
||||
|
||||
/-- Theorem instances found so far. We only internalize them after we complete a full round of E-matching. -/
|
||||
structure TheoremInstance where
|
||||
proof : Expr
|
||||
prop : Expr
|
||||
generation : Nat
|
||||
deriving Inhabited
|
||||
|
||||
/-- Context for the E-matching monad. -/
|
||||
structure Context where
|
||||
/-- `useMT` is `true` if we are using the mod-time optimization. It is always set to false for new `EMatchTheorem`s. -/
|
||||
|
|
@ -70,7 +63,7 @@ structure Context where
|
|||
structure State where
|
||||
/-- Choices that still have to be processed. -/
|
||||
choiceStack : List Choice := []
|
||||
newInstances : PArray TheoremInstance := {}
|
||||
newInstances : Array EMatchTheoremInstance := #[]
|
||||
deriving Inhabited
|
||||
|
||||
abbrev M := ReaderT Context $ StateRefT State GoalM
|
||||
|
|
@ -181,6 +174,8 @@ Missing parameters are synthesized using type inference and type class synthesis
|
|||
-/
|
||||
private partial def instantiateTheorem (c : Choice) : M Unit := withDefault do
|
||||
let thm := (← read).thm
|
||||
unless (← markTheorenInstance thm.proof c.assignment) do
|
||||
return ()
|
||||
trace[grind.ematch.instance.assignment] "{← thm.origin.pp}: {assignmentToMessageData c.assignment}"
|
||||
let proof ← thm.getProofWithFreshMVarLevels
|
||||
let numParams := thm.numParams
|
||||
|
|
@ -285,22 +280,26 @@ where
|
|||
def ematchTheorems (thms : PArray EMatchTheorem) : M Unit := do
|
||||
thms.forM ematchTheorem
|
||||
|
||||
def internalizeNewInstances : M Unit := do
|
||||
-- TODO
|
||||
return ()
|
||||
|
||||
end EMatch
|
||||
|
||||
open EMatch
|
||||
|
||||
/-- Performs one round of E-matching, and internalizes new instances. -/
|
||||
def ematch : GoalM Unit := do
|
||||
let go (thms newThms : PArray EMatchTheorem) : EMatch.M Unit := do
|
||||
/-- Performs one round of E-matching, and returns new instances. -/
|
||||
def ematch : GoalM (Array EMatchTheoremInstance) := do
|
||||
let go (thms newThms : PArray EMatchTheorem) : EMatch.M (Array EMatchTheoremInstance) := do
|
||||
withReader (fun ctx => { ctx with useMT := true }) <| ematchTheorems thms
|
||||
withReader (fun ctx => { ctx with useMT := false }) <| ematchTheorems newThms
|
||||
internalizeNewInstances
|
||||
unless (← checkMaxInstancesExceeded) do
|
||||
go (← get).thms (← get).newThms |>.run'
|
||||
modify fun s => { s with thms := s.thms ++ s.newThms, newThms := {}, gmt := s.gmt + 1 }
|
||||
return (← get).newInstances
|
||||
if (← checkMaxInstancesExceeded) then
|
||||
return #[]
|
||||
else
|
||||
let insts ← go (← get).thms (← get).newThms |>.run'
|
||||
modify fun s => { s with
|
||||
thms := s.thms ++ s.newThms
|
||||
newThms := {}
|
||||
gmt := s.gmt + 1
|
||||
numInstances := s.numInstances + insts.size
|
||||
}
|
||||
return insts
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
|
|||
|
|
@ -59,7 +59,10 @@ private partial def activateTheoremPatterns (fName : Name) (generation : Nat) :
|
|||
let thm := { thm with symbols }
|
||||
match symbols with
|
||||
| [] =>
|
||||
let thm := { thm with patterns := (← thm.patterns.mapM (internalizePattern · generation)) }
|
||||
-- Recall that we use the proof as part of the key for a set of instances found so far.
|
||||
-- We don't want to use structural equality when comparing keys.
|
||||
let proof ← shareCommon thm.proof
|
||||
let thm := { thm with proof, patterns := (← thm.patterns.mapM (internalizePattern · generation)) }
|
||||
trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}"
|
||||
modify fun s => { s with newThms := s.newThms.push thm }
|
||||
| _ =>
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ def preprocess (mvarId : MVarId) : PreM State := do
|
|||
loop (← mkGoal mvarId)
|
||||
let goals := (← get).goals
|
||||
-- Testing `ematch` module here. We will rewrite this part later.
|
||||
let goals ← goals.mapM fun goal => GoalM.run' goal ematch
|
||||
let goals ← goals.mapM fun goal => GoalM.run' goal (discard <| ematch)
|
||||
if (← isTracingEnabledFor `grind.pre) then
|
||||
trace[grind.debug.pre] (← ppGoals goals)
|
||||
for goal in goals do
|
||||
|
|
|
|||
|
|
@ -101,6 +101,11 @@ def getMainDeclName : GrindM Name :=
|
|||
@[inline] def getMethodsRef : GrindM MethodsRef :=
|
||||
read
|
||||
|
||||
/--
|
||||
Returns maximum term generation that is considered during ematching. -/
|
||||
def getMaxGeneration : GrindM Nat := do
|
||||
return 10000 -- TODO
|
||||
|
||||
/--
|
||||
Abtracts nested proofs in `e`. This is a preprocessing step performed before internalization.
|
||||
-/
|
||||
|
|
@ -193,31 +198,44 @@ structure NewEq where
|
|||
proof : Expr
|
||||
isHEq : Bool
|
||||
|
||||
abbrev ENodes := PHashMap USize ENode
|
||||
/--
|
||||
Key for the `ENodeMap` and `ParentMap` map.
|
||||
We use pointer addresses and rely on the fact all internalized expressions
|
||||
have been hash-consed, i.e., we have applied `shareCommon`.
|
||||
-/
|
||||
private structure ENodeKey where
|
||||
expr : Expr
|
||||
|
||||
structure CongrKey (enodes : ENodes) where
|
||||
instance : Hashable ENodeKey where
|
||||
hash k := unsafe (ptrAddrUnsafe k.expr).toUInt64
|
||||
|
||||
instance : BEq ENodeKey where
|
||||
beq k₁ k₂ := isSameExpr k₁.expr k₂.expr
|
||||
|
||||
abbrev ENodeMap := PHashMap ENodeKey ENode
|
||||
|
||||
/--
|
||||
Key for the congruence table.
|
||||
We need access to the `enodes` to be able to retrieve the equivalence class roots.
|
||||
-/
|
||||
structure CongrKey (enodes : ENodeMap) where
|
||||
e : Expr
|
||||
|
||||
private abbrev toENodeKey (e : Expr) : USize :=
|
||||
unsafe ptrAddrUnsafe e
|
||||
|
||||
private def hashRoot (enodes : ENodes) (e : Expr) : UInt64 :=
|
||||
if let some node := enodes.find? (toENodeKey e) then
|
||||
toENodeKey node.root |>.toUInt64
|
||||
private def hashRoot (enodes : ENodeMap) (e : Expr) : UInt64 :=
|
||||
if let some node := enodes.find? { expr := e } then
|
||||
unsafe (ptrAddrUnsafe node.root).toUInt64
|
||||
else
|
||||
13
|
||||
|
||||
private def hasSameRoot (enodes : ENodes) (a b : Expr) : Bool := Id.run do
|
||||
let ka := toENodeKey a
|
||||
let kb := toENodeKey b
|
||||
if ka == kb then
|
||||
private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do
|
||||
if isSameExpr a b then
|
||||
return true
|
||||
else
|
||||
let some n1 := enodes.find? ka | return false
|
||||
let some n2 := enodes.find? kb | return false
|
||||
toENodeKey n1.root == toENodeKey n2.root
|
||||
let some n1 := enodes.find? { expr := a } | return false
|
||||
let some n2 := enodes.find? { expr := b } | return false
|
||||
isSameExpr n1.root n2.root
|
||||
|
||||
def congrHash (enodes : ENodes) (e : Expr) : UInt64 :=
|
||||
def congrHash (enodes : ENodeMap) (e : Expr) : UInt64 :=
|
||||
if e.isAppOfArity ``Lean.Grind.nestedProof 2 then
|
||||
-- We only hash the proposition
|
||||
hashRoot enodes (e.getArg! 0)
|
||||
|
|
@ -229,7 +247,7 @@ where
|
|||
| .app f a => go f (mixHash r (hashRoot enodes a))
|
||||
| _ => mixHash r (hashRoot enodes e)
|
||||
|
||||
partial def isCongruent (enodes : ENodes) (a b : Expr) : Bool :=
|
||||
partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool :=
|
||||
if a.isAppOfArity ``Lean.Grind.nestedProof 2 && b.isAppOfArity ``Lean.Grind.nestedProof 2 then
|
||||
hasSameRoot enodes (a.getArg! 0) (b.getArg! 0)
|
||||
else
|
||||
|
|
@ -249,15 +267,43 @@ instance : Hashable (CongrKey enodes) where
|
|||
instance : BEq (CongrKey enodes) where
|
||||
beq k1 k2 := isCongruent enodes k1.e k2.e
|
||||
|
||||
abbrev CongrTable (enodes : ENodes) := PHashSet (CongrKey enodes)
|
||||
abbrev CongrTable (enodes : ENodeMap) := PHashSet (CongrKey enodes)
|
||||
|
||||
-- Remark: we cannot use pointer addresses here because we have to traverse the tree.
|
||||
abbrev ParentSet := RBTree Expr Expr.quickComp
|
||||
abbrev ParentMap := PHashMap USize ParentSet
|
||||
abbrev ParentMap := PHashMap ENodeKey ParentSet
|
||||
|
||||
/--
|
||||
The E-matching module instantiates theorems using the `EMatchTheorem proof` and a (partial) assignment.
|
||||
We want to avoid instantiating the same theorem with the same assignment more than once.
|
||||
Therefore, we store the (pre-)instance information in set.
|
||||
Recall that the proofs of activated theorems have been hash-consed.
|
||||
The assignment contains internalized expressions, which have also been hash-consed.
|
||||
-/
|
||||
structure PreInstance where
|
||||
proof : Expr
|
||||
assignment : Array Expr
|
||||
|
||||
instance : Hashable PreInstance where
|
||||
hash i := Id.run do
|
||||
let mut r := unsafe (ptrAddrUnsafe i.proof >>> 3).toUInt64
|
||||
for v in i.assignment do
|
||||
r := mixHash r (unsafe (ptrAddrUnsafe v >>> 3).toUInt64)
|
||||
return r
|
||||
|
||||
instance : BEq PreInstance where
|
||||
beq i₁ i₂ := Id.run do
|
||||
unless isSameExpr i₁.proof i₂.proof do return false
|
||||
unless i₁.assignment.size == i₂.assignment.size do return false
|
||||
for v₁ in i₁.assignment, v₂ in i₂.assignment do
|
||||
unless isSameExpr v₁ v₂ do return false
|
||||
return true
|
||||
|
||||
abbrev PreInstanceSet := PHashSet PreInstance
|
||||
|
||||
structure Goal where
|
||||
mvarId : MVarId
|
||||
enodes : ENodes := {}
|
||||
enodes : ENodeMap := {}
|
||||
parents : ParentMap := {}
|
||||
congrTable : CongrTable enodes := {}
|
||||
/--
|
||||
|
|
@ -285,6 +331,8 @@ structure Goal where
|
|||
thmMap : EMatchTheorems
|
||||
/-- Number of theorem instances generated so far -/
|
||||
numInstances : Nat := 0
|
||||
/-- (pre-)instances found so far -/
|
||||
instances : PreInstanceSet := {}
|
||||
deriving Inhabited
|
||||
|
||||
def Goal.admit (goal : Goal) : MetaM Unit :=
|
||||
|
|
@ -294,6 +342,21 @@ abbrev GoalM := StateRefT Goal GrindM
|
|||
|
||||
abbrev Propagator := Expr → GoalM Unit
|
||||
|
||||
/--
|
||||
A helper function used to mark a theorem instance found by the E-matching module.
|
||||
It returns `true` if it is a new instance and `false` otherwise.
|
||||
-/
|
||||
def markTheorenInstance (proof : Expr) (assignment : Array Expr) : GoalM Bool := do
|
||||
let k := { proof, assignment }
|
||||
if (← get).instances.contains k then
|
||||
return false
|
||||
modify fun s => { s with instances := s.instances.insert k }
|
||||
return true
|
||||
|
||||
/-- Returns `true` if the maximum number of instances has been reached. -/
|
||||
def checkMaxInstancesExceeded : GoalM Bool := do
|
||||
return false -- TODO
|
||||
|
||||
/-- Returns `true` if `e` is the internalized `True` expression. -/
|
||||
def isTrueExpr (e : Expr) : GrindM Bool :=
|
||||
return isSameExpr e (← getTrueExpr)
|
||||
|
|
@ -307,11 +370,11 @@ Returns `some n` if `e` has already been "internalized" into the
|
|||
Otherwise, returns `none`s.
|
||||
-/
|
||||
def getENode? (e : Expr) : GoalM (Option ENode) :=
|
||||
return (← get).enodes.find? (unsafe ptrAddrUnsafe e)
|
||||
return (← get).enodes.find? { expr := e }
|
||||
|
||||
/-- Returns node associated with `e`. It assumes `e` has already been internalized. -/
|
||||
def getENode (e : Expr) : GoalM ENode := do
|
||||
let some n := (← get).enodes.find? (unsafe ptrAddrUnsafe e)
|
||||
let some n := (← get).enodes.find? { expr := e }
|
||||
| throwError "internal `grind` error, term has not been internalized{indentExpr e}"
|
||||
return n
|
||||
|
||||
|
|
@ -362,7 +425,7 @@ def getNext (e : Expr) : GoalM Expr :=
|
|||
|
||||
/-- Returns `true` if `e` has already been internalized. -/
|
||||
def alreadyInternalized (e : Expr) : GoalM Bool :=
|
||||
return (← get).enodes.contains (unsafe ptrAddrUnsafe e)
|
||||
return (← get).enodes.contains { expr := e }
|
||||
|
||||
def getTarget? (e : Expr) : GoalM (Option Expr) := do
|
||||
let some n ← getENode? e | return none
|
||||
|
|
@ -407,9 +470,8 @@ information in the root (aka canonical representative) of `child`.
|
|||
-/
|
||||
def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do
|
||||
let some childRoot ← getRoot? child | return ()
|
||||
let key := toENodeKey childRoot
|
||||
let parents := if let some parents := (← get).parents.find? key then parents else {}
|
||||
modify fun s => { s with parents := s.parents.insert key (parents.insert parent) }
|
||||
let parents := if let some parents := (← get).parents.find? { expr := childRoot } then parents else {}
|
||||
modify fun s => { s with parents := s.parents.insert { expr := childRoot } (parents.insert parent) }
|
||||
|
||||
/--
|
||||
Returns the set of expressions `e` is a child of, or an expression in
|
||||
|
|
@ -417,7 +479,7 @@ Returns the set of expressions `e` is a child of, or an expression in
|
|||
The information is only up to date if `e` is the root (aka canonical representative) of the equivalence class.
|
||||
-/
|
||||
def getParents (e : Expr) : GoalM ParentSet := do
|
||||
let some parents := (← get).parents.find? (toENodeKey e) | return {}
|
||||
let some parents := (← get).parents.find? { expr := e } | return {}
|
||||
return parents
|
||||
|
||||
/--
|
||||
|
|
@ -425,7 +487,7 @@ Similar to `getParents`, but also removes the entry `e ↦ parents` from the par
|
|||
-/
|
||||
def getParentsAndReset (e : Expr) : GoalM ParentSet := do
|
||||
let parents ← getParents e
|
||||
modify fun s => { s with parents := s.parents.erase (toENodeKey e) }
|
||||
modify fun s => { s with parents := s.parents.erase { expr := e } }
|
||||
return parents
|
||||
|
||||
/--
|
||||
|
|
@ -433,15 +495,14 @@ Copy `parents` to the parents of `root`.
|
|||
`root` must be the root of its equivalence class.
|
||||
-/
|
||||
def copyParentsTo (parents : ParentSet) (root : Expr) : GoalM Unit := do
|
||||
let key := toENodeKey root
|
||||
let mut curr := if let some parents := (← get).parents.find? key then parents else {}
|
||||
let mut curr := if let some parents := (← get).parents.find? { expr := root } then parents else {}
|
||||
for parent in parents do
|
||||
curr := curr.insert parent
|
||||
modify fun s => { s with parents := s.parents.insert key curr }
|
||||
modify fun s => { s with parents := s.parents.insert { expr := root } curr }
|
||||
|
||||
def setENode (e : Expr) (n : ENode) : GoalM Unit :=
|
||||
modify fun s => { s with
|
||||
enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n
|
||||
enodes := s.enodes.insert { expr := e } n
|
||||
congrTable := unsafe unsafeCast s.congrTable
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue