From 755d9dedbe75de5af4370efdbc8a19f7545933df Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 15 Oct 2020 09:21:58 -0700 Subject: [PATCH] chore: move to new frontend --- src/Lean/Elab/Match.lean | 754 +++++++++++++++++++-------------------- 1 file changed, 370 insertions(+), 384 deletions(-) diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index f2d2ddf190..0688260fb9 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -1,3 +1,4 @@ +#lang lean4 /- Copyright (c) 2020 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. @@ -8,10 +9,7 @@ import Lean.Meta.Match.Match import Lean.Elab.SyntheticMVars import Lean.Elab.App -namespace Lean -namespace Elab -namespace Term - +namespace Lean.Elab.Term open Meta /- This modules assumes "match"-expressions use the following syntax. @@ -38,39 +36,36 @@ structure MatchAltView := (rhs : Syntax) def mkMatchAltView (ref : Syntax) (matchAlt : Syntax) : MatchAltView := -{ ref := ref, patterns := (matchAlt.getArg 0).getSepArgs, rhs := matchAlt.getArg 2 } +{ ref := ref, patterns := matchAlt[0].getSepArgs, rhs := matchAlt[2] } private def expandSimpleMatch (stx discr lhsVar rhs : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do -newStx ← `(let $lhsVar := $discr; $rhs); +let newStx ← `(let $lhsVar := $discr; $rhs) withMacroExpansion stx newStx $ elabTerm newStx expectedType? private def expandSimpleMatchWithType (stx discr lhsVar type rhs : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do -newStx ← `(let $lhsVar : $type := $discr; $rhs); +let newStx ← `(let $lhsVar : $type := $discr; $rhs) withMacroExpansion stx newStx $ elabTerm newStx expectedType? -private partial def elabDiscrsWitMatchTypeAux (discrStxs : Array Syntax) (expectedType : Expr) : Nat → Expr → Array Expr → TermElabM (Array Expr) -| i, matchType, discrs => - if h : i < discrStxs.size then do - let discrStx := (discrStxs.get ⟨i, h⟩).getArg 1; - matchType ← whnf matchType; - match matchType with - | Expr.forallE _ d b _ => do - discr ← fullApproxDefEq $ elabTermEnsuringType discrStx d; - trace `Elab.match fun _ => "discr #" ++ toString i ++ " " ++ discr ++ " : " ++ d; - elabDiscrsWitMatchTypeAux (i+1) (b.instantiate1 discr) (discrs.push discr) - | _ => throwError ("invalid type provided to match-expression, function type with arity #" ++ toString discrStxs ++ " expected") - else do - unlessM (fullApproxDefEq $ isDefEq matchType expectedType) $ - throwError ("invalid result type provided to match-expression" ++ indentExpr matchType ++ Format.line ++ "expected type" ++ indentExpr expectedType); - pure discrs - -private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) := -elabDiscrsWitMatchTypeAux discrStxs expectedType 0 matchType #[] +private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) := do +let discrs := #[] +let i := 0 +for discrStx in discrStxs do + i := i + 1 + matchType ← whnf matchType + match matchType with + | Expr.forallE _ d b _ => + let discr ← fullApproxDefEq $ elabTermEnsuringType discrStx d + trace[Elab.match]! "discr #{i} {discr} : {d}" + matchType ← b.instantiate1 discr + discrs := discrs.push discr + | _ => + throwError! "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected" +pure discrs private def mkUserNameFor (e : Expr) : TermElabM Name := match e with -| Expr.fvar fvarId _ => do localDecl ← getLocalDecl fvarId; pure localDecl.userName -| _ => mkFreshBinderName +| Expr.fvar fvarId _ => do pure (← getLocalDecl fvarId).userName +| _ => mkFreshBinderName -- `expandNonAtomicDiscrs?` create auxiliary variables with base name `_discr` private def isAuxDiscrName (n : Name) : Bool := @@ -78,51 +73,50 @@ n.eraseMacroScopes == `_discr -- See expandNonAtomicDiscrs? private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do -let term := discr.getArg 1; -local? ← isLocalIdent? term; -match local? with -| some e@(Expr.fvar fvarId _) => do - localDecl ← getLocalDecl fvarId; +let term := discr[1] +match (← isLocalIdent? term) with +| some e@(Expr.fvar fvarId _) => + let localDecl ← getLocalDecl fvarId if !isAuxDiscrName localDecl.userName then pure e -- it is not an auxiliary local created by `expandNonAtomicDiscrs?` else pure localDecl.value | _ => throwErrorAt discr "unexpected discriminant" -private def elabMatchTypeAndDiscrsAux (discrStxs : Array Syntax) : Nat → Array Expr → Expr → Array MatchAltView → TermElabM (Array Expr × Expr × Array MatchAltView) -| 0, discrs, matchType, matchAltViews => pure (discrs.reverse, matchType, matchAltViews) -| i+1, discrs, matchType, matchAltViews => do - let discrStx := discrStxs.get! i; - discr ← elabAtomicDiscr discrStx; - discr ← instantiateMVars discr; - discrType ← inferType discr; - discrType ← instantiateMVars discrType; - matchTypeBody ← kabstract matchType discr; - userName ← mkUserNameFor discr; - if (discrStx.getArg 0).isNone then do - elabMatchTypeAndDiscrsAux i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) matchAltViews - else - let identStx := (discrStx.getArg 0).getArg 0; - withLocalDeclD userName discrType fun x => do - eqType ← mkEq discr x; - withLocalDeclD identStx.getId eqType fun h => do - let matchTypeBody := matchTypeBody.instantiate1 x; - matchType ← mkForallFVars #[x, h] matchTypeBody; - refl ← mkEqRefl discr; - let discrs := (discrs.push refl).push discr; - let matchAltViews := matchAltViews.map fun altView => - { altView with patterns := altView.patterns.insertAt (i+1) identStx }; - elabMatchTypeAndDiscrsAux i discrs matchType matchAltViews - private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr) - : TermElabM (Array Expr × Expr × Array MatchAltView) := -let numDiscrs := discrStxs.size; -if matchOptType.isNone then do - elabMatchTypeAndDiscrsAux discrStxs discrStxs.size #[] expectedType matchAltViews -else do - let matchTypeStx := (matchOptType.getArg 0).getArg 1; - matchType ← elabType matchTypeStx; - discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType; + : TermElabM (Array Expr × Expr × Array MatchAltView) := do +let numDiscrs := discrStxs.size +if matchOptType.isNone then + let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (matchAltViews : Array MatchAltView) := do + match i with + | 0 => pure (discrs.reverse, matchType, matchAltViews) + | i+1 => + let discrStx := discrStxs[i] + let discr ← elabAtomicDiscr discrStx + let discr ← instantiateMVars discr + let discrType ← inferType discr + let discrType ← instantiateMVars discrType + let matchTypeBody ← kabstract matchType discr + let userName ← mkUserNameFor discr + if discrStx[0].isNone then + loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) matchAltViews + else + let identStx := discrStx[0][0] + withLocalDeclD userName discrType fun x => do + let eqType ← mkEq discr x + withLocalDeclD identStx.getId eqType fun h => do + let matchTypeBody := matchTypeBody.instantiate1 x + let matchType ← mkForallFVars #[x, h] matchTypeBody + let refl ← mkEqRefl discr + let discrs := (discrs.push refl).push discr + let matchAltViews := matchAltViews.map fun altView => + { altView with patterns := altView.patterns.insertAt (i+1) identStx } + loop i discrs matchType matchAltViews + loop discrStxs.size #[] expectedType matchAltViews +else + let matchTypeStx := matchOptType[0][1] + let matchType ← elabType matchTypeStx + let discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType pure (discrs, matchType, matchAltViews) /- @@ -130,27 +124,23 @@ nodeWithAntiquot "matchAlt" `Lean.Parser.Term.matchAlt $ sepBy1 termParser ", " -/ def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array MatchAltView) := do matchAlts.mapM fun matchAlt => do - patterns ← matchAlt.patterns.mapM $ expandMacros; - pure $ { matchAlt with patterns := patterns } - -private partial def getMatchAltsAux (args : Array Syntax) : Nat → Syntax → Array MatchAltView → Array MatchAltView -| i, ref, result => - if h : i < args.size then - let arg := args.get ⟨i, h⟩; - let ref := if ref.isNone then arg else ref; -- The first vertical is optional - if arg.getKind == `Lean.Parser.Term.matchAlt then - getMatchAltsAux (i+1) ref (result.push (mkMatchAltView ref arg)) - else - -- current `arg` is the vertical bar delimiter - getMatchAltsAux (i+1) arg result - else - result + let patterns ← matchAlt.patterns.mapM expandMacros + pure { matchAlt with patterns := patterns } /- Given `stx` a match-expression, return its alternatives. -/ -private def getMatchAlts (stx : Syntax) : Array MatchAltView := -let matchAlts := stx.getArg 4; -let firstVBar := matchAlts.getArg 0; -getMatchAltsAux (matchAlts.getArg 1).getArgs 0 firstVBar #[] +private def getMatchAlts (stx : Syntax) : Array MatchAltView := do +let matchAlts := stx[4] +let firstVBar := matchAlts[0] +let ref := firstVBar +let result := #[] +for arg in matchAlts[1].getArgs do + if ref.isNone then ref := arg -- The first vertical bar is optional + if arg.getKind == `Lean.Parser.Term.matchAlt then + result := result.push (mkMatchAltView ref arg) + else + ref := arg -- current `arg` is a vertical bar +pure result + /-- Auxiliary annotation used to mark terms marked with the "inaccessible" annotation `.(t)` and @@ -167,9 +157,9 @@ inductive PatternVar | anonymousVar (mvarId : MVarId) instance PatternVar.hasToString : HasToString PatternVar := -⟨fun v => match v with +⟨fun | PatternVar.localVar x => toString x - | PatternVar.anonymousVar mvarId => "?m" ++ toString mvarId⟩ + | PatternVar.anonymousVar mvarId => s!"?m{mvarId}"⟩ @[init] private def registerAuxiliaryNodeKind : IO Unit := Parser.registerBuiltinNodeKind `MVarWithIdKind @@ -179,12 +169,12 @@ Parser.registerBuiltinNodeKind `MVarWithIdKind We use this kind of Syntax for representing `_` occurring in patterns. The metavariables are created before we elaborate the patterns into `Expr`s. -/ private def mkMVarSyntax : TermElabM Syntax := do -mvarId ← mkFreshId; +let mvarId ← mkFreshId pure $ Syntax.node `MVarWithIdKind #[Syntax.node mvarId #[]] /-- Given a syntax node constructed using `mkMVarSyntax`, return its MVarId -/ private def getMVarSyntaxMVarId (stx : Syntax) : MVarId := -(stx.getArg 0).getKind +stx[0].getKind /-- The elaboration function for `Syntax` created using `mkMVarSyntax`. @@ -194,7 +184,7 @@ fun stx expectedType? => pure $ mkInaccessible $ mkMVar (getMVarSyntaxMVarId stx @[builtinTermElab inaccessible] def elabInaccessible : TermElab := fun stx expectedType? => do - e ← elabTerm (stx.getArg 1) expectedType?; + let e ← elabTerm stx[1] expectedType? pure $ mkInaccessible e /- @@ -230,22 +220,23 @@ private def throwCtorExpected {α} : M α := throwError "invalid pattern, constructor or constant marked with '[matchPattern]' expected" private def getNumExplicitCtorParams (ctorVal : ConstructorVal) : TermElabM Nat := -forallBoundedTelescope ctorVal.type ctorVal.nparams fun ps _ => - ps.foldlM - (fun acc p => do - localDecl ← getLocalDecl p.fvarId!; - if localDecl.binderInfo.isExplicit then pure $ acc+1 else pure acc) - 0 +forallBoundedTelescope ctorVal.type ctorVal.nparams fun ps _ => do + let result := 0 + for p in ps do + let localDecl ← getLocalDecl p.fvarId! + if localDecl.binderInfo.isExplicit then + result := result+1 + pure result private def throwAmbiguous {α} (fs : List Expr) : M α := -throwError ("ambiguous pattern, use fully qualified name, possible interpretations " ++ fs) +throwError! "ambiguous pattern, use fully qualified name, possible interpretations {fs}" def resolveId? (stx : Syntax) : M (Option Expr) := match stx with | Syntax.ident _ _ val preresolved => do - rs ← liftM $ catch (resolveName val preresolved []) (fun _ => pure []); - let rs := rs.filter fun ⟨f, projs⟩ => projs.isEmpty; - let fs := rs.map fun ⟨f, _⟩ => f; + let rs ← try resolveName val preresolved [] catch _ => pure [] + let rs := rs.filter fun ⟨f, projs⟩ => projs.isEmpty + let fs := rs.map fun (f, _) => f match fs with | [] => pure none | [f] => pure (some f) @@ -284,34 +275,34 @@ instance Context.inhabited : Inhabited Context := private def isDone (ctx : Context) : Bool := ctx.paramDeclIdx ≥ ctx.paramDecls.size -private def finalize (ctx : Context) : M Syntax := -if ctx.namedArgs.isEmpty && ctx.args.isEmpty then do - fStx ← `(@$(ctx.funId):ident); +private def finalize (ctx : Context) : M Syntax := do +if ctx.namedArgs.isEmpty && ctx.args.isEmpty then + let fStx ← `(@$(ctx.funId):ident) pure $ mkAppStx fStx ctx.newArgs else throwError "too many arguments" private def isNextArgAccessible (ctx : Context) : Bool := -let i := ctx.paramDeclIdx; +let i := ctx.paramDeclIdx match ctx.ctorVal? with | some ctorVal => i ≥ ctorVal.nparams -- For constructor applications only fields are accessible | none => if h : i < ctx.paramDecls.size then -- For `[matchPattern]` applications, only explicit parameters are accessible. - let d := ctx.paramDecls.get ⟨i, h⟩; + let d := ctx.paramDecls.get ⟨i, h⟩ d.binderInfo.isExplicit else false private def getNextParam (ctx : Context) : LocalDecl × Context := -let i := ctx.paramDeclIdx; -let d := ctx.paramDecls.get! i; +let i := ctx.paramDeclIdx +let d := ctx.paramDecls[i] (d, { ctx with paramDeclIdx := ctx.paramDeclIdx + 1 }) private def pushNewArg (collect : Syntax → M Syntax) (accessible : Bool) (ctx : Context) (arg : Arg) : M Context := match arg with | Arg.stx stx => do - stx ← if accessible then collect stx else pure stx; + let stx ← if accessible then collect stx else pure stx pure { ctx with newArgs := ctx.newArgs.push stx } | _ => unreachable! @@ -319,89 +310,89 @@ private def processExplicitArg (collect : Syntax → M Syntax) (accessible : Boo match ctx.args with | [] => if ctx.ellipsis then do - hole ← `(_); + let hole ← `(_) pushNewArg collect accessible ctx (Arg.stx hole) else - throwError ("explicit parameter is missing, unused named arguments " ++ toString (ctx.namedArgs.map $ fun narg => narg.name)) + throwError! "explicit parameter is missing, unused named arguments {ctx.namedArgs.map fun narg => narg.name}" | arg::args => do - let ctx := { ctx with args := args }; + let ctx := { ctx with args := args } pushNewArg collect accessible ctx arg private def processImplicitArg (collect : Syntax → M Syntax) (accessible : Bool) (ctx : Context) : M Context := if ctx.explicit then processExplicitArg collect accessible ctx else do - hole ← `(_); + let hole ← `(_) pushNewArg collect accessible ctx (Arg.stx hole) private partial def processCtorAppAux (collect : Syntax → M Syntax) : Context → M Syntax -| ctx => - if isDone ctx then finalize ctx +| ctx => do + if isDone ctx then + finalize ctx else - let accessible := isNextArgAccessible ctx; - let (d, ctx) := getNextParam ctx; - match ctx.namedArgs.findIdx? (fun namedArg => namedArg.name == d.userName) with - | some idx => do - let arg := ctx.namedArgs.get! idx; - let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx }; - ctx ← pushNewArg collect accessible ctx arg.val; - processCtorAppAux ctx - | none => do - ctx ← match d.binderInfo with - | BinderInfo.implicit => processImplicitArg collect accessible ctx - | BinderInfo.instImplicit => processImplicitArg collect accessible ctx - | _ => processExplicitArg collect accessible ctx; - processCtorAppAux ctx + let accessible := isNextArgAccessible ctx + let (d, ctx) := getNextParam ctx + match ctx.namedArgs.findIdx? fun namedArg => namedArg.name == d.userName with + | some idx => + let arg := ctx.namedArgs[idx] + let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx } + let ctx ← pushNewArg collect accessible ctx arg.val + processCtorAppAux collect ctx + | none => + let ctx ← match d.binderInfo with + | BinderInfo.implicit => processImplicitArg collect accessible ctx + | BinderInfo.instImplicit => processImplicitArg collect accessible ctx + | _ => processExplicitArg collect accessible ctx + processCtorAppAux collect ctx def processCtorApp (collect : Syntax → M Syntax) (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (ellipsis : Bool) : M Syntax := do -let args := args.toList; -(fId, explicit) ← match_syntax f with -| `($fId:ident) => pure (fId, false) -| `(@$fId:ident) => pure (fId, true) -| _ => throwError "identifier expected"; -some (Expr.const fName _ _) ← resolveId? fId | throwCtorExpected; -fInfo ← getConstInfo fName; +let args := args.toList +let (fId, explicit) ← match_syntax f with + | `($fId:ident) => pure (fId, false) + | `(@$fId:ident) => pure (fId, true) + | _ => throwError "identifier expected" +let some (Expr.const fName _ _) ← resolveId? fId | throwCtorExpected +let fInfo ← getConstInfo fName forallTelescopeReducing fInfo.type fun xs _ => do -paramDecls ← xs.mapM getFVarLocalDecl; -match fInfo with -| ConstantInfo.ctorInfo val => - processCtorAppAux collect - { funId := fId, explicit := explicit, ctorVal? := val, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis } -| _ => do - env ← getEnv; - if hasMatchPatternAttribute env fName then + let paramDecls ← xs.mapM getFVarLocalDecl + match fInfo with + | ConstantInfo.ctorInfo val => processCtorAppAux collect - { funId := fId, explicit := explicit, ctorVal? := none, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis } - else - throwCtorExpected + { funId := fId, explicit := explicit, ctorVal? := val, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis } + | _ => + let env ← getEnv + if hasMatchPatternAttribute env fName then + processCtorAppAux collect + { funId := fId, explicit := explicit, ctorVal? := none, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis } + else + throwCtorExpected end CtorApp def processCtorApp (collect : Syntax → M Syntax) (stx : Syntax) : M Syntax := do -(f, namedArgs, args, ellipsis) ← liftM $ expandApp stx true; +let (f, namedArgs, args, ellipsis) ← liftM $ expandApp stx true CtorApp.processCtorApp collect f namedArgs args ellipsis def processCtor (collect : Syntax → M Syntax) (stx : Syntax) : M Syntax := do CtorApp.processCtorApp collect stx #[] #[] false private def processVar (idStx : Syntax) : M Syntax := do -unless idStx.isIdent $ - throwErrorAt idStx "identifier expected"; -let id := idStx.getId; -unless id.eraseMacroScopes.isAtomic $ throwError "invalid pattern variable, must be atomic"; -s ← get; -when (s.found.contains id) $ throwError ("invalid pattern, variable '" ++ id ++ "' occurred more than once"); -modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := s.found.insert id }; +unless idStx.isIdent do + throwErrorAt idStx "identifier expected" +let id := idStx.getId +unless id.eraseMacroScopes.isAtomic do throwError "invalid pattern variable, must be atomic" +let s ← get +if s.found.contains id then throwError! "invalid pattern, variable '{id}' occurred more than once" +modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := s.found.insert id } pure idStx /- Check whether `stx` is a pattern variable or constructor-like (i.e., constructor or constant tagged with `[matchPattern]` attribute) -/ private def processId (collect : Syntax → M Syntax) (stx : Syntax) : M Syntax := do -env ← getEnv; -f? ← resolveId? stx; -match f? with +let env ← getEnv +match (← resolveId? stx) with | none => processVar stx | some f => match f with - | Expr.const fName _ _ => do + | Expr.const fName _ _ => match env.find? fName with | some (ConstantInfo.ctorInfo _) => processCtor collect stx | some _ => @@ -414,69 +405,68 @@ match f? with private def nameToPattern : Name → TermElabM Syntax | Name.anonymous => `(Name.anonymous) -| Name.str p s _ => do p ← nameToPattern p; `(Name.str $p $(quote s) _) -| Name.num p n _ => do p ← nameToPattern p; `(Name.num $p $(quote n) _) +| Name.str p s _ => do let p ← nameToPattern p; `(Name.str $p $(quote s) _) +| Name.num p n _ => do let p ← nameToPattern p; `(Name.num $p $(quote n) _) private def quotedNameToPattern (stx : Syntax) : TermElabM Syntax := -match (stx.getArg 0).isNameLit? with +match stx[0].isNameLit? with | some val => nameToPattern val | none => throwIllFormedSyntax partial def collect : Syntax → M Syntax -| stx@(Syntax.node k args) => withRef stx $ withFreshMacroScope $ - if k == `Lean.Parser.Term.app then do +| stx@(Syntax.node k args) => withRef stx $ withFreshMacroScope do + if k == `Lean.Parser.Term.app then processCtorApp collect stx - else if k == `Lean.Parser.Term.anonymousCtor then do - elems ← (args.get! 1).getArgs.mapSepElemsM $ collect; - pure $ Syntax.node k $ args.set! 1 $ mkNullNode elems - else if k == `Lean.Parser.Term.structInst then do + else if k == `Lean.Parser.Term.anonymousCtor then + let elems ← args[1].getArgs.mapSepElemsM collect + pure $ Syntax.node k (args.set! 1 $ mkNullNode elems) + else if k == `Lean.Parser.Term.structInst then /- { " >> optional (try (termParser >> " with ")) >> sepBy structInstField ", " true >> optional ".." >> optional (" : " >> termParser) >> " }" -/ - let withMod := args.get! 1; - unless withMod.isNone $ - throwErrorAt withMod "invalid struct instance pattern, 'with' is not allowed in patterns"; - let fields := (args.get! 2).getArgs; - fields ← fields.mapSepElemsM fun field => do { + let withMod := args[1] + unless withMod.isNone do + throwErrorAt withMod "invalid struct instance pattern, 'with' is not allowed in patterns" + let fields := args[2].getArgs + let fields ← fields.mapSepElemsM fun field => do -- parser! structInstLVal >> " := " >> termParser - newVal ← collect (field.getArg 3); -- `structInstLVal` has arity 2 + let newVal ← collect field[3] -- `structInstLVal` has arity 2 pure $ field.setArg 3 newVal - }; - pure $ Syntax.node k $ args.set! 2 $ mkNullNode fields - else if k == `Lean.Parser.Term.hole then do - r ← liftM mkMVarSyntax; - modify fun s => { s with vars := s.vars.push $ PatternVar.anonymousVar $ getMVarSyntaxMVarId r }; + pure $ Syntax.node k (args.set! 2 $ mkNullNode fields) + else if k == `Lean.Parser.Term.hole then + let r ← mkMVarSyntax + modify fun s => { s with vars := s.vars.push $ PatternVar.anonymousVar $ getMVarSyntaxMVarId r } pure r else if k == `Lean.Parser.Term.paren then - let arg := args.get! 1; + let arg := args[1] if arg.isNone then pure stx -- `()` - else do - let t := arg.getArg 0; - let s := arg.getArg 1; - if s.isNone || (s.getArg 0).isOfKind `Lean.Parser.Term.typeAscription then do + else + let t := arg[0] + let s := arg[1] + if s.isNone || s[0].getKind == `Lean.Parser.Term.typeAscription then -- Ignore `s`, since it empty or it is a type ascription - t ← collect t; - let arg := arg.setArg 0 t; - pure $ Syntax.node k $ args.set! 1 arg - else do + let t ← collect t + let arg := arg.setArg 0 t + pure $ Syntax.node k (args.set! 1 arg) + else -- Tuple literal is a constructor - t ← collect t; - let arg := arg.setArg 0 t; - let tupleTail := s.getArg 0; - let tupleTailElems := (tupleTail.getArg 1).getArgs; - tupleTailElems ← tupleTailElems.mapSepElemsM collect; - let tupleTail := tupleTail.setArg 1 $ mkNullNode tupleTailElems; - let s := s.setArg 0 tupleTail; - let arg := arg.setArg 1 s; - pure $ Syntax.node k $ args.set! 1 arg - else if k == `Lean.Parser.Term.explicitUniv then do - processCtor collect (stx.getArg 0) - else if k == `Lean.Parser.Term.namedPattern then do + let t ← collect t + let arg := arg.setArg 0 t + let tupleTail := s[0] + let tupleTailElems := tupleTail[1].getArgs + let tupleTailElems ← tupleTailElems.mapSepElemsM collect + let tupleTail := tupleTail.setArg 1 $ mkNullNode tupleTailElems + let s := s.setArg 0 tupleTail + let arg := arg.setArg 1 s + pure $ Syntax.node k (args.set! 1 arg) + else if k == `Lean.Parser.Term.explicitUniv then + processCtor collect stx[0] + else if k == `Lean.Parser.Term.namedPattern then /- Recall that def namedPattern := check... >> tparser! "@" >> termParser -/ - let id := stx.getArg 0; - processVar id; - let pat := stx.getArg 2; - pat ← collect pat; + let id := stx[0] + processVar id + let pat := stx[2] + pat ← collect pat `(namedPattern $id $pat) else if k == `Lean.Parser.Term.inaccessible then pure stx @@ -490,7 +480,7 @@ partial def collect : Syntax → M Syntax /- Quoted names have an elaboration function associated with them, and they will not be macro expanded. Note that macro expansion is not a good option since it produces a term using the smart constructors `mkNameStr`, `mkNameNum` instead of the constructors `Name.str` and `Name.num` -/ - liftM $ quotedNameToPattern stx + quotedNameToPattern stx else if k == choiceKind then throwError "invalid pattern, notation is ambiguous" else @@ -501,27 +491,30 @@ partial def collect : Syntax → M Syntax throwInvalidPattern def main (alt : MatchAltView) : M MatchAltView := do -patterns ← alt.patterns.mapM fun p => do { - trace `Elab.match fun _ => "collecting variables at pattern: " ++ p; +let patterns ← alt.patterns.mapM fun p => do + trace[Elab.match]! "collecting variables at pattern: {p}" collect p -}; pure { alt with patterns := patterns } end CollectPatternVars private def collectPatternVars (alt : MatchAltView) : TermElabM (Array PatternVar × MatchAltView) := do -(alt, s) ← (CollectPatternVars.main alt).run {}; +let (alt, s) ← (CollectPatternVars.main alt).run {} pure (s.vars, alt) /- Return the pattern variables in the given pattern. Remark: this method is not used here, but in other macros (e.g., at `Do.lean`). -/ def getPatternVars (patternStx : Syntax) : TermElabM (Array PatternVar) := do -patternStx ← liftMacroM $ expandMacros patternStx; -(_, s) ← (CollectPatternVars.collect patternStx).run {}; +let patternStx ← liftMacroM $ expandMacros patternStx +let (_, s) ← (CollectPatternVars.collect patternStx).run {} pure s.vars def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := do -(_, s) ← (patterns.mapM fun pattern => do { pattern ← liftMacroM $ expandMacros pattern; CollectPatternVars.collect pattern }).run {}; +let collect : CollectPatternVars.M Unit := do + for pattern in patterns do + CollectPatternVars.collect (← liftMacroM $ expandMacros pattern) + pure () +let (_, s) ← collect.run {} pure s.vars /- We convert the collected `PatternVar`s intro `PatternVarDecl` -/ @@ -531,69 +524,65 @@ inductive PatternVarDecl | anonymousVar (mvarId : MVarId) (fvarId : FVarId) | localVar (fvarId : FVarId) -private partial def withPatternVarsAux {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) - : Nat → Array PatternVarDecl → TermElabM α -| i, decls => +private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α := +let rec loop (i : Nat) (decls : Array PatternVarDecl) := do if h : i < pVars.size then match pVars.get ⟨i, h⟩ with - | PatternVar.anonymousVar mvarId => do - type ← mkFreshTypeMVar; - userName ← mkFreshBinderName; + | PatternVar.anonymousVar mvarId => + let type ← mkFreshTypeMVar + let userName ← mkFreshBinderName withLocalDecl userName BinderInfo.default type fun x => - withPatternVarsAux (i+1) (decls.push (PatternVarDecl.anonymousVar mvarId x.fvarId!)) - | PatternVar.localVar userName => do - type ← mkFreshTypeMVar; + loop (i+1) (decls.push (PatternVarDecl.anonymousVar mvarId x.fvarId!)) + | PatternVar.localVar userName => + let type ← mkFreshTypeMVar withLocalDecl userName BinderInfo.default type fun x => - withPatternVarsAux (i+1) (decls.push (PatternVarDecl.localVar x.fvarId!)) - else do + loop (i+1) (decls.push (PatternVarDecl.localVar x.fvarId!)) + else /- We must create the metavariables for `PatternVar.anonymousVar` AFTER we create the new local decls using `withLocalDecl`. Reason: their scope must include the new local decls since some of them are assigned by typing constraints. -/ decls.forM fun decl => match decl with | PatternVarDecl.anonymousVar mvarId fvarId => do - type ← inferType (mkFVar fvarId); - _ ← mkFreshExprMVarWithId mvarId type; + let type ← inferType (mkFVar fvarId) + mkFreshExprMVarWithId mvarId type pure () - | _ => pure (); + | _ => pure () k decls +loop 0 #[] -private def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α := -withPatternVarsAux pVars k 0 #[] +private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Array Expr × Expr) := do +let patterns := #[] +for patternStx in patternStxs do + matchType ← whnf matchType + match matchType with + | Expr.forallE _ d b _ => + let pattern ← elabTermEnsuringType patternStx d + matchType := b.instantiate1 pattern + patterns := patterns.push pattern + | _ => throwError "unexpected match type" +pure (patterns, matchType) -private partial def elabPatternsAux (patternStxs : Array Syntax) : Nat → Expr → Array Expr → TermElabM (Array Expr × Expr) -| i, matchType, patterns => - if h : i < patternStxs.size then do - matchType ← whnf matchType; - match matchType with - | Expr.forallE _ d b _ => do - let patternStx := patternStxs.get ⟨i, h⟩; - pattern ← elabTermEnsuringType patternStx d; - elabPatternsAux (i+1) (b.instantiate1 pattern) (patterns.push pattern) - | _ => throwError "unexpected match type" - else - pure (patterns, matchType) - -def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := -patternVarDecls.foldlM - (fun (decls : Array LocalDecl) pdecl => do - match pdecl with - | PatternVarDecl.localVar fvarId => do - decl ← getLocalDecl fvarId; - decl ← instantiateLocalDeclMVars decl; - pure $ decls.push decl - | PatternVarDecl.anonymousVar mvarId fvarId => do - e ← instantiateMVars (mkMVar mvarId); - trace `Elab.match fun _ => "finalizePatternDecls: mvarId: " ++ mvarId ++ " := " ++ e ++ ", fvarId: " ++ mkFVar fvarId; - match e with - | Expr.mvar newMVarId _ => do - /- Metavariable was not assigned, or assigned to another metavariable. So, - we assign to the auxiliary free variable we created at `withPatternVars` to `newMVarId`. -/ - assignExprMVar newMVarId (mkFVar fvarId); - trace `Elab.match fun _ => "finalizePatternDecls: " ++ mkMVar newMVarId ++ " := " ++ mkFVar fvarId; - decl ← getLocalDecl fvarId; - decl ← instantiateLocalDeclMVars decl; - pure $ decls.push decl - | _ => pure decls) - #[] +def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := do +let decls := #[] +for pdecl in patternVarDecls do + match pdecl with + | PatternVarDecl.localVar fvarId => + let decl ← getLocalDecl fvarId + let decl ← instantiateLocalDeclMVars decl + decls := decls.push decl + | PatternVarDecl.anonymousVar mvarId fvarId => + let e ← instantiateMVars (mkMVar mvarId); + trace[Elab.match]! "finalizePatternDecls: mvarId: {mvarId} := {e}, fvar: {mkFVar fvarId}" + match e with + | Expr.mvar newMVarId _ => + /- Metavariable was not assigned, or assigned to another metavariable. So, + we assign to the auxiliary free variable we created at `withPatternVars` to `newMVarId`. -/ + assignExprMVar newMVarId (mkFVar fvarId) + trace[Elab.match]! "finalizePatternDecls: {mkMVar newMVarId} := {mkFVar fvarId}" + let decl ← getLocalDecl fvarId + let decl ← instantiateLocalDeclMVars decl + decls := decls.push decl + | _ => pure () +pure decls open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS MatcherResult) @@ -607,121 +596,121 @@ structure State := abbrev M := StateRefT State TermElabM private def alreadyVisited (fvarId : FVarId) : M Bool := do -s ← get; +let s ← get pure $ s.found.contains fvarId private def markAsVisited (fvarId : FVarId) : M Unit := -modify $ fun s => { s with found := s.found.insert fvarId } +modify fun s => { s with found := s.found.insert fvarId } private def throwInvalidPattern {α} (e : Expr) : M α := -throwError ("invalid pattern " ++ indentExpr e) +throwError! "invalid pattern {indentExpr e}" /- Create a new LocalDecl `x` for the metavariable `mvar`, and return `Pattern.var x` -/ private def mkLocalDeclFor (mvar : Expr) : M Pattern := do -let mvarId := mvar.mvarId!; -s ← get; -val? ← getExprMVarAssignment? mvarId; -match val? with +let mvarId := mvar.mvarId! +let s ← get +match (← getExprMVarAssignment? mvarId) with | some val => pure $ Pattern.inaccessible val -| none => do - fvarId ← mkFreshId; - type ← inferType mvar; +| none => + let fvarId ← mkFreshId + let type ← inferType mvar /- HACK: `fvarId` is not in the scope of `mvarId` If this generates problems in the future, we should update the metavariable declarations. -/ - assignExprMVar mvarId (mkFVar fvarId); - userName ← liftM $ mkFreshBinderName; + assignExprMVar mvarId (mkFVar fvarId) + let userName ← liftM $ mkFreshBinderName let newDecl := LocalDecl.cdecl (arbitrary _) fvarId userName type BinderInfo.default; - modify $ fun s => + modify fun s => { s with newLocals := s.newLocals.insert fvarId, localDecls := match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with | none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar` - | some i => s.localDecls.insertAt i newDecl }; + | some i => s.localDecls.insertAt i newDecl } pure $ Pattern.var fvarId partial def main : Expr → M Pattern -| e => - let isLocalDecl (fvarId : FVarId) : M Bool := do { - s ← get; +| e => do + let isLocalDecl (fvarId : FVarId) : M Bool := do + let s ← get pure $ s.localDecls.any fun d => d.fvarId == fvarId - }; - let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do { - condM (alreadyVisited fvarId) - (pure $ Pattern.inaccessible e) - (do markAsVisited fvarId; pure $ Pattern.var e.fvarId!) - }; - let mkInaccessible (e : Expr) : M Pattern := do { + let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do + if (← alreadyVisited fvarId) then + pure $ Pattern.inaccessible e + else + markAsVisited fvarId + pure $ Pattern.var e.fvarId! + let mkInaccessible (e : Expr) : M Pattern := do match e with | Expr.fvar fvarId _ => - condM (isLocalDecl fvarId) - (mkPatternVar fvarId e) - (pure $ Pattern.inaccessible e) + if (← isLocalDecl fvarId) then + mkPatternVar fvarId e + else + pure $ Pattern.inaccessible e | _ => pure $ Pattern.inaccessible e - }; match inaccessible? e with | some t => mkInaccessible t | none => match e.arrayLit? with - | some (α, lits) => do - ps ← lits.mapM main; + | some (α, lits) => + let ps ← lits.mapM main; pure $ Pattern.arrayLit α ps | none => - if e.isAppOfArity `namedPattern 3 then do - p ← main $ e.getArg! 2; + if e.isAppOfArity `namedPattern 3 then + let p ← main $ e.getArg! 2; match e.getArg! 1 with | Expr.fvar fvarId _ => pure $ Pattern.as fvarId p | _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'" else if e.isNatLit || e.isStringLit || e.isCharLit then pure $ Pattern.val e - else if e.isFVar then do - let fvarId := e.fvarId!; - unlessM (isLocalDecl fvarId) $ throwInvalidPattern e; + else if e.isFVar then + let fvarId := e.fvarId! + unless(← isLocalDecl fvarId) do throwInvalidPattern e mkPatternVar fvarId e - else if e.isMVar then do + else if e.isMVar then mkLocalDeclFor e - else do - newE ← whnf e; + else + let newE ← whnf e if newE != e then main newE else matchConstCtor e.getAppFn (fun _ => throwInvalidPattern e) fun v us => do - let args := e.getAppArgs; - unless (args.size == v.nparams + v.nfields) $ throwInvalidPattern e; - let params := args.extract 0 v.nparams; - let fields := args.extract v.nparams args.size; - fields ← fields.mapM main; + let args := e.getAppArgs + unless args.size == v.nparams + v.nfields do + throwInvalidPattern e + let params := args.extract 0 v.nparams + let fields := args.extract v.nparams args.size + let fields ← fields.mapM main pure $ Pattern.ctor v.name us params.toList fields.toList end ToDepElimPattern def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k : Array LocalDecl → Array Pattern → TermElabM α) : TermElabM α := do -(patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls }; -localDecls ← s.localDecls.mapM fun d => instantiateLocalDeclMVars d; +let (patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls } +let localDecls ← s.localDecls.mapM fun d => instantiateLocalDeclMVars d /- toDepElimPatterns may have added new localDecls. Thus, we must update the local context before we execute `k` -/ -lctx ← getLCtx; -let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId) lctx; -let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx; +let lctx ← getLCtx +let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId) lctx +let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ k localDecls patterns private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) (k : AltLHS → Expr → TermElabM α) : TermElabM α := do -(patterns, matchType) ← withSynthesize $ elabPatternsAux patternStxs 0 matchType #[]; -localDecls ← finalizePatternDecls patternVarDecls; -patterns ← patterns.mapM instantiateMVars; +let (patterns, matchType) ← withSynthesize $ elabPatterns patternStxs matchType +let localDecls ← finalizePatternDecls patternVarDecls +let patterns ← patterns.mapM instantiateMVars withDepElimPatterns localDecls patterns fun localDecls patterns => k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) := withRef alt.ref do -(patternVars, alt) ← collectPatternVars alt; -trace `Elab.match fun _ => "patternVars: " ++ toString patternVars; +let (patternVars, alt) ← collectPatternVars alt +trace[Elab.match]! "patternVars: {patternVars}" withPatternVars patternVars fun patternVarDecls => do withElaboratedLHS alt.ref patternVarDecls alt.patterns matchType fun altLHS matchType => do - rhs ← elabTermEnsuringType alt.rhs matchType; - let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr; - rhs ← if xs.isEmpty then pure $ mkSimpleThunk rhs else mkLambdaFVars xs rhs; - trace `Elab.match fun _ => "rhs: " ++ rhs; + let rhs ← elabTermEnsuringType alt.rhs matchType + let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr + let rhs ← if xs.isEmpty then pure $ mkSimpleThunk rhs else mkLambdaFVars xs rhs + trace[Elab.match]! "rhs: {rhs}" pure (altLHS, rhs) def mkMatcher (elimName : Name) (matchType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult := @@ -729,87 +718,85 @@ liftMetaM $ Meta.Match.mkMatcher elimName matchType numDiscrs lhss def reportMatcherResultErrors (result : MatcherResult) : TermElabM Unit := do -- TODO: improve error messages -unless result.counterExamples.isEmpty $ - throwError ("missing cases:" ++ Format.line ++ Meta.Match.counterExamplesToMessageData result.counterExamples); -unless result.unusedAltIdxs.isEmpty $ - throwError ("unused alternatives: " ++ toString (result.unusedAltIdxs.map fun idx => "#" ++ toString (idx+1))) +unless result.counterExamples.isEmpty do + throwError! "missing cases:\n{Meta.Match.counterExamplesToMessageData result.counterExamples}" +unless result.unusedAltIdxs.isEmpty do + throwError! "unused alternatives: {result.unusedAltIdxs.map fun idx => s!"#{idx+1}"}" private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr) : TermElabM Expr := do -(discrs, matchType, altViews) ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType; -matchAlts ← liftMacroM $ expandMacrosInPatterns altViews; -trace `Elab.match fun _ => "matchType: " ++ matchType; -alts ← matchAlts.mapM $ fun alt => elabMatchAltView alt matchType; -synthesizeSyntheticMVarsNoPostponing; +let (discrs, matchType, altViews) ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType +let matchAlts ← liftMacroM $ expandMacrosInPatterns altViews +trace[Elab.match]! "matchType: {matchType}" +let alts ← matchAlts.mapM $ fun alt => elabMatchAltView alt matchType +synthesizeSyntheticMVarsNoPostponing -- TODO report error if matchType or altLHSS.toList have metavars -let rhss := alts.map Prod.snd; -let altLHSS := alts.map Prod.fst; -let numDiscrs := discrs.size; -matcherName ← mkAuxName `match; -matcherResult ← mkMatcher matcherName matchType numDiscrs altLHSS.toList; -motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType; -reportMatcherResultErrors matcherResult; -let r := mkApp matcherResult.matcher motive; -let r := mkAppN r discrs; -let r := mkAppN r rhss; -trace `Elab.match fun _ => "result: " ++ r; +let rhss := alts.map Prod.snd +let altLHSS := alts.map Prod.fst +let numDiscrs := discrs.size +let matcherName ← mkAuxName `match +let matcherResult ← mkMatcher matcherName matchType numDiscrs altLHSS.toList +let motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType +reportMatcherResultErrors matcherResult +let r := mkApp matcherResult.matcher motive +let r := mkAppN r discrs +let r := mkAppN r rhss +trace[Elab.match]! "result: {r}" pure r private def getDiscrs (matchStx : Syntax) : Array Syntax := -(matchStx.getArg 1).getSepArgs +matchStx[1].getSepArgs private def getMatchOptType (matchStx : Syntax) : Syntax := -matchStx.getArg 2 - -private def expandNonAtomicDiscrsAux (matchStx : Syntax) : List Syntax → Array Syntax → TermElabM Syntax -| [], discrsNew => - let discrs := mkSepStx discrsNew (mkAtomFrom matchStx ", "); - pure $ matchStx.setArg 1 discrs -| discr :: discrs, discrsNew => do - -- Recall that - -- matchDiscr := parser! optional (ident >> ":") >> termParser - let term := discr.getArg 1; - local? ← isLocalIdent? term; - match local? with - | some _ => expandNonAtomicDiscrsAux discrs (discrsNew.push discr) - | none => withFreshMacroScope do - d ← `(_discr); - unless (isAuxDiscrName d.getId) $ -- Use assertion? - throwError "unexpected internal auxiliary discriminant name"; - let discrNew := discr.setArg 1 d; - r ← expandNonAtomicDiscrsAux discrs (discrsNew.push discrNew); - `(let _discr := $term; $r) +matchStx[2] private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) := let matchOptType := getMatchOptType matchStx; if matchOptType.isNone then do let discrs := getDiscrs matchStx; - allLocal ← discrs.allM fun discr => Option.isSome <$> isLocalIdent? (discr.getArg 1); + let allLocal ← discrs.allM fun discr => Option.isSome <$> isLocalIdent? discr[1] if allLocal then pure none else - some <$> expandNonAtomicDiscrsAux matchStx discrs.toList #[] + let rec loop (discrs : List Syntax) (discrsNew : Array Syntax) := do + match discrs with + | [] => + let discrs := mkSepStx discrsNew (mkAtomFrom matchStx ", "); + pure (matchStx.setArg 1 discrs) + | discr :: discrs => + -- Recall that + -- matchDiscr := parser! optional (ident >> ":") >> termParser + let term := discr[1] + match (← isLocalIdent? term) with + | some _ => loop discrs (discrsNew.push discr) + | none => withFreshMacroScope do + let d ← `(_discr); + unless isAuxDiscrName d.getId do -- Use assertion? + throwError "unexpected internal auxiliary discriminant name" + let discrNew := discr.setArg 1 d; + let r ← loop discrs (discrsNew.push discrNew) + `(let _discr := $term; $r) + pure (some (← loop discrs.toList #[])) else -- We do not pull non atomic discriminants when match type is provided explicitly by the user pure none private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := do -tryPostponeIfNoneOrMVar expectedType?; +tryPostponeIfNoneOrMVar expectedType? match expectedType? with | some expectedType => pure expectedType | none => mkFreshTypeMVar -private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := +private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := do -- We don't wait for the discriminants types when match type is provided by user -when (getMatchOptType matchStx).isNone do - let discrs := getDiscrs matchStx; - discrs.forM fun discr => do - let term := discr.getArg 1; - local? ← isLocalIdent? term; - match local? with +if getMatchOptType matchStx $.isNone then + let discrs := getDiscrs matchStx + for discr in discrs do + let term := discr[1] + match (← isLocalIdent? term) with | none => throwErrorAt discr "unexpected discriminant" -- see `expandNonAtomicDiscrs? - | some d => do - dType ← inferType d; + | some d => + let dType ← inferType d tryPostponeIfMVar dType /- @@ -843,8 +830,8 @@ List.filter (fun p => match p with | (a, b) => a > b) xs When we visit `match p with | (a, b) => a > b`, we don't know the type of `p` yet. -/ private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do -tryPostponeIfNoneOrMVar expectedType?; -tryPostponeIfDiscrTypeIsMVar matchStx; +tryPostponeIfNoneOrMVar expectedType? +tryPostponeIfDiscrTypeIsMVar matchStx match expectedType? with | some expectedType => pure expectedType | none => mkFreshTypeMVar @@ -856,10 +843,10 @@ parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> ma Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`. -/ private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do -expectedType ← waitExpectedTypeAndDiscrs stx expectedType?; -let discrStxs := (getDiscrs stx).map fun d => d; -let altViews := getMatchAlts stx; -let matchOptType := getMatchOptType stx; +let expectedType ← waitExpectedTypeAndDiscrs stx expectedType? +let discrStxs := (getDiscrs stx).map fun d => d +let altViews := getMatchAlts stx +let matchOptType := getMatchOptType stx elabMatchAux discrStxs altViews matchOptType expectedType -- parser! "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts @@ -870,14 +857,13 @@ fun stx expectedType? => match_syntax stx with | `(match $discr:term : $type with $y:ident => $rhs:term) => expandSimpleMatchWithType stx discr y type rhs expectedType? | `(match $discr:term : $type with | $y:ident => $rhs:term) => expandSimpleMatchWithType stx discr y type rhs expectedType? | _ => do - stxNew? ← expandNonAtomicDiscrs? stx; - match stxNew? with + match (← expandNonAtomicDiscrs? stx) with | some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? - | none => do + | none => let discrs := getDiscrs stx; let matchOptType := getMatchOptType stx; - when (!matchOptType.isNone && discrs.any fun d => !(d.getArg 0).isNone) $ - throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used"; + if !matchOptType.isNone && discrs.any fun d => !d[0].isNone then + throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used" elabMatchCore stx expectedType? @[init] private def regTraceClasses : IO Unit := do @@ -888,8 +874,8 @@ pure () @[builtinTermElab «nomatch»] def elabNoMatch : TermElab := fun stx expectedType? => match_syntax stx with | `(nomatch $discrExpr) => do - expectedType ← waitExpectedType expectedType?; - let discr := Syntax.node `Lean.Parser.Term.matchDiscr #[mkNullNode, discrExpr]; + let expectedType ← waitExpectedType expectedType? + let discr := Syntax.node `Lean.Parser.Term.matchDiscr #[mkNullNode, discrExpr] elabMatchAux #[discr] #[] mkNullNode expectedType | _ => throwUnsupportedSyntax