diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 135240f0d1..62068a62f1 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -72,11 +72,28 @@ match e with | Expr.fvar fvarId _ => do localDecl ← getLocalDecl fvarId; pure localDecl.userName | _ => mkFreshBinderName +-- `expandNonAtomicDiscrs?` create auxiliary variables with base name `_discr` +private def isAuxDiscrName (n : Name) : Bool := +n.eraseMacroScopes == `_discr + +-- See expandNonAtomicDiscrs? +private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do +let term := discr.getArg 1; +local? ← isLocalIdent? term; +match local? with +| some e@(Expr.fvar fvarId _) => do + localDecl ← getLocalDecl fvarId; + if !isAuxDiscrName localDecl.userName then + pure e -- it is not an auxiliary local created by `expandNonAtomicDiscrs?` + else + pure localDecl.value +| _ => throwErrorAt discr "unexpected discriminant" + private def elabMatchTypeAndDiscrsAux (discrStxs : Array Syntax) : Nat → Array Expr → Expr → Array MatchAltView → TermElabM (Array Expr × Expr × Array MatchAltView) | 0, discrs, matchType, matchAltViews => pure (discrs.reverse, matchType, matchAltViews) | i+1, discrs, matchType, matchAltViews => do let discrStx := discrStxs.get! i; - discr ← elabTerm (discrStx.getArg 1) none; + discr ← elabAtomicDiscr discrStx; discr ← instantiateMVars discr; discrType ← inferType discr; discrType ← instantiateMVars discrType; @@ -738,12 +755,100 @@ let r := mkAppN r rhss; trace `Elab.match fun _ => "result: " ++ r; pure r +private def getDiscrs (matchStx : Syntax) : Array Syntax := +(matchStx.getArg 1).getArgs.getSepElems + +private def getMatchOptType (matchStx : Syntax) : Syntax := +matchStx.getArg 2 + +private def expandNonAtomicDiscrsAux (matchStx : Syntax) : List Syntax → Array Syntax → TermElabM Syntax +| [], discrsNew => + let discrs := mkSepStx discrsNew (mkAtomFrom matchStx ", "); + pure $ matchStx.setArg 1 discrs +| discr :: discrs, discrsNew => do + -- Recall that + -- matchDiscr := parser! optional (ident >> ":") >> termParser + let term := discr.getArg 1; + local? ← isLocalIdent? term; + match local? with + | some _ => expandNonAtomicDiscrsAux discrs (discrsNew.push discr) + | none => withFreshMacroScope do + d ← `(_discr); + unless (isAuxDiscrName d.getId) $ -- Use assertion? + throwError "unexpected internal auxiliary discriminant name"; + let discrNew := discr.setArg 1 d; + r ← expandNonAtomicDiscrsAux discrs (discrsNew.push discrNew); + `(let _discr := $term; $r) + +private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) := +let matchOptType := getMatchOptType matchStx; +if matchOptType.isNone then do + let discrs := getDiscrs matchStx; + allLocal ← discrs.allM fun discr => Option.isSome <$> isLocalIdent? (discr.getArg 1); + if allLocal then + pure none + else + some <$> expandNonAtomicDiscrsAux matchStx discrs.toList #[] +else + -- We do not pull non atomic discriminants when match type is provided explicitly by the user + pure none + private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := do tryPostponeIfNoneOrMVar expectedType?; match expectedType? with | some expectedType => pure expectedType | none => mkFreshTypeMVar +private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := +-- We don't wait for the discriminants types when match type is provided by user +when (getMatchOptType matchStx).isNone do + let discrs := getDiscrs matchStx; + discrs.forM fun discr => do + let term := discr.getArg 1; + local? ← isLocalIdent? term; + match local? with + | none => throwErrorAt discr "unexpected discriminant" -- see `expandNonAtomicDiscrs? + | some d => do + dType ← inferType d; + tryPostponeIfMVar dType + +/- +We (try to) elaborate a `match` only when the expected type is available. +If the `matchType` has not been provided by the user, we also try to postpone elaboration if the type +of a discriminant is not available. That is, it is of the form `(?m ...)`. +We use `expandNonAtomicDiscrs?` to make sure all discriminants are local variables. +This is a standard trick we use in the elaborator, and it is also used to elaborate structure instances. +Suppose, we are trying to elaborate +``` +match g x with +| ... => ... +``` +`expandNonAtomicDiscrs?` converts it intro +``` +let _discr := g x +match _discr with +| ... => ... +``` +Thus, at `tryPostponeIfDiscrTypeIsMVar` we only need to check whether the type of `_discr` is not of the form `(?m ...)`. +Note that, the auxiliary variable `_discr` is expanded at `elabAtomicDiscr`. + +This elaboration technique is needed to elaborate terms such as: +```lean +xs.filter fun (a, b) => a > b +``` +which are syntax sugar for +```lean +List.filter (fun p => match p with | (a, b) => a > b) xs +``` +When we visit `match p with | (a, b) => a > b`, we don't know the type of `p` yet. +-/ +private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do +tryPostponeIfNoneOrMVar expectedType?; +tryPostponeIfDiscrTypeIsMVar matchStx; +match expectedType? with +| some expectedType => pure expectedType +| none => mkFreshTypeMVar + /- ``` parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts @@ -751,10 +856,10 @@ parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> ma Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`. -/ private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do -expectedType ← waitExpectedType expectedType?; -let discrStxs := (stx.getArg 1).getArgs.getSepElems.map fun d => d; -let altViews := getMatchAlts stx; -let matchOptType := stx.getArg 2; +expectedType ← waitExpectedTypeAndDiscrs stx expectedType?; +let discrStxs := (getDiscrs stx).map fun d => d; +let altViews := getMatchAlts stx; +let matchOptType := getMatchOptType stx; elabMatchAux discrStxs altViews matchOptType expectedType -- parser! "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts @@ -765,11 +870,15 @@ fun stx expectedType? => match_syntax stx with | `(match $discr:term : $type with $y:ident => $rhs:term) => expandSimpleMatchWithType stx discr y type rhs expectedType? | `(match $discr:term : $type with | $y:ident => $rhs:term) => expandSimpleMatchWithType stx discr y type rhs expectedType? | _ => do - let discrs := (stx.getArg 1).getArgs; - let matchOptType := stx.getArg 2; - when (!matchOptType.isNone && discrs.getSepElems.any fun d => !(d.getArg 0).isNone) $ - throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used"; - elabMatchCore stx expectedType? + stxNew? ← expandNonAtomicDiscrs? stx; + match stxNew? with + | some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? + | none => do + let discrs := getDiscrs stx; + let matchOptType := getMatchOptType stx; + when (!matchOptType.isNone && discrs.any fun d => !(d.getArg 0).isNone) $ + throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used"; + elabMatchCore stx expectedType? @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Elab.match; diff --git a/tests/lean/run/matchDiscrType.lean b/tests/lean/run/matchDiscrType.lean new file mode 100644 index 0000000000..c29a281c64 --- /dev/null +++ b/tests/lean/run/matchDiscrType.lean @@ -0,0 +1,14 @@ +new_frontend + +def g (x : Nat) : List (Nat × List Nat) := +[(x, [x, x]), (x, [])] + +def h (x : Nat) : List Nat := +let xs := g x $.filter (fun ⟨_, xs⟩ => xs.isEmpty) +xs.map (·.1) + +theorem ex1 : g 10 = [(10, [10, 10]), (10, [])] := +rfl + +theorem ex2 : h 10 = [10] := +rfl