feat: modify notation for providing motive in "match" expressions

This commit is contained in:
Leonardo de Moura 2022-02-14 15:35:38 -08:00
parent 0030208d99
commit 93b5b74b36
8 changed files with 65 additions and 53 deletions

View file

@ -190,7 +190,7 @@ inductive Code where
| «return» (ref : Syntax) (val : Syntax)
/- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/
| ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
| «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt Code))
| «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt Code))
| jmp (ref : Syntax) (jpName : Name) (args : Array Syntax)
deriving Inhabited
@ -545,13 +545,13 @@ def mkUnless (cond : Syntax) (c : CodeBlock) : MacroM CodeBlock := do
let thenBranch ← mkPureUnitAction
pure { c with code := Code.ite (← getRef) none mkNullNode cond thenBranch.code c.code }
def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do
def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do
-- nary version of homogenize
let ws := alts.foldl (union · ·.rhs.uvars) {}
let alts ← alts.mapM fun alt => do
let rhs ← extendUpdatedVars alt.rhs ws
pure { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code }
pure { code := Code.«match» ref genParam discrs optType alts, uvars := ws }
return { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code }
return { code := Code.«match» ref genParam discrs optMotive alts, uvars := ws }
/- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`.
This method assumes `terminal` is a terminal -/
@ -1036,14 +1036,14 @@ partial def toTerm : Code → M Syntax
| Code.reassign _ stx k => do reassignToTerm stx (← toTerm k)
| Code.seq stx k => do seqToTerm stx (← toTerm k)
| Code.ite ref _ o c t e => withRef ref <| do mkIte o c (← toTerm t) (← toTerm e)
| Code.«match» ref genParam discrs optType alts => do
| Code.«match» ref genParam discrs optMotive alts => do
let mut termAlts := #[]
for alt in alts do
let rhs ← toTerm alt.rhs
let termAlt := mkNode `Lean.Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", alt.patterns, mkAtomFrom alt.ref "=>", rhs]
termAlts := termAlts.push termAlt
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts]
pure $ mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, discrs, optType, mkAtomFrom ref "with", termMatchAlts]
return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts]
def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do
let term ← toTerm code { m := m, kind := kind, uvars := uvars }
@ -1445,8 +1445,8 @@ mutual
partial def doMatchToCode (doMatch : Syntax) (doElems: List Syntax) : M CodeBlock := do
let ref := doMatch
let genParam := doMatch[1]
let discrs := doMatch[2]
let optType := doMatch[3]
let optMotive := doMatch[2]
let discrs := doMatch[3]
let matchAlts := doMatch[5][0].getArgs -- Array of `doMatchAlt`
let alts ← matchAlts.mapM fun matchAlt => do
let patterns := matchAlt[1]
@ -1455,7 +1455,7 @@ mutual
let rhs := matchAlt[3]
let rhs ← doSeqToCode (getDoSeqElems rhs)
pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock }
let matchCode ← mkMatch ref genParam discrs optType alts
let matchCode ← mkMatch ref genParam discrs optMotive alts
concatWith matchCode doElems
/--

View file

