diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 1f93e5a1c5..c2149ca911 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -35,26 +35,6 @@ private def expandSimpleMatch (stx discr lhsVar rhs : Syntax) (expectedType? : O let newStx ← `(let $lhsVar := $discr; $rhs) withMacroExpansion stx newStx <| elabTerm newStx expectedType? -private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr × Bool) := do - let mut discrs := #[] - let mut i := 0 - let mut matchType := matchType - let mut isDep := false - for discrStx in discrStxs do - i := i + 1 - matchType ← whnf matchType - match matchType with - | Expr.forallE _ d b _ => - let discr ← fullApproxDefEq <| elabTermEnsuringType discrStx[1] d - trace[Elab.match] "discr #{i} {discr} : {d}" - if b.hasLooseBVars then - isDep := true - matchType ← b.instantiate1 discr - discrs := discrs.push discr - | _ => - throwError "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected" - pure (discrs, isDep) - private def mkUserNameFor (e : Expr) : TermElabM Name := do match e with /- Remark: we use `mkFreshUserName` to make sure we don't add a variable to the local context that can be resolved to `e`. -/ @@ -109,11 +89,40 @@ structure ElabMatchTypeAndDiscsResult where isDep : Bool alts : Array MatchAltView -private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr) - : TermElabM ElabMatchTypeAndDiscsResult := do - let numDiscrs := discrStxs.size - if matchOptType.isNone then - let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (isDep : Bool) (matchAltViews : Array MatchAltView) := do + private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr) + : TermElabM ElabMatchTypeAndDiscsResult := do + let numDiscrs := discrStxs.size + if matchOptType.isNone then + elabDiscrs discrStxs.size (discrs := #[]) (isDep := false) expectedType matchAltViews + else + let matchTypeStx := matchOptType[0][1] + let matchType ← elabType matchTypeStx + let (discrs, isDep) ← elabDiscrsWitMatchType matchType expectedType + return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews } + where + /- Easy case: elaborate discriminant when the match-type has been explicitly provided by the user. -/ + elabDiscrsWitMatchType (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr × Bool) := do + let mut discrs := #[] + let mut i := 0 + let mut matchType := matchType + let mut isDep := false + for discrStx in discrStxs do + i := i + 1 + matchType ← whnf matchType + match matchType with + | Expr.forallE _ d b _ => + let discr ← fullApproxDefEq <| elabTermEnsuringType discrStx[1] d + trace[Elab.match] "discr #{i} {discr} : {d}" + if b.hasLooseBVars then + isDep := true + matchType ← b.instantiate1 discr + discrs := discrs.push discr + | _ => + throwError "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected" + return (discrs, isDep) + + /- Elaborate discriminants inferring the match-type -/ + elabDiscrs (i : Nat) (discrs : Array Expr) (matchType : Expr) (isDep : Bool) (matchAltViews : Array MatchAltView) := do match i with | 0 => return { discrs := discrs.reverse, matchType := matchType, isDep := isDep, alts := matchAltViews } | i+1 => @@ -126,7 +135,7 @@ private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Sy let isDep := isDep || matchTypeBody.hasLooseBVars let userName ← mkUserNameFor discr if discrStx[0].isNone then - loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) isDep matchAltViews + elabDiscrs i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) isDep matchAltViews else let identStx := discrStx[0][0] withLocalDeclD userName discrType fun x => do @@ -138,13 +147,7 @@ private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Sy let discrs := (discrs.push refl).push discr let matchAltViews := matchAltViews.map fun altView => { altView with patterns := altView.patterns.insertAt (i+1) identStx } - loop i discrs matchType isDep matchAltViews - loop discrStxs.size (discrs := #[]) (isDep := false) expectedType matchAltViews - else - let matchTypeStx := matchOptType[0][1] - let matchType ← elabType matchTypeStx - let (discrs, isDep) ← elabDiscrsWitMatchType discrStxs matchType expectedType - return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews } + elabDiscrs i discrs matchType isDep matchAltViews def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array MatchAltView) := do matchAlts.mapM fun matchAlt => do