From 93b5b74b362fa3b87c501ebfdd872831697eef32 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 14 Feb 2022 15:35:38 -0800 Subject: [PATCH] feat: modify notation for providing motive in "match" expressions --- src/Lean/Elab/Do.lean | 18 ++-- src/Lean/Elab/Match.lean | 84 +++++++++++-------- src/Lean/Parser/Do.lean | 2 +- src/Lean/Parser/Tactic.lean | 2 +- src/Lean/Parser/Term.lean | 4 +- .../PrettyPrinter/Delaborator/Builtins.lean | 2 +- tests/lean/run/match1.lean | 2 +- tests/lean/run/newfrontend2.lean | 4 +- 8 files changed, 65 insertions(+), 53 deletions(-) diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index f6357ef8b7..8476743f48 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -190,7 +190,7 @@ inductive Code where | «return» (ref : Syntax) (val : Syntax) /- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/ | ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) - | «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt Code)) + | «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt Code)) | jmp (ref : Syntax) (jpName : Name) (args : Array Syntax) deriving Inhabited @@ -545,13 +545,13 @@ def mkUnless (cond : Syntax) (c : CodeBlock) : MacroM CodeBlock := do let thenBranch ← mkPureUnitAction pure { c with code := Code.ite (← getRef) none mkNullNode cond thenBranch.code c.code } -def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do +def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do -- nary version of homogenize let ws := alts.foldl (union · ·.rhs.uvars) {} let alts ← alts.mapM fun alt => do let rhs ← extendUpdatedVars alt.rhs ws - pure { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code } - pure { code := Code.«match» ref genParam discrs optType alts, uvars := ws } + return { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code } + return { code := Code.«match» ref genParam discrs optMotive alts, uvars := ws } /- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`. This method assumes `terminal` is a terminal -/ @@ -1036,14 +1036,14 @@ partial def toTerm : Code → M Syntax | Code.reassign _ stx k => do reassignToTerm stx (← toTerm k) | Code.seq stx k => do seqToTerm stx (← toTerm k) | Code.ite ref _ o c t e => withRef ref <| do mkIte o c (← toTerm t) (← toTerm e) - | Code.«match» ref genParam discrs optType alts => do + | Code.«match» ref genParam discrs optMotive alts => do let mut termAlts := #[] for alt in alts do let rhs ← toTerm alt.rhs let termAlt := mkNode `Lean.Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", alt.patterns, mkAtomFrom alt.ref "=>", rhs] termAlts := termAlts.push termAlt let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts] - pure $ mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, discrs, optType, mkAtomFrom ref "with", termMatchAlts] + return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts] def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do let term ← toTerm code { m := m, kind := kind, uvars := uvars } @@ -1445,8 +1445,8 @@ mutual partial def doMatchToCode (doMatch : Syntax) (doElems: List Syntax) : M CodeBlock := do let ref := doMatch let genParam := doMatch[1] - let discrs := doMatch[2] - let optType := doMatch[3] + let optMotive := doMatch[2] + let discrs := doMatch[3] let matchAlts := doMatch[5][0].getArgs -- Array of `doMatchAlt` let alts ← matchAlts.mapM fun matchAlt => do let patterns := matchAlt[1] @@ -1455,7 +1455,7 @@ mutual let rhs := matchAlt[3] let rhs ← doSeqToCode (getDoSeqElems rhs) pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock } - let matchCode ← mkMatch ref genParam discrs optType alts + let matchCode ← mkMatch ref genParam discrs optMotive alts concatWith matchCode doElems /-- diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 636edddb7d..6c9d2f0eed 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -77,13 +77,14 @@ structure ElabMatchTypeAndDiscrsResult where isDep : Bool alts : Array MatchAltView -private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr) +private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptMotive : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr) : TermElabM ElabMatchTypeAndDiscrsResult := do let numDiscrs := discrStxs.size - if matchOptType.isNone then + if matchOptMotive.isNone then elabDiscrs 0 #[] else - let matchTypeStx := matchOptType[0][1] + -- motive := leading_parser atomic ("(" >> nonReservedSymbol "motive" >> " := ") >> termParser >> ")" + let matchTypeStx := matchOptMotive[0][3] let matchType ← elabType matchTypeStx let (discrs, isDep) ← elabDiscrsWitMatchType matchType expectedType return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews } @@ -106,7 +107,7 @@ private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptT matchType := b.instantiate1 discr discrs := discrs.push discr | _ => - throwError "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected" + throwError "invalid motive provided to match-expression, function type with arity #{discrStxs.size} expected" return (discrs, isDep) markIsDep (r : ElabMatchTypeAndDiscrsResult) := @@ -157,13 +158,13 @@ def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array Matc pure { matchAlt with patterns := patterns } private def getMatchGeneralizing? : Syntax → Option Bool - | `(match (generalizing := true) $discrs,* $[: $ty?]? with $alts:matchAlt*) => some true - | `(match (generalizing := false) $discrs,* $[: $ty?]? with $alts:matchAlt*) => some false + | `(match (generalizing := true) $[$motive]? $discrs,* with $alts:matchAlt*) => some true + | `(match (generalizing := false) $[$motive]? $discrs,* with $alts:matchAlt*) => some false | _ => none /- Given `stx` a match-expression, return its alternatives. -/ private def getMatchAlts : Syntax → Array MatchAltView - | `(match $[$gen]? $discrs,* $[: $ty?]? with $alts:matchAlt*) => + | `(match $[$gen]? $[$motive]? $discrs,* with $alts:matchAlt*) => alts.filterMap fun alt => match alt with | `(matchAltExpr| | $patterns,* => $rhs) => some { ref := alt, @@ -409,7 +410,7 @@ def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (A let decl ← instantiateLocalDeclMVars decl decls := decls.push decl | PatternVarDecl.anonymousVar mvarId fvarId => - let e ← instantiateMVars (mkMVar mvarId); + let e ← instantiateMVars (mkMVar mvarId) trace[Elab.match] "finalizePatternDecls: mvarId: {mvarId.name} := {e}, fvar: {mkFVar fvarId}" match e with | Expr.mvar newMVarId _ => @@ -458,7 +459,7 @@ private def mkLocalDeclFor (mvar : Expr) : M Pattern := do If this generates problems in the future, we should update the metavariable declarations. -/ assignExprMVar mvarId (mkFVar fvarId) let userName ← mkFreshBinderName - let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default; + let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default modify fun s => { s with newLocals := s.newLocals.insert fvarId, @@ -764,15 +765,16 @@ private def isMatchUnit? (altLHSS : List Match.AltLHS) (rhss : Array Expr) : Met | Expr.lam _ _ b _ => return if b.hasLooseBVars then none else b | _ => return none | _ => return none -private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr) + +private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptMotive : Syntax) (expectedType : Expr) : TermElabM Expr := do let mut generalizing? := generalizing? - if !matchOptType.isNone then + if !matchOptMotive.isNone then if generalizing? == some true then - throwError "the '(generalizing := true)' parameter is not supported when the 'match' type is explicitly provided" + throwError "the '(generalizing := true)' parameter is not supported when the 'match' motive is explicitly provided" generalizing? := some false let (discrs, matchType, altLHSS, isDep, rhss) ← commitIfDidNotPostpone do - let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType + let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptMotive altViews expectedType let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews trace[Elab.match] "matchType: {matchType}" let (discrs, matchType, alts, refined) ← elabMatchAltViews generalizing? discrs matchType matchAlts @@ -845,16 +847,24 @@ private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax trace[Elab.match] "result: {r}" return r -private def getDiscrs (matchStx : Syntax) : Array Syntax := - matchStx[2].getSepArgs +-- leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts -private def getMatchOptType (matchStx : Syntax) : Syntax := - matchStx[3] +private def getDiscrs (matchStx : Syntax) : Array Syntax := + if matchStx[3].isNone then -- HACK for bootstrapping issues + matchStx[2].getSepArgs + else + matchStx[3].getSepArgs + +private def getMatchOptMotive (matchStx : Syntax) : Syntax := + if !matchStx[2].isNone && matchStx[2][0].isOfKind ``Lean.Parser.Term.matchDiscr then -- HACK for bootstrapping issues + mkNullNode + else + matchStx[2] private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) := - let matchOptType := getMatchOptType matchStx; - if matchOptType.isNone then do - let discrs := getDiscrs matchStx; + let matchOptMotive := getMatchOptMotive matchStx + if matchOptMotive.isNone then do + let discrs := getDiscrs matchStx let allLocal ← discrs.allM fun discr => Option.isSome <$> isAtomicDiscr? discr[1] if allLocal then return none @@ -864,17 +874,17 @@ private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Synta let rec loop (discrs : List Syntax) (discrsNew : Array Syntax) (foundFVars : FVarIdSet) := do match discrs with | [] => - let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", "); - pure (matchStx.setArg 2 discrs) + let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", ") + pure (matchStx.setArg 3 discrs) | discr :: discrs => -- Recall that -- matchDiscr := leading_parser optional (ident >> ":") >> termParser let term := discr[1] let addAux : TermElabM Syntax := withFreshMacroScope do - let d ← `(_discr); + let d ← `(_discr) unless isAuxDiscrName d.getId do -- Use assertion? throwError "unexpected internal auxiliary discriminant name" - let discrNew := discr.setArg 1 d; + let discrNew := discr.setArg 1 d let r ← loop discrs (discrsNew.push discrNew) foundFVars `(let _discr := $term; $r) match (← isAtomicDiscr? term) with @@ -893,7 +903,7 @@ private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := d private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := do -- We don't wait for the discriminants types when match type is provided by user - if getMatchOptType matchStx |>.isNone then + if getMatchOptMotive matchStx |>.isNone then let discrs := getDiscrs matchStx for discr in discrs do let term := discr[1] @@ -943,17 +953,17 @@ private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Optio /- ``` -leading_parser:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts +leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts ``` Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`. -/ private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do - let expectedType ← waitExpectedTypeAndDiscrs stx expectedType? - let discrStxs := (getDiscrs stx).map fun d => d - let gen? := getMatchGeneralizing? stx - let altViews := getMatchAlts stx - let matchOptType := getMatchOptType stx - elabMatchAux gen? discrStxs altViews matchOptType expectedType + let expectedType ← waitExpectedTypeAndDiscrs stx expectedType? + let discrStxs := (getDiscrs stx).map fun d => d + let gen? := getMatchGeneralizing? stx + let altViews := getMatchAlts stx + let matchOptMotive := getMatchOptMotive stx + elabMatchAux gen? discrStxs altViews matchOptMotive expectedType private def isPatternVar (stx : Syntax) : TermElabM Bool := do match (← resolveId? stx "pattern") with @@ -969,7 +979,7 @@ where isAtomicIdent (stx : Syntax) : Bool := stx.isIdent && stx.getId.eraseMacroScopes.isAtomic --- leading_parser "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts +-- leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts /-- Pattern matching. `match e, ... with | p, ... => f | ...` matches each given term `e` against each pattern `p` of a match alternative. When all patterns @@ -989,10 +999,10 @@ where match (← expandNonAtomicDiscrs? stx) with | some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType? | none => - let discrs := getDiscrs stx; - let matchOptType := getMatchOptType stx; - if !matchOptType.isNone && discrs.any fun d => !d[0].isNone then - throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used" + let discrs := getDiscrs stx + let matchOptMotive := getMatchOptMotive stx + if !matchOptMotive.isNone && discrs.any fun d => !d[0].isNone then + throwErrorAt matchOptMotive "match motive should not be provided when discriminants with equality proofs are used" elabMatchCore stx expectedType? builtin_initialize diff --git a/src/Lean/Parser/Do.lean b/src/Lean/Parser/Do.lean index da1bf0a72c..c0a3084bb4 100644 --- a/src/Lean/Parser/Do.lean +++ b/src/Lean/Parser/Do.lean @@ -96,7 +96,7 @@ def doForDecl := leading_parser termParser >> " in " >> withForbidden "do" termP @[builtinDoElemParser] def doFor := leading_parser "for " >> sepBy1 doForDecl ", " >> "do " >> doSeq def doMatchAlts := ppDedent <| matchAlts (rhsParser := doSeq) -@[builtinDoElemParser] def doMatch := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts +@[builtinDoElemParser] def doMatch := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> optional Term.motive >> sepBy1 matchDiscr ", " >> " with " >> doMatchAlts def doCatch := leading_parser atomic ("catch " >> binderIdent) >> optional (" : " >> termParser) >> darrow >> doSeq def doCatchMatch := leading_parser "catch " >> doMatchAlts diff --git a/src/Lean/Parser/Tactic.lean b/src/Lean/Parser/Tactic.lean index 33a3086d2e..a3f3771a67 100644 --- a/src/Lean/Parser/Tactic.lean +++ b/src/Lean/Parser/Tactic.lean @@ -20,7 +20,7 @@ builtin_initialize def matchRhs := Term.hole <|> Term.syntheticHole <|> tacticSeq def matchAlts := Term.matchAlts (rhsParser := matchRhs) -@[builtinTacticParser] def «match» := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> sepBy1 Term.matchDiscr ", " >> Term.optType >> " with " >> ppDedent matchAlts +@[builtinTacticParser] def «match» := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> optional Term.motive >> sepBy1 Term.matchDiscr ", " >> " with " >> ppDedent matchAlts @[builtinTacticParser] def introMatch := leading_parser nonReservedSymbol "intro " >> matchAlts @[builtinTacticParser] def decide := leading_parser nonReservedSymbol "decide" diff --git a/src/Lean/Parser/Term.lean b/src/Lean/Parser/Term.lean index 41b64a3500..66f2de6029 100644 --- a/src/Lean/Parser/Term.lean +++ b/src/Lean/Parser/Term.lean @@ -150,7 +150,9 @@ def trueVal := leading_parser nonReservedSymbol "true" def falseVal := leading_parser nonReservedSymbol "false" def generalizingParam := leading_parser atomic ("(" >> nonReservedSymbol "generalizing") >> " := " >> (trueVal <|> falseVal) >> ")" -@[builtinTermParser] def «match» := leading_parser:leadPrec "match " >> optional generalizingParam >> sepBy1 matchDiscr ", " >> optType >> " with " >> ppDedent matchAlts +def motive := leading_parser atomic ("(" >> nonReservedSymbol "motive" >> " := ") >> termParser >> ")" + +@[builtinTermParser] def «match» := leading_parser:leadPrec "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts @[builtinTermParser] def «nomatch» := leading_parser:leadPrec "nomatch " >> termParser def funImplicitBinder := atomic (lookahead ("{" >> many1 binderIdent >> (symbol " : " <|> "}"))) >> implicitBinder diff --git a/src/Lean/PrettyPrinter/Delaborator/Builtins.lean b/src/Lean/PrettyPrinter/Delaborator/Builtins.lean index a5a11cef17..276380f456 100644 --- a/src/Lean/PrettyPrinter/Delaborator/Builtins.lean +++ b/src/Lean/PrettyPrinter/Delaborator/Builtins.lean @@ -421,7 +421,7 @@ def delabAppMatch : Delab := whenPPOption getPPNotation <| whenPPOption getPPMat let opts ← getOptions -- TODO: disable the match if other implicits are needed? if ← pure st.motiveNamed <||> shouldShowMotive lamMotive opts then - `(match $[$st.discrs:term],* : $piStx with $[| $pats,* => $st.rhss]*) + `(match (motive := $piStx) $[$st.discrs:term],* with $[| $pats,* => $st.rhss]*) else `(match $[$st.discrs:term],* with $[| $pats,* => $st.rhss]*) return Syntax.mkApp stx st.moreArgs diff --git a/tests/lean/run/match1.lean b/tests/lean/run/match1.lean index e11768707c..87e76b9821 100644 --- a/tests/lean/run/match1.lean +++ b/tests/lean/run/match1.lean @@ -162,7 +162,7 @@ match h:foo b with #eval checkWithMkMatcherInput ``Bla.isNat2?.match_1 def foo2 (x : Nat) : Nat := -match x, rfl : (y : Nat) → x = y → Nat with +match (motive := (y : Nat) → x = y → Nat) x, rfl with | 0, h => 0 | x+1, h => 1 #eval checkWithMkMatcherInput ``foo2.match_1 diff --git a/tests/lean/run/newfrontend2.lean b/tests/lean/run/newfrontend2.lean index fdc9a0acf7..abb02584ee 100644 --- a/tests/lean/run/newfrontend2.lean +++ b/tests/lean/run/newfrontend2.lean @@ -16,9 +16,9 @@ def x := 1 #check foo x x #check match 1 with | x => x + 1 -#check match 1 : Int -> _ with | x => x + 1 +#check match (motive := Int → _) 1 with | x => x + 1 #check match 1 with | x => x + 1 -#check match 1 : Int -> _ with | x => x + 1 +#check match (motive := Int → _) 1 with | x => x + 1 def g (x : Nat × Nat) (y : Nat) := x.1 + x.2 + y