feat: modify notation for providing motive in "match" expressions
This commit is contained in:
parent
0030208d99
commit
93b5b74b36
8 changed files with 65 additions and 53 deletions
|
|
@ -190,7 +190,7 @@ inductive Code where
|
|||
| «return» (ref : Syntax) (val : Syntax)
|
||||
/- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/
|
||||
| ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
|
||||
| «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt Code))
|
||||
| «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt Code))
|
||||
| jmp (ref : Syntax) (jpName : Name) (args : Array Syntax)
|
||||
deriving Inhabited
|
||||
|
||||
|
|
@ -545,13 +545,13 @@ def mkUnless (cond : Syntax) (c : CodeBlock) : MacroM CodeBlock := do
|
|||
let thenBranch ← mkPureUnitAction
|
||||
pure { c with code := Code.ite (← getRef) none mkNullNode cond thenBranch.code c.code }
|
||||
|
||||
def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do
|
||||
def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do
|
||||
-- nary version of homogenize
|
||||
let ws := alts.foldl (union · ·.rhs.uvars) {}
|
||||
let alts ← alts.mapM fun alt => do
|
||||
let rhs ← extendUpdatedVars alt.rhs ws
|
||||
pure { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code }
|
||||
pure { code := Code.«match» ref genParam discrs optType alts, uvars := ws }
|
||||
return { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code }
|
||||
return { code := Code.«match» ref genParam discrs optMotive alts, uvars := ws }
|
||||
|
||||
/- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`.
|
||||
This method assumes `terminal` is a terminal -/
|
||||
|
|
@ -1036,14 +1036,14 @@ partial def toTerm : Code → M Syntax
|
|||
| Code.reassign _ stx k => do reassignToTerm stx (← toTerm k)
|
||||
| Code.seq stx k => do seqToTerm stx (← toTerm k)
|
||||
| Code.ite ref _ o c t e => withRef ref <| do mkIte o c (← toTerm t) (← toTerm e)
|
||||
| Code.«match» ref genParam discrs optType alts => do
|
||||
| Code.«match» ref genParam discrs optMotive alts => do
|
||||
let mut termAlts := #[]
|
||||
for alt in alts do
|
||||
let rhs ← toTerm alt.rhs
|
||||
let termAlt := mkNode `Lean.Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", alt.patterns, mkAtomFrom alt.ref "=>", rhs]
|
||||
termAlts := termAlts.push termAlt
|
||||
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts]
|
||||
pure $ mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, discrs, optType, mkAtomFrom ref "with", termMatchAlts]
|
||||
return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts]
|
||||
|
||||
def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do
|
||||
let term ← toTerm code { m := m, kind := kind, uvars := uvars }
|
||||
|
|
@ -1445,8 +1445,8 @@ mutual
|
|||
partial def doMatchToCode (doMatch : Syntax) (doElems: List Syntax) : M CodeBlock := do
|
||||
let ref := doMatch
|
||||
let genParam := doMatch[1]
|
||||
let discrs := doMatch[2]
|
||||
let optType := doMatch[3]
|
||||
let optMotive := doMatch[2]
|
||||
let discrs := doMatch[3]
|
||||
let matchAlts := doMatch[5][0].getArgs -- Array of `doMatchAlt`
|
||||
let alts ← matchAlts.mapM fun matchAlt => do
|
||||
let patterns := matchAlt[1]
|
||||
|
|
@ -1455,7 +1455,7 @@ mutual
|
|||
let rhs := matchAlt[3]
|
||||
let rhs ← doSeqToCode (getDoSeqElems rhs)
|
||||
pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock }
|
||||
let matchCode ← mkMatch ref genParam discrs optType alts
|
||||
let matchCode ← mkMatch ref genParam discrs optMotive alts
|
||||
concatWith matchCode doElems
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -77,13 +77,14 @@ structure ElabMatchTypeAndDiscrsResult where
|
|||
isDep : Bool
|
||||
alts : Array MatchAltView
|
||||
|
||||
private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
|
||||
private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptMotive : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
|
||||
: TermElabM ElabMatchTypeAndDiscrsResult := do
|
||||
let numDiscrs := discrStxs.size
|
||||
if matchOptType.isNone then
|
||||
if matchOptMotive.isNone then
|
||||
elabDiscrs 0 #[]
|
||||
else
|
||||
let matchTypeStx := matchOptType[0][1]
|
||||
-- motive := leading_parser atomic ("(" >> nonReservedSymbol "motive" >> " := ") >> termParser >> ")"
|
||||
let matchTypeStx := matchOptMotive[0][3]
|
||||
let matchType ← elabType matchTypeStx
|
||||
let (discrs, isDep) ← elabDiscrsWitMatchType matchType expectedType
|
||||
return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews }
|
||||
|
|
@ -106,7 +107,7 @@ private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptT
|
|||
matchType := b.instantiate1 discr
|
||||
discrs := discrs.push discr
|
||||
| _ =>
|
||||
throwError "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected"
|
||||
throwError "invalid motive provided to match-expression, function type with arity #{discrStxs.size} expected"
|
||||
return (discrs, isDep)
|
||||
|
||||
markIsDep (r : ElabMatchTypeAndDiscrsResult) :=
|
||||
|
|
@ -157,13 +158,13 @@ def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array Matc
|
|||
pure { matchAlt with patterns := patterns }
|
||||
|
||||
private def getMatchGeneralizing? : Syntax → Option Bool
|
||||
| `(match (generalizing := true) $discrs,* $[: $ty?]? with $alts:matchAlt*) => some true
|
||||
| `(match (generalizing := false) $discrs,* $[: $ty?]? with $alts:matchAlt*) => some false
|
||||
| `(match (generalizing := true) $[$motive]? $discrs,* with $alts:matchAlt*) => some true
|
||||
| `(match (generalizing := false) $[$motive]? $discrs,* with $alts:matchAlt*) => some false
|
||||
| _ => none
|
||||
|
||||
/- Given `stx` a match-expression, return its alternatives. -/
|
||||
private def getMatchAlts : Syntax → Array MatchAltView
|
||||
| `(match $[$gen]? $discrs,* $[: $ty?]? with $alts:matchAlt*) =>
|
||||
| `(match $[$gen]? $[$motive]? $discrs,* with $alts:matchAlt*) =>
|
||||
alts.filterMap fun alt => match alt with
|
||||
| `(matchAltExpr| | $patterns,* => $rhs) => some {
|
||||
ref := alt,
|
||||
|
|
@ -409,7 +410,7 @@ def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (A
|
|||
let decl ← instantiateLocalDeclMVars decl
|
||||
decls := decls.push decl
|
||||
| PatternVarDecl.anonymousVar mvarId fvarId =>
|
||||
let e ← instantiateMVars (mkMVar mvarId);
|
||||
let e ← instantiateMVars (mkMVar mvarId)
|
||||
trace[Elab.match] "finalizePatternDecls: mvarId: {mvarId.name} := {e}, fvar: {mkFVar fvarId}"
|
||||
match e with
|
||||
| Expr.mvar newMVarId _ =>
|
||||
|
|
@ -458,7 +459,7 @@ private def mkLocalDeclFor (mvar : Expr) : M Pattern := do
|
|||
If this generates problems in the future, we should update the metavariable declarations. -/
|
||||
assignExprMVar mvarId (mkFVar fvarId)
|
||||
let userName ← mkFreshBinderName
|
||||
let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default;
|
||||
let newDecl := LocalDecl.cdecl default fvarId userName type BinderInfo.default
|
||||
modify fun s =>
|
||||
{ s with
|
||||
newLocals := s.newLocals.insert fvarId,
|
||||
|
|
@ -764,15 +765,16 @@ private def isMatchUnit? (altLHSS : List Match.AltLHS) (rhss : Array Expr) : Met
|
|||
| Expr.lam _ _ b _ => return if b.hasLooseBVars then none else b
|
||||
| _ => return none
|
||||
| _ => return none
|
||||
private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr)
|
||||
|
||||
private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptMotive : Syntax) (expectedType : Expr)
|
||||
: TermElabM Expr := do
|
||||
let mut generalizing? := generalizing?
|
||||
if !matchOptType.isNone then
|
||||
if !matchOptMotive.isNone then
|
||||
if generalizing? == some true then
|
||||
throwError "the '(generalizing := true)' parameter is not supported when the 'match' type is explicitly provided"
|
||||
throwError "the '(generalizing := true)' parameter is not supported when the 'match' motive is explicitly provided"
|
||||
generalizing? := some false
|
||||
let (discrs, matchType, altLHSS, isDep, rhss) ← commitIfDidNotPostpone do
|
||||
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
|
||||
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptMotive altViews expectedType
|
||||
let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews
|
||||
trace[Elab.match] "matchType: {matchType}"
|
||||
let (discrs, matchType, alts, refined) ← elabMatchAltViews generalizing? discrs matchType matchAlts
|
||||
|
|
@ -845,16 +847,24 @@ private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax
|
|||
trace[Elab.match] "result: {r}"
|
||||
return r
|
||||
|
||||
private def getDiscrs (matchStx : Syntax) : Array Syntax :=
|
||||
matchStx[2].getSepArgs
|
||||
-- leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
|
||||
|
||||
private def getMatchOptType (matchStx : Syntax) : Syntax :=
|
||||
matchStx[3]
|
||||
private def getDiscrs (matchStx : Syntax) : Array Syntax :=
|
||||
if matchStx[3].isNone then -- HACK for bootstrapping issues
|
||||
matchStx[2].getSepArgs
|
||||
else
|
||||
matchStx[3].getSepArgs
|
||||
|
||||
private def getMatchOptMotive (matchStx : Syntax) : Syntax :=
|
||||
if !matchStx[2].isNone && matchStx[2][0].isOfKind ``Lean.Parser.Term.matchDiscr then -- HACK for bootstrapping issues
|
||||
mkNullNode
|
||||
else
|
||||
matchStx[2]
|
||||
|
||||
private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) :=
|
||||
let matchOptType := getMatchOptType matchStx;
|
||||
if matchOptType.isNone then do
|
||||
let discrs := getDiscrs matchStx;
|
||||
let matchOptMotive := getMatchOptMotive matchStx
|
||||
if matchOptMotive.isNone then do
|
||||
let discrs := getDiscrs matchStx
|
||||
let allLocal ← discrs.allM fun discr => Option.isSome <$> isAtomicDiscr? discr[1]
|
||||
if allLocal then
|
||||
return none
|
||||
|
|
@ -864,17 +874,17 @@ private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Synta
|
|||
let rec loop (discrs : List Syntax) (discrsNew : Array Syntax) (foundFVars : FVarIdSet) := do
|
||||
match discrs with
|
||||
| [] =>
|
||||
let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", ");
|
||||
pure (matchStx.setArg 2 discrs)
|
||||
let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", ")
|
||||
pure (matchStx.setArg 3 discrs)
|
||||
| discr :: discrs =>
|
||||
-- Recall that
|
||||
-- matchDiscr := leading_parser optional (ident >> ":") >> termParser
|
||||
let term := discr[1]
|
||||
let addAux : TermElabM Syntax := withFreshMacroScope do
|
||||
let d ← `(_discr);
|
||||
let d ← `(_discr)
|
||||
unless isAuxDiscrName d.getId do -- Use assertion?
|
||||
throwError "unexpected internal auxiliary discriminant name"
|
||||
let discrNew := discr.setArg 1 d;
|
||||
let discrNew := discr.setArg 1 d
|
||||
let r ← loop discrs (discrsNew.push discrNew) foundFVars
|
||||
`(let _discr := $term; $r)
|
||||
match (← isAtomicDiscr? term) with
|
||||
|
|
@ -893,7 +903,7 @@ private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := d
|
|||
|
||||
private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := do
|
||||
-- We don't wait for the discriminants types when match type is provided by user
|
||||
if getMatchOptType matchStx |>.isNone then
|
||||
if getMatchOptMotive matchStx |>.isNone then
|
||||
let discrs := getDiscrs matchStx
|
||||
for discr in discrs do
|
||||
let term := discr[1]
|
||||
|
|
@ -943,17 +953,17 @@ private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Optio
|
|||
|
||||
/-
|
||||
```
|
||||
leading_parser:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts
|
||||
leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
|
||||
```
|
||||
Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`.
|
||||
-/
|
||||
private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
let expectedType ← waitExpectedTypeAndDiscrs stx expectedType?
|
||||
let discrStxs := (getDiscrs stx).map fun d => d
|
||||
let gen? := getMatchGeneralizing? stx
|
||||
let altViews := getMatchAlts stx
|
||||
let matchOptType := getMatchOptType stx
|
||||
elabMatchAux gen? discrStxs altViews matchOptType expectedType
|
||||
let expectedType ← waitExpectedTypeAndDiscrs stx expectedType?
|
||||
let discrStxs := (getDiscrs stx).map fun d => d
|
||||
let gen? := getMatchGeneralizing? stx
|
||||
let altViews := getMatchAlts stx
|
||||
let matchOptMotive := getMatchOptMotive stx
|
||||
elabMatchAux gen? discrStxs altViews matchOptMotive expectedType
|
||||
|
||||
private def isPatternVar (stx : Syntax) : TermElabM Bool := do
|
||||
match (← resolveId? stx "pattern") with
|
||||
|
|
@ -969,7 +979,7 @@ where
|
|||
isAtomicIdent (stx : Syntax) : Bool :=
|
||||
stx.isIdent && stx.getId.eraseMacroScopes.isAtomic
|
||||
|
||||
-- leading_parser "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts
|
||||
-- leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
|
||||
/--
|
||||
Pattern matching. `match e, ... with | p, ... => f | ...` matches each given
|
||||
term `e` against each pattern `p` of a match alternative. When all patterns
|
||||
|
|
@ -989,10 +999,10 @@ where
|
|||
match (← expandNonAtomicDiscrs? stx) with
|
||||
| some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
|
||||
| none =>
|
||||
let discrs := getDiscrs stx;
|
||||
let matchOptType := getMatchOptType stx;
|
||||
if !matchOptType.isNone && discrs.any fun d => !d[0].isNone then
|
||||
throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used"
|
||||
let discrs := getDiscrs stx
|
||||
let matchOptMotive := getMatchOptMotive stx
|
||||
if !matchOptMotive.isNone && discrs.any fun d => !d[0].isNone then
|
||||
throwErrorAt matchOptMotive "match motive should not be provided when discriminants with equality proofs are used"
|
||||
elabMatchCore stx expectedType?
|
||||
|
||||
builtin_initialize
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ def doForDecl := leading_parser termParser >> " in " >> withForbidden "do" termP
|
|||
@[builtinDoElemParser] def doFor := leading_parser "for " >> sepBy1 doForDecl ", " >> "do " >> doSeq
|
||||
|
||||
def doMatchAlts := ppDedent <| matchAlts (rhsParser := doSeq)
|
||||
@[builtinDoElemParser] def doMatch := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts
|
||||
@[builtinDoElemParser] def doMatch := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> optional Term.motive >> sepBy1 matchDiscr ", " >> " with " >> doMatchAlts
|
||||
|
||||
def doCatch := leading_parser atomic ("catch " >> binderIdent) >> optional (" : " >> termParser) >> darrow >> doSeq
|
||||
def doCatchMatch := leading_parser "catch " >> doMatchAlts
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ builtin_initialize
|
|||
|
||||
def matchRhs := Term.hole <|> Term.syntheticHole <|> tacticSeq
|
||||
def matchAlts := Term.matchAlts (rhsParser := matchRhs)
|
||||
@[builtinTacticParser] def «match» := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> sepBy1 Term.matchDiscr ", " >> Term.optType >> " with " >> ppDedent matchAlts
|
||||
@[builtinTacticParser] def «match» := leading_parser:leadPrec "match " >> optional Term.generalizingParam >> optional Term.motive >> sepBy1 Term.matchDiscr ", " >> " with " >> ppDedent matchAlts
|
||||
@[builtinTacticParser] def introMatch := leading_parser nonReservedSymbol "intro " >> matchAlts
|
||||
|
||||
@[builtinTacticParser] def decide := leading_parser nonReservedSymbol "decide"
|
||||
|
|
|
|||
|
|
@ -150,7 +150,9 @@ def trueVal := leading_parser nonReservedSymbol "true"
|
|||
def falseVal := leading_parser nonReservedSymbol "false"
|
||||
def generalizingParam := leading_parser atomic ("(" >> nonReservedSymbol "generalizing") >> " := " >> (trueVal <|> falseVal) >> ")"
|
||||
|
||||
@[builtinTermParser] def «match» := leading_parser:leadPrec "match " >> optional generalizingParam >> sepBy1 matchDiscr ", " >> optType >> " with " >> ppDedent matchAlts
|
||||
def motive := leading_parser atomic ("(" >> nonReservedSymbol "motive" >> " := ") >> termParser >> ")"
|
||||
|
||||
@[builtinTermParser] def «match» := leading_parser:leadPrec "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
|
||||
@[builtinTermParser] def «nomatch» := leading_parser:leadPrec "nomatch " >> termParser
|
||||
|
||||
def funImplicitBinder := atomic (lookahead ("{" >> many1 binderIdent >> (symbol " : " <|> "}"))) >> implicitBinder
|
||||
|
|
|
|||
|
|
@ -421,7 +421,7 @@ def delabAppMatch : Delab := whenPPOption getPPNotation <| whenPPOption getPPMat
|
|||
let opts ← getOptions
|
||||
-- TODO: disable the match if other implicits are needed?
|
||||
if ← pure st.motiveNamed <||> shouldShowMotive lamMotive opts then
|
||||
`(match $[$st.discrs:term],* : $piStx with $[| $pats,* => $st.rhss]*)
|
||||
`(match (motive := $piStx) $[$st.discrs:term],* with $[| $pats,* => $st.rhss]*)
|
||||
else
|
||||
`(match $[$st.discrs:term],* with $[| $pats,* => $st.rhss]*)
|
||||
return Syntax.mkApp stx st.moreArgs
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ match h:foo b with
|
|||
#eval checkWithMkMatcherInput ``Bla.isNat2?.match_1
|
||||
|
||||
def foo2 (x : Nat) : Nat :=
|
||||
match x, rfl : (y : Nat) → x = y → Nat with
|
||||
match (motive := (y : Nat) → x = y → Nat) x, rfl with
|
||||
| 0, h => 0
|
||||
| x+1, h => 1
|
||||
#eval checkWithMkMatcherInput ``foo2.match_1
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ def x := 1
|
|||
#check foo x x
|
||||
|
||||
#check match 1 with | x => x + 1
|
||||
#check match 1 : Int -> _ with | x => x + 1
|
||||
#check match (motive := Int → _) 1 with | x => x + 1
|
||||
#check match 1 with | x => x + 1
|
||||
#check match 1 : Int -> _ with | x => x + 1
|
||||
#check match (motive := Int → _) 1 with | x => x + 1
|
||||
|
||||
def g (x : Nat × Nat) (y : Nat) :=
|
||||
x.1 + x.2 + y
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue