fix: handling of ite/dite expressions in cbv tactic (#12361)

This PR develops custom simprocs for dealing with `ite`/`dite`
expressions in `cbv` tactics, based on equivalent simprocs from
`Sym.simp`, with the difference that if the condition is not reduced to
`True`/`False`, we make use of the decidable instance and calculate to
what the condition reduces to.

Stacked on top of #12391.
This commit is contained in:
Wojciech Różowski 2026-02-09 15:00:10 +00:00 committed by GitHub
parent 919721c758
commit 57c5efe309
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 222 additions and 7 deletions

View file

@ -0,0 +1,189 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Wojciech Różowski
-/
module
prelude
public import Lean.Meta.Sym.Simp.SimpM
import Lean.Meta.Sym.Simp.Result
import Lean.Meta.Sym.Simp.Rewrite
import Lean.Meta.Sym.Simp.ControlFlow
import Lean.Meta.Sym.AlphaShareBuilder
import Lean.Meta.Sym.InstantiateS
import Lean.Meta.Sym.InferType
import Lean.Meta.Sym.Simp.App
import Lean.Meta.SynthInstance
import Lean.Meta.WHNF
import Lean.Meta.AppBuilder
import Init.Sym.Lemmas
import Lean.Meta.Tactic.Cbv.TheoremsLookup
namespace Lean.Meta.Sym.Simp
open Internal
public def simpIteCbv : Simproc := fun e => do
let numArgs := e.getAppNumArgs
if numArgs < 5 then return .rfl (done := true)
propagateOverApplied e (numArgs - 5) fun e => do
let_expr f@ite α c _ a b := e | return .rfl
match (← simp c) with
| .rfl _ =>
if (← isTrueExpr c) then
return .step a <| mkApp3 (mkConst ``ite_true f.constLevels!) α a b
else if (← isFalseExpr c) then
return .step b <| mkApp3 (mkConst ``ite_false f.constLevels!) α a b
else
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c) | return .rfl
let inst' ← shareCommon inst'
let toEval ← mkAppS₂ (mkConst ``Decidable.decide) c inst'
let evalRes ← simp toEval
match evalRes with
| .rfl _ =>
return .rfl (done := true)
| .step v hv _ =>
if (← isBoolTrueExpr v) then
let h' := mkApp3 (mkConst ``eq_true_of_decide) c inst' hv
let inst' := mkConst ``instDecidableTrue
let c' ← getTrueExpr
let e' := e.getBoundedAppFn 4
let e' ← mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.ite_cond_congr) c' inst' h'
let ha := mkApp3 (mkConst ``ite_true f.constLevels!) α a b
let ha ← mkEqTrans e e' h' a ha
return .step a ha (done := false)
else if (← isBoolFalseExpr v) then
let h' := mkApp3 (mkConst ``eq_false_of_decide) c inst' hv
let inst' := mkConst ``instDecidableFalse
let c' ← getFalseExpr
let e' := e.getBoundedAppFn 4
let e' ← mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.ite_cond_congr) c' inst' h'
let hb := mkApp3 (mkConst ``ite_false f.constLevels!) α a b
let hb ← mkEqTrans e e' h' b hb
return .step b hb (done := false)
else
return .rfl (done := true)
| .step c' h _ =>
if (← isTrueExpr c') then
return .step a <| mkApp (e.replaceFn ``ite_cond_eq_true) h
else if (← isFalseExpr c') then
return .step b <| mkApp (e.replaceFn ``ite_cond_eq_false) h
else
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
let inst' ← shareCommon inst'
let e' := e.getBoundedAppFn 4
let e' ← mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.ite_cond_congr) c' inst' h
return .step e' h'
public def simpDIteCbv : Simproc := fun e => do
let numArgs := e.getAppNumArgs
if numArgs < 5 then return .rfl (done := true)
propagateOverApplied e (numArgs - 5) fun e => do
let_expr f@dite α c _ a b := e | return .rfl
match (← simp c) with
| .rfl _ =>
if (← isTrueExpr c) then
let a' ← share <| a.betaRev #[mkConst ``True.intro]
return .step a' <| mkApp3 (mkConst ``dite_true f.constLevels!) α a b
else if (← isFalseExpr c) then
let b' ← share <| b.betaRev #[mkConst ``not_false]
return .step b' <| mkApp3 (mkConst ``dite_false f.constLevels!) α a b
else
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c) | return .rfl
let inst' ← shareCommon inst'
let toEval ← mkAppS₂ (mkConst ``Decidable.decide) c inst'
let evalRes ← simp toEval
match evalRes with
| .rfl _ => return .rfl (done := true)
| .step v hv _ =>
if (← isBoolTrueExpr v) then
let h' := mkApp3 (mkConst ``eq_true_of_decide) c inst' hv
let inst' := mkConst ``instDecidableTrue
let e' := e.getBoundedAppFn 4
let h ← shareCommon h'
let c' ← getTrueExpr
let a ← share <| mkLambda `h .default c' (a.betaRev #[mkApp4 (mkConst ``Eq.mpr_prop) c c' h (mkBVar 0)])
let b ← share <| mkLambda `h .default (mkNot c') (b.betaRev #[mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)])
let e' ← mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.dite_cond_congr) c' inst' h
let a' ← share <| a.betaRev #[mkConst ``True.intro]
let ha := mkApp3 (mkConst ``dite_true f.constLevels!) α a b
let ha ← mkEqTrans e e' h' a' ha
return .step a' ha
else if (← isBoolFalseExpr v) then
let h' := mkApp3 (mkConst ``eq_false_of_decide) c inst' hv
let inst' := mkConst ``instDecidableFalse
let e' := e.getBoundedAppFn 4
let h ← shareCommon h'
let c' ← getFalseExpr
let a ← share <| mkLambda `h .default c' (a.betaRev #[mkApp4 (mkConst ``Eq.mpr_prop) c c' h (mkBVar 0)])
let b ← share <| mkLambda `h .default (mkNot c') (b.betaRev #[mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)])
let e' ← mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.dite_cond_congr) c' inst' h
let b' ← share <| b.betaRev #[mkConst ``not_false]
let hb := mkApp3 (mkConst ``dite_false f.constLevels!) α a b
let hb ← mkEqTrans e e' h' b' hb
return .step b' hb
else
return .rfl (done := true)
| .step c' h _ =>
if (← isTrueExpr c') then
let h' ← shareCommon <| mkOfEqTrueCore c h
let a ← share <| a.betaRev #[h']
return .step a <| mkApp (e.replaceFn ``dite_cond_eq_true) h
else if (← isFalseExpr c') then
let h' ← shareCommon <| mkOfEqFalseCore c h
let b ← share <| b.betaRev #[h']
return .step b <| mkApp (e.replaceFn ``dite_cond_eq_false) h
else
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
let inst' ← shareCommon inst'
let e' := e.getBoundedAppFn 4
let h ← shareCommon h
let a ← share <| mkLambda `h .default c' (a.betaRev #[mkApp4 (mkConst ``Eq.mpr_prop) c c' h (mkBVar 0)])
let b ← share <| mkLambda `h .default (mkNot c') (b.betaRev #[mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)])
let e' ← mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.dite_cond_congr) c' inst' h
return .step e' h'
end Lean.Meta.Sym.Simp
namespace Lean.Meta.Tactic.Cbv
open Lean.Meta.Sym.Simp
def tryMatchEquations (appFn : Name) : Simproc := fun e => do
let thms ← getMatchTheorems appFn
thms.rewrite (d := dischargeNone) e
public def reduceRecMatcher : Simproc := fun e => do
if let some e' ← reduceRecMatcher? e then
return .step e' (← Sym.mkEqRefl e')
else
return .rfl
def tryMatcher : Simproc := fun e => do
unless e.isApp do
return .rfl
let some appFn := e.getAppFn.constName? | return .rfl
let some info ← getMatcherInfo? appFn | return .rfl
let start := info.numParams + 1
let stop := start + info.numDiscrs
(simpAppArgRange · start stop)
>> tryMatchEquations appFn
<|> reduceRecMatcher
<| e
public def simpControlCbv : Simproc := fun e => do
if !e.isApp then return .rfl
let .const declName _ := e.getAppFn | return .rfl
if declName == ``ite then
simpIteCbv e
else if declName == ``cond then
simpCond e
else if declName == ``dite then
simpDIteCbv e
else
tryMatcher e
end Lean.Meta.Tactic.Cbv

View file

@ -9,6 +9,7 @@ module
prelude
public import Lean.Meta.Sym.Simp.SimpM
public import Lean.Meta.Tactic.Cbv.Opaque
public import Lean.Meta.Tactic.Cbv.ControlFlow
import Lean.Meta.Tactic.Cbv.Util
import Lean.Meta.Tactic.Cbv.TheoremsLookup
import Lean.Meta.Sym
@ -28,12 +29,6 @@ def tryMatchEquations (appFn : Name) : Simproc := fun e => do
let thms ← getMatchTheorems appFn
thms.rewrite (d := dischargeNone) e
def reduceRecMatcher : Simproc := fun e => do
if let some e' ← reduceRecMatcher? e then
return .step e' (← Sym.mkEqRefl e')
else
return .rfl
def tryEquations : Simproc := fun e => do
unless e.isApp do
return .rfl
@ -149,7 +144,7 @@ def handleConst : Simproc := fun e => do
def cbvPre : Simproc :=
isBuiltinValue <|> isProofTerm <|> skipBinders
>> isOpaqueApp
>> (tryMatcher >> simpControl)
>> simpControlCbv
<|> ((isOpaqueConst >> handleConst) <|> simplifyAppFn <|> handleProj)
def cbvPost : Simproc :=

View file

@ -0,0 +1,31 @@
set_option cbv.warning false
example : (if (true = false) then 5 else 7) = 7 := by
conv =>
lhs
cbv
example : (if (true = ((fun x => x) true)) then 5 else 7) = 5 := by
conv =>
lhs
cbv
example : (if (String.Pos.Raw.mk 1 = String.Pos.Raw.mk 2) then 5 else 42) = 42 := by
conv =>
lhs
cbv
example : (if (String.Pos.Raw.mk 1 = String.Pos.Raw.mk 1) then 5 else 42) = 5 := by
conv =>
lhs
cbv
example : (if _ : String.Pos.Raw.mk 1 = String.Pos.Raw.mk 2 then 5 else 42) = 42 := by
conv =>
lhs
cbv
example : (if _ : String.Pos.Raw.mk 1 = String.Pos.Raw.mk 1 then 5 else 42) = 5 := by
conv =>
lhs
cbv