perf: generate grind ring instances on demand (#9228)
This PR improves the startup time for `grind ring` by generating the required type classes on demand. This optimization is particularly relevant for files that make hundreds of calls to `grind`, such as `tests/lean/run/grind_bitvec2.lean`. For example, before this change, `grind` spent 6.87 seconds synthesizing type classes, compared to 3.92 seconds after this PR.
This commit is contained in:
parent
c9debdaf2a
commit
5d46391dde
16 changed files with 334 additions and 233 deletions
|
|
@ -14,18 +14,24 @@ namespace Lean.Meta.Grind.Arith.CommRing
|
|||
Helper functions for converting reified terms back into their denotations.
|
||||
-/
|
||||
|
||||
variable [Monad M] [MonadGetRing M]
|
||||
variable [Monad M] [MonadError M] [MonadLiftT MetaM M] [MonadRing M]
|
||||
|
||||
def denoteNum (k : Int) : M Expr := do
|
||||
let ring ← getRing
|
||||
return denoteNumCore ring.u ring.type ring.semiringInst ring.negFn k
|
||||
let n := mkRawNatLit k.natAbs
|
||||
let ofNatInst := mkApp3 (mkConst ``Grind.Semiring.ofNat [ring.u]) ring.type ring.semiringInst n
|
||||
let n := mkApp3 (mkConst ``OfNat.ofNat [ring.u]) ring.type n ofNatInst
|
||||
if k < 0 then
|
||||
return mkApp (← getNegFn) n
|
||||
else
|
||||
return n
|
||||
|
||||
def _root_.Lean.Grind.CommRing.Power.denoteExpr (pw : Power) : M Expr := do
|
||||
let x := (← getRing).vars[pw.x]!
|
||||
if pw.k == 1 then
|
||||
return x
|
||||
else
|
||||
return mkApp2 (← getRing).powFn x (toExpr pw.k)
|
||||
return mkApp2 (← getPowFn) x (toExpr pw.k)
|
||||
|
||||
def _root_.Lean.Grind.CommRing.Mon.denoteExpr (m : Mon) : M Expr := do
|
||||
match m with
|
||||
|
|
@ -35,7 +41,7 @@ where
|
|||
go (m : Mon) (acc : Expr) : M Expr := do
|
||||
match m with
|
||||
| .unit => return acc
|
||||
| .mult pw m => go m (mkApp2 (← getRing).mulFn acc (← pw.denoteExpr))
|
||||
| .mult pw m => go m (mkApp2 (← getMulFn) acc (← pw.denoteExpr))
|
||||
|
||||
def _root_.Lean.Grind.CommRing.Poly.denoteExpr (p : Poly) : M Expr := do
|
||||
match p with
|
||||
|
|
@ -46,13 +52,13 @@ where
|
|||
if k == 1 then
|
||||
m.denoteExpr
|
||||
else
|
||||
return mkApp2 (← getRing).mulFn (← denoteNum k) (← m.denoteExpr)
|
||||
return mkApp2 (← getMulFn) (← denoteNum k) (← m.denoteExpr)
|
||||
|
||||
go (p : Poly) (acc : Expr) : M Expr := do
|
||||
match p with
|
||||
| .num 0 => return acc
|
||||
| .num k => return mkApp2 (← getRing).addFn acc (← denoteNum k)
|
||||
| .add k m p => go p (mkApp2 (← getRing).addFn acc (← denoteTerm k m))
|
||||
| .num k => return mkApp2 (← getAddFn) acc (← denoteNum k)
|
||||
| .add k m p => go p (mkApp2 (← getAddFn) acc (← denoteTerm k m))
|
||||
|
||||
def _root_.Lean.Grind.CommRing.Expr.denoteExpr (e : RingExpr) : M Expr := do
|
||||
go e
|
||||
|
|
@ -60,11 +66,11 @@ where
|
|||
go : RingExpr → M Expr
|
||||
| .num k => denoteNum k
|
||||
| .var x => return (← getRing).vars[x]!
|
||||
| .add a b => return mkApp2 (← getRing).addFn (← go a) (← go b)
|
||||
| .sub a b => return mkApp2 (← getRing).subFn (← go a) (← go b)
|
||||
| .mul a b => return mkApp2 (← getRing).mulFn (← go a) (← go b)
|
||||
| .pow a k => return mkApp2 (← getRing).powFn (← go a) (toExpr k)
|
||||
| .neg a => return mkApp (← getRing).negFn (← go a)
|
||||
| .add a b => return mkApp2 (← getAddFn) (← go a) (← go b)
|
||||
| .sub a b => return mkApp2 (← getSubFn) (← go a) (← go b)
|
||||
| .mul a b => return mkApp2 (← getMulFn) (← go a) (← go b)
|
||||
| .pow a k => return mkApp2 (← getPowFn) (← go a) (toExpr k)
|
||||
| .neg a => return mkApp (← getNegFn) (← go a)
|
||||
|
||||
private def mkEq (a b : Expr) : M Expr := do
|
||||
let r ← getRing
|
||||
|
|
@ -84,9 +90,9 @@ def _root_.Lean.Grind.Ring.OfSemiring.Expr.denoteAsRingExpr (e : SemiringExpr) :
|
|||
where
|
||||
go : SemiringExpr → SemiringM Expr
|
||||
| .num k => denoteNum k
|
||||
| .var x => return mkApp (← getSemiring).toQFn (← getSemiring).vars[x]!
|
||||
| .add a b => return mkApp2 (← getRing).addFn (← go a) (← go b)
|
||||
| .mul a b => return mkApp2 (← getRing).mulFn (← go a) (← go b)
|
||||
| .pow a k => return mkApp2 (← getRing).powFn (← go a) (toExpr k)
|
||||
| .var x => return mkApp (← getToQFn) (← getSemiring).vars[x]!
|
||||
| .add a b => return mkApp2 (← getAddFn) (← go a) (← go b)
|
||||
| .mul a b => return mkApp2 (← getMulFn) (← go a) (← go b)
|
||||
| .pow a k => return mkApp2 (← getPowFn) (← go a) (toExpr k)
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
|
|||
|
|
@ -359,13 +359,13 @@ private def diseqToEq (a b : Expr) : RingM Unit := do
|
|||
let gen := max (← getGeneration a) (← getGeneration b)
|
||||
let ring ← getRing
|
||||
let some fieldInst := ring.fieldInst? | unreachable!
|
||||
let e ← pre <| mkApp2 ring.subFn a b
|
||||
let e ← pre <| mkApp2 (← getSubFn) a b
|
||||
modifyRing fun s => { s with invSet := s.invSet.insert e }
|
||||
let eInv ← pre <| mkApp (← getRing).invFn?.get! e
|
||||
let lhs ← pre <| mkApp2 ring.mulFn e eInv
|
||||
let eInv ← pre <| mkApp (← getInvFn) e
|
||||
let lhs ← pre <| mkApp2 (← getMulFn) e eInv
|
||||
internalize lhs gen none
|
||||
trace[grind.debug.ring.rabinowitsch] "{lhs}"
|
||||
pushEq lhs ring.one <| mkApp5 (mkConst ``Grind.CommRing.diseq_to_eq [ring.u]) ring.type fieldInst a b (← mkDiseqProof a b)
|
||||
pushEq lhs (← getOne) <| mkApp5 (mkConst ``Grind.CommRing.diseq_to_eq [ring.u]) ring.type fieldInst a b (← mkDiseqProof a b)
|
||||
|
||||
private def diseqZeroToEq (a b : Expr) : RingM Unit := do
|
||||
-- Rabinowitsch transformation for `b = 0` case
|
||||
|
|
@ -373,11 +373,11 @@ private def diseqZeroToEq (a b : Expr) : RingM Unit := do
|
|||
let ring ← getRing
|
||||
let some fieldInst := ring.fieldInst? | unreachable!
|
||||
modifyRing fun s => { s with invSet := s.invSet.insert a }
|
||||
let aInv ← pre <| mkApp (← getRing).invFn?.get! a
|
||||
let lhs ← pre <| mkApp2 ring.mulFn a aInv
|
||||
let aInv ← pre <| mkApp (← getInvFn) a
|
||||
let lhs ← pre <| mkApp2 (← getMulFn) a aInv
|
||||
internalize lhs gen none
|
||||
trace[grind.debug.ring.rabinowitsch] "{lhs}"
|
||||
pushEq lhs ring.one <| mkApp4 (mkConst ``Grind.CommRing.diseq0_to_eq [ring.u]) ring.type fieldInst a (← mkDiseqProof a b)
|
||||
pushEq lhs (← getOne) <| mkApp4 (mkConst ``Grind.CommRing.diseq0_to_eq [ring.u]) ring.type fieldInst a (← mkDiseqProof a b)
|
||||
|
||||
@[export lean_process_ring_diseq]
|
||||
def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
|
||||
|
|
@ -400,7 +400,7 @@ def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
|
|||
ofSemiring? := none
|
||||
}
|
||||
else if let some semiringId ← inSameSemiring? a b then SemiringM.run semiringId do
|
||||
if (← getSemiring).addRightCancelInst?.isSome then
|
||||
if (← getAddRightCancelInst?).isSome then
|
||||
if (← getConfig).ringNull then return () -- TODO: remove after we add Nullstellensatz certificates for semiring adapter
|
||||
trace_goal[grind.ring.assert] "{mkNot (← mkEq a b)}"
|
||||
let some sa ← toSemiringExpr? a | return ()
|
||||
|
|
|
|||
|
|
@ -46,11 +46,11 @@ private def isForbiddenParent (parent? : Option Expr) : Bool :=
|
|||
private partial def toInt? (e : Expr) : RingM (Option Int) := do
|
||||
match_expr e with
|
||||
| Neg.neg _ i a =>
|
||||
if isNegInst (← getRing) i then return (- .) <$> (← toInt? a) else return none
|
||||
if (← isNegInst i) then return (- .) <$> (← toInt? a) else return none
|
||||
| IntCast.intCast _ i a =>
|
||||
if isIntCastInst (← getRing) i then getIntValue? a else return none
|
||||
if (← isIntCastInst i) then getIntValue? a else return none
|
||||
| NatCast.natCast _ i a =>
|
||||
if isNatCastInst (← getRing) i then
|
||||
if (← isNatCastInst i) then
|
||||
let some v ← getNatValue? a | return none
|
||||
return some (Int.ofNat v)
|
||||
else
|
||||
|
|
@ -61,8 +61,8 @@ private partial def toInt? (e : Expr) : RingM (Option Int) := do
|
|||
| _ => return none
|
||||
|
||||
private def isInvInst (inst : Expr) : RingM Bool := do
|
||||
let some fn := (← getRing).invFn? | return false
|
||||
return isSameExpr fn.appArg! inst
|
||||
if (← getRing).fieldInst?.isNone then return false
|
||||
return isSameExpr (← getInvFn).appArg! inst
|
||||
|
||||
/--
|
||||
Given `e` of the form `@Inv.inv _ inst a`,
|
||||
|
|
@ -80,7 +80,7 @@ private def processInv (e inst a : Expr) : RingM Unit := do
|
|||
if (← hasChar) then
|
||||
let (charInst, c) ← getCharInst
|
||||
if c == 0 then
|
||||
let expected ← mkEq (mkApp2 ring.mulFn a e) (← denoteNum 1)
|
||||
let expected ← mkEq (mkApp2 (← getMulFn) a e) (← denoteNum 1)
|
||||
pushNewFact <| mkExpectedPropHint
|
||||
(mkApp5 (mkConst ``Grind.CommRing.inv_int_eq [ring.u]) ring.type fieldInst charInst (mkIntLit k) reflBoolTrue)
|
||||
expected
|
||||
|
|
@ -90,7 +90,7 @@ private def processInv (e inst a : Expr) : RingM Unit := do
|
|||
(mkApp6 (mkConst ``Grind.CommRing.inv_zero_eqC [ring.u]) ring.type (mkNatLit c) fieldInst charInst (mkIntLit k) reflBoolTrue)
|
||||
expected
|
||||
else
|
||||
let expected ← mkEq (mkApp2 ring.mulFn a e) (← denoteNum 1)
|
||||
let expected ← mkEq (mkApp2 (← getMulFn) a e) (← denoteNum 1)
|
||||
pushNewFact <| mkExpectedPropHint
|
||||
(mkApp6 (mkConst ``Grind.CommRing.inv_int_eqC [ring.u]) ring.type (mkNatLit c) fieldInst charInst (mkIntLit k) reflBoolTrue)
|
||||
expected
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@ import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
|||
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
instance : MonadGetRing (ReaderT Ring MetaM) where
|
||||
getRing := read
|
||||
private abbrev M := StateT Ring MetaM
|
||||
|
||||
private def M := ReaderT Goal (StateT (Array MessageData) MetaM)
|
||||
instance : MonadRing M where
|
||||
getRing := get
|
||||
modifyRing := modify
|
||||
canonExpr e := return e
|
||||
synthInstance? e := Meta.synthInstance? e none
|
||||
|
||||
private def toOption (cls : Name) (header : Thunk MessageData) (msgs : Array MessageData) : Option MessageData :=
|
||||
if msgs.isEmpty then
|
||||
|
|
@ -22,19 +25,19 @@ private def toOption (cls : Name) (header : Thunk MessageData) (msgs : Array Mes
|
|||
private def push (msgs : Array MessageData) (msg? : Option MessageData) : Array MessageData :=
|
||||
if let some msg := msg? then msgs.push msg else msgs
|
||||
|
||||
def ppBasis? : ReaderT Ring MetaM (Option MessageData) := do
|
||||
def ppBasis? : M (Option MessageData) := do
|
||||
let mut basis := #[]
|
||||
for c in (← getRing).basis do
|
||||
basis := basis.push (toTraceElem (← c.denoteExpr))
|
||||
return toOption `basis "Basis" basis
|
||||
|
||||
def ppDiseqs? : ReaderT Ring MetaM (Option MessageData) := do
|
||||
def ppDiseqs? : M (Option MessageData) := do
|
||||
let mut diseqs := #[]
|
||||
for d in (← getRing).diseqs do
|
||||
diseqs := diseqs.push (toTraceElem (← d.denoteExpr))
|
||||
return toOption `diseqs "Disequalities" diseqs
|
||||
|
||||
def ppRing? : ReaderT Ring MetaM (Option MessageData) := do
|
||||
def ppRing? : M (Option MessageData) := do
|
||||
let msgs := #[]
|
||||
let msgs := push msgs (← ppBasis?)
|
||||
let msgs := push msgs (← ppDiseqs?)
|
||||
|
|
@ -43,7 +46,7 @@ def ppRing? : ReaderT Ring MetaM (Option MessageData) := do
|
|||
def pp? (goal : Goal) : MetaM (Option MessageData) := do
|
||||
let mut msgs := #[]
|
||||
for ring in goal.arith.ring.rings do
|
||||
let some msg ← ppRing? ring | pure ()
|
||||
let some msg ← ppRing? |>.run' ring | pure ()
|
||||
msgs := msgs.push msg
|
||||
if msgs.isEmpty then
|
||||
return none
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ def toContextExpr : RingM Expr := do
|
|||
if h : 0 < ring.vars.size then
|
||||
RArray.toExpr ring.type id (RArray.ofFn (ring.vars[·]) h)
|
||||
else
|
||||
RArray.toExpr ring.type id (RArray.leaf (mkApp ring.natCastFn (toExpr 0)))
|
||||
RArray.toExpr ring.type id (RArray.leaf (mkApp (← getNatCastFn) (toExpr 0)))
|
||||
|
||||
private def toSContextExpr' : SemiringM Expr := do
|
||||
let semiring ← getSemiring
|
||||
if h : 0 < semiring.vars.size then
|
||||
RArray.toExpr semiring.type id (RArray.ofFn (semiring.vars[·]) h)
|
||||
else
|
||||
RArray.toExpr semiring.type id (RArray.leaf (mkApp semiring.natCastFn (toExpr 0)))
|
||||
RArray.toExpr semiring.type id (RArray.leaf (mkApp (← getNatCastFn') (toExpr 0)))
|
||||
|
||||
/-- Similar to `toContextExpr`, but for semirings. -/
|
||||
private def toSContextExpr (semiringId : Nat) : RingM Expr := do
|
||||
|
|
@ -391,9 +391,12 @@ private def mkStepPrefix (declName declNameC : Name) : ProofM Expr := do
|
|||
else
|
||||
mkStepBasicPrefix declName
|
||||
|
||||
private def getSemiringOf : RingM Semiring := do
|
||||
private def getSemiringIdOf : RingM Nat := do
|
||||
let some semiringId := (← getRing).semiringId? | throwError "`grind` internal error, semiring is not available"
|
||||
SemiringM.run semiringId do getSemiring
|
||||
return semiringId
|
||||
|
||||
private def getSemiringOf : RingM Semiring := do
|
||||
SemiringM.run (← getSemiringIdOf) do getSemiring
|
||||
|
||||
private def mkSemiringPrefix (declName : Name) : ProofM Expr := do
|
||||
let sctx ← getSContext
|
||||
|
|
@ -403,7 +406,7 @@ private def mkSemiringPrefix (declName : Name) : ProofM Expr := do
|
|||
private def mkSemiringAddRightCancelPrefix (declName : Name) : ProofM Expr := do
|
||||
let sctx ← getSContext
|
||||
let semiring ← getSemiringOf
|
||||
let some addRightCancelInst := semiring.addRightCancelInst?
|
||||
let some addRightCancelInst ← SemiringM.run (← getSemiringIdOf) do getAddRightCancelInst?
|
||||
| throwError "`grind` internal error, `AddRightCancel` instance is not available"
|
||||
return mkApp4 (mkConst declName [semiring.u]) semiring.type semiring.semiringInst addRightCancelInst sctx
|
||||
|
||||
|
|
|
|||
|
|
@ -10,20 +10,20 @@ import Lean.Meta.Tactic.Grind.Arith.CommRing.Var
|
|||
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
def isAddInst (ring : Ring) (inst : Expr) : Bool :=
|
||||
isSameExpr ring.addFn.appArg! inst
|
||||
def isMulInst (ring : Ring) (inst : Expr) : Bool :=
|
||||
isSameExpr ring.mulFn.appArg! inst
|
||||
def isSubInst (ring : Ring) (inst : Expr) : Bool :=
|
||||
isSameExpr ring.subFn.appArg! inst
|
||||
def isNegInst (ring : Ring) (inst : Expr) : Bool :=
|
||||
isSameExpr ring.negFn.appArg! inst
|
||||
def isPowInst (ring : Ring) (inst : Expr) : Bool :=
|
||||
isSameExpr ring.powFn.appArg! inst
|
||||
def isIntCastInst (ring : Ring) (inst : Expr) : Bool :=
|
||||
isSameExpr ring.intCastFn.appArg! inst
|
||||
def isNatCastInst (ring : Ring) (inst : Expr) : Bool :=
|
||||
isSameExpr ring.natCastFn.appArg! inst
|
||||
def isAddInst (inst : Expr) : RingM Bool :=
|
||||
return isSameExpr (← getAddFn).appArg! inst
|
||||
def isMulInst (inst : Expr) : RingM Bool :=
|
||||
return isSameExpr (← getMulFn).appArg! inst
|
||||
def isSubInst (inst : Expr) : RingM Bool :=
|
||||
return isSameExpr (← getSubFn).appArg! inst
|
||||
def isNegInst (inst : Expr) : RingM Bool :=
|
||||
return isSameExpr (← getNegFn).appArg! inst
|
||||
def isPowInst (inst : Expr) : RingM Bool :=
|
||||
return isSameExpr (← getPowFn).appArg! inst
|
||||
def isIntCastInst (inst : Expr) : RingM Bool :=
|
||||
return isSameExpr (← getIntCastFn).appArg! inst
|
||||
def isNatCastInst (inst : Expr) : RingM Bool :=
|
||||
return isSameExpr (← getNatCastFn).appArg! inst
|
||||
|
||||
private def reportAppIssue (e : Expr) : GoalM Unit := do
|
||||
reportIssue! "comm ring term with unexpected instance{indentExpr e}"
|
||||
|
|
@ -49,24 +49,24 @@ partial def reify? (e : Expr) (skipVar := true) (gen : Nat := 0) : RingM (Option
|
|||
let rec go (e : Expr) : RingM RingExpr := do
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if isAddInst (← getRing) i then return .add (← go a) (← go b) else asVar e
|
||||
if (← isAddInst i) then return .add (← go a) (← go b) else asVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if isMulInst (← getRing) i then return .mul (← go a) (← go b) else asVar e
|
||||
if (← isMulInst i) then return .mul (← go a) (← go b) else asVar e
|
||||
| HSub.hSub _ _ _ i a b =>
|
||||
if isSubInst (← getRing) i then return .sub (← go a) (← go b) else asVar e
|
||||
if (← isSubInst i) then return .sub (← go a) (← go b) else asVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k ← getNatValue? b | toVar e
|
||||
if isPowInst (← getRing) i then return .pow (← go a) k else asVar e
|
||||
if (← isPowInst i) then return .pow (← go a) k else asVar e
|
||||
| Neg.neg _ i a =>
|
||||
if isNegInst (← getRing) i then return .neg (← go a) else asVar e
|
||||
if (← isNegInst i) then return .neg (← go a) else asVar e
|
||||
| IntCast.intCast _ i a =>
|
||||
if isIntCastInst (← getRing) i then
|
||||
if (← isIntCastInst i) then
|
||||
let some k ← getIntValue? a | toVar e
|
||||
return .num k
|
||||
else
|
||||
asVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if isNatCastInst (← getRing) i then
|
||||
if (← isNatCastInst i) then
|
||||
let some k ← getNatValue? a | toVar e
|
||||
return .num k
|
||||
else
|
||||
|
|
@ -85,24 +85,24 @@ partial def reify? (e : Expr) (skipVar := true) (gen : Nat := 0) : RingM (Option
|
|||
toTopVar e
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if isAddInst (← getRing) i then return some (.add (← go a) (← go b)) else asTopVar e
|
||||
if (← isAddInst i) then return some (.add (← go a) (← go b)) else asTopVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if isMulInst (← getRing) i then return some (.mul (← go a) (← go b)) else asTopVar e
|
||||
if (← isMulInst i) then return some (.mul (← go a) (← go b)) else asTopVar e
|
||||
| HSub.hSub _ _ _ i a b =>
|
||||
if isSubInst (← getRing) i then return some (.sub (← go a) (← go b)) else asTopVar e
|
||||
if (← isSubInst i) then return some (.sub (← go a) (← go b)) else asTopVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k ← getNatValue? b | return none
|
||||
if isPowInst (← getRing) i then return some (.pow (← go a) k) else asTopVar e
|
||||
if (← isPowInst i) then return some (.pow (← go a) k) else asTopVar e
|
||||
| Neg.neg _ i a =>
|
||||
if isNegInst (← getRing) i then return some (.neg (← go a)) else asTopVar e
|
||||
if (← isNegInst i) then return some (.neg (← go a)) else asTopVar e
|
||||
| IntCast.intCast _ i a =>
|
||||
if isIntCastInst (← getRing) i then
|
||||
if (← isIntCastInst i) then
|
||||
let some k ← getIntValue? a | toTopVar e
|
||||
return some (.num k)
|
||||
else
|
||||
asTopVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if isNatCastInst (← getRing) i then
|
||||
if (← isNatCastInst i) then
|
||||
let some k ← getNatValue? a | toTopVar e
|
||||
return some (.num k)
|
||||
else
|
||||
|
|
@ -127,14 +127,14 @@ partial def sreify? (e : Expr) : SemiringM (Option SemiringExpr) := do
|
|||
let rec go (e : Expr) : SemiringM SemiringExpr := do
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if isSameExpr (← getSemiring).addFn.appArg! i then return .add (← go a) (← go b) else asVar e
|
||||
if isSameExpr (← getAddFn').appArg! i then return .add (← go a) (← go b) else asVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if isSameExpr (← getSemiring).mulFn.appArg! i then return .mul (← go a) (← go b) else asVar e
|
||||
if isSameExpr (← getMulFn').appArg! i then return .mul (← go a) (← go b) else asVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k ← getNatValue? b | toVar e
|
||||
if isSameExpr (← getSemiring).powFn.appArg! i then return .pow (← go a) k else asVar e
|
||||
if isSameExpr (← getPowFn').appArg! i then return .pow (← go a) k else asVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if isSameExpr (← getSemiring).natCastFn.appArg! i then
|
||||
if isSameExpr (← getNatCastFn').appArg! i then
|
||||
let some k ← getNatValue? a | toVar e
|
||||
return .num k
|
||||
else
|
||||
|
|
@ -150,14 +150,14 @@ partial def sreify? (e : Expr) : SemiringM (Option SemiringExpr) := do
|
|||
toTopVar e
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if isSameExpr (← getSemiring).addFn.appArg! i then return some (.add (← go a) (← go b)) else asTopVar e
|
||||
if isSameExpr (← getAddFn').appArg! i then return some (.add (← go a) (← go b)) else asTopVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if isSameExpr (← getSemiring).mulFn.appArg! i then return some (.mul (← go a) (← go b)) else asTopVar e
|
||||
if isSameExpr (← getMulFn').appArg! i then return some (.mul (← go a) (← go b)) else asTopVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k ← getNatValue? b | return none
|
||||
if isSameExpr (← getSemiring).powFn.appArg! i then return some (.pow (← go a) k) else asTopVar e
|
||||
if isSameExpr (← getPowFn').appArg! i then return some (.pow (← go a) k) else asTopVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if isSameExpr (← getSemiring).natCastFn.appArg! i then
|
||||
if isSameExpr (← getNatCastFn').appArg! i then
|
||||
let some k ← getNatValue? a | toTopVar e
|
||||
return some (.num k)
|
||||
else
|
||||
|
|
|
|||
|
|
@ -13,88 +13,6 @@ import Lean.Meta.Tactic.Grind.Arith.CommRing.Util
|
|||
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
def denoteNumCore (u : Level) (type : Expr) (semiringInst : Expr) (negFn : Expr) (k : Int) : Expr :=
|
||||
let n := mkRawNatLit k.natAbs
|
||||
let ofNatInst := mkApp3 (mkConst ``Grind.Semiring.ofNat [u]) type semiringInst n
|
||||
let n := mkApp3 (mkConst ``OfNat.ofNat [u]) type n ofNatInst
|
||||
if k < 0 then
|
||||
mkApp negFn n
|
||||
else
|
||||
n
|
||||
|
||||
private def internalizeFn (fn : Expr) : GoalM Expr := do
|
||||
shareCommon (← canon fn)
|
||||
|
||||
private def getUnaryFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) : GoalM Expr := do
|
||||
let inst ← synthInstance <| mkApp (mkConst instDeclName [u]) type
|
||||
internalizeFn <| mkApp2 (mkConst declName [u]) type inst
|
||||
|
||||
private def getBinHomoFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) : GoalM Expr := do
|
||||
let inst ← synthInstance <| mkApp3 (mkConst instDeclName [u, u, u]) type type type
|
||||
internalizeFn <| mkApp4 (mkConst declName [u, u, u]) type type type inst
|
||||
|
||||
-- Remark: we removed consistency checks such as the one that ensures `HAdd` instance matches `Semiring.toAdd`
|
||||
-- That is, we are assuming the type classes were properly setup.
|
||||
|
||||
private def getAddFn (type : Expr) (u : Level) : GoalM Expr := do
|
||||
getBinHomoFn type u ``HAdd ``HAdd.hAdd
|
||||
|
||||
private def getMulFn (type : Expr) (u : Level) : GoalM Expr := do
|
||||
getBinHomoFn type u ``HMul ``HMul.hMul
|
||||
|
||||
private def getSubFn (type : Expr) (u : Level) : GoalM Expr := do
|
||||
getBinHomoFn type u ``HSub ``HSub.hSub
|
||||
|
||||
private def getDivFn (type : Expr) (u : Level) : GoalM Expr := do
|
||||
getBinHomoFn type u ``HDiv ``HDiv.hDiv
|
||||
|
||||
private def getNegFn (type : Expr) (u : Level) : GoalM Expr := do
|
||||
getUnaryFn type u ``Neg ``Neg.neg
|
||||
|
||||
private def getInvFn (type : Expr) (u : Level) : GoalM Expr := do
|
||||
getUnaryFn type u ``Inv ``Inv.inv
|
||||
|
||||
private def getPowFn (type : Expr) (u : Level) (semiringInst : Expr) : GoalM Expr := do
|
||||
let inst ← synthInstance <| mkApp3 (mkConst ``HPow [u, 0, u]) type Nat.mkType type
|
||||
let inst' := mkApp2 (mkConst ``Grind.Semiring.toHPow [u]) type semiringInst
|
||||
unless (← withDefault <| isDefEq inst inst') do
|
||||
throwError "instance for power operator{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
|
||||
internalizeFn <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst
|
||||
|
||||
private def getIntCastFn (type : Expr) (u : Level) (ringInst : Expr) : GoalM Expr := do
|
||||
let inst' := mkApp2 (mkConst ``Grind.Ring.intCast [u]) type ringInst
|
||||
let instType := mkApp (mkConst ``IntCast [u]) type
|
||||
-- Note that `Ring.intCast` is not registered as a global instance
|
||||
-- (to avoid introducing unwanted coercions)
|
||||
-- so merely having a `Ring α` instance
|
||||
-- does not guarantee that an `IntCast α` will be available.
|
||||
-- When both are present we verify that they are defeq,
|
||||
-- and otherwise fall back to the field of the `Ring α` instance that we already have.
|
||||
let inst ← match (← synthInstance? instType) with
|
||||
| none => pure inst'
|
||||
| some inst =>
|
||||
unless (← withDefault <| isDefEq inst inst') do
|
||||
throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.Ring` one{indentExpr inst'}"
|
||||
pure inst
|
||||
internalizeFn <| mkApp2 (mkConst ``IntCast.intCast [u]) type inst
|
||||
|
||||
private def getNatCastFn (type : Expr) (u : Level) (semiringInst : Expr) : GoalM Expr := do
|
||||
let inst' := mkApp2 (mkConst ``Grind.Semiring.natCast [u]) type semiringInst
|
||||
let instType := mkApp (mkConst ``NatCast [u]) type
|
||||
-- Note that `Semiring.natCast` is not registered as a global instance
|
||||
-- (to avoid introducing unwanted coercions)
|
||||
-- so merely having a `Semiring α` instance
|
||||
-- does not guarantee that an `NatCast α` will be available.
|
||||
-- When both are present we verify that they are defeq,
|
||||
-- and otherwise fall back to the field of the `Semiring α` instance that we already have.
|
||||
let inst ← match (← synthInstance? instType) with
|
||||
| none => pure inst'
|
||||
| some inst =>
|
||||
unless (← withDefault <| isDefEq inst inst') do
|
||||
throwError "instance for natCast{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
|
||||
pure inst
|
||||
internalizeFn <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst
|
||||
|
||||
/--
|
||||
Returns the ring id for the given type if there is a `CommRing` instance for it.
|
||||
|
||||
|
|
@ -109,10 +27,6 @@ def getRingId? (type : Expr) : GoalM (Option Nat) := do
|
|||
else
|
||||
let id? ← go?
|
||||
modify' fun s => { s with typeIdOf := s.typeIdOf.insert { expr := type } id? }
|
||||
if let some id := id? then
|
||||
-- Internalize helper constants
|
||||
let ring := (← get').rings[id]!
|
||||
internalize ring.one 0
|
||||
return id?
|
||||
where
|
||||
go? : GoalM (Option Nat) := do
|
||||
|
|
@ -127,24 +41,12 @@ where
|
|||
let noZeroDivInst? ← getNoZeroDivInst? u type
|
||||
trace_goal[grind.ring] "NoNatZeroDivisors available: {noZeroDivInst?.isSome}"
|
||||
let fieldInst? ← synthInstance? <| mkApp (mkConst ``Grind.Field [u]) type
|
||||
let addFn ← getAddFn type u
|
||||
let mulFn ← getMulFn type u
|
||||
let subFn ← getSubFn type u
|
||||
let negFn ← getNegFn type u
|
||||
let powFn ← getPowFn type u semiringInst
|
||||
let intCastFn ← getIntCastFn type u ringInst
|
||||
let natCastFn ← getNatCastFn type u semiringInst
|
||||
let invFn? ← if fieldInst?.isSome then
|
||||
pure (some (← getInvFn type u))
|
||||
else
|
||||
pure none
|
||||
let one ← shareCommon <| (← canon <| denoteNumCore u type semiringInst negFn 1)
|
||||
let semiringId? := none
|
||||
let id := (← get').rings.size
|
||||
let ring : Ring := {
|
||||
id, semiringId?, type, u, semiringInst, ringInst, commSemiringInst,
|
||||
commRingInst, charInst?, noZeroDivInst?, fieldInst?,
|
||||
addFn, mulFn, subFn, negFn, powFn, intCastFn, natCastFn, invFn?, one }
|
||||
}
|
||||
modify' fun s => { s with rings := s.rings.push ring }
|
||||
return some id
|
||||
|
||||
|
|
@ -164,22 +66,12 @@ where
|
|||
let commSemiring := mkApp (mkConst ``Grind.CommSemiring [u]) type
|
||||
let some commSemiringInst ← synthInstance? commSemiring | return none
|
||||
let semiringInst := mkApp2 (mkConst ``Grind.CommSemiring.toSemiring [u]) type commSemiringInst
|
||||
let toQFn ← internalizeFn <| mkApp2 (mkConst ``Grind.Ring.OfSemiring.toQ [u]) type semiringInst
|
||||
let addFn ← getAddFn type u
|
||||
let mulFn ← getMulFn type u
|
||||
let powFn ← getPowFn type u semiringInst
|
||||
let natCastFn ← getNatCastFn type u semiringInst
|
||||
let add := mkApp (mkConst ``Add [u]) type
|
||||
let some addInst ← synthInstance? add | return none
|
||||
let addRightCancel := mkApp2 (mkConst ``Grind.AddRightCancel [u]) type addInst
|
||||
let addRightCancelInst? ← synthInstance? addRightCancel
|
||||
let q ← shareCommon (← canon (mkApp2 (mkConst ``Grind.Ring.OfSemiring.Q [u]) type semiringInst))
|
||||
let some ringId ← getRingId? q
|
||||
| throwError "`grind` unexpected failure, failure to initialize ring{indentExpr q}"
|
||||
let id := (← get').semirings.size
|
||||
let semiring : Semiring := {
|
||||
id, type, ringId, u, semiringInst, commSemiringInst,
|
||||
addFn, mulFn, powFn, natCastFn, toQFn, addRightCancelInst?
|
||||
id, type, ringId, u, semiringInst, commSemiringInst
|
||||
}
|
||||
modify' fun s => { s with semirings := s.semirings.push semiring }
|
||||
setSemiringId ringId id
|
||||
|
|
|
|||
|
|
@ -165,16 +165,16 @@ structure Ring where
|
|||
noZeroDivInst? : Option Expr
|
||||
/-- `Field` instance for `type` if available. -/
|
||||
fieldInst? : Option Expr
|
||||
addFn : Expr
|
||||
mulFn : Expr
|
||||
subFn : Expr
|
||||
negFn : Expr
|
||||
powFn : Expr
|
||||
intCastFn : Expr
|
||||
natCastFn : Expr
|
||||
addFn? : Option Expr := none
|
||||
mulFn? : Option Expr := none
|
||||
subFn? : Option Expr := none
|
||||
negFn? : Option Expr := none
|
||||
powFn? : Option Expr := none
|
||||
intCastFn? : Option Expr := none
|
||||
natCastFn? : Option Expr := none
|
||||
/-- Inverse if `fieldInst?` is `some inst` -/
|
||||
invFn? : Option Expr
|
||||
one : Expr
|
||||
invFn? : Option Expr := none
|
||||
one? : Option Expr := none
|
||||
/--
|
||||
Mapping from variables to their denotations.
|
||||
Remark each variable can be in only one ring.
|
||||
|
|
@ -230,12 +230,12 @@ structure Semiring where
|
|||
/-- `CommSemiring` instance for `type` -/
|
||||
commSemiringInst : Expr
|
||||
/-- `AddRightCancel` instance for `type` if available. -/
|
||||
addRightCancelInst? : Option Expr
|
||||
toQFn : Expr
|
||||
addFn : Expr
|
||||
mulFn : Expr
|
||||
powFn : Expr
|
||||
natCastFn : Expr
|
||||
addRightCancelInst? : Option (Option Expr) := none
|
||||
toQFn? : Option Expr := none
|
||||
addFn? : Option Expr := none
|
||||
mulFn? : Option Expr := none
|
||||
powFn? : Option Expr := none
|
||||
natCastFn? : Option Expr := none
|
||||
/-- Mapping from Lean expressions to their representations as `SemiringExpr` -/
|
||||
denote : PHashMap ExprPtr SemiringExpr := {}
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.Tactic.Grind.Canon
|
||||
import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.SynthInstance
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
@ -36,14 +38,34 @@ structure RingM.Context where
|
|||
-/
|
||||
checkCoeffDvd : Bool := false
|
||||
|
||||
class MonadGetRing (m : Type → Type) where
|
||||
class MonadRing (m : Type → Type) where
|
||||
getRing : m Ring
|
||||
modifyRing : (Ring → Ring) → m Unit
|
||||
/--
|
||||
Helper function for removing dependency on `GoalM`.
|
||||
In `RingM` and `SemiringM`, this is just `sharedCommon (← canon e)`
|
||||
When printing counterexamples, we are at `MetaM`, and this is just the identity function.
|
||||
-/
|
||||
canonExpr : Expr → m Expr
|
||||
/--
|
||||
Helper function for removing dependency on `GoalM`. During search we
|
||||
want to track the instances synthesized by `grind`, and this is `Grind.synthInstance`.
|
||||
-/
|
||||
synthInstance? : Expr → m (Option Expr)
|
||||
|
||||
export MonadGetRing (getRing)
|
||||
export MonadRing (getRing modifyRing canonExpr)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadGetRing m] : MonadGetRing n where
|
||||
instance (m n) [MonadLift m n] [MonadRing m] : MonadRing n where
|
||||
getRing := liftM (getRing : m Ring)
|
||||
modifyRing f := liftM (modifyRing f : m Unit)
|
||||
canonExpr e := liftM (canonExpr e : m Expr)
|
||||
synthInstance? e := liftM (MonadRing.synthInstance? e : m (Option Expr))
|
||||
|
||||
def MonadRing.synthInstance [Monad m] [MonadError m] [MonadRing m] (type : Expr) : m Expr := do
|
||||
let some inst ← synthInstance? type
|
||||
| throwError "`grind` failed to find instance{indentExpr type}"
|
||||
return inst
|
||||
|
||||
/-- We don't want to keep carrying the `RingId` around. -/
|
||||
abbrev RingM := ReaderT RingM.Context GoalM
|
||||
|
|
@ -62,12 +84,13 @@ protected def RingM.getRing : RingM Ring := do
|
|||
else
|
||||
throwError "`grind` internal error, invalid ringId"
|
||||
|
||||
instance : MonadGetRing RingM where
|
||||
instance : MonadRing RingM where
|
||||
getRing := RingM.getRing
|
||||
|
||||
@[inline] def modifyRing (f : Ring → Ring) : RingM Unit := do
|
||||
let ringId ← getRingId
|
||||
modify' fun s => { s with rings := s.rings.modify ringId f }
|
||||
modifyRing f := do
|
||||
let ringId ← getRingId
|
||||
modify' fun s => { s with rings := s.rings.modify ringId f }
|
||||
canonExpr e := do shareCommon (← canon e)
|
||||
synthInstance? e := Grind.synthInstance? e
|
||||
|
||||
structure SemiringM.Context where
|
||||
semiringId : Nat
|
||||
|
|
@ -96,8 +119,13 @@ protected def SemiringM.getRing : SemiringM Ring := do
|
|||
else
|
||||
throwError "`grind` internal error, invalid ringId"
|
||||
|
||||
instance : MonadGetRing SemiringM where
|
||||
instance : MonadRing SemiringM where
|
||||
getRing := SemiringM.getRing
|
||||
modifyRing f := do
|
||||
let ringId := (← getSemiring).ringId
|
||||
modify' fun s => { s with rings := s.rings.modify ringId f }
|
||||
canonExpr e := do shareCommon (← canon e)
|
||||
synthInstance? e := Grind.synthInstance? e
|
||||
|
||||
@[inline] def modifySemiring (f : Semiring → Semiring) : SemiringM Unit := do
|
||||
let semiringId ← getSemiringId
|
||||
|
|
@ -132,14 +160,14 @@ def setTermSemiringId (e : Expr) : SemiringM Unit := do
|
|||
modify' fun s => { s with exprToSemiringId := s.exprToSemiringId.insert { expr := e } semiringId }
|
||||
|
||||
/-- Returns `some c` if the current ring has a nonzero characteristic `c`. -/
|
||||
def nonzeroChar? [Monad m] [MonadGetRing m] : m (Option Nat) := do
|
||||
def nonzeroChar? [Monad m] [MonadRing m] : m (Option Nat) := do
|
||||
if let some (_, c) := (← getRing).charInst? then
|
||||
if c != 0 then
|
||||
return some c
|
||||
return none
|
||||
|
||||
/-- Returns `some (charInst, c)` if the current ring has a nonzero characteristic `c`. -/
|
||||
def nonzeroCharInst? [Monad m] [MonadGetRing m] : m (Option (Expr × Nat)) := do
|
||||
def nonzeroCharInst? [Monad m] [MonadRing m] : m (Option (Expr × Nat)) := do
|
||||
if let some (inst, c) := (← getRing).charInst? then
|
||||
if c != 0 then
|
||||
return some (inst, c)
|
||||
|
|
@ -181,4 +209,176 @@ def getNext? : RingM (Option EqCnstr) := do
|
|||
incSteps
|
||||
return some c
|
||||
|
||||
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadRing m]
|
||||
|
||||
private def mkUnaryFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) : m Expr := do
|
||||
let inst ← MonadRing.synthInstance <| mkApp (mkConst instDeclName [u]) type
|
||||
canonExpr <| mkApp2 (mkConst declName [u]) type inst
|
||||
|
||||
private def mkBinHomoFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) : m Expr := do
|
||||
let inst ← MonadRing.synthInstance <| mkApp3 (mkConst instDeclName [u, u, u]) type type type
|
||||
canonExpr <| mkApp4 (mkConst declName [u, u, u]) type type type inst
|
||||
|
||||
def getAddFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some addFn := ring.addFn? then return addFn
|
||||
let addFn ← mkBinHomoFn ring.type ring.u ``HAdd ``HAdd.hAdd
|
||||
modifyRing fun s => { s with addFn? := some addFn }
|
||||
return addFn
|
||||
|
||||
def getSubFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some subFn := ring.subFn? then return subFn
|
||||
let subFn ← mkBinHomoFn ring.type ring.u ``HSub ``HSub.hSub
|
||||
modifyRing fun s => { s with subFn? := some subFn }
|
||||
return subFn
|
||||
|
||||
def getMulFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some mulFn := ring.mulFn? then return mulFn
|
||||
let mulFn ← mkBinHomoFn ring.type ring.u ``HMul ``HMul.hMul
|
||||
modifyRing fun s => { s with mulFn? := some mulFn }
|
||||
return mulFn
|
||||
|
||||
def getNegFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some negFn := ring.negFn? then return negFn
|
||||
let negFn ← mkUnaryFn ring.type ring.u ``Neg ``Neg.neg
|
||||
modifyRing fun s => { s with negFn? := some negFn }
|
||||
return negFn
|
||||
|
||||
def getInvFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if ring.fieldInst?.isNone then
|
||||
throwError "`grind` internal error, type is not a field{indentExpr ring.type}"
|
||||
if let some invFn := ring.invFn? then return invFn
|
||||
let invFn ← mkUnaryFn ring.type ring.u ``Inv ``Inv.inv
|
||||
modifyRing fun s => { s with invFn? := some invFn }
|
||||
return invFn
|
||||
|
||||
private def mkPowFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
|
||||
let inst ← MonadRing.synthInstance <| mkApp3 (mkConst ``HPow [u, 0, u]) type Nat.mkType type
|
||||
let inst' := mkApp2 (mkConst ``Grind.Semiring.toHPow [u]) type semiringInst
|
||||
checkInst inst inst'
|
||||
canonExpr <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst
|
||||
where
|
||||
checkInst (inst inst' : Expr) : MetaM Unit := do
|
||||
unless (← withDefault <| isDefEq inst inst') do
|
||||
throwError "instance for power operator{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
|
||||
|
||||
def getPowFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some powFn := ring.powFn? then return powFn
|
||||
let powFn ← mkPowFn ring.u ring.type ring.semiringInst
|
||||
modifyRing fun s => { s with powFn? := some powFn }
|
||||
return powFn
|
||||
|
||||
def getIntCastFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some intCastFn := ring.intCastFn? then return intCastFn
|
||||
let inst' := mkApp2 (mkConst ``Grind.Ring.intCast [ring.u]) ring.type ring.ringInst
|
||||
let instType := mkApp (mkConst ``IntCast [ring.u]) ring.type
|
||||
-- Note that `Ring.intCast` is not registered as a global instance
|
||||
-- (to avoid introducing unwanted coercions)
|
||||
-- so merely having a `Ring α` instance
|
||||
-- does not guarantee that an `IntCast α` will be available.
|
||||
-- When both are present we verify that they are defeq,
|
||||
-- and otherwise fall back to the field of the `Ring α` instance that we already have.
|
||||
let inst ← match (← MonadRing.synthInstance? instType) with
|
||||
| none => pure inst'
|
||||
| some inst => checkInst inst inst'; pure inst
|
||||
let intCastFn ← canonExpr <| mkApp2 (mkConst ``IntCast.intCast [ring.u]) ring.type inst
|
||||
modifyRing fun s => { s with intCastFn? := some intCastFn }
|
||||
return intCastFn
|
||||
where
|
||||
checkInst (inst inst' : Expr) : MetaM Unit := do
|
||||
unless (← withDefault <| isDefEq inst inst') do
|
||||
throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.Ring` one{indentExpr inst'}"
|
||||
|
||||
private def mkNatCastFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
|
||||
let inst' := mkApp2 (mkConst ``Grind.Semiring.natCast [u]) type semiringInst
|
||||
let instType := mkApp (mkConst ``NatCast [u]) type
|
||||
-- Note that `Semiring.natCast` is not registered as a global instance
|
||||
-- (to avoid introducing unwanted coercions)
|
||||
-- so merely having a `Semiring α` instance
|
||||
-- does not guarantee that an `NatCast α` will be available.
|
||||
-- When both are present we verify that they are defeq,
|
||||
-- and otherwise fall back to the field of the `Semiring α` instance that we already have.
|
||||
let inst ← match (← MonadRing.synthInstance? instType) with
|
||||
| none => pure inst'
|
||||
| some inst => checkInst inst inst'; pure inst
|
||||
canonExpr <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst
|
||||
where
|
||||
checkInst (inst inst' : Expr) : MetaM Unit := do
|
||||
unless (← withDefault <| isDefEq inst inst') do
|
||||
throwError "instance for natCast{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
|
||||
|
||||
def getNatCastFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some natCastFn := ring.natCastFn? then return natCastFn
|
||||
let natCastFn ← mkNatCastFn ring.u ring.type ring.semiringInst
|
||||
modifyRing fun s => { s with natCastFn? := some natCastFn }
|
||||
return natCastFn
|
||||
|
||||
private def mkOne (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
|
||||
let n := mkRawNatLit 1
|
||||
let ofNatInst := mkApp3 (mkConst ``Grind.Semiring.ofNat [u]) type semiringInst n
|
||||
canonExpr <| mkApp3 (mkConst ``OfNat.ofNat [u]) type n ofNatInst
|
||||
|
||||
def getOne [MonadLiftT GoalM m] : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some one := ring.one? then return one
|
||||
let one ← mkOne ring.u ring.type ring.semiringInst
|
||||
modifyRing fun s => { s with one? := some one }
|
||||
internalize one 0
|
||||
return one
|
||||
|
||||
def getAddFn' : SemiringM Expr := do
|
||||
let s ← getSemiring
|
||||
if let some addFn := s.addFn? then return addFn
|
||||
let addFn ← mkBinHomoFn s.type s.u ``HAdd ``HAdd.hAdd
|
||||
modifySemiring fun s => { s with addFn? := some addFn }
|
||||
return addFn
|
||||
|
||||
def getMulFn' : SemiringM Expr := do
|
||||
let s ← getSemiring
|
||||
if let some mulFn := s.mulFn? then return mulFn
|
||||
let mulFn ← mkBinHomoFn s.type s.u ``HMul ``HMul.hMul
|
||||
modifySemiring fun s => { s with mulFn? := some mulFn }
|
||||
return mulFn
|
||||
|
||||
def getPowFn' : SemiringM Expr := do
|
||||
let s ← getSemiring
|
||||
if let some powFn := s.powFn? then return powFn
|
||||
let powFn ← mkPowFn s.u s.type s.semiringInst
|
||||
modifySemiring fun s => { s with powFn? := some powFn }
|
||||
return powFn
|
||||
|
||||
def getNatCastFn' : SemiringM Expr := do
|
||||
let s ← getSemiring
|
||||
if let some natCastFn := s.natCastFn? then return natCastFn
|
||||
let natCastFn ← mkNatCastFn s.u s.type s.semiringInst
|
||||
modifySemiring fun s => { s with natCastFn? := some natCastFn }
|
||||
return natCastFn
|
||||
|
||||
def getToQFn : SemiringM Expr := do
|
||||
let s ← getSemiring
|
||||
if let some toQFn := s.toQFn? then return toQFn
|
||||
let toQFn ← canonExpr <| mkApp2 (mkConst ``Grind.Ring.OfSemiring.toQ [s.u]) s.type s.semiringInst
|
||||
modifySemiring fun s => { s with toQFn? := some toQFn }
|
||||
return toQFn
|
||||
|
||||
private def mkAddRightCancelInst? (u : Level) (type : Expr) : GoalM (Option Expr) := do
|
||||
let add := mkApp (mkConst ``Add [u]) type
|
||||
let some addInst ← synthInstance? add | return none
|
||||
let addRightCancel := mkApp2 (mkConst ``Grind.AddRightCancel [u]) type addInst
|
||||
synthInstance? addRightCancel
|
||||
|
||||
def getAddRightCancelInst? : SemiringM (Option Expr) := do
|
||||
let s ← getSemiring
|
||||
if let some r := s.addRightCancelInst? then return r
|
||||
let addRightCancelInst? ← mkAddRightCancelInst? s.u s.type
|
||||
modifySemiring fun s => { s with addRightCancelInst? := some addRightCancelInst? }
|
||||
return addRightCancelInst?
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
|
|
|
|||
|
|
@ -67,8 +67,13 @@ def getRing : LinearM Ring := do
|
|||
| throwNotCommRing
|
||||
return ring
|
||||
|
||||
instance : MonadGetRing LinearM where
|
||||
instance : MonadRing LinearM where
|
||||
getRing := Linear.getRing
|
||||
modifyRing f := do
|
||||
let some ringId := (← getStruct).ringId? | throwNotCommRing
|
||||
RingM.run ringId do modifyRing f
|
||||
canonExpr e := do shareCommon (← canon e)
|
||||
synthInstance? e := Grind.synthInstance? e
|
||||
|
||||
def getZero : LinearM Expr :=
|
||||
return (← getStruct).zero
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ h : ¬a = 10
|
|||
[eqc] False propositions
|
||||
[prop] a = 10
|
||||
[cutsat] Assignment satisfying linear constraints
|
||||
[assign] a := 2
|
||||
[assign] a := 1
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : a = 5 + 5 := by
|
||||
|
|
@ -49,8 +49,8 @@ h : ¬f a = 11
|
|||
[eqc] False propositions
|
||||
[prop] f a = 11
|
||||
[cutsat] Assignment satisfying linear constraints
|
||||
[assign] a := 3
|
||||
[assign] f a := 2
|
||||
[assign] a := 2
|
||||
[assign] f a := 1
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : f a = 10 + 1 := by
|
||||
|
|
@ -75,9 +75,9 @@ h : ¬f x = 11
|
|||
[ematch] E-matching patterns
|
||||
[thm] fa: [f `[a]]
|
||||
[cutsat] Assignment satisfying linear constraints
|
||||
[assign] x := 4
|
||||
[assign] a := 3
|
||||
[assign] f x := 2
|
||||
[assign] x := 3
|
||||
[assign] a := 2
|
||||
[assign] f x := 1
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : f x = 10 + 1 := by
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ theorem ex₃ (a b c : Int) : a + b + c = 0 → a = c → b = 4 → c = -2 := by
|
|||
|
||||
/--
|
||||
trace: [grind.cutsat.assert] -1*「a + -2 * b + -2 * c」 + a + -2*b + -2*c = 0
|
||||
[grind.cutsat.assert] -1*「1」 + 1 = 0
|
||||
[grind.cutsat.assert] -1*「0」 = 0
|
||||
[grind.cutsat.assert] 「a + -2 * b + -2 * c」 = 0
|
||||
[grind.cutsat.assert] -1*「a + -2 * b + -2 * d」 + a + -2*b + -2*d = 0
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ theorem ex₄ (f : Int → Int) (a b : Int) (_ : 2 ∣ f (f a) + 1) (h₁ : 3
|
|||
/--
|
||||
trace: [grind.debug.cutsat.search.assign] a := 1
|
||||
[grind.debug.cutsat.search.assign] b := 0
|
||||
[grind.debug.cutsat.search.assign] 「1」 := 1
|
||||
-/
|
||||
#guard_msgs (trace) in -- finds the model without any backtracking
|
||||
set_option trace.grind.debug.cutsat.search.assign true in
|
||||
|
|
@ -31,12 +30,10 @@ example (a b : Int) (_ : 2 ∣ a + 3) (_ : 3 ∣ a + b - 4) : False := by
|
|||
sorry
|
||||
|
||||
/--
|
||||
trace: [grind.cutsat.assert] -1*「1」 + 1 = 0
|
||||
[grind.cutsat.assert] 2 ∣ a + 3
|
||||
trace: [grind.cutsat.assert] 2 ∣ a + 3
|
||||
[grind.cutsat.assert] 3 ∣ a + 3*b + -4
|
||||
[grind.debug.cutsat.search.assign] a := 1
|
||||
[grind.debug.cutsat.search.assign] b := 0
|
||||
[grind.debug.cutsat.search.assign] 「1」 := 1
|
||||
-/
|
||||
#guard_msgs (trace) in
|
||||
set_option trace.grind.cutsat.assert true in
|
||||
|
|
@ -48,7 +45,6 @@ example (a b : Int) (_ : 2 ∣ a + 3) (_ : 3 ∣ a + 3*b - 4) : False := by
|
|||
/--
|
||||
trace: [grind.debug.cutsat.search.assign] a := 1
|
||||
[grind.debug.cutsat.search.assign] b := 15
|
||||
[grind.debug.cutsat.search.assign] 「1」 := 1
|
||||
-/
|
||||
#guard_msgs (trace) in
|
||||
set_option trace.grind.debug.cutsat.search.assign true in
|
||||
|
|
@ -59,7 +55,6 @@ example (a b : Int) (_ : 2 ∣ a + 3) (_ : 3 ∣ a + b - 4) (_ : b < 18): False
|
|||
/--
|
||||
trace: [grind.debug.cutsat.search.assign] a := 1
|
||||
[grind.debug.cutsat.search.assign] b := 12
|
||||
[grind.debug.cutsat.search.assign] 「1」 := 1
|
||||
-/
|
||||
#guard_msgs (trace) in
|
||||
set_option trace.grind.debug.cutsat.search.assign true in
|
||||
|
|
@ -70,8 +65,8 @@ example (a b : Int) (_ : 2 ∣ a + 3) (_ : 3 ∣ a + b - 4) (_ : b ≥ 11): Fals
|
|||
/--
|
||||
trace: [grind.debug.cutsat.search.assign] f 0 := 11
|
||||
[grind.debug.cutsat.search.assign] f 1 := 2
|
||||
[grind.debug.cutsat.search.assign] 「0」 := 0
|
||||
[grind.debug.cutsat.search.assign] 「1」 := 1
|
||||
[grind.debug.cutsat.search.assign] 「0」 := 0
|
||||
-/
|
||||
#guard_msgs (trace) in
|
||||
set_option trace.grind.debug.cutsat.search.assign true in
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ open Int.Linear
|
|||
|
||||
/--
|
||||
trace: [grind.cutsat.assert] -1*「b + f a + 1」 + b + f a + 1 = 0
|
||||
[grind.cutsat.assert] -1*「1」 + 1 = 0
|
||||
[grind.cutsat.assert] -1*「0」 = 0
|
||||
[grind.cutsat.assert] 「b + f a + 1」 = 0
|
||||
-/
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ example (a b : Int) : a + b = Int.ofNat 2 → a - 2 = -b := by
|
|||
trace: [grind.cutsat.assert] -1*↑a ≤ 0
|
||||
[grind.cutsat.assert] -1*↑b ≤ 0
|
||||
[grind.cutsat.assert] -1*「↑a * ↑b」 ≤ 0
|
||||
[grind.cutsat.assert] -1*「1」 + 1 = 0
|
||||
[grind.cutsat.assert] -1*↑c ≤ 0
|
||||
[grind.cutsat.assert] -1*「↑a * ↑b + -1 * ↑c + 1」 + 「↑a * ↑b」 + -1*↑c + 1 = 0
|
||||
[grind.cutsat.assert] 「↑a * ↑b」 + -1*↑c + 1 ≤ 0
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ h_1 : ⋯ ≍ ⋯
|
|||
[cases] [1/2]: X c 0
|
||||
[cases] source: Initial goal
|
||||
[cutsat] Assignment satisfying linear constraints
|
||||
[assign] c := 2
|
||||
[assign] c := 1
|
||||
[assign] s := 0
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue