perf: in CaseValues, subst only once (#11510)

This PR avoids running substCore twice in caseValues.
This commit is contained in:
Joachim Breitner 2025-12-04 16:43:46 +01:00 committed by GitHub
parent 5f561bfee2
commit f0738c2cd1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 52 additions and 60 deletions

View file

@ -12,6 +12,7 @@ public import Lean.Elab.BindersUtil
public import Lean.Elab.PatternVar
public import Lean.Elab.Quotation.Precheck
public import Lean.Elab.SyntheticMVars
import Lean.Meta.Match.Value
import Lean.Meta.Match.NamedPatterns
public section

View file

@ -6,8 +6,14 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Basic
public import Lean.Meta.Tactic.FVarSubst
public import Lean.Meta.CollectFVars
public import Lean.Meta.Match.CaseArraySizes
import Lean.Meta.Match.Value
import Lean.Meta.AppBuilder
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Assert
import Lean.Meta.Tactic.Subst
import Lean.Meta.Match.NamedPatterns
public section

View file

@ -6,32 +6,36 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Match.CaseValues
public section
public import Lean.Meta.Basic
public import Lean.Meta.Tactic.FVarSubst
import Lean.Meta.Match.CaseValues
import Lean.Meta.AppBuilder
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Assert
import Lean.Meta.Tactic.Subst
namespace Lean.Meta
structure CaseArraySizesSubgoal where
public structure CaseArraySizesSubgoal where
mvarId : MVarId
elems : Array FVarId := #[]
diseqs : Array FVarId := #[]
subst : FVarSubst := {}
deriving Inhabited
def getArrayArgType (a : Expr) : MetaM Expr := do
public def getArrayArgType (a : Expr) : MetaM Expr := do
let aType ← inferType a
let aType ← whnfD aType
unless aType.isAppOfArity ``Array 1 do
throwError "array expected{indentExpr a}"
pure aType.appArg!
private def mkArrayGetLit (a : Expr) (i : Nat) (n : Nat) (h : Expr) : MetaM Expr := do
def mkArrayGetLit (a : Expr) (i : Nat) (n : Nat) (h : Expr) : MetaM Expr := do
let lt ← mkLt (mkRawNatLit i) (mkRawNatLit n)
let ltPrf ← mkDecideProof lt
mkAppM `Array.getLit #[a, mkRawNatLit i, h, ltPrf]
private partial def introArrayLit (mvarId : MVarId) (a : Expr) (n : Nat) (xNamePrefix : Name) (aSizeEqN : Expr) : MetaM MVarId := do
partial def introArrayLit (mvarId : MVarId) (a : Expr) (n : Nat) (xNamePrefix : Name) (aSizeEqN : Expr) : MetaM MVarId := do
let α ← getArrayArgType a
let rec loop (i : Nat) (xs : Array Expr) (args : Array Expr) := do
if i < n then
@ -61,7 +65,7 @@ private partial def introArrayLit (mvarId : MVarId) (a : Expr) (n : Nat) (xNameP
n) `..., x_1 ... x_{sizes[n-1]} |- C #[x_1, ..., x_{sizes[n-1]}]`
n+1) `..., (h_1 : a.size != sizes[0]), ..., (h_n : a.size != sizes[n-1]) |- C a`
where `n = sizes.size` -/
def caseArraySizes (mvarId : MVarId) (fvarId : FVarId) (sizes : Array Nat) (xNamePrefix := `x) (hNamePrefix := `h) : MetaM (Array CaseArraySizesSubgoal) :=
public def caseArraySizes (mvarId : MVarId) (fvarId : FVarId) (sizes : Array Nat) (xNamePrefix := `x) (hNamePrefix := `h) : MetaM (Array CaseArraySizesSubgoal) :=
mvarId.withContext do
let a := mkFVar fvarId
let aSize ← mkAppM `Array.size #[a]
@ -72,22 +76,20 @@ def caseArraySizes (mvarId : MVarId) (fvarId : FVarId) (sizes : Array Nat) (xNam
subgoals.mapIdxM fun i subgoal => do
let subst := subgoal.subst
let mvarId := subgoal.mvarId
let hEqSz := (subst.get hEq).fvarId!
if h : i < sizes.size then
let n := sizes[i]
let mvarId ← mvarId.clear subgoal.newHs[0]!
let mvarId ← mvarId.clear (subst.get aSizeFVarId).fvarId!
mvarId.withContext do
let hEqSzSymm ← mkEqSymm (mkFVar hEqSz)
let mvarId ← introArrayLit mvarId a n xNamePrefix hEqSzSymm
let (xs, mvarId) ← mvarId.introN n
let (hEqLit, mvarId) ← mvarId.intro1
let mvarId ← mvarId.clear hEqSz
let (subst, mvarId) ← substCore mvarId hEqLit false subst
pure { mvarId := mvarId, elems := xs, subst := subst }
let hEqSz := (subst.get hEq).fvarId!
let n := sizes[i]
mvarId.withContext do
let hEqSzSymm ← mkEqSymm (mkFVar hEqSz)
let mvarId ← introArrayLit mvarId a n xNamePrefix hEqSzSymm
let (xs, mvarId) ← mvarId.introN n
let (hEqLit, mvarId) ← mvarId.intro1
let mvarId ← mvarId.clear hEqSz
let (subst, mvarId) ← substCore mvarId hEqLit (symm := false) subst
pure { mvarId := mvarId, elems := xs, subst := subst }
else
let (subst, mvarId) ← substCore mvarId hEq false subst
let diseqs := subgoal.newHs.map fun fvarId => (subst.get fvarId).fvarId!
pure { mvarId := mvarId, diseqs := diseqs, subst := subst }
let (subst, mvarId) ← substCore mvarId hEq (symm := false) subst
let diseqs := subgoal.newHs.map fun fvarId => (subst.get fvarId).fvarId!
pure { mvarId := mvarId, diseqs := diseqs, subst := subst }
end Lean.Meta

View file

@ -6,28 +6,25 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Tactic.Subst
public import Lean.Meta.Match.Value
public section
public import Lean.Meta.Basic
public import Lean.Meta.Tactic.FVarSubst
import Lean.Meta.Tactic.Subst
namespace Lean.Meta
structure CaseValueSubgoal where
mvarId : MVarId
newH : FVarId
subst : FVarSubst := {}
deriving Inhabited
/--
Split goal `... |- C x` into two subgoals
`..., (h : x = value) |- C value`
`..., (h : x = value) |- C x`
`..., (h : x != value) |- C x`
where `fvarId` is `x`s id.
The type of `x` must have decidable equality.
Remark: `subst` field of the second subgoal is equal to the input `subst`. -/
private def caseValueAux (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hName : Name := `h) (subst : FVarSubst := {})
-/
def caseValue (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hName : Name := `h)
: MetaM (CaseValueSubgoal × CaseValueSubgoal) :=
mvarId.withContext do
let tag ← mvarId.getTag
@ -42,27 +39,16 @@ private def caseValueAux (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hNa
let val ← mkAppOptM `dite #[none, xEqValue, none, thenMVar, elseMVar]
mvarId.assign val
let (elseH, elseMVarId) ← elseMVar.mvarId!.intro1P
let elseSubgoal := { mvarId := elseMVarId, newH := elseH, subst := subst : CaseValueSubgoal }
let elseSubgoal := { mvarId := elseMVarId, newH := elseH }
let (thenH, thenMVarId) ← thenMVar.mvarId!.intro1P
let symm := false
let clearH := false
let (thenSubst, thenMVarId) ← substCore thenMVarId thenH symm subst clearH
thenMVarId.withContext do
trace[Meta] "subst domain: {thenSubst.domain.map (·.name)}"
let thenH := (thenSubst.get thenH).fvarId!
trace[Meta] "searching for decl"
let _ ← thenH.getDecl
trace[Meta] "found decl"
let thenSubgoal := { mvarId := thenMVarId, newH := (thenSubst.get thenH).fvarId!, subst := thenSubst : CaseValueSubgoal }
let thenSubgoal := { mvarId := thenMVarId, newH := thenH }
pure (thenSubgoal, elseSubgoal)
def caseValue (mvarId : MVarId) (fvarId : FVarId) (value : Expr) : MetaM (CaseValueSubgoal × CaseValueSubgoal) := do
let s ← caseValueAux mvarId fvarId value
appendTagSuffix s.1.mvarId `thenBranch
appendTagSuffix s.2.mvarId `elseBranch
pure s
structure CaseValuesSubgoal where
public structure CaseValuesSubgoal where
mvarId : MVarId
newHs : Array FVarId := #[]
subst : FVarSubst := {}
@ -83,22 +69,15 @@ structure CaseValuesSubgoal where
If `substNewEqs = true`, then the new `h_i` equality hypotheses are substituted in the first `n` cases.
-/
def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNamePrefix := `h) (substNewEqs := false) : MetaM (Array CaseValuesSubgoal) :=
public def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNamePrefix := `h) : MetaM (Array CaseValuesSubgoal) :=
let rec loop : Nat → MVarId → List Expr → Array FVarId → Array CaseValuesSubgoal → MetaM (Array CaseValuesSubgoal)
| _, mvarId, [], _, _ => throwTacticEx `caseValues mvarId "list of values must not be empty"
| i, mvarId, v::vs, hs, subgoals => do
let (thenSubgoal, elseSubgoal) ← caseValueAux mvarId fvarId v (hNamePrefix.appendIndexAfter i) {}
let (thenSubgoal, elseSubgoal) ← caseValue mvarId fvarId v (hNamePrefix.appendIndexAfter i)
appendTagSuffix thenSubgoal.mvarId ((`case).appendIndexAfter i)
let thenMVarId ← hs.foldlM
(fun thenMVarId h => match thenSubgoal.subst.get h with
| Expr.fvar fvarId => thenMVarId.tryClear fvarId
| _ => pure thenMVarId)
thenSubgoal.mvarId
let subgoals ← if substNewEqs then
let (subst, mvarId) ← substCore thenMVarId thenSubgoal.newH false thenSubgoal.subst true
pure <| subgoals.push { mvarId := mvarId, newHs := #[], subst := subst }
else
pure <| subgoals.push { mvarId := thenMVarId, newHs := #[thenSubgoal.newH], subst := thenSubgoal.subst }
let thenMVarId ← thenSubgoal.mvarId.tryClearMany hs
let (subst, mvarId) ← substCore thenMVarId thenSubgoal.newH (symm := false) {} (clearH := true)
let subgoals := subgoals.push { mvarId := mvarId, newHs := #[], subst := subst }
match vs with
| [] => do
appendTagSuffix elseSubgoal.mvarId ((`case).appendIndexAfter (i+1))

View file

@ -16,6 +16,8 @@ public import Lean.Meta.Match.MVarRenaming
import Lean.Meta.Match.SimpH
import Lean.Meta.Match.SolveOverlap
import Lean.Meta.HasNotBit
import Lean.Meta.Match.CaseArraySizes
import Lean.Meta.Match.CaseValues
import Lean.Meta.Match.NamedPatterns
public section
@ -724,7 +726,7 @@ private def processValue (p : Problem) : MetaM (Array Problem) := do
trace[Meta.Match.match] "value step"
let x :: xs := p.vars | unreachable!
let values := collectValues p
let subgoals ← caseValues p.mvarId x.fvarId! values (substNewEqs := true)
let subgoals ← caseValues p.mvarId x.fvarId! values
subgoals.mapIdxM fun i subgoal => do
trace[Meta.Match.match] "processValue subgoal\n{MessageData.ofGoal subgoal.mvarId}"
if h : i < values.size then

View file

@ -8,6 +8,7 @@ module
prelude
public import Lean.Meta.Basic
public import Lean.Meta.Match.Basic
public import Lean.Meta.Match.MatcherInfo
import Lean.Meta.Eqns
public section

View file

@ -10,6 +10,7 @@ prelude
public import Lean.Meta.Match.MatcherApp.Basic
public import Lean.Meta.Match.MatchEqsExt
public import Lean.Meta.Match.AltTelescopes
public import Lean.Meta.AppBuilder
import Lean.Meta.Tactic.Split
import Lean.Meta.Tactic.Refl