diff --git a/src/Init/Grind/Tactics.lean b/src/Init/Grind/Tactics.lean index 60e9976de2..4bccf6c073 100644 --- a/src/Init/Grind/Tactics.lean +++ b/src/Init/Grind/Tactics.lean @@ -111,6 +111,10 @@ structure Config where -/ cutsat := true /-- + When `true` (default: `true`), uses procedure for handling associative (and commutative) operators. + -/ + ac := true + /-- Maximum exponent eagerly evaluated while computing bounds for `ToInt` and the characteristic of a ring. -/ diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index ec8a736216..da96eb091d 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -36,6 +36,7 @@ public import Lean.Meta.Tactic.Grind.Lookahead public import Lean.Meta.Tactic.Grind.LawfulEqCmp public import Lean.Meta.Tactic.Grind.ReflCmp public import Lean.Meta.Tactic.Grind.SynthInstance +public import Lean.Meta.Tactic.Grind.AC public section diff --git a/src/Lean/Meta/Tactic/Grind/AC.lean b/src/Lean/Meta/Tactic/Grind/AC.lean new file mode 100644 index 0000000000..6f8183b879 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC.lean @@ -0,0 +1,19 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.AC.Types +public import Lean.Meta.Tactic.Grind.AC.Util +public import Lean.Meta.Tactic.Grind.AC.Var +public import Lean.Meta.Tactic.Grind.AC.Internalize +public section +namespace Lean +builtin_initialize registerTraceClass `grind.ac +builtin_initialize registerTraceClass `grind.ac.assert +builtin_initialize registerTraceClass `grind.ac.internalize + +builtin_initialize registerTraceClass `grind.debug.ac.op +end Lean diff --git a/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean b/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean new file mode 100644 index 0000000000..a890237b83 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Internalize.lean @@ -0,0 +1,27 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.Types +public import Lean.Meta.Tactic.Grind.AC.Util +public section +namespace Lean.Meta.Grind.AC + +private def isParentSameOpApp (parent? : Option Expr) (op : Expr) : GoalM Bool := do + let some e := parent? | return false + unless e.isApp && e.appFn!.isApp do return false + return isSameExpr e.appFn!.appFn! op + +def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do + unless (← getConfig).ac do return () + unless e.isApp && e.appFn!.isApp do return () + let op := e.appFn!.appFn! + let some id ← getOpId? op | return () + if (← isParentSameOpApp parent? op) then return () + trace[grind.ac.internalize] "[{id}] {e}" + -- TODO: internalize `e` + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Types.lean b/src/Lean/Meta/Tactic/Grind/AC/Types.lean new file mode 100644 index 0000000000..e111f6fee2 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Types.lean @@ -0,0 +1,54 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Init.Core +public import Init.Grind.AC +public import Std.Data.HashMap +public import Lean.Expr +public import Lean.Data.PersistentArray +public import Lean.Meta.Tactic.Grind.ExprPtr +public section + +namespace Lean.Meta.Grind.AC +open Lean.Grind.AC + +structure Struct where + id : Nat + type : Expr + /-- Cached `getLevel type` -/ + u : Level + op : Expr + neutral? : Option Expr + assocInst : Expr + idempotentInst? : Option Expr + commInst? : Option Expr + neutralInst? : Option Expr + /-- + Mapping from variables to their denotations. + Remark each variable can be in only one ring. + -/ + vars : PArray Expr := {} + /-- Mapping from `Expr` to a variable representing it. -/ + varMap : PHashMap ExprPtr Var := {} + deriving Inhabited + +/-- State for all associative operators detected by `grind`. -/ +structure State where + /-- + Structures/operators detected. + We expect to find a small number of associative operators in a given goal. Thus, using `Array` is fine here. + -/ + structs : Array Struct := {} + /-- + Mapping from operators to its "operator id". We cache failures using `none`. + `opIdOf[op]` is `some id`, then `id < structs.size`. -/ + opIdOf : PHashMap ExprPtr (Option Nat) := {} + -- Remark: a term may be argument of different associative operators. + -- TODO: add mappings + deriving Inhabited + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Util.lean b/src/Lean/Meta/Tactic/Grind/AC/Util.lean new file mode 100644 index 0000000000..8cc982012c --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Util.lean @@ -0,0 +1,119 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.Types +public import Lean.Meta.Tactic.Grind.ProveEq +public import Lean.Meta.Tactic.Grind.SynthInstance +public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId +public section + +namespace Lean.Meta.Grind.AC + +def get' : GoalM State := do + return (← get).ac + +@[inline] def modify' (f : State → State) : GoalM Unit := do + modify fun s => { s with ac := f s.ac } + +structure ACM.Context where + opId : Nat + +class MonadGetStruct (m : Type → Type) where + getStruct : m Struct + +export MonadGetStruct (getStruct) + +@[always_inline] +instance (m n) [MonadLift m n] [MonadGetStruct m] : MonadGetStruct n where + getStruct := liftM (getStruct : m Struct) + +abbrev ACM := ReaderT ACM.Context GoalM + +abbrev ACM.run (opId : Nat) (x : ACM α) : GoalM α := + x { opId } + +abbrev getOpId : ACM Nat := + return (← read).opId + +protected def ACM.getStruct : ACM Struct := do + let s ← get' + let opId ← getOpId + if h : opId < s.structs.size then + return s.structs[opId] + else + throwError "`grind` internal error, invalid structure id" + +instance : MonadGetStruct ACM where + getStruct := ACM.getStruct + +def getOp : ACM Expr := + return (← getStruct).op + +private def notAssoc : Std.HashSet Name := + Std.HashSet.ofList [``Eq, ``And, ``Or, ``Iff, ``getElem, ``OfNat.ofNat, ``ite, ``dite, ``cond, ``LT.lt, ``LE.le] + +/-- +Returns `true` if `op` is an arithmetic operator supported in other modules. +Remark: `f == op.getAppFn!` +-/ +private def isArithOpInOtherModules (op : Expr) (f : Expr) : GoalM Bool := do + unless (← getConfig).ring do return false + -- Remark: if `ring` is disabled we could check whether `cutsat` is enabled and discard `+` and `-`, but this is just a filter. + let .const declName _ := f | return false + if declName == ``HAdd.hAdd || declName == ``HMul.hMul || declName == ``HSub.hSub || declName == ``HDiv.hDiv || declName == ``HPow.hPow then + if op.getAppNumArgs == 4 then + let α := op.appFn!.appFn!.appArg! + if (← Arith.CommRing.getRingId? α).isSome then return true + if (← Arith.CommRing.getSemiringId? α).isSome then return true + return false + +def getOpId? (op : Expr) : GoalM (Option Nat) := do + if let some id? := (← get').opIdOf.find? { expr := op } then + return id? + let id? ← go + modify' fun s => { s with opIdOf := s.opIdOf.insert { expr := op } id? } + return id? +where + go : GoalM (Option Nat) := do + let f := op.getAppFn + if let .const declName _ := f then + if notAssoc.contains declName then return none + let .forallE _ α b _ ← whnf (← inferType op) | return none + if b.hasLooseBVars then return none + let .forallE _ α₂ α₃ _ ← whnf b | return none + if α₃.hasLooseBVars then return none + unless (← isDefEq α α₂) do return none + unless (← isDefEq α α₃) do return none + if (← isArithOpInOtherModules op f) then return none + let u ← getLevel α + let assocType := mkApp2 (mkConst ``Std.Associative [u]) α op + let some assocInst ← synthInstance? assocType | return none + let commType := mkApp2 (mkConst ``Std.Commutative [u]) α op + let commInst? ← synthInstance? commType + let idempotentType := mkApp2 (mkConst ``Std.IdempotentOp [u]) α op + let idempotentInst? ← synthInstance? idempotentType + let (neutralInst?, neutral?) ← do + let neutral ← mkFreshExprMVar α + let identityType := mkApp3 (mkConst ``Std.Identity [u]) α op neutral + if let some identityInst ← synthInstance? identityType then + let neutral ← instantiateExprMVars neutral + let neutral ← preprocessLight neutral + internalize neutral (← getGeneration op) + pure (some identityInst, some neutral) + else + pure (none, none) + let id := (← get').structs.size + modify' fun s => { s with + structs := s.structs.push { + id, type := α, u, op, neutral?, assocInst, commInst?, + idempotentInst?, neutralInst? + }} + -- TODO: neutral element must be variable 0 + trace[grind.debug.ac.op] "{op}, comm: {commInst?.isSome}, idempotent: {idempotentInst?.isSome}, neutral?: {neutral?}" + return some id + +end Lean.Meta.Grind.AC diff --git a/src/Lean/Meta/Tactic/Grind/AC/Var.lean b/src/Lean/Meta/Tactic/Grind/AC/Var.lean new file mode 100644 index 0000000000..084a258181 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/AC/Var.lean @@ -0,0 +1,14 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Tactic.Grind.Util +public section +namespace Lean.Meta.Grind.Arith.Var + +-- TODO: add mkVar + +end Lean.Meta.Grind.Arith.Var diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean index ca8ce5aa39..148b3f122a 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean @@ -267,7 +267,6 @@ private def propagateNonlinearPow (x : Var) : GoalM Bool := do pure (kb.toNat, some cb) else return false - trace[Meta.debug] ">> e: {e}, k: {ka^kb}" let c' ← pure { p := .add 1 x (.num (-(ka^kb))), h := .pow ka ca? kb cb? : EqCnstr } c'.assert return true diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean index a40fbd78d2..09020f6f48 100644 --- a/src/Lean/Meta/Tactic/Grind/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -19,6 +19,7 @@ public import Lean.Meta.Tactic.Grind.Canon public import Lean.Meta.Tactic.Grind.Beta public import Lean.Meta.Tactic.Grind.MatchCond public import Lean.Meta.Tactic.Grind.Arith.Internalize +public import Lean.Meta.Tactic.Grind.AC.Internalize public section @@ -361,6 +362,10 @@ private def tryEta (e : Expr) (generation : Nat) : GoalM Unit := do internalize e' generation pushEq e e' (← mkEqRefl e) +private def internalizeTheories (e : Expr) (parent? : Option Expr := none) : GoalM Unit := do + Arith.internalize e parent? + AC.internalize e parent? + @[export lean_grind_internalize] private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Option Expr := none) : GoalM Unit := withIncRecDepth do if (← alreadyInternalized e) then @@ -373,7 +378,7 @@ private partial def internalizeImpl (e : Expr) (generation : Nat) (parent? : Opt Later, if we try to internalize `f 1`, the arithmetic module must create a node for `1`. Otherwise, it will not be able to propagate that `a + 1 = 1` when `a = 0` -/ - Arith.internalize e parent? + internalizeTheories e parent? else go propagateEtaStruct e generation @@ -422,7 +427,7 @@ where if (← isLitValue e) then -- We do not want to internalize the components of a literal value. mkENode e generation - Arith.internalize e parent? + internalizeTheories e parent? else if e.isAppOfArity ``Grind.MatchCond 1 then internalizeMatchCond e generation else e.withApp fun f args => do @@ -459,7 +464,7 @@ where internalize arg generation e registerParent e arg addCongrTable e - Arith.internalize e parent? + internalizeTheories e parent? propagateUp e propagateBetaForNewApp e diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index bfc9f0e150..e39b0468a4 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -22,6 +22,7 @@ public import Lean.Meta.Tactic.Grind.Attr public import Lean.Meta.Tactic.Grind.ExtAttr public import Lean.Meta.Tactic.Grind.Cases public import Lean.Meta.Tactic.Grind.Arith.Types +public import Lean.Meta.Tactic.Grind.AC.Types public import Lean.Meta.Tactic.Grind.EMatchTheorem meta import Lean.Parser.Do import Lean.Meta.Match.MatchEqsExt @@ -764,6 +765,8 @@ structure Goal where split : Split.State := {} /-- State of arithmetic procedures. -/ arith : Arith.State := {} + /-- State of the ac solver. -/ + ac : AC.State := {} /-- State of the clean name generator. -/ clean : Clean.State := {} deriving Inhabited diff --git a/tests/lean/run/grind_sort_eqc.lean b/tests/lean/run/grind_sort_eqc.lean index c38293a632..3bc0cd6639 100644 --- a/tests/lean/run/grind_sort_eqc.lean +++ b/tests/lean/run/grind_sort_eqc.lean @@ -70,8 +70,8 @@ h_2 : ¬f (f x) = g x x [eqc] Equivalence classes [eqc] {x, g x x} [eqc] {z, g y z} - [eqc] {g z y, g y x} [eqc] {0, f x + -1 * f (f x) + -1, f (f x) + -1 * f (f (f x)) + -1, f (f (f x)) + -1 * f (f (f (f x))) + -1} + [eqc] {g z y, g y x} [ematch] E-matching patterns [thm] feq: [@f #4 #3 #0] [thm] geq: [@g #2 #1 #0 #0]