feat: add grind core module (#4249)

This commit is contained in:
Leonardo de Moura 2024-05-22 05:50:36 +02:00 committed by GitHub
parent c2b8a1e618
commit ff37e5d512
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 370 additions and 81 deletions

View file

@ -7,3 +7,4 @@ prelude
import Init.Grind.Norm
import Init.Grind.Tactics
import Init.Grind.Lemmas
import Init.Grind.Cases

15
src/Init/Grind/Cases.lean Normal file
View file

@ -0,0 +1,15 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Core
attribute [grind_cases] And Prod False Empty True Unit Exists
namespace Lean.Grind.Eager
attribute [scoped grind_cases] Or
end Lean.Grind.Eager

View file

@ -7,8 +7,19 @@ prelude
import Init.Tactics
namespace Lean.Grind
/--
The configuration for `grind`.
Passed to `grind` using, for example, the `grind (config := { eager := true })` syntax.
-/
structure Config where
/--
When `eager` is true (default: `false`), `grind` eagerly splits `if-then-else` and `match`
expressions.
-/
eager : Bool := false
deriving Inhabited, BEq
/-!
`grind` tactic and related tactics.
-/
end Lean.Grind

View file

@ -99,6 +99,8 @@ def getUInt64Value? (e : Expr) : MetaM (Option UInt64) := OptionT.run do
let (n, _) ← getOfNatValue? e ``UInt64
return UInt64.ofNat n
-- TODO: extensibility
/--
If `e` is a literal value, ensure it is encoded using the standard representation.
Otherwise, just return `e`.
@ -117,6 +119,23 @@ def normLitValue (e : Expr) : MetaM Expr := do
if let some n ← getUInt64Value? e then return toExpr n
return e
/--
Returns `true` if `e` is a literal value.
-/
def isLitValue (e : Expr) : MetaM Bool := do
let e ← instantiateMVars e
if (← getNatValue? e).isSome then return true
if (← getIntValue? e).isSome then return true
if (← getFinValue? e).isSome then return true
if (← getBitVecValue? e).isSome then return true
if (getStringValue? e).isSome then return true
if (← getCharValue? e).isSome then return true
if (← getUInt8Value? e).isSome then return true
if (← getUInt16Value? e).isSome then return true
if (← getUInt32Value? e).isSome then return true
if (← getUInt64Value? e).isSome then return true
return false
/--
If `e` is a `Nat`, `Int`, or `Fin` literal value, converts it into a constructor application.
Otherwise, just return `e`.

View file

@ -11,3 +11,4 @@ import Lean.Meta.Tactic.Grind.Preprocessor
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Injection
import Lean.Meta.Tactic.Grind.Core

View file

@ -0,0 +1,157 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.LitValues
namespace Lean.Meta.Grind
/--
Returns `true` if `e` is `True`, `False`, or a literal value.
See `LitValues` for supported literals.
-/
def isInterpreted (e : Expr) : MetaM Bool := do
if e.isTrue || e.isFalse then return true
isLitValue e
/--
Creates an `ENode` for `e` if one does not already exist.
This method assumes `e` has been hashconsed.
-/
def mkENode (e : Expr) (generation : Nat := 0) : GoalM Unit := do
if (← getENode? e).isSome then return ()
let ctor := (← isConstructorAppCore? e).isSome
let interpreted ← isInterpreted e
mkENodeCore e interpreted ctor generation
/--
Returns the root element in the equivalence class of `e`.
-/
def getRoot (e : Expr) : GoalM Expr := do
let some n ← getENode? e | return e
return n.root
/--
Returns the next element in the equivalence class of `e`.
-/
def getNext (e : Expr) : GoalM Expr := do
let some n ← getENode? e | return e
return n.next
@[inline] def isSameExpr (a b : Expr) : Bool :=
-- It is safe to use pointer equality because we hashcons all expressions
-- inserted into the E-graph
unsafe ptrEq a b
private def pushNewEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit :=
modify fun s => { s with newEqs := s.newEqs.push { lhs, rhs, proof, isHEq } }
@[inline] private def pushNewEq (lhs rhs proof : Expr) : GoalM Unit :=
pushNewEqCore lhs rhs proof (isHEq := false)
@[inline] private def pushNewHEq (lhs rhs proof : Expr) : GoalM Unit :=
pushNewEqCore lhs rhs proof (isHEq := true)
/--
The fields `target?` and `proof?` in `e`'s `ENode` are encoding a transitivity proof
from `e` to the root of the equivalence class
This method "inverts" the proof, and makes it to go from the root of the equivalence class to `e`.
We use this method when merging two equivalence classes.
-/
private partial def invertTrans (e : Expr) : GoalM Unit := do
go e false none none
where
go (e : Expr) (flippedNew : Bool) (targetNew? : Option Expr) (proofNew? : Option Expr) : GoalM Unit := do
let some node ← getENode? e | unreachable!
if let some target := node.target? then
go target (!node.flipped) (some e) node.proof?
setENode e { node with
target? := targetNew?
flipped := flippedNew
proof? := proofNew?
}
private def markAsInconsistent : GoalM Unit :=
modify fun s => { s with inconsistent := true }
private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
let some lhsNode ← getENode? lhs | return () -- `lhs` has not been internalized yet
let some rhsNode ← getENode? rhs | return () -- `rhs` has not been internalized yet
if isSameExpr lhsNode.root rhsNode.root then return () -- `lhs` and `rhs` are already in the same equivalence class.
let some lhsRoot ← getENode? lhsNode.root | unreachable!
let some rhsRoot ← getENode? rhsNode.root | unreachable!
if (lhsRoot.interpreted && !rhsRoot.interpreted)
|| (lhsRoot.ctor && !rhsRoot.ctor)
|| (lhsRoot.size > rhsRoot.size && !rhsRoot.interpreted && !rhsRoot.ctor) then
go rhs lhs rhsNode lhsNode rhsRoot lhsRoot true
else
go lhs rhs lhsNode rhsNode lhsRoot rhsRoot false
where
go (lhs rhs : Expr) (lhsNode rhsNode lhsRoot rhsRoot : ENode) (flipped : Bool) : GoalM Unit := do
let mut valueInconsistency := false
if lhsRoot.interpreted && rhsRoot.interpreted then
if lhsNode.root.isTrue || rhsNode.root.isTrue then
markAsInconsistent
else
valueInconsistency := true
-- TODO: process valueInconsistency := true
/-
We have the following `target?/proof?`
`lhs -> ... -> lhsNode.root`
`rhs -> ... -> rhsNode.root`
We want to convert it to
`lhsNode.root -> ... -> lhs -*-> rhs -> ... -> rhsNode.root`
where step `-*->` is justified by `proof` (or `proof.symm` if `flipped := true`)
-/
invertTrans lhs
setENode lhs { lhsNode with
target? := rhs
proof? := proof
flipped
}
-- TODO: Remove parents from congruence table
-- TODO: set propagateBool
updateRoots lhs rhsNode.root true -- TODO
-- TODO: Reinsert parents into congruence table
setENode lhsNode.root { lhsRoot with
next := rhsRoot.next
}
setENode rhsNode.root { rhsRoot with
next := lhsRoot.next
size := rhsRoot.size + lhsRoot.size
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
}
-- TODO: copy parentst from lhsRoot parents to rhsRoot parents
updateRoots (lhs : Expr) (rootNew : Expr) (_propagateBool : Bool) : GoalM Unit := do
let rec loop (e : Expr) : GoalM Unit := do
-- TODO: propagateBool
let some n ← getENode? e | unreachable!
setENode e { n with root := rootNew }
if isSameExpr lhs n.next then return ()
loop n.next
loop lhs
partial def addEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
addEqStep lhs rhs proof isHEq
processTodo
where
processTodo : GoalM Unit := do
if (← get).inconsistent then
modify fun s => { s with newEqs := #[] }
return ()
let some { lhs, rhs, proof, isHEq } := (← get).newEqs.back? | return ()
addEqStep lhs rhs proof isHEq
processTodo
def addEq (lhs rhs proof : Expr) : GoalM Unit := do
addEqCore lhs rhs proof false
def addHEq (lhs rhs proof : Expr) : GoalM Unit := do
addEqCore lhs rhs proof true
end Lean.Meta.Grind

View file

@ -15,6 +15,7 @@ import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Injection
import Lean.Meta.Tactic.Grind.Core
namespace Lean.Meta.Grind
namespace Preprocessor
@ -29,6 +30,7 @@ structure Context where
structure State where
simpStats : Simp.Stats := {}
goals : PArray Goal := {}
deriving Inhabited
abbrev PreM := ReaderT Context $ StateRefT State GrindM
@ -43,18 +45,13 @@ def PreM.run (x : PreM α) : GrindM α := do
}
x { simp, simprocs } |>.run' {}
def simp (e : Expr) : PreM Simp.Result := do
def simp (_goal : Goal) (e : Expr) : PreM Simp.Result := do
-- TODO: use `goal` state in the simplifier
let simpStats := (← get).simpStats
let (r, simpStats) ← Meta.simp e (← read).simp (← read).simprocs (stats := simpStats)
modify fun s => { s with simpStats }
return r
def simpHyp? (mvarId : MVarId) (fvarId : FVarId) : PreM (Option (FVarId × MVarId)) := do
let simpStats := (← get).simpStats
let (result, simpStats) ← simpLocalDecl mvarId fvarId (← read).simp (← read).simprocs (stats := simpStats)
modify fun s => { s with simpStats }
return result
inductive IntroResult where
| done
| newHyp (fvarId : FVarId) (goal : Goal)
@ -72,7 +69,7 @@ def introNext (goal : Goal) : PreM IntroResult := do
else
let tag ← goal.mvarId.getTag
let q := target.bindingBody!
let r ← simp p
let r ← simp goal p
let p' := r.expr
let p' ← canon p'
let p' ← shareCommon p'
@ -105,7 +102,7 @@ def introNext (goal : Goal) : PreM IntroResult := do
return .done
def pushResult (goal : Goal) : PreM Unit :=
modifyThe Grind.State fun s => { s with goals := s.goals.push goal }
modify fun s => { s with goals := s.goals.push goal }
def isCasesCandidate (fvarId : FVarId) : MetaM Bool := do
let .const declName _ := (← fvarId.getType).getAppFn | return false
@ -124,42 +121,47 @@ def applyInjection? (goal : Goal) (fvarId : FVarId) : MetaM (Option Goal) := do
else
return none
partial def preprocess (goal : Goal) : PreM Unit := do
partial def loop (goal : Goal) : PreM Unit := do
match (← introNext goal) with
| .done =>
if let some mvarId ← goal.mvarId.byContra? then
preprocess { goal with mvarId }
loop { goal with mvarId }
else
pushResult goal
| .newHyp fvarId goal =>
if let some goals ← applyCases? goal fvarId then
goals.forM preprocess
goals.forM loop
else if let some goal ← applyInjection? goal fvarId then
preprocess goal
loop goal
else
let clause ← goal.mvarId.withContext do mkInputClause fvarId
preprocess { goal with clauses := goal.clauses.push clause }
loop { goal with clauses := goal.clauses.push clause }
| .newDepHyp goal =>
preprocess goal
loop goal
| .newLocal fvarId goal =>
if let some goals ← applyCases? goal fvarId then
goals.forM preprocess
goals.forM loop
else
preprocess goal
loop goal
def preprocess (mvarId : MVarId) : PreM State := do
loop (← mkGoal mvarId)
get
end Preprocessor
open Preprocessor
partial def main (mvarId : MVarId) (mainDeclName : Name) : MetaM Grind.State := do
partial def main (mvarId : MVarId) (mainDeclName : Name) : MetaM (List MVarId) := do
mvarId.ensureProp
mvarId.ensureNoMVar
let mvarId ← mvarId.clearAuxDecls
let mvarId ← mvarId.revertAll
mvarId.ensureNoMVar
let mvarId ← mvarId.abstractNestedProofs mainDeclName
let mvarId ← mvarId.unfoldReducible
let mvarId ← mvarId.betaReduce
let s ← (preprocess { mvarId } *> getThe Grind.State) |>.run |>.run mainDeclName
return s
let s ← preprocess mvarId |>.run |>.run mainDeclName
return s.goals.toList.map (·.mvarId)
end Lean.Meta.Grind

View file

@ -11,6 +11,47 @@ import Lean.Meta.Canonicalizer
import Lean.Meta.Tactic.Util
namespace Lean.Meta.Grind
structure Context where
mainDeclName : Name
structure State where
canon : Canonicalizer.State := {}
/-- `ShareCommon` (aka `Hashconsing`) state. -/
scState : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCommon.State.mk _
/-- Next index for creating auxiliary theorems. -/
nextThmIdx : Nat := 1
abbrev GrindM := ReaderT Context $ StateRefT State MetaM
@[inline] def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α :=
x { mainDeclName } |>.run' {}
def abstractNestedProofs (e : Expr) : GrindM Expr := do
let nextIdx := (← get).nextThmIdx
let (e, s') ← AbstractNestedProofs.visit e |>.run { baseName := (← read).mainDeclName } |>.run |>.run { nextIdx }
modify fun s => { s with nextThmIdx := s'.nextIdx }
return e
def shareCommon (e : Expr) : GrindM Expr := do
modifyGet fun { canon, scState, nextThmIdx } =>
let (e, scState) := ShareCommon.State.shareCommon scState e
(e, { canon, scState, nextThmIdx })
def canon (e : Expr) : GrindM Expr := do
let canonS ← modifyGet fun s => (s.canon, { s with canon := {} })
let (e, canonS) ← Canonicalizer.CanonM.run (canonRec e) (s := canonS)
modify fun s => { s with canon := canonS }
return e
where
canonRec (e : Expr) : CanonM Expr := do
let post (e : Expr) : CanonM TransformStep := do
if e.isApp then
return .done (← Meta.canon e)
else
return .done e
transform e post
/--
Stores information for a node in the egraph.
Each internalized expression `e` has an `ENode` associated with it.
@ -43,6 +84,9 @@ structure ENode where
on heterogeneous equality.
-/
heqProofs : Bool := false
generation : Nat := 0
/-- Modification time -/
mt : Nat := 0
-- TODO: see Lean 3 implementation
structure Clause where
@ -53,58 +97,57 @@ structure Clause where
def mkInputClause (fvarId : FVarId) : MetaM Clause :=
return { expr := (← fvarId.getType), proof := mkFVar fvarId }
structure Goal where
mvarId : MVarId
clauses : PArray Clause := {}
enodes : PHashMap UInt64 ENode := {}
-- TODO: occurrences for propagation, etc
deriving Inhabited
structure NewEq where
lhs : Expr
rhs : Expr
proof : Expr
isHEq : Bool
def mkGoal (mvarId : MVarId) : Goal :=
{ mvarId }
structure Goal where
mvarId : MVarId
clauses : PArray Clause := {}
enodes : PHashMap USize ENode := {}
newEqs : Array NewEq := #[]
/-- `inconsistent := true` if `ENode`s for `True` and `False` are in the same equivalence class. -/
inconsistent : Bool := false
/-- Goal modification time. -/
gmt : Nat := 0
deriving Inhabited
def Goal.admit (goal : Goal) : MetaM Unit :=
goal.mvarId.admit
structure Context where
mainDeclName : Name
abbrev GoalM := StateRefT Goal GrindM
structure State where
canon : Canonicalizer.State := {}
/-- `ShareCommon` (aka `Hashconsing`) state. -/
scState : ShareCommon.State.{0} ShareCommon.objectFactory := ShareCommon.State.mk _
/-- Next index for creating auxiliary theorems. -/
nextThmIdx : Nat := 1
goals : PArray Goal := {}
@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
StateRefT'.run x goal
abbrev GrindM := ReaderT Context $ StateRefT State MetaM
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
StateRefT'.run' (x *> get) goal
def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α :=
x { mainDeclName } |>.run' {}
/--
Returns `some n` if `e` has already been "internalized" into the
Otherwise, returns `none`s.
-/
def getENode? (e : Expr) : GoalM (Option ENode) :=
return (← get).enodes.find? (unsafe ptrAddrUnsafe e)
def abstractNestedProofs (e : Expr) : GrindM Expr := do
let nextIdx := (← get).nextThmIdx
let (e, s') ← AbstractNestedProofs.visit e |>.run { baseName := (← read).mainDeclName } |>.run |>.run { nextIdx }
modify fun s => { s with nextThmIdx := s'.nextIdx }
return e
def setENode (e : Expr) (n : ENode) : GoalM Unit :=
modify fun s => { s with enodes := s.enodes.insert (unsafe ptrAddrUnsafe e) n }
def shareCommon (e : Expr) : GrindM Expr := do
modifyGet fun { canon, scState, nextThmIdx, goals } =>
let (e, scState) := ShareCommon.State.shareCommon scState e
(e, { canon, scState, nextThmIdx, goals })
def mkENodeCore (e : Expr) (interpreted ctor : Bool) (generation : Nat) : GoalM Unit := do
setENode e {
next := e, root := e, cgRoot := e, size := 1
flipped := false
heqProofs := false
hasLambdas := e.isLambda
mt := (← get).gmt
interpreted, ctor, generation
}
def canon (e : Expr) : GrindM Expr := do
let canonS ← modifyGet fun s => (s.canon, { s with canon := {} })
let (e, canonS) ← Canonicalizer.CanonM.run (canonRec e) (s := canonS)
modify fun s => { s with canon := canonS }
return e
where
canonRec (e : Expr) : CanonM Expr := do
let post (e : Expr) : CanonM TransformStep := do
if e.isApp then
return .done (← Meta.canon e)
else
return .done e
transform e post
def mkGoal (mvarId : MVarId) : GrindM Goal := do
GoalM.run' { mvarId } do
mkENodeCore (← shareCommon (mkConst ``True)) (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore (← shareCommon (mkConst ``False)) (interpreted := true) (ctor := false) (generation := 0)
end Lean.Meta.Grind

View file

@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.AbstractNestedProofs
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Clear
namespace Lean.Meta.Grind
/--
@ -66,7 +67,7 @@ def _root_.Lean.MVarId.betaReduce (mvarId : MVarId) : MetaM MVarId :=
If the target is not `False`, apply `byContradiction`.
-/
def _root_.Lean.MVarId.byContra? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withContext do
mvarId.checkNotAssigned `grind
mvarId.checkNotAssigned `grind.by_contra
let target ← mvarId.getType
if target.isFalse then return none
let targetNew ← mkArrow (mkNot target) (mkConst ``False)
@ -75,4 +76,24 @@ def _root_.Lean.MVarId.byContra? (mvarId : MVarId) : MetaM (Option MVarId) := mv
mvarId.assign <| mkApp2 (mkConst ``Classical.byContradiction) target mvarNew
return mvarNew.mvarId!
/--
Clear auxiliary decls used to encode recursive declarations.
`grind` eliminates them to ensure they are not accidentaly used by its proof automation.
-/
def _root_.Lean.MVarId.clearAuxDecls (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
mvarId.checkNotAssigned `grind.clear_aux_decls
let mut toClear := []
for localDecl in (← getLCtx) do
if localDecl.isAuxDecl then
toClear := localDecl.fvarId :: toClear
if toClear.isEmpty then
return mvarId
let mut mvarId := mvarId
for fvarId in toClear do
try
mvarId ← mvarId.clear fvarId
catch _ =>
throwTacticEx `grind.clear_aux_decls mvarId "failed to clear local auxiliary declaration"
return mvarId
end Lean.Meta.Grind

View file

@ -4,13 +4,29 @@ open Lean Meta Elab Tactic Grind in
elab "grind_pre" : tactic => do
let declName := (← Term.getDeclName?).getD `_main
liftMetaTactic fun mvarId => do
let result ← Meta.Grind.main mvarId declName
return result.goals.map (·.mvarId) |>.toList
Meta.Grind.main mvarId declName
abbrev f (a : α) := a
attribute [grind_cases] And Or
/--
warning: declaration uses 'sorry'
---
info: a b c : Bool
p q : Prop
left✝ : a = true
right✝ : b = true c = true
left : p
right : q
x✝ : b = false a = false
⊢ False
-/
#guard_msgs in
theorem ex (h : (f a && (b || f (f c))) = true) (h' : p ∧ q) : b && a := by
grind_pre
trace_state
all_goals sorry
open Lean.Grind.Eager in
/--
warning: declaration uses 'sorry'
---
@ -51,11 +67,12 @@ h : a = false
⊢ False
-/
#guard_msgs in
theorem ex (h : (f a && (b || f (f c))) = true) (h' : p ∧ q) : b && a := by
theorem ex2 (h : (f a && (b || f (f c))) = true) (h' : p ∧ q) : b && a := by
grind_pre
trace_state
all_goals sorry
def g (i : Nat) (j : Nat) (_ : i > j := by omega) := i + j
example (i j : Nat) (h : i + 1 > j + 1) : g (i+1) j = f ((fun x => x) i) + f j + 1 := by
@ -65,27 +82,29 @@ example (i j : Nat) (h : i + 1 > j + 1) : g (i+1) j = f ((fun x => x) i) + f j +
guard_hyp hn : ¬g (i + 1) j _ = i + j + 1
simp_arith [g] at hn
structure Point where
x : Nat
y : Int
/--
warning: declaration uses 'sorry'
---
info: α✝ : Type u_1
β✝ : Type u_2
a₁ : α✝ × β✝
a₂ : α✝
a₃ : β✝
as : List (α✝ × β✝)
b₁ : α✝ × β✝
b₂ : α✝
b₃ : β✝
bs : List (α✝ × β✝)
info: a₁ : Point
a₂ : Nat
a₃ : Int
as : List Point
b₁ : Point
bs : List Point
b₂ : Nat
b₃ : Int
head_eq : a₁ = b₁
fst_eq : a₂ = b₂
snd_eq : a₃ = b₃
x_eq : a₂ = b₂
y_eq : a₃ = b₃
tail_eq : as = bs
⊢ False
-/
#guard_msgs in
theorem ex2 (h : a₁ :: (a₂, a₃) :: as = b₁ :: (b₂, b₃) :: bs) : False := by
theorem ex3 (h : a₁ :: { x := a₂, y := a₃ : Point } :: as = b₁ :: { x := b₂, y := b₃} :: bs) : False := by
grind_pre
trace_state
sorry