feat: projections in grind (#6465)
This PR adds support for projection functions to the (WIP) `grind` tactic.
This commit is contained in:
parent
f545df9922
commit
fe45ddd610
9 changed files with 238 additions and 136 deletions
|
|
@ -10,56 +10,10 @@ import Lean.Meta.Tactic.Grind.Types
|
|||
import Lean.Meta.Tactic.Grind.Inv
|
||||
import Lean.Meta.Tactic.Grind.PP
|
||||
import Lean.Meta.Tactic.Grind.Ctor
|
||||
import Lean.Meta.Tactic.Grind.Internalize
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-- Adds `e` to congruence table. -/
|
||||
private def addCongrTable (e : Expr) : GoalM Unit := do
|
||||
if let some { e := e' } := (← get).congrTable.find? { e } then
|
||||
trace[grind.congr] "{e} = {e'}"
|
||||
pushEqHEq e e' congrPlaceholderProof
|
||||
-- TODO: we must check whether the types of the functions are the same
|
||||
-- TODO: update cgRoot for `e`
|
||||
else
|
||||
modify fun s => { s with congrTable := s.congrTable.insert { e } }
|
||||
|
||||
partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
|
||||
if (← alreadyInternalized e) then return ()
|
||||
match e with
|
||||
| .bvar .. => unreachable!
|
||||
| .sort .. => return ()
|
||||
| .fvar .. | .letE .. | .lam .. | .forallE .. =>
|
||||
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
|
||||
| .lit .. | .const .. =>
|
||||
mkENode e generation
|
||||
| .mvar ..
|
||||
| .mdata ..
|
||||
| .proj .. =>
|
||||
trace[grind.issues] "unexpected term during internalization{indentExpr e}"
|
||||
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
|
||||
| .app .. =>
|
||||
if (← isLitValue e) then
|
||||
-- We do not want to internalize the components of a literal value.
|
||||
mkENode e generation
|
||||
else e.withApp fun f args => do
|
||||
if f.isConstOf ``Lean.Grind.nestedProof && args.size == 2 then
|
||||
-- We only internalize the proposition. We can skip the proof because of
|
||||
-- proof irrelevance
|
||||
let c := args[0]!
|
||||
internalize c generation
|
||||
registerParent e c
|
||||
else
|
||||
unless f.isConst do
|
||||
internalize f generation
|
||||
registerParent e f
|
||||
for h : i in [: args.size] do
|
||||
let arg := args[i]
|
||||
internalize arg generation
|
||||
registerParent e arg
|
||||
mkENode e generation
|
||||
addCongrTable e
|
||||
propagateUp e
|
||||
|
||||
/--
|
||||
The fields `target?` and `proof?` in `e`'s `ENode` are encoding a transitivity proof
|
||||
from `e` to the root of the equivalence class
|
||||
|
|
|
|||
60
src/Lean/Meta/Tactic/Grind/Internalize.lean
Normal file
60
src/Lean/Meta/Tactic/Grind/Internalize.lean
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
/-
|
||||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Grind.Util
|
||||
import Lean.Meta.LitValues
|
||||
import Lean.Meta.Tactic.Grind.Types
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-- Adds `e` to congruence table. -/
|
||||
def addCongrTable (e : Expr) : GoalM Unit := do
|
||||
if let some { e := e' } := (← get).congrTable.find? { e } then
|
||||
trace[grind.congr] "{e} = {e'}"
|
||||
pushEqHEq e e' congrPlaceholderProof
|
||||
-- TODO: we must check whether the types of the functions are the same
|
||||
-- TODO: update cgRoot for `e`
|
||||
else
|
||||
modify fun s => { s with congrTable := s.congrTable.insert { e } }
|
||||
|
||||
partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
|
||||
if (← alreadyInternalized e) then return ()
|
||||
match e with
|
||||
| .bvar .. => unreachable!
|
||||
| .sort .. => return ()
|
||||
| .fvar .. | .letE .. | .lam .. | .forallE .. =>
|
||||
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
|
||||
| .lit .. | .const .. =>
|
||||
mkENode e generation
|
||||
| .mvar ..
|
||||
| .mdata ..
|
||||
| .proj .. =>
|
||||
trace[grind.issues] "unexpected term during internalization{indentExpr e}"
|
||||
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
|
||||
| .app .. =>
|
||||
if (← isLitValue e) then
|
||||
-- We do not want to internalize the components of a literal value.
|
||||
mkENode e generation
|
||||
else e.withApp fun f args => do
|
||||
if f.isConstOf ``Lean.Grind.nestedProof && args.size == 2 then
|
||||
-- We only internalize the proposition. We can skip the proof because of
|
||||
-- proof irrelevance
|
||||
let c := args[0]!
|
||||
internalize c generation
|
||||
registerParent e c
|
||||
else
|
||||
unless f.isConst do
|
||||
internalize f generation
|
||||
registerParent e f
|
||||
for h : i in [: args.size] do
|
||||
let arg := args[i]
|
||||
internalize arg generation
|
||||
registerParent e arg
|
||||
mkENode e generation
|
||||
addCongrTable e
|
||||
propagateUp e
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
@ -17,6 +17,7 @@ import Lean.Meta.Tactic.Grind.Cases
|
|||
import Lean.Meta.Tactic.Grind.Injection
|
||||
import Lean.Meta.Tactic.Grind.Core
|
||||
import Lean.Meta.Tactic.Grind.Simp
|
||||
import Lean.Meta.Tactic.Grind.Run
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
namespace Preprocessor
|
||||
|
|
|
|||
35
src/Lean/Meta/Tactic/Grind/Proj.lean
Normal file
35
src/Lean/Meta/Tactic/Grind/Proj.lean
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
/-
|
||||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.ProjFns
|
||||
import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.Internalize
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/--
|
||||
If `parent` is a projection-application `proj_i c`,
|
||||
check whether the root of the equivalence class containing `c` is a constructor-application `ctor ... a_i ...`.
|
||||
If so, internalize the term `proj_i (ctor ... a_i ...)` and add the equality `proj_i (ctor ... a_i ...) = a_i`.
|
||||
-/
|
||||
def propagateProjEq (parent : Expr) : GoalM Unit := do
|
||||
let .const declName _ := parent.getAppFn | return ()
|
||||
let some info ← getProjectionFnInfo? declName | return ()
|
||||
unless info.numParams + 1 == parent.getAppNumArgs do return ()
|
||||
let arg := parent.appArg!
|
||||
let ctor ← getRoot arg
|
||||
unless ctor.isAppOf info.ctorName do return ()
|
||||
if isSameExpr arg ctor then
|
||||
let idx := info.numParams + info.i
|
||||
unless idx < ctor.getAppNumArgs do return ()
|
||||
let v := ctor.getArg! idx
|
||||
pushEq parent v (← mkEqRefl v)
|
||||
else
|
||||
let newProj := mkApp parent.appFn! ctor
|
||||
let newProj ← shareCommon newProj
|
||||
internalize newProj (← getGeneration parent)
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
|||
prelude
|
||||
import Init.Grind
|
||||
import Lean.Meta.Tactic.Grind.Proof
|
||||
import Lean.Meta.Tactic.Grind.PropagatorAttr
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
|
|
|
|||
61
src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean
Normal file
61
src/Lean/Meta/Tactic/Grind/PropagatorAttr.lean
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
/-
|
||||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Grind
|
||||
import Lean.Meta.Tactic.Grind.Proof
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/-- Builtin propagators. -/
|
||||
structure BuiltinPropagators where
|
||||
up : Std.HashMap Name Propagator := {}
|
||||
down : Std.HashMap Name Propagator := {}
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize builtinPropagatorsRef : IO.Ref BuiltinPropagators ← IO.mkRef {}
|
||||
|
||||
private def registerBuiltinPropagatorCore (declName : Name) (up : Bool) (proc : Propagator) : IO Unit := do
|
||||
unless (← initializing) do
|
||||
throw (IO.userError s!"invalid builtin `grind` propagator declaration, it can only be registered during initialization")
|
||||
if up then
|
||||
if (← builtinPropagatorsRef.get).up.contains declName then
|
||||
throw (IO.userError s!"invalid builtin `grind` upward propagator `{declName}`, it has already been declared")
|
||||
builtinPropagatorsRef.modify fun { up, down } => { up := up.insert declName proc, down }
|
||||
else
|
||||
if (← builtinPropagatorsRef.get).down.contains declName then
|
||||
throw (IO.userError s!"invalid builtin `grind` downward propagator `{declName}`, it has already been declared")
|
||||
builtinPropagatorsRef.modify fun { up, down } => { up, down := down.insert declName proc }
|
||||
|
||||
def registerBuiltinUpwardPropagator (declName : Name) (proc : Propagator) : IO Unit :=
|
||||
registerBuiltinPropagatorCore declName true proc
|
||||
|
||||
def registerBuiltinDownwardPropagator (declName : Name) (proc : Propagator) : IO Unit :=
|
||||
registerBuiltinPropagatorCore declName false proc
|
||||
|
||||
private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do
|
||||
let go : MetaM Unit := do
|
||||
let up := stx[1].getKind == ``Lean.Parser.Tactic.simpPost
|
||||
let addDeclName := if up then
|
||||
``registerBuiltinUpwardPropagator
|
||||
else
|
||||
``registerBuiltinDownwardPropagator
|
||||
let declName ← resolveGlobalConstNoOverload stx[2]
|
||||
let val := mkAppN (mkConst addDeclName) #[toExpr declName, mkConst propagatorName]
|
||||
let initDeclName ← mkFreshUserName (propagatorName ++ `declare)
|
||||
declareBuiltin initDeclName val
|
||||
go.run' {}
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinAttribute {
|
||||
ref := by exact decl_name%
|
||||
name := `grindPropagatorBuiltinAttr
|
||||
descr := "Builtin `grind` propagator procedure"
|
||||
applicationTime := AttributeApplicationTime.afterCompilation
|
||||
erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]"
|
||||
add := fun declName stx _ => addBuiltin declName stx
|
||||
}
|
||||
|
||||
end Lean.Meta.Grind
|
||||
53
src/Lean/Meta/Tactic/Grind/Run.lean
Normal file
53
src/Lean/Meta/Tactic/Grind/Run.lean
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
/-
|
||||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Init.Grind.Lemmas
|
||||
import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.PropagatorAttr
|
||||
import Lean.Meta.Tactic.Grind.Proj
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
def mkMethods : CoreM Methods := do
|
||||
let builtinPropagators ← builtinPropagatorsRef.get
|
||||
return {
|
||||
propagateUp := fun e => do
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
propagateProjEq e
|
||||
if let some prop := builtinPropagators.up[declName]? then
|
||||
prop e
|
||||
propagateDown := fun e => do
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
if let some prop := builtinPropagators.down[declName]? then
|
||||
prop e
|
||||
}
|
||||
|
||||
def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
|
||||
let scState := ShareCommon.State.mk _
|
||||
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
|
||||
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
|
||||
let thms ← grindNormExt.getTheorems
|
||||
let simprocs := #[(← grindNormSimprocExt.getSimprocs)]
|
||||
let simp ← Simp.mkContext
|
||||
(config := { arith := true })
|
||||
(simpTheorems := #[thms])
|
||||
(congrTheorems := (← getSimpCongrTheorems))
|
||||
x (← mkMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }
|
||||
|
||||
@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
|
||||
goal.mvarId.withContext do StateRefT'.run x goal
|
||||
|
||||
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
|
||||
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal
|
||||
|
||||
def mkGoal (mvarId : MVarId) : GrindM Goal := do
|
||||
let trueExpr ← getTrueExpr
|
||||
let falseExpr ← getFalseExpr
|
||||
GoalM.run' { mvarId } do
|
||||
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
@ -296,6 +296,10 @@ def getENode (e : Expr) : GoalM ENode := do
|
|||
| throwError "internal `grind` error, term has not been internalized{indentExpr e}"
|
||||
return n
|
||||
|
||||
/-- Returns the generation of the given term. Is assumes it has been internalized -/
|
||||
def getGeneration (e : Expr) : GoalM Nat :=
|
||||
return (← getENode e).generation
|
||||
|
||||
/-- Returns `true` if `e` is in the equivalence class of `True`. -/
|
||||
def isEqTrue (e : Expr) : GoalM Bool := do
|
||||
let n ← getENode e
|
||||
|
|
@ -508,12 +512,12 @@ def forEachEqc (f : ENode → GoalM Unit) : GoalM Unit := do
|
|||
if isSameExpr n.self n.root then
|
||||
f n
|
||||
|
||||
private structure Methods where
|
||||
structure Methods where
|
||||
propagateUp : Propagator := fun _ => return ()
|
||||
propagateDown : Propagator := fun _ => return ()
|
||||
deriving Inhabited
|
||||
|
||||
private def Methods.toMethodsRef (m : Methods) : MethodsRef :=
|
||||
def Methods.toMethodsRef (m : Methods) : MethodsRef :=
|
||||
unsafe unsafeCast m
|
||||
|
||||
private def MethodsRef.toMethods (m : MethodsRef) : Methods :=
|
||||
|
|
@ -528,93 +532,6 @@ def propagateUp (e : Expr) : GoalM Unit := do
|
|||
def propagateDown (e : Expr) : GoalM Unit := do
|
||||
(← getMethods).propagateDown e
|
||||
|
||||
/-- Builtin propagators. -/
|
||||
structure BuiltinPropagators where
|
||||
up : Std.HashMap Name Propagator := {}
|
||||
down : Std.HashMap Name Propagator := {}
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize builtinPropagatorsRef : IO.Ref BuiltinPropagators ← IO.mkRef {}
|
||||
|
||||
private def registerBuiltinPropagatorCore (declName : Name) (up : Bool) (proc : Propagator) : IO Unit := do
|
||||
unless (← initializing) do
|
||||
throw (IO.userError s!"invalid builtin `grind` propagator declaration, it can only be registered during initialization")
|
||||
if up then
|
||||
if (← builtinPropagatorsRef.get).up.contains declName then
|
||||
throw (IO.userError s!"invalid builtin `grind` upward propagator `{declName}`, it has already been declared")
|
||||
builtinPropagatorsRef.modify fun { up, down } => { up := up.insert declName proc, down }
|
||||
else
|
||||
if (← builtinPropagatorsRef.get).down.contains declName then
|
||||
throw (IO.userError s!"invalid builtin `grind` downward propagator `{declName}`, it has already been declared")
|
||||
builtinPropagatorsRef.modify fun { up, down } => { up, down := down.insert declName proc }
|
||||
|
||||
def registerBuiltinUpwardPropagator (declName : Name) (proc : Propagator) : IO Unit :=
|
||||
registerBuiltinPropagatorCore declName true proc
|
||||
|
||||
def registerBuiltinDownwardPropagator (declName : Name) (proc : Propagator) : IO Unit :=
|
||||
registerBuiltinPropagatorCore declName false proc
|
||||
|
||||
private def addBuiltin (propagatorName : Name) (stx : Syntax) : AttrM Unit := do
|
||||
let go : MetaM Unit := do
|
||||
let up := stx[1].getKind == ``Lean.Parser.Tactic.simpPost
|
||||
let addDeclName := if up then
|
||||
``registerBuiltinUpwardPropagator
|
||||
else
|
||||
``registerBuiltinDownwardPropagator
|
||||
let declName ← resolveGlobalConstNoOverload stx[2]
|
||||
let val := mkAppN (mkConst addDeclName) #[toExpr declName, mkConst propagatorName]
|
||||
let initDeclName ← mkFreshUserName (propagatorName ++ `declare)
|
||||
declareBuiltin initDeclName val
|
||||
go.run' {}
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinAttribute {
|
||||
ref := by exact decl_name%
|
||||
name := `grindPropagatorBuiltinAttr
|
||||
descr := "Builtin `grind` propagator procedure"
|
||||
applicationTime := AttributeApplicationTime.afterCompilation
|
||||
erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]"
|
||||
add := fun declName stx _ => addBuiltin declName stx
|
||||
}
|
||||
|
||||
def mkMethods : CoreM Methods := do
|
||||
let builtinPropagators ← builtinPropagatorsRef.get
|
||||
return {
|
||||
propagateUp := fun e => do
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
if let some prop := builtinPropagators.up[declName]? then
|
||||
prop e
|
||||
propagateDown := fun e => do
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
if let some prop := builtinPropagators.down[declName]? then
|
||||
prop e
|
||||
}
|
||||
|
||||
def GrindM.run (x : GrindM α) (mainDeclName : Name) : MetaM α := do
|
||||
let scState := ShareCommon.State.mk _
|
||||
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
|
||||
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
|
||||
let thms ← grindNormExt.getTheorems
|
||||
let simprocs := #[(← grindNormSimprocExt.getSimprocs)]
|
||||
let simp ← Simp.mkContext
|
||||
(config := { arith := true })
|
||||
(simpTheorems := #[thms])
|
||||
(congrTheorems := (← getSimpCongrTheorems))
|
||||
x (← mkMethods).toMethodsRef { mainDeclName, simprocs, simp } |>.run' { scState, trueExpr, falseExpr }
|
||||
|
||||
@[inline] def GoalM.run (goal : Goal) (x : GoalM α) : GrindM (α × Goal) :=
|
||||
goal.mvarId.withContext do StateRefT'.run x goal
|
||||
|
||||
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
|
||||
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal
|
||||
|
||||
def mkGoal (mvarId : MVarId) : GrindM Goal := do
|
||||
let trueExpr ← getTrueExpr
|
||||
let falseExpr ← getFalseExpr
|
||||
GoalM.run' { mvarId } do
|
||||
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
|
||||
/-- Returns expressions in the given expression equivalence class. -/
|
||||
partial def getEqc (e : Expr) : GoalM (List Expr) :=
|
||||
go e e []
|
||||
|
|
|
|||
|
|
@ -55,3 +55,23 @@ example (a b c : BitVec 32) : a = c → a = 1#32 → c = 2#32 → c = b → Fals
|
|||
|
||||
example (a b c : UInt32) : a = c → a = 1 → c = 200 → c = b → False := by
|
||||
grind
|
||||
|
||||
structure Boo (α : Type) where
|
||||
a : α
|
||||
b : α
|
||||
c : α
|
||||
|
||||
example (a b d : Nat) (f : Nat → Boo Nat) : (f d).1 ≠ a → f d = ⟨b, v₁, v₂⟩ → b = a → False := by
|
||||
grind
|
||||
|
||||
def ex (a b c d : Nat) (f : Nat → Boo Nat) : (f d).2 ≠ a → f d = ⟨b, c, v₂⟩ → c = a → False := by
|
||||
grind
|
||||
|
||||
example (a b c : Nat) (f : Nat → Nat) : { a := f b, c, b := 4 : Boo Nat }.1 ≠ f a → f b = f c → a = c → False := by
|
||||
grind
|
||||
|
||||
example (a b c : Nat) (f : Nat → Nat) : p = { a := f b, c, b := 4 : Boo Nat } → p.1 ≠ f a → f b = f c → a = c → False := by
|
||||
grind
|
||||
|
||||
example (a b c : Nat) (f : Nat → Nat) : p.1 ≠ f a → p = { a := f b, c, b := 4 : Boo Nat } → f b = f c → a = c → False := by
|
||||
grind
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue