fix: propagate position information of variables in do blocks
This commit is contained in:
parent
daa9e03e78
commit
22f8ea147c
3 changed files with 117 additions and 122 deletions
|
|
@ -122,10 +122,12 @@ private partial def extractBind (expectedType? : Option Expr) : TermElabM Extrac
|
|||
|
||||
namespace Do
|
||||
|
||||
abbrev Var := Syntax -- TODO: should be `TSyntax identKind`
|
||||
|
||||
/- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/
|
||||
structure Alt (σ : Type) where
|
||||
ref : Syntax
|
||||
vars : Array Name
|
||||
vars : Array Var
|
||||
patterns : Syntax
|
||||
rhs : σ
|
||||
deriving Inhabited
|
||||
|
|
@ -179,34 +181,36 @@ structure Alt (σ : Type) where
|
|||
- For every `jmp ref j as` in `C`, there is a `joinpoint j ps b k` and `jmp ref j as` is in `k`, and
|
||||
`ps.size == as.size` -/
|
||||
inductive Code where
|
||||
| decl (xs : Array Name) (doElem : Syntax) (k : Code)
|
||||
| reassign (xs : Array Name) (doElem : Syntax) (k : Code)
|
||||
| decl (xs : Array Var) (doElem : Syntax) (k : Code)
|
||||
| reassign (xs : Array Var) (doElem : Syntax) (k : Code)
|
||||
/- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/
|
||||
| joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (k : Code)
|
||||
| joinpoint (name : Name) (params : Array (Var × Bool)) (body : Code) (k : Code)
|
||||
| seq (action : Syntax) (k : Code)
|
||||
| action (action : Syntax)
|
||||
| «break» (ref : Syntax)
|
||||
| «continue» (ref : Syntax)
|
||||
| «return» (ref : Syntax) (val : Syntax)
|
||||
/- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/
|
||||
| ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
|
||||
| ite (ref : Syntax) (h? : Option Var) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
|
||||
| «match» (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt Code))
|
||||
| jmp (ref : Syntax) (jpName : Name) (args : Array Syntax)
|
||||
deriving Inhabited
|
||||
|
||||
abbrev VarSet := Std.RBMap Name Syntax Name.cmp
|
||||
|
||||
/- A code block, and the collection of variables updated by it. -/
|
||||
structure CodeBlock where
|
||||
code : Code
|
||||
uvars : NameSet := {} -- set of variables updated by `code`
|
||||
uvars : VarSet := {} -- set of variables updated by `code`
|
||||
|
||||
private def nameSetToArray (s : NameSet) : Array Name :=
|
||||
s.fold (fun (xs : Array Name) x => xs.push x) #[]
|
||||
private def varSetToArray (s : VarSet) : Array Var :=
|
||||
s.fold (fun xs _ x => xs.push x) #[]
|
||||
|
||||
private def varsToMessageData (vars : Array Name) : MessageData :=
|
||||
MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.simpMacroScopes)) " "
|
||||
private def varsToMessageData (vars : Array Var) : MessageData :=
|
||||
MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.getId.simpMacroScopes)) " "
|
||||
|
||||
partial def CodeBlocl.toMessageData (codeBlock : CodeBlock) : MessageData :=
|
||||
let us := MessageData.ofList $ (nameSetToArray codeBlock.uvars).toList.map MessageData.ofName
|
||||
let us := MessageData.ofList $ (varSetToArray codeBlock.uvars).toList.map MessageData.ofSyntax
|
||||
let rec loop : Code → MessageData
|
||||
| Code.decl xs _ k => m!"let {varsToMessageData xs} := ...\n{loop k}"
|
||||
| Code.reassign xs _ k => m!"{varsToMessageData xs} := ...\n{loop k}"
|
||||
|
|
@ -264,15 +268,14 @@ def hasBreakContinueReturn (c : Code) : Bool :=
|
|||
|
||||
def mkAuxDeclFor {m} [Monad m] [MonadQuotation m] (e : Syntax) (mkCont : Syntax → m Code) : m Code := withRef e <| withFreshMacroScope do
|
||||
let y ← `(y)
|
||||
let yName := y.getId
|
||||
let doElem ← `(doElem| let y ← $e:term)
|
||||
-- Add elaboration hint for producing sane error message
|
||||
let y ← `(ensure_expected_type% "type mismatch, result value" $y)
|
||||
let k ← mkCont y
|
||||
pure $ Code.decl #[yName] doElem k
|
||||
pure $ Code.decl #[y] doElem k
|
||||
|
||||
/- Convert `action _ e` instructions in `c` into `let y ← e; jmp _ jp (xs y)`. -/
|
||||
partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array Name) : MacroM Code :=
|
||||
partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array Var) : MacroM Code :=
|
||||
let rec loop : Code → MacroM Code
|
||||
| Code.decl xs stx k => return Code.decl xs stx (← loop k)
|
||||
| Code.reassign xs stx k => return Code.reassign xs stx (← loop k)
|
||||
|
|
@ -283,15 +286,14 @@ partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array N
|
|||
| Code.action e => mkAuxDeclFor e fun y =>
|
||||
let ref := e
|
||||
-- We jump to `jp` with xs **and** y
|
||||
let jmpArgs := xs.map $ mkIdentFrom ref
|
||||
let jmpArgs := jmpArgs.push y
|
||||
let jmpArgs := xs.push y
|
||||
return Code.jmp ref jp jmpArgs
|
||||
| c => return c
|
||||
loop code
|
||||
|
||||
structure JPDecl where
|
||||
name : Name
|
||||
params : Array (Name × Bool)
|
||||
params : Array (Var × Bool)
|
||||
body : Code
|
||||
|
||||
def attachJP (jpDecl : JPDecl) (k : Code) : Code :=
|
||||
|
|
@ -300,10 +302,10 @@ def attachJP (jpDecl : JPDecl) (k : Code) : Code :=
|
|||
def attachJPs (jpDecls : Array JPDecl) (k : Code) : Code :=
|
||||
jpDecls.foldr attachJP k
|
||||
|
||||
def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do
|
||||
def mkFreshJP (ps : Array (Var × Bool)) (body : Code) : TermElabM JPDecl := do
|
||||
let ps ←
|
||||
if ps.isEmpty then
|
||||
let y ← mkFreshUserName `y
|
||||
let y ← `(y)
|
||||
pure #[(y, false)]
|
||||
else
|
||||
pure ps
|
||||
|
|
@ -313,51 +315,47 @@ def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do
|
|||
let name ← mkFreshUserName `_do_jp
|
||||
pure { name := name, params := ps, body := body }
|
||||
|
||||
def mkFreshJP' (xs : Array Name) (body : Code) : TermElabM JPDecl :=
|
||||
mkFreshJP (xs.map fun x => (x, true)) body
|
||||
|
||||
def addFreshJP (ps : Array (Name × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do
|
||||
def addFreshJP (ps : Array (Var × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do
|
||||
let jp ← mkFreshJP ps body
|
||||
modify fun (jps : Array JPDecl) => jps.push jp
|
||||
pure jp.name
|
||||
|
||||
def insertVars (rs : NameSet) (xs : Array Name) : NameSet :=
|
||||
xs.foldl (·.insert ·) rs
|
||||
def insertVars (rs : VarSet) (xs : Array Var) : VarSet :=
|
||||
xs.foldl (fun rs x => rs.insert x.getId x) rs
|
||||
|
||||
def eraseVars (rs : NameSet) (xs : Array Name) : NameSet :=
|
||||
xs.foldl (·.erase ·) rs
|
||||
def eraseVars (rs : VarSet) (xs : Array Var) : VarSet :=
|
||||
xs.foldl (·.erase ·.getId) rs
|
||||
|
||||
def eraseOptVar (rs : NameSet) (x? : Option Name) : NameSet :=
|
||||
def eraseOptVar (rs : VarSet) (x? : Option Var) : VarSet :=
|
||||
match x? with
|
||||
| none => rs
|
||||
| some x => rs.insert x
|
||||
| some x => rs.insert x.getId x
|
||||
|
||||
/- Create a new jointpoint for `c`, and jump to it with the variables `rs` -/
|
||||
def mkSimpleJmp (ref : Syntax) (rs : NameSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do
|
||||
let xs := nameSetToArray rs
|
||||
def mkSimpleJmp (ref : Syntax) (rs : VarSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do
|
||||
let xs := varSetToArray rs
|
||||
let jp ← addFreshJP (xs.map fun x => (x, true)) c
|
||||
if xs.isEmpty then
|
||||
let unit ← ``(Unit.unit)
|
||||
return Code.jmp ref jp #[unit]
|
||||
else
|
||||
return Code.jmp ref jp (xs.map $ mkIdentFrom ref)
|
||||
return Code.jmp ref jp xs
|
||||
|
||||
/- Create a new joinpoint that takes `rs` and `val` as arguments. `val` must be syntax representing a pure value.
|
||||
The body of the joinpoint is created using `mkJPBody yFresh`, where `yFresh`
|
||||
is a fresh variable created by this method. -/
|
||||
def mkJmp (ref : Syntax) (rs : NameSet) (val : Syntax) (mkJPBody : Syntax → MacroM Code) : StateRefT (Array JPDecl) TermElabM Code := do
|
||||
let xs := nameSetToArray rs
|
||||
let args := xs.map $ mkIdentFrom ref
|
||||
let args := args.push val
|
||||
let yFresh ← mkFreshUserName `y
|
||||
def mkJmp (ref : Syntax) (rs : VarSet) (val : Syntax) (mkJPBody : Syntax → MacroM Code) : StateRefT (Array JPDecl) TermElabM Code := do
|
||||
let xs := varSetToArray rs
|
||||
let args := xs.push val
|
||||
let yFresh ← withRef ref `(y)
|
||||
let ps := xs.map fun x => (x, true)
|
||||
let ps := ps.push (yFresh, false)
|
||||
let jpBody ← liftMacroM $ mkJPBody (mkIdentFrom ref yFresh)
|
||||
let jpBody ← liftMacroM $ mkJPBody yFresh
|
||||
let jp ← addFreshJP ps jpBody
|
||||
pure $ Code.jmp ref jp args
|
||||
|
||||
/- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/
|
||||
partial def pullExitPointsAux : NameSet → Code → StateRefT (Array JPDecl) TermElabM Code
|
||||
partial def pullExitPointsAux : VarSet → Code → StateRefT (Array JPDecl) TermElabM Code
|
||||
| rs, Code.decl xs stx k => return Code.decl xs stx (← pullExitPointsAux (eraseVars rs xs) k)
|
||||
| rs, Code.reassign xs stx k => return Code.reassign xs stx (← pullExitPointsAux (insertVars rs xs) k)
|
||||
| rs, Code.joinpoint j ps b k => return Code.joinpoint j ps (← pullExitPointsAux rs b) (← pullExitPointsAux rs k)
|
||||
|
|
@ -430,26 +428,26 @@ def pullExitPoints (c : Code) : TermElabM Code := do
|
|||
else
|
||||
pure c
|
||||
|
||||
partial def extendUpdatedVarsAux (c : Code) (ws : NameSet) : TermElabM Code :=
|
||||
partial def extendUpdatedVarsAux (c : Code) (ws : VarSet) : TermElabM Code :=
|
||||
let rec update : Code → TermElabM Code
|
||||
| Code.joinpoint j ps b k => return Code.joinpoint j ps (← update b) (← update k)
|
||||
| Code.seq e k => return Code.seq e (← update k)
|
||||
| c@(Code.«match» ref g ds t alts) => do
|
||||
if alts.any fun alt => alt.vars.any fun x => ws.contains x then
|
||||
if alts.any fun alt => alt.vars.any fun x => ws.contains x.getId then
|
||||
-- If a pattern variable is shadowing a variable in ws, we `pullExitPoints`
|
||||
pullExitPoints c
|
||||
else
|
||||
return Code.«match» ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) })
|
||||
| Code.ite ref none o c t e => return Code.ite ref none o c (← update t) (← update e)
|
||||
| c@(Code.ite ref (some h) o cond t e) => do
|
||||
if ws.contains h then
|
||||
if ws.contains h.getId then
|
||||
-- if the `h` at `if h:c then t else e` shadows a variable in `ws`, we `pullExitPoints`
|
||||
pullExitPoints c
|
||||
else
|
||||
return Code.ite ref (some h) o cond (← update t) (← update e)
|
||||
| Code.reassign xs stx k => return Code.reassign xs stx (← update k)
|
||||
| c@(Code.decl xs stx k) => do
|
||||
if xs.any fun x => ws.contains x then
|
||||
if xs.any fun x => ws.contains x.getId then
|
||||
-- One the declared variables is shadowing a variable in `ws`
|
||||
pullExitPoints c
|
||||
else
|
||||
|
|
@ -462,14 +460,14 @@ Extend the set of updated variables. It assumes `ws` is a super set of `c.uvars`
|
|||
We **cannot** simply update the field `c.uvars`, because `c` may have shadowed some variable in `ws`.
|
||||
See discussion at `pullExitPoints`.
|
||||
-/
|
||||
partial def extendUpdatedVars (c : CodeBlock) (ws : NameSet) : TermElabM CodeBlock := do
|
||||
if ws.any fun x => !c.uvars.contains x then
|
||||
partial def extendUpdatedVars (c : CodeBlock) (ws : VarSet) : TermElabM CodeBlock := do
|
||||
if ws.any fun x _ => !c.uvars.contains x then
|
||||
-- `ws` contains a variable that is not in `c.uvars`, but in `c.dvars` (i.e., it has been shadowed)
|
||||
pure { code := (← extendUpdatedVarsAux c.code ws), uvars := ws }
|
||||
else
|
||||
pure { c with uvars := ws }
|
||||
|
||||
private def union (s₁ s₂ : NameSet) : NameSet :=
|
||||
private def union (s₁ s₂ : VarSet) : VarSet :=
|
||||
s₁.fold (·.insert ·) s₂
|
||||
|
||||
/-
|
||||
|
|
@ -490,7 +488,7 @@ Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are
|
|||
declared by it. It is an array because we have let-declarations that declare multiple variables.
|
||||
Example: `let (x, y) := t`
|
||||
-/
|
||||
def mkVarDeclCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : CodeBlock := {
|
||||
def mkVarDeclCore (xs : Array Var) (stx : Syntax) (c : CodeBlock) : CodeBlock := {
|
||||
code := Code.decl xs stx c.code,
|
||||
uvars := eraseVars c.uvars xs
|
||||
}
|
||||
|
|
@ -501,12 +499,12 @@ Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are
|
|||
declared by it. It is an array because we have let-declarations that declare multiple variables.
|
||||
Example: `(x, y) ← t`
|
||||
-/
|
||||
def mkReassignCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : TermElabM CodeBlock := do
|
||||
def mkReassignCore (xs : Array Var) (stx : Syntax) (c : CodeBlock) : TermElabM CodeBlock := do
|
||||
let us := c.uvars
|
||||
let ws := insertVars us xs
|
||||
-- If `xs` contains a new updated variable, then we must use `extendUpdatedVars`.
|
||||
-- See discussion at `pullExitPoints`
|
||||
let code ← if xs.any fun x => !us.contains x then extendUpdatedVarsAux c.code ws else pure c.code
|
||||
let code ← if xs.any fun x => !us.contains x.getId then extendUpdatedVarsAux c.code ws else pure c.code
|
||||
pure { code := Code.reassign xs stx code, uvars := ws }
|
||||
|
||||
def mkSeq (action : Syntax) (c : CodeBlock) : CodeBlock :=
|
||||
|
|
@ -525,7 +523,7 @@ def mkContinue (ref : Syntax) : CodeBlock :=
|
|||
{ code := Code.«continue» ref }
|
||||
|
||||
def mkIte (ref : Syntax) (optIdent : Syntax) (cond : Syntax) (thenBranch : CodeBlock) (elseBranch : CodeBlock) : TermElabM CodeBlock := do
|
||||
let x? := if optIdent.isNone then none else some optIdent[0].getId
|
||||
let x? := optIdent.getOptional?
|
||||
let (thenBranch, elseBranch) ← homogenize thenBranch elseBranch
|
||||
pure {
|
||||
code := Code.ite ref x? optIdent cond thenBranch.code elseBranch.code,
|
||||
|
|
@ -555,12 +553,12 @@ def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optMotive : Sy
|
|||
|
||||
/- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`.
|
||||
This method assumes `terminal` is a terminal -/
|
||||
def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Name) (k : CodeBlock) : TermElabM CodeBlock := do
|
||||
def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Var) (k : CodeBlock) : TermElabM CodeBlock := do
|
||||
unless hasTerminalAction terminal.code do
|
||||
throwErrorAt kRef "'do' element is unreachable"
|
||||
let (terminal, k) ← homogenize terminal k
|
||||
let xs := nameSetToArray k.uvars
|
||||
let y ← match y? with | some y => pure y | none => mkFreshUserName `y
|
||||
let xs := varSetToArray k.uvars
|
||||
let y ← match y? with | some y => pure y | none => `(y)
|
||||
let ps := xs.map fun x => (x, true)
|
||||
let ps := ps.push (y, false)
|
||||
let jpDecl ← mkFreshJP ps k.code
|
||||
|
|
@ -568,26 +566,26 @@ def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Name) (k : CodeBl
|
|||
let terminal ← liftMacroM $ convertTerminalActionIntoJmp terminal.code jp xs
|
||||
pure { code := attachJP jpDecl terminal, uvars := k.uvars }
|
||||
|
||||
def getLetIdDeclVar (letIdDecl : Syntax) : Name :=
|
||||
letIdDecl[0].getId
|
||||
def getLetIdDeclVar (letIdDecl : Syntax) : Var :=
|
||||
letIdDecl[0]
|
||||
|
||||
-- support both regular and syntax match
|
||||
def getPatternVarsEx (pattern : Syntax) : TermElabM (Array Name) :=
|
||||
getPatternVarNames <$> getPatternVars pattern <|>
|
||||
Array.map Syntax.getId <$> Quotation.getPatternVars pattern
|
||||
def getPatternVarsEx (pattern : Syntax) : TermElabM (Array Var) :=
|
||||
getPatternVars pattern <|>
|
||||
Quotation.getPatternVars pattern
|
||||
|
||||
def getPatternsVarsEx (patterns : Array Syntax) : TermElabM (Array Name) :=
|
||||
getPatternVarNames <$> getPatternsVars patterns <|>
|
||||
Array.map Syntax.getId <$> Quotation.getPatternsVars patterns
|
||||
def getPatternsVarsEx (patterns : Array Syntax) : TermElabM (Array Var) :=
|
||||
getPatternsVars patterns <|>
|
||||
Quotation.getPatternsVars patterns
|
||||
|
||||
def getLetPatDeclVars (letPatDecl : Syntax) : TermElabM (Array Name) := do
|
||||
def getLetPatDeclVars (letPatDecl : Syntax) : TermElabM (Array Var) := do
|
||||
let pattern := letPatDecl[0]
|
||||
getPatternVarsEx pattern
|
||||
|
||||
def getLetEqnsDeclVar (letEqnsDecl : Syntax) : Name :=
|
||||
letEqnsDecl[0].getId
|
||||
def getLetEqnsDeclVar (letEqnsDecl : Syntax) : Var :=
|
||||
letEqnsDecl[0]
|
||||
|
||||
def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Name) := do
|
||||
def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Var) := do
|
||||
let arg := letDecl[0]
|
||||
if arg.getKind == ``Lean.Parser.Term.letIdDecl then
|
||||
pure #[getLetIdDeclVar arg]
|
||||
|
|
@ -598,33 +596,33 @@ def getLetDeclVars (letDecl : Syntax) : TermElabM (Array Name) := do
|
|||
else
|
||||
throwError "unexpected kind of let declaration"
|
||||
|
||||
def getDoLetVars (doLet : Syntax) : TermElabM (Array Name) :=
|
||||
def getDoLetVars (doLet : Syntax) : TermElabM (Array Var) :=
|
||||
-- leading_parser "let " >> optional "mut " >> letDecl
|
||||
getLetDeclVars doLet[2]
|
||||
|
||||
def getHaveIdLhsVar (optIdent : Syntax) : Name :=
|
||||
def getHaveIdLhsVar (optIdent : Syntax) : TermElabM Var :=
|
||||
if optIdent.isNone then
|
||||
`this
|
||||
`(this)
|
||||
else
|
||||
optIdent[0].getId
|
||||
pure optIdent[0]
|
||||
|
||||
def getDoHaveVars (doHave : Syntax) : TermElabM (Array Name) :=
|
||||
def getDoHaveVars (doHave : Syntax) : TermElabM (Array Var) := do
|
||||
-- doHave := leading_parser "have " >> Term.haveDecl
|
||||
-- haveDecl := leading_parser haveIdDecl <|> letPatDecl <|> haveEqnsDecl
|
||||
let arg := doHave[1][0]
|
||||
if arg.getKind == ``Lean.Parser.Term.haveIdDecl then
|
||||
-- haveIdDecl := leading_parser atomic (haveIdLhs >> " := ") >> termParser
|
||||
-- haveIdLhs := optional (ident >> many (ppSpace >> (simpleBinderWithoutType <|> bracketedBinder))) >> optType
|
||||
pure #[getHaveIdLhsVar arg[0]]
|
||||
pure #[← getHaveIdLhsVar arg[0]]
|
||||
else if arg.getKind == ``Lean.Parser.Term.letPatDecl then
|
||||
getLetPatDeclVars arg
|
||||
else if arg.getKind == ``Lean.Parser.Term.haveEqnsDecl then
|
||||
-- haveEqnsDecl := leading_parser haveIdLhs >> matchAlts
|
||||
pure #[getHaveIdLhsVar arg[0]]
|
||||
pure #[← getHaveIdLhsVar arg[0]]
|
||||
else
|
||||
throwError "unexpected kind of have declaration"
|
||||
|
||||
def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Name) := do
|
||||
def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Var) := do
|
||||
-- letRecDecls is an array of `(group (optional attributes >> letDecl))`
|
||||
let letRecDecls := doLetRec[1][0].getSepArgs
|
||||
let letDecls := letRecDecls.map fun p => p[2]
|
||||
|
|
@ -635,16 +633,16 @@ def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Name) := do
|
|||
pure allVars
|
||||
|
||||
-- ident >> optType >> leftArrow >> termParser
|
||||
def getDoIdDeclVar (doIdDecl : Syntax) : Name :=
|
||||
doIdDecl[0].getId
|
||||
def getDoIdDeclVar (doIdDecl : Syntax) : Var :=
|
||||
doIdDecl[0]
|
||||
|
||||
-- termParser >> leftArrow >> termParser >> optional (" | " >> termParser)
|
||||
def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Name) := do
|
||||
def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Var) := do
|
||||
let pattern := doPatDecl[0]
|
||||
getPatternVarsEx pattern
|
||||
|
||||
-- leading_parser "let " >> optional "mut " >> (doIdDecl <|> doPatDecl)
|
||||
def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do
|
||||
def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Var) := do
|
||||
let decl := doLetArrow[2]
|
||||
if decl.getKind == ``Lean.Parser.Term.doIdDecl then
|
||||
pure #[getDoIdDeclVar decl]
|
||||
|
|
@ -653,7 +651,7 @@ def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do
|
|||
else
|
||||
throwError "unexpected kind of 'do' declaration"
|
||||
|
||||
def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Name) := do
|
||||
def getDoReassignVars (doReassign : Syntax) : TermElabM (Array Var) := do
|
||||
let arg := doReassign[0]
|
||||
if arg.getKind == ``Lean.Parser.Term.letIdDecl then
|
||||
pure #[getLetIdDeclVar arg]
|
||||
|
|
@ -748,20 +746,20 @@ def isDoExpr? (doElem : Syntax) : Option Syntax :=
|
|||
|
||||
We use this method when expanding the `for-in` notation.
|
||||
-/
|
||||
private def destructTuple (uvars : Array Name) (x : Syntax) (body : Syntax) : MacroM Syntax := do
|
||||
private def destructTuple (uvars : Array Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do
|
||||
if uvars.size == 0 then
|
||||
return body
|
||||
else if uvars.size == 1 then
|
||||
`(let $(← mkIdentFromRef uvars[0]):ident := $x; $body)
|
||||
`(let $(uvars[0]):ident := $x; $body)
|
||||
else
|
||||
destruct uvars.toList x body
|
||||
where
|
||||
destruct (as : List Name) (x : Syntax) (body : Syntax) : MacroM Syntax := do
|
||||
destruct (as : List Var) (x : Syntax) (body : Syntax) : MacroM Syntax := do
|
||||
match as with
|
||||
| [a, b] => `(let $(← mkIdentFromRef a):ident := $x.1; let $(← mkIdentFromRef b):ident := $x.2; $body)
|
||||
| [a, b] => `(let $a:ident := $x.1; let $b:ident := $x.2; $body)
|
||||
| a :: as => withFreshMacroScope do
|
||||
let rest ← destruct as (← `(x)) body
|
||||
`(let $(← mkIdentFromRef a):ident := $x.1; let x := $x.2; $rest)
|
||||
`(let $a:ident := $x.1; let x := $x.2; $rest)
|
||||
| _ => unreachable!
|
||||
|
||||
/-
|
||||
|
|
@ -864,15 +862,14 @@ def Kind.isRegular : Kind → Bool
|
|||
|
||||
structure Context where
|
||||
m : Syntax -- Syntax to reference the monad associated with the do notation.
|
||||
uvars : Array Name
|
||||
uvars : Array Var
|
||||
kind : Kind
|
||||
|
||||
abbrev M := ReaderT Context MacroM
|
||||
|
||||
def mkUVarTuple : M Syntax := do
|
||||
let ctx ← read
|
||||
let uvarIdents ← ctx.uvars.mapM mkIdentFromRef
|
||||
mkTuple uvarIdents
|
||||
mkTuple ctx.uvars
|
||||
|
||||
def returnToTerm (val : Syntax) : M Syntax := do
|
||||
let ctx ← read
|
||||
|
|
@ -993,9 +990,9 @@ def mkIte (optIdent : Syntax) (cond : Syntax) (thenBranch : Syntax) (elseBranch
|
|||
let h := optIdent[0]
|
||||
``(if $h:ident : $cond then $thenBranch else $elseBranch)
|
||||
|
||||
def mkJoinPoint (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withRef body <| withFreshMacroScope do
|
||||
let pTypes ← ps.mapM fun ⟨id, useTypeOf⟩ => do if useTypeOf then `(type_of% $(← mkIdentFromRef id)) else `(_)
|
||||
let ps ← ps.mapM fun ⟨id, useTypeOf⟩ => mkIdentFromRef id
|
||||
def mkJoinPoint (j : Name) (ps : Array (Syntax × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withRef body <| withFreshMacroScope do
|
||||
let pTypes ← ps.mapM fun ⟨id, useTypeOf⟩ => do if useTypeOf then `(type_of% $id) else `(_)
|
||||
let ps := ps.map (·.1)
|
||||
/-
|
||||
We use `let_delayed` instead of `let` for joinpoints to make sure `$k` is elaborated before `$body`.
|
||||
By elaborating `$k` first, we "learn" more about `$body`'s type.
|
||||
|
|
@ -1045,7 +1042,7 @@ partial def toTerm : Code → M Syntax
|
|||
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts]
|
||||
return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts]
|
||||
|
||||
def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do
|
||||
def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := do
|
||||
let term ← toTerm code { m := m, kind := kind, uvars := uvars }
|
||||
pure term
|
||||
|
||||
|
|
@ -1066,7 +1063,7 @@ def mkNestedKind (a r bc : Bool) : Kind :=
|
|||
| true, true, true => Kind.nestedPRBC
|
||||
| false, false, false => unreachable!
|
||||
|
||||
def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Name) (a r bc : Bool) : MacroM Syntax := do
|
||||
def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do
|
||||
ToTerm.run code m uvars (mkNestedKind a r bc)
|
||||
|
||||
/- Given a term `term` produced by `ToTerm.run`, pattern match on its result.
|
||||
|
|
@ -1077,9 +1074,9 @@ def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Name) (a r bc : Bool)
|
|||
- `bc` is true if the code block has a `Code.break _` or `Code.continue _` exit point
|
||||
|
||||
The result is a sequence of `doElem` -/
|
||||
def matchNestedTermResult (term : Syntax) (uvars : Array Name) (a r bc : Bool) : MacroM (List Syntax) := do
|
||||
def matchNestedTermResult (term : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM (List Syntax) := do
|
||||
let toDoElems (auxDo : Syntax) : List Syntax := getDoSeqElems (getDoSeq auxDo)
|
||||
let u ← mkTuple (← uvars.mapM mkIdentFromRef)
|
||||
let u ← mkTuple uvars
|
||||
match a, r, bc with
|
||||
| true, false, false =>
|
||||
if uvars.isEmpty then
|
||||
|
|
@ -1135,41 +1132,41 @@ namespace ToCodeBlock
|
|||
structure Context where
|
||||
ref : Syntax
|
||||
m : Syntax -- Syntax representing the monad associated with the do notation.
|
||||
mutableVars : NameSet := {}
|
||||
mutableVars : VarSet := {}
|
||||
insideFor : Bool := false
|
||||
|
||||
abbrev M := ReaderT Context TermElabM
|
||||
|
||||
def withNewMutableVars {α} (newVars : Array Name) (mutable : Bool) (x : M α) : M α :=
|
||||
def withNewMutableVars {α} (newVars : Array Var) (mutable : Bool) (x : M α) : M α :=
|
||||
withReader (fun ctx => if mutable then { ctx with mutableVars := insertVars ctx.mutableVars newVars } else ctx) x
|
||||
|
||||
def checkReassignable (xs : Array Name) : M Unit := do
|
||||
def checkReassignable (xs : Array Var) : M Unit := do
|
||||
let throwInvalidReassignment (x : Name) : M Unit :=
|
||||
throwError "'{x.simpMacroScopes}' cannot be reassigned"
|
||||
let ctx ← read
|
||||
for x in xs do
|
||||
unless ctx.mutableVars.contains x do
|
||||
throwInvalidReassignment x
|
||||
unless ctx.mutableVars.contains x.getId do
|
||||
throwInvalidReassignment x.getId
|
||||
|
||||
def checkNotShadowingMutable (xs : Array Name) : M Unit := do
|
||||
def checkNotShadowingMutable (xs : Array Var) : M Unit := do
|
||||
let throwInvalidShadowing (x : Name) : M Unit :=
|
||||
throwError "mutable variable '{x.simpMacroScopes}' cannot be shadowed"
|
||||
let ctx ← read
|
||||
for x in xs do
|
||||
if ctx.mutableVars.contains x then
|
||||
throwInvalidShadowing x
|
||||
if ctx.mutableVars.contains x.getId then
|
||||
throwInvalidShadowing x.getId
|
||||
|
||||
def withFor {α} (x : M α) : M α :=
|
||||
withReader (fun ctx => { ctx with insideFor := true }) x
|
||||
|
||||
structure ToForInTermResult where
|
||||
uvars : Array Name
|
||||
uvars : Array Var
|
||||
term : Syntax
|
||||
|
||||
def mkForInBody (x : Syntax) (forInBody : CodeBlock) : M ToForInTermResult := do
|
||||
let ctx ← read
|
||||
let uvars := forInBody.uvars
|
||||
let uvars := nameSetToArray uvars
|
||||
let uvars := varSetToArray uvars
|
||||
let term ← liftMacroM $ ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn)
|
||||
pure ⟨uvars, term⟩
|
||||
|
||||
|
|
@ -1233,7 +1230,7 @@ structure Catch where
|
|||
optType : Syntax
|
||||
codeBlock : CodeBlock
|
||||
|
||||
def getTryCatchUpdatedVars (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) : NameSet :=
|
||||
def getTryCatchUpdatedVars (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : Option CodeBlock) : VarSet :=
|
||||
let ws := tryCode.uvars
|
||||
let ws := catches.foldl (fun ws alt => union alt.codeBlock.uvars ws) ws
|
||||
let ws := match finallyCode? with
|
||||
|
|
@ -1248,6 +1245,10 @@ def tryCatchPred (tryCode : CodeBlock) (catches : Array Catch) (finallyCode? : O
|
|||
| none => false
|
||||
| some finallyCode => p finallyCode.code
|
||||
|
||||
def mutVarNamesToDefStx (mutVars : Array Syntax) : M (Array Syntax) := do
|
||||
let ctx ← read
|
||||
return mutVars.map fun v => ctx.mutableVars.findD v.getId v
|
||||
|
||||
mutual
|
||||
/- "Concatenate" `c` with `doSeqToCode doElems` -/
|
||||
partial def concatWith (c : CodeBlock) (doElems : List Syntax) : M CodeBlock :=
|
||||
|
|
@ -1273,7 +1274,7 @@ mutual
|
|||
let ref := doLetArrow
|
||||
let decl := doLetArrow[2]
|
||||
if decl.getKind == ``Lean.Parser.Term.doIdDecl then
|
||||
let y := decl[0].getId
|
||||
let y := decl[0]
|
||||
checkNotShadowingMutable #[y]
|
||||
let doElem := decl[3]
|
||||
let k ← withNewMutableVars #[y] (isMutableLet doLetArrow) (doSeqToCode doElems)
|
||||
|
|
@ -1422,7 +1423,7 @@ mutual
|
|||
let forElems := getDoSeqElems doFor[3]
|
||||
let forInBodyCodeBlock ← withFor (doSeqToCode forElems)
|
||||
let ⟨uvars, forInBody⟩ ← mkForInBody x forInBodyCodeBlock
|
||||
let uvarsTuple ← liftMacroM do mkTuple (← uvars.mapM mkIdentFromRef)
|
||||
let uvarsTuple ← liftMacroM do mkTuple uvars
|
||||
if hasReturn forInBodyCodeBlock.code then
|
||||
let forInBody ← liftMacroM <| destructTuple uvars (← `(r)) forInBody
|
||||
let forInTerm ←
|
||||
|
|
@ -1487,7 +1488,7 @@ mutual
|
|||
if catchStx.getKind == ``Lean.Parser.Term.doCatch then
|
||||
let x := catchStx[1]
|
||||
if x.isIdent then
|
||||
withRef x <| checkNotShadowingMutable #[x.getId]
|
||||
withRef x <| checkNotShadowingMutable #[x]
|
||||
let optType := catchStx[2]
|
||||
let c ← doSeqToCode (getDoSeqElems catchStx[4])
|
||||
pure { x := x, optType := optType, codeBlock := c : Catch }
|
||||
|
|
@ -1504,7 +1505,7 @@ mutual
|
|||
throwError "invalid 'try', it must have a 'catch' or 'finally'"
|
||||
let ctx ← read
|
||||
let ws := getTryCatchUpdatedVars tryCode catches finallyCode?
|
||||
let uvars := nameSetToArray ws
|
||||
let uvars := varSetToArray ws
|
||||
let a := tryCatchPred tryCode catches finallyCode? hasTerminalAction
|
||||
let r := tryCatchPred tryCode catches finallyCode? hasReturn
|
||||
let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue
|
||||
|
|
|
|||
|
|
@ -182,11 +182,10 @@ structure PatternVarDecl where
|
|||
private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α :=
|
||||
let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do
|
||||
if h : i < pVars.size then
|
||||
match pVars.get ⟨i, h⟩ with
|
||||
| { userName } =>
|
||||
let type ← mkFreshTypeMVar
|
||||
withLocalDecl userName BinderInfo.default type fun x =>
|
||||
loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous)
|
||||
let var := pVars.get ⟨i, h⟩
|
||||
let type ← mkFreshTypeMVar
|
||||
withLocalDecl var.getId BinderInfo.default type fun x =>
|
||||
loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous)
|
||||
else
|
||||
k decls
|
||||
loop 0 #[] #[]
|
||||
|
|
@ -882,7 +881,7 @@ private def generalize (discrs : Array Discr) (matchType : Expr) (altViews : Arr
|
|||
if ysUserNames.contains yUserName then
|
||||
yUserName ← mkFreshUserName yUserName
|
||||
-- Explicitly provided pattern variables shadow `y`
|
||||
else if patternVars.any fun x => x.userName == yUserName then
|
||||
else if patternVars.any fun x => x.getId == yUserName then
|
||||
yUserName ← mkFreshUserName yUserName
|
||||
return ysUserNames.push yUserName
|
||||
let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n
|
||||
|
|
|
|||
|
|
@ -11,12 +11,7 @@ namespace Lean.Elab.Term
|
|||
|
||||
open Meta
|
||||
|
||||
structure PatternVar where
|
||||
userName : Name
|
||||
deriving BEq
|
||||
|
||||
instance : ToString PatternVar where
|
||||
toString x := toString x.userName
|
||||
abbrev PatternVar := Syntax -- TODO: should be `TSyntax identKind`
|
||||
|
||||
/-
|
||||
Patterns define new local variables.
|
||||
|
|
@ -111,7 +106,7 @@ private def processVar (idStx : Syntax) : M Syntax := do
|
|||
throwError "invalid pattern variable, must be atomic"
|
||||
if (← get).found.contains id then
|
||||
throwError "invalid pattern, variable '{id}' occurred more than once"
|
||||
modify fun s => { s with vars := s.vars.push { userName := id }, found := s.found.insert id }
|
||||
modify fun s => { s with vars := s.vars.push idStx, found := s.found.insert id }
|
||||
return idStx
|
||||
|
||||
private def nameToPattern : Name → TermElabM Syntax
|
||||
|
|
@ -366,6 +361,6 @@ def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) :=
|
|||
return s.vars
|
||||
|
||||
def getPatternVarNames (pvars : Array PatternVar) : Array Name :=
|
||||
pvars.map fun x => x.userName
|
||||
pvars.map fun x => x.getId
|
||||
|
||||
end Lean.Elab.Term
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue