feat: "discriminant refinement" for match-expressions
This commit is contained in:
parent
5ac7b1232a
commit
5742b078af
3 changed files with 130 additions and 12 deletions
|
|
@ -81,6 +81,8 @@ private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do
|
|||
structure ElabMatchTypeAndDiscsResult where
|
||||
discrs : Array Expr
|
||||
matchType : Expr
|
||||
/- `true` when performing dependent elimination. We use this to decide whether we optimize the "match unit" case.
|
||||
See `isMatchUnit?`. -/
|
||||
isDep : Bool
|
||||
alts : Array MatchAltView
|
||||
|
||||
|
|
@ -529,17 +531,53 @@ private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array P
|
|||
k decls
|
||||
loop 0 #[]
|
||||
|
||||
private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Array Expr × Expr) :=
|
||||
/-
|
||||
Remark: we performing dependent pattern matching, we often had to write code such as
|
||||
|
||||
```lean
|
||||
def Vec.map' (f : α → β) (xs : Vec α n) : Vec β n :=
|
||||
match n, xs with
|
||||
| _, nil => nil
|
||||
| _, cons a as => cons (f a) (map' f as)
|
||||
```
|
||||
We had to include `n` and the `_`s because the type of `xs` depends on `n`.
|
||||
Moreover, `nil` and `cons a as` have different types.
|
||||
This was quite tedious, and we have implemented an automatic "discriminant"
|
||||
refinement procedure. The procedure is based on the observation that we get
|
||||
a type error whenenver we forget to include `_`s and the indices a discriminant
|
||||
depends on. So, we catch the exception, check whether the type of the discriminant
|
||||
is an indexed family, and add them as new indices.
|
||||
|
||||
The current implementation, adds indices as they are found, and does not
|
||||
try to "sort" the new discriminants.
|
||||
|
||||
Moreover, if the refinement process fails, we report the original error message.
|
||||
-/
|
||||
|
||||
/- Auxiliary structure for storing an type mismatch exception when processing the
|
||||
pattern #`idx` of some alternative. -/
|
||||
structure PatternElabException where
|
||||
ex : Exception
|
||||
idx : Nat
|
||||
|
||||
private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : ExceptT PatternElabException TermElabM (Array Expr × Expr) :=
|
||||
withReader (fun ctx => { ctx with implicitLambda := false }) do
|
||||
let mut patterns := #[]
|
||||
let mut matchType := matchType
|
||||
for patternStx in patternStxs do
|
||||
for idx in [:patternStxs.size] do
|
||||
let patternStx := patternStxs[idx]
|
||||
matchType ← whnf matchType
|
||||
match matchType with
|
||||
| Expr.forallE _ d b _ =>
|
||||
let pattern ← elabTermEnsuringType patternStx d
|
||||
matchType := b.instantiate1 pattern
|
||||
patterns := patterns.push pattern
|
||||
let pattern ← elabTerm patternStx d
|
||||
let pattern ←
|
||||
try
|
||||
withRef patternStx <| ensureHasType d pattern
|
||||
catch ex =>
|
||||
-- Wrap the type mismatch exception for the "discriminant refinement" feature.
|
||||
throwThe PatternElabException { ex := ex, idx := idx }
|
||||
matchType := b.instantiate1 pattern
|
||||
patterns := patterns.push pattern
|
||||
| _ => throwError "unexpected match type"
|
||||
return (patterns, matchType)
|
||||
|
||||
|
|
@ -676,14 +714,15 @@ def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k
|
|||
k localDecls patterns
|
||||
|
||||
private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr)
|
||||
(k : AltLHS → Expr → TermElabM α) : TermElabM α := do
|
||||
(k : AltLHS → Expr → TermElabM α) : ExceptT PatternElabException TermElabM α := do
|
||||
let (patterns, matchType) ← withSynthesize <| elabPatterns patternStxs matchType
|
||||
let localDecls ← finalizePatternDecls patternVarDecls
|
||||
let patterns ← patterns.mapM (instantiateMVars ·)
|
||||
withDepElimPatterns localDecls patterns fun localDecls patterns =>
|
||||
k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
|
||||
id (α := TermElabM α) do
|
||||
let localDecls ← finalizePatternDecls patternVarDecls
|
||||
let patterns ← patterns.mapM (instantiateMVars ·)
|
||||
withDepElimPatterns localDecls patterns fun localDecls patterns =>
|
||||
k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
|
||||
|
||||
def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) := withRef alt.ref do
|
||||
private def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : ExceptT PatternElabException TermElabM (AltLHS × Expr) := withRef alt.ref do
|
||||
let (patternVars, alt) ← collectPatternVars alt
|
||||
trace[Elab.match] "patternVars: {patternVars}"
|
||||
withPatternVars patternVars fun patternVarDecls => do
|
||||
|
|
@ -694,6 +733,61 @@ def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS
|
|||
trace[Elab.match] "rhs: {rhs}"
|
||||
return (altLHS, rhs)
|
||||
|
||||
/--
|
||||
Collect indices for the "discriminant refinement feature". This method is invoked
|
||||
when we detect a type mismatch at a pattern #`idx` of some alternative. -/
|
||||
private def getIndicesToInclude (discrs : Array Expr) (idx : Nat) : TermElabM (Array Expr) := do
|
||||
let discrType ← whnfD (← inferType discrs[idx])
|
||||
matchConstInduct discrType.getAppFn (fun _ => return #[]) fun info _ => do
|
||||
let mut result := #[]
|
||||
let args := discrType.getAppArgs
|
||||
for arg in args[info.numParams : args.size] do
|
||||
unless (← discrs.anyM fun discr => isDefEq discr arg) do
|
||||
result := result.push arg
|
||||
return result
|
||||
|
||||
private partial def elabMatchAltViews (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) : TermElabM (Array Expr × Expr × Array (AltLHS × Expr) × Bool) := do
|
||||
loop discrs matchType altViews none
|
||||
where
|
||||
/-
|
||||
"Discriminant refinement" main loop.
|
||||
`first?` contains the first error message we found before updated the `discrs`. -/
|
||||
loop (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) (first? : Option (SavedState × Exception))
|
||||
: TermElabM (Array Expr × Expr × Array (AltLHS × Expr) × Bool) := do
|
||||
let s ← saveAllState
|
||||
match ← altViews.mapM (fun alt => elabMatchAltView alt matchType) |>.run with
|
||||
| Except.ok alts => return (discrs, matchType, alts, first?.isSome)
|
||||
| Except.error { idx := idx, ex := ex } =>
|
||||
let indices ← getIndicesToInclude discrs idx
|
||||
if indices.isEmpty then
|
||||
match first? with
|
||||
| none => throw ex
|
||||
| some (s, ex) => s.restore; throw ex
|
||||
else
|
||||
let first? ← updateFirst first? ex
|
||||
s.restore
|
||||
let matchType ← updateMatchType indices matchType
|
||||
let altViews ← addWildcardPatterns indices.size altViews
|
||||
let discrs := indices ++ discrs
|
||||
loop discrs matchType altViews first?
|
||||
|
||||
updateFirst (first? : Option (SavedState × Exception)) (ex : Exception) : TermElabM (Option (SavedState × Exception)) := do
|
||||
match first? with
|
||||
| none => return some (← saveAllState, ex)
|
||||
| some _ => return first?
|
||||
|
||||
updateMatchType (indices : Array Expr) (matchType : Expr) : TermElabM Expr :=
|
||||
indices.foldrM (init := matchType) fun index matchType => do
|
||||
let indexType ← inferType index
|
||||
let matchTypeBody ← kabstract matchType index
|
||||
let userName ← mkUserNameFor index
|
||||
return Lean.mkForall userName BinderInfo.default indexType matchTypeBody
|
||||
|
||||
addWildcardPatterns (num : Nat) (altViews : Array MatchAltView) : TermElabM (Array MatchAltView) := do
|
||||
let hole := mkHole (← getRef)
|
||||
let wildcards := mkArray num hole
|
||||
return altViews.map fun altView => { altView with patterns := wildcards ++ altView.patterns }
|
||||
|
||||
def mkMatcher (elimName : Name) (matchType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult :=
|
||||
Meta.Match.mkMatcher elimName matchType numDiscrs lhss
|
||||
|
||||
|
|
@ -733,7 +827,8 @@ private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltVi
|
|||
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
|
||||
let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews
|
||||
trace[Elab.match] "matchType: {matchType}"
|
||||
let alts ← matchAlts.mapM fun alt => elabMatchAltView alt matchType
|
||||
let (discrs, matchType, alts, refined) ← elabMatchAltViews discrs matchType matchAlts
|
||||
let isDep := isDep || refined
|
||||
/-
|
||||
We should not use `synthesizeSyntheticMVarsNoPostponing` here. Otherwise, we will not be
|
||||
able to elaborate examples such as:
|
||||
|
|
|
|||
16
tests/lean/run/discrRefinement.lean
Normal file
16
tests/lean/run/discrRefinement.lean
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
inductive Vec (α : Type u) : Nat → Type u
|
||||
| nil : Vec α 0
|
||||
| cons : α → Vec α n → Vec α (n+1)
|
||||
|
||||
def Vec.map (xs : Vec α n) (f : α → β) : Vec β n :=
|
||||
match xs with
|
||||
| nil => nil
|
||||
| cons a as => cons (f a) (map as f)
|
||||
|
||||
def Vec.map' (f : α → β) : Vec α n → Vec β n
|
||||
| nil => nil
|
||||
| cons a as => cons (f a) (map' f as)
|
||||
|
||||
def Vec.map2 (f : α → α → β) : Vec α n → Vec α n → Vec β n
|
||||
| nil, nil => nil
|
||||
| cons a as, cons b bs => cons (f a b) (map2 f as bs)
|
||||
|
|
@ -39,3 +39,10 @@ def balanceRR' {h c} (left : rbnode h c) (y : Int) (right : hiddenTree h) : almo
|
|||
| _, _, HB c, R a x b => LR (R a x b) y c
|
||||
| _, _, HB c, B a x b => V (R (B a x b) y c)
|
||||
| _, _, HB c, Leaf => V (R Leaf y c)
|
||||
|
||||
def balanceRR'' {h c} (left : rbnode h c) (y : Int) (right : hiddenTree h) : almostNode h :=
|
||||
match left, right with
|
||||
| left, HR c => RR left y c
|
||||
| R a x b, HB c => LR (R a x b) y c
|
||||
| B a x b, HB c => V (R (B a x b) y c)
|
||||
| Leaf, HB c => V (R Leaf y c)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue