chore: cleanup

This commit is contained in:
Leonardo de Moura 2020-10-29 10:24:16 -07:00
parent f31b0d7d19
commit 6858cb5fb6
29 changed files with 2420 additions and 2433 deletions

View file

@ -1271,7 +1271,7 @@ def Prod.map.{u₁, u₂, v₁, v₂} {α₁ : Type u₁} {α₂ : Type u₂} {
/- Dependent products -/
theorem exOfPsig {α : Type u} {p : α → Prop} : (PSigma (fun x => p x)) → Exists (fun x => p x)
| ⟨x, hx⟩ => ⟨x, hx⟩
| ⟨x, hx⟩ => ⟨x, hx⟩
protected theorem PSigma.eta {α : Sort u} {β : α → Sort v} {a₁ a₂ : α} {b₁ : β a₁} {b₂ : β a₂}
(h₁ : a₁ = a₂) (h₂ : Eq.rec (motive := fun a _ => β a) b₁ h₁ = b₂) : PSigma.mk a₁ b₁ = PSigma.mk a₂ b₂ := by

View file

@ -14,22 +14,22 @@ namespace Array
-- TODO: remove `partial` using well-founded recursion
@[specialize] partial def binSearchAux {α : Type u} {β : Type v} [Inhabited α] [Inhabited β] (lt : αα → Bool) (found : Option α → β) (as : Array α) (k : α) : Nat → Nat → β
| lo, hi =>
if lo <= hi then
let m := (lo + hi)/2;
let a := as.get! m;
if lt a k then binSearchAux lt found as k (m+1) hi
else if lt k a then
if m == 0 then found none
else binSearchAux lt found as k lo (m-1)
else found (some a)
else found none
| lo, hi =>
if lo <= hi then
let m := (lo + hi)/2;
let a := as.get! m;
if lt a k then binSearchAux lt found as k (m+1) hi
else if lt k a then
if m == 0 then found none
else binSearchAux lt found as k lo (m-1)
else found (some a)
else found none
@[inline] def binSearch {α : Type} [Inhabited α] (as : Array α) (k : α) (lt : αα → Bool) (lo := 0) (hi := as.size - 1) : Option α :=
binSearchAux lt id as k lo hi
binSearchAux lt id as k lo hi
@[inline] def binSearchContains {α : Type} [Inhabited α] (as : Array α) (k : α) (lt : αα → Bool) (lo := 0) (hi := as.size - 1) : Bool :=
binSearchAux lt Option.isSome as k lo hi
binSearchAux lt Option.isSome as k lo hi
@[specialize] private partial def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α]
(lt : αα → Bool)
@ -37,17 +37,17 @@ binSearchAux lt Option.isSome as k lo hi
(add : Unit → m α)
(as : Array α)
(k : α) : Nat → Nat → m (Array α)
| lo, hi =>
-- as[lo] < k < as[hi]
let mid := (lo + hi)/2;
let midVal := as.get! mid;
if lt midVal k then
if mid == lo then do let v ← add (); pure $ as.insertAt (lo+1) v
else binInsertAux lt merge add as k mid hi
else if lt k midVal then
binInsertAux lt merge add as k lo mid
else do
as.modifyM mid $ fun v => merge v
| lo, hi =>
-- as[lo] < k < as[hi]
let mid := (lo + hi)/2;
let midVal := as.get! mid;
if lt midVal k then
if mid == lo then do let v ← add (); pure $ as.insertAt (lo+1) v
else binInsertAux lt merge add as k mid hi
else if lt k midVal then
binInsertAux lt merge add as k lo mid
else do
as.modifyM mid $ fun v => merge v
@[specialize] partial def binInsertM {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α]
(lt : αα → Bool)
@ -55,14 +55,14 @@ binSearchAux lt Option.isSome as k lo hi
(add : Unit → m α)
(as : Array α)
(k : α) : m (Array α) :=
if as.isEmpty then do let v ← add (); pure $ as.push v
else if lt k (as.get! 0) then do let v ← add (); pure $ as.insertAt 0 v
else if !lt (as.get! 0) k then as.modifyM 0 $ merge
else if lt as.back k then do let v ← add (); pure $ as.push v
else if !lt k as.back then as.modifyM (as.size - 1) $ merge
else binInsertAux lt merge add as k 0 (as.size - 1)
if as.isEmpty then do let v ← add (); pure $ as.push v
else if lt k (as.get! 0) then do let v ← add (); pure $ as.insertAt 0 v
else if !lt (as.get! 0) k then as.modifyM 0 $ merge
else if lt as.back k then do let v ← add (); pure $ as.push v
else if !lt k as.back then as.modifyM (as.size - 1) $ merge
else binInsertAux lt merge add as k 0 (as.size - 1)
@[inline] def binInsert {α : Type u} [Inhabited α] (lt : αα → Bool) (as : Array α) (k : α) : Array α :=
Id.run $ binInsertM lt (fun _ => k) (fun _ => k) as k
Id.run $ binInsertM lt (fun _ => k) (fun _ => k) as k
end Array

View file

@ -16,8 +16,8 @@ Remark: we can define `mapM`, `mapM₂` and `forM` using `Applicative` instead o
Example:
```
def mapM {m : Type u → Type v} [Applicative m] {α : Type w} {β : Type u} (f : α → m β) : List α → m (List β)
| [] => pure []
| a::as => List.cons <$> (f a) <*> mapM as
| [] => pure []
| a::as => List.cons <$> (f a) <*> mapM as
```
However, we consider `f <$> a <*> b` an anti-idiom because the generated code

View file

@ -201,16 +201,16 @@ protected theorem oneMul (n : Nat) : 1 * n = n :=
Nat.mulComm n 1 ▸ Nat.mulOne n
protected theorem leftDistrib : ∀ (n m k : Nat), n * (m + k) = n * m + n * k
| 0, m, k => (Nat.zeroMul (m + k)).symm ▸ (Nat.zeroMul m).symm ▸ (Nat.zeroMul k).symm ▸ rfl
| succ n, m, k =>
have h₁ : succ n * (m + k) = n * (m + k) + (m + k) from succMul ..
have h₂ : n * (m + k) + (m + k) = (n * m + n * k) + (m + k) from Nat.leftDistrib n m k ▸ rfl
have h₃ : (n * m + n * k) + (m + k) = n * m + (n * k + (m + k)) from Nat.addAssoc ..
have h₄ : n * m + (n * k + (m + k)) = n * m + (m + (n * k + k)) from congrArg (fun x => n*m + x) (Nat.addLeftComm ..)
have h₅ : n * m + (m + (n * k + k)) = (n * m + m) + (n * k + k) from (Nat.addAssoc ..).symm
have h₆ : (n * m + m) + (n * k + k) = (n * m + m) + succ n * k from succMul n k ▸ rfl
have h₇ : (n * m + m) + succ n * k = succ n * m + succ n * k from succMul n m ▸ rfl
(((((h₁.trans h₂).trans h₃).trans h₄).trans h₅).trans h₆).trans h₇
| 0, m, k => (Nat.zeroMul (m + k)).symm ▸ (Nat.zeroMul m).symm ▸ (Nat.zeroMul k).symm ▸ rfl
| succ n, m, k =>
have h₁ : succ n * (m + k) = n * (m + k) + (m + k) from succMul ..
have h₂ : n * (m + k) + (m + k) = (n * m + n * k) + (m + k) from Nat.leftDistrib n m k ▸ rfl
have h₃ : (n * m + n * k) + (m + k) = n * m + (n * k + (m + k)) from Nat.addAssoc ..
have h₄ : n * m + (n * k + (m + k)) = n * m + (m + (n * k + k)) from congrArg (fun x => n*m + x) (Nat.addLeftComm ..)
have h₅ : n * m + (m + (n * k + k)) = (n * m + m) + (n * k + k) from (Nat.addAssoc ..).symm
have h₆ : (n * m + m) + (n * k + k) = (n * m + m) + succ n * k from succMul n k ▸ rfl
have h₇ : (n * m + m) + succ n * k = succ n * m + succ n * k from succMul n m ▸ rfl
(((((h₁.trans h₂).trans h₃).trans h₄).trans h₅).trans h₆).trans h₇
protected theorem rightDistrib (n m k : Nat) : (n + m) * k = n * k + m * k :=
have h₁ : (n + m) * k = k * (n + m) from Nat.mulComm ..

View file

@ -10,7 +10,7 @@ import Init.Data.ToString.Basic
syntax:max "s!" (interpolatedStr term) : term
macro_rules
| `(s! $interpStr) => do
let chunks := interpStr.getArgs
let r ← Lean.Syntax.expandInterpolatedStrChunks chunks (fun a b => `($a ++ $b)) (fun a => `(toString $a))
`(($r : String))
| `(s! $interpStr) => do
let chunks := interpStr.getArgs
let r ← Lean.Syntax.expandInterpolatedStrChunks chunks (fun a b => `($a ++ $b)) (fun a => `(toString $a))
`(($r : String))

View file

@ -109,21 +109,21 @@ protected def append : Name → Name → Name
instance : Append Name := ⟨Name.append⟩
def capitalize : Name → Name
| Name.str p s _ => mkNameStr p s.capitalize
| n => n
| Name.str p s _ => mkNameStr p s.capitalize
| n => n
def appendAfter : Name → String → Name
| str p s _, suffix => mkNameStr p (s ++ suffix)
| n, suffix => mkNameStr n suffix
| str p s _, suffix => mkNameStr p (s ++ suffix)
| n, suffix => mkNameStr n suffix
def appendIndexAfter : Name → Nat → Name
| str p s _, idx => mkNameStr p (s ++ "_" ++ toString idx)
| n, idx => mkNameStr n ("_" ++ toString idx)
| str p s _, idx => mkNameStr p (s ++ "_" ++ toString idx)
| n, idx => mkNameStr n ("_" ++ toString idx)
def appendBefore : Name → String → Name
| anonymous, pre => mkNameStr anonymous pre
| str p s _, pre => mkNameStr p (pre ++ s)
| num p n _, pre => mkNameNum (mkNameStr p pre) n
| anonymous, pre => mkNameStr anonymous pre
| str p s _, pre => mkNameStr p (pre ++ s)
| num p n _, pre => mkNameNum (mkNameStr p pre) n
end Name

View file

@ -11,7 +11,7 @@ universes u v
set_option codegen false
inductive Acc {α : Sort u} (r : αα → Prop) : α → Prop
| intro (x : α) (h : (y : α) → r y x → Acc r y) : Acc r x
| intro (x : α) (h : (y : α) → r y x → Acc r y) : Acc r x
@[elabAsEliminator, inline, reducible]
def Acc.ndrec.{u1, u2} {α : Sort u2} {r : αα → Prop} {C : α → Sort u1}
@ -289,8 +289,8 @@ variables {α : Sort u} {β : Sort v}
-- Reverse lexicographical order based on r and s
inductive RevLex (r : αα → Prop) (s : β → β → Prop) : @PSigma α (fun a => β) → @PSigma α (fun a => β) → Prop
| left : {a₁ a₂ : α} → (b : β) → r a₁ a₂ → RevLex r s ⟨a₁, b⟩ ⟨a₂, b⟩
| right : (a₁ : α) → {b₁ : β} → (a₂ : α) → {b₂ : β} → s b₁ b₂ → RevLex r s ⟨a₁, b₁⟩ ⟨a₂, b₂⟩
| left : {a₁ a₂ : α} → (b : β) → r a₁ a₂ → RevLex r s ⟨a₁, b⟩ ⟨a₂, b⟩
| right : (a₁ : α) → {b₁ : β} → (a₂ : α) → {b₂ : β} → s b₁ b₂ → RevLex r s ⟨a₁, b₁⟩ ⟨a₂, b₂⟩
end
section

View file

@ -15,11 +15,12 @@ namespace OwnedSet
abbrev Key := FunId × Index
def beq : Key → Key → Bool
| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂
| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂
instance : BEq Key := ⟨beq⟩
def getHash : Key → USize
| (f, x) => mixHash (hash f) (hash x)
| (f, x) => mixHash (hash f) (hash x)
instance : Hashable Key := ⟨getHash⟩
end OwnedSet
@ -35,19 +36,19 @@ def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := Std.HashMap.
Recall that `Param` contains the field `borrow`. -/
namespace ParamMap
inductive Key
| decl (name : FunId)
| jp (name : FunId) (jpid : JoinPointId)
| decl (name : FunId)
| jp (name : FunId) (jpid : JoinPointId)
def beq : Key → Key → Bool
| Key.decl n₁, Key.decl n₂ => n₁ == n₂
| Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂
| _, _ => false
| Key.decl n₁, Key.decl n₂ => n₁ == n₂
| Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂
| _, _ => false
instance : BEq Key := ⟨beq⟩
def getHash : Key → USize
| Key.decl n => hash n
| Key.jp n id => mixHash (hash n) (hash id)
| Key.decl n => hash n
| Key.jp n id => mixHash (hash n) (hash id)
instance : Hashable Key := ⟨getHash⟩
end ParamMap
@ -56,13 +57,13 @@ open ParamMap (Key)
abbrev ParamMap := Std.HashMap Key (Array Param)
def ParamMap.fmt (map : ParamMap) : Format :=
let fmts := map.fold (fun fmt k ps =>
let k := match k with
| ParamMap.Key.decl n => format n
| ParamMap.Key.jp n id => format n ++ ":" ++ format id
fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps)
Format.nil
"{" ++ (Format.nest 1 fmts) ++ "}"
let fmts := map.fold (fun fmt k ps =>
let k := match k with
| ParamMap.Key.decl n => format n
| ParamMap.Key.jp n id => format n ++ ":" ++ format id
fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps)
Format.nil
"{" ++ (Format.nest 1 fmts) ++ "}"
instance : ToFormat ParamMap := ⟨ParamMap.fmt⟩
instance : ToString ParamMap := ⟨fun m => Format.pretty (format m)⟩
@ -70,7 +71,7 @@ instance : ToString ParamMap := ⟨fun m => Format.pretty (format m)⟩
namespace InitParamMap
/- Mark parameters that take a reference as borrow -/
def initBorrow (ps : Array Param) : Array Param :=
ps.map $ fun p => { p with borrow := p.ty.isObj }
ps.map $ fun p => { p with borrow := p.ty.isObj }
/- We do perform borrow inference for constants marked as `export`.
Reason: we current write wrappers in C++ for using exported functions.
@ -80,26 +81,26 @@ ps.map $ fun p => { p with borrow := p.ty.isObj }
We can revise this decision when we implement code for generating
the wrappers automatically. -/
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
if exported then ps else initBorrow ps
if exported then ps else initBorrow ps
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
| FnBody.jdecl j xs v b => do
modify fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs)
visitFnBody fnid v
visitFnBody fnid b
| FnBody.case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body
| e => do
unless e.isTerminal do
let (instr, b) := e.split
| FnBody.jdecl j xs v b => do
modify fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs)
visitFnBody fnid v
visitFnBody fnid b
| FnBody.case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body
| e => do
unless e.isTerminal do
let (instr, b) := e.split
visitFnBody fnid b
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
decls.forM fun decl => match decl with
| Decl.fdecl f xs _ b => do
let exported := isExport env f
modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs)
visitFnBody f b
| _ => pure ()
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
decls.forM fun decl => match decl with
| Decl.fdecl f xs _ b => do
let exported := isExport env f
modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs)
visitFnBody f b
| _ => pure ()
end InitParamMap
def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap :=
@ -110,111 +111,111 @@ def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap :=
namespace ApplyParamMap
partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
| FnBody.jdecl j xs v b =>
let v := visitFnBody fn paramMap v
let b := visitFnBody fn paramMap b
match paramMap.find? (ParamMap.Key.jp fn j) with
| some ys => FnBody.jdecl j ys v b
| none => unreachable!
| FnBody.case tid x xType alts =>
FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (visitFnBody fn paramMap)
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
| FnBody.jdecl j xs v b =>
let v := visitFnBody fn paramMap v
let b := visitFnBody fn paramMap b
instr.setBody b
match paramMap.find? (ParamMap.Key.jp fn j) with
| some ys => FnBody.jdecl j ys v b
| none => unreachable!
| FnBody.case tid x xType alts =>
FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (visitFnBody fn paramMap)
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := visitFnBody fn paramMap b
instr.setBody b
def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
decls.map fun decl => match decl with
| Decl.fdecl f xs ty b =>
let b := visitFnBody f paramMap b
match paramMap.find? (ParamMap.Key.decl f) with
| some xs => Decl.fdecl f xs ty b
| none => unreachable!
| other => other
decls.map fun decl => match decl with
| Decl.fdecl f xs ty b =>
let b := visitFnBody f paramMap b
match paramMap.find? (ParamMap.Key.decl f) with
| some xs => Decl.fdecl f xs ty b
| none => unreachable!
| other => other
end ApplyParamMap
def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl :=
-- dbgTrace ("applyParamMap " ++ toString map) $ fun _ =>
ApplyParamMap.visitDecls decls map
-- dbgTrace ("applyParamMap " ++ toString map) $ fun _ =>
ApplyParamMap.visitDecls decls map
structure BorrowInfCtx :=
(env : Environment)
(currFn : FunId := arbitrary _) -- Function being analyzed.
(paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams`
(env : Environment)
(currFn : FunId := arbitrary _) -- Function being analyzed.
(paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams`
structure BorrowInfState :=
/- Set of variables that must be `owned`. -/
(owned : OwnedSet := {})
(modified : Bool := false)
(paramMap : ParamMap)
/- Set of variables that must be `owned`. -/
(owned : OwnedSet := {})
(modified : Bool := false)
(paramMap : ParamMap)
abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState)
def getCurrFn : M FunId := do
let ctx ← read
pure ctx.currFn
let ctx ← read
pure ctx.currFn
def markModified : M Unit :=
modify $ fun s => { s with modified := true }
modify fun s => { s with modified := true }
def ownVar (x : VarId) : M Unit := do
-- dbgTrace ("ownVar " ++ toString x) $ fun _ =>
let currFn ← getCurrFn
modify fun s =>
if s.owned.contains (currFn, x.idx) then s
else { s with owned := s.owned.insert (currFn, x.idx), modified := true }
-- dbgTrace ("ownVar " ++ toString x) $ fun _ =>
let currFn ← getCurrFn
modify fun s =>
if s.owned.contains (currFn, x.idx) then s
else { s with owned := s.owned.insert (currFn, x.idx), modified := true }
def ownArg (x : Arg) : M Unit :=
match x with
| Arg.var x => ownVar x
| _ => pure ()
match x with
| Arg.var x => ownVar x
| _ => pure ()
def ownArgs (xs : Array Arg) : M Unit :=
xs.forM ownArg
xs.forM ownArg
def isOwned (x : VarId) : M Bool := do
let currFn ← getCurrFn
let s ← get
pure $ s.owned.contains (currFn, x.idx)
let currFn ← getCurrFn
let s ← get
pure $ s.owned.contains (currFn, x.idx)
/- Updates `map[k]` using the current set of `owned` variables. -/
def updateParamMap (k : ParamMap.Key) : M Unit := do
let currFn ← getCurrFn
let s ← get
match s.paramMap.find? k with
| some ps => do
let ps ← ps.mapM fun (p : Param) => do
if !p.borrow then pure p
else if (← isOwned p.x) then
markModified
pure { p with borrow := false }
else
pure p
modify fun s => { s with paramMap := s.paramMap.insert k ps }
| none => pure ()
let currFn ← getCurrFn
let s ← get
match s.paramMap.find? k with
| some ps => do
let ps ← ps.mapM fun (p : Param) => do
if !p.borrow then pure p
else if (← isOwned p.x) then
markModified
pure { p with borrow := false }
else
pure p
modify fun s => { s with paramMap := s.paramMap.insert k ps }
| none => pure ()
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
let s ← get
match s.paramMap.find? k with
| some ps => pure ps
| none =>
match k with
| ParamMap.Key.decl fn => do
let ctx ← read
match findEnvDecl ctx.env fn with
| some decl => pure decl.params
| none => unreachable!
| _ => unreachable!
let s ← get
match s.paramMap.find? k with
| some ps => pure ps
| none =>
match k with
| ParamMap.Key.decl fn => do
let ctx ← read
match findEnvDecl ctx.env fn with
| some decl => pure decl.params
| none => unreachable!
| _ => unreachable!
/- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.forM fun i => do
let x := xs[i]
let p := ps[i]
unless p.borrow do ownArg x
xs.size.forM fun i => do
let x := xs[i]
let p := ps[i]
unless p.borrow do ownArg x
/- For each xs[i], if xs[i] is owned, then mark ps[i] as owned.
We use this action to preserve tail calls. That is, if we have
@ -222,12 +223,12 @@ xs.size.forM fun i => do
we would have to insert a `dec xs[i]` after `f xs` and consequently
"break" the tail call. -/
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.forM fun i => do
let x := xs[i]
let p := ps[i]
match x with
| Arg.var x => if (← isOwned x) then ownVar p.x
| _ => pure ()
xs.size.forM fun i => do
let x := xs[i]
let p := ps[i]
match x with
| Arg.var x => if (← isOwned x) then ownVar p.x
| _ => pure ()
/- Mark `xs[i]` as owned if it is one of the parameters `ps`.
We use this action to mark function parameters that are being "packed" inside constructors.
@ -239,85 +240,85 @@ xs.size.forM fun i => do
ret z
``` -/
def ownArgsIfParam (xs : Array Arg) : M Unit := do
let ctx ← read
xs.forM fun x => do
match x with
| Arg.var x => if ctx.paramSet.contains x.idx then ownVar x
| _ => pure ()
let ctx ← read
xs.forM fun x => do
match x with
| Arg.var x => if ctx.paramSet.contains x.idx then ownVar x
| _ => pure ()
def collectExpr (z : VarId) : Expr → M Unit
| Expr.reset _ x => ownVar z *> ownVar x
| Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
| Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs
| Expr.proj _ x => do
if (← isOwned x) then ownVar z
if (← isOwned z) then ownVar x
| Expr.fap g xs => do
let ps ← getParamInfo (ParamMap.Key.decl g)
ownVar z *> ownArgsUsingParams xs ps
| Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys
| Expr.pap _ xs => ownVar z *> ownArgs xs
| other => pure ()
| Expr.reset _ x => ownVar z *> ownVar x
| Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
| Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs
| Expr.proj _ x => do
if (← isOwned x) then ownVar z
if (← isOwned z) then ownVar x
| Expr.fap g xs => do
let ps ← getParamInfo (ParamMap.Key.decl g)
ownVar z *> ownArgsUsingParams xs ps
| Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys
| Expr.pap _ xs => ownVar z *> ownArgs xs
| other => pure ()
def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do
let ctx ← read
match v, b with
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) =>
if ctx.currFn == g && x == z then
-- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do
let ps ← getParamInfo (ParamMap.Key.decl g)
ownParamsUsingArgs ys ps
| _, _ => pure ()
let ctx ← read
match v, b with
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) =>
if ctx.currFn == g && x == z then
-- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do
let ps ← getParamInfo (ParamMap.Key.decl g)
ownParamsUsingArgs ys ps
| _, _ => pure ()
def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx :=
{ ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet }
{ ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet }
partial def collectFnBody : FnBody → M Unit
| FnBody.jdecl j ys v b => do
withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v)
let ctx ← read
updateParamMap (ParamMap.Key.jp ctx.currFn j)
collectFnBody b
| FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
| FnBody.jmp j ys => do
let ctx ← read
let ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j)
ownArgsUsingParams ys ps -- for making sure the join point can reuse
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
| FnBody.case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body
| e => do unless e.isTerminal do collectFnBody e.body
| FnBody.jdecl j ys v b => do
withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v)
let ctx ← read
updateParamMap (ParamMap.Key.jp ctx.currFn j)
collectFnBody b
| FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
| FnBody.jmp j ys => do
let ctx ← read
let ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j)
ownArgsUsingParams ys ps -- for making sure the join point can reuse
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
| FnBody.case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body
| e => do unless e.isTerminal do collectFnBody e.body
partial def collectDecl : Decl → M Unit
| Decl.fdecl f ys _ b =>
withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do
collectFnBody b
updateParamMap (ParamMap.Key.decl f)
| _ => pure ()
| Decl.fdecl f ys _ b =>
withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do
collectFnBody b
updateParamMap (ParamMap.Key.decl f)
| _ => pure ()
/- Keep executing `x` until it reaches a fixpoint -/
@[inline] partial def whileModifing (x : M Unit) : M Unit := do
modify fun s => { s with modified := false }
x
let s ← get
if s.modified then
whileModifing x
else
pure ()
modify fun s => { s with modified := false }
x
let s ← get
if s.modified then
whileModifing x
else
pure ()
def collectDecls (decls : Array Decl) : M ParamMap := do
whileModifing (decls.forM collectDecl)
let s ← get
pure s.paramMap
whileModifing (decls.forM collectDecl)
let s ← get
pure s.paramMap
def infer (env : Environment) (decls : Array Decl) : ParamMap :=
collectDecls decls { env := env } $.run' { paramMap := mkInitParamMap env decls }
collectDecls decls { env := env } $.run' { paramMap := mkInitParamMap env decls }
end Borrow
def inferBorrow (decls : Array Decl) : CompilerM (Array Decl) := do
let env ← getEnv
let paramMap := Borrow.infer env decls
pure (Borrow.applyParamMap decls paramMap)
let env ← getEnv
let paramMap := Borrow.infer env decls
pure (Borrow.applyParamMap decls paramMap)
end IR
end Lean

View file

@ -30,314 +30,320 @@ Assumptions:
open Std (AssocList)
def mkBoxedName (n : Name) : Name :=
mkNameStr n "_boxed"
mkNameStr n "_boxed"
def isBoxedName : Name → Bool
| Name.str _ "_boxed" _ => true
| _ => false
| Name.str _ "_boxed" _ => true
| _ => false
abbrev N := StateM Nat
private def N.mkFresh : N VarId :=
modifyGet fun n => ({ idx := n }, n + 1)
modifyGet fun n => ({ idx := n }, n + 1)
def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
let ps := decl.params
(ps.size > 0 && (decl.resultType.isScalar || ps.any (fun p => p.ty.isScalar || p.borrow) || isExtern env decl.name))
|| ps.size > closureMaxArgs
let ps := decl.params
(ps.size > 0 && (decl.resultType.isScalar || ps.any (fun p => p.ty.isScalar || p.borrow) || isExtern env decl.name))
|| ps.size > closureMaxArgs
def mkBoxedVersionAux (decl : Decl) : N Decl := do
let ps := decl.params
let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
let p := ps[i]
let q := qs[i]
if !p.ty.isScalar then
pure (newVDecls, xs.push (Arg.var q.x))
else
let x ← N.mkFresh
pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (arbitrary _)), xs.push (Arg.var x))
let r ← N.mkFresh
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (arbitrary _))
let body ←
if !decl.resultType.isScalar then
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
else
let newR ← N.mkFresh
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (arbitrary _))
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
let ps := decl.params
let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
let p := ps[i]
let q := qs[i]
if !p.ty.isScalar then
pure (newVDecls, xs.push (Arg.var q.x))
else
let x ← N.mkFresh
pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (arbitrary _)), xs.push (Arg.var x))
let r ← N.mkFresh
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (arbitrary _))
let body ←
if !decl.resultType.isScalar then
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
else
let newR ← N.mkFresh
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (arbitrary _))
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
def mkBoxedVersion (decl : Decl) : Decl :=
(mkBoxedVersionAux decl).run' 1
(mkBoxedVersionAux decl).run' 1
def addBoxedVersions (env : Environment) (decls : Array Decl) : Array Decl :=
let boxedDecls := decls.foldl (init := #[]) fun newDecls decl =>
if requiresBoxedVersion env decl then newDecls.push (mkBoxedVersion decl) else newDecls
decls ++ boxedDecls
let boxedDecls := decls.foldl (init := #[]) fun newDecls decl =>
if requiresBoxedVersion env decl then newDecls.push (mkBoxedVersion decl) else newDecls
decls ++ boxedDecls
/- Infer scrutinee type using `case` alternatives.
This can be done whenever `alts` does not contain an `Alt.default _` value. -/
def getScrutineeType (alts : Array Alt) : IRType :=
let isScalar :=
alts.size > 1 && -- Recall that we encode Unit and PUnit using `object`.
alts.all fun
| Alt.ctor c _ => c.isScalar
| Alt.default _ => false
match isScalar with
| false => IRType.object
| true =>
let n := alts.size
if n < 256 then IRType.uint8
else if n < 65536 then IRType.uint16
else if n < 4294967296 then IRType.uint32
else IRType.object -- in practice this should be unreachable
let isScalar :=
alts.size > 1 && -- Recall that we encode Unit and PUnit using `object`.
alts.all fun
| Alt.ctor c _ => c.isScalar
| Alt.default _ => false
match isScalar with
| false => IRType.object
| true =>
let n := alts.size
if n < 256 then IRType.uint8
else if n < 65536 then IRType.uint16
else if n < 4294967296 then IRType.uint32
else IRType.object -- in practice this should be unreachable
def eqvTypes (t₁ t₂ : IRType) : Bool :=
(t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂)
(t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂)
structure BoxingContext :=
(f : FunId := arbitrary _) (localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment)
(f : FunId := arbitrary _)
(localCtx : LocalContext := {})
(resultType : IRType := IRType.irrelevant)
(decls : Array Decl)
(env : Environment)
structure BoxingState :=
(nextIdx : Index)
/- We create auxiliary declarations when boxing constant and literals.
The idea is to avoid code such as
```
let x1 := Uint64.inhabited;
let x2 := box x1;
...
```
We currently do not cache these declarations in an environment extension, but
we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when
processing the same IR declaration.
-/
(auxDecls : Array Decl := #[])
(auxDeclCache : AssocList FnBody Expr := Std.AssocList.empty)
(nextAuxId : Nat := 1)
(nextIdx : Index)
/- We create auxiliary declarations when boxing constant and literals.
The idea is to avoid code such as
```
let x1 := Uint64.inhabited;
let x2 := box x1;
...
```
We currently do not cache these declarations in an environment extension, but
we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when
processing the same IR declaration.
-/
(auxDecls : Array Decl := #[])
(auxDeclCache : AssocList FnBody Expr := Std.AssocList.empty)
(nextAuxId : Nat := 1)
abbrev M := ReaderT BoxingContext (StateT BoxingState Id)
private def M.mkFresh : M VarId := do
let oldS ← getModify fun s => { s with nextIdx := s.nextIdx + 1 }
pure { idx := oldS.nextIdx }
let oldS ← getModify fun s => { s with nextIdx := s.nextIdx + 1 }
pure { idx := oldS.nextIdx }
def getEnv : M Environment := BoxingContext.env <$> read
def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read
def getResultType : M IRType := BoxingContext.resultType <$> read
def getVarType (x : VarId) : M IRType := do
let localCtx ← getLocalContext
match localCtx.getType x with
| some t => pure t
| none => pure IRType.object -- unreachable, we assume the code is well formed
let localCtx ← getLocalContext
match localCtx.getType x with
| some t => pure t
| none => pure IRType.object -- unreachable, we assume the code is well formed
def getJPParams (j : JoinPointId) : M (Array Param) := do
let localCtx ← getLocalContext
match localCtx.getJPParams j with
| some ys => pure ys
| none => pure #[] -- unreachable, we assume the code is well formed
let localCtx ← getLocalContext
match localCtx.getJPParams j with
| some ys => pure ys
| none => pure #[] -- unreachable, we assume the code is well formed
def getDecl (fid : FunId) : M Decl := do
let ctx ← read
match findEnvDecl' ctx.env fid ctx.decls with
| some decl => pure decl
| none => pure (arbitrary _) -- unreachable if well-formed
let ctx ← read
match findEnvDecl' ctx.env fid ctx.decls with
| some decl => pure decl
| none => pure (arbitrary _) -- unreachable if well-formed
@[inline] def withParams {α : Type} (xs : Array Param) (k : M α) : M α :=
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addParams xs }) k
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addParams xs }) k
@[inline] def withVDecl {α : Type} (x : VarId) (ty : IRType) (v : Expr) (k : M α) : M α :=
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x ty v }) k
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x ty v }) k
@[inline] def withJDecl {α : Type} (j : JoinPointId) (xs : Array Param) (v : FnBody) (k : M α) : M α :=
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j xs v }) k
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j xs v }) k
/- If `x` declaration is of the form `x := Expr.lit _` or `x := Expr.fap c #[]`,
and `x`'s type is not cheap to box (e.g., it is `UInt64), then return its value. -/
private def isExpensiveConstantValueBoxing (x : VarId) (xType : IRType) : M (Option Expr) :=
if !xType.isScalar then pure none -- We assume unboxing is always cheap
else match xType with
| IRType.uint8 => pure none
| IRType.uint16 => pure none
| _ => do
let localCtx ← getLocalContext
match localCtx.getValue x with
| some val =>
match val with
| Expr.lit _ => pure $ some val
| Expr.fap _ args => pure $ if args.size == 0 then some val else none
| _ => pure none
| _ => pure none
if !xType.isScalar then
pure none -- We assume unboxing is always cheap
else match xType with
| IRType.uint8 => pure none
| IRType.uint16 => pure none
| _ => do
let localCtx ← getLocalContext
match localCtx.getValue x with
| some val =>
match val with
| Expr.lit _ => pure $ some val
| Expr.fap _ args => pure $ if args.size == 0 then some val else none
| _ => pure none
| _ => pure none
/- Auxiliary function used by castVarIfNeeded.
It is used when the expected type does not match `xType`.
If `xType` is scalar, then we need to "box" it. Otherwise, we need to "unbox" it. -/
def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr := do
match (← isExpensiveConstantValueBoxing x xType) with
| some v => do
let ctx ← read
let s ← get
/- Create auxiliary FnBody
```
let x_1 : xType := v;
let x_2 : expectedType := Expr.box xType x_1;
ret x_2
```
-/
let body : FnBody :=
FnBody.vdecl { idx := 1 } xType v $
FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $
FnBody.ret (mkVarArg { idx := 2 })
match s.auxDeclCache.find? body with
| some v => pure v
| none => do
let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId)
let auxConst := Expr.fap auxName #[]
let auxDecl := Decl.fdecl auxName #[] expectedType body
modify fun s => {
s with
auxDecls := s.auxDecls.push auxDecl,
auxDeclCache := s.auxDeclCache.cons body auxConst,
nextAuxId := s.nextAuxId + 1
}
pure auxConst
| none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x
match (← isExpensiveConstantValueBoxing x xType) with
| some v => do
let ctx ← read
let s ← get
/- Create auxiliary FnBody
```
let x_1 : xType := v;
let x_2 : expectedType := Expr.box xType x_1;
ret x_2
```
-/
let body : FnBody :=
FnBody.vdecl { idx := 1 } xType v $
FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $
FnBody.ret (mkVarArg { idx := 2 })
match s.auxDeclCache.find? body with
| some v => pure v
| none => do
let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId)
let auxConst := Expr.fap auxName #[]
let auxDecl := Decl.fdecl auxName #[] expectedType body
modify fun s => {
s with
auxDecls := s.auxDecls.push auxDecl,
auxDeclCache := s.auxDeclCache.cons body auxConst,
nextAuxId := s.nextAuxId + 1
}
pure auxConst
| none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x
@[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody := do
let xType ← getVarType x
if eqvTypes xType expected then k x
else
let y ← M.mkFresh
let v ← mkCast x xType expected
FnBody.vdecl y expected v <$> k y
let xType ← getVarType x
if eqvTypes xType expected then
k x
else
let y ← M.mkFresh
let v ← mkCast x xType expected
FnBody.vdecl y expected v <$> k y
@[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody :=
match x with
| Arg.var x => castVarIfNeeded x expected (fun x => k (Arg.var x))
| _ => k x
match x with
| Arg.var x => castVarIfNeeded x expected (fun x => k (Arg.var x))
| _ => k x
@[specialize] def castArgsIfNeededAux (xs : Array Arg) (typeFromIdx : Nat → IRType) : M (Array Arg × Array FnBody) := do
let xs' := #[]
let bs := #[]
let i := 0
for x in xs do
let expected := typeFromIdx i
match x with
| Arg.irrelevant =>
xs' := xs'.push x
| Arg.var x =>
let xType ← getVarType x
if eqvTypes xType expected then
xs' := xs'.push (Arg.var x)
else
let y ← M.mkFresh
let v ← mkCast x xType expected
let b := FnBody.vdecl y expected v FnBody.nil
xs' := xs'.push (Arg.var y)
bs := bs.push b
i := i + 1
return (xs', bs)
let xs' := #[]
let bs := #[]
let i := 0
for x in xs do
let expected := typeFromIdx i
match x with
| Arg.irrelevant =>
xs' := xs'.push x
| Arg.var x =>
let xType ← getVarType x
if eqvTypes xType expected then
xs' := xs'.push (Arg.var x)
else
let y ← M.mkFresh
let v ← mkCast x xType expected
let b := FnBody.vdecl y expected v FnBody.nil
xs' := xs'.push (Arg.var y)
bs := bs.push b
i := i + 1
return (xs', bs)
@[inline] def castArgsIfNeeded (xs : Array Arg) (ps : Array Param) (k : Array Arg → M FnBody) : M FnBody := do
let (ys, bs) ← castArgsIfNeededAux xs fun i => ps[i].ty
let b ← k ys
pure (reshape bs b)
let (ys, bs) ← castArgsIfNeededAux xs fun i => ps[i].ty
let b ← k ys
pure (reshape bs b)
@[inline] def boxArgsIfNeeded (xs : Array Arg) (k : Array Arg → M FnBody) : M FnBody := do
let (ys, bs) ← castArgsIfNeededAux xs (fun _ => IRType.object)
let b ← k ys
pure (reshape bs b)
let (ys, bs) ← castArgsIfNeededAux xs (fun _ => IRType.object)
let b ← k ys
pure (reshape bs b)
def unboxResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody := do
if ty.isScalar then
let y ← M.mkFresh
pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b)
else
pure $ FnBody.vdecl x ty e b
if ty.isScalar then
let y ← M.mkFresh
pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b)
else
pure $ FnBody.vdecl x ty e b
def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b : FnBody) : M FnBody := do
if eqvTypes ty eType then
pure $ FnBody.vdecl x ty e b
else
let y ← M.mkFresh
let v ← mkCast y eType ty
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b)
if eqvTypes ty eType then
pure $ FnBody.vdecl x ty e b
else
let y ← M.mkFresh
let v ← mkCast y eType ty
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b)
def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
match e with
| Expr.ctor c ys =>
if c.isScalar && ty.isScalar then
pure $ FnBody.vdecl x ty (Expr.lit (LitVal.num c.cidx)) b
else
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.ctor c ys) b
| Expr.reuse w c u ys =>
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b
| Expr.fap f ys => do
let decl ← getDecl f
castArgsIfNeeded ys decl.params fun ys =>
castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b
| Expr.pap f ys => do
let env ← getEnv
let decl ← getDecl f
let f := if requiresBoxedVersion env decl then mkBoxedName f else f
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.pap f ys) b
| Expr.ap f ys =>
boxArgsIfNeeded ys fun ys =>
unboxResultIfNeeded x ty (Expr.ap f ys) b
| other =>
pure $ FnBody.vdecl x ty e b
match e with
| Expr.ctor c ys =>
if c.isScalar && ty.isScalar then
pure $ FnBody.vdecl x ty (Expr.lit (LitVal.num c.cidx)) b
else
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.ctor c ys) b
| Expr.reuse w c u ys =>
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b
| Expr.fap f ys => do
let decl ← getDecl f
castArgsIfNeeded ys decl.params fun ys =>
castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b
| Expr.pap f ys => do
let env ← getEnv
let decl ← getDecl f
let f := if requiresBoxedVersion env decl then mkBoxedName f else f
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.pap f ys) b
| Expr.ap f ys =>
boxArgsIfNeeded ys fun ys =>
unboxResultIfNeeded x ty (Expr.ap f ys) b
| other =>
pure $ FnBody.vdecl x ty e b
partial def visitFnBody : FnBody → M FnBody
| FnBody.vdecl x t v b => do
let b ← withVDecl x t v (visitFnBody b)
visitVDeclExpr x t v b
| FnBody.jdecl j xs v b => do
let v ← withParams xs (visitFnBody v)
let b ← withJDecl j xs v (visitFnBody b)
pure $ FnBody.jdecl j xs v b
| FnBody.uset x i y b => do
let b ← visitFnBody b
castVarIfNeeded y IRType.usize fun y =>
pure $ FnBody.uset x i y b
| FnBody.sset x i o y ty b => do
let b ← visitFnBody b
castVarIfNeeded y ty fun y =>
pure $ FnBody.sset x i o y ty b
| FnBody.mdata d b =>
FnBody.mdata d <$> visitFnBody b
| FnBody.case tid x _ alts => do
let expected := getScrutineeType alts
let alts ← alts.mapM fun alt => alt.mmodifyBody visitFnBody
castVarIfNeeded x expected fun x => do
pure $ FnBody.case tid x expected alts
| FnBody.ret x => do
let expected ← getResultType
castArgIfNeeded x expected (fun x => pure $ FnBody.ret x)
| FnBody.jmp j ys => do
let ps ← getJPParams j
castArgsIfNeeded ys ps fun ys => pure $ FnBody.jmp j ys
| other =>
pure other
| FnBody.vdecl x t v b => do
let b ← withVDecl x t v (visitFnBody b)
visitVDeclExpr x t v b
| FnBody.jdecl j xs v b => do
let v ← withParams xs (visitFnBody v)
let b ← withJDecl j xs v (visitFnBody b)
pure $ FnBody.jdecl j xs v b
| FnBody.uset x i y b => do
let b ← visitFnBody b
castVarIfNeeded y IRType.usize fun y =>
pure $ FnBody.uset x i y b
| FnBody.sset x i o y ty b => do
let b ← visitFnBody b
castVarIfNeeded y ty fun y =>
pure $ FnBody.sset x i o y ty b
| FnBody.mdata d b =>
FnBody.mdata d <$> visitFnBody b
| FnBody.case tid x _ alts => do
let expected := getScrutineeType alts
let alts ← alts.mapM fun alt => alt.mmodifyBody visitFnBody
castVarIfNeeded x expected fun x => do
pure $ FnBody.case tid x expected alts
| FnBody.ret x => do
let expected ← getResultType
castArgIfNeeded x expected (fun x => pure $ FnBody.ret x)
| FnBody.jmp j ys => do
let ps ← getJPParams j
castArgsIfNeeded ys ps fun ys => pure $ FnBody.jmp j ys
| other =>
pure other
def run (env : Environment) (decls : Array Decl) : Array Decl :=
let ctx : BoxingContext := { decls := decls, env := env }
let decls := decls.foldl (init := #[]) fun newDecls decl =>
match decl with
| Decl.fdecl f xs t b =>
let nextIdx := decl.maxIndex + 1
let (b, s) := (withParams xs (visitFnBody b) { ctx with f := f, resultType := t }).run { nextIdx := nextIdx }
let newDecls := newDecls ++ s.auxDecls
let newDecl := Decl.fdecl f xs t b
let newDecl := newDecl.elimDead
newDecls.push newDecl
| d => newDecls.push d
addBoxedVersions env decls
let ctx : BoxingContext := { decls := decls, env := env }
let decls := decls.foldl (init := #[]) fun newDecls decl =>
match decl with
| Decl.fdecl f xs t b =>
let nextIdx := decl.maxIndex + 1
let (b, s) := (withParams xs (visitFnBody b) { ctx with f := f, resultType := t }).run { nextIdx := nextIdx }
let newDecls := newDecls ++ s.auxDecls
let newDecl := Decl.fdecl f xs t b
let newDecl := newDecl.elimDead
newDecls.push newDecl
| d => newDecls.push d
addBoxedVersions env decls
end ExplicitBoxing
def explicitBoxing (decls : Array Decl) : CompilerM (Array Decl) := do
let env ← getEnv
pure $ ExplicitBoxing.run env decls
let env ← getEnv
pure $ ExplicitBoxing.run env decls
end Lean.IR

View file

@ -9,159 +9,161 @@ import Lean.Compiler.IR.Format
namespace Lean.IR.Checker
structure CheckerContext :=
(env : Environment) (localCtx : LocalContext := {}) (decls : Array Decl)
(env : Environment)
(localCtx : LocalContext := {})
(decls : Array Decl)
structure CheckerState :=
(foundVars : IndexSet := {})
(foundVars : IndexSet := {})
abbrev M := ReaderT CheckerContext (ExceptT String (StateT CheckerState Id))
def markIndex (i : Index) : M Unit := do
let s ← get
if s.foundVars.contains i then
throw s!"variable / joinpoint index {i} has already been used"
modify fun s => { s with foundVars := s.foundVars.insert i }
let s ← get
if s.foundVars.contains i then
throw s!"variable / joinpoint index {i} has already been used"
modify fun s => { s with foundVars := s.foundVars.insert i }
def markVar (x : VarId) : M Unit :=
markIndex x.idx
markIndex x.idx
def markJP (j : JoinPointId) : M Unit :=
markIndex j.idx
markIndex j.idx
def getDecl (c : Name) : M Decl := do
let ctx ← read
match findEnvDecl' ctx.env c ctx.decls with
| none => throw s!"unknown declaration '{c}'"
| some d => pure d
let ctx ← read
match findEnvDecl' ctx.env c ctx.decls with
| none => throw s!"unknown declaration '{c}'"
| some d => pure d
def checkVar (x : VarId) : M Unit := do
let ctx ← read
unless ctx.localCtx.isLocalVar x.idx || ctx.localCtx.isParam x.idx do
throw s!"unknown variable '{x}'"
let ctx ← read
unless ctx.localCtx.isLocalVar x.idx || ctx.localCtx.isParam x.idx do
throw s!"unknown variable '{x}'"
def checkJP (j : JoinPointId) : M Unit := do
let ctx ← read
unless ctx.localCtx.isJP j.idx do
throw s!"unknown join point '{j}'"
let ctx ← read
unless ctx.localCtx.isJP j.idx do
throw s!"unknown join point '{j}'"
def checkArg (a : Arg) : M Unit :=
match a with
| Arg.var x => checkVar x
| other => pure ()
match a with
| Arg.var x => checkVar x
| other => pure ()
def checkArgs (as : Array Arg) : M Unit :=
as.forM checkArg
as.forM checkArg
@[inline] def checkEqTypes (ty₁ ty₂ : IRType) : M Unit := do
unless ty₁ == ty₂ do
throw "unexpected type"
unless ty₁ == ty₂ do
throw "unexpected type"
@[inline] def checkType (ty : IRType) (p : IRType → Bool) : M Unit := do
unless p ty do
throw s!"unexpected type '{ty}'"
unless p ty do
throw s!"unexpected type '{ty}'"
def checkObjType (ty : IRType) : M Unit := checkType ty IRType.isObj
def checkScalarType (ty : IRType) : M Unit := checkType ty IRType.isScalar
def getType (x : VarId) : M IRType := do
let ctx ← read
match ctx.localCtx.getType x with
| some ty => pure ty
| none => throw s!"unknown variable '{x}'"
let ctx ← read
match ctx.localCtx.getType x with
| some ty => pure ty
| none => throw s!"unknown variable '{x}'"
@[inline] def checkVarType (x : VarId) (p : IRType → Bool) : M Unit := do
let ty ← getType x; checkType ty p
let ty ← getType x; checkType ty p
def checkObjVar (x : VarId) : M Unit :=
checkVarType x IRType.isObj
checkVarType x IRType.isObj
def checkScalarVar (x : VarId) : M Unit :=
checkVarType x IRType.isScalar
checkVarType x IRType.isScalar
def checkFullApp (c : FunId) (ys : Array Arg) : M Unit := do
if c == `hugeFuel then
throw "the auxiliary constant `hugeFuel` cannot be used in code, it is used internally for compiling `partial` definitions"
let decl ← getDecl c
unless ys.size == decl.params.size do
throw s!"incorrect number of arguments to '{c}', {ys.size} provided, {decl.params.size} expected"
checkArgs ys
if c == `hugeFuel then
throw "the auxiliary constant `hugeFuel` cannot be used in code, it is used internally for compiling `partial` definitions"
let decl ← getDecl c
unless ys.size == decl.params.size do
throw s!"incorrect number of arguments to '{c}', {ys.size} provided, {decl.params.size} expected"
checkArgs ys
def checkPartialApp (c : FunId) (ys : Array Arg) : M Unit := do
let decl ← getDecl c
unless ys.size < decl.params.size do
throw s!"too many arguments too partial application '{c}', num. args: {ys.size}, arity: {decl.params.size}"
checkArgs ys
let decl ← getDecl c
unless ys.size < decl.params.size do
throw s!"too many arguments too partial application '{c}', num. args: {ys.size}, arity: {decl.params.size}"
checkArgs ys
def checkExpr (ty : IRType) : Expr → M Unit
| Expr.pap f ys => checkPartialApp f ys *> checkObjType ty -- partial applications should always produce a closure object
| Expr.ap x ys => checkObjVar x *> checkArgs ys
| Expr.fap f ys => checkFullApp f ys
| Expr.ctor c ys => when (!ty.isStruct && !ty.isUnion && c.isRef) (checkObjType ty) *> checkArgs ys
| Expr.reset _ x => checkObjVar x *> checkObjType ty
| Expr.reuse x i u ys => checkObjVar x *> checkArgs ys *> checkObjType ty
| Expr.box xty x => checkObjType ty *> checkScalarVar x *> checkVarType x (fun t => t == xty)
| Expr.unbox x => checkScalarType ty *> checkObjVar x
| Expr.proj i x => do
let xType ← getType x;
match xType with
| IRType.object => checkObjType ty
| IRType.tobject => checkObjType ty
| IRType.struct _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index"
| IRType.union _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index"
| other => throw s!"unexpected IR type '{xType}'"
| Expr.uproj _ x => checkObjVar x *> checkType ty (fun t => t == IRType.usize)
| Expr.sproj _ _ x => checkObjVar x *> checkScalarType ty
| Expr.isShared x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8)
| Expr.isTaggedPtr x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8)
| Expr.lit (LitVal.str _) => checkObjType ty
| Expr.lit _ => pure ()
| Expr.pap f ys => checkPartialApp f ys *> checkObjType ty -- partial applications should always produce a closure object
| Expr.ap x ys => checkObjVar x *> checkArgs ys
| Expr.fap f ys => checkFullApp f ys
| Expr.ctor c ys => when (!ty.isStruct && !ty.isUnion && c.isRef) (checkObjType ty) *> checkArgs ys
| Expr.reset _ x => checkObjVar x *> checkObjType ty
| Expr.reuse x i u ys => checkObjVar x *> checkArgs ys *> checkObjType ty
| Expr.box xty x => checkObjType ty *> checkScalarVar x *> checkVarType x (fun t => t == xty)
| Expr.unbox x => checkScalarType ty *> checkObjVar x
| Expr.proj i x => do
let xType ← getType x;
match xType with
| IRType.object => checkObjType ty
| IRType.tobject => checkObjType ty
| IRType.struct _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index"
| IRType.union _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index"
| other => throw s!"unexpected IR type '{xType}'"
| Expr.uproj _ x => checkObjVar x *> checkType ty (fun t => t == IRType.usize)
| Expr.sproj _ _ x => checkObjVar x *> checkScalarType ty
| Expr.isShared x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8)
| Expr.isTaggedPtr x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8)
| Expr.lit (LitVal.str _) => checkObjType ty
| Expr.lit _ => pure ()
@[inline] def withParams (ps : Array Param) (k : M Unit) : M Unit := do
let ctx ← read
let localCtx ← ps.foldlM (init := ctx.localCtx) fun (ctx : LocalContext) p => do
markVar p.x
pure $ ctx.addParam p
withReader (fun _ => { ctx with localCtx := localCtx }) k
let ctx ← read
let localCtx ← ps.foldlM (init := ctx.localCtx) fun (ctx : LocalContext) p => do
markVar p.x
pure $ ctx.addParam p
withReader (fun _ => { ctx with localCtx := localCtx }) k
partial def checkFnBody : FnBody → M Unit
| FnBody.vdecl x t v b => do
checkExpr t v;
markVar x;
let ctx ← read
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x t v }) (checkFnBody b)
| FnBody.jdecl j ys v b => do
markJP j;
withParams ys (checkFnBody v);
let ctx ← read
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j ys v }) (checkFnBody b)
| FnBody.set x _ y b => checkVar x *> checkArg y *> checkFnBody b
| FnBody.uset x _ y b => checkVar x *> checkVar y *> checkFnBody b
| FnBody.sset x _ _ y _ b => checkVar x *> checkVar y *> checkFnBody b
| FnBody.setTag x _ b => checkVar x *> checkFnBody b
| FnBody.inc x _ _ _ b => checkVar x *> checkFnBody b
| FnBody.dec x _ _ _ b => checkVar x *> checkFnBody b
| FnBody.del x b => checkVar x *> checkFnBody b
| FnBody.mdata _ b => checkFnBody b
| FnBody.jmp j ys => checkJP j *> checkArgs ys
| FnBody.ret x => checkArg x
| FnBody.case _ x _ alts => checkVar x *> alts.forM (fun alt => checkFnBody alt.body)
| FnBody.unreachable => pure ()
| FnBody.vdecl x t v b => do
checkExpr t v;
markVar x;
let ctx ← read
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x t v }) (checkFnBody b)
| FnBody.jdecl j ys v b => do
markJP j;
withParams ys (checkFnBody v);
let ctx ← read
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j ys v }) (checkFnBody b)
| FnBody.set x _ y b => checkVar x *> checkArg y *> checkFnBody b
| FnBody.uset x _ y b => checkVar x *> checkVar y *> checkFnBody b
| FnBody.sset x _ _ y _ b => checkVar x *> checkVar y *> checkFnBody b
| FnBody.setTag x _ b => checkVar x *> checkFnBody b
| FnBody.inc x _ _ _ b => checkVar x *> checkFnBody b
| FnBody.dec x _ _ _ b => checkVar x *> checkFnBody b
| FnBody.del x b => checkVar x *> checkFnBody b
| FnBody.mdata _ b => checkFnBody b
| FnBody.jmp j ys => checkJP j *> checkArgs ys
| FnBody.ret x => checkArg x
| FnBody.case _ x _ alts => checkVar x *> alts.forM (fun alt => checkFnBody alt.body)
| FnBody.unreachable => pure ()
def checkDecl : Decl → M Unit
| Decl.fdecl f xs t b => withParams xs (checkFnBody b)
| Decl.extern f xs t _ => withParams xs (pure ())
| Decl.fdecl f xs t b => withParams xs (checkFnBody b)
| Decl.extern f xs t _ => withParams xs (pure ())
end Checker
def checkDecl (decls : Array Decl) (decl : Decl) : CompilerM Unit := do
let env ← getEnv
match (Checker.checkDecl decl { env := env, decls := decls }).run' {} with
| Except.error msg => throw s!"IR check failed at '{decl.name}', error: {msg}"
| other => pure ()
let env ← getEnv
match (Checker.checkDecl decl { env := env, decls := decls }).run' {} with
| Except.error msg => throw s!"IR check failed at '{decl.name}', error: {msg}"
| other => pure ()
def checkDecls (decls : Array Decl) : CompilerM Unit :=
decls.forM (checkDecl decls)
decls.forM (checkDecl decls)
end IR
end Lean

View file

@ -10,13 +10,13 @@ import Lean.Compiler.IR.Format
namespace Lean.IR
inductive LogEntry
| step (cls : Name) (decls : Array Decl)
| message (msg : Format)
| step (cls : Name) (decls : Array Decl)
| message (msg : Format)
namespace LogEntry
protected def fmt : LogEntry → Format
| step cls decls => Format.bracket "[" (format cls) "]" ++ decls.foldl (fun fmt decl => fmt ++ Format.line ++ format decl) Format.nil
| message msg => msg
| step cls decls => Format.bracket "[" (format cls) "]" ++ decls.foldl (fun fmt decl => fmt ++ Format.line ++ format decl) Format.nil
| message msg => msg
instance : ToFormat LogEntry := ⟨LogEntry.fmt⟩
end LogEntry
@ -24,49 +24,50 @@ end LogEntry
abbrev Log := Array LogEntry
def Log.format (log : Log) : Format :=
log.foldl (init := Format.nil) fun fmt entry =>
f!"{fmt}{Format.line}{entry}"
log.foldl (init := Format.nil) fun fmt entry =>
f!"{fmt}{Format.line}{entry}"
@[export lean_ir_log_to_string]
def Log.toString (log : Log) : String :=
log.format.pretty
log.format.pretty
structure CompilerState :=
(env : Environment) (log : Log := #[])
(env : Environment)
(log : Log := #[])
abbrev CompilerM := ReaderT Options (EStateM String CompilerState)
def log (entry : LogEntry) : CompilerM Unit :=
modify $ fun s => { s with log := s.log.push entry }
modify fun s => { s with log := s.log.push entry }
def tracePrefixOptionName := `trace.compiler.ir
private def isLogEnabledFor (opts : Options) (optName : Name) : Bool :=
match opts.find optName with
| some (DataValue.ofBool v) => v
| other => opts.getBool tracePrefixOptionName
match opts.find optName with
| some (DataValue.ofBool v) => v
| other => opts.getBool tracePrefixOptionName
private def logDeclsAux (optName : Name) (cls : Name) (decls : Array Decl) : CompilerM Unit := do
let opts ← read
if isLogEnabledFor opts optName then
log (LogEntry.step cls decls)
let opts ← read
if isLogEnabledFor opts optName then
log (LogEntry.step cls decls)
@[inline] def logDecls (cls : Name) (decl : Array Decl) : CompilerM Unit :=
logDeclsAux (tracePrefixOptionName ++ cls) cls decl
logDeclsAux (tracePrefixOptionName ++ cls) cls decl
private def logMessageIfAux {α : Type} [ToFormat α] (optName : Name) (a : α) : CompilerM Unit := do
let opts ← read
if isLogEnabledFor opts optName then
log (LogEntry.message (format a))
let opts ← read
if isLogEnabledFor opts optName then
log (LogEntry.message (format a))
@[inline] def logMessageIf {α : Type} [ToFormat α] (cls : Name) (a : α) : CompilerM Unit :=
logMessageIfAux (tracePrefixOptionName ++ cls) a
logMessageIfAux (tracePrefixOptionName ++ cls) a
@[inline] def logMessage {α : Type} [ToFormat α] (cls : Name) (a : α) : CompilerM Unit :=
logMessageIfAux tracePrefixOptionName a
logMessageIfAux tracePrefixOptionName a
@[inline] def modifyEnv (f : Environment → Environment) : CompilerM Unit :=
modify fun s => { s with env := f s.env }
modify fun s => { s with env := f s.env }
open Std (HashMap)
@ -75,10 +76,10 @@ abbrev DeclMap := SMap Name Decl
/- Create an array of decls to be saved on .olean file.
`decls` may contain duplicate entries, but we assume the one that occurs last is the most recent one. -/
private def mkEntryArray (decls : List Decl) : Array Decl :=
/- Remove duplicates by adding decls into a map -/
let map : HashMap Name Decl := {}
let map := decls.foldl (init := map) fun map decl => map.insert decl.name decl
map.fold (fun a k v => a.push v) #[]
/- Remove duplicates by adding decls into a map -/
let map : HashMap Name Decl := {}
let map := decls.foldl (init := map) fun map decl => map.insert decl.name decl
map.fold (fun a k v => a.push v) #[]
builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ←
registerSimplePersistentEnvExtension {
@ -92,53 +93,54 @@ builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ←
@[export lean_ir_find_env_decl]
def findEnvDecl (env : Environment) (n : Name) : Option Decl :=
(declMapExt.getState env).find? n
(declMapExt.getState env).find? n
def findDecl (n : Name) : CompilerM (Option Decl) := do
let s ← get
pure $ findEnvDecl s.env n
let s ← get
pure $ findEnvDecl s.env n
def containsDecl (n : Name) : CompilerM Bool := do
let s ← get
pure $ (declMapExt.getState s.env).contains n
def getDecl (n : Name) : CompilerM Decl := do
let (some decl) ← findDecl n | throw s!"unknown declaration '{n}'"
pure decl
@[export lean_ir_add_decl]
def addDeclAux (env : Environment) (decl : Decl) : Environment :=
declMapExt.addEntry env decl
def getDecls (env : Environment) : List Decl :=
declMapExt.getEntries env
def getEnv : CompilerM Environment := do
let s ← get; pure s.env
def addDecl (decl : Decl) : CompilerM Unit :=
modifyEnv fun env => declMapExt.addEntry env decl
def addDecls (decls : Array Decl) : CompilerM Unit :=
decls.forM addDecl
def findEnvDecl' (env : Environment) (n : Name) (decls : Array Decl) : Option Decl :=
match decls.find? (fun decl => decl.name == n) with
| some decl => some decl
| none => (declMapExt.getState env).find? n
def findDecl' (n : Name) (decls : Array Decl) : CompilerM (Option Decl) := do
let s ← get; pure $ findEnvDecl' s.env n decls
def containsDecl' (n : Name) (decls : Array Decl) : CompilerM Bool :=
if decls.any (fun decl => decl.name == n) then pure true
else do
let s ← get
pure $ (declMapExt.getState s.env).contains n
def getDecl (n : Name) : CompilerM Decl := do
let (some decl) ← findDecl n | throw s!"unknown declaration '{n}'"
pure decl
@[export lean_ir_add_decl]
def addDeclAux (env : Environment) (decl : Decl) : Environment :=
declMapExt.addEntry env decl
def getDecls (env : Environment) : List Decl :=
declMapExt.getEntries env
def getEnv : CompilerM Environment := do
let s ← get; pure s.env
def addDecl (decl : Decl) : CompilerM Unit :=
modifyEnv fun env => declMapExt.addEntry env decl
def addDecls (decls : Array Decl) : CompilerM Unit :=
decls.forM addDecl
def findEnvDecl' (env : Environment) (n : Name) (decls : Array Decl) : Option Decl :=
match decls.find? (fun decl => decl.name == n) with
| some decl => some decl
| none => (declMapExt.getState env).find? n
def findDecl' (n : Name) (decls : Array Decl) : CompilerM (Option Decl) := do
let s ← get; pure $ findEnvDecl' s.env n decls
def containsDecl' (n : Name) (decls : Array Decl) : CompilerM Bool := do
if decls.any fun decl => decl.name == n then
pure true
else
let s ← get
pure $ (declMapExt.getState s.env).contains n
def getDecl' (n : Name) (decls : Array Decl) : CompilerM Decl := do
let (some decl) ← findDecl' n decls | throw s!"unknown declaration '{n}'"
pure decl
let (some decl) ← findDecl' n decls | throw s!"unknown declaration '{n}'"
pure decl
end IR
end Lean

View file

@ -10,32 +10,32 @@ namespace Lean
namespace IR
inductive CtorFieldInfo
| irrelevant
| object (i : Nat)
| usize (i : Nat)
| scalar (sz : Nat) (offset : Nat) (type : IRType)
| irrelevant
| object (i : Nat)
| usize (i : Nat)
| scalar (sz : Nat) (offset : Nat) (type : IRType)
namespace CtorFieldInfo
def format : CtorFieldInfo → Format
| irrelevant => "◾"
| object i => f!"obj@{i}"
| usize i => f!"usize@{i}"
| scalar sz offset type => f!"scalar#{sz}@{offset}:{type}"
| irrelevant => "◾"
| object i => f!"obj@{i}"
| usize i => f!"usize@{i}"
| scalar sz offset type => f!"scalar#{sz}@{offset}:{type}"
instance : ToFormat CtorFieldInfo := ⟨format⟩
end CtorFieldInfo
structure CtorLayout :=
(cidx : Nat)
(fieldInfo : List CtorFieldInfo)
(numObjs : Nat)
(numUSize : Nat)
(scalarSize : Nat)
(cidx : Nat)
(fieldInfo : List CtorFieldInfo)
(numObjs : Nat)
(numUSize : Nat)
(scalarSize : Nat)
@[extern "lean_ir_get_ctor_layout"]
constant getCtorLayout (env : @& Environment) (ctorName : @& Name) : Except String CtorLayout := arbitrary _
constant getCtorLayout (env : @& Environment) (ctorName : @& Name) : Except String CtorLayout
end IR
end Lean

View file

@ -11,51 +11,51 @@ namespace Lean.IR.UnreachableBranches
/-- Value used in the abstract interpreter -/
inductive Value
| bot -- undefined
| top -- any value
| ctor (i : CtorInfo) (vs : Array Value)
| choice (vs : List Value)
| bot -- undefined
| top -- any value
| ctor (i : CtorInfo) (vs : Array Value)
| choice (vs : List Value)
namespace Value
instance : Inhabited Value := ⟨top⟩
protected partial def beq : Value → Value → Bool
| bot, bot => true
| top, top => true
| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ Value.beq
| choice vs₁, choice vs₂ =>
vs₁.all (fun v₁ => vs₂.any fun v₂ => Value.beq v₁ v₂)
&&
vs₂.all (fun v₂ => vs₁.any fun v₁ => Value.beq v₁ v₂)
| _, _ => false
| bot, bot => true
| top, top => true
| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ Value.beq
| choice vs₁, choice vs₂ =>
vs₁.all (fun v₁ => vs₂.any fun v₂ => Value.beq v₁ v₂)
&&
vs₂.all (fun v₂ => vs₁.any fun v₁ => Value.beq v₁ v₂)
| _, _ => false
instance : BEq Value := ⟨Value.beq⟩
partial def addChoice (merge : Value → Value → Value) : List Value → Value → List Value
| [], v => [v]
| v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) =>
if i₁ == i₂ then merge v₁ v₂ :: cs
else v₁ :: addChoice merge cs v₂
| _, _ => panic! "invalid addChoice"
| [], v => [v]
| v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) =>
if i₁ == i₂ then merge v₁ v₂ :: cs
else v₁ :: addChoice merge cs v₂
| _, _ => panic! "invalid addChoice"
partial def merge : Value → Value → Value
| bot, v => v
| v, bot => v
| top, _ => top
| _, top => top
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
if i₁ == i₂ then ctor i₁ $ vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i] vs₂[i])
else choice [v₁, v₂]
| choice vs₁, choice vs₂ => choice $ vs₁.foldl (addChoice merge) vs₂
| choice vs, v => choice $ addChoice merge vs v
| v, choice vs => choice $ addChoice merge vs v
| bot, v => v
| v, bot => v
| top, _ => top
| _, top => top
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
if i₁ == i₂ then ctor i₁ $ vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i] vs₂[i])
else choice [v₁, v₂]
| choice vs₁, choice vs₂ => choice $ vs₁.foldl (addChoice merge) vs₂
| choice vs, v => choice $ addChoice merge vs v
| v, choice vs => choice $ addChoice merge vs v
protected partial def format : Value → Format
| top => "top"
| bot => "bot"
| choice vs => fmt "@" ++ @List.format _ ⟨Value.format⟩ vs
| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨Value.format⟩ vs)
| top => "top"
| bot => "bot"
| choice vs => fmt "@" ++ @List.format _ ⟨Value.format⟩ vs
| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨Value.format⟩ vs)
instance : ToFormat Value := ⟨Value.format⟩
instance : ToString Value := ⟨Format.pretty ∘ Value.format⟩
@ -64,240 +64,236 @@ instance : ToString Value := ⟨Format.pretty ∘ Value.format⟩
We use this function this function to implement a simple widening operation for our abstract
interpreter. -/
partial def truncate (env : Environment) : Value → NameSet → Value
| ctor i vs, found =>
let I := i.name.getPrefix
if found.contains I then
top
else
let cont (found' : NameSet) : Value :=
ctor i (vs.map fun v => truncate env v found')
match env.find? I with
| some (ConstantInfo.inductInfo d) =>
if d.isRec then cont (found.insert I)
else cont found
| _ => cont found
| choice vs, found =>
let newVs := vs.map fun v => truncate env v found
if newVs.elem top then top
else choice newVs
| v, _ => v
| ctor i vs, found =>
let I := i.name.getPrefix
if found.contains I then
top
else
let cont (found' : NameSet) : Value :=
ctor i (vs.map fun v => truncate env v found')
match env.find? I with
| some (ConstantInfo.inductInfo d) =>
if d.isRec then cont (found.insert I)
else cont found
| _ => cont found
| choice vs, found =>
let newVs := vs.map fun v => truncate env v found
if newVs.elem top then top
else choice newVs
| v, _ => v
/- Widening operator that guarantees termination in our abstract interpreter. -/
def widening (env : Environment) (v₁ v₂ : Value) : Value :=
truncate env (merge v₁ v₂) {}
truncate env (merge v₁ v₂) {}
end Value
abbrev FunctionSummaries := SMap FunId Value
def mkFunctionSummariesExtension : IO (SimplePersistentEnvExtension (FunId × Value) FunctionSummaries) :=
registerSimplePersistentEnvExtension {
name := `unreachBranchesFunSummary,
addImportedFn := fun as =>
let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as
cache.switch,
addEntryFn := fun s ⟨e, n⟩ => s.insert e n
}
@[builtinInit mkFunctionSummariesExtension]
constant functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries := arbitrary _
builtin_initialize functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries ←
registerSimplePersistentEnvExtension {
name := `unreachBranchesFunSummary,
addImportedFn := fun as =>
let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as
cache.switch,
addEntryFn := fun s ⟨e, n⟩ => s.insert e n
}
def addFunctionSummary (env : Environment) (fid : FunId) (v : Value) : Environment :=
functionSummariesExt.addEntry env (fid, v)
functionSummariesExt.addEntry env (fid, v)
def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
(functionSummariesExt.getState env).find? fid
(functionSummariesExt.getState env).find? fid
abbrev Assignment := Std.HashMap VarId Value
structure InterpContext :=
(currFnIdx : Nat := 0)
(decls : Array Decl)
(env : Environment)
(lctx : LocalContext := {})
(currFnIdx : Nat := 0)
(decls : Array Decl)
(env : Environment)
(lctx : LocalContext := {})
structure InterpState :=
(assignments : Array Assignment)
(funVals : Std.PArray Value) -- we take snapshots during fixpoint computations
(assignments : Array Assignment)
(funVals : Std.PArray Value) -- we take snapshots during fixpoint computations
abbrev M := ReaderT InterpContext (StateM InterpState)
open Value
def findVarValue (x : VarId) : M Value := do
let ctx ← read
let s ← get
let assignment := s.assignments[ctx.currFnIdx]
pure $ assignment.findD x bot
let ctx ← read
let s ← get
let assignment := s.assignments[ctx.currFnIdx]
pure $ assignment.findD x bot
def findArgValue (arg : Arg) : M Value :=
match arg with
| Arg.var x => findVarValue x
| _ => pure top
match arg with
| Arg.var x => findVarValue x
| _ => pure top
def updateVarAssignment (x : VarId) (v : Value) : M Unit := do
let v' ← findVarValue x
let ctx ← read
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') }
let v' ← findVarValue x
let ctx ← read
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') }
def resetVarAssignment (x : VarId) : M Unit := do
let ctx ← read
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot }
let ctx ← read
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot }
def resetParamAssignment (y : Param) : M Unit :=
resetVarAssignment y.x
resetVarAssignment y.x
partial def projValue : Value → Nat → Value
| ctor _ vs, i => vs.getD i bot
| choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot
| v, _ => v
| ctor _ vs, i => vs.getD i bot
| choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot
| v, _ => v
def interpExpr : Expr → M Value
| Expr.ctor i ys => do return ctor i (← ys.mapM fun y => findArgValue y)
| Expr.proj i x => do return projValue (← findVarValue x) i
| Expr.fap fid ys => do
let ctx ← read
match getFunctionSummary? ctx.env fid with
| some v => pure v
| none => do
let s ← get
match ctx.decls.findIdx? (fun decl => decl.name == fid) with
| some idx => pure s.funVals[idx]
| none => pure top
| _ => pure top
| Expr.ctor i ys => do return ctor i (← ys.mapM fun y => findArgValue y)
| Expr.proj i x => do return projValue (← findVarValue x) i
| Expr.fap fid ys => do
let ctx ← read
match getFunctionSummary? ctx.env fid with
| some v => pure v
| none => do
let s ← get
match ctx.decls.findIdx? (fun decl => decl.name == fid) with
| some idx => pure s.funVals[idx]
| none => pure top
| _ => pure top
partial def containsCtor : Value → CtorInfo → Bool
| top, _ => true
| ctor i _, j => i == j
| choice vs, j => vs.any $ fun v => containsCtor v j
| _, _ => false
| top, _ => true
| ctor i _, j => i == j
| choice vs, j => vs.any $ fun v => containsCtor v j
| _, _ => false
def updateCurrFnSummary (v : Value) : M Unit := do
let ctx ← read
let currFnIdx := ctx.currFnIdx
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
let ctx ← read
let currFnIdx := ctx.currFnIdx
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
/-- Return true if the assignment of at least one parameter has been updated. -/
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
let ctx ← read
let currFnIdx := ctx.currFnIdx
ys.size.foldM (init := false) fun i r => do
let y := ys[i]
let x := xs[i]
let yVal ← findVarValue y.x
let xVal ← findArgValue x
let newVal := merge yVal xVal
if newVal == yVal then
pure r
else
modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal }
pure true
private partial def resetNestedJPParams : FnBody → M Unit
| FnBody.jdecl _ ys b k => do
let ctx ← read
let currFnIdx := ctx.currFnIdx
ys.forM resetParamAssignment
/- Remark we don't need to reset the parameters of joint-points
nested in `b` since they will be reset if this JP is used. -/
resetNestedJPParams k
| FnBody.case _ _ _ alts =>
alts.forM fun alt => match alt with
| Alt.ctor _ b => resetNestedJPParams b
| Alt.default b => resetNestedJPParams b
| e => do unless e.isTerminal do resetNestedJPParams e.body
ys.size.foldM (init := false) fun i r => do
let y := ys[i]
let x := xs[i]
let yVal ← findVarValue y.x
let xVal ← findArgValue x
let newVal := merge yVal xVal
if newVal == yVal then
pure r
else
modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal }
pure true
private partial def resetNestedJPParams : FnBody → M Unit
| FnBody.jdecl _ ys b k => do
let ctx ← read
let currFnIdx := ctx.currFnIdx
ys.forM resetParamAssignment
/- Remark we don't need to reset the parameters of joint-points
nested in `b` since they will be reset if this JP is used. -/
resetNestedJPParams k
| FnBody.case _ _ _ alts =>
alts.forM fun alt => match alt with
| Alt.ctor _ b => resetNestedJPParams b
| Alt.default b => resetNestedJPParams b
| e => do unless e.isTerminal do resetNestedJPParams e.body
partial def interpFnBody : FnBody → M Unit
| FnBody.vdecl x _ e b => do
let v ← interpExpr e
updateVarAssignment x v
interpFnBody b
| FnBody.jdecl j ys v b =>
withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) do
| FnBody.vdecl x _ e b => do
let v ← interpExpr e
updateVarAssignment x v
interpFnBody b
| FnBody.case _ x _ alts => do
let v ← findVarValue x
alts.forM fun alt => do
match alt with
| Alt.ctor i b => if containsCtor v i then interpFnBody b
| Alt.default b => interpFnBody b
| FnBody.ret x => do
let v ← findArgValue x
-- dbgTrace ("ret " ++ toString v) $ fun _ =>
updateCurrFnSummary v
| FnBody.jmp j xs => do
let ctx ← read
let ys := (ctx.lctx.getJPParams j).get!
let b := (ctx.lctx.getJPBody j).get!
let updated ← updateJPParamsAssignment ys xs
if updated then
-- We must reset the value of nested join-point parameters since they depend on `ys` values
resetNestedJPParams b
interpFnBody b
| e => do
unless e.isTerminal do
interpFnBody e.body
| FnBody.jdecl j ys v b =>
withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) do
interpFnBody b
| FnBody.case _ x _ alts => do
let v ← findVarValue x
alts.forM fun alt => do
match alt with
| Alt.ctor i b => if containsCtor v i then interpFnBody b
| Alt.default b => interpFnBody b
| FnBody.ret x => do
let v ← findArgValue x
-- dbgTrace ("ret " ++ toString v) $ fun _ =>
updateCurrFnSummary v
| FnBody.jmp j xs => do
let ctx ← read
let ys := (ctx.lctx.getJPParams j).get!
let b := (ctx.lctx.getJPBody j).get!
let updated ← updateJPParamsAssignment ys xs
if updated then
-- We must reset the value of nested join-point parameters since they depend on `ys` values
resetNestedJPParams b
interpFnBody b
| e => do
unless e.isTerminal do
interpFnBody e.body
def inferStep : M Bool := do
let ctx ← read
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
ctx.decls.size.foldM (init := false) fun idx modified => do
match ctx.decls[idx] with
| Decl.fdecl fid ys _ b => do
let s ← get
let currVals := s.funVals[idx]
withReader (fun ctx => { ctx with currFnIdx := idx }) do
ys.forM fun y => updateVarAssignment y.x top
interpFnBody b
let ctx ← read
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
ctx.decls.size.foldM (init := false) fun idx modified => do
match ctx.decls[idx] with
| Decl.fdecl fid ys _ b => do
let s ← get
let newVals := s.funVals[idx]
pure (modified || currVals != newVals)
| Decl.extern _ _ _ _ => pure modified
let currVals := s.funVals[idx]
withReader (fun ctx => { ctx with currFnIdx := idx }) do
ys.forM fun y => updateVarAssignment y.x top
interpFnBody b
let s ← get
let newVals := s.funVals[idx]
pure (modified || currVals != newVals)
| Decl.extern _ _ _ _ => pure modified
partial def inferMain : Unit → M Unit
| _ => do
partial def inferMain : M Unit := do
let modified ← inferStep
if modified then inferMain () else pure ()
if modified then inferMain else pure ()
partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
| FnBody.case tid x xType alts =>
let v := assignment.findD x bot
let alts := alts.map fun alt =>
match alt with
| Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
| Alt.default b => Alt.default (elimDeadAux assignment b)
FnBody.case tid x xType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := elimDeadAux assignment b
instr.setBody b
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
| FnBody.case tid x xType alts =>
let v := assignment.findD x bot
let alts := alts.map fun alt =>
match alt with
| Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
| Alt.default b => Alt.default (elimDeadAux assignment b)
FnBody.case tid x xType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := elimDeadAux assignment b
instr.setBody b
partial def elimDead (assignment : Assignment) : Decl → Decl
| Decl.fdecl fid ys t b => Decl.fdecl fid ys t $ elimDeadAux assignment b
| other => other
| Decl.fdecl fid ys t b => Decl.fdecl fid ys t $ elimDeadAux assignment b
| other => other
end UnreachableBranches
open UnreachableBranches
def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
let s ← get
let env := s.env
let assignments : Array Assignment := decls.map fun _ => {}
let funVals := Std.mkPArray decls.size Value.bot
let ctx : InterpContext := { decls := decls, env := env }
let s : InterpState := { assignments := assignments, funVals := funVals }
let (_, s) := (inferMain () ctx).run s
let funVals := s.funVals
let assignments := s.assignments
modify fun s =>
let env := decls.size.fold (init := s.env) fun i env =>
addFunctionSummary env decls[i].name funVals[i]
{ s with env := env }
pure $ decls.mapIdx fun i decl => elimDead assignments[i] decl
let s ← get
let env := s.env
let assignments : Array Assignment := decls.map fun _ => {}
let funVals := Std.mkPArray decls.size Value.bot
let ctx : InterpContext := { decls := decls, env := env }
let s : InterpState := { assignments := assignments, funVals := funVals }
let (_, s) := (inferMain ctx).run s
let funVals := s.funVals
let assignments := s.assignments
modify fun s =>
let env := decls.size.fold (init := s.env) fun i env =>
addFunctionSummary env decls[i].name funVals[i]
{ s with env := env }
pure $ decls.mapIdx fun i decl => elimDead assignments[i] decl
end Lean.IR

View file

@ -8,30 +8,27 @@ import Lean.Compiler.IR.FreeVars
namespace Lean.IR
partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnBody
| bs, b, used =>
if bs.isEmpty then b
else
let curr := bs.back
let bs := bs.pop
let keep (_ : Unit) :=
let used := curr.collectFreeIndices used
let b := curr.setBody b
reshapeWithoutDeadAux bs b used
let keepIfUsed (vidx : Index) :=
if used.contains vidx then keep ()
else reshapeWithoutDeadAux bs b used
match curr with
| FnBody.vdecl x _ _ _ => keepIfUsed x.idx
-- TODO: we should keep all struct/union projections because they are used to ensure struct/union values are fully consumed.
| FnBody.jdecl j _ _ _ => keepIfUsed j.idx
| _ => keep ()
partial def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody :=
let rec reshape (bs : Array FnBody) (b : FnBody) (used : IndexSet) :=
if bs.isEmpty then b
else
let curr := bs.back
let bs := bs.pop
let keep (_ : Unit) :=
let used := curr.collectFreeIndices used
let b := curr.setBody b
reshape bs b used
let keepIfUsed (vidx : Index) :=
if used.contains vidx then keep ()
else reshape bs b used
match curr with
| FnBody.vdecl x _ _ _ => keepIfUsed x.idx
-- TODO: we should keep all struct/union projections because they are used to ensure struct/union values are fully consumed.
| FnBody.jdecl j _ _ _ => keepIfUsed j.idx
| _ => keep ()
reshape bs term term.freeIndices
def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody :=
reshapeWithoutDeadAux bs term term.freeIndices
partial def FnBody.elimDead : FnBody → FnBody
| b =>
partial def FnBody.elimDead (b : FnBody) : FnBody :=
let (bs, term) := b.flatten
let bs := modifyJPs bs elimDead
let term := match term with
@ -43,7 +40,7 @@ partial def FnBody.elimDead : FnBody → FnBody
/-- Eliminate dead let-declarations and join points -/
def Decl.elimDead : Decl → Decl
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.elimDead
| other => other
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.elimDead
| other => other
end Lean.IR

File diff suppressed because it is too large Load diff

View file

@ -11,44 +11,44 @@ import Lean.Compiler.IR.CompilerM
namespace Lean.IR
/- Return true iff `b` is of the form `let x := g ys; ret x` -/
def isTailCallTo (g : Name) (b : FnBody) : Bool :=
match b with
| FnBody.vdecl x _ (Expr.fap f _) (FnBody.ret (Arg.var y)) => x == y && f == g
| _ => false
match b with
| FnBody.vdecl x _ (Expr.fap f _) (FnBody.ret (Arg.var y)) => x == y && f == g
| _ => false
def usesModuleFrom (env : Environment) (modulePrefix : Name) : Bool :=
env.allImportedModuleNames.toList.any $ fun modName => modulePrefix.isPrefixOf modName
env.allImportedModuleNames.toList.any $ fun modName => modulePrefix.isPrefixOf modName
namespace CollectUsedDecls
abbrev M := ReaderT Environment (StateM NameSet)
@[inline] def collect (f : FunId) : M Unit :=
modify $ fun s => s.insert f
modify fun s => s.insert f
partial def collectFnBody : FnBody → M Unit
| FnBody.vdecl _ _ v b =>
match v with
| Expr.fap f _ => collect f *> collectFnBody b
| Expr.pap f _ => collect f *> collectFnBody b
| other => collectFnBody b
| FnBody.jdecl _ _ v b => collectFnBody v *> collectFnBody b
| FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body
| e => do unless e.isTerminal do collectFnBody e.body
| FnBody.vdecl _ _ v b =>
match v with
| Expr.fap f _ => collect f *> collectFnBody b
| Expr.pap f _ => collect f *> collectFnBody b
| other => collectFnBody b
| FnBody.jdecl _ _ v b => collectFnBody v *> collectFnBody b
| FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body
| e => do unless e.isTerminal do collectFnBody e.body
def collectInitDecl (fn : Name) : M Unit := do
let env ← read
match getInitFnNameFor? env fn with
| some initFn => collect initFn
| _ => pure ()
let env ← read
match getInitFnNameFor? env fn with
| some initFn => collect initFn
| _ => pure ()
def collectDecl : Decl → M NameSet
| Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get
| Decl.extern fn _ _ _ => collectInitDecl fn *> get
| Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get
| Decl.extern fn _ _ _ => collectInitDecl fn *> get
end CollectUsedDecls
def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet :=
(CollectUsedDecls.collectDecl decl env).run' used
(CollectUsedDecls.collectDecl decl env).run' used
abbrev VarTypeMap := Std.HashMap VarId IRType
abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
@ -56,22 +56,22 @@ abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
namespace CollectMaps
abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap)
@[inline] def collectVar (x : VarId) (t : IRType) : Collector
| (vs, js) => (vs.insert x t, js)
| (vs, js) => (vs.insert x t, js)
def collectParams (ps : Array Param) : Collector :=
fun s => ps.foldl (fun s p => collectVar p.x p.ty s) s
fun s => ps.foldl (fun s p => collectVar p.x p.ty s) s
@[inline] def collectJP (j : JoinPointId) (xs : Array Param) : Collector
| (vs, js) => (vs, js.insert j xs)
| (vs, js) => (vs, js.insert j xs)
/- `collectFnBody` assumes the variables in -/
partial def collectFnBody : FnBody → Collector
| FnBody.vdecl x t _ b => collectVar x t ∘ collectFnBody b
| FnBody.jdecl j xs v b => collectJP j xs ∘ collectParams xs ∘ collectFnBody v ∘ collectFnBody b
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
| e => if e.isTerminal then id else collectFnBody e.body
| FnBody.vdecl x t _ b => collectVar x t ∘ collectFnBody b
| FnBody.jdecl j xs v b => collectJP j xs ∘ collectParams xs ∘ collectFnBody v ∘ collectFnBody b
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
| e => if e.isTerminal then id else collectFnBody e.body
def collectDecl : Decl → Collector
| Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b
| _ => id
| Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b
| _ => id
end CollectMaps
@ -79,7 +79,7 @@ end CollectMaps
and `j` is a mapping from join point to parameters.
This function assumes `d` has normalized indexes (see `normids.lean`). -/
def mkVarJPMaps (d : Decl) : VarTypeMap × JPParamsMap :=
CollectMaps.collectDecl d ({}, {})
CollectMaps.collectDecl d ({}, {})
end IR
end Lean

View file

@ -12,46 +12,45 @@ namespace Lean.IR.ExpandResetReuse
abbrev ProjMap := Std.HashMap VarId Expr
namespace CollectProjMap
abbrev Collector := ProjMap → ProjMap
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector :=
fun m => match v with
| Expr.proj _ _ => m.insert x v
| Expr.sproj _ _ _ => m.insert x v
| Expr.uproj _ _ => m.insert x v
| _ => m
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m =>
match v with
| Expr.proj .. => m.insert x v
| Expr.sproj .. => m.insert x v
| Expr.uproj .. => m.insert x v
| _ => m
partial def collectFnBody : FnBody → Collector
| FnBody.vdecl x _ v b => collectVDecl x v ∘ collectFnBody b
| FnBody.jdecl _ _ v b => collectFnBody v ∘ collectFnBody b
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
| e => if e.isTerminal then id else collectFnBody e.body
| FnBody.vdecl x _ v b => collectVDecl x v ∘ collectFnBody b
| FnBody.jdecl _ _ v b => collectFnBody v ∘ collectFnBody b
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
| e => if e.isTerminal then id else collectFnBody e.body
end CollectProjMap
/- Create a mapping from variables to projections.
This function assumes variable ids have been normalized -/
def mkProjMap (d : Decl) : ProjMap :=
match d with
| Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {}
| _ => {}
match d with
| Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {}
| _ => {}
structure Context :=
(projMap : ProjMap)
(projMap : ProjMap)
/- Return true iff `x` is consumed in all branches of the current block.
Here consumption means the block contains a `dec x` or `reuse x ...`. -/
partial def consumed (x : VarId) : FnBody → Bool
| FnBody.vdecl _ _ v b =>
match v with
| Expr.reuse y _ _ _ => x == y || consumed x b
| _ => consumed x b
| FnBody.dec y _ _ _ b => x == y || consumed x b
| FnBody.case _ _ _ alts => alts.all fun alt => consumed x alt.body
| e => !e.isTerminal && consumed x e.body
| FnBody.vdecl _ _ v b =>
match v with
| Expr.reuse y _ _ _ => x == y || consumed x b
| _ => consumed x b
| FnBody.dec y _ _ _ b => x == y || consumed x b
| FnBody.case _ _ _ alts => alts.all fun alt => consumed x alt.body
| e => !e.isTerminal && consumed x e.body
abbrev Mask := Array (Option VarId)
/- Auxiliary function for eraseProjIncFor -/
partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnBody → Array FnBody × Mask
| bs, mask, keep =>
partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
let done (_ : Unit) := (bs ++ keep.reverse, mask)
let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
if bs.size < 2 then done ()
@ -85,30 +84,30 @@ partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnB
/- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`.
Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/
def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask :=
eraseProjIncForAux y bs (mkArray n none) #[]
eraseProjIncForAux y bs (mkArray n none) #[]
/- Replace `reuse x ctor ...` with `ctor ...`, and remoce `dec x` -/
partial def reuseToCtor (x : VarId) : FnBody → FnBody
| FnBody.dec y n c p b =>
if x == y then b -- n must be 1 since `x := reset ...`
else FnBody.dec y n c p (reuseToCtor x b)
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse y c u xs =>
if x == y then FnBody.vdecl z t (Expr.ctor c xs) b
else FnBody.vdecl z t v (reuseToCtor x b)
| _ =>
FnBody.vdecl z t v (reuseToCtor x b)
| FnBody.case tid y yType alts =>
let alts := alts.map fun alt => alt.modifyBody (reuseToCtor x)
FnBody.case tid y yType alts
| e =>
if e.isTerminal then
e
else
let (instr, b) := e.split
let b := reuseToCtor x b
instr.setBody b
| FnBody.dec y n c p b =>
if x == y then b -- n must be 1 since `x := reset ...`
else FnBody.dec y n c p (reuseToCtor x b)
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse y c u xs =>
if x == y then FnBody.vdecl z t (Expr.ctor c xs) b
else FnBody.vdecl z t v (reuseToCtor x b)
| _ =>
FnBody.vdecl z t v (reuseToCtor x b)
| FnBody.case tid y yType alts =>
let alts := alts.map fun alt => alt.modifyBody (reuseToCtor x)
FnBody.case tid y yType alts
| e =>
if e.isTerminal then
e
else
let (instr, b) := e.split
let b := reuseToCtor x b
instr.setBody b
/-
replace
@ -123,97 +122,91 @@ where `z_i`'s are the variables in `mask`,
and `b'` is `b` where we removed `dec x` and replaced `reuse x ctor_i ...` with `ctor_i ...`.
-/
def mkSlowPath (x y : VarId) (mask : Mask) (b : FnBody) : FnBody :=
let b := reuseToCtor x b
let b := FnBody.dec y 1 true false b
mask.foldl
(fun b m => match m with
| some z => FnBody.inc z 1 true false b
| none => b)
b
let b := reuseToCtor x b
let b := FnBody.dec y 1 true false b
mask.foldl (init := b) fun b m => match m with
| some z => FnBody.inc z 1 true false b
| none => b
abbrev M := ReaderT Context (StateM Nat)
def mkFresh : M VarId :=
modifyGet $ fun n => ({ idx := n }, n + 1)
def mkFresh : M VarId :=
modifyGet $ fun n => ({ idx := n }, n + 1)
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
mask.size.foldM
(fun i b =>
mask.size.foldM (init := b) fun i b =>
match mask.get! i with
| some _ => pure b -- code took ownership of this field
| none => do
let fld ← mkFresh
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b)))
b
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b))
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
zs.size.fold
(fun i b => FnBody.set y i (zs.get! i) b)
b
zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b
/- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
match y with
| Arg.var y =>
match ctx.projMap.find? y with
| some (Expr.proj j w) => j == i && w == x
match y with
| Arg.var y =>
match ctx.projMap.find? y with
| some (Expr.proj j w) => j == i && w == x
| _ => false
| _ => false
| _ => false
/- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap.find? y with
| some (Expr.uproj j w) => j == i && w == x
| _ => false
match ctx.projMap.find? y with
| some (Expr.uproj j w) => j == i && w == x
| _ => false
/- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap.find? y with
| some (Expr.sproj m j w) => n == m && j == i && w == x
| _ => false
match ctx.projMap.find? y with
| some (Expr.sproj m j w) => n == m && j == i && w == x
| _ => false
/- Remove unnecessary `set/uset/sset` operations -/
partial def removeSelfSet (ctx : Context) : FnBody → FnBody
| FnBody.set x i y b =>
if isSelfSet ctx x i y then removeSelfSet ctx b
else FnBody.set x i y (removeSelfSet ctx b)
| FnBody.uset x i y b =>
if isSelfUSet ctx x i y then removeSelfSet ctx b
else FnBody.uset x i y (removeSelfSet ctx b)
| FnBody.sset x n i y t b =>
if isSelfSSet ctx x n i y then removeSelfSet ctx b
else FnBody.sset x n i y t (removeSelfSet ctx b)
| FnBody.case tid y yType alts =>
let alts := alts.map fun alt => alt.modifyBody (removeSelfSet ctx)
FnBody.case tid y yType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := removeSelfSet ctx b
instr.setBody b
| FnBody.set x i y b =>
if isSelfSet ctx x i y then removeSelfSet ctx b
else FnBody.set x i y (removeSelfSet ctx b)
| FnBody.uset x i y b =>
if isSelfUSet ctx x i y then removeSelfSet ctx b
else FnBody.uset x i y (removeSelfSet ctx b)
| FnBody.sset x n i y t b =>
if isSelfSSet ctx x n i y then removeSelfSet ctx b
else FnBody.sset x n i y t (removeSelfSet ctx b)
| FnBody.case tid y yType alts =>
let alts := alts.map fun alt => alt.modifyBody (removeSelfSet ctx)
FnBody.case tid y yType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := removeSelfSet ctx b
instr.setBody b
partial def reuseToSet (ctx : Context) (x y : VarId) : FnBody → FnBody
| FnBody.dec z n c p b =>
if x == z then FnBody.del y b
else FnBody.dec z n c p (reuseToSet ctx x y b)
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse w c u zs =>
if x == w then
let b := setFields y zs (b.replaceVar z y)
let b := if u then FnBody.setTag y c.cidx b else b
removeSelfSet ctx b
else FnBody.vdecl z t v (reuseToSet ctx x y b)
| _ => FnBody.vdecl z t v (reuseToSet ctx x y b)
| FnBody.case tid z zType alts =>
let alts := alts.map fun alt => alt.modifyBody (reuseToSet ctx x y)
FnBody.case tid z zType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := reuseToSet ctx x y b
instr.setBody b
| FnBody.dec z n c p b =>
if x == z then FnBody.del y b
else FnBody.dec z n c p (reuseToSet ctx x y b)
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse w c u zs =>
if x == w then
let b := setFields y zs (b.replaceVar z y)
let b := if u then FnBody.setTag y c.cidx b else b
removeSelfSet ctx b
else FnBody.vdecl z t v (reuseToSet ctx x y b)
| _ => FnBody.vdecl z t v (reuseToSet ctx x y b)
| FnBody.case tid z zType alts =>
let alts := alts.map fun alt => alt.modifyBody (reuseToSet ctx x y)
FnBody.case tid z zType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := reuseToSet ctx x y b
instr.setBody b
/-
replace
@ -235,54 +228,54 @@ and `z := reuse x ctor_i ws; F` is replaced with
`set x i ws[i]` operations, and we replace `z` with `x` in `F`
-/
def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody := do
let ctx ← read
let b := reuseToSet ctx x y b
releaseUnreadFields y mask b
let ctx ← read
let b := reuseToSet ctx x y b
releaseUnreadFields y mask b
-- Expand `bs; x := reset[n] y; b`
partial def expand (mainFn : FnBody → Array FnBody → M FnBody)
(bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody := do
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b
let (bs, mask) := eraseProjIncFor n y bs
/- Remark: we may be duplicting variable/JP indices. That is, `bSlow` and `bFast` may
have duplicate indices. We run `normalizeIds` to fix the ids after we have expand them. -/
let bSlow := mkSlowPath x y mask b
let bFast ← mkFastPath x y mask b
/- We only optimize recursively the fast. -/
let bFast ← mainFn bFast #[]
let c ← mkFresh
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast)
pure $ reshape bs b
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b
let (bs, mask) := eraseProjIncFor n y bs
/- Remark: we may be duplicting variable/JP indices. That is, `bSlow` and `bFast` may
have duplicate indices. We run `normalizeIds` to fix the ids after we have expand them. -/
let bSlow := mkSlowPath x y mask b
let bFast ← mkFastPath x y mask b
/- We only optimize recursively the fast. -/
let bFast ← mainFn bFast #[]
let c ← mkFresh
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast)
pure $ reshape bs b
partial def searchAndExpand : FnBody → Array FnBody → M FnBody
| d@(FnBody.vdecl x t (Expr.reset n y) b), bs =>
if consumed x b then do
expand searchAndExpand bs x n y b
else
searchAndExpand b (push bs d)
| FnBody.jdecl j xs v b, bs => do
let v ← searchAndExpand v #[]
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
| FnBody.case tid x xType alts, bs => do
let alts ← alts.mapM $ fun alt => alt.mmodifyBody fun b => searchAndExpand b #[]
pure $ reshape bs (FnBody.case tid x xType alts)
| b, bs =>
if b.isTerminal then pure $ reshape bs b
else searchAndExpand b.body (push bs b)
| d@(FnBody.vdecl x t (Expr.reset n y) b), bs =>
if consumed x b then do
expand searchAndExpand bs x n y b
else
searchAndExpand b (push bs d)
| FnBody.jdecl j xs v b, bs => do
let v ← searchAndExpand v #[]
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
| FnBody.case tid x xType alts, bs => do
let alts ← alts.mapM $ fun alt => alt.mmodifyBody fun b => searchAndExpand b #[]
pure $ reshape bs (FnBody.case tid x xType alts)
| b, bs =>
if b.isTerminal then pure $ reshape bs b
else searchAndExpand b.body (push bs b)
def main (d : Decl) : Decl :=
match d with
| (Decl.fdecl f xs t b) =>
let m := mkProjMap d
let nextIdx := d.maxIndex + 1
let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx
Decl.fdecl f xs t b
| d => d
match d with
| (Decl.fdecl f xs t b) =>
let m := mkProjMap d
let nextIdx := d.maxIndex + 1
let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx
Decl.fdecl f xs t b
| d => d
end ExpandResetReuse
/-- (Try to) expand `reset` and `reuse` instructions. -/
def Decl.expandResetReuse (d : Decl) : Decl :=
(ExpandResetReuse.main d).normalizeIds
(ExpandResetReuse.main d).normalizeIds
end Lean.IR

View file

@ -26,62 +26,62 @@ abbrev Collector := Index → Index
instance : AndThen Collector := ⟨seq⟩
private def collectArg : Arg → Collector
| Arg.var x => collectVar x
| irrelevant => skip
| Arg.var x => collectVar x
| irrelevant => skip
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
fun m => as.foldl (fun m a => f a m) m
fun m => as.foldl (fun m a => f a m) m
private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg
private def collectParam (p : Param) : Collector := collectVar p.x
private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam
private def collectExpr : Expr → Collector
| Expr.ctor _ ys => collectArgs ys
| Expr.reset _ x => collectVar x
| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys
| Expr.proj _ x => collectVar x
| Expr.uproj _ x => collectVar x
| Expr.sproj _ _ x => collectVar x
| Expr.fap _ ys => collectArgs ys
| Expr.pap _ ys => collectArgs ys
| Expr.ap x ys => collectVar x >> collectArgs ys
| Expr.box _ x => collectVar x
| Expr.unbox x => collectVar x
| Expr.lit v => skip
| Expr.isShared x => collectVar x
| Expr.isTaggedPtr x => collectVar x
| Expr.ctor _ ys => collectArgs ys
| Expr.reset _ x => collectVar x
| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys
| Expr.proj _ x => collectVar x
| Expr.uproj _ x => collectVar x
| Expr.sproj _ _ x => collectVar x
| Expr.fap _ ys => collectArgs ys
| Expr.pap _ ys => collectArgs ys
| Expr.ap x ys => collectVar x >> collectArgs ys
| Expr.box _ x => collectVar x
| Expr.unbox x => collectVar x
| Expr.lit v => skip
| Expr.isShared x => collectVar x
| Expr.isTaggedPtr x => collectVar x
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
collectArray alts $ fun alt => f alt.body
collectArray alts $ fun alt => f alt.body
partial def collectFnBody : FnBody → Collector
| FnBody.vdecl x _ v b => collectVar x >> collectExpr v >> collectFnBody b
| FnBody.jdecl j ys v b => collectJP j >> collectFnBody v >> collectParams ys >> collectFnBody b
| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b
| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b
| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b
| FnBody.setTag x _ b => collectVar x >> collectFnBody b
| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b
| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b
| FnBody.del x b => collectVar x >> collectFnBody b
| FnBody.mdata _ b => collectFnBody b
| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts
| FnBody.jmp j ys => collectJP j >> collectArgs ys
| FnBody.ret x => collectArg x
| FnBody.unreachable => skip
| FnBody.vdecl x _ v b => collectVar x >> collectExpr v >> collectFnBody b
| FnBody.jdecl j ys v b => collectJP j >> collectFnBody v >> collectParams ys >> collectFnBody b
| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b
| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b
| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b
| FnBody.setTag x _ b => collectVar x >> collectFnBody b
| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b
| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b
| FnBody.del x b => collectVar x >> collectFnBody b
| FnBody.mdata _ b => collectFnBody b
| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts
| FnBody.jmp j ys => collectJP j >> collectArgs ys
| FnBody.ret x => collectArg x
| FnBody.unreachable => skip
partial def collectDecl : Decl → Collector
| Decl.fdecl _ xs _ b => collectParams xs >> collectFnBody b
| Decl.extern _ xs _ _ => collectParams xs
| Decl.fdecl _ xs _ b => collectParams xs >> collectFnBody b
| Decl.extern _ xs _ _ => collectParams xs
end MaxIndex
def FnBody.maxIndex (b : FnBody) : Index :=
MaxIndex.collectFnBody b 0
MaxIndex.collectFnBody b 0
def Decl.maxIndex (d : Decl) : Index :=
MaxIndex.collectDecl d 0
MaxIndex.collectDecl d 0
namespace FreeIndices
/- We say a variable (join point) index (aka name) is free in a function body
@ -90,89 +90,89 @@ namespace FreeIndices
abbrev Collector := IndexSet → IndexSet → IndexSet
@[inline] private def skip : Collector :=
fun bv fv => fv
fun bv fv => fv
@[inline] private def collectIndex (x : Index) : Collector :=
fun bv fv => if bv.contains x then fv else fv.insert x
fun bv fv => if bv.contains x then fv else fv.insert x
@[inline] private def collectVar (x : VarId) : Collector :=
collectIndex x.idx
collectIndex x.idx
@[inline] private def collectJP (x : JoinPointId) : Collector :=
collectIndex x.idx
collectIndex x.idx
@[inline] private def withIndex (x : Index) : Collector → Collector :=
fun k bv fv => k (bv.insert x) fv
fun k bv fv => k (bv.insert x) fv
@[inline] private def withVar (x : VarId) : Collector → Collector :=
withIndex x.idx
withIndex x.idx
@[inline] private def withJP (x : JoinPointId) : Collector → Collector :=
withIndex x.idx
withIndex x.idx
def insertParams (s : IndexSet) (ys : Array Param) : IndexSet :=
ys.foldl (fun s p => s.insert p.x.idx) s
ys.foldl (init := s) fun s p => s.insert p.x.idx
@[inline] private def withParams (ys : Array Param) : Collector → Collector :=
fun k bv fv => k (insertParams bv ys) fv
fun k bv fv => k (insertParams bv ys) fv
@[inline] private def seq : Collector → Collector → Collector :=
fun k₁ k₂ bv fv => k₂ bv (k₁ bv fv)
fun k₁ k₂ bv fv => k₂ bv (k₁ bv fv)
instance : AndThen Collector := ⟨seq⟩
private def collectArg : Arg → Collector
| Arg.var x => collectVar x
| irrelevant => skip
| Arg.var x => collectVar x
| irrelevant => skip
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
fun bv fv => as.foldl (fun fv a => f a bv fv) fv
fun bv fv => as.foldl (fun fv a => f a bv fv) fv
private def collectArgs (as : Array Arg) : Collector :=
collectArray as collectArg
collectArray as collectArg
private def collectExpr : Expr → Collector
| Expr.ctor _ ys => collectArgs ys
| Expr.reset _ x => collectVar x
| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys
| Expr.proj _ x => collectVar x
| Expr.uproj _ x => collectVar x
| Expr.sproj _ _ x => collectVar x
| Expr.fap _ ys => collectArgs ys
| Expr.pap _ ys => collectArgs ys
| Expr.ap x ys => collectVar x >> collectArgs ys
| Expr.box _ x => collectVar x
| Expr.unbox x => collectVar x
| Expr.lit v => skip
| Expr.isShared x => collectVar x
| Expr.isTaggedPtr x => collectVar x
| Expr.ctor _ ys => collectArgs ys
| Expr.reset _ x => collectVar x
| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys
| Expr.proj _ x => collectVar x
| Expr.uproj _ x => collectVar x
| Expr.sproj _ _ x => collectVar x
| Expr.fap _ ys => collectArgs ys
| Expr.pap _ ys => collectArgs ys
| Expr.ap x ys => collectVar x >> collectArgs ys
| Expr.box _ x => collectVar x
| Expr.unbox x => collectVar x
| Expr.lit v => skip
| Expr.isShared x => collectVar x
| Expr.isTaggedPtr x => collectVar x
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
collectArray alts $ fun alt => f alt.body
collectArray alts $ fun alt => f alt.body
partial def collectFnBody : FnBody → Collector
| FnBody.vdecl x _ v b => collectExpr v >> withVar x (collectFnBody b)
| FnBody.jdecl j ys v b => withParams ys (collectFnBody v) >> withJP j (collectFnBody b)
| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b
| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b
| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b
| FnBody.setTag x _ b => collectVar x >> collectFnBody b
| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b
| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b
| FnBody.del x b => collectVar x >> collectFnBody b
| FnBody.mdata _ b => collectFnBody b
| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts
| FnBody.jmp j ys => collectJP j >> collectArgs ys
| FnBody.ret x => collectArg x
| FnBody.unreachable => skip
| FnBody.vdecl x _ v b => collectExpr v >> withVar x (collectFnBody b)
| FnBody.jdecl j ys v b => withParams ys (collectFnBody v) >> withJP j (collectFnBody b)
| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b
| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b
| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b
| FnBody.setTag x _ b => collectVar x >> collectFnBody b
| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b
| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b
| FnBody.del x b => collectVar x >> collectFnBody b
| FnBody.mdata _ b => collectFnBody b
| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts
| FnBody.jmp j ys => collectJP j >> collectArgs ys
| FnBody.ret x => collectArg x
| FnBody.unreachable => skip
end FreeIndices
def FnBody.collectFreeIndices (b : FnBody) (vs : IndexSet) : IndexSet :=
FreeIndices.collectFnBody b {} vs
FreeIndices.collectFnBody b {} vs
def FnBody.freeIndices (b : FnBody) : IndexSet :=
b.collectFreeIndices {}
b.collectFreeIndices {}
namespace HasIndex
/- In principle, we can check whether a function body `b` contains an index `i` using
@ -183,46 +183,46 @@ def visitVar (w : Index) (x : VarId) : Bool := w == x.idx
def visitJP (w : Index) (x : JoinPointId) : Bool := w == x.idx
def visitArg (w : Index) : Arg → Bool
| Arg.var x => visitVar w x
| _ => false
| Arg.var x => visitVar w x
| _ => false
def visitArgs (w : Index) (xs : Array Arg) : Bool :=
xs.any (visitArg w)
xs.any (visitArg w)
def visitParams (w : Index) (ps : Array Param) : Bool :=
ps.any (fun p => w == p.x.idx)
ps.any (fun p => w == p.x.idx)
def visitExpr (w : Index) : Expr → Bool
| Expr.ctor _ ys => visitArgs w ys
| Expr.reset _ x => visitVar w x
| Expr.reuse x _ _ ys => visitVar w x || visitArgs w ys
| Expr.proj _ x => visitVar w x
| Expr.uproj _ x => visitVar w x
| Expr.sproj _ _ x => visitVar w x
| Expr.fap _ ys => visitArgs w ys
| Expr.pap _ ys => visitArgs w ys
| Expr.ap x ys => visitVar w x || visitArgs w ys
| Expr.box _ x => visitVar w x
| Expr.unbox x => visitVar w x
| Expr.lit v => false
| Expr.isShared x => visitVar w x
| Expr.isTaggedPtr x => visitVar w x
| Expr.ctor _ ys => visitArgs w ys
| Expr.reset _ x => visitVar w x
| Expr.reuse x _ _ ys => visitVar w x || visitArgs w ys
| Expr.proj _ x => visitVar w x
| Expr.uproj _ x => visitVar w x
| Expr.sproj _ _ x => visitVar w x
| Expr.fap _ ys => visitArgs w ys
| Expr.pap _ ys => visitArgs w ys
| Expr.ap x ys => visitVar w x || visitArgs w ys
| Expr.box _ x => visitVar w x
| Expr.unbox x => visitVar w x
| Expr.lit v => false
| Expr.isShared x => visitVar w x
| Expr.isTaggedPtr x => visitVar w x
partial def visitFnBody (w : Index) : FnBody → Bool
| FnBody.vdecl x _ v b => visitExpr w v || visitFnBody w b
| FnBody.jdecl j ys v b => visitFnBody w v || visitFnBody w b
| FnBody.set x _ y b => visitVar w x || visitArg w y || visitFnBody w b
| FnBody.uset x _ y b => visitVar w x || visitVar w y || visitFnBody w b
| FnBody.sset x _ _ y _ b => visitVar w x || visitVar w y || visitFnBody w b
| FnBody.setTag x _ b => visitVar w x || visitFnBody w b
| FnBody.inc x _ _ _ b => visitVar w x || visitFnBody w b
| FnBody.dec x _ _ _ b => visitVar w x || visitFnBody w b
| FnBody.del x b => visitVar w x || visitFnBody w b
| FnBody.mdata _ b => visitFnBody w b
| FnBody.jmp j ys => visitJP w j || visitArgs w ys
| FnBody.ret x => visitArg w x
| FnBody.case _ x _ alts => visitVar w x || alts.any (fun alt => visitFnBody w alt.body)
| FnBody.unreachable => false
| FnBody.vdecl x _ v b => visitExpr w v || visitFnBody w b
| FnBody.jdecl j ys v b => visitFnBody w v || visitFnBody w b
| FnBody.set x _ y b => visitVar w x || visitArg w y || visitFnBody w b
| FnBody.uset x _ y b => visitVar w x || visitVar w y || visitFnBody w b
| FnBody.sset x _ _ y _ b => visitVar w x || visitVar w y || visitFnBody w b
| FnBody.setTag x _ b => visitVar w x || visitFnBody w b
| FnBody.inc x _ _ _ b => visitVar w x || visitFnBody w b
| FnBody.dec x _ _ _ b => visitVar w x || visitFnBody w b
| FnBody.del x b => visitVar w x || visitFnBody w b
| FnBody.mdata _ b => visitFnBody w b
| FnBody.jmp j ys => visitJP w j || visitArgs w ys
| FnBody.ret x => visitArg w x
| FnBody.case _ x _ alts => visitVar w x || alts.any (fun alt => visitFnBody w alt.body)
| FnBody.unreachable => false
end HasIndex