@ -77,13 +77,14 @@ structure ElabMatchTypeAndDiscrsResult where
isDep : Bool
alts : Array MatchAltView
private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptMotive : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
: TermElabM ElabMatchTypeAndDiscrsResult := do
let numDiscrs := discrStxs.size
if matchOptType.isNone then
if matchOptMotive.isNone then
elabDiscrs 0 #[]
else
let matchTypeStx := matchOptType[0][1]
-- motive := leading_parser atomic ("(" >> nonReservedSymbol "motive" >> " := ") >> termParser >> ")"
let matchTypeStx := matchOptMotive[0][3]
let matchType ← elabType matchTypeStx
let (discrs, isDep) ← elabDiscrsWitMatchType matchType expectedType
return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews }
@ -106,7 +107,7 @@ private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptT
matchType := b.instantiate1 discr
discrs := discrs.push discr
| _ =>
throwError "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected"
throwError "invalid motive provided to match-expression, function type with arity #{discrStxs.size} expected"
return (discrs, isDep)
markIsDep (r : ElabMatchTypeAndDiscrsResult) :=
@ -157,13 +158,13 @@ def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array Matc
pure { matchAlt with patterns := patterns }
private def getMatchGeneralizing? : Syntax → Option Bool
| `(match (generalizing := true) $discrs,* $[: $ty?]? with $alts:matchAlt*) => some true
| `(match (generalizing := false) $discrs,* $[: $ty?]? with $alts:matchAlt*) => some false
| `(match (generalizing := true) $[$motive]? $discrs,* with $alts:matchAlt*) => some true
| `(match (generalizing := false) $[$motive]? $discrs,* with $alts:matchAlt*) => some false
| _ => none
/- Given `stx` a match-expression, return its alternatives. -/
private def getMatchAlts : Syntax → Array MatchAltView
| `(match $[$gen]? $discrs,* $[: $ty?]? with $alts:matchAlt*) =>
| `(match $[$gen]? $[$motive]? $discrs,* with $alts:matchAlt*) =>
alts.filterMap fun alt => match alt with
| `(matchAltExpr| | $patterns,* => $rhs) => some {
ref := alt,
@ -409,7 +410,7 @@ def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (A
let decl ← instantiateLocalDeclMVars decl
decls := decls.push decl
| PatternVarDecl.anonymousVar mvarId fvarId =>
let e ← instantiateMVars (mkMVar mvarId);
let e ← instantiateMVars (mkMVar mvarId)
trace[Elab.match] "finalizePatternDecls: mvarId: {mvarId.name} := {e}, fvar: {mkFVar fvarId}"
match e with
| Expr.mvar newMVarId _ =>
@ -458,7 +459,7 @@ private def mkLocalDeclFor (mvar : Expr) : M Pattern := do
If this generates problems in the future, we should update the metavariable declarations. -/
assignExprMVar mvarId (mkFVar fvarId)
let userName ← mkFreshBinderName
let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default;
let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default
modify fun s =>
{ s with
newLocals := s.newLocals.insert fvarId,
@ -764,15 +765,16 @@ private def isMatchUnit? (altLHSS : List Match.AltLHS) (rhss : Array Expr) : Met
| Expr.lam _ _ b _ => return if b.hasLooseBVars then none else b
| _ => return none
| _ => return none
private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr)
private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptMotive : Syntax) (expectedType : Expr)
: TermElabM Expr := do
let mut generalizing? := generalizing?
if !matchOptType.isNone then
if !matchOptMotive.isNone then
if generalizing? == some true then
throwError "the '(generalizing := true)' parameter is not supported when the 'match' type is explicitly provided"
throwError "the '(generalizing := true)' parameter is not supported when the 'match' motive is explicitly provided"
generalizing? := some false
let (discrs, matchType, altLHSS, isDep, rhss) ← commitIfDidNotPostpone do
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptMotive altViews expectedType
let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews
trace[Elab.match] "matchType: {matchType}"
let (discrs, matchType, alts, refined) ← elabMatchAltViews generalizing? discrs matchType matchAlts
@ -845,16 +847,24 @@ private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax
trace[Elab.match] "result: {r}"
return r
private def getDiscrs (matchStx : Syntax) : Array Syntax :=
matchStx[2].getSepArgs
-- leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
private def getMatchOptType (matchStx : Syntax) : Syntax :=
matchStx[3]
private def getDiscrs (matchStx : Syntax) : Array Syntax :=
if matchStx[3].isNone then -- HACK for bootstrapping issues
matchStx[2].getSepArgs
else
matchStx[3].getSepArgs
private def getMatchOptMotive (matchStx : Syntax) : Syntax :=
if !matchStx[2].isNone && matchStx[2][0].isOfKind ``Lean.Parser.Term.matchDiscr then -- HACK for bootstrapping issues
mkNullNode
else
matchStx[2]
private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) :=
let matchOptType := getMatchOptType matchStx;
if matchOptType.isNone then do
let discrs := getDiscrs matchStx;
let matchOptMotive := getMatchOptMotive matchStx
if matchOptMotive.isNone then do
let discrs := getDiscrs matchStx
let allLocal ← discrs.allM fun discr => Option.isSome <$> isAtomicDiscr? discr[1]
if allLocal then
return none
@ -864,17 +874,17 @@ private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Synta
let rec loop (discrs : List Syntax) (discrsNew : Array Syntax) (foundFVars : FVarIdSet) := do
match discrs with
| [] =>
let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", ");
pure (matchStx.setArg 2 discrs)
let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", ")
pure (matchStx.setArg 3 discrs)
| discr :: discrs =>
-- Recall that
-- matchDiscr := leading_parser optional (ident >> ":") >> termParser
let term := discr[1]
let addAux : TermElabM Syntax := withFreshMacroScope do
let d ← `(_discr);
let d ← `(_discr)
unless isAuxDiscrName d.getId do -- Use assertion?
throwError "unexpected internal auxiliary discriminant name"
let discrNew := discr.setArg 1 d;
let discrNew := discr.setArg 1 d
let r ← loop discrs (discrsNew.push discrNew) foundFVars
`(let _discr := $term; $r)
match (← isAtomicDiscr? term) with
@ -893,7 +903,7 @@ private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := d
private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := do
-- We don't wait for the discriminants types when match type is provided by user
if getMatchOptType matchStx |>.isNone then
if getMatchOptMotive matchStx |>.isNone then
let discrs := getDiscrs matchStx
for discr in discrs do
let term := discr[1]
@ -943,17 +953,17 @@ private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Optio
/-
```
leading_parser:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts
leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
```
Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`.
-/
private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
let expectedType ← waitExpectedTypeAndDiscrs stx expectedType?
let discrStxs := (getDiscrs stx).map fun d => d
let gen? := getMatchGeneralizing? stx
let altViews := getMatchAlts stx
let matchOptType := getMatchOptType stx
elabMatchAux gen? discrStxs altViews matchOptType expectedType
let expectedType ← waitExpectedTypeAndDiscrs stx expectedType?
let discrStxs := (getDiscrs stx).map fun d => d
let gen? := getMatchGeneralizing? stx
let altViews := getMatchAlts stx
let matchOptMotive := getMatchOptMotive stx
elabMatchAux gen? discrStxs altViews matchOptMotive expectedType
private def isPatternVar (stx : Syntax) : TermElabM Bool := do
match (← resolveId? stx "pattern") with
@ -969,7 +979,7 @@ where
isAtomicIdent (stx : Syntax) : Bool :=
stx.isIdent && stx.getId.eraseMacroScopes.isAtomic
-- leading_parser "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts
-- leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
/--
Pattern matching. `match e, ... with | p, ... => f | ...` matches each given
term `e` against each pattern `p` of a match alternative. When all patterns
@ -989,10 +999,10 @@ where
match (← expandNonAtomicDiscrs? stx) with
| some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
| none =>
let discrs := getDiscrs stx;
let matchOptType := getMatchOptType stx;
if !matchOptType.isNone && discrs.any fun d => !d[0].isNone then
throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used"
let discrs := getDiscrs stx
let matchOptMotive := getMatchOptMotive stx
if !matchOptMotive.isNone && discrs.any fun d => !d[0].isNone then
throwErrorAt matchOptMotive "match motive should not be provided when discriminants with equality proofs are used"
elabMatchCore stx expectedType?
builtin_initialize

View file

@ -96,7 +96,7 @@ def doForDecl := leading_parser termParser >> " in " >> withForbidden "do" termP
@[builtinDoElemParser] def doFor := leading_parser "for " >> sepBy1 doForDecl ", " >> "do " >> doSeq
def doMatchAlts := ppDedent <| matchAlts (rhsParser := doSeq)
@[builtinDoElemParser] def doMatch := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts
@[builtinDoElemParser] def doMatch := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> optional Term.motive >> sepBy1 matchDiscr ", " >> " with " >> doMatchAlts
def doCatch := leading_parser atomic ("catch " >> binderIdent) >> optional (" : " >> termParser) >> darrow >> doSeq
def doCatchMatch := leading_parser "catch " >> doMatchAlts

View file

@ -20,7 +20,7 @@ builtin_initialize
def matchRhs := Term.hole <|> Term.syntheticHole <|> tacticSeq
def matchAlts := Term.matchAlts (rhsParser := matchRhs)
@[builtinTacticParser] def «match» := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> sepBy1 Term.matchDiscr ", " >> Term.optType >> " with " >> ppDedent matchAlts
@[builtinTacticParser] def «match» := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> optional Term.motive >> sepBy1 Term.matchDiscr ", " >> " with " >> ppDedent matchAlts
@[builtinTacticParser] def introMatch := leading_parser nonReservedSymbol "intro " >> matchAlts
@[builtinTacticParser] def decide := leading_parser nonReservedSymbol "decide"

View file

@ -150,7 +150,9 @@ def trueVal := leading_parser nonReservedSymbol "true"
def falseVal := leading_parser nonReservedSymbol "false"
def generalizingParam := leading_parser atomic ("(" >> nonReservedSymbol "generalizing") >> " := " >> (trueVal <|> falseVal) >> ")"
@[builtinTermParser] def «match» := leading_parser:leadPrec "match " >> optional generalizingParam >> sepBy1 matchDiscr ", " >> optType >> " with " >> ppDedent matchAlts
def motive := leading_parser atomic ("(" >> nonReservedSymbol "motive" >> " := ") >> termParser >> ")"
@[builtinTermParser] def «match» := leading_parser:leadPrec "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
@[builtinTermParser] def «nomatch» := leading_parser:leadPrec "nomatch " >> termParser
def funImplicitBinder := atomic (lookahead ("{" >> many1 binderIdent >> (symbol " : " <|> "}"))) >> implicitBinder

View file

@ -421,7 +421,7 @@ def delabAppMatch : Delab := whenPPOption getPPNotation <| whenPPOption getPPMat
let opts ← getOptions
-- TODO: disable the match if other implicits are needed?
if ← pure st.motiveNamed <||> shouldShowMotive lamMotive opts then
`(match $[$st.discrs:term],* : $piStx with $[| $pats,* => $st.rhss]*)
`(match (motive := $piStx) $[$st.discrs:term],* with $[| $pats,* => $st.rhss]*)
else
`(match $[$st.discrs:term],* with $[| $pats,* => $st.rhss]*)
return Syntax.mkApp stx st.moreArgs

View file

@ -162,7 +162,7 @@ match h:foo b with
#eval checkWithMkMatcherInput ``Bla.isNat2?.match_1
def foo2 (x : Nat) : Nat :=
match x, rfl : (y : Nat) → x = y → Nat with
match (motive := (y : Nat) → x = y → Nat) x, rfl with
| 0, h => 0
| x+1, h => 1
#eval checkWithMkMatcherInput ``foo2.match_1

View file

@ -16,9 +16,9 @@ def x := 1
#check foo x x
#check match 1 with | x => x + 1
#check match 1 : Int -> _ with | x => x + 1
#check match (motive := Int → _) 1 with | x => x + 1
#check match 1 with | x => x + 1
#check match 1 : Int -> _ with | x => x + 1
#check match (motive := Int → _) 1 with | x => x + 1
def g (x : Nat × Nat) (y : Nat) :=
x.1 + x.2 + y