fix: propagate position information of variables in do blocks

This commit is contained in:
Sebastian Ullrich 2022-05-07 22:23:05 +02:00 committed by Leonardo de Moura
parent daa9e03e78
commit 22f8ea147c
3 changed files with 117 additions and 122 deletions

View file

@ -122,10 +122,12 @@ private partial def extractBind (expectedType? : Option Expr) : TermElabM Extrac
namespace Do
abbrev Var := Syntax -- TODO: should be `TSyntax identKind`
/- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/
structure Alt (σ : Type) where
ref : Syntax
vars : Array Name
vars : Array Var
patterns : Syntax
rhs : σ
deriving Inhabited
@ -179,34 +181,36 @@ structure Alt (σ : Type) where
- For every `jmp ref j as` in `C`, there is a `joinpoint j ps b k` and `jmp ref j as` is in `k`, and
`ps.size == as.size` -/
inductive Code where
| decl (xs : Array Name) (doElem : Syntax) (k : Code)
| reassign (xs : Array Name) (doElem : Syntax) (k : Code)
| decl (xs : Array Var) (doElem : Syntax) (k : Code)
| reassign (xs : Array Var) (doElem : Syntax) (k : Code)
/- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/
| joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (k : Code)
| joinpoint (name : Name) (params : Array (Var × Bool)) (body : Code) (k : Code)
| seq (action : Syntax) (k : Code)
| action (action : Syntax)
| «break» (ref : Syntax)
| «continue» (ref : Syntax)
| «return» (ref : Syntax) (val : Syntax)
/- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/
| ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
| ite (ref : Syntax) (h? : Option Var) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
| «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt Code))
| jmp (ref : Syntax) (jpName : Name) (args : Array Syntax)
deriving Inhabited
abbrev VarSet := Std.RBMap Name Syntax Name.cmp
/- A code block, and the collection of variables updated by it. -/
structure CodeBlock where
code : Code
uvars : NameSet := {} -- set of variables updated by `code`
uvars : VarSet := {} -- set of variables updated by `code`
private def nameSetToArray (s : NameSet) : Array Name :=
s.fold (fun (xs : Array Name) x => xs.push x) #[]
private def varSetToArray (s : VarSet) : Array Var :=
s.fold (fun xs _ x => xs.push x) #[]
private def varsToMessageData (vars : Array Name) : MessageData :=
MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.simpMacroScopes)) " "
private def varsToMessageData (vars : Array Var) : MessageData :=
MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.getId.simpMacroScopes)) " "
partial def CodeBlocl.toMessageData (codeBlock : CodeBlock) : MessageData :=
let us := MessageData.ofList $ (nameSetToArray codeBlock.uvars).toList.map MessageData.ofName
let us := MessageData.ofList $ (varSetToArray codeBlock.uvars).toList.map MessageData.ofSyntax
let rec loop : Code → MessageData
| Code.decl xs _ k => m!"let {varsToMessageData xs} := ...\n{loop k}"
| Code.reassign xs _ k => m!"{varsToMessageData xs} := ...\n{loop k}"
@ -264,15 +268,14 @@ def hasBreakContinueReturn (c : Code) : Bool :=
def mkAuxDeclFor {m} [Monad m] [MonadQuotation m] (e : Syntax) (mkCont : Syntax → m Code) : m Code := withRef e <| withFreshMacroScope do
let y ← `(y)
let yName := y.getId
let doElem ← `(doElem| let y ← $e:term)
-- Add elaboration hint for producing sane error message
let y ← `(ensure_expected_type% "type mismatch, result value" $y)
let k ← mkCont y
pure $ Code.decl #[yName] doElem k
pure $ Code.decl #[y] doElem k
/- Convert `action _ e` instructions in `c` into `let y ← e; jmp _ jp (xs y)`. -/
partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array Name) : MacroM Code :=
partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array Var) : MacroM Code :=
let rec loop : Code → MacroM Code
| Code.decl xs stx k => return Code.decl xs stx (← loop k)
| Code.reassign xs stx k => return Code.reassign xs stx (← loop k)
@ -283,15 +286,14 @@ partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array N
| Code.action e => mkAuxDeclFor e fun y =>
let ref := e
-- We jump to `jp` with xs **and** y
let jmpArgs := xs.map $ mkIdentFrom ref
let jmpArgs := jmpArgs.push y
let jmpArgs := xs.push y
return Code.jmp ref jp jmpArgs
| c => return c
loop code
structure JPDecl where
name : Name
params : Array (Name × Bool)
params : Array (Var × Bool)
body : Code
def attachJP (jpDecl : JPDecl) (k : Code) : Code :=
@ -300,10 +302,10 @@ def attachJP (jpDecl : JPDecl) (k : Code) : Code :=
def attachJPs (jpDecls : Array JPDecl) (k : Code) : Code :=
jpDecls.foldr attachJP k
def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do
def mkFreshJP (ps : Array (Var × Bool)) (body : Code) : TermElabM JPDecl := do
let ps ←
if ps.isEmpty then
let y ← mkFreshUserName `y
let y ← `(y)
pure #[(y, false)]
else
pure ps
@ -313,51 +315,47 @@ def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do
let name ← mkFreshUserName `_do_jp
pure { name := name, params := ps, body := body }
def mkFreshJP' (xs : Array Name) (body : Code) : TermElabM JPDecl :=
mkFreshJP (xs.map fun x => (x, true)) body
def addFreshJP (ps : Array (Name × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do
def addFreshJP (ps : Array (Var × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do
let jp ← mkFreshJP ps body
modify fun (jps : Array JPDecl) => jps.push jp
pure jp.name
def insertVars (rs : NameSet) (xs : Array Name) : NameSet :=
xs.foldl (·.insert ·) rs
def insertVars (rs : VarSet) (xs : Array Var) : VarSet :=
xs.foldl (fun rs x => rs.insert x.getId x) rs
def eraseVars (rs : NameSet) (xs : Array Name) : NameSet :=
xs.foldl (·.erase ·) rs
def eraseVars (rs : VarSet) (xs : Array Var) : VarSet :=
xs.foldl (·.erase ·.getId) rs
def eraseOptVar (rs : NameSet) (x? : Option Name) : NameSet :=
def eraseOptVar (rs : VarSet) (x? : Option Var) : VarSet :=
match x? with
| none => rs
| some x => rs.insert x
| some x => rs.insert x.getId x
/- Create a new jointpoint for `c`, and jump to it with the variables `rs` -/
def mkSimpleJmp (ref : Syntax) (rs : NameSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do
let xs := nameSetToArray rs
def mkSimpleJmp (ref : Syntax) (rs : VarSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do
let xs := varSetToArray rs
let jp ← addFreshJP (xs.map fun x => (x, true)) c
if xs.isEmpty then
let unit ← ``(Unit.unit)
return Code.jmp ref jp #[unit]
else
return Code.jmp ref jp (xs.map $ mkIdentFrom ref)
return Code.jmp ref jp xs
/- Create a new joinpoint that takes `rs` and `val` as arguments. `val` must be syntax representing a pure value.
The body of the joinpoint is created using `mkJPBody yFresh`, where `yFresh`
is a fresh variable created by this method. -/
def mkJmp (ref : Syntax) (rs : NameSet) (val : Syntax) (mkJPBody : Syntax → MacroM Code) : StateRefT (Array JPDecl) TermElabM Code := do
let xs := nameSetToArray rs
let args := xs.map $ mkIdentFrom ref
let args := args.push val
let yFresh ← mkFreshUserName `y
def mkJmp (ref : Syntax) (rs : VarSet) (val : Syntax) (mkJPBody : Syntax → MacroM Code) : StateRefT (Array JPDecl) TermElabM Code := do
let xs := varSetToArray rs
let args := xs.push val
let yFresh ← withRef ref `(y)
let ps := xs.map fun x => (x, true)
let ps := ps.push (yFresh, false)
let jpBody ← liftMacroM $ mkJPBody (mkIdentFrom ref yFresh)
let jpBody ← liftMacroM $ mkJPBody yFresh
let jp ← addFreshJP ps jpBody
pure $ Code.jmp ref jp args
/- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/
partial def pullExitPointsAux : NameSet → Code → StateRefT (Array JPDecl) TermElabM Code
partial def pullExitPointsAux : VarSet → Code → StateRefT (Array JPDecl) TermElabM Code
| rs, Code.decl xs stx k => return Code.decl xs stx (← pullExitPointsAux (eraseVars rs xs) k)
| rs, Code.reassign xs stx k => return Code.reassign xs stx (← pullExitPointsAux (insertVars rs xs) k)
| rs, Code.joinpoint j ps b k => return Code.joinpoint j ps (← pullExitPointsAux rs b) (← pullExitPointsAux rs k)
@ -430,26 +428,26 @@ def pullExitPoints (c : Code) : TermElabM Code := do
else
pure c
partial def extendUpdatedVarsAux (c : Code) (ws : NameSet) : TermElabM Code :=
partial def extendUpdatedVarsAux (c : Code) (ws : VarSet) : TermElabM Code :=
let rec update : Code → TermElabM Code
| Code.joinpoint j ps b k => return Code.joinpoint j ps (← update b) (← update k)
| Code.seq e k => return Code.seq e (← update k)
| c@(Code.«match» ref g ds t alts) => do
if alts.any fun alt => alt.vars.any fun x => ws.contains x then
if alts.any fun alt => alt.vars.any fun x => ws.contains x.getId then
-- If a pattern variable is shadowing a variable in ws, we `pullExitPoints`
pullExitPoints c
else
return Code.«match» ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) })
| Code.ite ref none o c t e => return Code.ite ref none o c (← update t) (← update e)
| c@(Code.ite ref (some h) o cond t e) => do
if ws.contains h then
if ws.contains h.getId then
-- if the `h` at `if h:c then t else e` shadows a variable in `ws`, we `pullExitPoints`
pullExitPoints c
else
return Code.ite ref (some h) o cond (← update t) (← update e)
| Code.reassign xs stx k => return Code.reassign xs stx (← update k)
| c@(Code.decl xs stx k) => do
if xs.any fun x => ws.contains x then
if xs.any fun x => ws.contains x.getId then
-- One the declared variables is shadowing a variable in `ws`
pullExitPoints c
else
@ -462,14 +460,14 @@ Extend the set of updated variables. It assumes `ws` is a super set of `c.uvars`
We **cannot** simply update the field `c.uvars`, because `c` may have shadowed some variable in `ws`.
See discussion at `pullExitPoints`.
-/
partial def extendUpdatedVars (c : CodeBlock) (ws : NameSet) : TermElabM CodeBlock := do
if ws.any fun x => !c.uvars.contains x then
partial def extendUpdatedVars (c : CodeBlock) (ws : VarSet) : TermElabM CodeBlock := do
if ws.any fun x _ => !c.uvars.contains x then
-- `ws` contains a variable that is not in `c.uvars`, but in `c.dvars` (i.e., it has been shadowed)
pure { code := (← extendUpdatedVarsAux c.code ws), uvars := ws }
else
pure { c with uvars := ws }
private def union (s₁ s₂ : NameSet) : NameSet :=
private def union (s₁ s₂ : VarSet) : VarSet :=
s₁.fold (·.insert ·) s₂
/-
@ -490,7 +488,7 @@ Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are
declared by it. It is an array because we have let-declarations that declare multiple variables.
Example: `let (x, y) := t`
-/
def mkVarDeclCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : CodeBlock := {
def mkVarDeclCore (xs : Array Var) (stx : Syntax) (c : CodeBlock) : CodeBlock := {
code := Code.decl xs stx c.code,
uvars := eraseVars c.uvars xs
}
@ -501,12 +499,12 @@ Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are
declared by it. It is an array because we have let-declarations that declare multiple variables.
Example: `(x, y) ← t`
-/
def mkReassignCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : TermElabM CodeBlock := do
def mkReassignCore (xs : Array Var) (stx : Syntax) (c : CodeBlock) : TermElabM CodeBlock := do
let us := c.uvars
let ws := insertVars us xs
-- If `xs` contains a new updated variable, then we must use `extendUpdatedVars`.
-- See discussion at `pullExitPoints`
let code ← if xs.any fun x => !us.contains x then extendUpdatedVarsAux c.code ws else pure c.code
let code ← if xs.any fun x => !us.contains x.getId then extendUpdatedVarsAux c.code ws else pure c.code
pure { code := Code.reassign xs stx code, uvars := ws }
def mkSeq (action : Syntax) (c : CodeBlock) : CodeBlock :=
@ -525,7 +523,7 @@ def mkContinue (ref : Syntax) : CodeBlock :=
{ code := Code.«continue» ref }
def mkIte (ref : Syntax) (optIdent : Syntax) (cond : Syntax) (thenBranch : CodeBlock) (elseBranch : CodeBlock) : TermElabM CodeBlock := do
let x? := if optIdent.isNone then none else some optIdent[0].getId
let x? := optIdent.getOptional?
let (thenBranch, elseBranch) ← homogenize thenBranch elseBranch
pure {
code := Code.ite ref x? optIdent cond thenBranch.code elseBranch.code,
@ -555,12 +553,12 @@ def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optMotive : Sy
/- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`.
This method assumes `terminal` is a terminal -/
def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Name) (k : CodeBlock) : TermElabM CodeBlock := do
def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Var) (k : CodeBlock) : TermElabM CodeBlock := do
unless hasTerminalAction terminal.code do
throwErrorAt kRef "'do' element is unreachable"
let (terminal, k) ← homogenize terminal k
let xs := nameSetToArray k.uvars
let y ← match y? with | some y => pure y | none => mkFreshUserName `y
let xs := varSetToArray k.uvars
let y ← match y? with | some y => pure y | none => `(y)
let ps := xs.map fun x => (x, true)
let ps := ps.push (y, false)
let jpDecl ← mkFreshJP ps k.code
@ -568,26 +566,26 @@ def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Name) (k : CodeBl
let terminal ← liftMacroM $ convertTerminalActionIntoJmp terminal.code jp xs
pure { code := attachJP jpDecl terminal, uvars := k.uvars }
def getLetIdDeclVar (letIdDecl : Syntax) : Name :=
letIdDecl[0].getId
def getLetIdDeclVar (letIdDecl : Syntax) : Var :=
letIdDecl[0]
-- support both regular and syntax match
def getPatternVarsEx (pattern : Syntax) : TermElabM (Array Name) :=
getPatternVarNames <$> getPatternVars pattern <|>
Array.map Syntax.getId <$> Quotation.getPatternVars pattern
def getPatternVarsEx (pattern : Syntax) : TermElabM (Array Var) :=
getPatternVars pattern <|>
Quotation.getPatternVars pattern
def getPatternsVarsEx (patterns : Array Syntax) : TermElabM (Array Name) :=
getPatternVarNames <$> getPatternsVars patterns <|>
Array.map Syntax.getId <$> Quotation.getPatternsVars patterns
def getPatternsVarsEx (patterns : Array Syntax) : TermElabM (Array Var) :=
getPatternsVars patterns <|>
Quotation.getPatternsVars patterns
def getLetPatDeclVars (letPatDecl : Syntax) : TermElabM (Array Name) := do
def getLetPatDeclVars (letPatDecl : Syntax) : TermElabM (Array Var) := do
let pattern := letPatDecl[0]
getPatternVarsEx pattern
def getLetEqnsDeclVar (letEqnsDecl : Syntax) : Name :=
letEqnsDecl[0].getId
def getLetEqnsDeclVar (letEqnsDecl : Syntax) : Var :=
letEqnsDecl[0]
def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Name) := do
def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Var) := do
let arg := letDecl[0]
if arg.getKind == ``Lean.Parser.Term.letIdDecl then
pure #[getLetIdDeclVar arg]
@ -598,33 +596,33 @@ def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Name) := do
else
throwError "unexpected kind of let declaration"
def getDoLetVars (doLet : Syntax) : TermElabM (Array Name) :=
def getDoLetVars (doLet : Syntax) : TermElabM (Array Var) :=
-- leading_parser "let " >> optional "mut " >> letDecl
getLetDeclVars doLet[2]
def getHaveIdLhsVar (optIdent : Syntax) : Name :=
def getHaveIdLhsVar (optIdent : Syntax) : TermElabM Var :=
if optIdent.isNone then
`this
`(this)
else
optIdent[0].getId
pure optIdent[0]
def getDoHaveVars (doHave : Syntax) : TermElabM (Array Name) :=
def getDoHaveVars (doHave : Syntax) : TermElabM (Array Var) := do
-- doHave := leading_parser "have " >> Term.haveDecl
-- haveDecl := leading_parser haveIdDecl <|> letPatDecl <|> haveEqnsDecl
let arg := doHave[1][0]
if arg.getKind == ``Lean.Parser.Term.haveIdDecl then
-- haveIdDecl := leading_parser atomic (haveIdLhs >> " := ") >> termParser
-- haveIdLhs := optional (ident >> many (ppSpace >> (simpleBinderWithoutType <|> bracketedBinder))) >> optType
pure #[getHaveIdLhsVar arg[0]]
pure #[getHaveIdLhsVar arg[0]]
else if arg.getKind == ``Lean.Parser.Term.letPatDecl then
getLetPatDeclVars arg
else if arg.getKind == ``Lean.Parser.Term.haveEqnsDecl then
-- haveEqnsDecl := leading_parser haveIdLhs >> matchAlts
pure #[getHaveIdLhsVar arg[0]]
pure #[getHaveIdLhsVar arg[0]]
else
throwError "unexpected kind of have declaration"
def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Name) := do
def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Var) := do
-- letRecDecls is an array of `(group (optional attributes >> letDecl))`
let letRecDecls := doLetRec[1][0].getSepArgs
let letDecls := letRecDecls.map fun p => p[2]
@ -635,16 +633,16 @@ def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Name) := do
pure allVars
-- ident >> optType >> leftArrow >> termParser
def getDoIdDeclVar (doIdDecl : Syntax) : Name :=
doIdDecl[0].getId
def getDoIdDeclVar (doIdDecl : Syntax) : Var :=
doIdDecl[0]
-- termParser >> leftArrow >> termParser >> optional (" | " >> termParser)
def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Name) := do
def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Var) := do
let pattern := doPatDecl[0]
getPatternVarsEx pattern
-- leading_parser "let " >> optional "mut " >> (doIdDecl <|> doPatDecl)
def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do
def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Var) := do
let decl := doLetArrow[2]
if decl.getKind == ``Lean.Parser.Term.doIdDecl then
pure #[getDoIdDeclVar decl]
@ -653,7 +651,7 @@ def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do
else
throwError "unexpected kind of 'do' declaration"
def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Name) := do
def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Var) := do
let arg := doReassign[0]
if arg.getKind == ``Lean.Parser.Term.letIdDecl then
pure #[getLetIdDeclVar arg]
@ -748,20 +746,20 @@ def isDoExpr? (doElem : Syntax) : Option Syntax :=
We use this method when expanding the `for-in` notation.
-/
private def destructTuple (uvars : Array Name) (x : Syntax) (body : Syntax) : MacroM Syntax := do
private def destructTuple (uvars : Array Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do
if uvars.size == 0 then
return body
else if uvars.size == 1 then
`(let $(← mkIdentFromRef uvars[0]):ident := $x; $body)
`(let $(uvars[0]):ident := $x; $body)
else
destruct uvars.toList x body
where
destruct (as : List Name) (x : Syntax) (body : Syntax) : MacroM Syntax := do
destruct (as : List Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do
match as with
| [a, b] => `(let $(← mkIdentFromRef a):ident := $x.1; let $(← mkIdentFromRef b):ident := $x.2; $body)
| [a, b] => `(let $a:ident := $x.1; let $b:ident := $x.2; $body)
| a :: as => withFreshMacroScope do
let rest ← destruct as (← `(x)) body
`(let $(← mkIdentFromRef a):ident := $x.1; let x := $x.2; $rest)
`(let $a:ident := $x.1; let x := $x.2; $rest)
| _ => unreachable!
/-
@ -864,15 +862,14 @@ def Kind.isRegular : Kind → Bool
structure Context where
m : Syntax -- Syntax to reference the monad associated with the do notation.
uvars : Array Name
uvars : Array Var
kind : Kind
abbrev M := ReaderT Context MacroM
def mkUVarTuple : M Syntax := do
let ctx ← read
let uvarIdents ← ctx.uvars.mapM mkIdentFromRef
mkTuple uvarIdents
mkTuple ctx.uvars
def returnToTerm (val : Syntax) : M Syntax := do
let ctx ← read
@ -993,9 +990,9 @@ def mkIte (optIdent : Syntax) (cond : Syntax) (thenBranch : Syntax) (elseBranch
let h := optIdent[0]
``(if $h:ident : $cond then $thenBranch else $elseBranch)
def mkJoinPoint (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withRef body <| withFreshMacroScope do
let pTypes ← ps.mapM fun ⟨id, useTypeOf⟩ => do if useTypeOf then `(type_of% $(← mkIdentFromRef id)) else `(_)
let ps ← ps.mapM fun ⟨id, useTypeOf⟩ => mkIdentFromRef id
def mkJoinPoint (j : Name) (ps : Array (Syntax × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withRef body <| withFreshMacroScope do
let pTypes ← ps.mapM fun ⟨id, useTypeOf⟩ => do if useTypeOf then `(type_of% $id) else `(_)
let ps := ps.map (·.1)
/-
We use `let_delayed` instead of `let` for joinpoints to make sure `$k` is elaborated before `$body`.
By elaborating `$k` first, we "learn" more about `$body`'s type.
@ -1045,7 +1042,7 @@ partial def toTerm : Code → M Syntax
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts]
return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts]
def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do
def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := do
let term ← toTerm code { m := m, kind := kind, uvars := uvars }
pure term
@ -1066,7 +1063,7 @@ def mkNestedKind (a r bc : Bool) : Kind :=
| true, true, true => Kind.nestedPRBC
| false, false, false => unreachable!
def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Name) (a r bc : Bool) : MacroM Syntax := do
def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do
ToTerm.run code m uvars (mkNestedKind a r bc)
/- Given a term `term` produced by `ToTerm.run`, pattern match on its result.
@ -1077,9 +1074,9 @@ def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Name) (a r bc : Bool)
- `bc` is true if the code block has a `Code.break _` or `Code.continue _` exit point
The result is a sequence of `doElem` -/
def matchNestedTermResult (term : Syntax) (uvars : Array Name) (a r bc : Bool) : MacroM (List Syntax) := do
def matchNestedTermResult (term : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM (List Syntax) := do
let toDoElems (auxDo : Syntax) : List Syntax := getDoSeqElems (getDoSeq auxDo)
let u ← mkTuple (← uvars.mapM mkIdentFromRef)
let u ← mkTuple uvars
match a, r, bc with
| true, false, false =>
if uvars.isEmpty then
@ -1135,41 +1132,41 @@ namespace ToCodeBlock
structure Context where
ref : Syntax
m : Syntax -- Syntax representing the monad associated with the do notation.
mutableVars : NameSet := {}
mutableVars : VarSet := {}
insideFor : Bool := false
abbrev M := ReaderT Context TermElabM
def withNewMutableVars {α} (newVars : Array Name) (mutable : Bool) (x : M α) : M α :=
def withNewMutableVars {α} (newVars : Array Var) (mutable : Bool) (x : M α) : M α :=
withReader (fun ctx => if mutable then { ctx with mutableVars := insertVars ctx.mutableVars newVars } else ctx) x
def checkReassignable (xs : Array Name) : M Unit := do
def checkReassignable (xs : Array Var) : M Unit := do
let throwInvalidReassignment (x : Name) : M Unit :=
throwError "'{x.simpMacroScopes}' cannot be reassigned"
let ctx ← read
for x in xs do
unless ctx.mutableVars.contains x do
throwInvalidReassignment x
unless ctx.mutableVars.contains x.getId do
throwInvalidReassignment x.getId
def checkNotShadowingMutable (xs : Array Name) : M Unit := do
def checkNotShadowingMutable (xs : Array Var) : M Unit := do
let throwInvalidShadowing (x : Name) : M Unit :=
throwError "mutable variable '{x.simpMacroScopes}' cannot be shadowed"
let ctx ← read
for x in xs do
if ctx.mutableVars.contains x then
throwInvalidShadowing x
if ctx.mutableVars.contains x.getId then
throwInvalidShadowing x.getId
def withFor {α} (x : M α) : M α :=
withReader (fun ctx => { ctx with insideFor := true }) x
structure ToForInTermResult where
uvars : Array Name
uvars : Array Var
term : Syntax
def mkForInBody (x : Syntax) (forInBody : CodeBlock) : M ToForInTermResult := do
let ctx ← read
let uvars := forInBody.uvars
let uvars := nameSetToArray uvars
let uvars := varSetToArray uvars
let term ← liftMacroM $ ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn)
pure ⟨uvars, term⟩
@ -1233,7 +1230,7 @@ structure Catch where
optType : Syntax
codeBlock : CodeBlock
def getTryCatchUpdatedVars (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) : NameSet :=
def getTryCatchUpdatedVars (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) : VarSet :=
let ws := tryCode.uvars
let ws := catches.foldl (fun ws alt => union alt.codeBlock.uvars ws) ws
let ws := match finallyCode? with
@ -1248,6 +1245,10 @@ def tryCatchPred (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : O
| none => false
| some finallyCode => p finallyCode.code
def mutVarNamesToDefStx (mutVars : Array Syntax) : M (Array Syntax) := do
let ctx ← read
return mutVars.map fun v => ctx.mutableVars.findD v.getId v
mutual
/- "Concatenate" `c` with `doSeqToCode doElems` -/
partial def concatWith (c : CodeBlock) (doElems : List Syntax) : M CodeBlock :=
@ -1273,7 +1274,7 @@ mutual
let ref := doLetArrow
let decl := doLetArrow[2]
if decl.getKind == ``Lean.Parser.Term.doIdDecl then
let y := decl[0].getId
let y := decl[0]
checkNotShadowingMutable #[y]
let doElem := decl[3]
let k ← withNewMutableVars #[y] (isMutableLet doLetArrow) (doSeqToCode doElems)
@ -1422,7 +1423,7 @@ mutual
let forElems := getDoSeqElems doFor[3]
let forInBodyCodeBlock ← withFor (doSeqToCode forElems)
let ⟨uvars, forInBody⟩ ← mkForInBody x forInBodyCodeBlock
let uvarsTuple ← liftMacroM do mkTuple (← uvars.mapM mkIdentFromRef)
let uvarsTuple ← liftMacroM do mkTuple uvars
if hasReturn forInBodyCodeBlock.code then
let forInBody ← liftMacroM <| destructTuple uvars (← `(r)) forInBody
let forInTerm ←
@ -1487,7 +1488,7 @@ mutual
if catchStx.getKind == ``Lean.Parser.Term.doCatch then
let x := catchStx[1]
if x.isIdent then
withRef x <| checkNotShadowingMutable #[x.getId]
withRef x <| checkNotShadowingMutable #[x]
let optType := catchStx[2]
let c ← doSeqToCode (getDoSeqElems catchStx[4])
pure { x := x, optType := optType, codeBlock := c : Catch }
@ -1504,7 +1505,7 @@ mutual
throwError "invalid 'try', it must have a 'catch' or 'finally'"
let ctx ← read
let ws := getTryCatchUpdatedVars tryCode catches finallyCode?
let uvars := nameSetToArray ws
let uvars := varSetToArray ws
let a := tryCatchPred tryCode catches finallyCode? hasTerminalAction
let r := tryCatchPred tryCode catches finallyCode? hasReturn
let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue

View file

@ -182,11 +182,10 @@ structure PatternVarDecl where
private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α :=
let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do
if h : i < pVars.size then
match pVars.get ⟨i, h⟩ with
| { userName } =>
let type ← mkFreshTypeMVar
withLocalDecl userName BinderInfo.default type fun x =>
loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous)
let var := pVars.get ⟨i, h⟩
let type ← mkFreshTypeMVar
withLocalDecl var.getId BinderInfo.default type fun x =>
loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous)
else
k decls
loop 0 #[] #[]
@ -882,7 +881,7 @@ private def generalize (discrs : Array Discr) (matchType : Expr) (altViews : Arr
if ysUserNames.contains yUserName then
yUserName ← mkFreshUserName yUserName
-- Explicitly provided pattern variables shadow `y`
else if patternVars.any fun x => x.userName == yUserName then
else if patternVars.any fun x => x.getId == yUserName then
yUserName ← mkFreshUserName yUserName
return ysUserNames.push yUserName
let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n

View file

@ -11,12 +11,7 @@ namespace Lean.Elab.Term
open Meta
structure PatternVar where
userName : Name
deriving BEq
instance : ToString PatternVar where
toString x := toString x.userName
abbrev PatternVar := Syntax -- TODO: should be `TSyntax identKind`
/-
Patterns define new local variables.
@ -111,7 +106,7 @@ private def processVar (idStx : Syntax) : M Syntax := do
throwError "invalid pattern variable, must be atomic"
if (← get).found.contains id then
throwError "invalid pattern, variable '{id}' occurred more than once"
modify fun s => { s with vars := s.vars.push { userName := id }, found := s.found.insert id }
modify fun s => { s with vars := s.vars.push idStx, found := s.found.insert id }
return idStx
private def nameToPattern : Name → TermElabM Syntax
@ -366,6 +361,6 @@ def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) :=
return s.vars
def getPatternVarNames (pvars : Array PatternVar) : Array Name :=
pvars.map fun x => x.userName
pvars.map fun x => x.getId
end Lean.Elab.Term