diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean index e122dc75e1..72ef45a833 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Proof.lean @@ -40,6 +40,7 @@ structure ProofM.State where exprDecls : Std.HashMap LinExpr Expr := {} ringPolyDecls : Std.HashMap CommRing.Poly Expr := {} ringExprDecls : Std.HashMap RingExpr Expr := {} + ringVarDecls : Std.HashMap Var Expr := {} structure ProofM.Context where ctx : Expr @@ -90,6 +91,9 @@ def mkRingPolyDecl (p : CommRing.Poly) : ProofM Expr := do def mkRingExprDecl (e : RingExpr) : ProofM Expr := do declare! ringExprDecls e +def mkRingVarDecl (x : Var) : ProofM Expr := do + declare! ringVarDecls x + private def mkContext (h : Expr) : ProofM Expr := do let varDecls := (← get).varDecls let polyDecls := (← get).polyDecls @@ -120,12 +124,17 @@ private def mkRingContext (h : Expr) : ProofM Expr := do unless (← isCommRing) do return h let ring ← withRingM do CommRing.getRing let vars := ring.vars - let usedVars := collectMapVars (← get).ringPolyDecls (·.collectVars) >> collectMapVars (← get).ringExprDecls (·.collectVars) <| {} + let ringVarDecls := (← get).ringVarDecls + let usedVars := collectMapVars (← get).ringPolyDecls (·.collectVars) >> collectMapVars (← get).ringExprDecls (·.collectVars) >> collectMapVars ringVarDecls collectVar <| {} let vars' := usedVars.toArray let varRename := mkVarRename vars' let vars := vars'.map fun x => vars[x]! let h := mkLetOfMap (← get).ringExprDecls h `re (mkConst ``Grind.CommRing.Expr) fun e => toExpr <| e.renameVars varRename let h := mkLetOfMap (← get).ringPolyDecls h `rp (mkConst ``Grind.CommRing.Poly) fun p => toExpr <| p.renameVars varRename + -- Replace ring variable FVars with their renamed indices + let varFVars := ringVarDecls.toArray.map (·.2) + let varIdsAsExpr := ringVarDecls.toArray.map fun (v, _) => toExpr (varRename v) + let h := h.replaceFVars varFVars varIdsAsExpr let h := h.abstract #[(← read).ringCtx] if h.hasLooseBVars then let struct ← getStruct @@ -281,7 +290,7 @@ partial def RingIneqCnstr.toExprProof (c' : RingIneqCnstr) : ProofM Expr := do mkCommRingLTThmPrefix ``Grind.CommRing.lt_cancel_var else mkCommRingLEThmPrefix ``Grind.CommRing.le_cancel_var - return mkApp7 h' (toExpr val) (toExpr x) p₁ (← mkRingPolyDecl c'.p) eagerReflBoolTrue h_eq_one h + return mkApp7 h' (toExpr val) (← mkRingVarDecl x) p₁ (← mkRingPolyDecl c'.p) eagerReflBoolTrue h_eq_one h partial def RingEqCnstr.toExprProof (c' : RingEqCnstr) : ProofM Expr := do match c'.h with @@ -299,7 +308,7 @@ partial def RingEqCnstr.toExprProof (c' : RingEqCnstr) : ProofM Expr := do let h := mkApp5 h (← mkRingPolyDecl c.p) (toExpr (val^n)) p₁ eagerReflBoolTrue (← c.toExprProof) let h_eq_one := mkApp2 (← mkFieldChar0ThmPrefix ``Grind.CommRing.inv_int_eq') (toExpr val) eagerReflBoolTrue let h' ← mkCommRingThmPrefix ``Grind.CommRing.eq_cancel_var - return mkApp7 h' (toExpr val) (toExpr x) p₁ (← mkRingPolyDecl c'.p) eagerReflBoolTrue h_eq_one h + return mkApp7 h' (toExpr val) (← mkRingVarDecl x) p₁ (← mkRingPolyDecl c'.p) eagerReflBoolTrue h_eq_one h partial def RingDiseqCnstr.toExprProof (c' : RingDiseqCnstr) : ProofM Expr := do match c'.h with @@ -313,7 +322,7 @@ partial def RingDiseqCnstr.toExprProof (c' : RingDiseqCnstr) : ProofM Expr := do let h := mkApp5 h (← mkRingPolyDecl c.p) (toExpr (val^n)) p₁ eagerReflBoolTrue (← c.toExprProof) let h_eq_one := mkApp2 (← mkFieldChar0ThmPrefix ``Grind.CommRing.inv_int_eq') (toExpr val) eagerReflBoolTrue let h' ← mkCommRingThmPrefix ``Grind.CommRing.diseq_cancel_var - return mkApp7 h' (toExpr val) (toExpr x) p₁ (← mkRingPolyDecl c'.p) eagerReflBoolTrue h_eq_one h + return mkApp7 h' (toExpr val) (← mkRingVarDecl x) p₁ (← mkRingPolyDecl c'.p) eagerReflBoolTrue h_eq_one h mutual partial def IneqCnstr.toExprProof (c' : IneqCnstr) : ProofM Expr := caching c' do diff --git a/tests/lean/run/grind_diseq_cancel_var_bug.lean b/tests/lean/run/grind_diseq_cancel_var_bug.lean new file mode 100644 index 0000000000..d2ffaae632 --- /dev/null +++ b/tests/lean/run/grind_diseq_cancel_var_bug.lean @@ -0,0 +1,26 @@ +-- Test for fix: grind diseq_cancel_var was using wrong variable index after renaming +-- Previously this would produce a kernel type mismatch error +-- The unused variable `hr` is intentional - it triggers the variable renaming that exposed the bug +set_option linter.unusedVariables false in +example (r : Rat) (hr : r ≤ r) : 2⁻¹ * 2 = (1 : Rat) := by grind + +-- Leo's test cases for diseq_cancel_var (from Zulip) +-- These should still work after the fix +open Std Lean.Grind + +variable {α : Type} [Field α] [LE α] [LT α] [LawfulOrderLT α] [IsLinearOrder α] [OrderedRing α] [NoNatZeroDivisors α] +example (a b : α) (h : a = b / 2) : a + a ≤ b := by grind +example (a : α) (h : a ≠ 1/2) : 2*a > 1 ∨ 2*a < 1 := by grind +example (a : α) (h : a ≠ (1/2)^3) : 8*a > 1 ∨ 8*a < 1 := by grind +example (a : α) (h : a ≠ (1/2)^3) : 8*a > 1 ∨ 8*a < 1 := by grind +example (a : α) (h : a ≠ (1/3)*(1/2)^3) : 24*a > 1 ∨ 24*a < 1 := by grind +example (a : α) (h : a ≠ b*(2⁻¹)^3) : 8*a > b ∨ 8*a < b := by grind +example (a : α) (h : a ≠ (2⁻¹)^3*b) : 8*a > b ∨ 8*a < b := by grind +example (a : α) (h : 5*(2⁻¹)^3*b ≠ a) : 8*a > 5*b ∨ 8*a < 5*b := by grind +example (a : α) (h : 5*(2⁻¹)*(2⁻¹)^3*b + (3/2)*c ≠ a) : 16*a > 5*b + 24*c ∨ 16*a < 5*b + 24*c := by grind +example (x : α) : x ≥ 1/3 → x ≥ 0 := by grind +example (a : α) (h : a ≠ 1/2 + 1/3) : 6*a > 5 ∨ 6*a < 5 := by grind +example (a : α) (h : 1/2 + 1/3 ≠ a) : 6*a > 5 ∨ 6*a < 5 := by grind +example (h : (1/4:α) ≠ (1/2)*(1/2)) : False := by grind +example (h : (1/4:α) + 1/4 ≠ (1/2)) : False := by grind +example (h : (1/2:α) + 1/3 ≠ (5/6)) : False := by grind