diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 695b6fcd24..2e05573bd8 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -122,10 +122,12 @@ private partial def extractBind (expectedType? : Option Expr) : TermElabM Extrac namespace Do +abbrev Var := Syntax -- TODO: should be `TSyntax identKind` + /- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/ structure Alt (σ : Type) where ref : Syntax - vars : Array Name + vars : Array Var patterns : Syntax rhs : σ deriving Inhabited @@ -179,34 +181,36 @@ structure Alt (σ : Type) where - For every `jmp ref j as` in `C`, there is a `joinpoint j ps b k` and `jmp ref j as` is in `k`, and `ps.size == as.size` -/ inductive Code where - | decl (xs : Array Name) (doElem : Syntax) (k : Code) - | reassign (xs : Array Name) (doElem : Syntax) (k : Code) + | decl (xs : Array Var) (doElem : Syntax) (k : Code) + | reassign (xs : Array Var) (doElem : Syntax) (k : Code) /- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/ - | joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (k : Code) + | joinpoint (name : Name) (params : Array (Var × Bool)) (body : Code) (k : Code) | seq (action : Syntax) (k : Code) | action (action : Syntax) | «break» (ref : Syntax) | «continue» (ref : Syntax) | «return» (ref : Syntax) (val : Syntax) /- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/ - | ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) + | ite (ref : Syntax) (h? : Option Var) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) | «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt Code)) | jmp (ref : Syntax) (jpName : Name) (args : Array Syntax) deriving Inhabited +abbrev VarSet := Std.RBMap Name Syntax Name.cmp + /- A code block, and the collection of variables updated by it. -/ structure CodeBlock where code : Code - uvars : NameSet := {} -- set of variables updated by `code` + uvars : VarSet := {} -- set of variables updated by `code` -private def nameSetToArray (s : NameSet) : Array Name := - s.fold (fun (xs : Array Name) x => xs.push x) #[] +private def varSetToArray (s : VarSet) : Array Var := + s.fold (fun xs _ x => xs.push x) #[] -private def varsToMessageData (vars : Array Name) : MessageData := - MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.simpMacroScopes)) " " +private def varsToMessageData (vars : Array Var) : MessageData := + MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.getId.simpMacroScopes)) " " partial def CodeBlocl.toMessageData (codeBlock : CodeBlock) : MessageData := - let us := MessageData.ofList $ (nameSetToArray codeBlock.uvars).toList.map MessageData.ofName + let us := MessageData.ofList $ (varSetToArray codeBlock.uvars).toList.map MessageData.ofSyntax let rec loop : Code → MessageData | Code.decl xs _ k => m!"let {varsToMessageData xs} := ...\n{loop k}" | Code.reassign xs _ k => m!"{varsToMessageData xs} := ...\n{loop k}" @@ -264,15 +268,14 @@ def hasBreakContinueReturn (c : Code) : Bool := def mkAuxDeclFor {m} [Monad m] [MonadQuotation m] (e : Syntax) (mkCont : Syntax → m Code) : m Code := withRef e <| withFreshMacroScope do let y ← `(y) - let yName := y.getId let doElem ← `(doElem| let y ← $e:term) -- Add elaboration hint for producing sane error message let y ← `(ensure_expected_type% "type mismatch, result value" $y) let k ← mkCont y - pure $ Code.decl #[yName] doElem k + pure $ Code.decl #[y] doElem k /- Convert `action _ e` instructions in `c` into `let y ← e; jmp _ jp (xs y)`. -/ -partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array Name) : MacroM Code := +partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array Var) : MacroM Code := let rec loop : Code → MacroM Code | Code.decl xs stx k => return Code.decl xs stx (← loop k) | Code.reassign xs stx k => return Code.reassign xs stx (← loop k) @@ -283,15 +286,14 @@ partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array N | Code.action e => mkAuxDeclFor e fun y => let ref := e -- We jump to `jp` with xs **and** y - let jmpArgs := xs.map $ mkIdentFrom ref - let jmpArgs := jmpArgs.push y + let jmpArgs := xs.push y return Code.jmp ref jp jmpArgs | c => return c loop code structure JPDecl where name : Name - params : Array (Name × Bool) + params : Array (Var × Bool) body : Code def attachJP (jpDecl : JPDecl) (k : Code) : Code := @@ -300,10 +302,10 @@ def attachJP (jpDecl : JPDecl) (k : Code) : Code := def attachJPs (jpDecls : Array JPDecl) (k : Code) : Code := jpDecls.foldr attachJP k -def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do +def mkFreshJP (ps : Array (Var × Bool)) (body : Code) : TermElabM JPDecl := do let ps ← if ps.isEmpty then - let y ← mkFreshUserName `y + let y ← `(y) pure #[(y, false)] else pure ps @@ -313,51 +315,47 @@ def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do let name ← mkFreshUserName `_do_jp pure { name := name, params := ps, body := body } -def mkFreshJP' (xs : Array Name) (body : Code) : TermElabM JPDecl := - mkFreshJP (xs.map fun x => (x, true)) body - -def addFreshJP (ps : Array (Name × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do +def addFreshJP (ps : Array (Var × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do let jp ← mkFreshJP ps body modify fun (jps : Array JPDecl) => jps.push jp pure jp.name -def insertVars (rs : NameSet) (xs : Array Name) : NameSet := - xs.foldl (·.insert ·) rs +def insertVars (rs : VarSet) (xs : Array Var) : VarSet := + xs.foldl (fun rs x => rs.insert x.getId x) rs -def eraseVars (rs : NameSet) (xs : Array Name) : NameSet := - xs.foldl (·.erase ·) rs +def eraseVars (rs : VarSet) (xs : Array Var) : VarSet := + xs.foldl (·.erase ·.getId) rs -def eraseOptVar (rs : NameSet) (x? : Option Name) : NameSet := +def eraseOptVar (rs : VarSet) (x? : Option Var) : VarSet := match x? with | none => rs - | some x => rs.insert x + | some x => rs.insert x.getId x /- Create a new jointpoint for `c`, and jump to it with the variables `rs` -/ -def mkSimpleJmp (ref : Syntax) (rs : NameSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do - let xs := nameSetToArray rs +def mkSimpleJmp (ref : Syntax) (rs : VarSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do + let xs := varSetToArray rs let jp ← addFreshJP (xs.map fun x => (x, true)) c if xs.isEmpty then let unit ← ``(Unit.unit) return Code.jmp ref jp #[unit] else - return Code.jmp ref jp (xs.map $ mkIdentFrom ref) + return Code.jmp ref jp xs /- Create a new joinpoint that takes `rs` and `val` as arguments. `val` must be syntax representing a pure value. The body of the joinpoint is created using `mkJPBody yFresh`, where `yFresh` is a fresh variable created by this method. -/ -def mkJmp (ref : Syntax) (rs : NameSet) (val : Syntax) (mkJPBody : Syntax → MacroM Code) : StateRefT (Array JPDecl) TermElabM Code := do - let xs := nameSetToArray rs - let args := xs.map $ mkIdentFrom ref - let args := args.push val - let yFresh ← mkFreshUserName `y +def mkJmp (ref : Syntax) (rs : VarSet) (val : Syntax) (mkJPBody : Syntax → MacroM Code) : StateRefT (Array JPDecl) TermElabM Code := do + let xs := varSetToArray rs + let args := xs.push val + let yFresh ← withRef ref `(y) let ps := xs.map fun x => (x, true) let ps := ps.push (yFresh, false) - let jpBody ← liftMacroM $ mkJPBody (mkIdentFrom ref yFresh) + let jpBody ← liftMacroM $ mkJPBody yFresh let jp ← addFreshJP ps jpBody pure $ Code.jmp ref jp args /- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/ -partial def pullExitPointsAux : NameSet → Code → StateRefT (Array JPDecl) TermElabM Code +partial def pullExitPointsAux : VarSet → Code → StateRefT (Array JPDecl) TermElabM Code | rs, Code.decl xs stx k => return Code.decl xs stx (← pullExitPointsAux (eraseVars rs xs) k) | rs, Code.reassign xs stx k => return Code.reassign xs stx (← pullExitPointsAux (insertVars rs xs) k) | rs, Code.joinpoint j ps b k => return Code.joinpoint j ps (← pullExitPointsAux rs b) (← pullExitPointsAux rs k) @@ -430,26 +428,26 @@ def pullExitPoints (c : Code) : TermElabM Code := do else pure c -partial def extendUpdatedVarsAux (c : Code) (ws : NameSet) : TermElabM Code := +partial def extendUpdatedVarsAux (c : Code) (ws : VarSet) : TermElabM Code := let rec update : Code → TermElabM Code | Code.joinpoint j ps b k => return Code.joinpoint j ps (← update b) (← update k) | Code.seq e k => return Code.seq e (← update k) | c@(Code.«match» ref g ds t alts) => do - if alts.any fun alt => alt.vars.any fun x => ws.contains x then + if alts.any fun alt => alt.vars.any fun x => ws.contains x.getId then -- If a pattern variable is shadowing a variable in ws, we `pullExitPoints` pullExitPoints c else return Code.«match» ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) }) | Code.ite ref none o c t e => return Code.ite ref none o c (← update t) (← update e) | c@(Code.ite ref (some h) o cond t e) => do - if ws.contains h then + if ws.contains h.getId then -- if the `h` at `if h:c then t else e` shadows a variable in `ws`, we `pullExitPoints` pullExitPoints c else return Code.ite ref (some h) o cond (← update t) (← update e) | Code.reassign xs stx k => return Code.reassign xs stx (← update k) | c@(Code.decl xs stx k) => do - if xs.any fun x => ws.contains x then + if xs.any fun x => ws.contains x.getId then -- One the declared variables is shadowing a variable in `ws` pullExitPoints c else @@ -462,14 +460,14 @@ Extend the set of updated variables. It assumes `ws` is a super set of `c.uvars` We **cannot** simply update the field `c.uvars`, because `c` may have shadowed some variable in `ws`. See discussion at `pullExitPoints`. -/ -partial def extendUpdatedVars (c : CodeBlock) (ws : NameSet) : TermElabM CodeBlock := do - if ws.any fun x => !c.uvars.contains x then +partial def extendUpdatedVars (c : CodeBlock) (ws : VarSet) : TermElabM CodeBlock := do + if ws.any fun x _ => !c.uvars.contains x then -- `ws` contains a variable that is not in `c.uvars`, but in `c.dvars` (i.e., it has been shadowed) pure { code := (← extendUpdatedVarsAux c.code ws), uvars := ws } else pure { c with uvars := ws } -private def union (s₁ s₂ : NameSet) : NameSet := +private def union (s₁ s₂ : VarSet) : VarSet := s₁.fold (·.insert ·) s₂ /- @@ -490,7 +488,7 @@ Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are declared by it. It is an array because we have let-declarations that declare multiple variables. Example: `let (x, y) := t` -/ -def mkVarDeclCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : CodeBlock := { +def mkVarDeclCore (xs : Array Var) (stx : Syntax) (c : CodeBlock) : CodeBlock := { code := Code.decl xs stx c.code, uvars := eraseVars c.uvars xs } @@ -501,12 +499,12 @@ Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are declared by it. It is an array because we have let-declarations that declare multiple variables. Example: `(x, y) ← t` -/ -def mkReassignCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : TermElabM CodeBlock := do +def mkReassignCore (xs : Array Var) (stx : Syntax) (c : CodeBlock) : TermElabM CodeBlock := do let us := c.uvars let ws := insertVars us xs -- If `xs` contains a new updated variable, then we must use `extendUpdatedVars`. -- See discussion at `pullExitPoints` - let code ← if xs.any fun x => !us.contains x then extendUpdatedVarsAux c.code ws else pure c.code + let code ← if xs.any fun x => !us.contains x.getId then extendUpdatedVarsAux c.code ws else pure c.code pure { code := Code.reassign xs stx code, uvars := ws } def mkSeq (action : Syntax) (c : CodeBlock) : CodeBlock := @@ -525,7 +523,7 @@ def mkContinue (ref : Syntax) : CodeBlock := { code := Code.«continue» ref } def mkIte (ref : Syntax) (optIdent : Syntax) (cond : Syntax) (thenBranch : CodeBlock) (elseBranch : CodeBlock) : TermElabM CodeBlock := do - let x? := if optIdent.isNone then none else some optIdent[0].getId + let x? := optIdent.getOptional? let (thenBranch, elseBranch) ← homogenize thenBranch elseBranch pure { code := Code.ite ref x? optIdent cond thenBranch.code elseBranch.code, @@ -555,12 +553,12 @@ def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optMotive : Sy /- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`. This method assumes `terminal` is a terminal -/ -def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Name) (k : CodeBlock) : TermElabM CodeBlock := do +def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Var) (k : CodeBlock) : TermElabM CodeBlock := do unless hasTerminalAction terminal.code do throwErrorAt kRef "'do' element is unreachable" let (terminal, k) ← homogenize terminal k - let xs := nameSetToArray k.uvars - let y ← match y? with | some y => pure y | none => mkFreshUserName `y + let xs := varSetToArray k.uvars + let y ← match y? with | some y => pure y | none => `(y) let ps := xs.map fun x => (x, true) let ps := ps.push (y, false) let jpDecl ← mkFreshJP ps k.code @@ -568,26 +566,26 @@ def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Name) (k : CodeBl let terminal ← liftMacroM $ convertTerminalActionIntoJmp terminal.code jp xs pure { code := attachJP jpDecl terminal, uvars := k.uvars } -def getLetIdDeclVar (letIdDecl : Syntax) : Name := - letIdDecl[0].getId +def getLetIdDeclVar (letIdDecl : Syntax) : Var := + letIdDecl[0] -- support both regular and syntax match -def getPatternVarsEx (pattern : Syntax) : TermElabM (Array Name) := - getPatternVarNames <$> getPatternVars pattern <|> - Array.map Syntax.getId <$> Quotation.getPatternVars pattern +def getPatternVarsEx (pattern : Syntax) : TermElabM (Array Var) := + getPatternVars pattern <|> + Quotation.getPatternVars pattern -def getPatternsVarsEx (patterns : Array Syntax) : TermElabM (Array Name) := - getPatternVarNames <$> getPatternsVars patterns <|> - Array.map Syntax.getId <$> Quotation.getPatternsVars patterns +def getPatternsVarsEx (patterns : Array Syntax) : TermElabM (Array Var) := + getPatternsVars patterns <|> + Quotation.getPatternsVars patterns -def getLetPatDeclVars (letPatDecl : Syntax) : TermElabM (Array Name) := do +def getLetPatDeclVars (letPatDecl : Syntax) : TermElabM (Array Var) := do let pattern := letPatDecl[0] getPatternVarsEx pattern -def getLetEqnsDeclVar (letEqnsDecl : Syntax) : Name := - letEqnsDecl[0].getId +def getLetEqnsDeclVar (letEqnsDecl : Syntax) : Var := + letEqnsDecl[0] -def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Name) := do +def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Var) := do let arg := letDecl[0] if arg.getKind == ``Lean.Parser.Term.letIdDecl then pure #[getLetIdDeclVar arg] @@ -598,33 +596,33 @@ def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Name) := do else throwError "unexpected kind of let declaration" -def getDoLetVars (doLet : Syntax) : TermElabM (Array Name) := +def getDoLetVars (doLet : Syntax) : TermElabM (Array Var) := -- leading_parser "let " >> optional "mut " >> letDecl getLetDeclVars doLet[2] -def getHaveIdLhsVar (optIdent : Syntax) : Name := +def getHaveIdLhsVar (optIdent : Syntax) : TermElabM Var := if optIdent.isNone then - `this + `(this) else - optIdent[0].getId + pure optIdent[0] -def getDoHaveVars (doHave : Syntax) : TermElabM (Array Name) := +def getDoHaveVars (doHave : Syntax) : TermElabM (Array Var) := do -- doHave := leading_parser "have " >> Term.haveDecl -- haveDecl := leading_parser haveIdDecl <|> letPatDecl <|> haveEqnsDecl let arg := doHave[1][0] if arg.getKind == ``Lean.Parser.Term.haveIdDecl then -- haveIdDecl := leading_parser atomic (haveIdLhs >> " := ") >> termParser -- haveIdLhs := optional (ident >> many (ppSpace >> (simpleBinderWithoutType <|> bracketedBinder))) >> optType - pure #[getHaveIdLhsVar arg[0]] + pure #[← getHaveIdLhsVar arg[0]] else if arg.getKind == ``Lean.Parser.Term.letPatDecl then getLetPatDeclVars arg else if arg.getKind == ``Lean.Parser.Term.haveEqnsDecl then -- haveEqnsDecl := leading_parser haveIdLhs >> matchAlts - pure #[getHaveIdLhsVar arg[0]] + pure #[← getHaveIdLhsVar arg[0]] else throwError "unexpected kind of have declaration" -def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Name) := do +def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Var) := do -- letRecDecls is an array of `(group (optional attributes >> letDecl))` let letRecDecls := doLetRec[1][0].getSepArgs let letDecls := letRecDecls.map fun p => p[2] @@ -635,16 +633,16 @@ def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Name) := do pure allVars -- ident >> optType >> leftArrow >> termParser -def getDoIdDeclVar (doIdDecl : Syntax) : Name := - doIdDecl[0].getId +def getDoIdDeclVar (doIdDecl : Syntax) : Var := + doIdDecl[0] -- termParser >> leftArrow >> termParser >> optional (" | " >> termParser) -def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Name) := do +def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Var) := do let pattern := doPatDecl[0] getPatternVarsEx pattern -- leading_parser "let " >> optional "mut " >> (doIdDecl <|> doPatDecl) -def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do +def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Var) := do let decl := doLetArrow[2] if decl.getKind == ``Lean.Parser.Term.doIdDecl then pure #[getDoIdDeclVar decl] @@ -653,7 +651,7 @@ def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do else throwError "unexpected kind of 'do' declaration" -def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Name) := do +def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Var) := do let arg := doReassign[0] if arg.getKind == ``Lean.Parser.Term.letIdDecl then pure #[getLetIdDeclVar arg] @@ -748,20 +746,20 @@ def isDoExpr? (doElem : Syntax) : Option Syntax := We use this method when expanding the `for-in` notation. -/ -private def destructTuple (uvars : Array Name) (x : Syntax) (body : Syntax) : MacroM Syntax := do +private def destructTuple (uvars : Array Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do if uvars.size == 0 then return body else if uvars.size == 1 then - `(let $(← mkIdentFromRef uvars[0]):ident := $x; $body) + `(let $(uvars[0]):ident := $x; $body) else destruct uvars.toList x body where - destruct (as : List Name) (x : Syntax) (body : Syntax) : MacroM Syntax := do + destruct (as : List Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do match as with - | [a, b] => `(let $(← mkIdentFromRef a):ident := $x.1; let $(← mkIdentFromRef b):ident := $x.2; $body) + | [a, b] => `(let $a:ident := $x.1; let $b:ident := $x.2; $body) | a :: as => withFreshMacroScope do let rest ← destruct as (← `(x)) body - `(let $(← mkIdentFromRef a):ident := $x.1; let x := $x.2; $rest) + `(let $a:ident := $x.1; let x := $x.2; $rest) | _ => unreachable! /- @@ -864,15 +862,14 @@ def Kind.isRegular : Kind → Bool structure Context where m : Syntax -- Syntax to reference the monad associated with the do notation. - uvars : Array Name + uvars : Array Var kind : Kind abbrev M := ReaderT Context MacroM def mkUVarTuple : M Syntax := do let ctx ← read - let uvarIdents ← ctx.uvars.mapM mkIdentFromRef - mkTuple uvarIdents + mkTuple ctx.uvars def returnToTerm (val : Syntax) : M Syntax := do let ctx ← read @@ -993,9 +990,9 @@ def mkIte (optIdent : Syntax) (cond : Syntax) (thenBranch : Syntax) (elseBranch let h := optIdent[0] ``(if $h:ident : $cond then $thenBranch else $elseBranch) -def mkJoinPoint (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withRef body <| withFreshMacroScope do - let pTypes ← ps.mapM fun ⟨id, useTypeOf⟩ => do if useTypeOf then `(type_of% $(← mkIdentFromRef id)) else `(_) - let ps ← ps.mapM fun ⟨id, useTypeOf⟩ => mkIdentFromRef id +def mkJoinPoint (j : Name) (ps : Array (Syntax × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withRef body <| withFreshMacroScope do + let pTypes ← ps.mapM fun ⟨id, useTypeOf⟩ => do if useTypeOf then `(type_of% $id) else `(_) + let ps := ps.map (·.1) /- We use `let_delayed` instead of `let` for joinpoints to make sure `$k` is elaborated before `$body`. By elaborating `$k` first, we "learn" more about `$body`'s type. @@ -1045,7 +1042,7 @@ partial def toTerm : Code → M Syntax let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts] return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts] -def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do +def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := do let term ← toTerm code { m := m, kind := kind, uvars := uvars } pure term @@ -1066,7 +1063,7 @@ def mkNestedKind (a r bc : Bool) : Kind := | true, true, true => Kind.nestedPRBC | false, false, false => unreachable! -def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Name) (a r bc : Bool) : MacroM Syntax := do +def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do ToTerm.run code m uvars (mkNestedKind a r bc) /- Given a term `term` produced by `ToTerm.run`, pattern match on its result. @@ -1077,9 +1074,9 @@ def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Name) (a r bc : Bool) - `bc` is true if the code block has a `Code.break _` or `Code.continue _` exit point The result is a sequence of `doElem` -/ -def matchNestedTermResult (term : Syntax) (uvars : Array Name) (a r bc : Bool) : MacroM (List Syntax) := do +def matchNestedTermResult (term : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM (List Syntax) := do let toDoElems (auxDo : Syntax) : List Syntax := getDoSeqElems (getDoSeq auxDo) - let u ← mkTuple (← uvars.mapM mkIdentFromRef) + let u ← mkTuple uvars match a, r, bc with | true, false, false => if uvars.isEmpty then @@ -1135,41 +1132,41 @@ namespace ToCodeBlock structure Context where ref : Syntax m : Syntax -- Syntax representing the monad associated with the do notation. - mutableVars : NameSet := {} + mutableVars : VarSet := {} insideFor : Bool := false abbrev M := ReaderT Context TermElabM -def withNewMutableVars {α} (newVars : Array Name) (mutable : Bool) (x : M α) : M α := +def withNewMutableVars {α} (newVars : Array Var) (mutable : Bool) (x : M α) : M α := withReader (fun ctx => if mutable then { ctx with mutableVars := insertVars ctx.mutableVars newVars } else ctx) x -def checkReassignable (xs : Array Name) : M Unit := do +def checkReassignable (xs : Array Var) : M Unit := do let throwInvalidReassignment (x : Name) : M Unit := throwError "'{x.simpMacroScopes}' cannot be reassigned" let ctx ← read for x in xs do - unless ctx.mutableVars.contains x do - throwInvalidReassignment x + unless ctx.mutableVars.contains x.getId do + throwInvalidReassignment x.getId -def checkNotShadowingMutable (xs : Array Name) : M Unit := do +def checkNotShadowingMutable (xs : Array Var) : M Unit := do let throwInvalidShadowing (x : Name) : M Unit := throwError "mutable variable '{x.simpMacroScopes}' cannot be shadowed" let ctx ← read for x in xs do - if ctx.mutableVars.contains x then - throwInvalidShadowing x + if ctx.mutableVars.contains x.getId then + throwInvalidShadowing x.getId def withFor {α} (x : M α) : M α := withReader (fun ctx => { ctx with insideFor := true }) x structure ToForInTermResult where - uvars : Array Name + uvars : Array Var term : Syntax def mkForInBody (x : Syntax) (forInBody : CodeBlock) : M ToForInTermResult := do let ctx ← read let uvars := forInBody.uvars - let uvars := nameSetToArray uvars + let uvars := varSetToArray uvars let term ← liftMacroM $ ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn) pure ⟨uvars, term⟩ @@ -1233,7 +1230,7 @@ structure Catch where optType : Syntax codeBlock : CodeBlock -def getTryCatchUpdatedVars (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) : NameSet := +def getTryCatchUpdatedVars (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) : VarSet := let ws := tryCode.uvars let ws := catches.foldl (fun ws alt => union alt.codeBlock.uvars ws) ws let ws := match finallyCode? with @@ -1248,6 +1245,10 @@ def tryCatchPred (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : O | none => false | some finallyCode => p finallyCode.code +def mutVarNamesToDefStx (mutVars : Array Syntax) : M (Array Syntax) := do + let ctx ← read + return mutVars.map fun v => ctx.mutableVars.findD v.getId v + mutual /- "Concatenate" `c` with `doSeqToCode doElems` -/ partial def concatWith (c : CodeBlock) (doElems : List Syntax) : M CodeBlock := @@ -1273,7 +1274,7 @@ mutual let ref := doLetArrow let decl := doLetArrow[2] if decl.getKind == ``Lean.Parser.Term.doIdDecl then - let y := decl[0].getId + let y := decl[0] checkNotShadowingMutable #[y] let doElem := decl[3] let k ← withNewMutableVars #[y] (isMutableLet doLetArrow) (doSeqToCode doElems) @@ -1422,7 +1423,7 @@ mutual let forElems := getDoSeqElems doFor[3] let forInBodyCodeBlock ← withFor (doSeqToCode forElems) let ⟨uvars, forInBody⟩ ← mkForInBody x forInBodyCodeBlock - let uvarsTuple ← liftMacroM do mkTuple (← uvars.mapM mkIdentFromRef) + let uvarsTuple ← liftMacroM do mkTuple uvars if hasReturn forInBodyCodeBlock.code then let forInBody ← liftMacroM <| destructTuple uvars (← `(r)) forInBody let forInTerm ← @@ -1487,7 +1488,7 @@ mutual if catchStx.getKind == ``Lean.Parser.Term.doCatch then let x := catchStx[1] if x.isIdent then - withRef x <| checkNotShadowingMutable #[x.getId] + withRef x <| checkNotShadowingMutable #[x] let optType := catchStx[2] let c ← doSeqToCode (getDoSeqElems catchStx[4]) pure { x := x, optType := optType, codeBlock := c : Catch } @@ -1504,7 +1505,7 @@ mutual throwError "invalid 'try', it must have a 'catch' or 'finally'" let ctx ← read let ws := getTryCatchUpdatedVars tryCode catches finallyCode? - let uvars := nameSetToArray ws + let uvars := varSetToArray ws let a := tryCatchPred tryCode catches finallyCode? hasTerminalAction let r := tryCatchPred tryCode catches finallyCode? hasReturn let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index d0e2c44509..abeeccc7da 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -182,11 +182,10 @@ structure PatternVarDecl where private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α := let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do if h : i < pVars.size then - match pVars.get ⟨i, h⟩ with - | { userName } => - let type ← mkFreshTypeMVar - withLocalDecl userName BinderInfo.default type fun x => - loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous) + let var := pVars.get ⟨i, h⟩ + let type ← mkFreshTypeMVar + withLocalDecl var.getId BinderInfo.default type fun x => + loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous) else k decls loop 0 #[] #[] @@ -882,7 +881,7 @@ private def generalize (discrs : Array Discr) (matchType : Expr) (altViews : Arr if ysUserNames.contains yUserName then yUserName ← mkFreshUserName yUserName -- Explicitly provided pattern variables shadow `y` - else if patternVars.any fun x => x.userName == yUserName then + else if patternVars.any fun x => x.getId == yUserName then yUserName ← mkFreshUserName yUserName return ysUserNames.push yUserName let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n diff --git a/src/Lean/Elab/PatternVar.lean b/src/Lean/Elab/PatternVar.lean index 2aa5b7cf61..a60d0ae43e 100644 --- a/src/Lean/Elab/PatternVar.lean +++ b/src/Lean/Elab/PatternVar.lean @@ -11,12 +11,7 @@ namespace Lean.Elab.Term open Meta -structure PatternVar where - userName : Name - deriving BEq - -instance : ToString PatternVar where - toString x := toString x.userName +abbrev PatternVar := Syntax -- TODO: should be `TSyntax identKind` /- Patterns define new local variables. @@ -111,7 +106,7 @@ private def processVar (idStx : Syntax) : M Syntax := do throwError "invalid pattern variable, must be atomic" if (← get).found.contains id then throwError "invalid pattern, variable '{id}' occurred more than once" - modify fun s => { s with vars := s.vars.push { userName := id }, found := s.found.insert id } + modify fun s => { s with vars := s.vars.push idStx, found := s.found.insert id } return idStx private def nameToPattern : Name → TermElabM Syntax @@ -366,6 +361,6 @@ def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := return s.vars def getPatternVarNames (pvars : Array PatternVar) : Array Name := - pvars.map fun x => x.userName + pvars.map fun x => x.getId end Lean.Elab.Term