feat: NatModule inequalities and equalities in grind linarith (#10278)

This PR adds support for `NatModule` equalities and inequalities in
`grind linarith`. Examples:
```lean
open Lean Grind Std

example [NatModule α] [LE α] [LT α] 
  [LawfulOrderLT α] [IsLinearOrder α] [OrderedAdd α] 
  (x y : α) : x ≤ y → 2 • x + y ≤ 3 • y := by
  grind

example [NatModule α] [AddRightCancel α] [LE α] [LT α] 
    [LawfulOrderLT α] [IsLinearOrder α] [OrderedAdd α] 
    (a b c d : α) : a ≤ b → a ≥ c + d → d ≤ 0 → d ≥ 0 → b = c → a = b := by
  grind
```
This commit is contained in:
Leonardo de Moura 2025-09-06 13:52:09 -07:00 committed by GitHub
parent 52a9fe3b67
commit 2ff41f43be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 240 additions and 66 deletions

View file

@ -15,22 +15,16 @@ import Lean.Meta.Tactic.Grind.Arith.Linear.StructId
import Lean.Meta.Tactic.Grind.Arith.Linear.Reify
import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr
import Lean.Meta.Tactic.Grind.Arith.Linear.Proof
public section
import Lean.Meta.Tactic.Grind.Arith.Linear.OfNatModule
namespace Lean.Meta.Grind.Arith.Linear
def isLeInst (struct : Struct) (inst : Expr) : Bool :=
if let some leFn := struct.leFn? then
isSameExpr leFn.appArg! inst
def isInstOf (fn? : Option Expr) (inst : Expr) : Bool :=
if let some fn := fn? then
isSameExpr fn.appArg! inst
else
false
def isLtInst (struct : Struct) (inst : Expr) : Bool :=
if let some ltFn := struct.ltFn? then
isSameExpr ltFn.appArg! inst
else
false
def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do
public def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do
trace[grind.linarith.assert] "{← c.denoteExpr}"
match c.p with
| .nil =>
@ -52,19 +46,19 @@ def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do
def propagateCommRingIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : LinearM Unit := do
let some lhs ← withRingM <| CommRing.reify? lhs (skipVar := false) | return ()
let some rhs ← withRingM <| CommRing.reify? rhs (skipVar := false) | return ()
let gen ← getGeneration e
let generation ← getGeneration e
if eqTrue then
let p' := (lhs.sub rhs).toPoly
let lhs' ← p'.toIntModuleExpr gen
let some lhs' ← reify? lhs' (skipVar := false) | return ()
let lhs' ← p'.toIntModuleExpr generation
let some lhs' ← reify? lhs' (skipVar := false) generation | return ()
let p := lhs'.norm
let c : IneqCnstr := { p, strict, h := .coreCommRing e lhs rhs p' lhs' }
c.assert
else if (← isLinearOrder) then
let p' := (rhs.sub lhs).toPoly
let strict := !strict
let lhs' ← p'.toIntModuleExpr gen
let some lhs' ← reify? lhs' (skipVar := false) | return ()
let lhs' ← p'.toIntModuleExpr generation
let some lhs' ← reify? lhs' (skipVar := false) generation | return ()
let p := lhs'.norm
let c : IneqCnstr := { p, strict, h := .notCoreCommRing e lhs rhs p' lhs' }
c.assert
@ -73,8 +67,8 @@ def propagateCommRingIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue :
modifyStruct fun s => { s with ignored := s.ignored.push e }
def propagateIntModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : LinearM Unit := do
let some lhs ← reify? lhs (skipVar := false) | return ()
let some rhs ← reify? rhs (skipVar := false) | return ()
let some lhs ← reify? lhs (skipVar := false) (← getGeneration lhs) | return ()
let some rhs ← reify? rhs (skipVar := false) (← getGeneration rhs) | return ()
if eqTrue then
let p := (lhs.sub rhs).norm
let c : IneqCnstr := { p, strict, h := .core e lhs rhs }
@ -88,27 +82,55 @@ def propagateIntModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue :
-- Negation for preorders is not supported
modifyStruct fun s => { s with ignored := s.ignored.push e }
def propagateIneq (e : Expr) (eqTrue : Bool) : GoalM Unit := do
def propagateNatModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : OfNatModuleM Unit := do
let ns ← getNatStruct
let (lhs₁, _) ← ofNatModule lhs
let (rhs₁, _) ← ofNatModule rhs
LinearM.run ns.structId do
let some lhs₂ ← reify? lhs₁ (skipVar := false) (← getGeneration lhs) | return ()
let some rhs₂ ← reify? rhs₁ (skipVar := false) (← getGeneration rhs) | return ()
if eqTrue then
let p := (lhs₂.sub rhs₂).norm
let c : IneqCnstr := { p, strict, h := .coreOfNat e ns.id lhs₂ rhs₂ }
c.assert
else
let p := (rhs₂.sub lhs₂).norm
let strict := !strict
let c : IneqCnstr := { p, strict, h := .notCoreOfNat e ns.id lhs₂ rhs₂ }
c.assert
public def propagateIneq (e : Expr) (eqTrue : Bool) : GoalM Unit := do
unless (← getConfig).linarith do return ()
let numArgs := e.getAppNumArgs
unless numArgs == 4 do return ()
let α := e.getArg! 0 numArgs
let some structId ← getStructId? α | return ()
LinearM.run structId do
let inst := e.getArg! 1 numArgs
let struct ← getStruct
let strict ← if isLeInst struct inst then
let inst := e.getArg! 1 numArgs
let lhs := e.getArg! 2 numArgs
let rhs := e.getArg! 3 numArgs
if let some structId ← getStructId? α then LinearM.run structId do
let s ← getStruct
let strict ← if isInstOf s.leFn? inst then
pure false
else if isLtInst struct inst then
else if isInstOf s.ltFn? inst then
pure true
else
return ()
let lhs := e.getArg! 2 numArgs
let rhs := e.getArg! 3 numArgs
if (← isOrderedCommRing) then
propagateCommRingIneq e lhs rhs strict eqTrue
-- TODO: non-commutative ring normalizer
else
propagateIntModuleIneq e lhs rhs strict eqTrue
else if let some natStructId ← getNatStructId? α then OfNatModuleM.run natStructId do
let s ← getNatStruct
if s.leInst?.isNone || s.isPreorderInst?.isNone || s.orderedAddInst?.isNone then return ()
let strict ← if some inst == s.leInst? then
pure false
else if some inst == s.ltInst? then
pure true
else
return ()
if strict && s.lawfulOrderLTInst?.isNone then return ()
if !eqTrue && s.isLinearInst?.isNone then return ()
propagateNatModuleIneq e lhs rhs strict eqTrue
end Lean.Meta.Grind.Arith.Linear

View file

@ -60,12 +60,17 @@ def setTermNatStructId (e : Expr) : OfNatModuleM Unit := do
modify' fun s => { s with exprToNatStructId := s.exprToNatStructId.insert { expr := e } id }
private def mkOfNatModuleVar (e : Expr) : OfNatModuleM (Expr × Expr) := do
let s ← getNatStruct
let toQe := mkApp s.toQFn e
let h := mkApp s.rfl_q toQe
setTermNatStructId e
markAsLinarithTerm e
return (toQe, h)
if let some r := (← getNatStruct).termMap.find? { expr := e } then
return r
else
let s ← getNatStruct
let toQe ← shareCommon (mkApp s.toQFn e)
let h := mkApp s.rfl_q toQe
let r := (toQe, h)
modifyNatStruct fun s => { s with termMap := s.termMap.insert { expr := e } r }
setTermNatStructId e
markAsLinarithTerm e
return r
private def isAddInst (natStruct : NatStruct) (inst : Expr) : Bool :=
isSameExpr natStruct.addFn.appArg! inst
@ -102,6 +107,13 @@ private partial def ofNatModule' (e : Expr) : OfNatModuleM (Expr × Expr) := do
pure (e', h)
else
mkOfNatModuleVar e
| OfNat.ofNat _ _ _ =>
if (← isDefEqD e ns.zero) then
let e' := s.zero
let h := mkApp2 (mkConst ``Grind.IntModule.OfNatModule.toQ_zero [ns.u]) ns.type ns.natModuleInst
pure (e', h)
else
mkOfNatModuleVar e
| _ => mkOfNatModuleVar e
def ofNatModule (e : Expr) : OfNatModuleM (Expr × Expr) := do
@ -115,7 +127,6 @@ def ofNatModule (e : Expr) : OfNatModuleM (Expr × Expr) := do
else
pure (r.expr, h)
setTermNatStructId e
internalize e' (← getGeneration e)
modifyNatStruct fun s => { s with termMap := s.termMap.insert { expr := e } (e', h) }
return (e', h)

View file

@ -250,6 +250,38 @@ partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' d
let h' := mkApp5 h' (← mkRingExprDecl lhs) (← mkRingExprDecl rhs) (← mkRingPolyDecl p') eagerReflBoolTrue (mkOfEqFalseCore e (← mkEqFalseProof e))
let h ← if c'.strict then mkIntModLawfulPreOrdThmPrefix ``Grind.Linarith.lt_norm else mkIntModPreOrdThmPrefix ``Grind.Linarith.le_norm
return mkApp5 h (← mkExprDecl lhs') (← mkExprDecl .zero) (← mkPolyDecl c'.p) eagerReflBoolTrue h'
| .coreOfNat e natStructId lhs rhs =>
let h' ← OfNatModuleM.run natStructId do
let a := e.appFn!.appArg!
let b := e.appArg!
let ns ← getNatStruct
let (a', ha) ← ofNatModule a
let (b', hb) ← ofNatModule b
let h := if c'.strict then
mkApp7 (mkConst ``Grind.IntModule.OfNatModule.of_lt [ns.u]) ns.type ns.natModuleInst ns.leInst?.get! ns.ltInst?.get!
ns.lawfulOrderLTInst?.get! ns.isPreorderInst?.get! ns.orderedAddInst?.get!
else
mkApp5 (mkConst ``Grind.IntModule.OfNatModule.of_le [ns.u]) ns.type ns.natModuleInst ns.leInst?.get!
ns.isPreorderInst?.get! ns.orderedAddInst?.get!
return mkApp7 h a b a' b' ha hb (mkOfEqTrueCore e (← mkEqTrueProof e))
let h ← if c'.strict then mkIntModLawfulPreOrdThmPrefix ``Grind.Linarith.lt_norm else mkIntModPreOrdThmPrefix ``Grind.Linarith.le_norm
return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) eagerReflBoolTrue h'
| .notCoreOfNat e natStructId lhs rhs =>
let h' ← OfNatModuleM.run natStructId do
let a := e.appFn!.appArg!
let b := e.appArg!
let ns ← getNatStruct
let (a', ha) ← ofNatModule a
let (b', hb) ← ofNatModule b
let h := if c'.strict then
mkApp5 (mkConst ``Grind.IntModule.OfNatModule.of_not_le [ns.u]) ns.type ns.natModuleInst ns.leInst?.get!
ns.isPreorderInst?.get! ns.orderedAddInst?.get!
else
mkApp7 (mkConst ``Grind.IntModule.OfNatModule.of_not_lt [ns.u]) ns.type ns.natModuleInst ns.leInst?.get! ns.ltInst?.get!
ns.lawfulOrderLTInst?.get! ns.isPreorderInst?.get! ns.orderedAddInst?.get!
return mkApp7 h a b a' b' ha hb (mkOfEqFalseCore e (← mkEqFalseProof e))
let h ← mkIntModLinOrdThmPrefix (if c'.strict then ``Grind.Linarith.not_le_norm else ``Grind.Linarith.not_lt_norm)
return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) eagerReflBoolTrue h'
| .combine c₁ c₂ =>
let (pre, c₁, c₂) :=
match c₁.strict, c₂.strict with
@ -266,6 +298,15 @@ partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' d
| .ofEq a b la lb =>
let h ← mkIntModPreOrdThmPrefix ``Grind.Linarith.le_of_eq
return mkApp5 h (← mkExprDecl la) (← mkExprDecl lb) (← mkPolyDecl c'.p) eagerReflBoolTrue (← mkEqProof a b)
| .ofEqOfNat a b natStructId la lb =>
let h' ← OfNatModuleM.run natStructId do
let ns ← getNatStruct
let (a', ha) ← ofNatModule a
let (b', hb) ← ofNatModule b
return mkApp9 (mkConst ``Grind.IntModule.OfNatModule.of_eq [ns.u]) ns.type ns.natModuleInst
a b a' b' ha hb (← mkEqProof a b)
let h ← mkIntModPreOrdThmPrefix ``Grind.Linarith.le_of_eq
return mkApp5 h (← mkExprDecl la) (← mkExprDecl lb) (← mkPolyDecl c'.p) eagerReflBoolTrue h'
| .ofCommRingEq a b la lb p' lhs' =>
let h' ← mkCommRingThmPrefix ``Grind.CommRing.eq_norm
let h' := mkApp5 h' (← mkRingExprDecl la) (← mkRingExprDecl lb) (← mkRingPolyDecl p') eagerReflBoolTrue (← mkEqProof a b)
@ -278,7 +319,7 @@ partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' d
let hNot := mkLambda `h .default (mkApp2 lt (← c₁.p.denoteExpr) (← getZero)) (hFalse.abstract #[mkFVar fvarId])
let h ← mkIntModLinOrdThmPrefix ``Grind.Linarith.diseq_split_resolve
return mkApp5 h (← mkPolyDecl c₁.p) (← mkPolyDecl c'.p) eagerReflBoolTrue (← c₁.toExprProof) hNot
| _ => throwError "not implemented yet"
| .subst .. | .norm .. => throwError "NIY"
partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c' do
match c'.h with
@ -318,6 +359,15 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do
let h ← mkIntModThmPrefix ``Grind.Linarith.eq_norm
return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) eagerReflBoolTrue (← mkEqProof a b)
| .coreCommRing .. => throwError "not implemented yet"
| .coreOfNat a b natStructId lhs rhs =>
let h' ← OfNatModuleM.run natStructId do
let ns ← getNatStruct
let (a', ha) ← ofNatModule a
let (b', hb) ← ofNatModule b
return mkApp9 (mkConst ``Grind.IntModule.OfNatModule.of_eq [ns.u]) ns.type ns.natModuleInst
a b a' b' ha hb (← mkEqProof a b)
let h ← mkIntModThmPrefix ``Grind.Linarith.eq_norm
return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) eagerReflBoolTrue h'
| .neg c =>
let h ← mkIntModThmPrefix ``Grind.Linarith.eq_neg
return mkApp4 h (← mkPolyDecl c.p) (← mkPolyDecl c'.p) eagerReflBoolTrue (← c.toExprProof)
@ -356,8 +406,8 @@ mutual
partial def IneqCnstr.collectDecVars (c' : IneqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c') do
match c'.h with
| .core .. | .notCore .. | .coreCommRing .. | .notCoreCommRing ..
| .oneGtZero | .ofEq .. | .ofCommRingEq .. => return ()
| .core .. | .notCore .. | .coreCommRing .. | .notCoreCommRing .. | .coreOfNat .. | .notCoreOfNat ..
| .oneGtZero | .ofEq .. | .ofEqOfNat .. | .ofCommRingEq .. => return ()
| .combine c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
| .norm c₁ _ => c₁.collectDecVars
| .dec h => markAsFound h
@ -375,7 +425,7 @@ partial def DiseqCnstr.collectDecVars (c' : DiseqCnstr) : CollectDecVarsM Unit :
partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c') do
match c'.h with
| .subst _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
| .core .. | .coreCommRing .. => return ()
| .core .. | .coreCommRing .. | .coreOfNat .. => return ()
| .neg c | .coeff _ c => c.collectDecVars
end

View file

@ -55,24 +55,24 @@ def inSameStruct? (a b : Expr) : GoalM (Option Nat) := do
private def processNewCommRingEq' (a b : Expr) : LinearM Unit := do
let some lhs ← withRingM <| CommRing.reify? a (skipVar := false) | return ()
let some rhs ← withRingM <| CommRing.reify? b (skipVar := false) | return ()
let gen := max (← getGeneration a) (← getGeneration b)
let generation := max (← getGeneration a) (← getGeneration b)
let p' := (lhs.sub rhs).toPoly
let lhs' ← p'.toIntModuleExpr gen
let some lhs' ← reify? lhs' (skipVar := false) | return ()
let lhs' ← p'.toIntModuleExpr generation
let some lhs' ← reify? lhs' (skipVar := false) generation | return ()
let p := lhs'.norm
if p == .nil then return ()
let c₁ : IneqCnstr := { p, strict := false, h := .ofCommRingEq a b lhs rhs p' lhs' }
c₁.assert
let p := p.mul (-1)
let p' := p'.mulConst (-1)
let lhs' ← p'.toIntModuleExpr gen
let some lhs' ← reify? lhs' (skipVar := false) | return ()
let lhs' ← p'.toIntModuleExpr generation
let some lhs' ← reify? lhs' (skipVar := false) generation | return ()
let c₂ : IneqCnstr := { p, strict := false, h := .ofCommRingEq b a rhs lhs p' lhs' }
c₂.assert
private def processNewIntModuleEq' (a b : Expr) : LinearM Unit := do
let some lhs ← reify? a (skipVar := false) | return ()
let some rhs ← reify? b (skipVar := false) | return ()
let some lhs ← reify? a (skipVar := false) (← getGeneration a) | return ()
let some rhs ← reify? b (skipVar := false) (← getGeneration b) | return ()
let p := (lhs.sub rhs).norm
if p == .nil then return ()
let c₁ : IneqCnstr := { p, strict := false, h := .ofEq a b lhs rhs }
@ -216,20 +216,45 @@ private def processNewCommRingEq (a b : Expr) : LinearM Unit := do
-- TODO
private def processNewIntModuleEq (a b : Expr) : LinearM Unit := do
let some lhs ← reify? a (skipVar := false) | return ()
let some rhs ← reify? b (skipVar := false) | return ()
let some lhs ← reify? a (skipVar := false) (← getGeneration a) | return ()
let some rhs ← reify? b (skipVar := false) (← getGeneration b) | return ()
let p := (lhs.sub rhs).norm
if p == .nil then return ()
let c : EqCnstr := { p, h := .core a b lhs rhs }
c.assert
private def processNewNatModuleEq' (a b : Expr) : OfNatModuleM Unit := do
let ns ← getNatStruct
let (a', _) ← ofNatModule a
let (b', _) ← ofNatModule b
LinearM.run ns.structId do
let some lhs ← reify? a' (skipVar := false) (← getGeneration a) | return ()
let some rhs ← reify? b' (skipVar := false) (← getGeneration b) | return ()
let p := (lhs.sub rhs).norm
if p == .nil then return ()
let c₁ : IneqCnstr := { p, strict := false, h := .ofEqOfNat a b ns.id lhs rhs }
c₁.assert
let p := p.mul (-1)
let c₂ : IneqCnstr := { p, strict := false, h := .ofEqOfNat b a ns.id rhs lhs }
c₂.assert
private def processNewNatModuleEq (a b : Expr) : OfNatModuleM Unit := do
let ns ← getNatStruct
let (a', _) ← ofNatModule a
let (b', _) ← ofNatModule b
LinearM.run ns.structId do
let some lhs ← reify? a' (skipVar := false) (← getGeneration a) | return ()
let some rhs ← reify? b' (skipVar := false) (← getGeneration b) | return ()
let p := (lhs.sub rhs).norm
if p == .nil then return ()
let c : EqCnstr := { p, h := .coreOfNat a b ns.id lhs rhs }
c.assert
@[export lean_process_linarith_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := do
if isSameExpr a b then return () -- TODO: check why this is needed
let some structId ← inSameStruct? a b | return ()
LinearM.run structId do
if let some structId ← inSameStruct? a b then LinearM.run structId do
if (← isOrderedAdd) then
trace_goal[grind.linarith.assert] "{← mkEq a b}"
if (← isCommRing) then
processNewCommRingEq' a b
else
@ -239,35 +264,39 @@ def processNewEqImpl (a b : Expr) : GoalM Unit := do
processNewCommRingEq a b
else
processNewIntModuleEq a b
else if let some natStructId ← inSameNatStruct? a b then OfNatModuleM.run natStructId do
let ns ← getNatStruct
if ns.orderedAddInst?.isSome then
processNewNatModuleEq' a b
else
processNewNatModuleEq a b
private def processNewCommRingDiseq (a b : Expr) : LinearM Unit := do
let some lhs ← withRingM <| CommRing.reify? a (skipVar := false) | return ()
let some rhs ← withRingM <| CommRing.reify? b (skipVar := false) | return ()
let gen := max (← getGeneration a) (← getGeneration b)
let generation := max (← getGeneration a) (← getGeneration b)
let p' := (lhs.sub rhs).toPoly
let lhs' ← p'.toIntModuleExpr gen
let some lhs' ← reify? lhs' (skipVar := false) | return ()
let lhs' ← p'.toIntModuleExpr generation
let some lhs' ← reify? lhs' (skipVar := false) generation | return ()
let p := lhs'.norm
let c : DiseqCnstr := { p, h := .coreCommRing a b lhs rhs p' lhs' }
c.assert
private def processNewIntModuleDiseq (a b : Expr) : LinearM Unit := do
let some lhs ← reify? a (skipVar := false) | return ()
let some rhs ← reify? b (skipVar := false) | return ()
let some lhs ← reify? a (skipVar := false) (← getGeneration a) | return ()
let some rhs ← reify? b (skipVar := false) (← getGeneration b) | return ()
let p := (lhs.sub rhs).norm
let c : DiseqCnstr := { p, h := .core a b lhs rhs }
c.assert
private def processNewNatModuleDiseq (a b : Expr) : OfNatModuleM Unit := do
let ns ← getNatStruct
trace[Meta.debug] "{a}, {b}"
unless ns.addRightCancelInst?.isSome do return ()
let (a', _) ← ofNatModule a
let (b', _) ← ofNatModule b
trace[Meta.debug] "{a'}, {b'}"
LinearM.run ns.structId do
let some lhs ← reify? a' (skipVar := false) | return ()
let some rhs ← reify? b' (skipVar := false) | return ()
let some lhs ← reify? a' (skipVar := false) (← getGeneration a) | return ()
let some rhs ← reify? b' (skipVar := false) (← getGeneration b) | return ()
let p := (lhs.sub rhs).norm
let c : DiseqCnstr := { p, h := .coreOfNat a b ns.id lhs rhs }
c.assert

View file

@ -39,7 +39,7 @@ Converts a Lean `IntModule` expression `e` into a `LinExpr`
If `skipVar` is `true`, then the result is `none` if `e` is not an interpreted `IntModule` term.
We use `skipVar := false` when processing inequalities, and `skipVar := true` for equalities and disequalities
-/
partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do
partial def reify? (e : Expr) (skipVar : Bool) (generation : Nat := 0) : LinearM (Option LinExpr) := do
match_expr e with
| HAdd.hAdd _ _ _ i a b =>
if isAddInst (← getStruct ) i then return some (.add (← go a) (← go b)) else asTopVar e
@ -61,10 +61,14 @@ partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do
return some (← toVar e)
where
toVar (e : Expr) : LinearM LinExpr := do
return .var (← mkVar e)
if (← alreadyInternalized e) then
return .var (← mkVar e)
else
internalize e generation
return .var (← mkVar e)
asVar (e : Expr) : LinearM LinExpr := do
reportInstIssue e
return .var (← mkVar e)
toVar e
asTopVar (e : Expr) : LinearM (Option LinExpr) := do
reportInstIssue e
if skipVar then

View file

@ -321,6 +321,7 @@ where
let ltInst? ← getInst? ``LT u type
let isPreorderInst? ← mkIsPreorderInst? u type leInst?
let lawfulOrderLTInst? ← mkLawfulOrderLTInst? u type ltInst? leInst?
let isLinearInst? ← mkIsLinearOrderInst? u type leInst?
let addInst ← getBinHomoInst ``HAdd u type
let addFn ← internalizeFn <| mkApp4 (mkConst ``HAdd.hAdd [u, u, u]) type type type addInst
let orderedAddInst? ← match leInst?, isPreorderInst? with
@ -338,7 +339,7 @@ where
let id := (← get').natStructs.size
let natStruct : NatStruct := {
id, structId, u, type, natModuleInst,
leInst?, ltInst?, lawfulOrderLTInst?, isPreorderInst?, orderedAddInst?, addRightCancelInst?,
leInst?, ltInst?, lawfulOrderLTInst?, isPreorderInst?, isLinearInst?, orderedAddInst?, addRightCancelInst?,
rfl_q, zero, toQFn, addFn, smulFn
}
modify' fun s => { s with natStructs := s.natStructs.push natStruct }

View file

@ -28,6 +28,7 @@ structure EqCnstr where
inductive EqCnstrProof where
| core (a b : Expr) (lhs rhs : LinExpr)
| coreCommRing (a b : Expr) (ra rb : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr)
| coreOfNat (a b : Expr) (natStructId : Nat) (lhs rhs : LinExpr)
| neg (c : EqCnstr)
| coeff (k : Nat) (c : EqCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : EqCnstr)
@ -43,6 +44,8 @@ inductive IneqCnstrProof where
| notCore (e : Expr) (lhs rhs : LinExpr)
| coreCommRing (e : Expr) (lhs rhs : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr)
| notCoreCommRing (e : Expr) (lhs rhs : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr)
| coreOfNat (e : Expr) (natStructId : Nat) (lhs rhs : LinExpr)
| notCoreOfNat (e : Expr) (natStructId : Nat) (lhs rhs : LinExpr)
| combine (c₁ : IneqCnstr) (c₂ : IneqCnstr)
| norm (c₁ : IneqCnstr) (k : Nat)
| dec (h : FVarId)
@ -50,6 +53,8 @@ inductive IneqCnstrProof where
| oneGtZero
| /-- `a ≤ b` from an equality `a = b` coming from the core. -/
ofEq (a b : Expr) (la lb : LinExpr)
| /-- `a ≤ b` from an equality `a = b` coming from the core. -/
ofEqOfNat (a b : Expr) (natStructId : Nat) (la lb : LinExpr)
| /-- `a ≤ b` from an equality `a = b` coming from the core. -/
ofCommRingEq (a b : Expr) (ra rb : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : IneqCnstr)
@ -220,6 +225,8 @@ structure NatStruct where
isPreorderInst? : Option Expr
/-- `OrderedAdd` instance with `IsPreorder` if available -/
orderedAddInst? : Option Expr
/-- `IsLinearOrder` instance if available -/
isLinearInst? : Option Expr
addRightCancelInst? : Option Expr
rfl_q : Expr -- `@Eq.Refl (OfNatModule.Q type)`
zero : Expr

View file

@ -1,5 +1,56 @@
open Lean Grind
variable (M : Type) [NatModule M] [AddRightCancel M]
open Lean Grind Std
variable (M : Type) [NatModule M]
section
variable [AddRightCancel M]
example (x y : M) : 2 • x + 3 • y + x = 3 • (x + y) := by
grind
end
section
variable [LE M] [LT M] [LawfulOrderLT M] [IsLinearOrder M] [OrderedAdd M]
example {x y : M} (h : x ≤ y) : 2 • x + y ≤ 3 • y := by
grind
end
section
variable [LE M] [LT M] [LawfulOrderLT M] [IsPreorder M] [OrderedAdd M]
example {x y : M} : x ≤ y → 2 • x + y > 3 • y → False := by
grind
example {x y z : M} : x ≤ y → y < z → 2 • x + y ≥ 3 • z → False := by
grind
end
section
variable [LE M] [IsLinearOrder M] [OrderedAdd M] [AddRightCancel M]
example {x y : M} : x + x ≤ y → y ≤ 2 • x → x + x ≠ y → False := by
grind
end
section
variable [AddRightCancel M]
example {x y : M} : x + x = y → 2•x ≠ y → False := by
grind
example {x y z : M} : x + z = y → x = z → 2•x ≠ y → False := by
grind
example {x y z : M} : x + z = y → x = 2•z → 3•z ≠ y → False := by
grind
end
section
variable [LE M] [IsLinearOrder M] [OrderedAdd M] [AddRightCancel M]
example {x y z : M} : x + z = y → x = 2•z → 3•z ≠ y → False := by
grind
end
example [NatModule α] [AddRightCancel α] [LE α] [LT α] [LawfulOrderLT α] [IsLinearOrder α] [OrderedAdd α] (a b c d : α)
: a ≤ b → a ≥ c + d → d ≤ 0 → d ≥ 0 → b = c → a = b := by
grind

View file

@ -1,6 +1,5 @@
open Std Lean.Grind
-- We could solve these problems by embedding the NatModule in its Grothendieck completion.
section NatModule
variable (M : Type) [NatModule M] [AddRightCancel M]