View file

@ -9,8 +9,7 @@ import Lean.Compiler.IR.NormIds
namespace Lean.IR
partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array FnBody → IndexSet → Array FnBody × Array Alt
| bs, alts, altsF, ctx, ctxF =>
partial def pushProjs (bs : Array FnBody) (alts : Array Alt) (altsF : Array IndexSet) (ctx : Array FnBody) (ctxF : IndexSet) : Array FnBody × Array Alt :=
if bs.isEmpty then (ctx.reverse, alts)
else
let b := bs.back
@ -37,8 +36,7 @@ partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array
| _ => done ()
| _ => done ()
partial def FnBody.pushProj : FnBody → FnBody
| b =>
partial def FnBody.pushProj (b : FnBody) : FnBody :=
let (bs, term) := b.flatten
let bs := modifyJPs bs pushProj
match term with
@ -52,7 +50,7 @@ partial def FnBody.pushProj : FnBody → FnBody
/-- Push projections inside `case` branches. -/
def Decl.pushProj : Decl → Decl
| Decl.fdecl f xs t b => (Decl.fdecl f xs t b.pushProj).normalizeIds
| other => other
| Decl.fdecl f xs t b => (Decl.fdecl f xs t b.pushProj).normalizeIds
| other => other
end Lean.IR

View file

@ -14,82 +14,82 @@ namespace Lean.IR.ExplicitRC
-/
structure VarInfo :=
(ref : Bool := true) -- true if the variable may be a reference (aka pointer) at runtime
(persistent : Bool := false) -- true if the variable is statically known to be marked a Persistent at runtime
(consume : Bool := false) -- true if the variable RC must be "consumed"
(ref : Bool := true) -- true if the variable may be a reference (aka pointer) at runtime
(persistent : Bool := false) -- true if the variable is statically known to be marked a Persistent at runtime
(consume : Bool := false) -- true if the variable RC must be "consumed"
abbrev VarMap := Std.RBMap VarId VarInfo (fun x y => x.idx < y.idx)
structure Context :=
(env : Environment)
(decls : Array Decl)
(varMap : VarMap := {})
(jpLiveVarMap : JPLiveVarMap := {}) -- map: join point => live variables
(localCtx : LocalContext := {}) -- we use it to store the join point declarations
(env : Environment)
(decls : Array Decl)
(varMap : VarMap := {})
(jpLiveVarMap : JPLiveVarMap := {}) -- map: join point => live variables
(localCtx : LocalContext := {}) -- we use it to store the join point declarations
def getDecl (ctx : Context) (fid : FunId) : Decl :=
match findEnvDecl' ctx.env fid ctx.decls with
| some decl => decl
| none => arbitrary _ -- unreachable if well-formed
match findEnvDecl' ctx.env fid ctx.decls with
| some decl => decl
| none => arbitrary _ -- unreachable if well-formed
def getVarInfo (ctx : Context) (x : VarId) : VarInfo :=
match ctx.varMap.find? x with
| some info => info
| none => {} -- unreachable in well-formed code
match ctx.varMap.find? x with
| some info => info
| none => {} -- unreachable in well-formed code
def getJPParams (ctx : Context) (j : JoinPointId) : Array Param :=
match ctx.localCtx.getJPParams j with
| some ps => ps
| none => #[] -- unreachable in well-formed code
match ctx.localCtx.getJPParams j with
| some ps => ps
| none => #[] -- unreachable in well-formed code
def getJPLiveVars (ctx : Context) (j : JoinPointId) : LiveVarSet :=
match ctx.jpLiveVarMap.find? j with
| some s => s
| none => {}
match ctx.jpLiveVarMap.find? j with
| some s => s
| none => {}
def mustConsume (ctx : Context) (x : VarId) : Bool :=
let info := getVarInfo ctx x
info.ref && info.consume
let info := getVarInfo ctx x
info.ref && info.consume
@[inline] def addInc (ctx : Context) (x : VarId) (b : FnBody) (n := 1) : FnBody :=
let info := getVarInfo ctx x
if n == 0 then b else FnBody.inc x n true info.persistent b
let info := getVarInfo ctx x
if n == 0 then b else FnBody.inc x n true info.persistent b
@[inline] def addDec (ctx : Context) (x : VarId) (b : FnBody) : FnBody :=
let info := getVarInfo ctx x
FnBody.dec x 1 true info.persistent b
let info := getVarInfo ctx x
FnBody.dec x 1 true info.persistent b
private def updateRefUsingCtorInfo (ctx : Context) (x : VarId) (c : CtorInfo) : Context :=
if c.isRef then ctx
else
let m := ctx.varMap
{ ctx with
varMap := match m.find? x with
| some info => m.insert x { info with ref := false } -- I really want a Lenses library + notation
| none => m }
if c.isRef then
ctx
else
let m := ctx.varMap
{ ctx with
varMap := match m.find? x with
| some info => m.insert x { info with ref := false } -- I really want a Lenses library + notation
| none => m }
private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet) (b : FnBody) : FnBody :=
caseLiveVars.fold
(fun b x => if !altLiveVars.contains x && mustConsume ctx x then addDec ctx x b else b)
b
caseLiveVars.fold (init := b) fun b x =>
if !altLiveVars.contains x && mustConsume ctx x then addDec ctx x b else b
/- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/
private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool :=
let x := xs[i]
i.all fun j => xs[j] != x
let x := xs[i]
i.all fun j => xs[j] != x
/- Return true if `x` also occurs in `ys` in a position that is not consumed.
That is, it is also passed as a borrow reference. -/
@[specialize]
private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Bool :=
ys.size.any $ fun i =>
let y := ys[i]
match y with
| Arg.irrelevant => false
| Arg.var y => x == y && !consumeParamPred i
ys.size.any fun i =>
let y := ys[i]
match y with
| Arg.irrelevant => false
| Arg.var y => x == y && !consumeParamPred i
private def isBorrowParam (x : VarId) (ys : Array Arg) (ps : Array Param) : Bool :=
isBorrowParamAux x ys fun i => not ps[i].borrow
isBorrowParamAux x ys fun i => not ps[i].borrow
/-
Return `n`, the number of times `x` is consumed.
@ -98,190 +98,187 @@ Return `n`, the number of times `x` is consumed.
-/
@[specialize]
private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Nat :=
ys.size.fold (init := 0) fun i n =>
let y := ys[i]
match y with
| Arg.irrelevant => n
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
ys.size.fold (init := 0) fun i n =>
let y := ys[i]
match y with
| Arg.irrelevant => n
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
@[specialize]
private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat → Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
xs.size.fold (init := b) fun i b =>
let x := xs[i]
match x with
| Arg.irrelevant => b
| Arg.var x =>
let info := getVarInfo ctx x
if !info.ref || !isFirstOcc xs i then b
else
let numConsuptions := getNumConsumptions x xs consumeParamPred -- number of times the argument is
let numIncs :=
if !info.consume || -- `x` is not a variable that must be consumed by the current procedure
liveVarsAfter.contains x || -- `x` is live after executing instruction
isBorrowParamAux x xs consumeParamPred -- `x` is used in a position that is passed as a borrow reference
then numConsuptions
else numConsuptions - 1
-- dbgTrace ("addInc " ++ toString x ++ " nconsumptions: " ++ toString numConsuptions ++ " incs: " ++ toString numIncs
-- ++ " consume: " ++ toString info.consume ++ " live: " ++ toString (liveVarsAfter.contains x)
-- ++ " borrowParam : " ++ toString (isBorrowParamAux x xs consumeParamPred)) $ fun _ =>
addInc ctx x b numIncs
xs.size.fold (init := b) fun i b =>
let x := xs[i]
match x with
| Arg.irrelevant => b
| Arg.var x =>
let info := getVarInfo ctx x
if !info.ref || !isFirstOcc xs i then b
else
let numConsuptions := getNumConsumptions x xs consumeParamPred -- number of times the argument is
let numIncs :=
if !info.consume || -- `x` is not a variable that must be consumed by the current procedure
liveVarsAfter.contains x || -- `x` is live after executing instruction
isBorrowParamAux x xs consumeParamPred -- `x` is used in a position that is passed as a borrow reference
then numConsuptions
else numConsuptions - 1
-- dbgTrace ("addInc " ++ toString x ++ " nconsumptions: " ++ toString numConsuptions ++ " incs: " ++ toString numIncs
-- ++ " consume: " ++ toString info.consume ++ " live: " ++ toString (liveVarsAfter.contains x)
-- ++ " borrowParam : " ++ toString (isBorrowParamAux x xs consumeParamPred)) $ fun _ =>
addInc ctx x b numIncs
private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
addIncBeforeAux ctx xs (fun i => not ps[i].borrow) b liveVarsAfter
addIncBeforeAux ctx xs (fun i => not ps[i].borrow) b liveVarsAfter
/- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/
private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
xs.size.fold
(fun i b =>
match xs[i] with
| Arg.irrelevant => b
| Arg.var x =>
/- We must add a `dec` if `x` must be consumed, it is alive after the application,
and it has been borrowed by the application.
Remark: `x` may occur multiple times in the application (e.g., `f x y x`).
This is why we check whether it is the first occurrence. -/
if mustConsume ctx x && isFirstOcc xs i && isBorrowParam x xs ps && !bLiveVars.contains x then
addDec ctx x b
else b)
b
xs.size.fold (init := b) fun i b =>
match xs[i] with
| Arg.irrelevant => b
| Arg.var x =>
/- We must add a `dec` if `x` must be consumed, it is alive after the application,
and it has been borrowed by the application.
Remark: `x` may occur multiple times in the application (e.g., `f x y x`).
This is why we check whether it is the first occurrence. -/
if mustConsume ctx x && isFirstOcc xs i && isBorrowParam x xs ps && !bLiveVars.contains x then
addDec ctx x b
else b
private def addIncBeforeConsumeAll (ctx : Context) (xs : Array Arg) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
addIncBeforeAux ctx xs (fun i => true) b liveVarsAfter
addIncBeforeAux ctx xs (fun i => true) b liveVarsAfter
/- Add `dec` instructions for parameters that are references, are not alive in `b`, and are not borrow.
That is, we must make sure these parameters are consumed. -/
private def addDecForDeadParams (ctx : Context) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
ps.foldl
(fun b p => if !p.borrow && p.ty.isObj && !bLiveVars.contains p.x then addDec ctx p.x b else b)
b
ps.foldl (init := b) fun b p =>
if !p.borrow && p.ty.isObj && !bLiveVars.contains p.x then addDec ctx p.x b else b
private def isPersistent : Expr → Bool
| Expr.fap c xs => xs.isEmpty -- all global constants are persistent objects
| _ => false
| Expr.fap c xs => xs.isEmpty -- all global constants are persistent objects
| _ => false
/- We do not need to consume the projection of a variable that is not consumed -/
private def consumeExpr (m : VarMap) : Expr → Bool
| Expr.proj i x => match m.find? x with
| some info => info.consume
| none => true
| other => true
| Expr.proj i x => match m.find? x with
| some info => info.consume
| none => true
| other => true
/- Return true iff `v` at runtime is a scalar value stored in a tagged pointer.
We do not need RC operations for this kind of value. -/
private def isScalarBoxedInTaggedPtr (v : Expr) : Bool :=
match v with
| Expr.ctor c ys => c.size == 0 && c.ssize == 0 && c.usize == 0
| Expr.lit (LitVal.num n) => n ≤ maxSmallNat
| _ => false
match v with
| Expr.ctor c ys => c.size == 0 && c.ssize == 0 && c.usize == 0
| Expr.lit (LitVal.num n) => n ≤ maxSmallNat
| _ => false
private def updateVarInfo (ctx : Context) (x : VarId) (t : IRType) (v : Expr) : Context :=
{ ctx with
varMap := ctx.varMap.insert x {
ref := t.isObj && !isScalarBoxedInTaggedPtr v,
persistent := isPersistent v,
consume := consumeExpr ctx.varMap v } }
{ ctx with
varMap := ctx.varMap.insert x {
ref := t.isObj && !isScalarBoxedInTaggedPtr v,
persistent := isPersistent v,
consume := consumeExpr ctx.varMap v
}
}
private def addDecIfNeeded (ctx : Context) (x : VarId) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
if mustConsume ctx x && !bLiveVars.contains x then addDec ctx x b else b
if mustConsume ctx x && !bLiveVars.contains x then addDec ctx x b else b
private def processVDecl (ctx : Context) (z : VarId) (t : IRType) (v : Expr) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody × LiveVarSet :=
-- dbgTrace ("processVDecl " ++ toString z ++ " " ++ toString (format v)) $ fun _ =>
let b := match v with
| (Expr.ctor _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
| (Expr.reuse _ _ _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
| (Expr.proj _ x) =>
let b := addDecIfNeeded ctx x b bLiveVars
let b := if (getVarInfo ctx x).consume then addInc ctx z b else b
(FnBody.vdecl z t v b)
| (Expr.uproj _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
| (Expr.sproj _ _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
| (Expr.fap f ys) =>
-- dbgTrace ("processVDecl " ++ toString v) $ fun _ =>
let ps := (getDecl ctx f).params
let b := addDecAfterFullApp ctx ys ps b bLiveVars
let b := FnBody.vdecl z t v b
addIncBefore ctx ys ps b bLiveVars
| (Expr.pap _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
| (Expr.ap x ys) =>
let ysx := ys.push (Arg.var x) -- TODO: avoid temporary array allocation
addIncBeforeConsumeAll ctx ysx (FnBody.vdecl z t v b) bLiveVars
| (Expr.unbox x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
| other => FnBody.vdecl z t v b -- Expr.reset, Expr.box, Expr.lit are handled here
let liveVars := updateLiveVars v bLiveVars
let liveVars := liveVars.erase z
(b, liveVars)
let b := match v with
| (Expr.ctor _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
| (Expr.reuse _ _ _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
| (Expr.proj _ x) =>
let b := addDecIfNeeded ctx x b bLiveVars
let b := if (getVarInfo ctx x).consume then addInc ctx z b else b
(FnBody.vdecl z t v b)
| (Expr.uproj _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
| (Expr.sproj _ _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
| (Expr.fap f ys) =>
-- dbgTrace ("processVDecl " ++ toString v) $ fun _ =>
let ps := (getDecl ctx f).params
let b := addDecAfterFullApp ctx ys ps b bLiveVars
let b := FnBody.vdecl z t v b
addIncBefore ctx ys ps b bLiveVars
| (Expr.pap _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
| (Expr.ap x ys) =>
let ysx := ys.push (Arg.var x) -- TODO: avoid temporary array allocation
addIncBeforeConsumeAll ctx ysx (FnBody.vdecl z t v b) bLiveVars
| (Expr.unbox x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
| other => FnBody.vdecl z t v b -- Expr.reset, Expr.box, Expr.lit are handled here
let liveVars := updateLiveVars v bLiveVars
let liveVars := liveVars.erase z
(b, liveVars)
def updateVarInfoWithParams (ctx : Context) (ps : Array Param) : Context :=
let m := ps.foldl (init := ctx.varMap) fun m p =>
m.insert p.x { ref := p.ty.isObj, consume := !p.borrow }
{ ctx with varMap := m }
let m := ps.foldl (init := ctx.varMap) fun m p =>
m.insert p.x { ref := p.ty.isObj, consume := !p.borrow }
{ ctx with varMap := m }
partial def visitFnBody : FnBody → Context → (FnBody × LiveVarSet)
| FnBody.vdecl x t v b, ctx =>
let ctx := updateVarInfo ctx x t v
let (b, bLiveVars) := visitFnBody b ctx
processVDecl ctx x t v b bLiveVars
| FnBody.jdecl j xs v b, ctx =>
let (v, vLiveVars) := visitFnBody v (updateVarInfoWithParams ctx xs)
let v := addDecForDeadParams ctx xs v vLiveVars
let ctx := { ctx with jpLiveVarMap := updateJPLiveVarMap j xs v ctx.jpLiveVarMap }
let (b, bLiveVars) := visitFnBody b ctx
(FnBody.jdecl j xs v b, bLiveVars)
| FnBody.uset x i y b, ctx =>
let (b, s) := visitFnBody b ctx
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
let s := s.insert x
(FnBody.uset x i y b, s)
| FnBody.sset x i o y t b, ctx =>
let (b, s) := visitFnBody b ctx
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
let s := s.insert x
(FnBody.sset x i o y t b, s)
| FnBody.mdata m b, ctx =>
let (b, s) := visitFnBody b ctx
(FnBody.mdata m b, s)
| b@(FnBody.case tid x xType alts), ctx =>
let caseLiveVars := collectLiveVars b ctx.jpLiveVarMap
let alts := alts.map $ fun alt => match alt with
| Alt.ctor c b =>
let ctx := updateRefUsingCtorInfo ctx x c
let (b, altLiveVars) := visitFnBody b ctx
let b := addDecForAlt ctx caseLiveVars altLiveVars b
Alt.ctor c b
| Alt.default b =>
let (b, altLiveVars) := visitFnBody b ctx
let b := addDecForAlt ctx caseLiveVars altLiveVars b
Alt.default b
(FnBody.case tid x xType alts, caseLiveVars)
| b@(FnBody.ret x), ctx =>
match x with
| Arg.var x =>
let info := getVarInfo ctx x
if info.ref && !info.consume then (addInc ctx x b, mkLiveVarSet x) else (b, mkLiveVarSet x)
| _ => (b, {})
| b@(FnBody.jmp j xs), ctx =>
let jLiveVars := getJPLiveVars ctx j
let ps := getJPParams ctx j
let b := addIncBefore ctx xs ps b jLiveVars
let bLiveVars := collectLiveVars b ctx.jpLiveVarMap
(b, bLiveVars)
| FnBody.unreachable, _ => (FnBody.unreachable, {})
| other, ctx => (other, {}) -- unreachable if well-formed
| FnBody.vdecl x t v b, ctx =>
let ctx := updateVarInfo ctx x t v
let (b, bLiveVars) := visitFnBody b ctx
processVDecl ctx x t v b bLiveVars
| FnBody.jdecl j xs v b, ctx =>
let (v, vLiveVars) := visitFnBody v (updateVarInfoWithParams ctx xs)
let v := addDecForDeadParams ctx xs v vLiveVars
let ctx := { ctx with jpLiveVarMap := updateJPLiveVarMap j xs v ctx.jpLiveVarMap }
let (b, bLiveVars) := visitFnBody b ctx
(FnBody.jdecl j xs v b, bLiveVars)
| FnBody.uset x i y b, ctx =>
let (b, s) := visitFnBody b ctx
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
let s := s.insert x
(FnBody.uset x i y b, s)
| FnBody.sset x i o y t b, ctx =>
let (b, s) := visitFnBody b ctx
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
let s := s.insert x
(FnBody.sset x i o y t b, s)
| FnBody.mdata m b, ctx =>
let (b, s) := visitFnBody b ctx
(FnBody.mdata m b, s)
| b@(FnBody.case tid x xType alts), ctx =>
let caseLiveVars := collectLiveVars b ctx.jpLiveVarMap
let alts := alts.map $ fun alt => match alt with
| Alt.ctor c b =>
let ctx := updateRefUsingCtorInfo ctx x c
let (b, altLiveVars) := visitFnBody b ctx
let b := addDecForAlt ctx caseLiveVars altLiveVars b
Alt.ctor c b
| Alt.default b =>
let (b, altLiveVars) := visitFnBody b ctx
let b := addDecForAlt ctx caseLiveVars altLiveVars b
Alt.default b
(FnBody.case tid x xType alts, caseLiveVars)
| b@(FnBody.ret x), ctx =>
match x with
| Arg.var x =>
let info := getVarInfo ctx x
if info.ref && !info.consume then (addInc ctx x b, mkLiveVarSet x) else (b, mkLiveVarSet x)
| _ => (b, {})
| b@(FnBody.jmp j xs), ctx =>
let jLiveVars := getJPLiveVars ctx j
let ps := getJPParams ctx j
let b := addIncBefore ctx xs ps b jLiveVars
let bLiveVars := collectLiveVars b ctx.jpLiveVarMap
(b, bLiveVars)
| FnBody.unreachable, _ => (FnBody.unreachable, {})
| other, ctx => (other, {}) -- unreachable if well-formed
partial def visitDecl (env : Environment) (decls : Array Decl) : Decl → Decl
| Decl.fdecl f xs t b =>
let ctx : Context := { env := env, decls := decls }
let ctx := updateVarInfoWithParams ctx xs
let (b, bLiveVars) := visitFnBody b ctx
let b := addDecForDeadParams ctx xs b bLiveVars
Decl.fdecl f xs t b
| other => other
| Decl.fdecl f xs t b =>
let ctx : Context := { env := env, decls := decls }
let ctx := updateVarInfoWithParams ctx xs
let (b, bLiveVars) := visitFnBody b ctx
let b := addDecForDeadParams ctx xs b bLiveVars
Decl.fdecl f xs t b
| other => other
end ExplicitRC
def explicitRC (decls : Array Decl) : CompilerM (Array Decl) := do
let env ← getEnv
pure $ decls.map (ExplicitRC.visitDecl env decls)
let env ← getEnv
pure $ decls.map (ExplicitRC.visitDecl env decls)
end Lean.IR

View file

@ -26,56 +26,56 @@ namespace Lean.IR.ResetReuse
-/
private def mayReuse (c₁ c₂ : CtorInfo) : Bool :=
c₁.size == c₂.size && c₁.usize == c₂.usize && c₁.ssize == c₂.ssize &&
/- The following condition is a heuristic.
We don't want to reuse cells from different types even when they are compatible
because it produces counterintuitive behavior. -/
c₁.name.getPrefix == c₂.name.getPrefix
c₁.size == c₂.size && c₁.usize == c₂.usize && c₁.ssize == c₂.ssize &&
/- The following condition is a heuristic.
We don't want to reuse cells from different types even when they are compatible
because it produces counterintuitive behavior. -/
c₁.name.getPrefix == c₂.name.getPrefix
private partial def S (w : VarId) (c : CtorInfo) : FnBody → FnBody
| FnBody.vdecl x t v@(Expr.ctor c' ys) b =>
if mayReuse c c' then
let updtCidx := c.cidx != c'.cidx
FnBody.vdecl x t (Expr.reuse w c' updtCidx ys) b
else
FnBody.vdecl x t v (S w c b)
| FnBody.jdecl j ys v b =>
let v' := S w c v
if v == v' then FnBody.jdecl j ys v (S w c b)
else FnBody.jdecl j ys v' b
| FnBody.case tid x xType alts => FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (S w c)
| b =>
if b.isTerminal then b
else let
(instr, b) := b.split
instr.setBody (S w c b)
| FnBody.vdecl x t v@(Expr.ctor c' ys) b =>
if mayReuse c c' then
let updtCidx := c.cidx != c'.cidx
FnBody.vdecl x t (Expr.reuse w c' updtCidx ys) b
else
FnBody.vdecl x t v (S w c b)
| FnBody.jdecl j ys v b =>
let v' := S w c v
if v == v' then FnBody.jdecl j ys v (S w c b)
else FnBody.jdecl j ys v' b
| FnBody.case tid x xType alts => FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (S w c)
| b =>
if b.isTerminal then b
else let
(instr, b) := b.split
instr.setBody (S w c b)
/- We use `Context` to track join points in scope. -/
abbrev M := ReaderT LocalContext (StateT Index Id)
private def mkFresh : M VarId := do
let idx ← getModify (fun n => n + 1)
pure { idx := idx }
let idx ← getModify (fun n => n + 1)
pure { idx := idx }
private def tryS (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody := do
let w ← mkFresh
let b' := S w c b
if b == b' then pure b
else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b'
let w ← mkFresh
let b' := S w c b
if b == b' then pure b
else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b'
private def Dfinalize (x : VarId) (c : CtorInfo) : FnBody × Bool → M FnBody
| (b, true) => pure b
| (b, false) => tryS x c b
| (b, true) => pure b
| (b, false) => tryS x c b
private def argsContainsVar (ys : Array Arg) (x : VarId) : Bool :=
ys.any fun arg => match arg with
| Arg.var y => x == y
| _ => false
ys.any fun arg => match arg with
| Arg.var y => x == y
| _ => false
private def isCtorUsing (b : FnBody) (x : VarId) : Bool :=
match b with
| (FnBody.vdecl _ _ (Expr.ctor _ ys) _) => argsContainsVar ys x
| _ => false
match b with
| (FnBody.vdecl _ _ (Expr.ctor _ ys) _) => argsContainsVar ys x
| _ => false
/- Given `Dmain b`, the resulting pair `(new_b, flag)` contains the new body `new_b`,
and `flag == true` if `x` is live in `b`.
@ -84,75 +84,75 @@ match b with
`D` checks whether `x` is live in `F` or not. This is great for clarity but it
is expensive: `O(n^2)` where `n` is the size of the function body. -/
private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody → M (FnBody × Bool)
| e@(FnBody.case tid y yType alts) => do
let ctx ← read
if e.hasLiveVar ctx x then do
/- If `x` is live in `e`, we recursively process each branch. -/
let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => Dmain x c b >>= Dfinalize x c
pure (FnBody.case tid y yType alts, true)
else pure (e, false)
| FnBody.jdecl j ys v b => do
let (b, found) ← withReader (fun ctx => ctx.addJP j ys v) (Dmain x c b)
let (v, _ /- found' -/) ← Dmain x c v
/- If `found' == true`, then `Dmain b` must also have returned `(b, true)` since
we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`),
then it must also live in `b` since `j` is reachable from `b` with a `jmp`.
On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/
pure (FnBody.jdecl j ys v b, found)
| e => do
let ctx ← read
if e.isTerminal then
pure (e, e.hasLiveVar ctx x)
else do
let (instr, b) := e.split
if isCtorUsing instr x then
/- If the scrutinee `x` (the one that is providing memory) is being
stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
It may work only if the new cell is consumed, but we ignore this case. -/
pure (e, true)
| e@(FnBody.case tid y yType alts) => do
let ctx ← read
if e.hasLiveVar ctx x then do
/- If `x` is live in `e`, we recursively process each branch. -/
let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => Dmain x c b >>= Dfinalize x c
pure (FnBody.case tid y yType alts, true)
else pure (e, false)
| FnBody.jdecl j ys v b => do
let (b, found) ← withReader (fun ctx => ctx.addJP j ys v) (Dmain x c b)
let (v, _ /- found' -/) ← Dmain x c v
/- If `found' == true`, then `Dmain b` must also have returned `(b, true)` since
we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`),
then it must also live in `b` since `j` is reachable from `b` with a `jmp`.
On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/
pure (FnBody.jdecl j ys v b, found)
| e => do
let ctx ← read
if e.isTerminal then
pure (e, e.hasLiveVar ctx x)
else do
let (b, found) ← Dmain x c b
/- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar`
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/
if found || !instr.hasFreeVar x then
pure (instr.setBody b, found)
else do
let b ← tryS x c b
pure (instr.setBody b, true)
let (instr, b) := e.split
if isCtorUsing instr x then
/- If the scrutinee `x` (the one that is providing memory) is being
stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
It may work only if the new cell is consumed, but we ignore this case. -/
pure (e, true)
else
let (b, found) ← Dmain x c b
/- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar`
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/
if found || !instr.hasFreeVar x then
pure (instr.setBody b, found)
else
let b ← tryS x c b
pure (instr.setBody b, true)
private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
Dmain x c b >>= Dfinalize x c
Dmain x c b >>= Dfinalize x c
partial def R : FnBody → M FnBody
| FnBody.case tid x xType alts => do
let alts ← alts.mapM fun alt => do
let alt ← alt.mmodifyBody R
match alt with
| Alt.ctor c b =>
if c.isScalar then pure alt
else Alt.ctor c <$> D x c b
| _ => pure alt
pure $ FnBody.case tid x xType alts
| FnBody.jdecl j ys v b => do
let v ← R v
let b ← withReader (fun ctx => ctx.addJP j ys v) (R b)
pure $ FnBody.jdecl j ys v b
| e => do
if e.isTerminal then pure e
else do
let (instr, b) := e.split
let b ← R b
pure (instr.setBody b)
| FnBody.case tid x xType alts => do
let alts ← alts.mapM fun alt => do
let alt ← alt.mmodifyBody R
match alt with
| Alt.ctor c b =>
if c.isScalar then pure alt
else Alt.ctor c <$> D x c b
| _ => pure alt
pure $ FnBody.case tid x xType alts
| FnBody.jdecl j ys v b => do
let v ← R v
let b ← withReader (fun ctx => ctx.addJP j ys v) (R b)
pure $ FnBody.jdecl j ys v b
| e => do
if e.isTerminal then pure e
else do
let (instr, b) := e.split
let b ← R b
pure (instr.setBody b)
end ResetReuse
open ResetReuse
def Decl.insertResetReuse : Decl → Decl
| d@(Decl.fdecl f xs t b) =>
let nextIndex := d.maxIndex + 1
let b := (R b {}).run' nextIndex
Decl.fdecl f xs t b
| other => other
| d@(Decl.fdecl f xs t b) =>
let nextIndex := d.maxIndex + 1
let b := (R b {}).run' nextIndex
Decl.fdecl f xs t b
| other => other
end Lean.IR

View file

@ -9,52 +9,51 @@ import Lean.Compiler.IR.Format
namespace Lean.IR
def ensureHasDefault (alts : Array Alt) : Array Alt :=
if alts.any Alt.isDefault then alts
else if alts.size < 2 then alts
else
let last := alts.back;
let alts := alts.pop;
alts.push (Alt.default last.body)
if alts.any Alt.isDefault then alts
else if alts.size < 2 then alts
else
let last := alts.back;
let alts := alts.pop;
alts.push (Alt.default last.body)
private def getOccsOf (alts : Array Alt) (i : Nat) : Nat := do
let aBody := (alts.get! i).body
let n := 1
for j in [i+1:alts.size] do
if alts[j].body == aBody then
n := n+1
return n
let aBody := (alts.get! i).body
let n := 1
for j in [i+1:alts.size] do
if alts[j].body == aBody then
n := n+1
return n
private def maxOccs (alts : Array Alt) : Alt × Nat := do
let maxAlt := alts[0]
let max := getOccsOf alts 0
for i in [1:alts.size] do
let curr := getOccsOf alts i
if curr > max then
maxAlt := alts[i]
max := curr
return (maxAlt, max)
let maxAlt := alts[0]
let max := getOccsOf alts 0
for i in [1:alts.size] do
let curr := getOccsOf alts i
if curr > max then
maxAlt := alts[i]
max := curr
return (maxAlt, max)
private def addDefault (alts : Array Alt) : Array Alt :=
if alts.size <= 1 || alts.any Alt.isDefault then alts
else
let (max, noccs) := maxOccs alts;
if noccs == 1 then alts
if alts.size <= 1 || alts.any Alt.isDefault then alts
else
let alts := alts.filter $ (fun alt => alt.body != max.body);
alts.push (Alt.default max.body)
let (max, noccs) := maxOccs alts;
if noccs == 1 then alts
else
let alts := alts.filter $ (fun alt => alt.body != max.body);
alts.push (Alt.default max.body)
private def mkSimpCase (tid : Name) (x : VarId) (xType : IRType) (alts : Array Alt) : FnBody :=
let alts := alts.filter (fun alt => alt.body != FnBody.unreachable);
let alts := addDefault alts;
if alts.size == 0 then
FnBody.unreachable
else if alts.size == 1 then
(alts.get! 0).body
else
FnBody.case tid x xType alts
let alts := alts.filter (fun alt => alt.body != FnBody.unreachable);
let alts := addDefault alts;
if alts.size == 0 then
FnBody.unreachable
else if alts.size == 1 then
(alts.get! 0).body
else
FnBody.case tid x xType alts
partial def FnBody.simpCase : FnBody → FnBody
| b =>
partial def FnBody.simpCase (b : FnBody) : FnBody :=
let (bs, term) := b.flatten;
let bs := modifyJPs bs simpCase;
match term with
@ -68,7 +67,7 @@ partial def FnBody.simpCase : FnBody → FnBody
- Remove `case` if there is only one branch.
- Merge most common branches using `Alt.default`. -/
def Decl.simpCase : Decl → Decl
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.simpCase
| other => other
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.simpCase
| other => other
end Lean.IR

View file

@ -824,13 +824,13 @@ This is a standard trick we use in the elaborator, and it is also used to elabor
Suppose, we are trying to elaborate
```
match g x with
| ... => ...
| ... => ...
```
`expandNonAtomicDiscrs?` converts it intro
```
let _discr := g x
match _discr with
| ... => ...
| ... => ...
```
Thus, at `tryPostponeIfDiscrTypeIsMVar` we only need to check whether the type of `_discr` is not of the form `(?m ...)`.
Note that, the auxiliary variable `_discr` is expanded at `elabAtomicDiscr`.

View file

@ -108,17 +108,17 @@ syntax "throwError! " ((interpolatedStr term) <|> term) : term
syntax "throwErrorAt! " term:max ((interpolatedStr term) <|> term) : term
macro_rules
| `(throwError! $msg) =>
if msg.getKind == interpolatedStrKind then
`(throwError (msg! $msg))
else
`(throwError $msg)
| `(throwError! $msg) =>
if msg.getKind == interpolatedStrKind then
`(throwError (msg! $msg))
else
`(throwError $msg)
macro_rules
| `(throwErrorAt! $ref $msg) =>
if msg.getKind == interpolatedStrKind then
`(throwErrorAt $ref (msg! $msg))
else
`(throwErrorAt $ref $msg)
| `(throwErrorAt! $ref $msg) =>
if msg.getKind == interpolatedStrKind then
`(throwErrorAt $ref (msg! $msg))
else
`(throwErrorAt $ref $msg)
end Lean

View file

@ -211,12 +211,12 @@ def occurs : Level → Level → Bool
| u, v => u == v
def ctorToNat : Level → Nat
| zero .. => 0
| param .. => 1
| mvar .. => 2
| succ .. => 3
| max .. => 4
| imax .. => 5
| zero .. => 0
| param .. => 1
| mvar .. => 2
| succ .. => 3
| max .. => 4
| imax .. => 5
/- TODO: use well founded recursion. -/
partial def normLtAux : Level → Nat → Level → Nat → Bool

View file

@ -111,15 +111,15 @@ def replaceFVarId (fvarId : FVarId) (v : Expr) (alt : Alt) : Alt :=
```
inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons {n} (head : α) (tail : Vec α n) : Vec α (n+1)
| nil : Vec α 0
| cons {n} (head : α) (tail : Vec α n) : Vec α (n+1)
inductive VecPred {α : Type u} (P : α → Prop) : {n : Nat} → Vec α n → Prop
| nil : VecPred P Vec.nil
| cons {n : Nat} {head : α} {tail : Vec α n} : P head → VecPred P tail → VecPred P (Vec.cons head tail)
| nil : VecPred P Vec.nil
| cons {n : Nat} {head : α} {tail : Vec α n} : P head → VecPred P tail → VecPred P (Vec.cons head tail)
theorem ex {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P
| _, Vec.cons head _, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩
| _, Vec.cons head _, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩
```
Recall that `_` in a pattern can be elaborated into pattern variable or an inaccessible term.
The elaborator uses an inaccessible term when typing constraints restrict its value.
@ -127,7 +127,7 @@ Thus, in the example above, the `_` at `Vec.cons head _` becomes the inaccessibl
because the type ascription `(w : VecPred P Vec.nil)` propagates typing constraints that restrict its value to be `Vec.nil`.
After elaboration the alternative becomes:
```
| .(0), @Vec.cons .(α) .(0) head .(Vec.nil), @VecPred.cons .(α) .(P) .(0) .(head) .(Vec.nil) h w => ⟨head, h⟩
| .(0), @Vec.cons .(α) .(0) head .(Vec.nil), @VecPred.cons .(α) .(P) .(0) .(head) .(Vec.nil) h w => ⟨head, h⟩
```
where
```
@ -138,7 +138,7 @@ Then, when we process this alternative in this module, the following check will
Note that if we had written
```
theorem ex {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P
| _, Vec.cons head Vec.nil, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩
| _, Vec.cons head Vec.nil, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩
```
we would get the easier to digest error message
```

View file

@ -19,7 +19,7 @@ builtin_initialize builtinTokenTable : IO.Ref TokenTable ← IO.mkRef {}
builtin_initialize builtinSyntaxNodeKindSetRef : IO.Ref SyntaxNodeKindSet ← IO.mkRef {}
def registerBuiltinNodeKind (k : SyntaxNodeKind) : IO Unit :=
builtinSyntaxNodeKindSetRef.modify fun s => s.insert k
builtinSyntaxNodeKindSetRef.modify fun s => s.insert k
builtin_initialize
registerBuiltinNodeKind choiceKind
@ -32,215 +32,213 @@ builtin_initialize
builtin_initialize builtinParserCategoriesRef : IO.Ref ParserCategories ← IO.mkRef {}
private def throwParserCategoryAlreadyDefined {α} (catName : Name) : ExceptT String Id α :=
throw s!"parser category '{catName}' has already been defined"
throw s!"parser category '{catName}' has already been defined"
private def addParserCategoryCore (categories : ParserCategories) (catName : Name) (initial : ParserCategory) : Except String ParserCategories :=
if categories.contains catName then
throwParserCategoryAlreadyDefined catName
else
pure $ categories.insert catName initial
if categories.contains catName then
throwParserCategoryAlreadyDefined catName
else
pure $ categories.insert catName initial
/-- All builtin parser categories are Pratt's parsers -/
private def addBuiltinParserCategory (catName : Name) (leadingIdentAsSymbol : Bool) : IO Unit := do
let categories ← builtinParserCategoriesRef.get
let categories ← IO.ofExcept $ addParserCategoryCore categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol}
builtinParserCategoriesRef.set categories
let categories ← builtinParserCategoriesRef.get
let categories ← IO.ofExcept $ addParserCategoryCore categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol}
builtinParserCategoriesRef.set categories
inductive ParserExtensionOleanEntry
| token (val : Token) : ParserExtensionOleanEntry
| kind (val : SyntaxNodeKind) : ParserExtensionOleanEntry
| category (catName : Name) (leadingIdentAsSymbol : Bool)
| parser (catName : Name) (declName : Name) (prio : Nat) : ParserExtensionOleanEntry
| token (val : Token) : ParserExtensionOleanEntry
| kind (val : SyntaxNodeKind) : ParserExtensionOleanEntry
| category (catName : Name) (leadingIdentAsSymbol : Bool)
| parser (catName : Name) (declName : Name) (prio : Nat) : ParserExtensionOleanEntry
inductive ParserExtensionEntry
| token (val : Token) : ParserExtensionEntry
| kind (val : SyntaxNodeKind) : ParserExtensionEntry
| category (catName : Name) (leadingIdentAsSymbol : Bool)
| parser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : ParserExtensionEntry
| token (val : Token) : ParserExtensionEntry
| kind (val : SyntaxNodeKind) : ParserExtensionEntry
| category (catName : Name) (leadingIdentAsSymbol : Bool)
| parser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : ParserExtensionEntry
structure ParserExtensionState :=
(tokens : TokenTable := {})
(kinds : SyntaxNodeKindSet := {})
(categories : ParserCategories := {})
(newEntries : List ParserExtensionOleanEntry := [])
(tokens : TokenTable := {})
(kinds : SyntaxNodeKindSet := {})
(categories : ParserCategories := {})
(newEntries : List ParserExtensionOleanEntry := [])
instance : Inhabited ParserExtensionState := ⟨{}⟩
abbrev ParserExtension := PersistentEnvExtension ParserExtensionOleanEntry ParserExtensionEntry ParserExtensionState
private def ParserExtension.mkInitial : IO ParserExtensionState := do
let tokens ← builtinTokenTable.get
let kinds ← builtinSyntaxNodeKindSetRef.get
let categories ← builtinParserCategoriesRef.get
pure { tokens := tokens, kinds := kinds, categories := categories }
let tokens ← builtinTokenTable.get
let kinds ← builtinSyntaxNodeKindSetRef.get
let categories ← builtinParserCategoriesRef.get
pure { tokens := tokens, kinds := kinds, categories := categories }
private def addTokenConfig (tokens : TokenTable) (tk : Token) : Except String TokenTable := do
if tk == "" then throw "invalid empty symbol"
else match tokens.find? tk with
| none => pure $ tokens.insert tk tk
| some _ => pure tokens
if tk == "" then throw "invalid empty symbol"
else match tokens.find? tk with
| none => pure $ tokens.insert tk tk
| some _ => pure tokens
def throwUnknownParserCategory {α} (catName : Name) : ExceptT String Id α :=
throw s!"unknown parser category '{catName}'"
throw s!"unknown parser category '{catName}'"
abbrev getCategory (categories : ParserCategories) (catName : Name) : Option ParserCategory :=
categories.find? catName
categories.find? catName
def addLeadingParser (categories : ParserCategories) (catName : Name) (parserName : Name) (p : Parser) (prio : Nat) : Except String ParserCategories :=
match getCategory categories catName with
| none =>
throwUnknownParserCategory catName
| some cat =>
let addTokens (tks : List Token) : Except String ParserCategories :=
let tks := tks.map $ fun tk => mkNameSimple tk
let tables := tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with leadingTable := tables.leadingTable.insert tk (p, prio) }) cat.tables
pure $ categories.insert catName { cat with tables := tables }
match getCategory categories catName with
| none =>
throwUnknownParserCategory catName
| some cat =>
let addTokens (tks : List Token) : Except String ParserCategories :=
let tks := tks.map $ fun tk => mkNameSimple tk
let tables := tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with leadingTable := tables.leadingTable.insert tk (p, prio) }) cat.tables
pure $ categories.insert catName { cat with tables := tables }
match p.info.firstTokens with
| FirstTokens.tokens tks => addTokens tks
| FirstTokens.optTokens tks => addTokens tks
| _ =>
let tables := { cat.tables with leadingParsers := (p, prio) :: cat.tables.leadingParsers }
pure $ categories.insert catName { cat with tables := tables }
private def addTrailingParserAux (tables : PrattParsingTables) (p : TrailingParser) (prio : Nat) : PrattParsingTables :=
let addTokens (tks : List Token) : PrattParsingTables :=
let tks := tks.map fun tk => mkNameSimple tk
tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with trailingTable := tables.trailingTable.insert tk (p, prio) }) tables
match p.info.firstTokens with
| FirstTokens.tokens tks => addTokens tks
| FirstTokens.optTokens tks => addTokens tks
| _ =>
let tables := { cat.tables with leadingParsers := (p, prio) :: cat.tables.leadingParsers }
pure $ categories.insert catName { cat with tables := tables }
private def addTrailingParserAux (tables : PrattParsingTables) (p : TrailingParser) (prio : Nat) : PrattParsingTables :=
let addTokens (tks : List Token) : PrattParsingTables :=
let tks := tks.map fun tk => mkNameSimple tk
tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with trailingTable := tables.trailingTable.insert tk (p, prio) }) tables
match p.info.firstTokens with
| FirstTokens.tokens tks => addTokens tks
| FirstTokens.optTokens tks => addTokens tks
| _ => { tables with trailingParsers := (p, prio) :: tables.trailingParsers }
| _ => { tables with trailingParsers := (p, prio) :: tables.trailingParsers }
def addTrailingParser (categories : ParserCategories) (catName : Name) (p : TrailingParser) (prio : Nat) : Except String ParserCategories :=
match getCategory categories catName with
| none => throwUnknownParserCategory catName
| some cat => pure $ categories.insert catName { cat with tables := addTrailingParserAux cat.tables p prio }
match getCategory categories catName with
| none => throwUnknownParserCategory catName
| some cat => pure $ categories.insert catName { cat with tables := addTrailingParserAux cat.tables p prio }
def addParser (categories : ParserCategories) (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : Except String ParserCategories :=
match leading, p with
| true, p => addLeadingParser categories catName declName p prio
| false, p => addTrailingParser categories catName p prio
match leading, p with
| true, p => addLeadingParser categories catName declName p prio
| false, p => addTrailingParser categories catName p prio
def addParserTokens (tokenTable : TokenTable) (info : ParserInfo) : Except String TokenTable :=
let newTokens := info.collectTokens []
newTokens.foldlM addTokenConfig tokenTable
let newTokens := info.collectTokens []
newTokens.foldlM addTokenConfig tokenTable
private def updateBuiltinTokens (info : ParserInfo) (declName : Name) : IO Unit := do
let tokenTable ← builtinTokenTable.swap {}
match addParserTokens tokenTable info with
| Except.ok tokenTable => builtinTokenTable.set tokenTable
| Except.error msg => throw (IO.userError s!"invalid builtin parser '{declName}', {msg}")
let tokenTable ← builtinTokenTable.swap {}
match addParserTokens tokenTable info with
| Except.ok tokenTable => builtinTokenTable.set tokenTable
| Except.error msg => throw (IO.userError s!"invalid builtin parser '{declName}', {msg}")
def addBuiltinParser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : IO Unit := do
let categories ← builtinParserCategoriesRef.get
let categories ← IO.ofExcept $ addParser categories catName declName leading p prio
builtinParserCategoriesRef.set categories
builtinSyntaxNodeKindSetRef.modify p.info.collectKinds
updateBuiltinTokens p.info declName
let categories ← builtinParserCategoriesRef.get
let categories ← IO.ofExcept $ addParser categories catName declName leading p prio
builtinParserCategoriesRef.set categories
builtinSyntaxNodeKindSetRef.modify p.info.collectKinds
updateBuiltinTokens p.info declName
def addBuiltinLeadingParser (catName : Name) (declName : Name) (p : Parser) (prio : Nat) : IO Unit :=
addBuiltinParser catName declName true p prio
addBuiltinParser catName declName true p prio
def addBuiltinTrailingParser (catName : Name) (declName : Name) (p : TrailingParser) (prio : Nat) : IO Unit :=
addBuiltinParser catName declName false p prio
addBuiltinParser catName declName false p prio
private def ParserExtensionAddEntry (s : ParserExtensionState) (e : ParserExtensionEntry) : ParserExtensionState :=
match e with
| ParserExtensionEntry.token tk =>
match addTokenConfig s.tokens tk with
| Except.ok tokens => { s with tokens := tokens, newEntries := ParserExtensionOleanEntry.token tk :: s.newEntries }
| _ => unreachable!
| ParserExtensionEntry.kind k =>
{ s with kinds := s.kinds.insert k, newEntries := ParserExtensionOleanEntry.kind k :: s.newEntries }
| ParserExtensionEntry.category catName leadingIdentAsSymbol =>
if s.categories.contains catName then s
else { s with
categories := s.categories.insert catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol },
newEntries := ParserExtensionOleanEntry.category catName leadingIdentAsSymbol :: s.newEntries }
| ParserExtensionEntry.parser catName declName leading parser prio =>
match addParser s.categories catName declName leading parser prio with
| Except.ok categories => { s with categories := categories, newEntries := ParserExtensionOleanEntry.parser catName declName prio :: s.newEntries }
| _ => unreachable!
match e with
| ParserExtensionEntry.token tk =>
match addTokenConfig s.tokens tk with
| Except.ok tokens => { s with tokens := tokens, newEntries := ParserExtensionOleanEntry.token tk :: s.newEntries }
| _ => unreachable!
| ParserExtensionEntry.kind k =>
{ s with kinds := s.kinds.insert k, newEntries := ParserExtensionOleanEntry.kind k :: s.newEntries }
| ParserExtensionEntry.category catName leadingIdentAsSymbol =>
if s.categories.contains catName then s
else { s with
categories := s.categories.insert catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol },
newEntries := ParserExtensionOleanEntry.category catName leadingIdentAsSymbol :: s.newEntries }
| ParserExtensionEntry.parser catName declName leading parser prio =>
match addParser s.categories catName declName leading parser prio with
| Except.ok categories => { s with categories := categories, newEntries := ParserExtensionOleanEntry.parser catName declName prio :: s.newEntries }
| _ => unreachable!
unsafe def mkParserOfConstantUnsafe
(env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name)
(compileParserDescr : ParserDescr → Except String Parser)
: Except String (Bool × Parser) :=
match env.find? constName with
| none => throw s!"unknow constant '{constName}'"
| some info =>
match info.type with
| Expr.const `Lean.Parser.TrailingParser _ _ => do
let p ← env.evalConst Parser opts constName
pure ⟨false, p⟩
| Expr.const `Lean.Parser.Parser _ _ => do
let p ← env.evalConst Parser opts constName
pure ⟨true, p⟩
| Expr.const `Lean.ParserDescr _ _ => do
let d ← env.evalConst ParserDescr opts constName
let p ← compileParserDescr d
pure ⟨true, p⟩
| Expr.const `Lean.TrailingParserDescr _ _ => do
let d ← env.evalConst TrailingParserDescr opts constName
let p ← compileParserDescr d
pure ⟨false, p⟩
| _ => throw s!"unexpected parser type at '{constName}' (`ParserDescr`, `TrailingParserDescr`, `Parser` or `TrailingParser` expected"
match env.find? constName with
| none => throw s!"unknow constant '{constName}'"
| some info =>
match info.type with
| Expr.const `Lean.Parser.TrailingParser _ _ => do
let p ← env.evalConst Parser opts constName
pure ⟨false, p⟩
| Expr.const `Lean.Parser.Parser _ _ => do
let p ← env.evalConst Parser opts constName
pure ⟨true, p⟩
| Expr.const `Lean.ParserDescr _ _ => do
let d ← env.evalConst ParserDescr opts constName
let p ← compileParserDescr d
pure ⟨true, p⟩
| Expr.const `Lean.TrailingParserDescr _ _ => do
let d ← env.evalConst TrailingParserDescr opts constName
let p ← compileParserDescr d
pure ⟨false, p⟩
| _ => throw s!"unexpected parser type at '{constName}' (`ParserDescr`, `TrailingParserDescr`, `Parser` or `TrailingParser` expected"
@[implementedBy mkParserOfConstantUnsafe]
constant mkParserOfConstantAux
(env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name)
(compileParserDescr : ParserDescr → Except String Parser)
: Except String (Bool × Parser) :=
arbitrary _
: Except String (Bool × Parser)
partial def compileParserDescr (env : Environment) (opts : Options) (categories : ParserCategories) (d : ParserDescr) : Except String Parser :=
let rec visit : ParserDescr → Except String Parser
| ParserDescr.andthen d₁ d₂ => andthen <$> visit d₁ <*> visit d₂
| ParserDescr.orelse d₁ d₂ => orelse <$> visit d₁ <*> visit d₂
| ParserDescr.optional d => optional <$> visit d
| ParserDescr.lookahead d => lookahead <$> visit d
| ParserDescr.«try» d => «try» <$> visit d
| ParserDescr.notFollowedBy d => do let p ← visit d; pure $ notFollowedBy p "element" -- TODO allow user to set msg at ParserDescr
| ParserDescr.many d => many <$> visit d
| ParserDescr.many1 d => many1 <$> visit d
| ParserDescr.sepBy d₁ d₂ => sepBy <$> visit d₁ <*> visit d₂
| ParserDescr.sepBy1 d₁ d₂ => sepBy1 <$> visit d₁ <*> visit d₂
| ParserDescr.node k prec d => leadingNode k prec <$> visit d
| ParserDescr.trailingNode k prec d => trailingNode k prec <$> visit d
| ParserDescr.symbol tk => pure $ symbol tk
| ParserDescr.noWs => pure $ checkNoWsBefore
| ParserDescr.numLit => pure $ numLit
| ParserDescr.strLit => pure $ strLit
| ParserDescr.charLit => pure $ charLit
| ParserDescr.nameLit => pure $ nameLit
| ParserDescr.interpolatedStr d => interpolatedStr <$> visit d
| ParserDescr.ident => pure $ ident
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol tk includeIdent
| ParserDescr.parser constName => do
let (_, p) ← mkParserOfConstantAux env opts categories constName visit;
pure p
| ParserDescr.cat catName prec =>
match getCategory categories catName with
| some _ => pure $ categoryParser catName prec
| none => throwUnknownParserCategory catName
visit d
let rec visit : ParserDescr → Except String Parser
| ParserDescr.andthen d₁ d₂ => andthen <$> visit d₁ <*> visit d₂
| ParserDescr.orelse d₁ d₂ => orelse <$> visit d₁ <*> visit d₂
| ParserDescr.optional d => optional <$> visit d
| ParserDescr.lookahead d => lookahead <$> visit d
| ParserDescr.«try» d => «try» <$> visit d
| ParserDescr.notFollowedBy d => do let p ← visit d; pure $ notFollowedBy p "element" -- TODO allow user to set msg at ParserDescr
| ParserDescr.many d => many <$> visit d
| ParserDescr.many1 d => many1 <$> visit d
| ParserDescr.sepBy d₁ d₂ => sepBy <$> visit d₁ <*> visit d₂
| ParserDescr.sepBy1 d₁ d₂ => sepBy1 <$> visit d₁ <*> visit d₂
| ParserDescr.node k prec d => leadingNode k prec <$> visit d
| ParserDescr.trailingNode k prec d => trailingNode k prec <$> visit d
| ParserDescr.symbol tk => pure $ symbol tk
| ParserDescr.noWs => pure $ checkNoWsBefore
| ParserDescr.numLit => pure $ numLit
| ParserDescr.strLit => pure $ strLit
| ParserDescr.charLit => pure $ charLit
| ParserDescr.nameLit => pure $ nameLit
| ParserDescr.interpolatedStr d => interpolatedStr <$> visit d
| ParserDescr.ident => pure $ ident
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol tk includeIdent
| ParserDescr.parser constName => do
let (_, p) ← mkParserOfConstantAux env opts categories constName visit;
pure p
| ParserDescr.cat catName prec =>
match getCategory categories catName with
| some _ => pure $ categoryParser catName prec
| none => throwUnknownParserCategory catName
visit d
def mkParserOfConstant (env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name) : Except String (Bool × Parser) :=
mkParserOfConstantAux env opts categories constName (compileParserDescr env opts categories)
mkParserOfConstantAux env opts categories constName (compileParserDescr env opts categories)
structure ParserAttributeHook :=
/- Called after a parser attribute is applied to a declaration. -/
(postAdd (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit)
/- Called after a parser attribute is applied to a declaration. -/
(postAdd (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit)
def mkParserAttributeHooks : IO (IO.Ref (List ParserAttributeHook)) := IO.mkRef {}
@[builtinInit mkParserAttributeHooks] constant parserAttributeHooks : IO.Ref (List ParserAttributeHook) := arbitrary _
builtin_initialize parserAttributeHooks : IO.Ref (List ParserAttributeHook) ← IO.mkRef {}
def registerParserAttributeHook (hook : ParserAttributeHook) : IO Unit := do
parserAttributeHooks.modify fun hooks => hook::hooks
parserAttributeHooks.modify fun hooks => hook::hooks
def runParserAttributeHooks (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit := do
let hooks ← parserAttributeHooks.get
hooks.forM fun hook => hook.postAdd catName declName builtin
let hooks ← parserAttributeHooks.get
hooks.forM fun hook => hook.postAdd catName declName builtin
builtin_initialize
registerBuiltinAttribute {
@ -261,27 +259,23 @@ builtin_initialize
}
private def ParserExtension.addImported (es : Array (Array ParserExtensionOleanEntry)) : ImportM ParserExtensionState := do
let ctx ← read
let s ← ParserExtension.mkInitial
es.foldlM
(fun s entries =>
entries.foldlM
(fun s entry =>
match entry with
| ParserExtensionOleanEntry.token tk => do
let tokens ← IO.ofExcept (addTokenConfig s.tokens tk)
pure { s with tokens := tokens }
| ParserExtensionOleanEntry.kind k =>
pure { s with kinds := s.kinds.insert k }
| ParserExtensionOleanEntry.category catName leadingIdentAsSymbol => do
let categories ← IO.ofExcept (addParserCategoryCore s.categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol})
pure { s with categories := categories }
| ParserExtensionOleanEntry.parser catName declName prio => do
let p ← IO.ofExcept $ mkParserOfConstant ctx.env ctx.opts s.categories declName
let categories ← IO.ofExcept $ addParser s.categories catName declName p.1 p.2 prio
pure { s with categories := categories })
s)
s
let ctx ← read
let s ← ParserExtension.mkInitial
es.foldlM (init := s) fun s entries =>
entries.foldlM (init := s) fun s entry =>
match entry with
| ParserExtensionOleanEntry.token tk => do
let tokens ← IO.ofExcept (addTokenConfig s.tokens tk)
pure { s with tokens := tokens }
| ParserExtensionOleanEntry.kind k =>
pure { s with kinds := s.kinds.insert k }
| ParserExtensionOleanEntry.category catName leadingIdentAsSymbol => do
let categories ← IO.ofExcept (addParserCategoryCore s.categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol})
pure { s with categories := categories }
| ParserExtensionOleanEntry.parser catName declName prio => do
let p ← IO.ofExcept $ mkParserOfConstant ctx.env ctx.opts s.categories declName
let categories ← IO.ofExcept $ addParser s.categories catName declName p.1 p.2 prio
pure { s with categories := categories }
builtin_initialize parserExtension : ParserExtension ←
registerPersistentEnvExtension {
@ -294,31 +288,30 @@ builtin_initialize parserExtension : ParserExtension ←
}
def isParserCategory (env : Environment) (catName : Name) : Bool :=
(parserExtension.getState env).categories.contains catName
(parserExtension.getState env).categories.contains catName
def addParserCategory (env : Environment) (catName : Name) (leadingIdentAsSymbol : Bool) : Except String Environment := do
if isParserCategory env catName then
throwParserCategoryAlreadyDefined catName
else
pure $ parserExtension.addEntry env $ ParserExtensionEntry.category catName leadingIdentAsSymbol
if isParserCategory env catName then
throwParserCategoryAlreadyDefined catName
else
pure $ parserExtension.addEntry env $ ParserExtensionEntry.category catName leadingIdentAsSymbol
/-
Return true if in the given category leading identifiers in parsers may be treated as atoms/symbols.
See comment at `ParserCategory`. -/
def leadingIdentAsSymbol (env : Environment) (catName : Name) : Bool :=
match getCategory (parserExtension.getState env).categories catName with
| none => false
| some cat => cat.leadingIdentAsSymbol
match getCategory (parserExtension.getState env).categories catName with
| none => false
| some cat => cat.leadingIdentAsSymbol
def mkCategoryAntiquotParser (kind : Name) : Parser :=
mkAntiquot kind.toString none
mkAntiquot kind.toString none
-- helper decl to work around inlining issue https://github.com/leanprover/lean4/commit/3f6de2af06dd9a25f62294129f64bc05a29ea912#r41340377
@[inline] private def mkCategoryAntiquotParserFn (kind : Name) : ParserFn :=
(mkCategoryAntiquotParser kind).fn
def categoryParserFnImpl (catName : Name) : ParserFn :=
fun ctx s =>
def categoryParserFnImpl (catName : Name) : ParserFn := fun ctx s =>
let catName := if catName == `syntax then `stx else catName -- temporary Hack
let categories := (parserExtension.getState ctx.env).categories
match getCategory categories catName with
@ -327,149 +320,152 @@ fun ctx s =>
| none => s.mkUnexpectedError ("unknown parser category '" ++ toString catName ++ "'")
@[builtinInit] def setCategoryParserFnRef : IO Unit :=
categoryParserFnRef.set categoryParserFnImpl
categoryParserFnRef.set categoryParserFnImpl
def addToken (env : Environment) (tk : Token) : Except String Environment := do
-- Recall that `ParserExtension.addEntry` is pure, and assumes `addTokenConfig` does not fail.
-- So, we must run it here to handle exception.
addTokenConfig (parserExtension.getState env).tokens tk
pure $ parserExtension.addEntry env $ ParserExtensionEntry.token tk
-- Recall that `ParserExtension.addEntry` is pure, and assumes `addTokenConfig` does not fail.
-- So, we must run it here to handle exception.
addTokenConfig (parserExtension.getState env).tokens tk
pure $ parserExtension.addEntry env $ ParserExtensionEntry.token tk
def addSyntaxNodeKind (env : Environment) (k : SyntaxNodeKind) : Environment :=
parserExtension.addEntry env $ ParserExtensionEntry.kind k
parserExtension.addEntry env $ ParserExtensionEntry.kind k
def isValidSyntaxNodeKind (env : Environment) (k : SyntaxNodeKind) : Bool :=
let kinds := (parserExtension.getState env).kinds
kinds.contains k
let kinds := (parserExtension.getState env).kinds
kinds.contains k
def getSyntaxNodeKinds (env : Environment) : List SyntaxNodeKind := do
let kinds := (parserExtension.getState env).kinds
kinds.foldl (fun ks k _ => k::ks) []
let kinds := (parserExtension.getState env).kinds
kinds.foldl (fun ks k _ => k::ks) []
def getTokenTable (env : Environment) : TokenTable :=
(parserExtension.getState env).tokens
(parserExtension.getState env).tokens
def mkInputContext (input : String) (fileName : String) : InputContext :=
{ input := input,
def mkInputContext (input : String) (fileName : String) : InputContext := {
input := input,
fileName := fileName,
fileMap := input.toFileMap }
fileMap := input.toFileMap
}
def mkParserContext (env : Environment) (ctx : InputContext) : ParserContext :=
{ prec := 0,
def mkParserContext (env : Environment) (ctx : InputContext) : ParserContext := {
prec := 0,
toInputContext := ctx,
env := env,
tokens := getTokenTable env }
tokens := getTokenTable env
}
def mkParserState (input : String) : ParserState :=
{ cache := initCacheForInput input }
{ cache := initCacheForInput input }
/- convenience function for testing -/
def runParserCategory (env : Environment) (catName : Name) (input : String) (fileName := "<input>") : Except String Syntax :=
let c := mkParserContext env (mkInputContext input fileName)
let s := mkParserState input
let s := whitespace c s
let s := categoryParserFnImpl catName c s
if s.hasError then
Except.error (s.toErrorMsg c)
else if input.atEnd s.pos then
Except.ok s.stxStack.back
else
Except.error ((s.mkError "end of input").toErrorMsg c)
let c := mkParserContext env (mkInputContext input fileName)
let s := mkParserState input
let s := whitespace c s
let s := categoryParserFnImpl catName c s
if s.hasError then
Except.error (s.toErrorMsg c)
else if input.atEnd s.pos then
Except.ok s.stxStack.back
else
Except.error ((s.mkError "end of input").toErrorMsg c)
def declareBuiltinParser (env : Environment) (addFnName : Name) (catName : Name) (declName : Name) (prio : Nat) : IO Environment :=
let name := `_regBuiltinParser ++ declName
let type := mkApp (mkConst `IO) (mkConst `Unit)
let val := mkAppN (mkConst addFnName) #[toExpr catName, toExpr declName, mkConst declName, mkNatLit prio]
let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false }
match env.addAndCompile {} decl with
-- TODO: pretty print error
| Except.error _ => throw (IO.userError ("failed to emit registration code for builtin parser '" ++ toString declName ++ "'"))
| Except.ok env => IO.ofExcept (setBuiltinInitAttr env name)
let name := `_regBuiltinParser ++ declName
let type := mkApp (mkConst `IO) (mkConst `Unit)
let val := mkAppN (mkConst addFnName) #[toExpr catName, toExpr declName, mkConst declName, mkNatLit prio]
let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false }
match env.addAndCompile {} decl with
-- TODO: pretty print error
| Except.error _ => throw (IO.userError ("failed to emit registration code for builtin parser '" ++ toString declName ++ "'"))
| Except.ok env => IO.ofExcept (setBuiltinInitAttr env name)
def declareLeadingBuiltinParser (env : Environment) (catName : Name) (declName : Name) (prio : Nat) : IO Environment := -- TODO: use CoreM?
declareBuiltinParser env `Lean.Parser.addBuiltinLeadingParser catName declName prio
declareBuiltinParser env `Lean.Parser.addBuiltinLeadingParser catName declName prio
def declareTrailingBuiltinParser (env : Environment) (catName : Name) (declName : Name) (prio : Nat) : IO Environment := -- TODO: use CoreM?
declareBuiltinParser env `Lean.Parser.addBuiltinTrailingParser catName declName prio
declareBuiltinParser env `Lean.Parser.addBuiltinTrailingParser catName declName prio
def getParserPriority (args : Syntax) : Except String Nat :=
match args.getNumArgs with
| 0 => pure 0
| 1 => match (args.getArg 0).isNatLit? with
| some prio => pure prio
| none => throw "invalid parser attribute, numeral expected"
| _ => throw "invalid parser attribute, no argument or numeral expected"
match args.getNumArgs with
| 0 => pure 0
| 1 => match (args.getArg 0).isNatLit? with
| some prio => pure prio
| none => throw "invalid parser attribute, numeral expected"
| _ => throw "invalid parser attribute, no argument or numeral expected"
private def BuiltinParserAttribute.add (attrName : Name) (catName : Name)
(declName : Name) (args : Syntax) (persistent : Bool) : AttrM Unit := do
let prio ← ofExcept (getParserPriority args)
unless persistent do throwError! "invalid attribute '{attrName}', must be persistent"
let decl ← getConstInfo declName
let env ← getEnv
match decl.type with
| Expr.const `Lean.Parser.TrailingParser _ _ => do
let env ← liftIO $ declareTrailingBuiltinParser env catName declName prio
setEnv env
| Expr.const `Lean.Parser.Parser _ _ => do
let env ← liftIO $ declareLeadingBuiltinParser env catName declName prio
setEnv env
| _ => throwError! "unexpected parser type at '{declName}' (`Parser` or `TrailingParser` expected)"
runParserAttributeHooks catName declName (builtin := true)
let prio ← ofExcept (getParserPriority args)
unless persistent do throwError! "invalid attribute '{attrName}', must be persistent"
let decl ← getConstInfo declName
let env ← getEnv
match decl.type with
| Expr.const `Lean.Parser.TrailingParser _ _ => do
let env ← liftIO $ declareTrailingBuiltinParser env catName declName prio
setEnv env
| Expr.const `Lean.Parser.Parser _ _ => do
let env ← liftIO $ declareLeadingBuiltinParser env catName declName prio
setEnv env
| _ => throwError! "unexpected parser type at '{declName}' (`Parser` or `TrailingParser` expected)"
runParserAttributeHooks catName declName (builtin := true)
/-
The parsing tables for builtin parsers are "stored" in the extracted source code.
-/
def registerBuiltinParserAttribute (attrName : Name) (catName : Name) (leadingIdentAsSymbol := false) : IO Unit := do
addBuiltinParserCategory catName leadingIdentAsSymbol
registerBuiltinAttribute {
name := attrName,
descr := "Builtin parser",
add := fun declName args persistent => liftM $ BuiltinParserAttribute.add attrName catName declName args persistent,
applicationTime := AttributeApplicationTime.afterCompilation
}
addBuiltinParserCategory catName leadingIdentAsSymbol
registerBuiltinAttribute {
name := attrName,
descr := "Builtin parser",
add := fun declName args persistent => liftM $ BuiltinParserAttribute.add attrName catName declName args persistent,
applicationTime := AttributeApplicationTime.afterCompilation
}
private def ParserAttribute.add (attrName : Name) (catName : Name) (declName : Name) (args : Syntax) (persistent : Bool) : AttrM Unit := do
let prio ← ofExcept (getParserPriority args)
let env ← getEnv
let opts ← getOptions
let categories := (parserExtension.getState env).categories
match mkParserOfConstant env opts categories declName with
| Except.error ex => throwError ex
| Except.ok p => do
let leading := p.1
let parser := p.2
let tokens := parser.info.collectTokens []
tokens.forM fun token => do
env ← getEnv
match addToken env token with
| Except.ok env => setEnv env
| Except.error msg => throwError! "invalid parser '{declName}', {msg}"
let kinds := parser.info.collectKinds {}
kinds.forM fun kind _ => modifyEnv fun env => addSyntaxNodeKind env kind
match addParser categories catName declName leading parser prio with
| Except.ok _ => modifyEnv fun env => parserExtension.addEntry env $ ParserExtensionEntry.parser catName declName leading parser prio
let prio ← ofExcept (getParserPriority args)
let env ← getEnv
let opts ← getOptions
let categories := (parserExtension.getState env).categories
match mkParserOfConstant env opts categories declName with
| Except.error ex => throwError ex
runParserAttributeHooks catName declName /- builtin -/ false
| Except.ok p => do
let leading := p.1
let parser := p.2
let tokens := parser.info.collectTokens []
tokens.forM fun token => do
env ← getEnv
match addToken env token with
| Except.ok env => setEnv env
| Except.error msg => throwError! "invalid parser '{declName}', {msg}"
let kinds := parser.info.collectKinds {}
kinds.forM fun kind _ => modifyEnv fun env => addSyntaxNodeKind env kind
match addParser categories catName declName leading parser prio with
| Except.ok _ => modifyEnv fun env => parserExtension.addEntry env $ ParserExtensionEntry.parser catName declName leading parser prio
| Except.error ex => throwError ex
runParserAttributeHooks catName declName /- builtin -/ false
def mkParserAttributeImpl (attrName : Name) (catName : Name) : AttributeImpl :=
{ name := attrName,
def mkParserAttributeImpl (attrName : Name) (catName : Name) : AttributeImpl := {
name := attrName,
descr := "parser",
add := fun declName args persistent => liftM $ ParserAttribute.add attrName catName declName args persistent,
applicationTime := AttributeApplicationTime.afterCompilation }
applicationTime := AttributeApplicationTime.afterCompilation
}
/- A builtin parser attribute that can be extended by users. -/
def registerBuiltinDynamicParserAttribute (attrName : Name) (catName : Name) : IO Unit := do
registerBuiltinAttribute (mkParserAttributeImpl attrName catName)
registerBuiltinAttribute (mkParserAttributeImpl attrName catName)
@[builtinInit] private def registerParserAttributeImplBuilder : IO Unit :=
registerAttributeImplBuilder `parserAttr fun args =>
match args with
| [DataValue.ofName attrName, DataValue.ofName catName] => pure $ mkParserAttributeImpl attrName catName
| _ => throw "invalid parser attribute implementation builder arguments"
registerAttributeImplBuilder `parserAttr fun args =>
match args with
| [DataValue.ofName attrName, DataValue.ofName catName] => pure $ mkParserAttributeImpl attrName catName
| _ => throw "invalid parser attribute implementation builder arguments"
def registerParserCategory (env : Environment) (attrName : Name) (catName : Name) (leadingIdentAsSymbol := false) : IO Environment := do
env ← IO.ofExcept $ addParserCategory env catName leadingIdentAsSymbol
registerAttributeOfBuilder env `parserAttr [DataValue.ofName attrName, DataValue.ofName catName]
let env ← IO.ofExcept $ addParserCategory env catName leadingIdentAsSymbol
registerAttributeOfBuilder env `parserAttr [DataValue.ofName attrName, DataValue.ofName catName]
-- declare `termParser` here since it is used everywhere via antiquotations
@ -483,10 +479,9 @@ builtin_initialize registerBuiltinParserAttribute `builtinCommandParser `command
builtin_initialize registerBuiltinDynamicParserAttribute `commandParser `command
@[inline] def commandParser (rbp : Nat := 0) : Parser :=
categoryParser `command rbp
categoryParser `command rbp
def notFollowedByCategoryTokenFn (catName : Name) : ParserFn :=
fun ctx s =>
def notFollowedByCategoryTokenFn (catName : Name) : ParserFn := fun ctx s =>
let categories := (parserExtension.getState ctx.env).categories
match getCategory categories catName with
| none => s.mkUnexpectedError s!"unknown parser category '{catName}'"
@ -500,14 +495,15 @@ fun ctx s =>
| _ => s
| _ => s
@[inline] def notFollowedByCategoryToken (catName : Name) : Parser :=
{ fn := notFollowedByCategoryTokenFn catName }
@[inline] def notFollowedByCategoryToken (catName : Name) : Parser := {
fn := notFollowedByCategoryTokenFn catName
}
abbrev notFollowedByCommandToken : Parser :=
notFollowedByCategoryToken `command
notFollowedByCategoryToken `command
abbrev notFollowedByTermToken : Parser :=
notFollowedByCategoryToken `term
notFollowedByCategoryToken `term
end Parser
end Lean

View file

@ -23,67 +23,67 @@ def ctx (interp) : ParserCompiler.Context Parenthesizer :=
⟨`parenthesizer, parenthesizerAttribute, combinatorParenthesizerAttribute, interp⟩
unsafe def interpretParserDescr : ParserDescr → AttrM Parenthesizer
| ParserDescr.andthen d₁ d₂ => andthen.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.orelse d₁ d₂ => orelse.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.optional d => optional.parenthesizer <$> interpretParserDescr d
| ParserDescr.lookahead d => lookahead.parenthesizer <$> interpretParserDescr d
| ParserDescr.try d => try.parenthesizer <$> interpretParserDescr d
| ParserDescr.notFollowedBy d => notFollowedBy.parenthesizer <$> interpretParserDescr d
| ParserDescr.many d => many.parenthesizer <$> interpretParserDescr d
| ParserDescr.many1 d => many1.parenthesizer <$> interpretParserDescr d
| ParserDescr.sepBy d₁ d₂ => sepBy.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.sepBy1 d₁ d₂ => sepBy1.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.node k prec d => leadingNode.parenthesizer k prec <$> interpretParserDescr d
| ParserDescr.trailingNode k prec d => trailingNode.parenthesizer k prec <$> interpretParserDescr d
| ParserDescr.symbol tk => pure $ symbol.parenthesizer tk
| ParserDescr.numLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "numLit" `numLit) numLitNoAntiquot.parenthesizer
| ParserDescr.strLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "strLit" `strLit) strLitNoAntiquot.parenthesizer
| ParserDescr.charLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "charLit" `charLit) charLitNoAntiquot.parenthesizer
| ParserDescr.nameLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "nameLit" `nameLit) nameLitNoAntiquot.parenthesizer
| ParserDescr.ident => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "ident" `ident) identNoAntiquot.parenthesizer
| ParserDescr.interpolatedStr d => interpolatedStr.parenthesizer <$> interpretParserDescr d
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.parenthesizer tk includeIdent
| ParserDescr.noWs => pure $ checkNoWsBefore.parenthesizer
| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName
| ParserDescr.cat catName prec => pure $ categoryParser.parenthesizer catName prec
| ParserDescr.andthen d₁ d₂ => andthen.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.orelse d₁ d₂ => orelse.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.optional d => optional.parenthesizer <$> interpretParserDescr d
| ParserDescr.lookahead d => lookahead.parenthesizer <$> interpretParserDescr d
| ParserDescr.try d => try.parenthesizer <$> interpretParserDescr d
| ParserDescr.notFollowedBy d => notFollowedBy.parenthesizer <$> interpretParserDescr d
| ParserDescr.many d => many.parenthesizer <$> interpretParserDescr d
| ParserDescr.many1 d => many1.parenthesizer <$> interpretParserDescr d
| ParserDescr.sepBy d₁ d₂ => sepBy.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.sepBy1 d₁ d₂ => sepBy1.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.node k prec d => leadingNode.parenthesizer k prec <$> interpretParserDescr d
| ParserDescr.trailingNode k prec d => trailingNode.parenthesizer k prec <$> interpretParserDescr d
| ParserDescr.symbol tk => pure $ symbol.parenthesizer tk
| ParserDescr.numLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "numLit" `numLit) numLitNoAntiquot.parenthesizer
| ParserDescr.strLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "strLit" `strLit) strLitNoAntiquot.parenthesizer
| ParserDescr.charLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "charLit" `charLit) charLitNoAntiquot.parenthesizer
| ParserDescr.nameLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "nameLit" `nameLit) nameLitNoAntiquot.parenthesizer
| ParserDescr.ident => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "ident" `ident) identNoAntiquot.parenthesizer
| ParserDescr.interpolatedStr d => interpolatedStr.parenthesizer <$> interpretParserDescr d
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.parenthesizer tk includeIdent
| ParserDescr.noWs => pure $ checkNoWsBefore.parenthesizer
| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName
| ParserDescr.cat catName prec => pure $ categoryParser.parenthesizer catName prec
@[builtinInit] unsafe def regHook : IO Unit :=
ParserCompiler.registerParserCompiler (ctx interpretParserDescr)
ParserCompiler.registerParserCompiler (ctx interpretParserDescr)
end Parenthesizer
namespace Formatter
def ctx (interp) : ParserCompiler.Context Formatter :=
⟨`formatter, formatterAttribute, combinatorFormatterAttribute, interp⟩
⟨`formatter, formatterAttribute, combinatorFormatterAttribute, interp⟩
unsafe def interpretParserDescr : ParserDescr → AttrM Formatter
| ParserDescr.andthen d₁ d₂ => andthen.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.orelse d₁ d₂ => orelse.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.optional d => optional.formatter <$> interpretParserDescr d
| ParserDescr.lookahead d => lookahead.formatter <$> interpretParserDescr d
| ParserDescr.try d => try.formatter <$> interpretParserDescr d
| ParserDescr.notFollowedBy d => notFollowedBy.formatter <$> interpretParserDescr d
| ParserDescr.many d => many.formatter <$> interpretParserDescr d
| ParserDescr.many1 d => many1.formatter <$> interpretParserDescr d
| ParserDescr.sepBy d₁ d₂ => sepBy.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.sepBy1 d₁ d₂ => sepBy1.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.node k prec d => node.formatter k <$> interpretParserDescr d
| ParserDescr.trailingNode k prec d => trailingNode.formatter k prec <$> interpretParserDescr d
| ParserDescr.symbol tk => pure $ symbol.formatter tk
| ParserDescr.numLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "numLit" `numLit) numLitNoAntiquot.formatter
| ParserDescr.strLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "strLit" `strLit) strLitNoAntiquot.formatter
| ParserDescr.charLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "charLit" `charLit) charLitNoAntiquot.formatter
| ParserDescr.nameLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "nameLit" `nameLit) nameLitNoAntiquot.formatter
| ParserDescr.interpolatedStr d => interpolatedStr.formatter <$> interpretParserDescr d
| ParserDescr.ident => pure $ withAntiquot.formatter (mkAntiquot.formatter' "ident" `ident) identNoAntiquot.formatter
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.formatter tk
| ParserDescr.noWs => pure $ checkNoWsBefore.formatter
| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName
| ParserDescr.cat catName prec => pure $ categoryParser.formatter catName
| ParserDescr.andthen d₁ d₂ => andthen.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.orelse d₁ d₂ => orelse.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.optional d => optional.formatter <$> interpretParserDescr d
| ParserDescr.lookahead d => lookahead.formatter <$> interpretParserDescr d
| ParserDescr.try d => try.formatter <$> interpretParserDescr d
| ParserDescr.notFollowedBy d => notFollowedBy.formatter <$> interpretParserDescr d
| ParserDescr.many d => many.formatter <$> interpretParserDescr d
| ParserDescr.many1 d => many1.formatter <$> interpretParserDescr d
| ParserDescr.sepBy d₁ d₂ => sepBy.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.sepBy1 d₁ d₂ => sepBy1.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
| ParserDescr.node k prec d => node.formatter k <$> interpretParserDescr d
| ParserDescr.trailingNode k prec d => trailingNode.formatter k prec <$> interpretParserDescr d
| ParserDescr.symbol tk => pure $ symbol.formatter tk
| ParserDescr.numLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "numLit" `numLit) numLitNoAntiquot.formatter
| ParserDescr.strLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "strLit" `strLit) strLitNoAntiquot.formatter
| ParserDescr.charLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "charLit" `charLit) charLitNoAntiquot.formatter
| ParserDescr.nameLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "nameLit" `nameLit) nameLitNoAntiquot.formatter
| ParserDescr.interpolatedStr d => interpolatedStr.formatter <$> interpretParserDescr d
| ParserDescr.ident => pure $ withAntiquot.formatter (mkAntiquot.formatter' "ident" `ident) identNoAntiquot.formatter
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.formatter tk
| ParserDescr.noWs => pure $ checkNoWsBefore.formatter
| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName
| ParserDescr.cat catName prec => pure $ categoryParser.formatter catName
@[builtinInit] unsafe def regHook : IO Unit :=
ParserCompiler.registerParserCompiler (ctx interpretParserDescr)
ParserCompiler.registerParserCompiler (ctx interpretParserDescr)
end Formatter

View file

@ -66,15 +66,15 @@ end ReplaceLevelImpl
@[implementedBy ReplaceLevelImpl.replaceUnsafe]
partial def replaceLevel (f? : Level → Option Level) : Expr → Expr
| e@(Expr.forallE _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateForallE! d b
| e@(Expr.lam _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateLambdaE! d b
| e@(Expr.mdata _ b _) => let b := replaceLevel f? b; e.updateMData! b
| e@(Expr.letE _ t v b _) => let t := replaceLevel f? t; let v := replaceLevel f? v; let b := replaceLevel f? b; e.updateLet! t v b
| e@(Expr.app f a _) => let f := replaceLevel f? f; let a := replaceLevel f? a; e.updateApp! f a
| e@(Expr.proj _ _ b _) => let b := replaceLevel f? b; e.updateProj! b
| e@(Expr.sort u _) => e.updateSort! (u.replace f?)
| e@(Expr.const n us _) => e.updateConst! (us.map (Level.replace f?))
| e => e
| e@(Expr.forallE _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateForallE! d b
| e@(Expr.lam _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateLambdaE! d b
| e@(Expr.mdata _ b _) => let b := replaceLevel f? b; e.updateMData! b
| e@(Expr.letE _ t v b _) => let t := replaceLevel f? t; let v := replaceLevel f? v; let b := replaceLevel f? b; e.updateLet! t v b
| e@(Expr.app f a _) => let f := replaceLevel f? f; let a := replaceLevel f? a; e.updateApp! f a
| e@(Expr.proj _ _ b _) => let b := replaceLevel f? b; e.updateProj! b
| e@(Expr.sort u _) => e.updateSort! (u.replace f?)
| e@(Expr.const n us _) => e.updateConst! (us.map (Level.replace f?))
| e => e
end Expr
end Lean