perf: when matching on values, avoid generating hyps when not needed (#11508)
This PR avoids generating hyps when not needed (i.e. if there is a
catch-all so no completeness checking needed) during matching on values.
This tweak was made possible by #11220.
This commit is contained in:
parent
d4c832ecb0
commit
4b77e226ab
4 changed files with 99 additions and 42 deletions
|
|
@ -12,20 +12,14 @@ import Lean.Meta.Tactic.Subst
|
|||
|
||||
namespace Lean.Meta
|
||||
|
||||
structure CaseValueSubgoal where
|
||||
mvarId : MVarId
|
||||
newH : FVarId
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
Split goal `... |- C x` into two subgoals
|
||||
`..., (h : x = value) |- C x`
|
||||
`..., (h : x != value) |- C x`
|
||||
where `fvarId` is `x`s id.
|
||||
Split goal `... |- C x`,, where `fvarId` is `x`s id, into two subgoals
|
||||
`..., |- (h : x = value) → C x`
|
||||
`..., |- (h : x != value) → C x`
|
||||
The type of `x` must have decidable equality.
|
||||
-/
|
||||
def caseValue (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hName : Name := `h)
|
||||
: MetaM (CaseValueSubgoal × CaseValueSubgoal) :=
|
||||
: MetaM (MVarId × MVarId) :=
|
||||
mvarId.withContext do
|
||||
let tag ← mvarId.getTag
|
||||
mvarId.checkNotAssigned `caseValue
|
||||
|
|
@ -38,15 +32,7 @@ def caseValue (mvarId : MVarId) (fvarId : FVarId) (value : Expr) (hName : Name :
|
|||
let elseMVar ← mkFreshExprSyntheticOpaqueMVar elseTarget tag
|
||||
let val ← mkAppOptM `dite #[none, xEqValue, none, thenMVar, elseMVar]
|
||||
mvarId.assign val
|
||||
let (elseH, elseMVarId) ← elseMVar.mvarId!.intro1P
|
||||
let elseSubgoal := { mvarId := elseMVarId, newH := elseH }
|
||||
let (thenH, thenMVarId) ← thenMVar.mvarId!.intro1P
|
||||
thenMVarId.withContext do
|
||||
trace[Meta] "searching for decl"
|
||||
let _ ← thenH.getDecl
|
||||
trace[Meta] "found decl"
|
||||
let thenSubgoal := { mvarId := thenMVarId, newH := thenH }
|
||||
pure (thenSubgoal, elseSubgoal)
|
||||
return (thenMVar.mvarId!, elseMVar.mvarId!)
|
||||
|
||||
public structure CaseValuesSubgoal where
|
||||
mvarId : MVarId
|
||||
|
|
@ -55,34 +41,44 @@ public structure CaseValuesSubgoal where
|
|||
deriving Inhabited
|
||||
|
||||
/--
|
||||
Split goal `... |- C x` into values.size + 1 subgoals
|
||||
1) `..., (h_1 : x = value[0]) |- C value[0]`
|
||||
Split goal `... |- C x`, where `fvarId` is `x`s id, into `values.size + 1` subgoals
|
||||
1) `..., (h_1 : x = value[0]) |- C value[0]`
|
||||
...
|
||||
n) `..., (h_n : x = value[n - 1]) |- C value[n - 1]`
|
||||
n) `..., (h_n : x = value[n - 1]) |- C value[n - 1]`
|
||||
n+1) `..., (h_1 : x != value[0]) ... (h_n : x != value[n-1]) |- C x`
|
||||
where `n = values.size`
|
||||
where `fvarId` is `x`s id.
|
||||
The type of `x` must have decidable equality.
|
||||
|
||||
Remark: the last subgoal is for the "else" catchall case, and its `subst` is `{}`.
|
||||
Remark: the field `newHs` has size 1 forall but the last subgoal.
|
||||
|
||||
If `substNewEqs = true`, then the new `h_i` equality hypotheses are substituted in the first `n` cases.
|
||||
If `needsHyps = false` then the else case comes without hypotheses.
|
||||
-/
|
||||
public def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNamePrefix := `h) : MetaM (Array CaseValuesSubgoal) :=
|
||||
public def caseValues (mvarId : MVarId) (fvarId : FVarId) (values : Array Expr) (hNamePrefix := `h)
|
||||
(needHyps := true) : MetaM (Array CaseValuesSubgoal) :=
|
||||
let rec loop : Nat → MVarId → List Expr → Array FVarId → Array CaseValuesSubgoal → MetaM (Array CaseValuesSubgoal)
|
||||
| _, mvarId, [], _, _ => throwTacticEx `caseValues mvarId "list of values must not be empty"
|
||||
| i, mvarId, v::vs, hs, subgoals => do
|
||||
let (thenSubgoal, elseSubgoal) ← caseValue mvarId fvarId v (hNamePrefix.appendIndexAfter i)
|
||||
appendTagSuffix thenSubgoal.mvarId ((`case).appendIndexAfter i)
|
||||
let thenMVarId ← thenSubgoal.mvarId.tryClearMany hs
|
||||
let (subst, mvarId) ← substCore thenMVarId thenSubgoal.newH (symm := false) {} (clearH := true)
|
||||
let subgoals := subgoals.push { mvarId := mvarId, newHs := #[], subst := subst }
|
||||
let (thenMVarId, elseMVarId) ← caseValue mvarId fvarId v (hNamePrefix.appendIndexAfter i)
|
||||
appendTagSuffix thenMVarId ((`case).appendIndexAfter i)
|
||||
let thenMVarId ← thenMVarId.tryClearMany hs
|
||||
let (thenH, thenMVarId) ← thenMVarId.intro1P
|
||||
let (subst, thenMVarId) ← substCore thenMVarId thenH (symm := false) {} (clearH := true)
|
||||
let subgoals := subgoals.push { mvarId := thenMVarId, newHs := #[], subst := subst }
|
||||
let (hs', elseMVarId) ←
|
||||
if needHyps then
|
||||
let (elseH, elseMVarId) ← elseMVarId.intro1P
|
||||
pure (hs.push elseH, elseMVarId)
|
||||
else
|
||||
let elseMVarId ← elseMVarId.intro1_
|
||||
pure (hs, elseMVarId)
|
||||
match vs with
|
||||
| [] => do
|
||||
appendTagSuffix elseSubgoal.mvarId ((`case).appendIndexAfter (i+1))
|
||||
pure $ subgoals.push { mvarId := elseSubgoal.mvarId, newHs := hs.push elseSubgoal.newH, subst := {} }
|
||||
| vs => loop (i+1) elseSubgoal.mvarId vs (hs.push elseSubgoal.newH) subgoals
|
||||
appendTagSuffix elseMVarId ((`case).appendIndexAfter (i+1))
|
||||
pure $ subgoals.push { mvarId := elseMVarId, newHs := hs', subst := {} }
|
||||
| vs =>
|
||||
loop (i+1) elseMVarId vs hs' subgoals
|
||||
|
||||
loop 1 mvarId values.toList #[] #[]
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
|
|
@ -722,11 +722,23 @@ private def isFirstPatternVar (alt : Alt) : Bool :=
|
|||
| .var _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
private def Pattern.isRefutable : Pattern → Bool
|
||||
| .var _ => false
|
||||
| .inaccessible _ => false
|
||||
| .as _ p _ => p.isRefutable
|
||||
| .arrayLit .. => true
|
||||
| .ctor .. => true
|
||||
| .val .. => true
|
||||
|
||||
private def triviallyComplete (p : Problem) : Bool :=
|
||||
!p.alts.isEmpty && p.alts.getLast!.patterns.all (!·.isRefutable)
|
||||
|
||||
private def processValue (p : Problem) : MetaM (Array Problem) := do
|
||||
trace[Meta.Match.match] "value step"
|
||||
let x :: xs := p.vars | unreachable!
|
||||
let values := collectValues p
|
||||
let subgoals ← caseValues p.mvarId x.fvarId! values
|
||||
let needHyps := !triviallyComplete p || p.alts.any (!·.notAltIdxs.isEmpty)
|
||||
let subgoals ← caseValues p.mvarId x.fvarId! values (needHyps := needHyps)
|
||||
subgoals.mapIdxM fun i subgoal => do
|
||||
trace[Meta.Match.match] "processValue subgoal\n{MessageData.ofGoal subgoal.mvarId}"
|
||||
if h : i < values.size then
|
||||
|
|
@ -900,14 +912,6 @@ private def moveToFront (p : Problem) (i : Nat) : Problem :=
|
|||
else
|
||||
p
|
||||
|
||||
def Pattern.isRefutable : Pattern → Bool
|
||||
| .var _ => false
|
||||
| .inaccessible _ => false
|
||||
| .as _ p _ => p.isRefutable
|
||||
| .arrayLit .. => true
|
||||
| .ctor .. => true
|
||||
| .val .. => true
|
||||
|
||||
/--
|
||||
Returns the index of the first pattern in the first alternative that is refutable
|
||||
(i.e. not a variable or inaccessible pattern). We want to handle these first
|
||||
|
|
|
|||
|
|
@ -161,6 +161,23 @@ does not start with a forall, lambda or let. -/
|
|||
abbrev _root_.Lean.MVarId.intro1P (mvarId : MVarId) : MetaM (FVarId × MVarId) :=
|
||||
intro1Core mvarId true
|
||||
|
||||
/--
|
||||
Given a goal `... |- β → α`, returns a goal `... ⊢ α`.
|
||||
Like `intro h; clear h`, but without ever appending to the local context.
|
||||
-/
|
||||
def _root_.Lean.MVarId.intro1_ (mvarId : MVarId) : MetaM MVarId := do
|
||||
mvarId.withContext do
|
||||
let target ← mvarId.getType'
|
||||
match target with
|
||||
| .forallE n β α bi =>
|
||||
if α.hasLooseBVars then
|
||||
throwError "intro1_: expected arrow type\n{mvarId}"
|
||||
let tag ← mvarId.getTag
|
||||
let newMVar ← mkFreshExprSyntheticOpaqueMVar α tag
|
||||
mvarId.assign (.lam n β newMVar bi)
|
||||
return newMVar.mvarId!
|
||||
| _ => throwError "intro1_: expected arrow type\n{mvarId}"
|
||||
|
||||
/--
|
||||
Calculate the number of new hypotheses that would be created by `intros`,
|
||||
i.e. the number of binders which can be introduced without unfolding definitions.
|
||||
|
|
|
|||
40
tests/lean/run/match_nat.lean
Normal file
40
tests/lean/run/match_nat.lean
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
def f : Nat → Nat
|
||||
| 0 => 0
|
||||
| 10 => 1
|
||||
| 100 => 2
|
||||
| _ => 3
|
||||
|
||||
|
||||
/--
|
||||
info: def f.match_1.{u_1} : (motive : Nat → Sort u_1) →
|
||||
(x : Nat) → (Unit → motive 0) → (Unit → motive 10) → (Unit → motive 100) → ((x : Nat) → motive x) → motive x
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print sig f.match_1
|
||||
|
||||
|
||||
/--
|
||||
info: private def f.match_1.splitter.{u_1} : (motive : Nat → Sort u_1) →
|
||||
(x : Nat) →
|
||||
(Unit → motive 0) →
|
||||
(Unit → motive 10) →
|
||||
(Unit → motive 100) → ((x : Nat) → (x = 0 → False) → (x = 10 → False) → (x = 100 → False) → motive x) → motive x
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print sig f.match_1.splitter
|
||||
|
||||
/--
|
||||
info: private theorem f.match_1.eq_4.{u_1} : ∀ (motive : Nat → Sort u_1) (x : Nat) (h_1 : Unit → motive 0)
|
||||
(h_2 : Unit → motive 10) (h_3 : Unit → motive 100) (h_4 : (x : Nat) → motive x),
|
||||
(x = 0 → False) →
|
||||
(x = 10 → False) →
|
||||
(x = 100 → False) →
|
||||
(match x with
|
||||
| 0 => h_1 ()
|
||||
| 10 => h_2 ()
|
||||
| 100 => h_3 ()
|
||||
| x => h_4 x) =
|
||||
h_4 x
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print sig f.match_1.eq_4
|
||||
Loading…
Add table
Reference in a new issue