From 6b24eb474f999660f84182fe3df3b9555ecdbdc0 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 18 Aug 2025 19:19:50 -0700 Subject: [PATCH] fix: variable reordering in `grind cutsat` (#9980) This PR fixes a bug in the dynamic variable reordering function used in `grind cutsat`. Closes #9948 --- .../Meta/Tactic/Grind/Arith/Cutsat/Proof.lean | 24 +++++++++++++------ .../Grind/Arith/Cutsat/ReorderVars.lean | 1 + tests/lean/run/grind_9948.lean | 6 +++++ 3 files changed, 24 insertions(+), 7 deletions(-) create mode 100644 tests/lean/run/grind_9948.lean diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean index 61238bba9d..1c83cf663a 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean @@ -148,8 +148,10 @@ private def toContextExprCore (vars : Array Expr) (type : Expr) : MetaM Expr := else RArray.toExpr type id (RArray.leaf (mkIntLit 0)) +-- Remark: the `prime` flag is used just to distinguish variables before/after reordering. +-- Recall that we keep two contexts. The "prime" one is the one **before** reordering. private def mkContext - (ctxVar : Expr) (vars : PArray Expr) (varDecls : Std.HashMap Var Expr) (polyDecls : Std.HashMap Poly Expr) (exprDecls : Std.HashMap Int.Linear.Expr Expr) + (ctxVar : Expr) (prime : Bool) (vars : PArray Expr) (varDecls : Std.HashMap Var Expr) (polyDecls : Std.HashMap Poly Expr) (exprDecls : Std.HashMap Int.Linear.Expr Expr) (h : Expr) : GoalM Expr := do let usedVars := collectMapVars varDecls collectVar >> collectMapVars polyDecls (·.collectVars) >> collectMapVars exprDecls (·.collectVars) <| {} let vars' := usedVars.toArray @@ -158,13 +160,13 @@ private def mkContext let varFVars := vars'.map fun x => varDecls[x]?.getD default let varIdsAsExpr := List.range vars'.size |>.toArray |>.map toExpr let h := h.replaceFVars varFVars varIdsAsExpr - let h := mkLetOfMap exprDecls h `e (mkConst ``Int.Linear.Expr) fun e => toExpr <| e.renameVars varRename - let h := mkLetOfMap polyDecls h `p (mkConst ``Int.Linear.Poly) fun p => toExpr <| p.renameVars varRename + let h := mkLetOfMap exprDecls h (cond prime `e' `e) (mkConst ``Int.Linear.Expr) fun e => toExpr <| e.renameVars varRename + let h := mkLetOfMap polyDecls h (cond prime `p' `p) (mkConst ``Int.Linear.Poly) fun p => toExpr <| p.renameVars varRename let h := h.abstract #[ctxVar] if h.hasLooseBVars then let ctxType := mkApp (mkConst ``RArray [levelZero]) Int.mkType let ctxVal ← toContextExprCore vars Int.mkType - return .letE `ctx ctxType ctxVal h (nondep := false) + return .letE (cond prime `ctx' `ctx) ctxType ctxVal h (nondep := false) else return h @@ -195,8 +197,8 @@ where go : ProofM Expr := do let h ← x let h ← mkRingContext h - let h ← mkContext (← read).ctx' (← get').vars' (← get).varDecls' (← get).polyDecls' (← get).exprDecls' h - mkContext (← read).ctx (← get').vars (← get).varDecls (← get).polyDecls (← get).exprDecls h + let h ← mkContext (← read).ctx' (prime := true) (← get').vars' (← get).varDecls' (← get).polyDecls' (← get).exprDecls' h + mkContext (← read).ctx (prime := false) (← get').vars (← get).varDecls (← get).polyDecls (← get).exprDecls h /-- Returns a Lean expression representing the auxiliary `CommRing` variable context needed for normalizing @@ -210,6 +212,14 @@ private def DvdCnstr.get_d_a (c : DvdCnstr) : GoalM (Int × Int) := do let .add a _ _ := c.p | c.throwUnexpected return (d, a) +/-- +Similar to `denoteExpr'`, but takes into account the `unordered` flag in the `ProofM` context. +Recall that if `unordered` is `true`, we should use `vars'` +-/ +private def _root_.Int.Linear.Poly.denoteExprUsingCurrVars (p : Poly) : ProofM Expr := do + let vars ← if (← read).unordered then pure (← get').vars' else getVars + return (← p.denoteExpr (vars[·]!)) + mutual partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do trace[grind.debug.cutsat.proof] "{← c'.pp}" @@ -378,7 +388,7 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := caching c' do | .ofDiseqSplit c₁ fvarId h _ => let p₂ := c₁.p.addConst 1 let hFalse ← h.toExprProofCore - let hNot := mkLambda `h .default (mkIntLE (← p₂.denoteExpr') (mkIntLit 0)) (hFalse.abstract #[mkFVar fvarId]) + let hNot := mkLambda `h .default (mkIntLE (← p₂.denoteExprUsingCurrVars) (mkIntLit 0)) (hFalse.abstract #[mkFVar fvarId]) return mkApp7 (mkConst ``Int.Linear.diseq_split_resolve) (← getContext) (← mkPolyDecl c₁.p) (← mkPolyDecl p₂) (← mkPolyDecl c'.p) eagerReflBoolTrue (← c₁.toExprProof) hNot | .cooper s => diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ReorderVars.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ReorderVars.lean index ccf741dc9c..159ea06a0e 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ReorderVars.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/ReorderVars.lean @@ -150,6 +150,7 @@ def reorderVars : GoalM Unit := do varMap := s.varMap.map fun x => old2new[x]! vars' := s.vars varMap' := s.varMap + natDef := s.natDef.map fun x => old2new[x]! dvds := s.dvds.map fun _ => none lowers := s.lowers.map fun _ => {} uppers := s.uppers.map fun _ => {} diff --git a/tests/lean/run/grind_9948.lean b/tests/lean/run/grind_9948.lean new file mode 100644 index 0000000000..ca51d1f542 --- /dev/null +++ b/tests/lean/run/grind_9948.lean @@ -0,0 +1,6 @@ +theorem sum_of_n (n : Nat) : + (List.range (n + 1)).sum = n * (n + 1) / 2 := by + induction n with + | zero => rfl + | succ k ih => + grind [List.range_succ]