diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/IneqCnstr.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/IneqCnstr.lean index d71eb8240b..0b371bf040 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/IneqCnstr.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/IneqCnstr.lean @@ -15,22 +15,16 @@ import Lean.Meta.Tactic.Grind.Arith.Linear.StructId import Lean.Meta.Tactic.Grind.Arith.Linear.Reify import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr import Lean.Meta.Tactic.Grind.Arith.Linear.Proof -public section +import Lean.Meta.Tactic.Grind.Arith.Linear.OfNatModule namespace Lean.Meta.Grind.Arith.Linear -def isLeInst (struct : Struct) (inst : Expr) : Bool := - if let some leFn := struct.leFn? then - isSameExpr leFn.appArg! inst +def isInstOf (fn? : Option Expr) (inst : Expr) : Bool := + if let some fn := fn? then + isSameExpr fn.appArg! inst else false -def isLtInst (struct : Struct) (inst : Expr) : Bool := - if let some ltFn := struct.ltFn? then - isSameExpr ltFn.appArg! inst - else - false - -def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do +public def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do trace[grind.linarith.assert] "{← c.denoteExpr}" match c.p with | .nil => @@ -52,19 +46,19 @@ def IneqCnstr.assert (c : IneqCnstr) : LinearM Unit := do def propagateCommRingIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : LinearM Unit := do let some lhs ← withRingM <| CommRing.reify? lhs (skipVar := false) | return () let some rhs ← withRingM <| CommRing.reify? rhs (skipVar := false) | return () - let gen ← getGeneration e + let generation ← getGeneration e if eqTrue then let p' := (lhs.sub rhs).toPoly - let lhs' ← p'.toIntModuleExpr gen - let some lhs' ← reify? lhs' (skipVar := false) | return () + let lhs' ← p'.toIntModuleExpr generation + let some lhs' ← reify? lhs' (skipVar := false) generation | return () let p := lhs'.norm let c : IneqCnstr := { p, strict, h := .coreCommRing e lhs rhs p' lhs' } c.assert else if (← isLinearOrder) then let p' := (rhs.sub lhs).toPoly let strict := !strict - let lhs' ← p'.toIntModuleExpr gen - let some lhs' ← reify? lhs' (skipVar := false) | return () + let lhs' ← p'.toIntModuleExpr generation + let some lhs' ← reify? lhs' (skipVar := false) generation | return () let p := lhs'.norm let c : IneqCnstr := { p, strict, h := .notCoreCommRing e lhs rhs p' lhs' } c.assert @@ -73,8 +67,8 @@ def propagateCommRingIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : modifyStruct fun s => { s with ignored := s.ignored.push e } def propagateIntModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : LinearM Unit := do - let some lhs ← reify? lhs (skipVar := false) | return () - let some rhs ← reify? rhs (skipVar := false) | return () + let some lhs ← reify? lhs (skipVar := false) (← getGeneration lhs) | return () + let some rhs ← reify? rhs (skipVar := false) (← getGeneration rhs) | return () if eqTrue then let p := (lhs.sub rhs).norm let c : IneqCnstr := { p, strict, h := .core e lhs rhs } @@ -88,27 +82,55 @@ def propagateIntModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : -- Negation for preorders is not supported modifyStruct fun s => { s with ignored := s.ignored.push e } -def propagateIneq (e : Expr) (eqTrue : Bool) : GoalM Unit := do +def propagateNatModuleIneq (e : Expr) (lhs rhs : Expr) (strict : Bool) (eqTrue : Bool) : OfNatModuleM Unit := do + let ns ← getNatStruct + let (lhs₁, _) ← ofNatModule lhs + let (rhs₁, _) ← ofNatModule rhs + LinearM.run ns.structId do + let some lhs₂ ← reify? lhs₁ (skipVar := false) (← getGeneration lhs) | return () + let some rhs₂ ← reify? rhs₁ (skipVar := false) (← getGeneration rhs) | return () + if eqTrue then + let p := (lhs₂.sub rhs₂).norm + let c : IneqCnstr := { p, strict, h := .coreOfNat e ns.id lhs₂ rhs₂ } + c.assert + else + let p := (rhs₂.sub lhs₂).norm + let strict := !strict + let c : IneqCnstr := { p, strict, h := .notCoreOfNat e ns.id lhs₂ rhs₂ } + c.assert + +public def propagateIneq (e : Expr) (eqTrue : Bool) : GoalM Unit := do unless (← getConfig).linarith do return () let numArgs := e.getAppNumArgs unless numArgs == 4 do return () let α := e.getArg! 0 numArgs - let some structId ← getStructId? α | return () - LinearM.run structId do - let inst := e.getArg! 1 numArgs - let struct ← getStruct - let strict ← if isLeInst struct inst then + let inst := e.getArg! 1 numArgs + let lhs := e.getArg! 2 numArgs + let rhs := e.getArg! 3 numArgs + if let some structId ← getStructId? α then LinearM.run structId do + let s ← getStruct + let strict ← if isInstOf s.leFn? inst then pure false - else if isLtInst struct inst then + else if isInstOf s.ltFn? inst then pure true else return () - let lhs := e.getArg! 2 numArgs - let rhs := e.getArg! 3 numArgs if (← isOrderedCommRing) then propagateCommRingIneq e lhs rhs strict eqTrue -- TODO: non-commutative ring normalizer else propagateIntModuleIneq e lhs rhs strict eqTrue + else if let some natStructId ← getNatStructId? α then OfNatModuleM.run natStructId do + let s ← getNatStruct + if s.leInst?.isNone || s.isPreorderInst?.isNone || s.orderedAddInst?.isNone then return () + let strict ← if some inst == s.leInst? then + pure false + else if some inst == s.ltInst? then + pure true + else + return () + if strict && s.lawfulOrderLTInst?.isNone then return () + if !eqTrue && s.isLinearInst?.isNone then return () + propagateNatModuleIneq e lhs rhs strict eqTrue end Lean.Meta.Grind.Arith.Linear diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/OfNatModule.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/OfNatModule.lean index 5e0c7c9900..6bc82adb74 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/OfNatModule.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/OfNatModule.lean @@ -60,12 +60,17 @@ def setTermNatStructId (e : Expr) : OfNatModuleM Unit := do modify' fun s => { s with exprToNatStructId := s.exprToNatStructId.insert { expr := e } id } private def mkOfNatModuleVar (e : Expr) : OfNatModuleM (Expr × Expr) := do - let s ← getNatStruct - let toQe := mkApp s.toQFn e - let h := mkApp s.rfl_q toQe - setTermNatStructId e - markAsLinarithTerm e - return (toQe, h) + if let some r := (← getNatStruct).termMap.find? { expr := e } then + return r + else + let s ← getNatStruct + let toQe ← shareCommon (mkApp s.toQFn e) + let h := mkApp s.rfl_q toQe + let r := (toQe, h) + modifyNatStruct fun s => { s with termMap := s.termMap.insert { expr := e } r } + setTermNatStructId e + markAsLinarithTerm e + return r private def isAddInst (natStruct : NatStruct) (inst : Expr) : Bool := isSameExpr natStruct.addFn.appArg! inst @@ -102,6 +107,13 @@ private partial def ofNatModule' (e : Expr) : OfNatModuleM (Expr × Expr) := do pure (e', h) else mkOfNatModuleVar e + | OfNat.ofNat _ _ _ => + if (← isDefEqD e ns.zero) then + let e' := s.zero + let h := mkApp2 (mkConst ``Grind.IntModule.OfNatModule.toQ_zero [ns.u]) ns.type ns.natModuleInst + pure (e', h) + else + mkOfNatModuleVar e | _ => mkOfNatModuleVar e def ofNatModule (e : Expr) : OfNatModuleM (Expr × Expr) := do @@ -115,7 +127,6 @@ def ofNatModule (e : Expr) : OfNatModuleM (Expr × Expr) := do else pure (r.expr, h) setTermNatStructId e - internalize e' (← getGeneration e) modifyNatStruct fun s => { s with termMap := s.termMap.insert { expr := e } (e', h) } return (e', h) diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean index 050d95c4c3..76ef993b8f 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean @@ -250,6 +250,38 @@ partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' d let h' := mkApp5 h' (← mkRingExprDecl lhs) (← mkRingExprDecl rhs) (← mkRingPolyDecl p') eagerReflBoolTrue (mkOfEqFalseCore e (← mkEqFalseProof e)) let h ← if c'.strict then mkIntModLawfulPreOrdThmPrefix ``Grind.Linarith.lt_norm else mkIntModPreOrdThmPrefix ``Grind.Linarith.le_norm return mkApp5 h (← mkExprDecl lhs') (← mkExprDecl .zero) (← mkPolyDecl c'.p) eagerReflBoolTrue h' + | .coreOfNat e natStructId lhs rhs => + let h' ← OfNatModuleM.run natStructId do + let a := e.appFn!.appArg! + let b := e.appArg! + let ns ← getNatStruct + let (a', ha) ← ofNatModule a + let (b', hb) ← ofNatModule b + let h := if c'.strict then + mkApp7 (mkConst ``Grind.IntModule.OfNatModule.of_lt [ns.u]) ns.type ns.natModuleInst ns.leInst?.get! ns.ltInst?.get! + ns.lawfulOrderLTInst?.get! ns.isPreorderInst?.get! ns.orderedAddInst?.get! + else + mkApp5 (mkConst ``Grind.IntModule.OfNatModule.of_le [ns.u]) ns.type ns.natModuleInst ns.leInst?.get! + ns.isPreorderInst?.get! ns.orderedAddInst?.get! + return mkApp7 h a b a' b' ha hb (mkOfEqTrueCore e (← mkEqTrueProof e)) + let h ← if c'.strict then mkIntModLawfulPreOrdThmPrefix ``Grind.Linarith.lt_norm else mkIntModPreOrdThmPrefix ``Grind.Linarith.le_norm + return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) eagerReflBoolTrue h' + | .notCoreOfNat e natStructId lhs rhs => + let h' ← OfNatModuleM.run natStructId do + let a := e.appFn!.appArg! + let b := e.appArg! + let ns ← getNatStruct + let (a', ha) ← ofNatModule a + let (b', hb) ← ofNatModule b + let h := if c'.strict then + mkApp5 (mkConst ``Grind.IntModule.OfNatModule.of_not_le [ns.u]) ns.type ns.natModuleInst ns.leInst?.get! + ns.isPreorderInst?.get! ns.orderedAddInst?.get! + else + mkApp7 (mkConst ``Grind.IntModule.OfNatModule.of_not_lt [ns.u]) ns.type ns.natModuleInst ns.leInst?.get! ns.ltInst?.get! + ns.lawfulOrderLTInst?.get! ns.isPreorderInst?.get! ns.orderedAddInst?.get! + return mkApp7 h a b a' b' ha hb (mkOfEqFalseCore e (← mkEqFalseProof e)) + let h ← mkIntModLinOrdThmPrefix (if c'.strict then ``Grind.Linarith.not_le_norm else ``Grind.Linarith.not_lt_norm) + return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) eagerReflBoolTrue h' | .combine c₁ c₂ => let (pre, c₁, c₂) := match c₁.strict, c₂.strict with @@ -266,6 +298,15 @@ partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' d | .ofEq a b la lb => let h ← mkIntModPreOrdThmPrefix ``Grind.Linarith.le_of_eq return mkApp5 h (← mkExprDecl la) (← mkExprDecl lb) (← mkPolyDecl c'.p) eagerReflBoolTrue (← mkEqProof a b) + | .ofEqOfNat a b natStructId la lb => + let h' ← OfNatModuleM.run natStructId do + let ns ← getNatStruct + let (a', ha) ← ofNatModule a + let (b', hb) ← ofNatModule b + return mkApp9 (mkConst ``Grind.IntModule.OfNatModule.of_eq [ns.u]) ns.type ns.natModuleInst + a b a' b' ha hb (← mkEqProof a b) + let h ← mkIntModPreOrdThmPrefix ``Grind.Linarith.le_of_eq + return mkApp5 h (← mkExprDecl la) (← mkExprDecl lb) (← mkPolyDecl c'.p) eagerReflBoolTrue h' | .ofCommRingEq a b la lb p' lhs' => let h' ← mkCommRingThmPrefix ``Grind.CommRing.eq_norm let h' := mkApp5 h' (← mkRingExprDecl la) (← mkRingExprDecl lb) (← mkRingPolyDecl p') eagerReflBoolTrue (← mkEqProof a b) @@ -278,7 +319,7 @@ partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' d let hNot := mkLambda `h .default (mkApp2 lt (← c₁.p.denoteExpr) (← getZero)) (hFalse.abstract #[mkFVar fvarId]) let h ← mkIntModLinOrdThmPrefix ``Grind.Linarith.diseq_split_resolve return mkApp5 h (← mkPolyDecl c₁.p) (← mkPolyDecl c'.p) eagerReflBoolTrue (← c₁.toExprProof) hNot - | _ => throwError "not implemented yet" + | .subst .. | .norm .. => throwError "NIY" partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c' do match c'.h with @@ -318,6 +359,15 @@ partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do let h ← mkIntModThmPrefix ``Grind.Linarith.eq_norm return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) eagerReflBoolTrue (← mkEqProof a b) | .coreCommRing .. => throwError "not implemented yet" + | .coreOfNat a b natStructId lhs rhs => + let h' ← OfNatModuleM.run natStructId do + let ns ← getNatStruct + let (a', ha) ← ofNatModule a + let (b', hb) ← ofNatModule b + return mkApp9 (mkConst ``Grind.IntModule.OfNatModule.of_eq [ns.u]) ns.type ns.natModuleInst + a b a' b' ha hb (← mkEqProof a b) + let h ← mkIntModThmPrefix ``Grind.Linarith.eq_norm + return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) eagerReflBoolTrue h' | .neg c => let h ← mkIntModThmPrefix ``Grind.Linarith.eq_neg return mkApp4 h (← mkPolyDecl c.p) (← mkPolyDecl c'.p) eagerReflBoolTrue (← c.toExprProof) @@ -356,8 +406,8 @@ mutual partial def IneqCnstr.collectDecVars (c' : IneqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c') do match c'.h with - | .core .. | .notCore .. | .coreCommRing .. | .notCoreCommRing .. - | .oneGtZero | .ofEq .. | .ofCommRingEq .. => return () + | .core .. | .notCore .. | .coreCommRing .. | .notCoreCommRing .. | .coreOfNat .. | .notCoreOfNat .. + | .oneGtZero | .ofEq .. | .ofEqOfNat .. | .ofCommRingEq .. => return () | .combine c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars | .norm c₁ _ => c₁.collectDecVars | .dec h => markAsFound h @@ -375,7 +425,7 @@ partial def DiseqCnstr.collectDecVars (c' : DiseqCnstr) : CollectDecVarsM Unit : partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c') do match c'.h with | .subst _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars - | .core .. | .coreCommRing .. => return () + | .core .. | .coreCommRing .. | .coreOfNat .. => return () | .neg c | .coeff _ c => c.collectDecVars end diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean index 94b3af8784..fd19bd79b3 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/PropagateEq.lean @@ -55,24 +55,24 @@ def inSameStruct? (a b : Expr) : GoalM (Option Nat) := do private def processNewCommRingEq' (a b : Expr) : LinearM Unit := do let some lhs ← withRingM <| CommRing.reify? a (skipVar := false) | return () let some rhs ← withRingM <| CommRing.reify? b (skipVar := false) | return () - let gen := max (← getGeneration a) (← getGeneration b) + let generation := max (← getGeneration a) (← getGeneration b) let p' := (lhs.sub rhs).toPoly - let lhs' ← p'.toIntModuleExpr gen - let some lhs' ← reify? lhs' (skipVar := false) | return () + let lhs' ← p'.toIntModuleExpr generation + let some lhs' ← reify? lhs' (skipVar := false) generation | return () let p := lhs'.norm if p == .nil then return () let c₁ : IneqCnstr := { p, strict := false, h := .ofCommRingEq a b lhs rhs p' lhs' } c₁.assert let p := p.mul (-1) let p' := p'.mulConst (-1) - let lhs' ← p'.toIntModuleExpr gen - let some lhs' ← reify? lhs' (skipVar := false) | return () + let lhs' ← p'.toIntModuleExpr generation + let some lhs' ← reify? lhs' (skipVar := false) generation | return () let c₂ : IneqCnstr := { p, strict := false, h := .ofCommRingEq b a rhs lhs p' lhs' } c₂.assert private def processNewIntModuleEq' (a b : Expr) : LinearM Unit := do - let some lhs ← reify? a (skipVar := false) | return () - let some rhs ← reify? b (skipVar := false) | return () + let some lhs ← reify? a (skipVar := false) (← getGeneration a) | return () + let some rhs ← reify? b (skipVar := false) (← getGeneration b) | return () let p := (lhs.sub rhs).norm if p == .nil then return () let c₁ : IneqCnstr := { p, strict := false, h := .ofEq a b lhs rhs } @@ -216,20 +216,45 @@ private def processNewCommRingEq (a b : Expr) : LinearM Unit := do -- TODO private def processNewIntModuleEq (a b : Expr) : LinearM Unit := do - let some lhs ← reify? a (skipVar := false) | return () - let some rhs ← reify? b (skipVar := false) | return () + let some lhs ← reify? a (skipVar := false) (← getGeneration a) | return () + let some rhs ← reify? b (skipVar := false) (← getGeneration b) | return () let p := (lhs.sub rhs).norm if p == .nil then return () let c : EqCnstr := { p, h := .core a b lhs rhs } c.assert +private def processNewNatModuleEq' (a b : Expr) : OfNatModuleM Unit := do + let ns ← getNatStruct + let (a', _) ← ofNatModule a + let (b', _) ← ofNatModule b + LinearM.run ns.structId do + let some lhs ← reify? a' (skipVar := false) (← getGeneration a) | return () + let some rhs ← reify? b' (skipVar := false) (← getGeneration b) | return () + let p := (lhs.sub rhs).norm + if p == .nil then return () + let c₁ : IneqCnstr := { p, strict := false, h := .ofEqOfNat a b ns.id lhs rhs } + c₁.assert + let p := p.mul (-1) + let c₂ : IneqCnstr := { p, strict := false, h := .ofEqOfNat b a ns.id rhs lhs } + c₂.assert + +private def processNewNatModuleEq (a b : Expr) : OfNatModuleM Unit := do + let ns ← getNatStruct + let (a', _) ← ofNatModule a + let (b', _) ← ofNatModule b + LinearM.run ns.structId do + let some lhs ← reify? a' (skipVar := false) (← getGeneration a) | return () + let some rhs ← reify? b' (skipVar := false) (← getGeneration b) | return () + let p := (lhs.sub rhs).norm + if p == .nil then return () + let c : EqCnstr := { p, h := .coreOfNat a b ns.id lhs rhs } + c.assert + @[export lean_process_linarith_eq] def processNewEqImpl (a b : Expr) : GoalM Unit := do if isSameExpr a b then return () -- TODO: check why this is needed - let some structId ← inSameStruct? a b | return () - LinearM.run structId do + if let some structId ← inSameStruct? a b then LinearM.run structId do if (← isOrderedAdd) then - trace_goal[grind.linarith.assert] "{← mkEq a b}" if (← isCommRing) then processNewCommRingEq' a b else @@ -239,35 +264,39 @@ def processNewEqImpl (a b : Expr) : GoalM Unit := do processNewCommRingEq a b else processNewIntModuleEq a b + else if let some natStructId ← inSameNatStruct? a b then OfNatModuleM.run natStructId do + let ns ← getNatStruct + if ns.orderedAddInst?.isSome then + processNewNatModuleEq' a b + else + processNewNatModuleEq a b private def processNewCommRingDiseq (a b : Expr) : LinearM Unit := do let some lhs ← withRingM <| CommRing.reify? a (skipVar := false) | return () let some rhs ← withRingM <| CommRing.reify? b (skipVar := false) | return () - let gen := max (← getGeneration a) (← getGeneration b) + let generation := max (← getGeneration a) (← getGeneration b) let p' := (lhs.sub rhs).toPoly - let lhs' ← p'.toIntModuleExpr gen - let some lhs' ← reify? lhs' (skipVar := false) | return () + let lhs' ← p'.toIntModuleExpr generation + let some lhs' ← reify? lhs' (skipVar := false) generation | return () let p := lhs'.norm let c : DiseqCnstr := { p, h := .coreCommRing a b lhs rhs p' lhs' } c.assert private def processNewIntModuleDiseq (a b : Expr) : LinearM Unit := do - let some lhs ← reify? a (skipVar := false) | return () - let some rhs ← reify? b (skipVar := false) | return () + let some lhs ← reify? a (skipVar := false) (← getGeneration a) | return () + let some rhs ← reify? b (skipVar := false) (← getGeneration b) | return () let p := (lhs.sub rhs).norm let c : DiseqCnstr := { p, h := .core a b lhs rhs } c.assert private def processNewNatModuleDiseq (a b : Expr) : OfNatModuleM Unit := do let ns ← getNatStruct - trace[Meta.debug] "{a}, {b}" unless ns.addRightCancelInst?.isSome do return () let (a', _) ← ofNatModule a let (b', _) ← ofNatModule b - trace[Meta.debug] "{a'}, {b'}" LinearM.run ns.structId do - let some lhs ← reify? a' (skipVar := false) | return () - let some rhs ← reify? b' (skipVar := false) | return () + let some lhs ← reify? a' (skipVar := false) (← getGeneration a) | return () + let some rhs ← reify? b' (skipVar := false) (← getGeneration b) | return () let p := (lhs.sub rhs).norm let c : DiseqCnstr := { p, h := .coreOfNat a b ns.id lhs rhs } c.assert diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Reify.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Reify.lean index 800b3f91f4..cbf12d7e2f 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Reify.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Reify.lean @@ -39,7 +39,7 @@ Converts a Lean `IntModule` expression `e` into a `LinExpr` If `skipVar` is `true`, then the result is `none` if `e` is not an interpreted `IntModule` term. We use `skipVar := false` when processing inequalities, and `skipVar := true` for equalities and disequalities -/ -partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do +partial def reify? (e : Expr) (skipVar : Bool) (generation : Nat := 0) : LinearM (Option LinExpr) := do match_expr e with | HAdd.hAdd _ _ _ i a b => if isAddInst (← getStruct ) i then return some (.add (← go a) (← go b)) else asTopVar e @@ -61,10 +61,14 @@ partial def reify? (e : Expr) (skipVar : Bool) : LinearM (Option LinExpr) := do return some (← toVar e) where toVar (e : Expr) : LinearM LinExpr := do - return .var (← mkVar e) + if (← alreadyInternalized e) then + return .var (← mkVar e) + else + internalize e generation + return .var (← mkVar e) asVar (e : Expr) : LinearM LinExpr := do reportInstIssue e - return .var (← mkVar e) + toVar e asTopVar (e : Expr) : LinearM (Option LinExpr) := do reportInstIssue e if skipVar then diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean index 2da4965b1a..1f2bdf9dd0 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/StructId.lean @@ -321,6 +321,7 @@ where let ltInst? ← getInst? ``LT u type let isPreorderInst? ← mkIsPreorderInst? u type leInst? let lawfulOrderLTInst? ← mkLawfulOrderLTInst? u type ltInst? leInst? + let isLinearInst? ← mkIsLinearOrderInst? u type leInst? let addInst ← getBinHomoInst ``HAdd u type let addFn ← internalizeFn <| mkApp4 (mkConst ``HAdd.hAdd [u, u, u]) type type type addInst let orderedAddInst? ← match leInst?, isPreorderInst? with @@ -338,7 +339,7 @@ where let id := (← get').natStructs.size let natStruct : NatStruct := { id, structId, u, type, natModuleInst, - leInst?, ltInst?, lawfulOrderLTInst?, isPreorderInst?, orderedAddInst?, addRightCancelInst?, + leInst?, ltInst?, lawfulOrderLTInst?, isPreorderInst?, isLinearInst?, orderedAddInst?, addRightCancelInst?, rfl_q, zero, toQFn, addFn, smulFn } modify' fun s => { s with natStructs := s.natStructs.push natStruct } diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Types.lean index a88487db72..d3610bae87 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Types.lean @@ -28,6 +28,7 @@ structure EqCnstr where inductive EqCnstrProof where | core (a b : Expr) (lhs rhs : LinExpr) | coreCommRing (a b : Expr) (ra rb : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr) + | coreOfNat (a b : Expr) (natStructId : Nat) (lhs rhs : LinExpr) | neg (c : EqCnstr) | coeff (k : Nat) (c : EqCnstr) | subst (x : Var) (c₁ : EqCnstr) (c₂ : EqCnstr) @@ -43,6 +44,8 @@ inductive IneqCnstrProof where | notCore (e : Expr) (lhs rhs : LinExpr) | coreCommRing (e : Expr) (lhs rhs : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr) | notCoreCommRing (e : Expr) (lhs rhs : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr) + | coreOfNat (e : Expr) (natStructId : Nat) (lhs rhs : LinExpr) + | notCoreOfNat (e : Expr) (natStructId : Nat) (lhs rhs : LinExpr) | combine (c₁ : IneqCnstr) (c₂ : IneqCnstr) | norm (c₁ : IneqCnstr) (k : Nat) | dec (h : FVarId) @@ -50,6 +53,8 @@ inductive IneqCnstrProof where | oneGtZero | /-- `a ≤ b` from an equality `a = b` coming from the core. -/ ofEq (a b : Expr) (la lb : LinExpr) + | /-- `a ≤ b` from an equality `a = b` coming from the core. -/ + ofEqOfNat (a b : Expr) (natStructId : Nat) (la lb : LinExpr) | /-- `a ≤ b` from an equality `a = b` coming from the core. -/ ofCommRingEq (a b : Expr) (ra rb : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr) | subst (x : Var) (c₁ : EqCnstr) (c₂ : IneqCnstr) @@ -220,6 +225,8 @@ structure NatStruct where isPreorderInst? : Option Expr /-- `OrderedAdd` instance with `IsPreorder` if available -/ orderedAddInst? : Option Expr + /-- `IsLinearOrder` instance if available -/ + isLinearInst? : Option Expr addRightCancelInst? : Option Expr rfl_q : Expr -- `@Eq.Refl (OfNatModule.Q type)` zero : Expr diff --git a/tests/lean/run/grind_nat_module.lean b/tests/lean/run/grind_nat_module.lean index 995cb192d4..5f530878e2 100644 --- a/tests/lean/run/grind_nat_module.lean +++ b/tests/lean/run/grind_nat_module.lean @@ -1,5 +1,56 @@ -open Lean Grind -variable (M : Type) [NatModule M] [AddRightCancel M] +open Lean Grind Std +variable (M : Type) [NatModule M] +section +variable [AddRightCancel M] example (x y : M) : 2 • x + 3 • y + x = 3 • (x + y) := by grind +end + +section +variable [LE M] [LT M] [LawfulOrderLT M] [IsLinearOrder M] [OrderedAdd M] + +example {x y : M} (h : x ≤ y) : 2 • x + y ≤ 3 • y := by + grind +end + +section +variable [LE M] [LT M] [LawfulOrderLT M] [IsPreorder M] [OrderedAdd M] + +example {x y : M} : x ≤ y → 2 • x + y > 3 • y → False := by + grind + +example {x y z : M} : x ≤ y → y < z → 2 • x + y ≥ 3 • z → False := by + grind +end + +section +variable [LE M] [IsLinearOrder M] [OrderedAdd M] [AddRightCancel M] + +example {x y : M} : x + x ≤ y → y ≤ 2 • x → x + x ≠ y → False := by + grind +end + +section +variable [AddRightCancel M] + +example {x y : M} : x + x = y → 2•x ≠ y → False := by + grind + +example {x y z : M} : x + z = y → x = z → 2•x ≠ y → False := by + grind + +example {x y z : M} : x + z = y → x = 2•z → 3•z ≠ y → False := by + grind +end + +section +variable [LE M] [IsLinearOrder M] [OrderedAdd M] [AddRightCancel M] + +example {x y z : M} : x + z = y → x = 2•z → 3•z ≠ y → False := by + grind +end + +example [NatModule α] [AddRightCancel α] [LE α] [LT α] [LawfulOrderLT α] [IsLinearOrder α] [OrderedAdd α] (a b c d : α) + : a ≤ b → a ≥ c + d → d ≤ 0 → d ≥ 0 → b = c → a = b := by + grind diff --git a/tests/lean/grind/algebra/nat_module.lean b/tests/lean/run/grind_nat_module_2.lean similarity index 76% rename from tests/lean/grind/algebra/nat_module.lean rename to tests/lean/run/grind_nat_module_2.lean index b7548d7a14..ab85f78ed1 100644 --- a/tests/lean/grind/algebra/nat_module.lean +++ b/tests/lean/run/grind_nat_module_2.lean @@ -1,6 +1,5 @@ open Std Lean.Grind --- We could solve these problems by embedding the NatModule in its Grothendieck completion. section NatModule variable (M : Type) [NatModule M] [AddRightCancel M]