diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index f3b2680ba5..30c4a06a60 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -210,7 +210,7 @@ private def varsToMessageData (vars : Array Var) : MessageData := MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.getId.simpMacroScopes)) " " partial def CodeBlocl.toMessageData (codeBlock : CodeBlock) : MessageData := - let us := MessageData.ofList $ (varSetToArray codeBlock.uvars).toList.map MessageData.ofSyntax + let us := MessageData.ofList <| (varSetToArray codeBlock.uvars).toList.map MessageData.ofSyntax let rec loop : Code → MessageData | Code.decl xs _ k => m!"let {varsToMessageData xs} := ...\n{loop k}" | Code.reassign xs _ k => m!"{varsToMessageData xs} := ...\n{loop k}" @@ -272,7 +272,7 @@ def mkAuxDeclFor {m} [Monad m] [MonadQuotation m] (e : Syntax) (mkCont : Syntax -- Add elaboration hint for producing sane error message let y ← `(ensure_expected_type% "type mismatch, result value" $y) let k ← mkCont y - pure $ Code.decl #[y] doElem k + return Code.decl #[y] doElem k /- Convert `action _ e` instructions in `c` into `let y ← e; jmp _ jp (xs y)`. -/ partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array Var) : MacroM Code := @@ -350,9 +350,9 @@ def mkJmp (ref : Syntax) (rs : VarSet) (val : Syntax) (mkJPBody : Syntax → Mac let yFresh ← withRef ref `(y) let ps := xs.map fun x => (x, true) let ps := ps.push (yFresh, false) - let jpBody ← liftMacroM $ mkJPBody yFresh + let jpBody ← liftMacroM <| mkJPBody yFresh let jp ← addFreshJP ps jpBody - pure $ Code.jmp ref jp args + return Code.jmp ref jp args /- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/ partial def pullExitPointsAux : VarSet → Code → StateRefT (Array JPDecl) TermElabM Code @@ -365,12 +365,12 @@ partial def pullExitPointsAux : VarSet → Code → StateRefT (Array JPDecl) Ter | rs, c@(Code.jmp _ _ _) => return c | rs, Code.«break» ref => mkSimpleJmp ref rs (Code.«break» ref) | rs, Code.«continue» ref => mkSimpleJmp ref rs (Code.«continue» ref) - | rs, Code.«return» ref val => mkJmp ref rs val (fun y => pure $ Code.«return» ref y) + | rs, Code.«return» ref val => mkJmp ref rs val (fun y => return Code.«return» ref y) | rs, Code.action e => -- We use `mkAuxDeclFor` because `e` is not pure. mkAuxDeclFor e fun y => let ref := e - mkJmp ref rs y (fun yFresh => do pure $ Code.action (← ``(Pure.pure $yFresh))) + mkJmp ref rs y (fun yFresh => return Code.action (← ``(Pure.pure $yFresh))) /- Auxiliary operation for adding new variables to the collection of updated variables in a CodeBlock. @@ -424,9 +424,9 @@ We implement the method as follows. Let `us` be `c.uvars`, then def pullExitPoints (c : Code) : TermElabM Code := do if hasExitPoint c then let (c, jpDecls) ← (pullExitPointsAux {} c).run #[] - pure $ attachJPs jpDecls c + return attachJPs jpDecls c else - pure c + return c partial def extendUpdatedVarsAux (c : Code) (ws : VarSet) : TermElabM Code := let rec update : Code → TermElabM Code @@ -563,8 +563,8 @@ def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Var) (k : CodeBlo let ps := ps.push (y, false) let jpDecl ← mkFreshJP ps k.code let jp := jpDecl.name - let terminal ← liftMacroM $ convertTerminalActionIntoJmp terminal.code jp xs - pure { code := attachJP jpDecl terminal, uvars := k.uvars } + let terminal ← liftMacroM <| convertTerminalActionIntoJmp terminal.code jp xs + return { code := attachJP jpDecl terminal, uvars := k.uvars } def getLetIdDeclVar (letIdDecl : Syntax) : Var := letIdDecl[0] @@ -588,11 +588,11 @@ def getLetEqnsDeclVar (letEqnsDecl : Syntax) : Var := def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Var) := do let arg := letDecl[0] if arg.getKind == ``Lean.Parser.Term.letIdDecl then - pure #[getLetIdDeclVar arg] + return #[getLetIdDeclVar arg] else if arg.getKind == ``Lean.Parser.Term.letPatDecl then getLetPatDeclVars arg else if arg.getKind == ``Lean.Parser.Term.letEqnsDecl then - pure #[getLetEqnsDeclVar arg] + return #[getLetEqnsDeclVar arg] else throwError "unexpected kind of let declaration" @@ -613,12 +613,12 @@ def getDoHaveVars (doHave : Syntax) : TermElabM (Array Var) := do if arg.getKind == ``Lean.Parser.Term.haveIdDecl then -- haveIdDecl := leading_parser atomic (haveIdLhs >> " := ") >> termParser -- haveIdLhs := optional (ident >> many (ppSpace >> (simpleBinderWithoutType <|> bracketedBinder))) >> optType - pure #[← getHaveIdLhsVar arg[0]] + return #[← getHaveIdLhsVar arg[0]] else if arg.getKind == ``Lean.Parser.Term.letPatDecl then getLetPatDeclVars arg else if arg.getKind == ``Lean.Parser.Term.haveEqnsDecl then -- haveEqnsDecl := leading_parser haveIdLhs >> matchAlts - pure #[← getHaveIdLhsVar arg[0]] + return #[← getHaveIdLhsVar arg[0]] else throwError "unexpected kind of have declaration" @@ -630,7 +630,7 @@ def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Var) := do for letDecl in letDecls do let vars ← getLetDeclVars letDecl allVars := allVars ++ vars - pure allVars + return allVars -- ident >> optType >> leftArrow >> termParser def getDoIdDeclVar (doIdDecl : Syntax) : Var := @@ -645,7 +645,7 @@ def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Var) := do def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Var) := do let decl := doLetArrow[2] if decl.getKind == ``Lean.Parser.Term.doIdDecl then - pure #[getDoIdDeclVar decl] + return #[getDoIdDeclVar decl] else if decl.getKind == ``Lean.Parser.Term.doPatDecl then getDoPatDeclVars decl else @@ -654,14 +654,14 @@ def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Var) := do def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Var) := do let arg := doReassign[0] if arg.getKind == ``Lean.Parser.Term.letIdDecl then - pure #[getLetIdDeclVar arg] + return #[getLetIdDeclVar arg] else if arg.getKind == ``Lean.Parser.Term.letPatDecl then getLetPatDeclVars arg else throwError "unexpected kind of reassignment" def mkDoSeq (doElems : Array Syntax) : Syntax := - mkNode `Lean.Parser.Term.doSeqIndent #[mkNullNode $ doElems.map fun doElem => mkNullNode #[doElem, mkNullNode]] + mkNode `Lean.Parser.Term.doSeqIndent #[mkNullNode <| doElems.map fun doElem => mkNullNode #[doElem, mkNullNode]] def mkSingletonDoSeq (doElem : Syntax) : Syntax := mkDoSeq #[doElem] @@ -714,11 +714,10 @@ private def mkTuple (elems : Array Syntax) : MacroM Syntax := do if elems.size == 0 then mkUnit else if elems.size == 1 then - pure elems[0] + return elems[0] else - (elems.extract 0 (elems.size - 1)).foldrM - (fun elem tuple => ``(MProd.mk $elem $tuple)) - (elems.back) + elems.extract 0 (elems.size - 1) |>.foldrM (init := elems.back) fun elem tuple => + ``(MProd.mk $elem $tuple) /- Return `some action` if `doElem` is a `doExpr `-/ def isDoExpr? (doElem : Syntax) : Option Syntax := @@ -938,7 +937,7 @@ def declToTerm (decl : Syntax) (k : Syntax) : M Syntax := withRef decl <| withFr else if kind == ``Lean.Parser.Term.doLetRec then let letRecToken := decl[0] let letRecDecls := decl[1] - pure $ mkNode ``Lean.Parser.Term.letrec #[letRecToken, letRecDecls, mkNullNode, k] + return mkNode ``Lean.Parser.Term.letrec #[letRecToken, letRecDecls, mkNullNode, k] else if kind == ``Lean.Parser.Term.doLetArrow then let arg := decl[2] let ref := arg @@ -958,7 +957,7 @@ def declToTerm (decl : Syntax) (k : Syntax) : M Syntax := withRef decl <| withFr -- The `have` term is of the form `"have " >> haveDecl >> optSemicolon termParser` let args := decl.getArgs let args := args ++ #[mkNullNode /- optional ';' -/, k] - pure $ mkNode `Lean.Parser.Term.«have» args + return mkNode `Lean.Parser.Term.«have» args else Macro.throwErrorAt decl "unexpected kind of 'do' declaration" @@ -1022,18 +1021,19 @@ def mkJoinPoint (j : Name) (ps : Array (Syntax × Bool)) (body : Syntax) (k : Sy def mkJmp (ref : Syntax) (j : Name) (args : Array Syntax) : Syntax := Syntax.mkApp (mkIdentFrom ref j) args -partial def toTerm : Code → M Syntax - | Code.«return» ref val => withRef ref <| returnToTerm val - | Code.«continue» ref => withRef ref continueToTerm - | Code.«break» ref => withRef ref breakToTerm +partial def toTerm (c : Code) : M Syntax := do + match c with + | Code.return ref val => withRef ref <| returnToTerm val + | Code.continue ref => withRef ref continueToTerm + | Code.break ref => withRef ref breakToTerm | Code.action e => actionTerminalToTerm e - | Code.joinpoint j ps b k => do mkJoinPoint j ps (← toTerm b) (← toTerm k) - | Code.jmp ref j args => pure $ mkJmp ref j args - | Code.decl _ stx k => do declToTerm stx (← toTerm k) - | Code.reassign _ stx k => do reassignToTerm stx (← toTerm k) - | Code.seq stx k => do seqToTerm stx (← toTerm k) + | Code.joinpoint j ps b k => mkJoinPoint j ps (← toTerm b) (← toTerm k) + | Code.jmp ref j args => return mkJmp ref j args + | Code.decl _ stx k => declToTerm stx (← toTerm k) + | Code.reassign _ stx k => reassignToTerm stx (← toTerm k) + | Code.seq stx k => seqToTerm stx (← toTerm k) | Code.ite ref _ o c t e => withRef ref <| do mkIte o c (← toTerm t) (← toTerm e) - | Code.«match» ref genParam discrs optMotive alts => do + | Code.«match» ref genParam discrs optMotive alts => let mut termAlts := #[] for alt in alts do let rhs ← toTerm alt.rhs @@ -1042,9 +1042,8 @@ partial def toTerm : Code → M Syntax let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts] return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts] -def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := do - let term ← toTerm code { m := m, kind := kind, uvars := uvars } - pure term +def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := + toTerm code { m := m, kind := kind, uvars := uvars } /- Given - `a` is true if the code block has a `Code.action _` exit point @@ -1054,13 +1053,13 @@ def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regu generate Kind. See comment at the beginning of the `ToTerm` namespace. -/ def mkNestedKind (a r bc : Bool) : Kind := match a, r, bc with - | true, false, false => Kind.regular - | false, true, false => Kind.regular - | false, false, true => Kind.nestedBC - | true, true, false => Kind.nestedPR - | true, false, true => Kind.nestedSBC - | false, true, true => Kind.nestedSBC - | true, true, true => Kind.nestedPRBC + | true, false, false => .regular + | false, true, false => .regular + | false, false, true => .nestedBC + | true, true, false => .nestedPR + | true, false, true => .nestedSBC + | false, true, true => .nestedSBC + | true, true, true => .nestedPRBC | false, false, false => unreachable! def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do @@ -1167,8 +1166,8 @@ def mkForInBody (x : Syntax) (forInBody : CodeBlock) : M ToForInTermResult := d let ctx ← read let uvars := forInBody.uvars let uvars := varSetToArray uvars - let term ← liftMacroM $ ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn) - pure ⟨uvars, term⟩ + let term ← liftMacroM <| ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn) + return ⟨uvars, term⟩ def ensureInsideFor : M Unit := unless (← read).insideFor do @@ -1195,14 +1194,14 @@ private partial def expandLiftMethodAux (inQuot : Bool) (inBinder : Bool) : Synt let inBinder := inBinder || (!inQuot && liftMethodForbiddenBinder stx) let args ← args.mapM (expandLiftMethodAux (inQuot && !inAntiquot || stx.isQuot) inBinder) return Syntax.node i k args - | stx => pure stx + | stx => return stx def expandLiftMethod (doElem : Syntax) : M (List Syntax × Syntax) := do if !hasLiftMethod doElem then - pure ([], doElem) + return ([], doElem) else let (doElem, doElemsNew) ← (expandLiftMethodAux false false doElem).run [] - pure (doElemsNew, doElem) + return (doElemsNew, doElem) def checkLetArrowRHS (doElem : Syntax) : M Unit := do let kind := doElem.getKind @@ -1232,7 +1231,7 @@ structure Catch where def getTryCatchUpdatedVars (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) : VarSet := let ws := tryCode.uvars - let ws := catches.foldl (fun ws alt => union alt.codeBlock.uvars ws) ws + let ws := catches.foldl (init := ws) fun ws alt => union alt.codeBlock.uvars ws let ws := match finallyCode? with | none => ws | some c => union c.uvars ws @@ -1275,7 +1274,7 @@ mutual let doElem := decl[3] let k ← withNewMutableVars #[y] (isMutableLet doLetArrow) (doSeqToCode doElems) match isDoExpr? doElem with - | some action => pure $ mkVarDeclCore #[y] doLetArrow k + | some action => return mkVarDeclCore #[y] doLetArrow k | none => checkLetArrowRHS doElem let c ← doSeqToCode [doElem] @@ -1492,13 +1491,13 @@ mutual withRef x <| checkNotShadowingMutable #[x] let optType := catchStx[2] let c ← doSeqToCode (getDoSeqElems catchStx[4]) - pure { x := x, optType := optType, codeBlock := c : Catch } + return { x := x, optType := optType, codeBlock := c : Catch } else if catchStx.getKind == ``Lean.Parser.Term.doCatchMatch then let matchAlts := catchStx[1] let x ← `(ex) let auxDo ← `(do match ex with $matchAlts) let c ← doSeqToCode (getDoSeqElems (getDoSeq auxDo)) - pure { x := x, codeBlock := c, optType := mkNullNode : Catch } + return { x := x, codeBlock := c, optType := mkNullNode : Catch } else throwError "unexpected kind of 'catch'" let finallyCode? ← if optFinally.isNone then pure none else some <$> doSeqToCode (getDoSeqElems optFinally[0][1]) @@ -1512,17 +1511,15 @@ mutual let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue let toTerm (codeBlock : CodeBlock) : M Syntax := do let codeBlock ← liftM $ extendUpdatedVars codeBlock ws - liftMacroM $ ToTerm.mkNestedTerm codeBlock.code ctx.m uvars a r bc + liftMacroM <| ToTerm.mkNestedTerm codeBlock.code ctx.m uvars a r bc let term ← toTerm tryCode - let term ← catches.foldlM - (fun term «catch» => do - let catchTerm ← toTerm «catch».codeBlock - if catch.optType.isNone then - ``(MonadExcept.tryCatch $term (fun $(«catch».x):ident => $catchTerm)) - else - let type := «catch».optType[1] - ``(tryCatchThe $type $term (fun $(«catch».x):ident => $catchTerm))) - term + let term ← catches.foldlM (init := term) fun term «catch» => do + let catchTerm ← toTerm «catch».codeBlock + if catch.optType.isNone then + ``(MonadExcept.tryCatch $term (fun $(«catch».x):ident => $catchTerm)) + else + let type := «catch».optType[1] + ``(tryCatchThe $type $term (fun $(«catch».x):ident => $catchTerm)) let term ← match finallyCode? with | none => pure term | some finallyCode => withRef optFinally do @@ -1624,16 +1621,16 @@ private def mkMonadAlias (m : Expr) : TermElabM Syntax := do let mType ← inferType m let mvar ← elabTerm result mType assignExprMVar mvar.mvarId! m - pure result + return result @[builtinTermElab «do»] def elabDo : TermElab := fun stx expectedType? => do tryPostponeIfNoneOrMVar expectedType? let bindInfo ← extractBind expectedType? let m ← mkMonadAlias bindInfo.m let codeBlock ← ToCodeBlock.run stx m - let stxNew ← liftMacroM $ ToTerm.run codeBlock.code m + let stxNew ← liftMacroM <| ToTerm.run codeBlock.code m trace[Elab.do] stxNew - withMacroExpansion stx stxNew $ elabTermEnsuringType stxNew bindInfo.expectedType + withMacroExpansion stx stxNew <| elabTermEnsuringType stxNew bindInfo.expectedType end Do