feat: "discriminant refinement" for match-expressions

This commit is contained in:
Leonardo de Moura 2021-03-23 20:40:07 -07:00
parent 5ac7b1232a
commit 5742b078af
3 changed files with 130 additions and 12 deletions

View file

@ -81,6 +81,8 @@ private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do
structure ElabMatchTypeAndDiscsResult where
discrs : Array Expr
matchType : Expr
/- `true` when performing dependent elimination. We use this to decide whether we optimize the "match unit" case.
See `isMatchUnit?`. -/
isDep : Bool
alts : Array MatchAltView
@ -529,17 +531,53 @@ private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array P
k decls
loop 0 #[]
private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Array Expr × Expr) :=
/-
Remark: we performing dependent pattern matching, we often had to write code such as
```lean
def Vec.map' (f : α → β) (xs : Vec α n) : Vec β n :=
match n, xs with
| _, nil => nil
| _, cons a as => cons (f a) (map' f as)
```
We had to include `n` and the `_`s because the type of `xs` depends on `n`.
Moreover, `nil` and `cons a as` have different types.
This was quite tedious, and we have implemented an automatic "discriminant"
refinement procedure. The procedure is based on the observation that we get
a type error whenenver we forget to include `_`s and the indices a discriminant
depends on. So, we catch the exception, check whether the type of the discriminant
is an indexed family, and add them as new indices.
The current implementation, adds indices as they are found, and does not
try to "sort" the new discriminants.
Moreover, if the refinement process fails, we report the original error message.
-/
/- Auxiliary structure for storing an type mismatch exception when processing the
pattern #`idx` of some alternative. -/
structure PatternElabException where
ex : Exception
idx : Nat
private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : ExceptT PatternElabException TermElabM (Array Expr × Expr) :=
withReader (fun ctx => { ctx with implicitLambda := false }) do
let mut patterns := #[]
let mut matchType := matchType
for patternStx in patternStxs do
for idx in [:patternStxs.size] do
let patternStx := patternStxs[idx]
matchType ← whnf matchType
match matchType with
| Expr.forallE _ d b _ =>
let pattern ← elabTermEnsuringType patternStx d
matchType := b.instantiate1 pattern
patterns := patterns.push pattern
let pattern ← elabTerm patternStx d
let pattern ←
try
withRef patternStx <| ensureHasType d pattern
catch ex =>
-- Wrap the type mismatch exception for the "discriminant refinement" feature.
throwThe PatternElabException { ex := ex, idx := idx }
matchType := b.instantiate1 pattern
patterns := patterns.push pattern
| _ => throwError "unexpected match type"
return (patterns, matchType)
@ -676,14 +714,15 @@ def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k
k localDecls patterns
private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr)
(k : AltLHS → Expr → TermElabM α) : TermElabM α := do
(k : AltLHS → Expr → TermElabM α) : ExceptT PatternElabException TermElabM α := do
let (patterns, matchType) ← withSynthesize <| elabPatterns patternStxs matchType
let localDecls ← finalizePatternDecls patternVarDecls
let patterns ← patterns.mapM (instantiateMVars ·)
withDepElimPatterns localDecls patterns fun localDecls patterns =>
k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
id (α := TermElabM α) do
let localDecls ← finalizePatternDecls patternVarDecls
let patterns ← patterns.mapM (instantiateMVars ·)
withDepElimPatterns localDecls patterns fun localDecls patterns =>
k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) := withRef alt.ref do
private def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : ExceptT PatternElabException TermElabM (AltLHS × Expr) := withRef alt.ref do
let (patternVars, alt) ← collectPatternVars alt
trace[Elab.match] "patternVars: {patternVars}"
withPatternVars patternVars fun patternVarDecls => do
@ -694,6 +733,61 @@ def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS
trace[Elab.match] "rhs: {rhs}"
return (altLHS, rhs)
/--
Collect indices for the "discriminant refinement feature". This method is invoked
when we detect a type mismatch at a pattern #`idx` of some alternative. -/
private def getIndicesToInclude (discrs : Array Expr) (idx : Nat) : TermElabM (Array Expr) := do
let discrType ← whnfD (← inferType discrs[idx])
matchConstInduct discrType.getAppFn (fun _ => return #[]) fun info _ => do
let mut result := #[]
let args := discrType.getAppArgs
for arg in args[info.numParams : args.size] do
unless (← discrs.anyM fun discr => isDefEq discr arg) do
result := result.push arg
return result
private partial def elabMatchAltViews (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) : TermElabM (Array Expr × Expr × Array (AltLHS × Expr) × Bool) := do
loop discrs matchType altViews none
where
/-
"Discriminant refinement" main loop.
`first?` contains the first error message we found before updated the `discrs`. -/
loop (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) (first? : Option (SavedState × Exception))
: TermElabM (Array Expr × Expr × Array (AltLHS × Expr) × Bool) := do
let s ← saveAllState
match ← altViews.mapM (fun alt => elabMatchAltView alt matchType) |>.run with
| Except.ok alts => return (discrs, matchType, alts, first?.isSome)
| Except.error { idx := idx, ex := ex } =>
let indices ← getIndicesToInclude discrs idx
if indices.isEmpty then
match first? with
| none => throw ex
| some (s, ex) => s.restore; throw ex
else
let first? ← updateFirst first? ex
s.restore
let matchType ← updateMatchType indices matchType
let altViews ← addWildcardPatterns indices.size altViews
let discrs := indices ++ discrs
loop discrs matchType altViews first?
updateFirst (first? : Option (SavedState × Exception)) (ex : Exception) : TermElabM (Option (SavedState × Exception)) := do
match first? with
| none => return some (← saveAllState, ex)
| some _ => return first?
updateMatchType (indices : Array Expr) (matchType : Expr) : TermElabM Expr :=
indices.foldrM (init := matchType) fun index matchType => do
let indexType ← inferType index
let matchTypeBody ← kabstract matchType index
let userName ← mkUserNameFor index
return Lean.mkForall userName BinderInfo.default indexType matchTypeBody
addWildcardPatterns (num : Nat) (altViews : Array MatchAltView) : TermElabM (Array MatchAltView) := do
let hole := mkHole (← getRef)
let wildcards := mkArray num hole
return altViews.map fun altView => { altView with patterns := wildcards ++ altView.patterns }
def mkMatcher (elimName : Name) (matchType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult :=
Meta.Match.mkMatcher elimName matchType numDiscrs lhss
@ -733,7 +827,8 @@ private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltVi
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews
trace[Elab.match] "matchType: {matchType}"
let alts ← matchAlts.mapM fun alt => elabMatchAltView alt matchType
let (discrs, matchType, alts, refined) ← elabMatchAltViews discrs matchType matchAlts
let isDep := isDep || refined
/-
We should not use `synthesizeSyntheticMVarsNoPostponing` here. Otherwise, we will not be
able to elaborate examples such as:

View file

@ -0,0 +1,16 @@
inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons : α → Vec α n → Vec α (n+1)
def Vec.map (xs : Vec α n) (f : α → β) : Vec β n :=
match xs with
| nil => nil
| cons a as => cons (f a) (map as f)
def Vec.map' (f : α → β) : Vec α n → Vec β n
| nil => nil
| cons a as => cons (f a) (map' f as)
def Vec.map2 (f : αα → β) : Vec α n → Vec α n → Vec β n
| nil, nil => nil
| cons a as, cons b bs => cons (f a b) (map2 f as bs)

View file

@ -39,3 +39,10 @@ def balanceRR' {h c} (left : rbnode h c) (y : Int) (right : hiddenTree h) : almo
| _, _, HB c, R a x b => LR (R a x b) y c
| _, _, HB c, B a x b => V (R (B a x b) y c)
| _, _, HB c, Leaf => V (R Leaf y c)
def balanceRR'' {h c} (left : rbnode h c) (y : Int) (right : hiddenTree h) : almostNode h :=
match left, right with
| left, HR c => RR left y c
| R a x b, HB c => LR (R a x b) y c
| B a x b, HB c => V (R (B a x b) y c)
| Leaf, HB c => V (R Leaf y c)