feat: associative operator detection in grind (#10105)
This PR adds support for detecting associative operators in `grind`. The new AC module also detects whether the operator is commutative, idempotent, and whether it has a neutral element. The information is cached.
This commit is contained in:
parent
cc5ff2afb1
commit
9be2eab93d
11 changed files with 250 additions and 5 deletions
|
|
@ -111,6 +111,10 @@ structure Config where
|
|||
-/
|
||||
cutsat := true
|
||||
/--
|
||||
When `true` (default: `true`), uses procedure for handling associative (and commutative) operators.
|
||||
-/
|
||||
ac := true
|
||||
/--
|
||||
Maximum exponent eagerly evaluated while computing bounds for `ToInt` and
|
||||
the characteristic of a ring.
|
||||
-/
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ public import Lean.Meta.Tactic.Grind.Lookahead
|
|||
public import Lean.Meta.Tactic.Grind.LawfulEqCmp
|
||||
public import Lean.Meta.Tactic.Grind.ReflCmp
|
||||
public import Lean.Meta.Tactic.Grind.SynthInstance
|
||||
public import Lean.Meta.Tactic.Grind.AC
|
||||
|
||||
public section
|
||||
|
||||
|
|
|
|||
19
src/Lean/Meta/Tactic/Grind/AC.lean
Normal file
19
src/Lean/Meta/Tactic/Grind/AC.lean
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.AC.Types
|
||||
public import Lean.Meta.Tactic.Grind.AC.Util
|
||||
public import Lean.Meta.Tactic.Grind.AC.Var
|
||||
public import Lean.Meta.Tactic.Grind.AC.Internalize
|
||||
public section
|
||||
namespace Lean
|
||||
builtin_initialize registerTraceClass `grind.ac
|
||||
builtin_initialize registerTraceClass `grind.ac.assert
|
||||
builtin_initialize registerTraceClass `grind.ac.internalize
|
||||
|
||||
builtin_initialize registerTraceClass `grind.debug.ac.op
|
||||
end Lean
|
||||
27
src/Lean/Meta/Tactic/Grind/AC/Internalize.lean
Normal file
27
src/Lean/Meta/Tactic/Grind/AC/Internalize.lean
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
public import Lean.Meta.Tactic.Grind.AC.Util
|
||||
public section
|
||||
namespace Lean.Meta.Grind.AC
|
||||
|
||||
private def isParentSameOpApp (parent? : Option Expr) (op : Expr) : GoalM Bool := do
|
||||
let some e := parent? | return false
|
||||
unless e.isApp && e.appFn!.isApp do return false
|
||||
return isSameExpr e.appFn!.appFn! op
|
||||
|
||||
def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
||||
unless (← getConfig).ac do return ()
|
||||
unless e.isApp && e.appFn!.isApp do return ()
|
||||
let op := e.appFn!.appFn!
|
||||
let some id ← getOpId? op | return ()
|
||||
if (← isParentSameOpApp parent? op) then return ()
|
||||
trace[grind.ac.internalize] "[{id}] {e}"
|
||||
-- TODO: internalize `e`
|
||||
|
||||
end Lean.Meta.Grind.AC
|
||||
54
src/Lean/Meta/Tactic/Grind/AC/Types.lean
Normal file
54
src/Lean/Meta/Tactic/Grind/AC/Types.lean
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Init.Core
|
||||
public import Init.Grind.AC
|
||||
public import Std.Data.HashMap
|
||||
public import Lean.Expr
|
||||
public import Lean.Data.PersistentArray
|
||||
public import Lean.Meta.Tactic.Grind.ExprPtr
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Grind.AC
|
||||
open Lean.Grind.AC
|
||||
|
||||
structure Struct where
|
||||
id : Nat
|
||||
type : Expr
|
||||
/-- Cached `getLevel type` -/
|
||||
u : Level
|
||||
op : Expr
|
||||
neutral? : Option Expr
|
||||
assocInst : Expr
|
||||
idempotentInst? : Option Expr
|
||||
commInst? : Option Expr
|
||||
neutralInst? : Option Expr
|
||||
/--
|
||||
Mapping from variables to their denotations.
|
||||
Remark each variable can be in only one ring.
|
||||
-/
|
||||
vars : PArray Expr := {}
|
||||
/-- Mapping from `Expr` to a variable representing it. -/
|
||||
varMap : PHashMap ExprPtr Var := {}
|
||||
deriving Inhabited
|
||||
|
||||
/-- State for all associative operators detected by `grind`. -/
|
||||
structure State where
|
||||
/--
|
||||
Structures/operators detected.
|
||||
We expect to find a small number of associative operators in a given goal. Thus, using `Array` is fine here.
|
||||
-/
|
||||
structs : Array Struct := {}
|
||||
/--
|
||||
Mapping from operators to its "operator id". We cache failures using `none`.
|
||||
`opIdOf[op]` is `some id`, then `id < structs.size`. -/
|
||||
opIdOf : PHashMap ExprPtr (Option Nat) := {}
|
||||
-- Remark: a term may be argument of different associative operators.
|
||||
-- TODO: add mappings
|
||||
deriving Inhabited
|
||||
|
||||
end Lean.Meta.Grind.AC
|
||||
119
src/Lean/Meta/Tactic/Grind/AC/Util.lean
Normal file
119
src/Lean/Meta/Tactic/Grind/AC/Util.lean
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
public import Lean.Meta.Tactic.Grind.ProveEq
|
||||
public import Lean.Meta.Tactic.Grind.SynthInstance
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Grind.AC
|
||||
|
||||
def get' : GoalM State := do
|
||||
return (← get).ac
|
||||
|
||||
@[inline] def modify' (f : State → State) : GoalM Unit := do
|
||||
modify fun s => { s with ac := f s.ac }
|
||||
|
||||
structure ACM.Context where
|
||||
opId : Nat
|
||||
|
||||
class MonadGetStruct (m : Type → Type) where
|
||||
getStruct : m Struct
|
||||
|
||||
export MonadGetStruct (getStruct)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadGetStruct m] : MonadGetStruct n where
|
||||
getStruct := liftM (getStruct : m Struct)
|
||||
|
||||
abbrev ACM := ReaderT ACM.Context GoalM
|
||||
|
||||
abbrev ACM.run (opId : Nat) (x : ACM α) : GoalM α :=
|
||||
x { opId }
|
||||
|
||||
abbrev getOpId : ACM Nat :=
|
||||
return (← read).opId
|
||||
|
||||
protected def ACM.getStruct : ACM Struct := do
|
||||
let s ← get'
|
||||
let opId ← getOpId
|
||||
if h : opId < s.structs.size then
|
||||
return s.structs[opId]
|
||||
else
|
||||
throwError "`grind` internal error, invalid structure id"
|
||||
|
||||
instance : MonadGetStruct ACM where
|
||||
getStruct := ACM.getStruct
|
||||
|
||||
def getOp : ACM Expr :=
|
||||
return (← getStruct).op
|
||||
|
||||
private def notAssoc : Std.HashSet Name :=
|
||||
Std.HashSet.ofList [``Eq, ``And, ``Or, ``Iff, ``getElem, ``OfNat.ofNat, ``ite, ``dite, ``cond, ``LT.lt, ``LE.le]
|
||||
|
||||
/--
|
||||
Returns `true` if `op` is an arithmetic operator supported in other modules.
|
||||
Remark: `f == op.getAppFn!`
|
||||
-/
|
||||
private def isArithOpInOtherModules (op : Expr) (f : Expr) : GoalM Bool := do
|
||||
unless (← getConfig).ring do return false
|
||||
-- Remark: if `ring` is disabled we could check whether `cutsat` is enabled and discard `+` and `-`, but this is just a filter.
|
||||
let .const declName _ := f | return false
|
||||
if declName == ``HAdd.hAdd || declName == ``HMul.hMul || declName == ``HSub.hSub || declName == ``HDiv.hDiv || declName == ``HPow.hPow then
|
||||
if op.getAppNumArgs == 4 then
|
||||
let α := op.appFn!.appFn!.appArg!
|
||||
if (← Arith.CommRing.getRingId? α).isSome then return true
|
||||
if (← Arith.CommRing.getSemiringId? α).isSome then return true
|
||||
return false
|
||||
|
||||
def getOpId? (op : Expr) : GoalM (Option Nat) := do
|
||||
if let some id? := (← get').opIdOf.find? { expr := op } then
|
||||
return id?
|
||||
let id? ← go
|
||||
modify' fun s => { s with opIdOf := s.opIdOf.insert { expr := op } id? }
|
||||
return id?
|
||||
where
|
||||
go : GoalM (Option Nat) := do
|
||||
let f := op.getAppFn
|
||||
if let .const declName _ := f then
|
||||
if notAssoc.contains declName then return none
|
||||
let .forallE _ α b _ ← whnf (← inferType op) | return none
|
||||
if b.hasLooseBVars then return none
|
||||
let .forallE _ α₂ α₃ _ ← whnf b | return none
|
||||
if α₃.hasLooseBVars then return none
|
||||
unless (← isDefEq α α₂) do return none
|
||||
unless (← isDefEq α α₃) do return none
|
||||
if (← isArithOpInOtherModules op f) then return none
|
||||
let u ← getLevel α
|
||||
let assocType := mkApp2 (mkConst ``Std.Associative [u]) α op
|
||||
let some assocInst ← synthInstance? assocType | return none
|
||||
let commType := mkApp2 (mkConst ``Std.Commutative [u]) α op
|
||||
let commInst? ← synthInstance? commType
|
||||
let idempotentType := mkApp2 (mkConst ``Std.IdempotentOp [u]) α op
|
||||
let idempotentInst? ← synthInstance? idempotentType
|
||||
let (neutralInst?, neutral?) ← do
|
||||
let neutral ← mkFreshExprMVar α
|
||||
let identityType := mkApp3 (mkConst ``Std.Identity [u]) α op neutral
|
||||
if let some identityInst ← synthInstance? identityType then
|
||||
let neutral ← instantiateExprMVars neutral
|
||||
let neutral ← preprocessLight neutral
|
||||
internalize neutral (← getGeneration op)
|
||||
pure (some identityInst, some neutral)
|
||||
else
|
||||
pure (none, none)
|
||||
let id := (← get').structs.size
|
||||
modify' fun s => { s with
|
||||
structs := s.structs.push {
|
||||
id, type := α, u, op, neutral?, assocInst, commInst?,
|
||||
idempotentInst?, neutralInst?
|
||||
}}
|
||||
-- TODO: neutral element must be variable 0
|
||||
trace[grind.debug.ac.op] "{op}, comm: {commInst?.isSome}, idempotent: {idempotentInst?.isSome}, neutral?: {neutral?}"
|
||||
return some id
|
||||
|
||||
end Lean.Meta.Grind.AC
|
||||
14
src/Lean/Meta/Tactic/Grind/AC/Var.lean
Normal file
14
src/Lean/Meta/Tactic/Grind/AC/Var.lean
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Util
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.Var
|
||||
|
||||
-- TODO: add mkVar
|
||||
|
||||
end Lean.Meta.Grind.Arith.Var
|
||||
|
|
@ -267,7 +267,6 @@ private def propagateNonlinearPow (x : Var) : GoalM Bool := do
|
|||
pure (kb.toNat, some cb)
|
||||
else
|
||||
return false
|
||||
trace[Meta.debug] ">> e: {e}, k: {ka^kb}"
|
||||
let c' ← pure { p := .add 1 x (.num (-(ka^kb))), h := .pow ka ca? kb cb? : EqCnstr }
|
||||
c'.assert
|
||||
return true
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ public import Lean.Meta.Tactic.Grind.Canon
|
|||
public import Lean.Meta.Tactic.Grind.Beta
|
||||
public import Lean.Meta.Tactic.Grind.MatchCond
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Internalize
|
||||
public import Lean.Meta.Tactic.Grind.AC.Internalize
|
||||
|
||||
public section
|
||||
|
||||
|
|
@ -361,6 +362,10 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do
|
|||
internalize e' generation
|
||||
pushEq e e' (← mkEqRefl e)
|
||||
|
||||
private def internalizeTheories (e : Expr) (parent? : Option Expr := none) : GoalM Unit := do
|
||||
Arith.internalize e parent?
|
||||
AC.internalize e parent?
|
||||
|
||||
@[export lean_grind_internalize]
|
||||
private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do
|
||||
if (← alreadyInternalized e) then
|
||||
|
|
@ -373,7 +378,7 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt
|
|||
Later, if we try to internalize `f 1`, the arithmetic module must create a node for `1`.
|
||||
Otherwise, it will not be able to propagate that `a + 1 = 1` when `a = 0`
|
||||
-/
|
||||
Arith.internalize e parent?
|
||||
internalizeTheories e parent?
|
||||
else
|
||||
go
|
||||
propagateEtaStruct e generation
|
||||
|
|
@ -422,7 +427,7 @@ where
|
|||
if (← isLitValue e) then
|
||||
-- We do not want to internalize the components of a literal value.
|
||||
mkENode e generation
|
||||
Arith.internalize e parent?
|
||||
internalizeTheories e parent?
|
||||
else if e.isAppOfArity ``Grind.MatchCond 1 then
|
||||
internalizeMatchCond e generation
|
||||
else e.withApp fun f args => do
|
||||
|
|
@ -459,7 +464,7 @@ where
|
|||
internalize arg generation e
|
||||
registerParent e arg
|
||||
addCongrTable e
|
||||
Arith.internalize e parent?
|
||||
internalizeTheories e parent?
|
||||
propagateUp e
|
||||
propagateBetaForNewApp e
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ public import Lean.Meta.Tactic.Grind.Attr
|
|||
public import Lean.Meta.Tactic.Grind.ExtAttr
|
||||
public import Lean.Meta.Tactic.Grind.Cases
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Types
|
||||
public import Lean.Meta.Tactic.Grind.AC.Types
|
||||
public import Lean.Meta.Tactic.Grind.EMatchTheorem
|
||||
meta import Lean.Parser.Do
|
||||
import Lean.Meta.Match.MatchEqsExt
|
||||
|
|
@ -764,6 +765,8 @@ structure Goal where
|
|||
split : Split.State := {}
|
||||
/-- State of arithmetic procedures. -/
|
||||
arith : Arith.State := {}
|
||||
/-- State of the ac solver. -/
|
||||
ac : AC.State := {}
|
||||
/-- State of the clean name generator. -/
|
||||
clean : Clean.State := {}
|
||||
deriving Inhabited
|
||||
|
|
|
|||
|
|
@ -70,8 +70,8 @@ h_2 : ¬f (f x) = g x x
|
|||
[eqc] Equivalence classes
|
||||
[eqc] {x, g x x}
|
||||
[eqc] {z, g y z}
|
||||
[eqc] {g z y, g y x}
|
||||
[eqc] {0, f x + -1 * f (f x) + -1, f (f x) + -1 * f (f (f x)) + -1, f (f (f x)) + -1 * f (f (f (f x))) + -1}
|
||||
[eqc] {g z y, g y x}
|
||||
[ematch] E-matching patterns
|
||||
[thm] feq: [@f #4 #3 #0]
|
||||
[thm] geq: [@g #2 #1 #0 #0]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue