From 5742b078af448b6deddf729669359d06184eb8ca Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 23 Mar 2021 20:40:07 -0700 Subject: [PATCH] feat: "discriminant refinement" for `match`-expressions --- src/Lean/Elab/Match.lean | 119 +++++++++++++++++++++++++--- tests/lean/run/discrRefinement.lean | 16 ++++ tests/lean/run/matchWithSearch.lean | 7 ++ 3 files changed, 130 insertions(+), 12 deletions(-) create mode 100644 tests/lean/run/discrRefinement.lean diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 9a25a70e43..53b0a41ffa 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -81,6 +81,8 @@ private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do structure ElabMatchTypeAndDiscsResult where discrs : Array Expr matchType : Expr + /- `true` when performing dependent elimination. We use this to decide whether we optimize the "match unit" case. + See `isMatchUnit?`. -/ isDep : Bool alts : Array MatchAltView @@ -529,17 +531,53 @@ private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array P k decls loop 0 #[] -private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Array Expr × Expr) := +/- +Remark: we performing dependent pattern matching, we often had to write code such as + +```lean +def Vec.map' (f : α → β) (xs : Vec α n) : Vec β n := + match n, xs with + | _, nil => nil + | _, cons a as => cons (f a) (map' f as) +``` +We had to include `n` and the `_`s because the type of `xs` depends on `n`. +Moreover, `nil` and `cons a as` have different types. +This was quite tedious, and we have implemented an automatic "discriminant" +refinement procedure. The procedure is based on the observation that we get +a type error whenenver we forget to include `_`s and the indices a discriminant +depends on. So, we catch the exception, check whether the type of the discriminant +is an indexed family, and add them as new indices. + +The current implementation, adds indices as they are found, and does not +try to "sort" the new discriminants. + +Moreover, if the refinement process fails, we report the original error message. +-/ + +/- Auxiliary structure for storing an type mismatch exception when processing the + pattern #`idx` of some alternative. -/ +structure PatternElabException where + ex : Exception + idx : Nat + +private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : ExceptT PatternElabException TermElabM (Array Expr × Expr) := withReader (fun ctx => { ctx with implicitLambda := false }) do let mut patterns := #[] let mut matchType := matchType - for patternStx in patternStxs do + for idx in [:patternStxs.size] do + let patternStx := patternStxs[idx] matchType ← whnf matchType match matchType with | Expr.forallE _ d b _ => - let pattern ← elabTermEnsuringType patternStx d - matchType := b.instantiate1 pattern - patterns := patterns.push pattern + let pattern ← elabTerm patternStx d + let pattern ← + try + withRef patternStx <| ensureHasType d pattern + catch ex => + -- Wrap the type mismatch exception for the "discriminant refinement" feature. + throwThe PatternElabException { ex := ex, idx := idx } + matchType := b.instantiate1 pattern + patterns := patterns.push pattern | _ => throwError "unexpected match type" return (patterns, matchType) @@ -676,14 +714,15 @@ def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k k localDecls patterns private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) - (k : AltLHS → Expr → TermElabM α) : TermElabM α := do + (k : AltLHS → Expr → TermElabM α) : ExceptT PatternElabException TermElabM α := do let (patterns, matchType) ← withSynthesize <| elabPatterns patternStxs matchType - let localDecls ← finalizePatternDecls patternVarDecls - let patterns ← patterns.mapM (instantiateMVars ·) - withDepElimPatterns localDecls patterns fun localDecls patterns => - k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType + id (α := TermElabM α) do + let localDecls ← finalizePatternDecls patternVarDecls + let patterns ← patterns.mapM (instantiateMVars ·) + withDepElimPatterns localDecls patterns fun localDecls patterns => + k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType -def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) := withRef alt.ref do +private def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : ExceptT PatternElabException TermElabM (AltLHS × Expr) := withRef alt.ref do let (patternVars, alt) ← collectPatternVars alt trace[Elab.match] "patternVars: {patternVars}" withPatternVars patternVars fun patternVarDecls => do @@ -694,6 +733,61 @@ def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS trace[Elab.match] "rhs: {rhs}" return (altLHS, rhs) +/-- + Collect indices for the "discriminant refinement feature". This method is invoked + when we detect a type mismatch at a pattern #`idx` of some alternative. -/ +private def getIndicesToInclude (discrs : Array Expr) (idx : Nat) : TermElabM (Array Expr) := do + let discrType ← whnfD (← inferType discrs[idx]) + matchConstInduct discrType.getAppFn (fun _ => return #[]) fun info _ => do + let mut result := #[] + let args := discrType.getAppArgs + for arg in args[info.numParams : args.size] do + unless (← discrs.anyM fun discr => isDefEq discr arg) do + result := result.push arg + return result + +private partial def elabMatchAltViews (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) : TermElabM (Array Expr × Expr × Array (AltLHS × Expr) × Bool) := do + loop discrs matchType altViews none +where + /- + "Discriminant refinement" main loop. + `first?` contains the first error message we found before updated the `discrs`. -/ + loop (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) (first? : Option (SavedState × Exception)) + : TermElabM (Array Expr × Expr × Array (AltLHS × Expr) × Bool) := do + let s ← saveAllState + match ← altViews.mapM (fun alt => elabMatchAltView alt matchType) |>.run with + | Except.ok alts => return (discrs, matchType, alts, first?.isSome) + | Except.error { idx := idx, ex := ex } => + let indices ← getIndicesToInclude discrs idx + if indices.isEmpty then + match first? with + | none => throw ex + | some (s, ex) => s.restore; throw ex + else + let first? ← updateFirst first? ex + s.restore + let matchType ← updateMatchType indices matchType + let altViews ← addWildcardPatterns indices.size altViews + let discrs := indices ++ discrs + loop discrs matchType altViews first? + + updateFirst (first? : Option (SavedState × Exception)) (ex : Exception) : TermElabM (Option (SavedState × Exception)) := do + match first? with + | none => return some (← saveAllState, ex) + | some _ => return first? + + updateMatchType (indices : Array Expr) (matchType : Expr) : TermElabM Expr := + indices.foldrM (init := matchType) fun index matchType => do + let indexType ← inferType index + let matchTypeBody ← kabstract matchType index + let userName ← mkUserNameFor index + return Lean.mkForall userName BinderInfo.default indexType matchTypeBody + + addWildcardPatterns (num : Nat) (altViews : Array MatchAltView) : TermElabM (Array MatchAltView) := do + let hole := mkHole (← getRef) + let wildcards := mkArray num hole + return altViews.map fun altView => { altView with patterns := wildcards ++ altView.patterns } + def mkMatcher (elimName : Name) (matchType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult := Meta.Match.mkMatcher elimName matchType numDiscrs lhss @@ -733,7 +827,8 @@ private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltVi let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews trace[Elab.match] "matchType: {matchType}" - let alts ← matchAlts.mapM fun alt => elabMatchAltView alt matchType + let (discrs, matchType, alts, refined) ← elabMatchAltViews discrs matchType matchAlts + let isDep := isDep || refined /- We should not use `synthesizeSyntheticMVarsNoPostponing` here. Otherwise, we will not be able to elaborate examples such as: diff --git a/tests/lean/run/discrRefinement.lean b/tests/lean/run/discrRefinement.lean new file mode 100644 index 0000000000..0b0af3b20c --- /dev/null +++ b/tests/lean/run/discrRefinement.lean @@ -0,0 +1,16 @@ +inductive Vec (α : Type u) : Nat → Type u + | nil : Vec α 0 + | cons : α → Vec α n → Vec α (n+1) + +def Vec.map (xs : Vec α n) (f : α → β) : Vec β n := + match xs with + | nil => nil + | cons a as => cons (f a) (map as f) + +def Vec.map' (f : α → β) : Vec α n → Vec β n + | nil => nil + | cons a as => cons (f a) (map' f as) + +def Vec.map2 (f : α → α → β) : Vec α n → Vec α n → Vec β n + | nil, nil => nil + | cons a as, cons b bs => cons (f a b) (map2 f as bs) diff --git a/tests/lean/run/matchWithSearch.lean b/tests/lean/run/matchWithSearch.lean index 2d341b1f60..c13b35f7b3 100644 --- a/tests/lean/run/matchWithSearch.lean +++ b/tests/lean/run/matchWithSearch.lean @@ -39,3 +39,10 @@ def balanceRR' {h c} (left : rbnode h c) (y : Int) (right : hiddenTree h) : almo | _, _, HB c, R a x b => LR (R a x b) y c | _, _, HB c, B a x b => V (R (B a x b) y c) | _, _, HB c, Leaf => V (R Leaf y c) + +def balanceRR'' {h c} (left : rbnode h c) (y : Int) (right : hiddenTree h) : almostNode h := + match left, right with + | left, HR c => RR left y c + | R a x b, HB c => LR (R a x b) y c + | B a x b, HB c => V (R (B a x b) y c) + | Leaf, HB c => V (R Leaf y c)