diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index f7d8b40d53..baf49bc745 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -36,10 +36,11 @@ private def expandSimpleMatchWithType (stx discr lhsVar type rhs : Syntax) (expe let newStx ← `(let $lhsVar : $type := $discr; $rhs) withMacroExpansion stx newStx <| elabTerm newStx expectedType? -private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) := do +private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr × Bool) := do let mut discrs := #[] let mut i := 0 let mut matchType := matchType + let mut isDep := false for discrStx in discrStxs do i := i + 1 matchType ← whnf matchType @@ -47,11 +48,13 @@ private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) | Expr.forallE _ d b _ => let discr ← fullApproxDefEq <| elabTermEnsuringType discrStx[1] d trace[Elab.match]! "discr #{i} {discr} : {d}" + if b.hasLooseBVars then + isDep := true matchType ← b.instantiate1 discr discrs := discrs.push discr | _ => throwError! "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected" - pure discrs + pure (discrs, isDep) private def mkUserNameFor (e : Expr) : TermElabM Name := match e with @@ -74,13 +77,19 @@ private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do pure localDecl.value | _ => throwErrorAt discr "unexpected discriminant" +structure ElabMatchTypeAndDiscsResult where + discrs : Array Expr + matchType : Expr + isDep : Bool + alts : Array MatchAltView + private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr) - : TermElabM (Array Expr × Expr × Array MatchAltView) := do + : TermElabM ElabMatchTypeAndDiscsResult := do let numDiscrs := discrStxs.size if matchOptType.isNone then - let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (matchAltViews : Array MatchAltView) := do + let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (isDep : Bool) (matchAltViews : Array MatchAltView) := do match i with - | 0 => pure (discrs.reverse, matchType, matchAltViews) + | 0 => return { discrs := discrs.reverse, matchType := matchType, isDep := isDep, alts := matchAltViews } | i+1 => let discrStx := discrStxs[i] let discr ← elabAtomicDiscr discrStx @@ -88,9 +97,10 @@ private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Sy let discrType ← inferType discr let discrType ← instantiateMVars discrType let matchTypeBody ← kabstract matchType discr + let isDep := isDep || matchTypeBody.hasLooseBVars let userName ← mkUserNameFor discr if discrStx[0].isNone then - loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) matchAltViews + loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) isDep matchAltViews else let identStx := discrStx[0][0] withLocalDeclD userName discrType fun x => do @@ -102,13 +112,13 @@ private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Sy let discrs := (discrs.push refl).push discr let matchAltViews := matchAltViews.map fun altView => { altView with patterns := altView.patterns.insertAt (i+1) identStx } - loop i discrs matchType matchAltViews - loop discrStxs.size #[] expectedType matchAltViews + loop i discrs matchType isDep matchAltViews + loop discrStxs.size (discrs := #[]) (isDep := false) expectedType matchAltViews else let matchTypeStx := matchOptType[0][1] let matchType ← elabType matchTypeStx - let discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType - pure (discrs, matchType, matchAltViews) + let (discrs, isDep) ← elabDiscrsWitMatchType discrStxs matchType expectedType + return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews } def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array MatchAltView) := do matchAlts.mapM fun matchAlt => do @@ -723,8 +733,8 @@ private def isMatchUnit? (altLHSS : List Match.AltLHS) (rhss : Array Expr) : Met private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr) : TermElabM Expr := do - let (discrs, matchType, altLHSS, rhss) ← commitIfDidNotPostpone do - let (discrs, matchType, altViews) ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType + let (discrs, matchType, altLHSS, isDep, rhss) ← commitIfDidNotPostpone do + let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews trace[Elab.match]! "matchType: {matchType}" let alts ← matchAlts.mapM fun alt => elabMatchAltView alt matchType @@ -780,8 +790,8 @@ private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltVi tryPostpone throwMVarError m!"invalid match-expression, pattern contains metavariables{indentExpr (← p.toExpr)}" pure altLHS - return (discrs, matchType, altLHSS, rhss) - if let some r ← isMatchUnit? altLHSS rhss then + return (discrs, matchType, altLHSS, isDep, rhss) + if let some r ← if isDep then pure none else isMatchUnit? altLHSS rhss then return r else let numDiscrs := discrs.size diff --git a/tests/lean/run/match_unit.lean b/tests/lean/run/match_unit.lean new file mode 100644 index 0000000000..1f64d52f86 --- /dev/null +++ b/tests/lean/run/match_unit.lean @@ -0,0 +1,2 @@ +theorem ex : ∀ x : Unit, x = () := by + intro (); rfl