fix: ac_nf0, simp_arith: don't tempt the kernel to reduce atoms (#5708)

this fixes #5699 and fixes #5384.
This commit is contained in:
Joachim Breitner 2024-10-16 10:52:58 +02:00 committed by GitHub
parent b333de1a36
commit a2d2977228
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 109 additions and 36 deletions

View file

@ -86,35 +86,63 @@ def toACExpr (op l r : Expr) : MetaM (Array Expr × ACExpr) := do
| PreExpr.op l r => Data.AC.Expr.op (toACExpr varMap l) (toACExpr varMap r)
| PreExpr.var x => Data.AC.Expr.var (varMap x)
def buildNormProof (preContext : PreContext) (l r : Expr) : MetaM (Lean.Expr × Lean.Expr) := do
let (vars, acExpr) ← toACExpr preContext.op l r
let α ← inferType vars[0]!
/--
In order to prevent the kernel trying to reduce the atoms of the expression, we abstract the proof
over them. But `ac_rfl` proofs are not completely abstract in the value of the atoms it recognizes
neutral elements. So we have to abstract over these proofs as well.
-/
def abstractAtoms (preContext : PreContext) (atoms : Array Expr)
(k : Array (Expr × Option Expr) → MetaM Expr) : MetaM Expr := do
let α ← inferType atoms[0]!
let u ← getLevel α
let (isNeutrals, context) ← mkContext α u vars
let acExprNormed := Data.AC.evalList ACExpr preContext $ Data.AC.norm (preContext, isNeutrals) acExpr
let tgt := convertTarget vars acExprNormed
let lhs := convert acExpr
let rhs := convert acExprNormed
let proof := mkAppN (mkConst ``Context.eq_of_norm [u]) #[α, context, lhs, rhs, ←mkEqRefl (mkConst ``Bool.true)]
let rec go i (acc : Array (Expr × Option Expr)) (vars : Array Expr) (args : Array Expr) := do
if h : i < atoms.size then
withLocalDeclD `x α fun v => do
match (← getInstance ``LawfulIdentity #[preContext.op, atoms[i]]) with
| none =>
go (i+1) (acc.push (v, .none)) (vars.push v) (args.push atoms[i])
| some inst =>
withLocalDeclD `inst (mkApp3 (mkConst ``LawfulIdentity [u]) α preContext.op v) fun iv =>
go (i+1) (acc.push (v, .some iv)) (vars ++ #[v,iv]) (args ++ #[atoms[i], inst])
else
let proof ← k acc
let proof ← mkLambdaFVars vars proof
let proof := mkAppN proof args
return proof
go 0 #[] #[] #[]
def buildNormProof (preContext : PreContext) (l r : Expr) : MetaM (Lean.Expr × Lean.Expr) := do
let (atoms, acExpr) ← toACExpr preContext.op l r
let proof ← abstractAtoms preContext atoms fun varsData => do
let α ← inferType atoms[0]!
let u ← getLevel α
let context ← mkContext α u varsData
let isNeutrals := varsData.map (·.2.isSome)
let vars := varsData.map (·.1)
let acExprNormed := Data.AC.evalList ACExpr preContext $ Data.AC.norm (preContext, isNeutrals) acExpr
let lhs := convert acExpr
let rhs := convert acExprNormed
let proof := mkAppN (mkConst ``Context.eq_of_norm [u]) #[α, context, lhs, rhs, ←mkEqRefl (mkConst ``Bool.true)]
let proofType ← mkEq (convertTarget vars acExpr) (convertTarget vars acExprNormed)
let proof ← mkExpectedTypeHint proof proofType
return proof
let some (_, _, tgt) := (← inferType proof).eq? | panic! "unexpected proof type"
return (proof, tgt)
where
mkContext (α : Expr) (u : Level) (vars : Array Expr) : MetaM (Array Bool × Expr) := do
let arbitrary := vars[0]!
mkContext (α : Expr) (u : Level) (vars : Array (Expr × Option Expr)) : MetaM Expr := do
let arbitrary := vars[0]!.1
let plift := mkApp (mkConst ``PLift [.zero])
let pliftUp := mkApp2 (mkConst ``PLift.up [.zero])
let noneE tp := mkApp (mkConst ``Option.none [.zero]) (plift tp)
let someE tp v := mkApp2 (mkConst ``Option.some [.zero]) (plift tp) (pliftUp tp v)
let vars ← vars.mapM fun x => do
let vars ← vars.mapM fun ⟨x, inst?⟩ =>
let isNeutral :=
let isNeutralClass := mkApp3 (mkConst ``LawfulIdentity [u]) α preContext.op x
match ←getInstance ``LawfulIdentity #[preContext.op, x] with
| none => (false, noneE isNeutralClass)
| some isNeutral => (true, someE isNeutralClass isNeutral)
match inst? with
| none => noneE isNeutralClass
| some isNeutral => someE isNeutralClass isNeutral
return mkApp4 (mkConst ``Variable.mk [u]) α preContext.op x isNeutral
return (isNeutral.1, mkApp4 (mkConst ``Variable.mk [u]) α preContext.op x isNeutral.2)
let (isNeutrals, vars) := vars.unzip
let vars := vars.toList
let vars ← mkListLit (mkApp2 (mkConst ``Variable [u]) α preContext.op) vars
@ -130,7 +158,7 @@ where
| none => noneE idemClass
| some idem => someE idemClass idem
return (isNeutrals, mkApp7 (mkConst ``Lean.Data.AC.Context.mk [u]) α preContext.op preContext.assoc comm idem vars arbitrary)
return mkApp7 (mkConst ``Lean.Data.AC.Context.mk [u]) α preContext.op preContext.assoc comm idem vars arbitrary
convert : ACExpr → Expr
| .op l r => mkApp2 (mkConst ``Data.AC.Expr.op) (convert l) (convert r)

View file

@ -8,24 +8,43 @@ import Lean.Meta.Tactic.LinearArith.Nat.Basic
namespace Lean.Meta.Linear.Nat
/-
To prevent the kernel from accidentially reducing the atoms in the equation while typechecking,
we abstract over them.
-/
def withAbstractAtoms (atoms : Array Expr) (k : Array Expr → MetaM (Option (Expr × Expr))) :
MetaM (Option (Expr × Expr)) := do
let atoms := atoms
let decls : Array (Name × (Array Expr → MetaM Expr)) ← atoms.mapM fun _ => do
return ((← mkFreshUserName `x), fun _ => pure (mkConst ``Nat))
withLocalDeclsD decls fun ctxt => do
let some (r, p) ← k ctxt | return none
let r := (← mkLambdaFVars ctxt r).beta atoms
let p := mkAppN (← mkLambdaFVars ctxt p) atoms
return some (r, p)
def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (some c, ctx) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none
let c₁ := c.toPoly
let c₂ := c₁.norm
if c₂.isUnsat then
let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_false_of_isUnsat) (← toContextExpr ctx) (toExpr c) reflTrue
return some (mkConst ``False, p)
else if c₂.isValid then
let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_true_of_isValid) (← toContextExpr ctx) (toExpr c) reflTrue
return some (mkConst ``True, p)
else
let c₂ : LinearCnstr := c₂.toExpr
let r ← c₂.toArith ctx
if r != e then
let p := mkApp4 (mkConst ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq) (← toContextExpr ctx) (toExpr c) (toExpr c₂) reflTrue
return some (r, ← mkExpectedTypeHint p (← mkEq e r))
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none
withAbstractAtoms atoms fun ctx => do
let lhs ← c.toArith ctx
let c₁ := c.toPoly
let c₂ := c₁.norm
if c₂.isUnsat then
let r := mkConst ``False
let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_false_of_isUnsat) (← toContextExpr ctx) (toExpr c) reflTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
else if c₂.isValid then
let r := mkConst ``True
let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_true_of_isValid) (← toContextExpr ctx) (toExpr c) reflTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
else
return none
let c₂ : LinearCnstr := c₂.toExpr
let r ← c₂.toArith ctx
if r != lhs then
let p := mkApp4 (mkConst ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq) (← toContextExpr ctx) (toExpr c) (toExpr c₂) reflTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
else
return none
def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
if let some arg := e.not? then

View file

@ -0,0 +1,16 @@
-- A function that reduced badly, as a canary for kernel reduction
def bad (n : Nat) : Nat :=
if h : n = 0 then 0 else bad (n / 2)
termination_by n
theorem foo : 2 * bad 42000 = bad 42000 + bad 42000 := by simp_arith
theorem foo2 (h : 2 * bad 42000 = bad 42000 + bad 42000 + 1) : False := by simp_arith at h
theorem foo3 (h : bad 42000 + bad 42000 = x) : (2 * bad 42000 = x) := by simp_arith at h; assumption
@[irreducible] def f : Nat → Nat := fun x => x
theorem doesn't_do_anything : f 3 = 3 := by
fail_if_success simp_arith -- does not apply f_eq and g_eq
rw [f]

View file

@ -0,0 +1,10 @@
axiom foo {p : Prop} {x : BitVec 32} (h : (!x == x + 0#32) = true) : p
theorem add_eq_sub_not_sub_one (x : BitVec 32) (h : (!x == x + (1#32 + 4294967295#32)) = true) : False := by
simp only [BitVec.reduceAdd] at h
exact foo h -- this works
theorem add_eq_sub_not_sub_one' (x : BitVec 32) (h : (!x == x + 1#32 + 4294967295#32) = true) : False := by
ac_nf0 at h
simp only [BitVec.reduceAdd] at h
exact foo h -- this used to hang