chore: cleanup
This commit is contained in:
parent
f31b0d7d19
commit
6858cb5fb6
29 changed files with 2420 additions and 2433 deletions
|
|
@ -1271,7 +1271,7 @@ def Prod.map.{u₁, u₂, v₁, v₂} {α₁ : Type u₁} {α₂ : Type u₂} {
|
|||
/- Dependent products -/
|
||||
|
||||
theorem exOfPsig {α : Type u} {p : α → Prop} : (PSigma (fun x => p x)) → Exists (fun x => p x)
|
||||
| ⟨x, hx⟩ => ⟨x, hx⟩
|
||||
| ⟨x, hx⟩ => ⟨x, hx⟩
|
||||
|
||||
protected theorem PSigma.eta {α : Sort u} {β : α → Sort v} {a₁ a₂ : α} {b₁ : β a₁} {b₂ : β a₂}
|
||||
(h₁ : a₁ = a₂) (h₂ : Eq.rec (motive := fun a _ => β a) b₁ h₁ = b₂) : PSigma.mk a₁ b₁ = PSigma.mk a₂ b₂ := by
|
||||
|
|
|
|||
|
|
@ -14,22 +14,22 @@ namespace Array
|
|||
-- TODO: remove `partial` using well-founded recursion
|
||||
|
||||
@[specialize] partial def binSearchAux {α : Type u} {β : Type v} [Inhabited α] [Inhabited β] (lt : α → α → Bool) (found : Option α → β) (as : Array α) (k : α) : Nat → Nat → β
|
||||
| lo, hi =>
|
||||
if lo <= hi then
|
||||
let m := (lo + hi)/2;
|
||||
let a := as.get! m;
|
||||
if lt a k then binSearchAux lt found as k (m+1) hi
|
||||
else if lt k a then
|
||||
if m == 0 then found none
|
||||
else binSearchAux lt found as k lo (m-1)
|
||||
else found (some a)
|
||||
else found none
|
||||
| lo, hi =>
|
||||
if lo <= hi then
|
||||
let m := (lo + hi)/2;
|
||||
let a := as.get! m;
|
||||
if lt a k then binSearchAux lt found as k (m+1) hi
|
||||
else if lt k a then
|
||||
if m == 0 then found none
|
||||
else binSearchAux lt found as k lo (m-1)
|
||||
else found (some a)
|
||||
else found none
|
||||
|
||||
@[inline] def binSearch {α : Type} [Inhabited α] (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Option α :=
|
||||
binSearchAux lt id as k lo hi
|
||||
binSearchAux lt id as k lo hi
|
||||
|
||||
@[inline] def binSearchContains {α : Type} [Inhabited α] (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Bool :=
|
||||
binSearchAux lt Option.isSome as k lo hi
|
||||
binSearchAux lt Option.isSome as k lo hi
|
||||
|
||||
@[specialize] private partial def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α]
|
||||
(lt : α → α → Bool)
|
||||
|
|
@ -37,17 +37,17 @@ binSearchAux lt Option.isSome as k lo hi
|
|||
(add : Unit → m α)
|
||||
(as : Array α)
|
||||
(k : α) : Nat → Nat → m (Array α)
|
||||
| lo, hi =>
|
||||
-- as[lo] < k < as[hi]
|
||||
let mid := (lo + hi)/2;
|
||||
let midVal := as.get! mid;
|
||||
if lt midVal k then
|
||||
if mid == lo then do let v ← add (); pure $ as.insertAt (lo+1) v
|
||||
else binInsertAux lt merge add as k mid hi
|
||||
else if lt k midVal then
|
||||
binInsertAux lt merge add as k lo mid
|
||||
else do
|
||||
as.modifyM mid $ fun v => merge v
|
||||
| lo, hi =>
|
||||
-- as[lo] < k < as[hi]
|
||||
let mid := (lo + hi)/2;
|
||||
let midVal := as.get! mid;
|
||||
if lt midVal k then
|
||||
if mid == lo then do let v ← add (); pure $ as.insertAt (lo+1) v
|
||||
else binInsertAux lt merge add as k mid hi
|
||||
else if lt k midVal then
|
||||
binInsertAux lt merge add as k lo mid
|
||||
else do
|
||||
as.modifyM mid $ fun v => merge v
|
||||
|
||||
@[specialize] partial def binInsertM {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α]
|
||||
(lt : α → α → Bool)
|
||||
|
|
@ -55,14 +55,14 @@ binSearchAux lt Option.isSome as k lo hi
|
|||
(add : Unit → m α)
|
||||
(as : Array α)
|
||||
(k : α) : m (Array α) :=
|
||||
if as.isEmpty then do let v ← add (); pure $ as.push v
|
||||
else if lt k (as.get! 0) then do let v ← add (); pure $ as.insertAt 0 v
|
||||
else if !lt (as.get! 0) k then as.modifyM 0 $ merge
|
||||
else if lt as.back k then do let v ← add (); pure $ as.push v
|
||||
else if !lt k as.back then as.modifyM (as.size - 1) $ merge
|
||||
else binInsertAux lt merge add as k 0 (as.size - 1)
|
||||
if as.isEmpty then do let v ← add (); pure $ as.push v
|
||||
else if lt k (as.get! 0) then do let v ← add (); pure $ as.insertAt 0 v
|
||||
else if !lt (as.get! 0) k then as.modifyM 0 $ merge
|
||||
else if lt as.back k then do let v ← add (); pure $ as.push v
|
||||
else if !lt k as.back then as.modifyM (as.size - 1) $ merge
|
||||
else binInsertAux lt merge add as k 0 (as.size - 1)
|
||||
|
||||
@[inline] def binInsert {α : Type u} [Inhabited α] (lt : α → α → Bool) (as : Array α) (k : α) : Array α :=
|
||||
Id.run $ binInsertM lt (fun _ => k) (fun _ => k) as k
|
||||
Id.run $ binInsertM lt (fun _ => k) (fun _ => k) as k
|
||||
|
||||
end Array
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ Remark: we can define `mapM`, `mapM₂` and `forM` using `Applicative` instead o
|
|||
Example:
|
||||
```
|
||||
def mapM {m : Type u → Type v} [Applicative m] {α : Type w} {β : Type u} (f : α → m β) : List α → m (List β)
|
||||
| [] => pure []
|
||||
| a::as => List.cons <$> (f a) <*> mapM as
|
||||
| [] => pure []
|
||||
| a::as => List.cons <$> (f a) <*> mapM as
|
||||
```
|
||||
|
||||
However, we consider `f <$> a <*> b` an anti-idiom because the generated code
|
||||
|
|
|
|||
|
|
@ -201,16 +201,16 @@ protected theorem oneMul (n : Nat) : 1 * n = n :=
|
|||
Nat.mulComm n 1 ▸ Nat.mulOne n
|
||||
|
||||
protected theorem leftDistrib : ∀ (n m k : Nat), n * (m + k) = n * m + n * k
|
||||
| 0, m, k => (Nat.zeroMul (m + k)).symm ▸ (Nat.zeroMul m).symm ▸ (Nat.zeroMul k).symm ▸ rfl
|
||||
| succ n, m, k =>
|
||||
have h₁ : succ n * (m + k) = n * (m + k) + (m + k) from succMul ..
|
||||
have h₂ : n * (m + k) + (m + k) = (n * m + n * k) + (m + k) from Nat.leftDistrib n m k ▸ rfl
|
||||
have h₃ : (n * m + n * k) + (m + k) = n * m + (n * k + (m + k)) from Nat.addAssoc ..
|
||||
have h₄ : n * m + (n * k + (m + k)) = n * m + (m + (n * k + k)) from congrArg (fun x => n*m + x) (Nat.addLeftComm ..)
|
||||
have h₅ : n * m + (m + (n * k + k)) = (n * m + m) + (n * k + k) from (Nat.addAssoc ..).symm
|
||||
have h₆ : (n * m + m) + (n * k + k) = (n * m + m) + succ n * k from succMul n k ▸ rfl
|
||||
have h₇ : (n * m + m) + succ n * k = succ n * m + succ n * k from succMul n m ▸ rfl
|
||||
(((((h₁.trans h₂).trans h₃).trans h₄).trans h₅).trans h₆).trans h₇
|
||||
| 0, m, k => (Nat.zeroMul (m + k)).symm ▸ (Nat.zeroMul m).symm ▸ (Nat.zeroMul k).symm ▸ rfl
|
||||
| succ n, m, k =>
|
||||
have h₁ : succ n * (m + k) = n * (m + k) + (m + k) from succMul ..
|
||||
have h₂ : n * (m + k) + (m + k) = (n * m + n * k) + (m + k) from Nat.leftDistrib n m k ▸ rfl
|
||||
have h₃ : (n * m + n * k) + (m + k) = n * m + (n * k + (m + k)) from Nat.addAssoc ..
|
||||
have h₄ : n * m + (n * k + (m + k)) = n * m + (m + (n * k + k)) from congrArg (fun x => n*m + x) (Nat.addLeftComm ..)
|
||||
have h₅ : n * m + (m + (n * k + k)) = (n * m + m) + (n * k + k) from (Nat.addAssoc ..).symm
|
||||
have h₆ : (n * m + m) + (n * k + k) = (n * m + m) + succ n * k from succMul n k ▸ rfl
|
||||
have h₇ : (n * m + m) + succ n * k = succ n * m + succ n * k from succMul n m ▸ rfl
|
||||
(((((h₁.trans h₂).trans h₃).trans h₄).trans h₅).trans h₆).trans h₇
|
||||
|
||||
protected theorem rightDistrib (n m k : Nat) : (n + m) * k = n * k + m * k :=
|
||||
have h₁ : (n + m) * k = k * (n + m) from Nat.mulComm ..
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import Init.Data.ToString.Basic
|
|||
syntax:max "s!" (interpolatedStr term) : term
|
||||
|
||||
macro_rules
|
||||
| `(s! $interpStr) => do
|
||||
let chunks := interpStr.getArgs
|
||||
let r ← Lean.Syntax.expandInterpolatedStrChunks chunks (fun a b => `($a ++ $b)) (fun a => `(toString $a))
|
||||
`(($r : String))
|
||||
| `(s! $interpStr) => do
|
||||
let chunks := interpStr.getArgs
|
||||
let r ← Lean.Syntax.expandInterpolatedStrChunks chunks (fun a b => `($a ++ $b)) (fun a => `(toString $a))
|
||||
`(($r : String))
|
||||
|
|
|
|||
|
|
@ -109,21 +109,21 @@ protected def append : Name → Name → Name
|
|||
instance : Append Name := ⟨Name.append⟩
|
||||
|
||||
def capitalize : Name → Name
|
||||
| Name.str p s _ => mkNameStr p s.capitalize
|
||||
| n => n
|
||||
| Name.str p s _ => mkNameStr p s.capitalize
|
||||
| n => n
|
||||
|
||||
def appendAfter : Name → String → Name
|
||||
| str p s _, suffix => mkNameStr p (s ++ suffix)
|
||||
| n, suffix => mkNameStr n suffix
|
||||
| str p s _, suffix => mkNameStr p (s ++ suffix)
|
||||
| n, suffix => mkNameStr n suffix
|
||||
|
||||
def appendIndexAfter : Name → Nat → Name
|
||||
| str p s _, idx => mkNameStr p (s ++ "_" ++ toString idx)
|
||||
| n, idx => mkNameStr n ("_" ++ toString idx)
|
||||
| str p s _, idx => mkNameStr p (s ++ "_" ++ toString idx)
|
||||
| n, idx => mkNameStr n ("_" ++ toString idx)
|
||||
|
||||
def appendBefore : Name → String → Name
|
||||
| anonymous, pre => mkNameStr anonymous pre
|
||||
| str p s _, pre => mkNameStr p (pre ++ s)
|
||||
| num p n _, pre => mkNameNum (mkNameStr p pre) n
|
||||
| anonymous, pre => mkNameStr anonymous pre
|
||||
| str p s _, pre => mkNameStr p (pre ++ s)
|
||||
| num p n _, pre => mkNameNum (mkNameStr p pre) n
|
||||
|
||||
end Name
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ universes u v
|
|||
set_option codegen false
|
||||
|
||||
inductive Acc {α : Sort u} (r : α → α → Prop) : α → Prop
|
||||
| intro (x : α) (h : (y : α) → r y x → Acc r y) : Acc r x
|
||||
| intro (x : α) (h : (y : α) → r y x → Acc r y) : Acc r x
|
||||
|
||||
@[elabAsEliminator, inline, reducible]
|
||||
def Acc.ndrec.{u1, u2} {α : Sort u2} {r : α → α → Prop} {C : α → Sort u1}
|
||||
|
|
@ -289,8 +289,8 @@ variables {α : Sort u} {β : Sort v}
|
|||
|
||||
-- Reverse lexicographical order based on r and s
|
||||
inductive RevLex (r : α → α → Prop) (s : β → β → Prop) : @PSigma α (fun a => β) → @PSigma α (fun a => β) → Prop
|
||||
| left : {a₁ a₂ : α} → (b : β) → r a₁ a₂ → RevLex r s ⟨a₁, b⟩ ⟨a₂, b⟩
|
||||
| right : (a₁ : α) → {b₁ : β} → (a₂ : α) → {b₂ : β} → s b₁ b₂ → RevLex r s ⟨a₁, b₁⟩ ⟨a₂, b₂⟩
|
||||
| left : {a₁ a₂ : α} → (b : β) → r a₁ a₂ → RevLex r s ⟨a₁, b⟩ ⟨a₂, b⟩
|
||||
| right : (a₁ : α) → {b₁ : β} → (a₂ : α) → {b₂ : β} → s b₁ b₂ → RevLex r s ⟨a₁, b₁⟩ ⟨a₂, b₂⟩
|
||||
end
|
||||
|
||||
section
|
||||
|
|
|
|||
|
|
@ -15,11 +15,12 @@ namespace OwnedSet
|
|||
abbrev Key := FunId × Index
|
||||
|
||||
def beq : Key → Key → Bool
|
||||
| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂
|
||||
| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂
|
||||
|
||||
instance : BEq Key := ⟨beq⟩
|
||||
|
||||
def getHash : Key → USize
|
||||
| (f, x) => mixHash (hash f) (hash x)
|
||||
| (f, x) => mixHash (hash f) (hash x)
|
||||
instance : Hashable Key := ⟨getHash⟩
|
||||
end OwnedSet
|
||||
|
||||
|
|
@ -35,19 +36,19 @@ def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := Std.HashMap.
|
|||
Recall that `Param` contains the field `borrow`. -/
|
||||
namespace ParamMap
|
||||
inductive Key
|
||||
| decl (name : FunId)
|
||||
| jp (name : FunId) (jpid : JoinPointId)
|
||||
| decl (name : FunId)
|
||||
| jp (name : FunId) (jpid : JoinPointId)
|
||||
|
||||
def beq : Key → Key → Bool
|
||||
| Key.decl n₁, Key.decl n₂ => n₁ == n₂
|
||||
| Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂
|
||||
| _, _ => false
|
||||
| Key.decl n₁, Key.decl n₂ => n₁ == n₂
|
||||
| Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂
|
||||
| _, _ => false
|
||||
|
||||
instance : BEq Key := ⟨beq⟩
|
||||
|
||||
def getHash : Key → USize
|
||||
| Key.decl n => hash n
|
||||
| Key.jp n id => mixHash (hash n) (hash id)
|
||||
| Key.decl n => hash n
|
||||
| Key.jp n id => mixHash (hash n) (hash id)
|
||||
|
||||
instance : Hashable Key := ⟨getHash⟩
|
||||
end ParamMap
|
||||
|
|
@ -56,13 +57,13 @@ open ParamMap (Key)
|
|||
abbrev ParamMap := Std.HashMap Key (Array Param)
|
||||
|
||||
def ParamMap.fmt (map : ParamMap) : Format :=
|
||||
let fmts := map.fold (fun fmt k ps =>
|
||||
let k := match k with
|
||||
| ParamMap.Key.decl n => format n
|
||||
| ParamMap.Key.jp n id => format n ++ ":" ++ format id
|
||||
fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps)
|
||||
Format.nil
|
||||
"{" ++ (Format.nest 1 fmts) ++ "}"
|
||||
let fmts := map.fold (fun fmt k ps =>
|
||||
let k := match k with
|
||||
| ParamMap.Key.decl n => format n
|
||||
| ParamMap.Key.jp n id => format n ++ ":" ++ format id
|
||||
fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps)
|
||||
Format.nil
|
||||
"{" ++ (Format.nest 1 fmts) ++ "}"
|
||||
|
||||
instance : ToFormat ParamMap := ⟨ParamMap.fmt⟩
|
||||
instance : ToString ParamMap := ⟨fun m => Format.pretty (format m)⟩
|
||||
|
|
@ -70,7 +71,7 @@ instance : ToString ParamMap := ⟨fun m => Format.pretty (format m)⟩
|
|||
namespace InitParamMap
|
||||
/- Mark parameters that take a reference as borrow -/
|
||||
def initBorrow (ps : Array Param) : Array Param :=
|
||||
ps.map $ fun p => { p with borrow := p.ty.isObj }
|
||||
ps.map $ fun p => { p with borrow := p.ty.isObj }
|
||||
|
||||
/- We do perform borrow inference for constants marked as `export`.
|
||||
Reason: we current write wrappers in C++ for using exported functions.
|
||||
|
|
@ -80,26 +81,26 @@ ps.map $ fun p => { p with borrow := p.ty.isObj }
|
|||
We can revise this decision when we implement code for generating
|
||||
the wrappers automatically. -/
|
||||
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
|
||||
if exported then ps else initBorrow ps
|
||||
if exported then ps else initBorrow ps
|
||||
|
||||
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
|
||||
| FnBody.jdecl j xs v b => do
|
||||
modify fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs)
|
||||
visitFnBody fnid v
|
||||
visitFnBody fnid b
|
||||
| FnBody.case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body
|
||||
| e => do
|
||||
unless e.isTerminal do
|
||||
let (instr, b) := e.split
|
||||
| FnBody.jdecl j xs v b => do
|
||||
modify fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs)
|
||||
visitFnBody fnid v
|
||||
visitFnBody fnid b
|
||||
| FnBody.case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body
|
||||
| e => do
|
||||
unless e.isTerminal do
|
||||
let (instr, b) := e.split
|
||||
visitFnBody fnid b
|
||||
|
||||
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
|
||||
decls.forM fun decl => match decl with
|
||||
| Decl.fdecl f xs _ b => do
|
||||
let exported := isExport env f
|
||||
modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs)
|
||||
visitFnBody f b
|
||||
| _ => pure ()
|
||||
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
|
||||
decls.forM fun decl => match decl with
|
||||
| Decl.fdecl f xs _ b => do
|
||||
let exported := isExport env f
|
||||
modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs)
|
||||
visitFnBody f b
|
||||
| _ => pure ()
|
||||
end InitParamMap
|
||||
|
||||
def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap :=
|
||||
|
|
@ -110,111 +111,111 @@ def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap :=
|
|||
namespace ApplyParamMap
|
||||
|
||||
partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
|
||||
| FnBody.jdecl j xs v b =>
|
||||
let v := visitFnBody fn paramMap v
|
||||
let b := visitFnBody fn paramMap b
|
||||
match paramMap.find? (ParamMap.Key.jp fn j) with
|
||||
| some ys => FnBody.jdecl j ys v b
|
||||
| none => unreachable!
|
||||
| FnBody.case tid x xType alts =>
|
||||
FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (visitFnBody fn paramMap)
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
| FnBody.jdecl j xs v b =>
|
||||
let v := visitFnBody fn paramMap v
|
||||
let b := visitFnBody fn paramMap b
|
||||
instr.setBody b
|
||||
match paramMap.find? (ParamMap.Key.jp fn j) with
|
||||
| some ys => FnBody.jdecl j ys v b
|
||||
| none => unreachable!
|
||||
| FnBody.case tid x xType alts =>
|
||||
FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (visitFnBody fn paramMap)
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := visitFnBody fn paramMap b
|
||||
instr.setBody b
|
||||
|
||||
def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
|
||||
decls.map fun decl => match decl with
|
||||
| Decl.fdecl f xs ty b =>
|
||||
let b := visitFnBody f paramMap b
|
||||
match paramMap.find? (ParamMap.Key.decl f) with
|
||||
| some xs => Decl.fdecl f xs ty b
|
||||
| none => unreachable!
|
||||
| other => other
|
||||
decls.map fun decl => match decl with
|
||||
| Decl.fdecl f xs ty b =>
|
||||
let b := visitFnBody f paramMap b
|
||||
match paramMap.find? (ParamMap.Key.decl f) with
|
||||
| some xs => Decl.fdecl f xs ty b
|
||||
| none => unreachable!
|
||||
| other => other
|
||||
|
||||
end ApplyParamMap
|
||||
|
||||
def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl :=
|
||||
-- dbgTrace ("applyParamMap " ++ toString map) $ fun _ =>
|
||||
ApplyParamMap.visitDecls decls map
|
||||
-- dbgTrace ("applyParamMap " ++ toString map) $ fun _ =>
|
||||
ApplyParamMap.visitDecls decls map
|
||||
|
||||
structure BorrowInfCtx :=
|
||||
(env : Environment)
|
||||
(currFn : FunId := arbitrary _) -- Function being analyzed.
|
||||
(paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams`
|
||||
(env : Environment)
|
||||
(currFn : FunId := arbitrary _) -- Function being analyzed.
|
||||
(paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams`
|
||||
|
||||
structure BorrowInfState :=
|
||||
/- Set of variables that must be `owned`. -/
|
||||
(owned : OwnedSet := {})
|
||||
(modified : Bool := false)
|
||||
(paramMap : ParamMap)
|
||||
/- Set of variables that must be `owned`. -/
|
||||
(owned : OwnedSet := {})
|
||||
(modified : Bool := false)
|
||||
(paramMap : ParamMap)
|
||||
|
||||
abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState)
|
||||
|
||||
def getCurrFn : M FunId := do
|
||||
let ctx ← read
|
||||
pure ctx.currFn
|
||||
let ctx ← read
|
||||
pure ctx.currFn
|
||||
|
||||
def markModified : M Unit :=
|
||||
modify $ fun s => { s with modified := true }
|
||||
modify fun s => { s with modified := true }
|
||||
|
||||
def ownVar (x : VarId) : M Unit := do
|
||||
-- dbgTrace ("ownVar " ++ toString x) $ fun _ =>
|
||||
let currFn ← getCurrFn
|
||||
modify fun s =>
|
||||
if s.owned.contains (currFn, x.idx) then s
|
||||
else { s with owned := s.owned.insert (currFn, x.idx), modified := true }
|
||||
-- dbgTrace ("ownVar " ++ toString x) $ fun _ =>
|
||||
let currFn ← getCurrFn
|
||||
modify fun s =>
|
||||
if s.owned.contains (currFn, x.idx) then s
|
||||
else { s with owned := s.owned.insert (currFn, x.idx), modified := true }
|
||||
|
||||
def ownArg (x : Arg) : M Unit :=
|
||||
match x with
|
||||
| Arg.var x => ownVar x
|
||||
| _ => pure ()
|
||||
match x with
|
||||
| Arg.var x => ownVar x
|
||||
| _ => pure ()
|
||||
|
||||
def ownArgs (xs : Array Arg) : M Unit :=
|
||||
xs.forM ownArg
|
||||
xs.forM ownArg
|
||||
|
||||
def isOwned (x : VarId) : M Bool := do
|
||||
let currFn ← getCurrFn
|
||||
let s ← get
|
||||
pure $ s.owned.contains (currFn, x.idx)
|
||||
let currFn ← getCurrFn
|
||||
let s ← get
|
||||
pure $ s.owned.contains (currFn, x.idx)
|
||||
|
||||
/- Updates `map[k]` using the current set of `owned` variables. -/
|
||||
def updateParamMap (k : ParamMap.Key) : M Unit := do
|
||||
let currFn ← getCurrFn
|
||||
let s ← get
|
||||
match s.paramMap.find? k with
|
||||
| some ps => do
|
||||
let ps ← ps.mapM fun (p : Param) => do
|
||||
if !p.borrow then pure p
|
||||
else if (← isOwned p.x) then
|
||||
markModified
|
||||
pure { p with borrow := false }
|
||||
else
|
||||
pure p
|
||||
modify fun s => { s with paramMap := s.paramMap.insert k ps }
|
||||
| none => pure ()
|
||||
let currFn ← getCurrFn
|
||||
let s ← get
|
||||
match s.paramMap.find? k with
|
||||
| some ps => do
|
||||
let ps ← ps.mapM fun (p : Param) => do
|
||||
if !p.borrow then pure p
|
||||
else if (← isOwned p.x) then
|
||||
markModified
|
||||
pure { p with borrow := false }
|
||||
else
|
||||
pure p
|
||||
modify fun s => { s with paramMap := s.paramMap.insert k ps }
|
||||
| none => pure ()
|
||||
|
||||
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
|
||||
let s ← get
|
||||
match s.paramMap.find? k with
|
||||
| some ps => pure ps
|
||||
| none =>
|
||||
match k with
|
||||
| ParamMap.Key.decl fn => do
|
||||
let ctx ← read
|
||||
match findEnvDecl ctx.env fn with
|
||||
| some decl => pure decl.params
|
||||
| none => unreachable!
|
||||
| _ => unreachable!
|
||||
let s ← get
|
||||
match s.paramMap.find? k with
|
||||
| some ps => pure ps
|
||||
| none =>
|
||||
match k with
|
||||
| ParamMap.Key.decl fn => do
|
||||
let ctx ← read
|
||||
match findEnvDecl ctx.env fn with
|
||||
| some decl => pure decl.params
|
||||
| none => unreachable!
|
||||
| _ => unreachable!
|
||||
|
||||
/- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
|
||||
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]
|
||||
let p := ps[i]
|
||||
unless p.borrow do ownArg x
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]
|
||||
let p := ps[i]
|
||||
unless p.borrow do ownArg x
|
||||
|
||||
/- For each xs[i], if xs[i] is owned, then mark ps[i] as owned.
|
||||
We use this action to preserve tail calls. That is, if we have
|
||||
|
|
@ -222,12 +223,12 @@ xs.size.forM fun i => do
|
|||
we would have to insert a `dec xs[i]` after `f xs` and consequently
|
||||
"break" the tail call. -/
|
||||
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]
|
||||
let p := ps[i]
|
||||
match x with
|
||||
| Arg.var x => if (← isOwned x) then ownVar p.x
|
||||
| _ => pure ()
|
||||
xs.size.forM fun i => do
|
||||
let x := xs[i]
|
||||
let p := ps[i]
|
||||
match x with
|
||||
| Arg.var x => if (← isOwned x) then ownVar p.x
|
||||
| _ => pure ()
|
||||
|
||||
/- Mark `xs[i]` as owned if it is one of the parameters `ps`.
|
||||
We use this action to mark function parameters that are being "packed" inside constructors.
|
||||
|
|
@ -239,85 +240,85 @@ xs.size.forM fun i => do
|
|||
ret z
|
||||
``` -/
|
||||
def ownArgsIfParam (xs : Array Arg) : M Unit := do
|
||||
let ctx ← read
|
||||
xs.forM fun x => do
|
||||
match x with
|
||||
| Arg.var x => if ctx.paramSet.contains x.idx then ownVar x
|
||||
| _ => pure ()
|
||||
let ctx ← read
|
||||
xs.forM fun x => do
|
||||
match x with
|
||||
| Arg.var x => if ctx.paramSet.contains x.idx then ownVar x
|
||||
| _ => pure ()
|
||||
|
||||
def collectExpr (z : VarId) : Expr → M Unit
|
||||
| Expr.reset _ x => ownVar z *> ownVar x
|
||||
| Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
|
||||
| Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs
|
||||
| Expr.proj _ x => do
|
||||
if (← isOwned x) then ownVar z
|
||||
if (← isOwned z) then ownVar x
|
||||
| Expr.fap g xs => do
|
||||
let ps ← getParamInfo (ParamMap.Key.decl g)
|
||||
ownVar z *> ownArgsUsingParams xs ps
|
||||
| Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys
|
||||
| Expr.pap _ xs => ownVar z *> ownArgs xs
|
||||
| other => pure ()
|
||||
| Expr.reset _ x => ownVar z *> ownVar x
|
||||
| Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
|
||||
| Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs
|
||||
| Expr.proj _ x => do
|
||||
if (← isOwned x) then ownVar z
|
||||
if (← isOwned z) then ownVar x
|
||||
| Expr.fap g xs => do
|
||||
let ps ← getParamInfo (ParamMap.Key.decl g)
|
||||
ownVar z *> ownArgsUsingParams xs ps
|
||||
| Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys
|
||||
| Expr.pap _ xs => ownVar z *> ownArgs xs
|
||||
| other => pure ()
|
||||
|
||||
def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do
|
||||
let ctx ← read
|
||||
match v, b with
|
||||
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) =>
|
||||
if ctx.currFn == g && x == z then
|
||||
-- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do
|
||||
let ps ← getParamInfo (ParamMap.Key.decl g)
|
||||
ownParamsUsingArgs ys ps
|
||||
| _, _ => pure ()
|
||||
let ctx ← read
|
||||
match v, b with
|
||||
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) =>
|
||||
if ctx.currFn == g && x == z then
|
||||
-- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do
|
||||
let ps ← getParamInfo (ParamMap.Key.decl g)
|
||||
ownParamsUsingArgs ys ps
|
||||
| _, _ => pure ()
|
||||
|
||||
def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx :=
|
||||
{ ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet }
|
||||
{ ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet }
|
||||
|
||||
partial def collectFnBody : FnBody → M Unit
|
||||
| FnBody.jdecl j ys v b => do
|
||||
withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v)
|
||||
let ctx ← read
|
||||
updateParamMap (ParamMap.Key.jp ctx.currFn j)
|
||||
collectFnBody b
|
||||
| FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
|
||||
| FnBody.jmp j ys => do
|
||||
let ctx ← read
|
||||
let ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j)
|
||||
ownArgsUsingParams ys ps -- for making sure the join point can reuse
|
||||
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
|
||||
| FnBody.case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body
|
||||
| e => do unless e.isTerminal do collectFnBody e.body
|
||||
| FnBody.jdecl j ys v b => do
|
||||
withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v)
|
||||
let ctx ← read
|
||||
updateParamMap (ParamMap.Key.jp ctx.currFn j)
|
||||
collectFnBody b
|
||||
| FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
|
||||
| FnBody.jmp j ys => do
|
||||
let ctx ← read
|
||||
let ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j)
|
||||
ownArgsUsingParams ys ps -- for making sure the join point can reuse
|
||||
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
|
||||
| FnBody.case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body
|
||||
| e => do unless e.isTerminal do collectFnBody e.body
|
||||
|
||||
partial def collectDecl : Decl → M Unit
|
||||
| Decl.fdecl f ys _ b =>
|
||||
withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do
|
||||
collectFnBody b
|
||||
updateParamMap (ParamMap.Key.decl f)
|
||||
| _ => pure ()
|
||||
| Decl.fdecl f ys _ b =>
|
||||
withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do
|
||||
collectFnBody b
|
||||
updateParamMap (ParamMap.Key.decl f)
|
||||
| _ => pure ()
|
||||
|
||||
/- Keep executing `x` until it reaches a fixpoint -/
|
||||
@[inline] partial def whileModifing (x : M Unit) : M Unit := do
|
||||
modify fun s => { s with modified := false }
|
||||
x
|
||||
let s ← get
|
||||
if s.modified then
|
||||
whileModifing x
|
||||
else
|
||||
pure ()
|
||||
modify fun s => { s with modified := false }
|
||||
x
|
||||
let s ← get
|
||||
if s.modified then
|
||||
whileModifing x
|
||||
else
|
||||
pure ()
|
||||
|
||||
def collectDecls (decls : Array Decl) : M ParamMap := do
|
||||
whileModifing (decls.forM collectDecl)
|
||||
let s ← get
|
||||
pure s.paramMap
|
||||
whileModifing (decls.forM collectDecl)
|
||||
let s ← get
|
||||
pure s.paramMap
|
||||
|
||||
def infer (env : Environment) (decls : Array Decl) : ParamMap :=
|
||||
collectDecls decls { env := env } $.run' { paramMap := mkInitParamMap env decls }
|
||||
collectDecls decls { env := env } $.run' { paramMap := mkInitParamMap env decls }
|
||||
|
||||
end Borrow
|
||||
|
||||
def inferBorrow (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let env ← getEnv
|
||||
let paramMap := Borrow.infer env decls
|
||||
pure (Borrow.applyParamMap decls paramMap)
|
||||
let env ← getEnv
|
||||
let paramMap := Borrow.infer env decls
|
||||
pure (Borrow.applyParamMap decls paramMap)
|
||||
|
||||
end IR
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -30,314 +30,320 @@ Assumptions:
|
|||
open Std (AssocList)
|
||||
|
||||
def mkBoxedName (n : Name) : Name :=
|
||||
mkNameStr n "_boxed"
|
||||
mkNameStr n "_boxed"
|
||||
|
||||
def isBoxedName : Name → Bool
|
||||
| Name.str _ "_boxed" _ => true
|
||||
| _ => false
|
||||
| Name.str _ "_boxed" _ => true
|
||||
| _ => false
|
||||
|
||||
abbrev N := StateM Nat
|
||||
|
||||
private def N.mkFresh : N VarId :=
|
||||
modifyGet fun n => ({ idx := n }, n + 1)
|
||||
modifyGet fun n => ({ idx := n }, n + 1)
|
||||
|
||||
def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
|
||||
let ps := decl.params
|
||||
(ps.size > 0 && (decl.resultType.isScalar || ps.any (fun p => p.ty.isScalar || p.borrow) || isExtern env decl.name))
|
||||
|| ps.size > closureMaxArgs
|
||||
let ps := decl.params
|
||||
(ps.size > 0 && (decl.resultType.isScalar || ps.any (fun p => p.ty.isScalar || p.borrow) || isExtern env decl.name))
|
||||
|| ps.size > closureMaxArgs
|
||||
|
||||
def mkBoxedVersionAux (decl : Decl) : N Decl := do
|
||||
let ps := decl.params
|
||||
let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
|
||||
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
|
||||
let p := ps[i]
|
||||
let q := qs[i]
|
||||
if !p.ty.isScalar then
|
||||
pure (newVDecls, xs.push (Arg.var q.x))
|
||||
else
|
||||
let x ← N.mkFresh
|
||||
pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (arbitrary _)), xs.push (Arg.var x))
|
||||
let r ← N.mkFresh
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (arbitrary _))
|
||||
let body ←
|
||||
if !decl.resultType.isScalar then
|
||||
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
|
||||
else
|
||||
let newR ← N.mkFresh
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (arbitrary _))
|
||||
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
|
||||
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
|
||||
let ps := decl.params
|
||||
let qs ← ps.mapM fun _ => do let x ← N.mkFresh; pure { x := x, ty := IRType.object, borrow := false : Param }
|
||||
let (newVDecls, xs) ← qs.size.foldM (init := (#[], #[])) fun i (newVDecls, xs) => do
|
||||
let p := ps[i]
|
||||
let q := qs[i]
|
||||
if !p.ty.isScalar then
|
||||
pure (newVDecls, xs.push (Arg.var q.x))
|
||||
else
|
||||
let x ← N.mkFresh
|
||||
pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (arbitrary _)), xs.push (Arg.var x))
|
||||
let r ← N.mkFresh
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (arbitrary _))
|
||||
let body ←
|
||||
if !decl.resultType.isScalar then
|
||||
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
|
||||
else
|
||||
let newR ← N.mkFresh
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (arbitrary _))
|
||||
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
|
||||
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
|
||||
|
||||
def mkBoxedVersion (decl : Decl) : Decl :=
|
||||
(mkBoxedVersionAux decl).run' 1
|
||||
(mkBoxedVersionAux decl).run' 1
|
||||
|
||||
def addBoxedVersions (env : Environment) (decls : Array Decl) : Array Decl :=
|
||||
let boxedDecls := decls.foldl (init := #[]) fun newDecls decl =>
|
||||
if requiresBoxedVersion env decl then newDecls.push (mkBoxedVersion decl) else newDecls
|
||||
decls ++ boxedDecls
|
||||
let boxedDecls := decls.foldl (init := #[]) fun newDecls decl =>
|
||||
if requiresBoxedVersion env decl then newDecls.push (mkBoxedVersion decl) else newDecls
|
||||
decls ++ boxedDecls
|
||||
|
||||
/- Infer scrutinee type using `case` alternatives.
|
||||
This can be done whenever `alts` does not contain an `Alt.default _` value. -/
|
||||
def getScrutineeType (alts : Array Alt) : IRType :=
|
||||
let isScalar :=
|
||||
alts.size > 1 && -- Recall that we encode Unit and PUnit using `object`.
|
||||
alts.all fun
|
||||
| Alt.ctor c _ => c.isScalar
|
||||
| Alt.default _ => false
|
||||
match isScalar with
|
||||
| false => IRType.object
|
||||
| true =>
|
||||
let n := alts.size
|
||||
if n < 256 then IRType.uint8
|
||||
else if n < 65536 then IRType.uint16
|
||||
else if n < 4294967296 then IRType.uint32
|
||||
else IRType.object -- in practice this should be unreachable
|
||||
let isScalar :=
|
||||
alts.size > 1 && -- Recall that we encode Unit and PUnit using `object`.
|
||||
alts.all fun
|
||||
| Alt.ctor c _ => c.isScalar
|
||||
| Alt.default _ => false
|
||||
match isScalar with
|
||||
| false => IRType.object
|
||||
| true =>
|
||||
let n := alts.size
|
||||
if n < 256 then IRType.uint8
|
||||
else if n < 65536 then IRType.uint16
|
||||
else if n < 4294967296 then IRType.uint32
|
||||
else IRType.object -- in practice this should be unreachable
|
||||
|
||||
def eqvTypes (t₁ t₂ : IRType) : Bool :=
|
||||
(t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂)
|
||||
(t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂)
|
||||
|
||||
structure BoxingContext :=
|
||||
(f : FunId := arbitrary _) (localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment)
|
||||
(f : FunId := arbitrary _)
|
||||
(localCtx : LocalContext := {})
|
||||
(resultType : IRType := IRType.irrelevant)
|
||||
(decls : Array Decl)
|
||||
(env : Environment)
|
||||
|
||||
structure BoxingState :=
|
||||
(nextIdx : Index)
|
||||
/- We create auxiliary declarations when boxing constant and literals.
|
||||
The idea is to avoid code such as
|
||||
```
|
||||
let x1 := Uint64.inhabited;
|
||||
let x2 := box x1;
|
||||
...
|
||||
```
|
||||
We currently do not cache these declarations in an environment extension, but
|
||||
we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when
|
||||
processing the same IR declaration.
|
||||
-/
|
||||
(auxDecls : Array Decl := #[])
|
||||
(auxDeclCache : AssocList FnBody Expr := Std.AssocList.empty)
|
||||
(nextAuxId : Nat := 1)
|
||||
(nextIdx : Index)
|
||||
/- We create auxiliary declarations when boxing constant and literals.
|
||||
The idea is to avoid code such as
|
||||
```
|
||||
let x1 := Uint64.inhabited;
|
||||
let x2 := box x1;
|
||||
...
|
||||
```
|
||||
We currently do not cache these declarations in an environment extension, but
|
||||
we use auxDeclCache to avoid creating equivalent auxiliary declarations more than once when
|
||||
processing the same IR declaration.
|
||||
-/
|
||||
(auxDecls : Array Decl := #[])
|
||||
(auxDeclCache : AssocList FnBody Expr := Std.AssocList.empty)
|
||||
(nextAuxId : Nat := 1)
|
||||
|
||||
abbrev M := ReaderT BoxingContext (StateT BoxingState Id)
|
||||
|
||||
private def M.mkFresh : M VarId := do
|
||||
let oldS ← getModify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||||
pure { idx := oldS.nextIdx }
|
||||
let oldS ← getModify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||||
pure { idx := oldS.nextIdx }
|
||||
|
||||
def getEnv : M Environment := BoxingContext.env <$> read
|
||||
def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read
|
||||
def getResultType : M IRType := BoxingContext.resultType <$> read
|
||||
|
||||
def getVarType (x : VarId) : M IRType := do
|
||||
let localCtx ← getLocalContext
|
||||
match localCtx.getType x with
|
||||
| some t => pure t
|
||||
| none => pure IRType.object -- unreachable, we assume the code is well formed
|
||||
let localCtx ← getLocalContext
|
||||
match localCtx.getType x with
|
||||
| some t => pure t
|
||||
| none => pure IRType.object -- unreachable, we assume the code is well formed
|
||||
|
||||
def getJPParams (j : JoinPointId) : M (Array Param) := do
|
||||
let localCtx ← getLocalContext
|
||||
match localCtx.getJPParams j with
|
||||
| some ys => pure ys
|
||||
| none => pure #[] -- unreachable, we assume the code is well formed
|
||||
let localCtx ← getLocalContext
|
||||
match localCtx.getJPParams j with
|
||||
| some ys => pure ys
|
||||
| none => pure #[] -- unreachable, we assume the code is well formed
|
||||
|
||||
def getDecl (fid : FunId) : M Decl := do
|
||||
let ctx ← read
|
||||
match findEnvDecl' ctx.env fid ctx.decls with
|
||||
| some decl => pure decl
|
||||
| none => pure (arbitrary _) -- unreachable if well-formed
|
||||
let ctx ← read
|
||||
match findEnvDecl' ctx.env fid ctx.decls with
|
||||
| some decl => pure decl
|
||||
| none => pure (arbitrary _) -- unreachable if well-formed
|
||||
|
||||
@[inline] def withParams {α : Type} (xs : Array Param) (k : M α) : M α :=
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addParams xs }) k
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addParams xs }) k
|
||||
|
||||
@[inline] def withVDecl {α : Type} (x : VarId) (ty : IRType) (v : Expr) (k : M α) : M α :=
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x ty v }) k
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x ty v }) k
|
||||
|
||||
@[inline] def withJDecl {α : Type} (j : JoinPointId) (xs : Array Param) (v : FnBody) (k : M α) : M α :=
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j xs v }) k
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j xs v }) k
|
||||
|
||||
/- If `x` declaration is of the form `x := Expr.lit _` or `x := Expr.fap c #[]`,
|
||||
and `x`'s type is not cheap to box (e.g., it is `UInt64), then return its value. -/
|
||||
private def isExpensiveConstantValueBoxing (x : VarId) (xType : IRType) : M (Option Expr) :=
|
||||
if !xType.isScalar then pure none -- We assume unboxing is always cheap
|
||||
else match xType with
|
||||
| IRType.uint8 => pure none
|
||||
| IRType.uint16 => pure none
|
||||
| _ => do
|
||||
let localCtx ← getLocalContext
|
||||
match localCtx.getValue x with
|
||||
| some val =>
|
||||
match val with
|
||||
| Expr.lit _ => pure $ some val
|
||||
| Expr.fap _ args => pure $ if args.size == 0 then some val else none
|
||||
| _ => pure none
|
||||
| _ => pure none
|
||||
if !xType.isScalar then
|
||||
pure none -- We assume unboxing is always cheap
|
||||
else match xType with
|
||||
| IRType.uint8 => pure none
|
||||
| IRType.uint16 => pure none
|
||||
| _ => do
|
||||
let localCtx ← getLocalContext
|
||||
match localCtx.getValue x with
|
||||
| some val =>
|
||||
match val with
|
||||
| Expr.lit _ => pure $ some val
|
||||
| Expr.fap _ args => pure $ if args.size == 0 then some val else none
|
||||
| _ => pure none
|
||||
| _ => pure none
|
||||
|
||||
/- Auxiliary function used by castVarIfNeeded.
|
||||
It is used when the expected type does not match `xType`.
|
||||
If `xType` is scalar, then we need to "box" it. Otherwise, we need to "unbox" it. -/
|
||||
def mkCast (x : VarId) (xType : IRType) (expectedType : IRType) : M Expr := do
|
||||
match (← isExpensiveConstantValueBoxing x xType) with
|
||||
| some v => do
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
/- Create auxiliary FnBody
|
||||
```
|
||||
let x_1 : xType := v;
|
||||
let x_2 : expectedType := Expr.box xType x_1;
|
||||
ret x_2
|
||||
```
|
||||
-/
|
||||
let body : FnBody :=
|
||||
FnBody.vdecl { idx := 1 } xType v $
|
||||
FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $
|
||||
FnBody.ret (mkVarArg { idx := 2 })
|
||||
match s.auxDeclCache.find? body with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId)
|
||||
let auxConst := Expr.fap auxName #[]
|
||||
let auxDecl := Decl.fdecl auxName #[] expectedType body
|
||||
modify fun s => {
|
||||
s with
|
||||
auxDecls := s.auxDecls.push auxDecl,
|
||||
auxDeclCache := s.auxDeclCache.cons body auxConst,
|
||||
nextAuxId := s.nextAuxId + 1
|
||||
}
|
||||
pure auxConst
|
||||
| none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x
|
||||
match (← isExpensiveConstantValueBoxing x xType) with
|
||||
| some v => do
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
/- Create auxiliary FnBody
|
||||
```
|
||||
let x_1 : xType := v;
|
||||
let x_2 : expectedType := Expr.box xType x_1;
|
||||
ret x_2
|
||||
```
|
||||
-/
|
||||
let body : FnBody :=
|
||||
FnBody.vdecl { idx := 1 } xType v $
|
||||
FnBody.vdecl { idx := 2 } expectedType (Expr.box xType { idx := 1 }) $
|
||||
FnBody.ret (mkVarArg { idx := 2 })
|
||||
match s.auxDeclCache.find? body with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let auxName := ctx.f ++ ((`_boxed_const).appendIndexAfter s.nextAuxId)
|
||||
let auxConst := Expr.fap auxName #[]
|
||||
let auxDecl := Decl.fdecl auxName #[] expectedType body
|
||||
modify fun s => {
|
||||
s with
|
||||
auxDecls := s.auxDecls.push auxDecl,
|
||||
auxDeclCache := s.auxDeclCache.cons body auxConst,
|
||||
nextAuxId := s.nextAuxId + 1
|
||||
}
|
||||
pure auxConst
|
||||
| none => pure $ if xType.isScalar then Expr.box xType x else Expr.unbox x
|
||||
|
||||
@[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody := do
|
||||
let xType ← getVarType x
|
||||
if eqvTypes xType expected then k x
|
||||
else
|
||||
let y ← M.mkFresh
|
||||
let v ← mkCast x xType expected
|
||||
FnBody.vdecl y expected v <$> k y
|
||||
let xType ← getVarType x
|
||||
if eqvTypes xType expected then
|
||||
k x
|
||||
else
|
||||
let y ← M.mkFresh
|
||||
let v ← mkCast x xType expected
|
||||
FnBody.vdecl y expected v <$> k y
|
||||
|
||||
@[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody :=
|
||||
match x with
|
||||
| Arg.var x => castVarIfNeeded x expected (fun x => k (Arg.var x))
|
||||
| _ => k x
|
||||
match x with
|
||||
| Arg.var x => castVarIfNeeded x expected (fun x => k (Arg.var x))
|
||||
| _ => k x
|
||||
|
||||
@[specialize] def castArgsIfNeededAux (xs : Array Arg) (typeFromIdx : Nat → IRType) : M (Array Arg × Array FnBody) := do
|
||||
let xs' := #[]
|
||||
let bs := #[]
|
||||
let i := 0
|
||||
for x in xs do
|
||||
let expected := typeFromIdx i
|
||||
match x with
|
||||
| Arg.irrelevant =>
|
||||
xs' := xs'.push x
|
||||
| Arg.var x =>
|
||||
let xType ← getVarType x
|
||||
if eqvTypes xType expected then
|
||||
xs' := xs'.push (Arg.var x)
|
||||
else
|
||||
let y ← M.mkFresh
|
||||
let v ← mkCast x xType expected
|
||||
let b := FnBody.vdecl y expected v FnBody.nil
|
||||
xs' := xs'.push (Arg.var y)
|
||||
bs := bs.push b
|
||||
i := i + 1
|
||||
return (xs', bs)
|
||||
let xs' := #[]
|
||||
let bs := #[]
|
||||
let i := 0
|
||||
for x in xs do
|
||||
let expected := typeFromIdx i
|
||||
match x with
|
||||
| Arg.irrelevant =>
|
||||
xs' := xs'.push x
|
||||
| Arg.var x =>
|
||||
let xType ← getVarType x
|
||||
if eqvTypes xType expected then
|
||||
xs' := xs'.push (Arg.var x)
|
||||
else
|
||||
let y ← M.mkFresh
|
||||
let v ← mkCast x xType expected
|
||||
let b := FnBody.vdecl y expected v FnBody.nil
|
||||
xs' := xs'.push (Arg.var y)
|
||||
bs := bs.push b
|
||||
i := i + 1
|
||||
return (xs', bs)
|
||||
|
||||
@[inline] def castArgsIfNeeded (xs : Array Arg) (ps : Array Param) (k : Array Arg → M FnBody) : M FnBody := do
|
||||
let (ys, bs) ← castArgsIfNeededAux xs fun i => ps[i].ty
|
||||
let b ← k ys
|
||||
pure (reshape bs b)
|
||||
let (ys, bs) ← castArgsIfNeededAux xs fun i => ps[i].ty
|
||||
let b ← k ys
|
||||
pure (reshape bs b)
|
||||
|
||||
@[inline] def boxArgsIfNeeded (xs : Array Arg) (k : Array Arg → M FnBody) : M FnBody := do
|
||||
let (ys, bs) ← castArgsIfNeededAux xs (fun _ => IRType.object)
|
||||
let b ← k ys
|
||||
pure (reshape bs b)
|
||||
let (ys, bs) ← castArgsIfNeededAux xs (fun _ => IRType.object)
|
||||
let b ← k ys
|
||||
pure (reshape bs b)
|
||||
|
||||
def unboxResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody := do
|
||||
if ty.isScalar then
|
||||
let y ← M.mkFresh
|
||||
pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b)
|
||||
else
|
||||
pure $ FnBody.vdecl x ty e b
|
||||
if ty.isScalar then
|
||||
let y ← M.mkFresh
|
||||
pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b)
|
||||
else
|
||||
pure $ FnBody.vdecl x ty e b
|
||||
|
||||
def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b : FnBody) : M FnBody := do
|
||||
if eqvTypes ty eType then
|
||||
pure $ FnBody.vdecl x ty e b
|
||||
else
|
||||
let y ← M.mkFresh
|
||||
let v ← mkCast y eType ty
|
||||
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b)
|
||||
if eqvTypes ty eType then
|
||||
pure $ FnBody.vdecl x ty e b
|
||||
else
|
||||
let y ← M.mkFresh
|
||||
let v ← mkCast y eType ty
|
||||
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty v b)
|
||||
|
||||
def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
|
||||
match e with
|
||||
| Expr.ctor c ys =>
|
||||
if c.isScalar && ty.isScalar then
|
||||
pure $ FnBody.vdecl x ty (Expr.lit (LitVal.num c.cidx)) b
|
||||
else
|
||||
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.ctor c ys) b
|
||||
| Expr.reuse w c u ys =>
|
||||
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b
|
||||
| Expr.fap f ys => do
|
||||
let decl ← getDecl f
|
||||
castArgsIfNeeded ys decl.params fun ys =>
|
||||
castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b
|
||||
| Expr.pap f ys => do
|
||||
let env ← getEnv
|
||||
let decl ← getDecl f
|
||||
let f := if requiresBoxedVersion env decl then mkBoxedName f else f
|
||||
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.pap f ys) b
|
||||
| Expr.ap f ys =>
|
||||
boxArgsIfNeeded ys fun ys =>
|
||||
unboxResultIfNeeded x ty (Expr.ap f ys) b
|
||||
| other =>
|
||||
pure $ FnBody.vdecl x ty e b
|
||||
match e with
|
||||
| Expr.ctor c ys =>
|
||||
if c.isScalar && ty.isScalar then
|
||||
pure $ FnBody.vdecl x ty (Expr.lit (LitVal.num c.cidx)) b
|
||||
else
|
||||
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.ctor c ys) b
|
||||
| Expr.reuse w c u ys =>
|
||||
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b
|
||||
| Expr.fap f ys => do
|
||||
let decl ← getDecl f
|
||||
castArgsIfNeeded ys decl.params fun ys =>
|
||||
castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b
|
||||
| Expr.pap f ys => do
|
||||
let env ← getEnv
|
||||
let decl ← getDecl f
|
||||
let f := if requiresBoxedVersion env decl then mkBoxedName f else f
|
||||
boxArgsIfNeeded ys fun ys => pure $ FnBody.vdecl x ty (Expr.pap f ys) b
|
||||
| Expr.ap f ys =>
|
||||
boxArgsIfNeeded ys fun ys =>
|
||||
unboxResultIfNeeded x ty (Expr.ap f ys) b
|
||||
| other =>
|
||||
pure $ FnBody.vdecl x ty e b
|
||||
|
||||
partial def visitFnBody : FnBody → M FnBody
|
||||
| FnBody.vdecl x t v b => do
|
||||
let b ← withVDecl x t v (visitFnBody b)
|
||||
visitVDeclExpr x t v b
|
||||
| FnBody.jdecl j xs v b => do
|
||||
let v ← withParams xs (visitFnBody v)
|
||||
let b ← withJDecl j xs v (visitFnBody b)
|
||||
pure $ FnBody.jdecl j xs v b
|
||||
| FnBody.uset x i y b => do
|
||||
let b ← visitFnBody b
|
||||
castVarIfNeeded y IRType.usize fun y =>
|
||||
pure $ FnBody.uset x i y b
|
||||
| FnBody.sset x i o y ty b => do
|
||||
let b ← visitFnBody b
|
||||
castVarIfNeeded y ty fun y =>
|
||||
pure $ FnBody.sset x i o y ty b
|
||||
| FnBody.mdata d b =>
|
||||
FnBody.mdata d <$> visitFnBody b
|
||||
| FnBody.case tid x _ alts => do
|
||||
let expected := getScrutineeType alts
|
||||
let alts ← alts.mapM fun alt => alt.mmodifyBody visitFnBody
|
||||
castVarIfNeeded x expected fun x => do
|
||||
pure $ FnBody.case tid x expected alts
|
||||
| FnBody.ret x => do
|
||||
let expected ← getResultType
|
||||
castArgIfNeeded x expected (fun x => pure $ FnBody.ret x)
|
||||
| FnBody.jmp j ys => do
|
||||
let ps ← getJPParams j
|
||||
castArgsIfNeeded ys ps fun ys => pure $ FnBody.jmp j ys
|
||||
| other =>
|
||||
pure other
|
||||
| FnBody.vdecl x t v b => do
|
||||
let b ← withVDecl x t v (visitFnBody b)
|
||||
visitVDeclExpr x t v b
|
||||
| FnBody.jdecl j xs v b => do
|
||||
let v ← withParams xs (visitFnBody v)
|
||||
let b ← withJDecl j xs v (visitFnBody b)
|
||||
pure $ FnBody.jdecl j xs v b
|
||||
| FnBody.uset x i y b => do
|
||||
let b ← visitFnBody b
|
||||
castVarIfNeeded y IRType.usize fun y =>
|
||||
pure $ FnBody.uset x i y b
|
||||
| FnBody.sset x i o y ty b => do
|
||||
let b ← visitFnBody b
|
||||
castVarIfNeeded y ty fun y =>
|
||||
pure $ FnBody.sset x i o y ty b
|
||||
| FnBody.mdata d b =>
|
||||
FnBody.mdata d <$> visitFnBody b
|
||||
| FnBody.case tid x _ alts => do
|
||||
let expected := getScrutineeType alts
|
||||
let alts ← alts.mapM fun alt => alt.mmodifyBody visitFnBody
|
||||
castVarIfNeeded x expected fun x => do
|
||||
pure $ FnBody.case tid x expected alts
|
||||
| FnBody.ret x => do
|
||||
let expected ← getResultType
|
||||
castArgIfNeeded x expected (fun x => pure $ FnBody.ret x)
|
||||
| FnBody.jmp j ys => do
|
||||
let ps ← getJPParams j
|
||||
castArgsIfNeeded ys ps fun ys => pure $ FnBody.jmp j ys
|
||||
| other =>
|
||||
pure other
|
||||
|
||||
def run (env : Environment) (decls : Array Decl) : Array Decl :=
|
||||
let ctx : BoxingContext := { decls := decls, env := env }
|
||||
let decls := decls.foldl (init := #[]) fun newDecls decl =>
|
||||
match decl with
|
||||
| Decl.fdecl f xs t b =>
|
||||
let nextIdx := decl.maxIndex + 1
|
||||
let (b, s) := (withParams xs (visitFnBody b) { ctx with f := f, resultType := t }).run { nextIdx := nextIdx }
|
||||
let newDecls := newDecls ++ s.auxDecls
|
||||
let newDecl := Decl.fdecl f xs t b
|
||||
let newDecl := newDecl.elimDead
|
||||
newDecls.push newDecl
|
||||
| d => newDecls.push d
|
||||
addBoxedVersions env decls
|
||||
let ctx : BoxingContext := { decls := decls, env := env }
|
||||
let decls := decls.foldl (init := #[]) fun newDecls decl =>
|
||||
match decl with
|
||||
| Decl.fdecl f xs t b =>
|
||||
let nextIdx := decl.maxIndex + 1
|
||||
let (b, s) := (withParams xs (visitFnBody b) { ctx with f := f, resultType := t }).run { nextIdx := nextIdx }
|
||||
let newDecls := newDecls ++ s.auxDecls
|
||||
let newDecl := Decl.fdecl f xs t b
|
||||
let newDecl := newDecl.elimDead
|
||||
newDecls.push newDecl
|
||||
| d => newDecls.push d
|
||||
addBoxedVersions env decls
|
||||
|
||||
end ExplicitBoxing
|
||||
|
||||
def explicitBoxing (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let env ← getEnv
|
||||
pure $ ExplicitBoxing.run env decls
|
||||
let env ← getEnv
|
||||
pure $ ExplicitBoxing.run env decls
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -9,159 +9,161 @@ import Lean.Compiler.IR.Format
|
|||
namespace Lean.IR.Checker
|
||||
|
||||
structure CheckerContext :=
|
||||
(env : Environment) (localCtx : LocalContext := {}) (decls : Array Decl)
|
||||
(env : Environment)
|
||||
(localCtx : LocalContext := {})
|
||||
(decls : Array Decl)
|
||||
|
||||
structure CheckerState :=
|
||||
(foundVars : IndexSet := {})
|
||||
(foundVars : IndexSet := {})
|
||||
|
||||
abbrev M := ReaderT CheckerContext (ExceptT String (StateT CheckerState Id))
|
||||
|
||||
def markIndex (i : Index) : M Unit := do
|
||||
let s ← get
|
||||
if s.foundVars.contains i then
|
||||
throw s!"variable / joinpoint index {i} has already been used"
|
||||
modify fun s => { s with foundVars := s.foundVars.insert i }
|
||||
let s ← get
|
||||
if s.foundVars.contains i then
|
||||
throw s!"variable / joinpoint index {i} has already been used"
|
||||
modify fun s => { s with foundVars := s.foundVars.insert i }
|
||||
|
||||
def markVar (x : VarId) : M Unit :=
|
||||
markIndex x.idx
|
||||
markIndex x.idx
|
||||
|
||||
def markJP (j : JoinPointId) : M Unit :=
|
||||
markIndex j.idx
|
||||
markIndex j.idx
|
||||
|
||||
def getDecl (c : Name) : M Decl := do
|
||||
let ctx ← read
|
||||
match findEnvDecl' ctx.env c ctx.decls with
|
||||
| none => throw s!"unknown declaration '{c}'"
|
||||
| some d => pure d
|
||||
let ctx ← read
|
||||
match findEnvDecl' ctx.env c ctx.decls with
|
||||
| none => throw s!"unknown declaration '{c}'"
|
||||
| some d => pure d
|
||||
|
||||
def checkVar (x : VarId) : M Unit := do
|
||||
let ctx ← read
|
||||
unless ctx.localCtx.isLocalVar x.idx || ctx.localCtx.isParam x.idx do
|
||||
throw s!"unknown variable '{x}'"
|
||||
let ctx ← read
|
||||
unless ctx.localCtx.isLocalVar x.idx || ctx.localCtx.isParam x.idx do
|
||||
throw s!"unknown variable '{x}'"
|
||||
|
||||
def checkJP (j : JoinPointId) : M Unit := do
|
||||
let ctx ← read
|
||||
unless ctx.localCtx.isJP j.idx do
|
||||
throw s!"unknown join point '{j}'"
|
||||
let ctx ← read
|
||||
unless ctx.localCtx.isJP j.idx do
|
||||
throw s!"unknown join point '{j}'"
|
||||
|
||||
def checkArg (a : Arg) : M Unit :=
|
||||
match a with
|
||||
| Arg.var x => checkVar x
|
||||
| other => pure ()
|
||||
match a with
|
||||
| Arg.var x => checkVar x
|
||||
| other => pure ()
|
||||
|
||||
def checkArgs (as : Array Arg) : M Unit :=
|
||||
as.forM checkArg
|
||||
as.forM checkArg
|
||||
|
||||
@[inline] def checkEqTypes (ty₁ ty₂ : IRType) : M Unit := do
|
||||
unless ty₁ == ty₂ do
|
||||
throw "unexpected type"
|
||||
unless ty₁ == ty₂ do
|
||||
throw "unexpected type"
|
||||
|
||||
@[inline] def checkType (ty : IRType) (p : IRType → Bool) : M Unit := do
|
||||
unless p ty do
|
||||
throw s!"unexpected type '{ty}'"
|
||||
unless p ty do
|
||||
throw s!"unexpected type '{ty}'"
|
||||
|
||||
def checkObjType (ty : IRType) : M Unit := checkType ty IRType.isObj
|
||||
|
||||
def checkScalarType (ty : IRType) : M Unit := checkType ty IRType.isScalar
|
||||
|
||||
def getType (x : VarId) : M IRType := do
|
||||
let ctx ← read
|
||||
match ctx.localCtx.getType x with
|
||||
| some ty => pure ty
|
||||
| none => throw s!"unknown variable '{x}'"
|
||||
let ctx ← read
|
||||
match ctx.localCtx.getType x with
|
||||
| some ty => pure ty
|
||||
| none => throw s!"unknown variable '{x}'"
|
||||
|
||||
@[inline] def checkVarType (x : VarId) (p : IRType → Bool) : M Unit := do
|
||||
let ty ← getType x; checkType ty p
|
||||
let ty ← getType x; checkType ty p
|
||||
|
||||
def checkObjVar (x : VarId) : M Unit :=
|
||||
checkVarType x IRType.isObj
|
||||
checkVarType x IRType.isObj
|
||||
|
||||
def checkScalarVar (x : VarId) : M Unit :=
|
||||
checkVarType x IRType.isScalar
|
||||
checkVarType x IRType.isScalar
|
||||
|
||||
def checkFullApp (c : FunId) (ys : Array Arg) : M Unit := do
|
||||
if c == `hugeFuel then
|
||||
throw "the auxiliary constant `hugeFuel` cannot be used in code, it is used internally for compiling `partial` definitions"
|
||||
let decl ← getDecl c
|
||||
unless ys.size == decl.params.size do
|
||||
throw s!"incorrect number of arguments to '{c}', {ys.size} provided, {decl.params.size} expected"
|
||||
checkArgs ys
|
||||
if c == `hugeFuel then
|
||||
throw "the auxiliary constant `hugeFuel` cannot be used in code, it is used internally for compiling `partial` definitions"
|
||||
let decl ← getDecl c
|
||||
unless ys.size == decl.params.size do
|
||||
throw s!"incorrect number of arguments to '{c}', {ys.size} provided, {decl.params.size} expected"
|
||||
checkArgs ys
|
||||
|
||||
def checkPartialApp (c : FunId) (ys : Array Arg) : M Unit := do
|
||||
let decl ← getDecl c
|
||||
unless ys.size < decl.params.size do
|
||||
throw s!"too many arguments too partial application '{c}', num. args: {ys.size}, arity: {decl.params.size}"
|
||||
checkArgs ys
|
||||
let decl ← getDecl c
|
||||
unless ys.size < decl.params.size do
|
||||
throw s!"too many arguments too partial application '{c}', num. args: {ys.size}, arity: {decl.params.size}"
|
||||
checkArgs ys
|
||||
|
||||
def checkExpr (ty : IRType) : Expr → M Unit
|
||||
| Expr.pap f ys => checkPartialApp f ys *> checkObjType ty -- partial applications should always produce a closure object
|
||||
| Expr.ap x ys => checkObjVar x *> checkArgs ys
|
||||
| Expr.fap f ys => checkFullApp f ys
|
||||
| Expr.ctor c ys => when (!ty.isStruct && !ty.isUnion && c.isRef) (checkObjType ty) *> checkArgs ys
|
||||
| Expr.reset _ x => checkObjVar x *> checkObjType ty
|
||||
| Expr.reuse x i u ys => checkObjVar x *> checkArgs ys *> checkObjType ty
|
||||
| Expr.box xty x => checkObjType ty *> checkScalarVar x *> checkVarType x (fun t => t == xty)
|
||||
| Expr.unbox x => checkScalarType ty *> checkObjVar x
|
||||
| Expr.proj i x => do
|
||||
let xType ← getType x;
|
||||
match xType with
|
||||
| IRType.object => checkObjType ty
|
||||
| IRType.tobject => checkObjType ty
|
||||
| IRType.struct _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index"
|
||||
| IRType.union _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index"
|
||||
| other => throw s!"unexpected IR type '{xType}'"
|
||||
| Expr.uproj _ x => checkObjVar x *> checkType ty (fun t => t == IRType.usize)
|
||||
| Expr.sproj _ _ x => checkObjVar x *> checkScalarType ty
|
||||
| Expr.isShared x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8)
|
||||
| Expr.isTaggedPtr x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8)
|
||||
| Expr.lit (LitVal.str _) => checkObjType ty
|
||||
| Expr.lit _ => pure ()
|
||||
| Expr.pap f ys => checkPartialApp f ys *> checkObjType ty -- partial applications should always produce a closure object
|
||||
| Expr.ap x ys => checkObjVar x *> checkArgs ys
|
||||
| Expr.fap f ys => checkFullApp f ys
|
||||
| Expr.ctor c ys => when (!ty.isStruct && !ty.isUnion && c.isRef) (checkObjType ty) *> checkArgs ys
|
||||
| Expr.reset _ x => checkObjVar x *> checkObjType ty
|
||||
| Expr.reuse x i u ys => checkObjVar x *> checkArgs ys *> checkObjType ty
|
||||
| Expr.box xty x => checkObjType ty *> checkScalarVar x *> checkVarType x (fun t => t == xty)
|
||||
| Expr.unbox x => checkScalarType ty *> checkObjVar x
|
||||
| Expr.proj i x => do
|
||||
let xType ← getType x;
|
||||
match xType with
|
||||
| IRType.object => checkObjType ty
|
||||
| IRType.tobject => checkObjType ty
|
||||
| IRType.struct _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index"
|
||||
| IRType.union _ tys => if h : i < tys.size then checkEqTypes (tys.get ⟨i,h⟩) ty else throw "invalid proj index"
|
||||
| other => throw s!"unexpected IR type '{xType}'"
|
||||
| Expr.uproj _ x => checkObjVar x *> checkType ty (fun t => t == IRType.usize)
|
||||
| Expr.sproj _ _ x => checkObjVar x *> checkScalarType ty
|
||||
| Expr.isShared x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8)
|
||||
| Expr.isTaggedPtr x => checkObjVar x *> checkType ty (fun t => t == IRType.uint8)
|
||||
| Expr.lit (LitVal.str _) => checkObjType ty
|
||||
| Expr.lit _ => pure ()
|
||||
|
||||
@[inline] def withParams (ps : Array Param) (k : M Unit) : M Unit := do
|
||||
let ctx ← read
|
||||
let localCtx ← ps.foldlM (init := ctx.localCtx) fun (ctx : LocalContext) p => do
|
||||
markVar p.x
|
||||
pure $ ctx.addParam p
|
||||
withReader (fun _ => { ctx with localCtx := localCtx }) k
|
||||
let ctx ← read
|
||||
let localCtx ← ps.foldlM (init := ctx.localCtx) fun (ctx : LocalContext) p => do
|
||||
markVar p.x
|
||||
pure $ ctx.addParam p
|
||||
withReader (fun _ => { ctx with localCtx := localCtx }) k
|
||||
|
||||
partial def checkFnBody : FnBody → M Unit
|
||||
| FnBody.vdecl x t v b => do
|
||||
checkExpr t v;
|
||||
markVar x;
|
||||
let ctx ← read
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x t v }) (checkFnBody b)
|
||||
| FnBody.jdecl j ys v b => do
|
||||
markJP j;
|
||||
withParams ys (checkFnBody v);
|
||||
let ctx ← read
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j ys v }) (checkFnBody b)
|
||||
| FnBody.set x _ y b => checkVar x *> checkArg y *> checkFnBody b
|
||||
| FnBody.uset x _ y b => checkVar x *> checkVar y *> checkFnBody b
|
||||
| FnBody.sset x _ _ y _ b => checkVar x *> checkVar y *> checkFnBody b
|
||||
| FnBody.setTag x _ b => checkVar x *> checkFnBody b
|
||||
| FnBody.inc x _ _ _ b => checkVar x *> checkFnBody b
|
||||
| FnBody.dec x _ _ _ b => checkVar x *> checkFnBody b
|
||||
| FnBody.del x b => checkVar x *> checkFnBody b
|
||||
| FnBody.mdata _ b => checkFnBody b
|
||||
| FnBody.jmp j ys => checkJP j *> checkArgs ys
|
||||
| FnBody.ret x => checkArg x
|
||||
| FnBody.case _ x _ alts => checkVar x *> alts.forM (fun alt => checkFnBody alt.body)
|
||||
| FnBody.unreachable => pure ()
|
||||
| FnBody.vdecl x t v b => do
|
||||
checkExpr t v;
|
||||
markVar x;
|
||||
let ctx ← read
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addLocal x t v }) (checkFnBody b)
|
||||
| FnBody.jdecl j ys v b => do
|
||||
markJP j;
|
||||
withParams ys (checkFnBody v);
|
||||
let ctx ← read
|
||||
withReader (fun ctx => { ctx with localCtx := ctx.localCtx.addJP j ys v }) (checkFnBody b)
|
||||
| FnBody.set x _ y b => checkVar x *> checkArg y *> checkFnBody b
|
||||
| FnBody.uset x _ y b => checkVar x *> checkVar y *> checkFnBody b
|
||||
| FnBody.sset x _ _ y _ b => checkVar x *> checkVar y *> checkFnBody b
|
||||
| FnBody.setTag x _ b => checkVar x *> checkFnBody b
|
||||
| FnBody.inc x _ _ _ b => checkVar x *> checkFnBody b
|
||||
| FnBody.dec x _ _ _ b => checkVar x *> checkFnBody b
|
||||
| FnBody.del x b => checkVar x *> checkFnBody b
|
||||
| FnBody.mdata _ b => checkFnBody b
|
||||
| FnBody.jmp j ys => checkJP j *> checkArgs ys
|
||||
| FnBody.ret x => checkArg x
|
||||
| FnBody.case _ x _ alts => checkVar x *> alts.forM (fun alt => checkFnBody alt.body)
|
||||
| FnBody.unreachable => pure ()
|
||||
|
||||
def checkDecl : Decl → M Unit
|
||||
| Decl.fdecl f xs t b => withParams xs (checkFnBody b)
|
||||
| Decl.extern f xs t _ => withParams xs (pure ())
|
||||
| Decl.fdecl f xs t b => withParams xs (checkFnBody b)
|
||||
| Decl.extern f xs t _ => withParams xs (pure ())
|
||||
|
||||
end Checker
|
||||
|
||||
def checkDecl (decls : Array Decl) (decl : Decl) : CompilerM Unit := do
|
||||
let env ← getEnv
|
||||
match (Checker.checkDecl decl { env := env, decls := decls }).run' {} with
|
||||
| Except.error msg => throw s!"IR check failed at '{decl.name}', error: {msg}"
|
||||
| other => pure ()
|
||||
let env ← getEnv
|
||||
match (Checker.checkDecl decl { env := env, decls := decls }).run' {} with
|
||||
| Except.error msg => throw s!"IR check failed at '{decl.name}', error: {msg}"
|
||||
| other => pure ()
|
||||
|
||||
def checkDecls (decls : Array Decl) : CompilerM Unit :=
|
||||
decls.forM (checkDecl decls)
|
||||
decls.forM (checkDecl decls)
|
||||
|
||||
end IR
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ import Lean.Compiler.IR.Format
|
|||
namespace Lean.IR
|
||||
|
||||
inductive LogEntry
|
||||
| step (cls : Name) (decls : Array Decl)
|
||||
| message (msg : Format)
|
||||
| step (cls : Name) (decls : Array Decl)
|
||||
| message (msg : Format)
|
||||
|
||||
namespace LogEntry
|
||||
protected def fmt : LogEntry → Format
|
||||
| step cls decls => Format.bracket "[" (format cls) "]" ++ decls.foldl (fun fmt decl => fmt ++ Format.line ++ format decl) Format.nil
|
||||
| message msg => msg
|
||||
| step cls decls => Format.bracket "[" (format cls) "]" ++ decls.foldl (fun fmt decl => fmt ++ Format.line ++ format decl) Format.nil
|
||||
| message msg => msg
|
||||
|
||||
instance : ToFormat LogEntry := ⟨LogEntry.fmt⟩
|
||||
end LogEntry
|
||||
|
|
@ -24,49 +24,50 @@ end LogEntry
|
|||
abbrev Log := Array LogEntry
|
||||
|
||||
def Log.format (log : Log) : Format :=
|
||||
log.foldl (init := Format.nil) fun fmt entry =>
|
||||
f!"{fmt}{Format.line}{entry}"
|
||||
log.foldl (init := Format.nil) fun fmt entry =>
|
||||
f!"{fmt}{Format.line}{entry}"
|
||||
|
||||
@[export lean_ir_log_to_string]
|
||||
def Log.toString (log : Log) : String :=
|
||||
log.format.pretty
|
||||
log.format.pretty
|
||||
|
||||
structure CompilerState :=
|
||||
(env : Environment) (log : Log := #[])
|
||||
(env : Environment)
|
||||
(log : Log := #[])
|
||||
|
||||
abbrev CompilerM := ReaderT Options (EStateM String CompilerState)
|
||||
|
||||
def log (entry : LogEntry) : CompilerM Unit :=
|
||||
modify $ fun s => { s with log := s.log.push entry }
|
||||
modify fun s => { s with log := s.log.push entry }
|
||||
|
||||
def tracePrefixOptionName := `trace.compiler.ir
|
||||
|
||||
private def isLogEnabledFor (opts : Options) (optName : Name) : Bool :=
|
||||
match opts.find optName with
|
||||
| some (DataValue.ofBool v) => v
|
||||
| other => opts.getBool tracePrefixOptionName
|
||||
match opts.find optName with
|
||||
| some (DataValue.ofBool v) => v
|
||||
| other => opts.getBool tracePrefixOptionName
|
||||
|
||||
private def logDeclsAux (optName : Name) (cls : Name) (decls : Array Decl) : CompilerM Unit := do
|
||||
let opts ← read
|
||||
if isLogEnabledFor opts optName then
|
||||
log (LogEntry.step cls decls)
|
||||
let opts ← read
|
||||
if isLogEnabledFor opts optName then
|
||||
log (LogEntry.step cls decls)
|
||||
|
||||
@[inline] def logDecls (cls : Name) (decl : Array Decl) : CompilerM Unit :=
|
||||
logDeclsAux (tracePrefixOptionName ++ cls) cls decl
|
||||
logDeclsAux (tracePrefixOptionName ++ cls) cls decl
|
||||
|
||||
private def logMessageIfAux {α : Type} [ToFormat α] (optName : Name) (a : α) : CompilerM Unit := do
|
||||
let opts ← read
|
||||
if isLogEnabledFor opts optName then
|
||||
log (LogEntry.message (format a))
|
||||
let opts ← read
|
||||
if isLogEnabledFor opts optName then
|
||||
log (LogEntry.message (format a))
|
||||
|
||||
@[inline] def logMessageIf {α : Type} [ToFormat α] (cls : Name) (a : α) : CompilerM Unit :=
|
||||
logMessageIfAux (tracePrefixOptionName ++ cls) a
|
||||
logMessageIfAux (tracePrefixOptionName ++ cls) a
|
||||
|
||||
@[inline] def logMessage {α : Type} [ToFormat α] (cls : Name) (a : α) : CompilerM Unit :=
|
||||
logMessageIfAux tracePrefixOptionName a
|
||||
logMessageIfAux tracePrefixOptionName a
|
||||
|
||||
@[inline] def modifyEnv (f : Environment → Environment) : CompilerM Unit :=
|
||||
modify fun s => { s with env := f s.env }
|
||||
modify fun s => { s with env := f s.env }
|
||||
|
||||
open Std (HashMap)
|
||||
|
||||
|
|
@ -75,10 +76,10 @@ abbrev DeclMap := SMap Name Decl
|
|||
/- Create an array of decls to be saved on .olean file.
|
||||
`decls` may contain duplicate entries, but we assume the one that occurs last is the most recent one. -/
|
||||
private def mkEntryArray (decls : List Decl) : Array Decl :=
|
||||
/- Remove duplicates by adding decls into a map -/
|
||||
let map : HashMap Name Decl := {}
|
||||
let map := decls.foldl (init := map) fun map decl => map.insert decl.name decl
|
||||
map.fold (fun a k v => a.push v) #[]
|
||||
/- Remove duplicates by adding decls into a map -/
|
||||
let map : HashMap Name Decl := {}
|
||||
let map := decls.foldl (init := map) fun map decl => map.insert decl.name decl
|
||||
map.fold (fun a k v => a.push v) #[]
|
||||
|
||||
builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
|
|
@ -92,53 +93,54 @@ builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ←
|
|||
|
||||
@[export lean_ir_find_env_decl]
|
||||
def findEnvDecl (env : Environment) (n : Name) : Option Decl :=
|
||||
(declMapExt.getState env).find? n
|
||||
(declMapExt.getState env).find? n
|
||||
|
||||
def findDecl (n : Name) : CompilerM (Option Decl) := do
|
||||
let s ← get
|
||||
pure $ findEnvDecl s.env n
|
||||
let s ← get
|
||||
pure $ findEnvDecl s.env n
|
||||
|
||||
def containsDecl (n : Name) : CompilerM Bool := do
|
||||
let s ← get
|
||||
pure $ (declMapExt.getState s.env).contains n
|
||||
|
||||
def getDecl (n : Name) : CompilerM Decl := do
|
||||
let (some decl) ← findDecl n | throw s!"unknown declaration '{n}'"
|
||||
pure decl
|
||||
|
||||
@[export lean_ir_add_decl]
|
||||
def addDeclAux (env : Environment) (decl : Decl) : Environment :=
|
||||
declMapExt.addEntry env decl
|
||||
|
||||
def getDecls (env : Environment) : List Decl :=
|
||||
declMapExt.getEntries env
|
||||
|
||||
def getEnv : CompilerM Environment := do
|
||||
let s ← get; pure s.env
|
||||
|
||||
def addDecl (decl : Decl) : CompilerM Unit :=
|
||||
modifyEnv fun env => declMapExt.addEntry env decl
|
||||
|
||||
def addDecls (decls : Array Decl) : CompilerM Unit :=
|
||||
decls.forM addDecl
|
||||
|
||||
def findEnvDecl' (env : Environment) (n : Name) (decls : Array Decl) : Option Decl :=
|
||||
match decls.find? (fun decl => decl.name == n) with
|
||||
| some decl => some decl
|
||||
| none => (declMapExt.getState env).find? n
|
||||
|
||||
def findDecl' (n : Name) (decls : Array Decl) : CompilerM (Option Decl) := do
|
||||
let s ← get; pure $ findEnvDecl' s.env n decls
|
||||
|
||||
def containsDecl' (n : Name) (decls : Array Decl) : CompilerM Bool :=
|
||||
if decls.any (fun decl => decl.name == n) then pure true
|
||||
else do
|
||||
let s ← get
|
||||
pure $ (declMapExt.getState s.env).contains n
|
||||
|
||||
def getDecl (n : Name) : CompilerM Decl := do
|
||||
let (some decl) ← findDecl n | throw s!"unknown declaration '{n}'"
|
||||
pure decl
|
||||
|
||||
@[export lean_ir_add_decl]
|
||||
def addDeclAux (env : Environment) (decl : Decl) : Environment :=
|
||||
declMapExt.addEntry env decl
|
||||
|
||||
def getDecls (env : Environment) : List Decl :=
|
||||
declMapExt.getEntries env
|
||||
|
||||
def getEnv : CompilerM Environment := do
|
||||
let s ← get; pure s.env
|
||||
|
||||
def addDecl (decl : Decl) : CompilerM Unit :=
|
||||
modifyEnv fun env => declMapExt.addEntry env decl
|
||||
|
||||
def addDecls (decls : Array Decl) : CompilerM Unit :=
|
||||
decls.forM addDecl
|
||||
|
||||
def findEnvDecl' (env : Environment) (n : Name) (decls : Array Decl) : Option Decl :=
|
||||
match decls.find? (fun decl => decl.name == n) with
|
||||
| some decl => some decl
|
||||
| none => (declMapExt.getState env).find? n
|
||||
|
||||
def findDecl' (n : Name) (decls : Array Decl) : CompilerM (Option Decl) := do
|
||||
let s ← get; pure $ findEnvDecl' s.env n decls
|
||||
|
||||
def containsDecl' (n : Name) (decls : Array Decl) : CompilerM Bool := do
|
||||
if decls.any fun decl => decl.name == n then
|
||||
pure true
|
||||
else
|
||||
let s ← get
|
||||
pure $ (declMapExt.getState s.env).contains n
|
||||
|
||||
def getDecl' (n : Name) (decls : Array Decl) : CompilerM Decl := do
|
||||
let (some decl) ← findDecl' n decls | throw s!"unknown declaration '{n}'"
|
||||
pure decl
|
||||
let (some decl) ← findDecl' n decls | throw s!"unknown declaration '{n}'"
|
||||
pure decl
|
||||
|
||||
end IR
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -10,32 +10,32 @@ namespace Lean
|
|||
namespace IR
|
||||
|
||||
inductive CtorFieldInfo
|
||||
| irrelevant
|
||||
| object (i : Nat)
|
||||
| usize (i : Nat)
|
||||
| scalar (sz : Nat) (offset : Nat) (type : IRType)
|
||||
| irrelevant
|
||||
| object (i : Nat)
|
||||
| usize (i : Nat)
|
||||
| scalar (sz : Nat) (offset : Nat) (type : IRType)
|
||||
|
||||
namespace CtorFieldInfo
|
||||
|
||||
def format : CtorFieldInfo → Format
|
||||
| irrelevant => "◾"
|
||||
| object i => f!"obj@{i}"
|
||||
| usize i => f!"usize@{i}"
|
||||
| scalar sz offset type => f!"scalar#{sz}@{offset}:{type}"
|
||||
| irrelevant => "◾"
|
||||
| object i => f!"obj@{i}"
|
||||
| usize i => f!"usize@{i}"
|
||||
| scalar sz offset type => f!"scalar#{sz}@{offset}:{type}"
|
||||
|
||||
instance : ToFormat CtorFieldInfo := ⟨format⟩
|
||||
|
||||
end CtorFieldInfo
|
||||
|
||||
structure CtorLayout :=
|
||||
(cidx : Nat)
|
||||
(fieldInfo : List CtorFieldInfo)
|
||||
(numObjs : Nat)
|
||||
(numUSize : Nat)
|
||||
(scalarSize : Nat)
|
||||
(cidx : Nat)
|
||||
(fieldInfo : List CtorFieldInfo)
|
||||
(numObjs : Nat)
|
||||
(numUSize : Nat)
|
||||
(scalarSize : Nat)
|
||||
|
||||
@[extern "lean_ir_get_ctor_layout"]
|
||||
constant getCtorLayout (env : @& Environment) (ctorName : @& Name) : Except String CtorLayout := arbitrary _
|
||||
constant getCtorLayout (env : @& Environment) (ctorName : @& Name) : Except String CtorLayout
|
||||
|
||||
end IR
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -11,51 +11,51 @@ namespace Lean.IR.UnreachableBranches
|
|||
|
||||
/-- Value used in the abstract interpreter -/
|
||||
inductive Value
|
||||
| bot -- undefined
|
||||
| top -- any value
|
||||
| ctor (i : CtorInfo) (vs : Array Value)
|
||||
| choice (vs : List Value)
|
||||
| bot -- undefined
|
||||
| top -- any value
|
||||
| ctor (i : CtorInfo) (vs : Array Value)
|
||||
| choice (vs : List Value)
|
||||
|
||||
namespace Value
|
||||
|
||||
instance : Inhabited Value := ⟨top⟩
|
||||
|
||||
protected partial def beq : Value → Value → Bool
|
||||
| bot, bot => true
|
||||
| top, top => true
|
||||
| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ Value.beq
|
||||
| choice vs₁, choice vs₂ =>
|
||||
vs₁.all (fun v₁ => vs₂.any fun v₂ => Value.beq v₁ v₂)
|
||||
&&
|
||||
vs₂.all (fun v₂ => vs₁.any fun v₁ => Value.beq v₁ v₂)
|
||||
| _, _ => false
|
||||
| bot, bot => true
|
||||
| top, top => true
|
||||
| ctor i₁ vs₁, ctor i₂ vs₂ => i₁ == i₂ && Array.isEqv vs₁ vs₂ Value.beq
|
||||
| choice vs₁, choice vs₂ =>
|
||||
vs₁.all (fun v₁ => vs₂.any fun v₂ => Value.beq v₁ v₂)
|
||||
&&
|
||||
vs₂.all (fun v₂ => vs₁.any fun v₁ => Value.beq v₁ v₂)
|
||||
| _, _ => false
|
||||
|
||||
instance : BEq Value := ⟨Value.beq⟩
|
||||
|
||||
partial def addChoice (merge : Value → Value → Value) : List Value → Value → List Value
|
||||
| [], v => [v]
|
||||
| v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then merge v₁ v₂ :: cs
|
||||
else v₁ :: addChoice merge cs v₂
|
||||
| _, _ => panic! "invalid addChoice"
|
||||
| [], v => [v]
|
||||
| v₁@(ctor i₁ vs₁) :: cs, v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then merge v₁ v₂ :: cs
|
||||
else v₁ :: addChoice merge cs v₂
|
||||
| _, _ => panic! "invalid addChoice"
|
||||
|
||||
partial def merge : Value → Value → Value
|
||||
| bot, v => v
|
||||
| v, bot => v
|
||||
| top, _ => top
|
||||
| _, top => top
|
||||
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then ctor i₁ $ vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i] vs₂[i])
|
||||
else choice [v₁, v₂]
|
||||
| choice vs₁, choice vs₂ => choice $ vs₁.foldl (addChoice merge) vs₂
|
||||
| choice vs, v => choice $ addChoice merge vs v
|
||||
| v, choice vs => choice $ addChoice merge vs v
|
||||
| bot, v => v
|
||||
| v, bot => v
|
||||
| top, _ => top
|
||||
| _, top => top
|
||||
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
|
||||
if i₁ == i₂ then ctor i₁ $ vs₁.size.fold (init := #[]) fun i r => r.push (merge vs₁[i] vs₂[i])
|
||||
else choice [v₁, v₂]
|
||||
| choice vs₁, choice vs₂ => choice $ vs₁.foldl (addChoice merge) vs₂
|
||||
| choice vs, v => choice $ addChoice merge vs v
|
||||
| v, choice vs => choice $ addChoice merge vs v
|
||||
|
||||
protected partial def format : Value → Format
|
||||
| top => "top"
|
||||
| bot => "bot"
|
||||
| choice vs => fmt "@" ++ @List.format _ ⟨Value.format⟩ vs
|
||||
| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨Value.format⟩ vs)
|
||||
| top => "top"
|
||||
| bot => "bot"
|
||||
| choice vs => fmt "@" ++ @List.format _ ⟨Value.format⟩ vs
|
||||
| ctor i vs => fmt "#" ++ if vs.isEmpty then fmt i.name else Format.paren (fmt i.name ++ @formatArray _ ⟨Value.format⟩ vs)
|
||||
|
||||
instance : ToFormat Value := ⟨Value.format⟩
|
||||
instance : ToString Value := ⟨Format.pretty ∘ Value.format⟩
|
||||
|
|
@ -64,240 +64,236 @@ instance : ToString Value := ⟨Format.pretty ∘ Value.format⟩
|
|||
We use this function this function to implement a simple widening operation for our abstract
|
||||
interpreter. -/
|
||||
partial def truncate (env : Environment) : Value → NameSet → Value
|
||||
| ctor i vs, found =>
|
||||
let I := i.name.getPrefix
|
||||
if found.contains I then
|
||||
top
|
||||
else
|
||||
let cont (found' : NameSet) : Value :=
|
||||
ctor i (vs.map fun v => truncate env v found')
|
||||
match env.find? I with
|
||||
| some (ConstantInfo.inductInfo d) =>
|
||||
if d.isRec then cont (found.insert I)
|
||||
else cont found
|
||||
| _ => cont found
|
||||
| choice vs, found =>
|
||||
let newVs := vs.map fun v => truncate env v found
|
||||
if newVs.elem top then top
|
||||
else choice newVs
|
||||
| v, _ => v
|
||||
| ctor i vs, found =>
|
||||
let I := i.name.getPrefix
|
||||
if found.contains I then
|
||||
top
|
||||
else
|
||||
let cont (found' : NameSet) : Value :=
|
||||
ctor i (vs.map fun v => truncate env v found')
|
||||
match env.find? I with
|
||||
| some (ConstantInfo.inductInfo d) =>
|
||||
if d.isRec then cont (found.insert I)
|
||||
else cont found
|
||||
| _ => cont found
|
||||
| choice vs, found =>
|
||||
let newVs := vs.map fun v => truncate env v found
|
||||
if newVs.elem top then top
|
||||
else choice newVs
|
||||
| v, _ => v
|
||||
|
||||
/- Widening operator that guarantees termination in our abstract interpreter. -/
|
||||
def widening (env : Environment) (v₁ v₂ : Value) : Value :=
|
||||
truncate env (merge v₁ v₂) {}
|
||||
truncate env (merge v₁ v₂) {}
|
||||
|
||||
end Value
|
||||
|
||||
abbrev FunctionSummaries := SMap FunId Value
|
||||
|
||||
def mkFunctionSummariesExtension : IO (SimplePersistentEnvExtension (FunId × Value) FunctionSummaries) :=
|
||||
registerSimplePersistentEnvExtension {
|
||||
name := `unreachBranchesFunSummary,
|
||||
addImportedFn := fun as =>
|
||||
let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as
|
||||
cache.switch,
|
||||
addEntryFn := fun s ⟨e, n⟩ => s.insert e n
|
||||
}
|
||||
|
||||
@[builtinInit mkFunctionSummariesExtension]
|
||||
constant functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries := arbitrary _
|
||||
builtin_initialize functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
name := `unreachBranchesFunSummary,
|
||||
addImportedFn := fun as =>
|
||||
let cache : FunctionSummaries := mkStateFromImportedEntries (fun s (p : FunId × Value) => s.insert p.1 p.2) {} as
|
||||
cache.switch,
|
||||
addEntryFn := fun s ⟨e, n⟩ => s.insert e n
|
||||
}
|
||||
|
||||
def addFunctionSummary (env : Environment) (fid : FunId) (v : Value) : Environment :=
|
||||
functionSummariesExt.addEntry env (fid, v)
|
||||
functionSummariesExt.addEntry env (fid, v)
|
||||
|
||||
def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
|
||||
(functionSummariesExt.getState env).find? fid
|
||||
(functionSummariesExt.getState env).find? fid
|
||||
|
||||
abbrev Assignment := Std.HashMap VarId Value
|
||||
|
||||
structure InterpContext :=
|
||||
(currFnIdx : Nat := 0)
|
||||
(decls : Array Decl)
|
||||
(env : Environment)
|
||||
(lctx : LocalContext := {})
|
||||
(currFnIdx : Nat := 0)
|
||||
(decls : Array Decl)
|
||||
(env : Environment)
|
||||
(lctx : LocalContext := {})
|
||||
|
||||
structure InterpState :=
|
||||
(assignments : Array Assignment)
|
||||
(funVals : Std.PArray Value) -- we take snapshots during fixpoint computations
|
||||
(assignments : Array Assignment)
|
||||
(funVals : Std.PArray Value) -- we take snapshots during fixpoint computations
|
||||
|
||||
abbrev M := ReaderT InterpContext (StateM InterpState)
|
||||
|
||||
open Value
|
||||
|
||||
def findVarValue (x : VarId) : M Value := do
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let assignment := s.assignments[ctx.currFnIdx]
|
||||
pure $ assignment.findD x bot
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let assignment := s.assignments[ctx.currFnIdx]
|
||||
pure $ assignment.findD x bot
|
||||
|
||||
def findArgValue (arg : Arg) : M Value :=
|
||||
match arg with
|
||||
| Arg.var x => findVarValue x
|
||||
| _ => pure top
|
||||
match arg with
|
||||
| Arg.var x => findVarValue x
|
||||
| _ => pure top
|
||||
|
||||
def updateVarAssignment (x : VarId) (v : Value) : M Unit := do
|
||||
let v' ← findVarValue x
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') }
|
||||
let v' ← findVarValue x
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') }
|
||||
|
||||
def resetVarAssignment (x : VarId) : M Unit := do
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot }
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot }
|
||||
|
||||
def resetParamAssignment (y : Param) : M Unit :=
|
||||
resetVarAssignment y.x
|
||||
resetVarAssignment y.x
|
||||
|
||||
partial def projValue : Value → Nat → Value
|
||||
| ctor _ vs, i => vs.getD i bot
|
||||
| choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot
|
||||
| v, _ => v
|
||||
| ctor _ vs, i => vs.getD i bot
|
||||
| choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot
|
||||
| v, _ => v
|
||||
|
||||
def interpExpr : Expr → M Value
|
||||
| Expr.ctor i ys => do return ctor i (← ys.mapM fun y => findArgValue y)
|
||||
| Expr.proj i x => do return projValue (← findVarValue x) i
|
||||
| Expr.fap fid ys => do
|
||||
let ctx ← read
|
||||
match getFunctionSummary? ctx.env fid with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let s ← get
|
||||
match ctx.decls.findIdx? (fun decl => decl.name == fid) with
|
||||
| some idx => pure s.funVals[idx]
|
||||
| none => pure top
|
||||
| _ => pure top
|
||||
| Expr.ctor i ys => do return ctor i (← ys.mapM fun y => findArgValue y)
|
||||
| Expr.proj i x => do return projValue (← findVarValue x) i
|
||||
| Expr.fap fid ys => do
|
||||
let ctx ← read
|
||||
match getFunctionSummary? ctx.env fid with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let s ← get
|
||||
match ctx.decls.findIdx? (fun decl => decl.name == fid) with
|
||||
| some idx => pure s.funVals[idx]
|
||||
| none => pure top
|
||||
| _ => pure top
|
||||
|
||||
partial def containsCtor : Value → CtorInfo → Bool
|
||||
| top, _ => true
|
||||
| ctor i _, j => i == j
|
||||
| choice vs, j => vs.any $ fun v => containsCtor v j
|
||||
| _, _ => false
|
||||
| top, _ => true
|
||||
| ctor i _, j => i == j
|
||||
| choice vs, j => vs.any $ fun v => containsCtor v j
|
||||
| _, _ => false
|
||||
|
||||
def updateCurrFnSummary (v : Value) : M Unit := do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
|
||||
|
||||
/-- Return true if the assignment of at least one parameter has been updated. -/
|
||||
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.size.foldM (init := false) fun i r => do
|
||||
let y := ys[i]
|
||||
let x := xs[i]
|
||||
let yVal ← findVarValue y.x
|
||||
let xVal ← findArgValue x
|
||||
let newVal := merge yVal xVal
|
||||
if newVal == yVal then
|
||||
pure r
|
||||
else
|
||||
modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal }
|
||||
pure true
|
||||
|
||||
private partial def resetNestedJPParams : FnBody → M Unit
|
||||
| FnBody.jdecl _ ys b k => do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.forM resetParamAssignment
|
||||
/- Remark we don't need to reset the parameters of joint-points
|
||||
nested in `b` since they will be reset if this JP is used. -/
|
||||
resetNestedJPParams k
|
||||
| FnBody.case _ _ _ alts =>
|
||||
alts.forM fun alt => match alt with
|
||||
| Alt.ctor _ b => resetNestedJPParams b
|
||||
| Alt.default b => resetNestedJPParams b
|
||||
| e => do unless e.isTerminal do resetNestedJPParams e.body
|
||||
ys.size.foldM (init := false) fun i r => do
|
||||
let y := ys[i]
|
||||
let x := xs[i]
|
||||
let yVal ← findVarValue y.x
|
||||
let xVal ← findArgValue x
|
||||
let newVal := merge yVal xVal
|
||||
if newVal == yVal then
|
||||
pure r
|
||||
else
|
||||
modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal }
|
||||
pure true
|
||||
|
||||
private partial def resetNestedJPParams : FnBody → M Unit
|
||||
| FnBody.jdecl _ ys b k => do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.forM resetParamAssignment
|
||||
/- Remark we don't need to reset the parameters of joint-points
|
||||
nested in `b` since they will be reset if this JP is used. -/
|
||||
resetNestedJPParams k
|
||||
| FnBody.case _ _ _ alts =>
|
||||
alts.forM fun alt => match alt with
|
||||
| Alt.ctor _ b => resetNestedJPParams b
|
||||
| Alt.default b => resetNestedJPParams b
|
||||
| e => do unless e.isTerminal do resetNestedJPParams e.body
|
||||
|
||||
partial def interpFnBody : FnBody → M Unit
|
||||
| FnBody.vdecl x _ e b => do
|
||||
let v ← interpExpr e
|
||||
updateVarAssignment x v
|
||||
interpFnBody b
|
||||
| FnBody.jdecl j ys v b =>
|
||||
withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) do
|
||||
| FnBody.vdecl x _ e b => do
|
||||
let v ← interpExpr e
|
||||
updateVarAssignment x v
|
||||
interpFnBody b
|
||||
| FnBody.case _ x _ alts => do
|
||||
let v ← findVarValue x
|
||||
alts.forM fun alt => do
|
||||
match alt with
|
||||
| Alt.ctor i b => if containsCtor v i then interpFnBody b
|
||||
| Alt.default b => interpFnBody b
|
||||
| FnBody.ret x => do
|
||||
let v ← findArgValue x
|
||||
-- dbgTrace ("ret " ++ toString v) $ fun _ =>
|
||||
updateCurrFnSummary v
|
||||
| FnBody.jmp j xs => do
|
||||
let ctx ← read
|
||||
let ys := (ctx.lctx.getJPParams j).get!
|
||||
let b := (ctx.lctx.getJPBody j).get!
|
||||
let updated ← updateJPParamsAssignment ys xs
|
||||
if updated then
|
||||
-- We must reset the value of nested join-point parameters since they depend on `ys` values
|
||||
resetNestedJPParams b
|
||||
interpFnBody b
|
||||
| e => do
|
||||
unless e.isTerminal do
|
||||
interpFnBody e.body
|
||||
| FnBody.jdecl j ys v b =>
|
||||
withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) do
|
||||
interpFnBody b
|
||||
| FnBody.case _ x _ alts => do
|
||||
let v ← findVarValue x
|
||||
alts.forM fun alt => do
|
||||
match alt with
|
||||
| Alt.ctor i b => if containsCtor v i then interpFnBody b
|
||||
| Alt.default b => interpFnBody b
|
||||
| FnBody.ret x => do
|
||||
let v ← findArgValue x
|
||||
-- dbgTrace ("ret " ++ toString v) $ fun _ =>
|
||||
updateCurrFnSummary v
|
||||
| FnBody.jmp j xs => do
|
||||
let ctx ← read
|
||||
let ys := (ctx.lctx.getJPParams j).get!
|
||||
let b := (ctx.lctx.getJPBody j).get!
|
||||
let updated ← updateJPParamsAssignment ys xs
|
||||
if updated then
|
||||
-- We must reset the value of nested join-point parameters since they depend on `ys` values
|
||||
resetNestedJPParams b
|
||||
interpFnBody b
|
||||
| e => do
|
||||
unless e.isTerminal do
|
||||
interpFnBody e.body
|
||||
|
||||
def inferStep : M Bool := do
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
|
||||
ctx.decls.size.foldM (init := false) fun idx modified => do
|
||||
match ctx.decls[idx] with
|
||||
| Decl.fdecl fid ys _ b => do
|
||||
let s ← get
|
||||
let currVals := s.funVals[idx]
|
||||
withReader (fun ctx => { ctx with currFnIdx := idx }) do
|
||||
ys.forM fun y => updateVarAssignment y.x top
|
||||
interpFnBody b
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
|
||||
ctx.decls.size.foldM (init := false) fun idx modified => do
|
||||
match ctx.decls[idx] with
|
||||
| Decl.fdecl fid ys _ b => do
|
||||
let s ← get
|
||||
let newVals := s.funVals[idx]
|
||||
pure (modified || currVals != newVals)
|
||||
| Decl.extern _ _ _ _ => pure modified
|
||||
let currVals := s.funVals[idx]
|
||||
withReader (fun ctx => { ctx with currFnIdx := idx }) do
|
||||
ys.forM fun y => updateVarAssignment y.x top
|
||||
interpFnBody b
|
||||
let s ← get
|
||||
let newVals := s.funVals[idx]
|
||||
pure (modified || currVals != newVals)
|
||||
| Decl.extern _ _ _ _ => pure modified
|
||||
|
||||
partial def inferMain : Unit → M Unit
|
||||
| _ => do
|
||||
partial def inferMain : M Unit := do
|
||||
let modified ← inferStep
|
||||
if modified then inferMain () else pure ()
|
||||
if modified then inferMain else pure ()
|
||||
|
||||
partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
|
||||
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
|
||||
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
|
||||
| FnBody.case tid x xType alts =>
|
||||
let v := assignment.findD x bot
|
||||
let alts := alts.map fun alt =>
|
||||
match alt with
|
||||
| Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
|
||||
| Alt.default b => Alt.default (elimDeadAux assignment b)
|
||||
FnBody.case tid x xType alts
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := elimDeadAux assignment b
|
||||
instr.setBody b
|
||||
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
|
||||
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
|
||||
| FnBody.case tid x xType alts =>
|
||||
let v := assignment.findD x bot
|
||||
let alts := alts.map fun alt =>
|
||||
match alt with
|
||||
| Alt.ctor i b => Alt.ctor i $ if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
|
||||
| Alt.default b => Alt.default (elimDeadAux assignment b)
|
||||
FnBody.case tid x xType alts
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := elimDeadAux assignment b
|
||||
instr.setBody b
|
||||
|
||||
partial def elimDead (assignment : Assignment) : Decl → Decl
|
||||
| Decl.fdecl fid ys t b => Decl.fdecl fid ys t $ elimDeadAux assignment b
|
||||
| other => other
|
||||
| Decl.fdecl fid ys t b => Decl.fdecl fid ys t $ elimDeadAux assignment b
|
||||
| other => other
|
||||
|
||||
end UnreachableBranches
|
||||
|
||||
open UnreachableBranches
|
||||
|
||||
def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let s ← get
|
||||
let env := s.env
|
||||
let assignments : Array Assignment := decls.map fun _ => {}
|
||||
let funVals := Std.mkPArray decls.size Value.bot
|
||||
let ctx : InterpContext := { decls := decls, env := env }
|
||||
let s : InterpState := { assignments := assignments, funVals := funVals }
|
||||
let (_, s) := (inferMain () ctx).run s
|
||||
let funVals := s.funVals
|
||||
let assignments := s.assignments
|
||||
modify fun s =>
|
||||
let env := decls.size.fold (init := s.env) fun i env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]
|
||||
{ s with env := env }
|
||||
pure $ decls.mapIdx fun i decl => elimDead assignments[i] decl
|
||||
let s ← get
|
||||
let env := s.env
|
||||
let assignments : Array Assignment := decls.map fun _ => {}
|
||||
let funVals := Std.mkPArray decls.size Value.bot
|
||||
let ctx : InterpContext := { decls := decls, env := env }
|
||||
let s : InterpState := { assignments := assignments, funVals := funVals }
|
||||
let (_, s) := (inferMain ctx).run s
|
||||
let funVals := s.funVals
|
||||
let assignments := s.assignments
|
||||
modify fun s =>
|
||||
let env := decls.size.fold (init := s.env) fun i env =>
|
||||
addFunctionSummary env decls[i].name funVals[i]
|
||||
{ s with env := env }
|
||||
pure $ decls.mapIdx fun i decl => elimDead assignments[i] decl
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -8,30 +8,27 @@ import Lean.Compiler.IR.FreeVars
|
|||
|
||||
namespace Lean.IR
|
||||
|
||||
partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnBody
|
||||
| bs, b, used =>
|
||||
if bs.isEmpty then b
|
||||
else
|
||||
let curr := bs.back
|
||||
let bs := bs.pop
|
||||
let keep (_ : Unit) :=
|
||||
let used := curr.collectFreeIndices used
|
||||
let b := curr.setBody b
|
||||
reshapeWithoutDeadAux bs b used
|
||||
let keepIfUsed (vidx : Index) :=
|
||||
if used.contains vidx then keep ()
|
||||
else reshapeWithoutDeadAux bs b used
|
||||
match curr with
|
||||
| FnBody.vdecl x _ _ _ => keepIfUsed x.idx
|
||||
-- TODO: we should keep all struct/union projections because they are used to ensure struct/union values are fully consumed.
|
||||
| FnBody.jdecl j _ _ _ => keepIfUsed j.idx
|
||||
| _ => keep ()
|
||||
partial def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody :=
|
||||
let rec reshape (bs : Array FnBody) (b : FnBody) (used : IndexSet) :=
|
||||
if bs.isEmpty then b
|
||||
else
|
||||
let curr := bs.back
|
||||
let bs := bs.pop
|
||||
let keep (_ : Unit) :=
|
||||
let used := curr.collectFreeIndices used
|
||||
let b := curr.setBody b
|
||||
reshape bs b used
|
||||
let keepIfUsed (vidx : Index) :=
|
||||
if used.contains vidx then keep ()
|
||||
else reshape bs b used
|
||||
match curr with
|
||||
| FnBody.vdecl x _ _ _ => keepIfUsed x.idx
|
||||
-- TODO: we should keep all struct/union projections because they are used to ensure struct/union values are fully consumed.
|
||||
| FnBody.jdecl j _ _ _ => keepIfUsed j.idx
|
||||
| _ => keep ()
|
||||
reshape bs term term.freeIndices
|
||||
|
||||
def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody :=
|
||||
reshapeWithoutDeadAux bs term term.freeIndices
|
||||
|
||||
partial def FnBody.elimDead : FnBody → FnBody
|
||||
| b =>
|
||||
partial def FnBody.elimDead (b : FnBody) : FnBody :=
|
||||
let (bs, term) := b.flatten
|
||||
let bs := modifyJPs bs elimDead
|
||||
let term := match term with
|
||||
|
|
@ -43,7 +40,7 @@ partial def FnBody.elimDead : FnBody → FnBody
|
|||
|
||||
/-- Eliminate dead let-declarations and join points -/
|
||||
def Decl.elimDead : Decl → Decl
|
||||
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.elimDead
|
||||
| other => other
|
||||
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.elimDead
|
||||
| other => other
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -11,44 +11,44 @@ import Lean.Compiler.IR.CompilerM
|
|||
namespace Lean.IR
|
||||
/- Return true iff `b` is of the form `let x := g ys; ret x` -/
|
||||
def isTailCallTo (g : Name) (b : FnBody) : Bool :=
|
||||
match b with
|
||||
| FnBody.vdecl x _ (Expr.fap f _) (FnBody.ret (Arg.var y)) => x == y && f == g
|
||||
| _ => false
|
||||
match b with
|
||||
| FnBody.vdecl x _ (Expr.fap f _) (FnBody.ret (Arg.var y)) => x == y && f == g
|
||||
| _ => false
|
||||
|
||||
def usesModuleFrom (env : Environment) (modulePrefix : Name) : Bool :=
|
||||
env.allImportedModuleNames.toList.any $ fun modName => modulePrefix.isPrefixOf modName
|
||||
env.allImportedModuleNames.toList.any $ fun modName => modulePrefix.isPrefixOf modName
|
||||
|
||||
namespace CollectUsedDecls
|
||||
|
||||
abbrev M := ReaderT Environment (StateM NameSet)
|
||||
|
||||
@[inline] def collect (f : FunId) : M Unit :=
|
||||
modify $ fun s => s.insert f
|
||||
modify fun s => s.insert f
|
||||
|
||||
partial def collectFnBody : FnBody → M Unit
|
||||
| FnBody.vdecl _ _ v b =>
|
||||
match v with
|
||||
| Expr.fap f _ => collect f *> collectFnBody b
|
||||
| Expr.pap f _ => collect f *> collectFnBody b
|
||||
| other => collectFnBody b
|
||||
| FnBody.jdecl _ _ v b => collectFnBody v *> collectFnBody b
|
||||
| FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body
|
||||
| e => do unless e.isTerminal do collectFnBody e.body
|
||||
| FnBody.vdecl _ _ v b =>
|
||||
match v with
|
||||
| Expr.fap f _ => collect f *> collectFnBody b
|
||||
| Expr.pap f _ => collect f *> collectFnBody b
|
||||
| other => collectFnBody b
|
||||
| FnBody.jdecl _ _ v b => collectFnBody v *> collectFnBody b
|
||||
| FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body
|
||||
| e => do unless e.isTerminal do collectFnBody e.body
|
||||
|
||||
def collectInitDecl (fn : Name) : M Unit := do
|
||||
let env ← read
|
||||
match getInitFnNameFor? env fn with
|
||||
| some initFn => collect initFn
|
||||
| _ => pure ()
|
||||
let env ← read
|
||||
match getInitFnNameFor? env fn with
|
||||
| some initFn => collect initFn
|
||||
| _ => pure ()
|
||||
|
||||
def collectDecl : Decl → M NameSet
|
||||
| Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get
|
||||
| Decl.extern fn _ _ _ => collectInitDecl fn *> get
|
||||
| Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get
|
||||
| Decl.extern fn _ _ _ => collectInitDecl fn *> get
|
||||
|
||||
end CollectUsedDecls
|
||||
|
||||
def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet :=
|
||||
(CollectUsedDecls.collectDecl decl env).run' used
|
||||
(CollectUsedDecls.collectDecl decl env).run' used
|
||||
|
||||
abbrev VarTypeMap := Std.HashMap VarId IRType
|
||||
abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
|
||||
|
|
@ -56,22 +56,22 @@ abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
|
|||
namespace CollectMaps
|
||||
abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap)
|
||||
@[inline] def collectVar (x : VarId) (t : IRType) : Collector
|
||||
| (vs, js) => (vs.insert x t, js)
|
||||
| (vs, js) => (vs.insert x t, js)
|
||||
def collectParams (ps : Array Param) : Collector :=
|
||||
fun s => ps.foldl (fun s p => collectVar p.x p.ty s) s
|
||||
fun s => ps.foldl (fun s p => collectVar p.x p.ty s) s
|
||||
@[inline] def collectJP (j : JoinPointId) (xs : Array Param) : Collector
|
||||
| (vs, js) => (vs, js.insert j xs)
|
||||
| (vs, js) => (vs, js.insert j xs)
|
||||
|
||||
/- `collectFnBody` assumes the variables in -/
|
||||
partial def collectFnBody : FnBody → Collector
|
||||
| FnBody.vdecl x t _ b => collectVar x t ∘ collectFnBody b
|
||||
| FnBody.jdecl j xs v b => collectJP j xs ∘ collectParams xs ∘ collectFnBody v ∘ collectFnBody b
|
||||
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
|
||||
| e => if e.isTerminal then id else collectFnBody e.body
|
||||
| FnBody.vdecl x t _ b => collectVar x t ∘ collectFnBody b
|
||||
| FnBody.jdecl j xs v b => collectJP j xs ∘ collectParams xs ∘ collectFnBody v ∘ collectFnBody b
|
||||
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
|
||||
| e => if e.isTerminal then id else collectFnBody e.body
|
||||
|
||||
def collectDecl : Decl → Collector
|
||||
| Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b
|
||||
| _ => id
|
||||
| Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b
|
||||
| _ => id
|
||||
|
||||
end CollectMaps
|
||||
|
||||
|
|
@ -79,7 +79,7 @@ end CollectMaps
|
|||
and `j` is a mapping from join point to parameters.
|
||||
This function assumes `d` has normalized indexes (see `normids.lean`). -/
|
||||
def mkVarJPMaps (d : Decl) : VarTypeMap × JPParamsMap :=
|
||||
CollectMaps.collectDecl d ({}, {})
|
||||
CollectMaps.collectDecl d ({}, {})
|
||||
|
||||
end IR
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -12,46 +12,45 @@ namespace Lean.IR.ExpandResetReuse
|
|||
abbrev ProjMap := Std.HashMap VarId Expr
|
||||
namespace CollectProjMap
|
||||
abbrev Collector := ProjMap → ProjMap
|
||||
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector :=
|
||||
fun m => match v with
|
||||
| Expr.proj _ _ => m.insert x v
|
||||
| Expr.sproj _ _ _ => m.insert x v
|
||||
| Expr.uproj _ _ => m.insert x v
|
||||
| _ => m
|
||||
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := fun m =>
|
||||
match v with
|
||||
| Expr.proj .. => m.insert x v
|
||||
| Expr.sproj .. => m.insert x v
|
||||
| Expr.uproj .. => m.insert x v
|
||||
| _ => m
|
||||
|
||||
partial def collectFnBody : FnBody → Collector
|
||||
| FnBody.vdecl x _ v b => collectVDecl x v ∘ collectFnBody b
|
||||
| FnBody.jdecl _ _ v b => collectFnBody v ∘ collectFnBody b
|
||||
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
|
||||
| e => if e.isTerminal then id else collectFnBody e.body
|
||||
| FnBody.vdecl x _ v b => collectVDecl x v ∘ collectFnBody b
|
||||
| FnBody.jdecl _ _ v b => collectFnBody v ∘ collectFnBody b
|
||||
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
|
||||
| e => if e.isTerminal then id else collectFnBody e.body
|
||||
end CollectProjMap
|
||||
|
||||
/- Create a mapping from variables to projections.
|
||||
This function assumes variable ids have been normalized -/
|
||||
def mkProjMap (d : Decl) : ProjMap :=
|
||||
match d with
|
||||
| Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {}
|
||||
| _ => {}
|
||||
match d with
|
||||
| Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {}
|
||||
| _ => {}
|
||||
|
||||
structure Context :=
|
||||
(projMap : ProjMap)
|
||||
(projMap : ProjMap)
|
||||
|
||||
/- Return true iff `x` is consumed in all branches of the current block.
|
||||
Here consumption means the block contains a `dec x` or `reuse x ...`. -/
|
||||
partial def consumed (x : VarId) : FnBody → Bool
|
||||
| FnBody.vdecl _ _ v b =>
|
||||
match v with
|
||||
| Expr.reuse y _ _ _ => x == y || consumed x b
|
||||
| _ => consumed x b
|
||||
| FnBody.dec y _ _ _ b => x == y || consumed x b
|
||||
| FnBody.case _ _ _ alts => alts.all fun alt => consumed x alt.body
|
||||
| e => !e.isTerminal && consumed x e.body
|
||||
| FnBody.vdecl _ _ v b =>
|
||||
match v with
|
||||
| Expr.reuse y _ _ _ => x == y || consumed x b
|
||||
| _ => consumed x b
|
||||
| FnBody.dec y _ _ _ b => x == y || consumed x b
|
||||
| FnBody.case _ _ _ alts => alts.all fun alt => consumed x alt.body
|
||||
| e => !e.isTerminal && consumed x e.body
|
||||
|
||||
abbrev Mask := Array (Option VarId)
|
||||
|
||||
/- Auxiliary function for eraseProjIncFor -/
|
||||
partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnBody → Array FnBody × Mask
|
||||
| bs, mask, keep =>
|
||||
partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
|
||||
let done (_ : Unit) := (bs ++ keep.reverse, mask)
|
||||
let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
|
||||
if bs.size < 2 then done ()
|
||||
|
|
@ -85,30 +84,30 @@ partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnB
|
|||
/- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`.
|
||||
Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/
|
||||
def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask :=
|
||||
eraseProjIncForAux y bs (mkArray n none) #[]
|
||||
eraseProjIncForAux y bs (mkArray n none) #[]
|
||||
|
||||
/- Replace `reuse x ctor ...` with `ctor ...`, and remoce `dec x` -/
|
||||
partial def reuseToCtor (x : VarId) : FnBody → FnBody
|
||||
| FnBody.dec y n c p b =>
|
||||
if x == y then b -- n must be 1 since `x := reset ...`
|
||||
else FnBody.dec y n c p (reuseToCtor x b)
|
||||
| FnBody.vdecl z t v b =>
|
||||
match v with
|
||||
| Expr.reuse y c u xs =>
|
||||
if x == y then FnBody.vdecl z t (Expr.ctor c xs) b
|
||||
else FnBody.vdecl z t v (reuseToCtor x b)
|
||||
| _ =>
|
||||
FnBody.vdecl z t v (reuseToCtor x b)
|
||||
| FnBody.case tid y yType alts =>
|
||||
let alts := alts.map fun alt => alt.modifyBody (reuseToCtor x)
|
||||
FnBody.case tid y yType alts
|
||||
| e =>
|
||||
if e.isTerminal then
|
||||
e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := reuseToCtor x b
|
||||
instr.setBody b
|
||||
| FnBody.dec y n c p b =>
|
||||
if x == y then b -- n must be 1 since `x := reset ...`
|
||||
else FnBody.dec y n c p (reuseToCtor x b)
|
||||
| FnBody.vdecl z t v b =>
|
||||
match v with
|
||||
| Expr.reuse y c u xs =>
|
||||
if x == y then FnBody.vdecl z t (Expr.ctor c xs) b
|
||||
else FnBody.vdecl z t v (reuseToCtor x b)
|
||||
| _ =>
|
||||
FnBody.vdecl z t v (reuseToCtor x b)
|
||||
| FnBody.case tid y yType alts =>
|
||||
let alts := alts.map fun alt => alt.modifyBody (reuseToCtor x)
|
||||
FnBody.case tid y yType alts
|
||||
| e =>
|
||||
if e.isTerminal then
|
||||
e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := reuseToCtor x b
|
||||
instr.setBody b
|
||||
|
||||
/-
|
||||
replace
|
||||
|
|
@ -123,97 +122,91 @@ where `z_i`'s are the variables in `mask`,
|
|||
and `b'` is `b` where we removed `dec x` and replaced `reuse x ctor_i ...` with `ctor_i ...`.
|
||||
-/
|
||||
def mkSlowPath (x y : VarId) (mask : Mask) (b : FnBody) : FnBody :=
|
||||
let b := reuseToCtor x b
|
||||
let b := FnBody.dec y 1 true false b
|
||||
mask.foldl
|
||||
(fun b m => match m with
|
||||
| some z => FnBody.inc z 1 true false b
|
||||
| none => b)
|
||||
b
|
||||
let b := reuseToCtor x b
|
||||
let b := FnBody.dec y 1 true false b
|
||||
mask.foldl (init := b) fun b m => match m with
|
||||
| some z => FnBody.inc z 1 true false b
|
||||
| none => b
|
||||
|
||||
abbrev M := ReaderT Context (StateM Nat)
|
||||
def mkFresh : M VarId :=
|
||||
modifyGet $ fun n => ({ idx := n }, n + 1)
|
||||
def mkFresh : M VarId :=
|
||||
modifyGet $ fun n => ({ idx := n }, n + 1)
|
||||
|
||||
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
|
||||
mask.size.foldM
|
||||
(fun i b =>
|
||||
mask.size.foldM (init := b) fun i b =>
|
||||
match mask.get! i with
|
||||
| some _ => pure b -- code took ownership of this field
|
||||
| none => do
|
||||
let fld ← mkFresh
|
||||
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b)))
|
||||
b
|
||||
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b))
|
||||
|
||||
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
|
||||
zs.size.fold
|
||||
(fun i b => FnBody.set y i (zs.get! i) b)
|
||||
b
|
||||
zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b
|
||||
|
||||
/- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
|
||||
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
|
||||
match y with
|
||||
| Arg.var y =>
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.proj j w) => j == i && w == x
|
||||
match y with
|
||||
| Arg.var y =>
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.proj j w) => j == i && w == x
|
||||
| _ => false
|
||||
| _ => false
|
||||
| _ => false
|
||||
|
||||
/- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
|
||||
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.uproj j w) => j == i && w == x
|
||||
| _ => false
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.uproj j w) => j == i && w == x
|
||||
| _ => false
|
||||
|
||||
/- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
|
||||
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.sproj m j w) => n == m && j == i && w == x
|
||||
| _ => false
|
||||
match ctx.projMap.find? y with
|
||||
| some (Expr.sproj m j w) => n == m && j == i && w == x
|
||||
| _ => false
|
||||
|
||||
/- Remove unnecessary `set/uset/sset` operations -/
|
||||
partial def removeSelfSet (ctx : Context) : FnBody → FnBody
|
||||
| FnBody.set x i y b =>
|
||||
if isSelfSet ctx x i y then removeSelfSet ctx b
|
||||
else FnBody.set x i y (removeSelfSet ctx b)
|
||||
| FnBody.uset x i y b =>
|
||||
if isSelfUSet ctx x i y then removeSelfSet ctx b
|
||||
else FnBody.uset x i y (removeSelfSet ctx b)
|
||||
| FnBody.sset x n i y t b =>
|
||||
if isSelfSSet ctx x n i y then removeSelfSet ctx b
|
||||
else FnBody.sset x n i y t (removeSelfSet ctx b)
|
||||
| FnBody.case tid y yType alts =>
|
||||
let alts := alts.map fun alt => alt.modifyBody (removeSelfSet ctx)
|
||||
FnBody.case tid y yType alts
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := removeSelfSet ctx b
|
||||
instr.setBody b
|
||||
| FnBody.set x i y b =>
|
||||
if isSelfSet ctx x i y then removeSelfSet ctx b
|
||||
else FnBody.set x i y (removeSelfSet ctx b)
|
||||
| FnBody.uset x i y b =>
|
||||
if isSelfUSet ctx x i y then removeSelfSet ctx b
|
||||
else FnBody.uset x i y (removeSelfSet ctx b)
|
||||
| FnBody.sset x n i y t b =>
|
||||
if isSelfSSet ctx x n i y then removeSelfSet ctx b
|
||||
else FnBody.sset x n i y t (removeSelfSet ctx b)
|
||||
| FnBody.case tid y yType alts =>
|
||||
let alts := alts.map fun alt => alt.modifyBody (removeSelfSet ctx)
|
||||
FnBody.case tid y yType alts
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := removeSelfSet ctx b
|
||||
instr.setBody b
|
||||
|
||||
partial def reuseToSet (ctx : Context) (x y : VarId) : FnBody → FnBody
|
||||
| FnBody.dec z n c p b =>
|
||||
if x == z then FnBody.del y b
|
||||
else FnBody.dec z n c p (reuseToSet ctx x y b)
|
||||
| FnBody.vdecl z t v b =>
|
||||
match v with
|
||||
| Expr.reuse w c u zs =>
|
||||
if x == w then
|
||||
let b := setFields y zs (b.replaceVar z y)
|
||||
let b := if u then FnBody.setTag y c.cidx b else b
|
||||
removeSelfSet ctx b
|
||||
else FnBody.vdecl z t v (reuseToSet ctx x y b)
|
||||
| _ => FnBody.vdecl z t v (reuseToSet ctx x y b)
|
||||
| FnBody.case tid z zType alts =>
|
||||
let alts := alts.map fun alt => alt.modifyBody (reuseToSet ctx x y)
|
||||
FnBody.case tid z zType alts
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := reuseToSet ctx x y b
|
||||
instr.setBody b
|
||||
| FnBody.dec z n c p b =>
|
||||
if x == z then FnBody.del y b
|
||||
else FnBody.dec z n c p (reuseToSet ctx x y b)
|
||||
| FnBody.vdecl z t v b =>
|
||||
match v with
|
||||
| Expr.reuse w c u zs =>
|
||||
if x == w then
|
||||
let b := setFields y zs (b.replaceVar z y)
|
||||
let b := if u then FnBody.setTag y c.cidx b else b
|
||||
removeSelfSet ctx b
|
||||
else FnBody.vdecl z t v (reuseToSet ctx x y b)
|
||||
| _ => FnBody.vdecl z t v (reuseToSet ctx x y b)
|
||||
| FnBody.case tid z zType alts =>
|
||||
let alts := alts.map fun alt => alt.modifyBody (reuseToSet ctx x y)
|
||||
FnBody.case tid z zType alts
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split
|
||||
let b := reuseToSet ctx x y b
|
||||
instr.setBody b
|
||||
|
||||
/-
|
||||
replace
|
||||
|
|
@ -235,54 +228,54 @@ and `z := reuse x ctor_i ws; F` is replaced with
|
|||
`set x i ws[i]` operations, and we replace `z` with `x` in `F`
|
||||
-/
|
||||
def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody := do
|
||||
let ctx ← read
|
||||
let b := reuseToSet ctx x y b
|
||||
releaseUnreadFields y mask b
|
||||
let ctx ← read
|
||||
let b := reuseToSet ctx x y b
|
||||
releaseUnreadFields y mask b
|
||||
|
||||
-- Expand `bs; x := reset[n] y; b`
|
||||
partial def expand (mainFn : FnBody → Array FnBody → M FnBody)
|
||||
(bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody := do
|
||||
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b
|
||||
let (bs, mask) := eraseProjIncFor n y bs
|
||||
/- Remark: we may be duplicting variable/JP indices. That is, `bSlow` and `bFast` may
|
||||
have duplicate indices. We run `normalizeIds` to fix the ids after we have expand them. -/
|
||||
let bSlow := mkSlowPath x y mask b
|
||||
let bFast ← mkFastPath x y mask b
|
||||
/- We only optimize recursively the fast. -/
|
||||
let bFast ← mainFn bFast #[]
|
||||
let c ← mkFresh
|
||||
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast)
|
||||
pure $ reshape bs b
|
||||
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b
|
||||
let (bs, mask) := eraseProjIncFor n y bs
|
||||
/- Remark: we may be duplicting variable/JP indices. That is, `bSlow` and `bFast` may
|
||||
have duplicate indices. We run `normalizeIds` to fix the ids after we have expand them. -/
|
||||
let bSlow := mkSlowPath x y mask b
|
||||
let bFast ← mkFastPath x y mask b
|
||||
/- We only optimize recursively the fast. -/
|
||||
let bFast ← mainFn bFast #[]
|
||||
let c ← mkFresh
|
||||
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast)
|
||||
pure $ reshape bs b
|
||||
|
||||
partial def searchAndExpand : FnBody → Array FnBody → M FnBody
|
||||
| d@(FnBody.vdecl x t (Expr.reset n y) b), bs =>
|
||||
if consumed x b then do
|
||||
expand searchAndExpand bs x n y b
|
||||
else
|
||||
searchAndExpand b (push bs d)
|
||||
| FnBody.jdecl j xs v b, bs => do
|
||||
let v ← searchAndExpand v #[]
|
||||
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
|
||||
| FnBody.case tid x xType alts, bs => do
|
||||
let alts ← alts.mapM $ fun alt => alt.mmodifyBody fun b => searchAndExpand b #[]
|
||||
pure $ reshape bs (FnBody.case tid x xType alts)
|
||||
| b, bs =>
|
||||
if b.isTerminal then pure $ reshape bs b
|
||||
else searchAndExpand b.body (push bs b)
|
||||
| d@(FnBody.vdecl x t (Expr.reset n y) b), bs =>
|
||||
if consumed x b then do
|
||||
expand searchAndExpand bs x n y b
|
||||
else
|
||||
searchAndExpand b (push bs d)
|
||||
| FnBody.jdecl j xs v b, bs => do
|
||||
let v ← searchAndExpand v #[]
|
||||
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
|
||||
| FnBody.case tid x xType alts, bs => do
|
||||
let alts ← alts.mapM $ fun alt => alt.mmodifyBody fun b => searchAndExpand b #[]
|
||||
pure $ reshape bs (FnBody.case tid x xType alts)
|
||||
| b, bs =>
|
||||
if b.isTerminal then pure $ reshape bs b
|
||||
else searchAndExpand b.body (push bs b)
|
||||
|
||||
def main (d : Decl) : Decl :=
|
||||
match d with
|
||||
| (Decl.fdecl f xs t b) =>
|
||||
let m := mkProjMap d
|
||||
let nextIdx := d.maxIndex + 1
|
||||
let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx
|
||||
Decl.fdecl f xs t b
|
||||
| d => d
|
||||
match d with
|
||||
| (Decl.fdecl f xs t b) =>
|
||||
let m := mkProjMap d
|
||||
let nextIdx := d.maxIndex + 1
|
||||
let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx
|
||||
Decl.fdecl f xs t b
|
||||
| d => d
|
||||
|
||||
end ExpandResetReuse
|
||||
|
||||
/-- (Try to) expand `reset` and `reuse` instructions. -/
|
||||
def Decl.expandResetReuse (d : Decl) : Decl :=
|
||||
(ExpandResetReuse.main d).normalizeIds
|
||||
(ExpandResetReuse.main d).normalizeIds
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -26,62 +26,62 @@ abbrev Collector := Index → Index
|
|||
instance : AndThen Collector := ⟨seq⟩
|
||||
|
||||
private def collectArg : Arg → Collector
|
||||
| Arg.var x => collectVar x
|
||||
| irrelevant => skip
|
||||
| Arg.var x => collectVar x
|
||||
| irrelevant => skip
|
||||
|
||||
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
|
||||
fun m => as.foldl (fun m a => f a m) m
|
||||
fun m => as.foldl (fun m a => f a m) m
|
||||
|
||||
private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg
|
||||
private def collectParam (p : Param) : Collector := collectVar p.x
|
||||
private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam
|
||||
|
||||
private def collectExpr : Expr → Collector
|
||||
| Expr.ctor _ ys => collectArgs ys
|
||||
| Expr.reset _ x => collectVar x
|
||||
| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys
|
||||
| Expr.proj _ x => collectVar x
|
||||
| Expr.uproj _ x => collectVar x
|
||||
| Expr.sproj _ _ x => collectVar x
|
||||
| Expr.fap _ ys => collectArgs ys
|
||||
| Expr.pap _ ys => collectArgs ys
|
||||
| Expr.ap x ys => collectVar x >> collectArgs ys
|
||||
| Expr.box _ x => collectVar x
|
||||
| Expr.unbox x => collectVar x
|
||||
| Expr.lit v => skip
|
||||
| Expr.isShared x => collectVar x
|
||||
| Expr.isTaggedPtr x => collectVar x
|
||||
| Expr.ctor _ ys => collectArgs ys
|
||||
| Expr.reset _ x => collectVar x
|
||||
| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys
|
||||
| Expr.proj _ x => collectVar x
|
||||
| Expr.uproj _ x => collectVar x
|
||||
| Expr.sproj _ _ x => collectVar x
|
||||
| Expr.fap _ ys => collectArgs ys
|
||||
| Expr.pap _ ys => collectArgs ys
|
||||
| Expr.ap x ys => collectVar x >> collectArgs ys
|
||||
| Expr.box _ x => collectVar x
|
||||
| Expr.unbox x => collectVar x
|
||||
| Expr.lit v => skip
|
||||
| Expr.isShared x => collectVar x
|
||||
| Expr.isTaggedPtr x => collectVar x
|
||||
|
||||
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
|
||||
collectArray alts $ fun alt => f alt.body
|
||||
collectArray alts $ fun alt => f alt.body
|
||||
|
||||
partial def collectFnBody : FnBody → Collector
|
||||
| FnBody.vdecl x _ v b => collectVar x >> collectExpr v >> collectFnBody b
|
||||
| FnBody.jdecl j ys v b => collectJP j >> collectFnBody v >> collectParams ys >> collectFnBody b
|
||||
| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b
|
||||
| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b
|
||||
| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b
|
||||
| FnBody.setTag x _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.del x b => collectVar x >> collectFnBody b
|
||||
| FnBody.mdata _ b => collectFnBody b
|
||||
| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts
|
||||
| FnBody.jmp j ys => collectJP j >> collectArgs ys
|
||||
| FnBody.ret x => collectArg x
|
||||
| FnBody.unreachable => skip
|
||||
| FnBody.vdecl x _ v b => collectVar x >> collectExpr v >> collectFnBody b
|
||||
| FnBody.jdecl j ys v b => collectJP j >> collectFnBody v >> collectParams ys >> collectFnBody b
|
||||
| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b
|
||||
| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b
|
||||
| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b
|
||||
| FnBody.setTag x _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.del x b => collectVar x >> collectFnBody b
|
||||
| FnBody.mdata _ b => collectFnBody b
|
||||
| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts
|
||||
| FnBody.jmp j ys => collectJP j >> collectArgs ys
|
||||
| FnBody.ret x => collectArg x
|
||||
| FnBody.unreachable => skip
|
||||
|
||||
partial def collectDecl : Decl → Collector
|
||||
| Decl.fdecl _ xs _ b => collectParams xs >> collectFnBody b
|
||||
| Decl.extern _ xs _ _ => collectParams xs
|
||||
| Decl.fdecl _ xs _ b => collectParams xs >> collectFnBody b
|
||||
| Decl.extern _ xs _ _ => collectParams xs
|
||||
|
||||
end MaxIndex
|
||||
|
||||
def FnBody.maxIndex (b : FnBody) : Index :=
|
||||
MaxIndex.collectFnBody b 0
|
||||
MaxIndex.collectFnBody b 0
|
||||
|
||||
def Decl.maxIndex (d : Decl) : Index :=
|
||||
MaxIndex.collectDecl d 0
|
||||
MaxIndex.collectDecl d 0
|
||||
|
||||
namespace FreeIndices
|
||||
/- We say a variable (join point) index (aka name) is free in a function body
|
||||
|
|
@ -90,89 +90,89 @@ namespace FreeIndices
|
|||
abbrev Collector := IndexSet → IndexSet → IndexSet
|
||||
|
||||
@[inline] private def skip : Collector :=
|
||||
fun bv fv => fv
|
||||
fun bv fv => fv
|
||||
|
||||
@[inline] private def collectIndex (x : Index) : Collector :=
|
||||
fun bv fv => if bv.contains x then fv else fv.insert x
|
||||
fun bv fv => if bv.contains x then fv else fv.insert x
|
||||
|
||||
@[inline] private def collectVar (x : VarId) : Collector :=
|
||||
collectIndex x.idx
|
||||
collectIndex x.idx
|
||||
|
||||
@[inline] private def collectJP (x : JoinPointId) : Collector :=
|
||||
collectIndex x.idx
|
||||
collectIndex x.idx
|
||||
|
||||
@[inline] private def withIndex (x : Index) : Collector → Collector :=
|
||||
fun k bv fv => k (bv.insert x) fv
|
||||
fun k bv fv => k (bv.insert x) fv
|
||||
|
||||
@[inline] private def withVar (x : VarId) : Collector → Collector :=
|
||||
withIndex x.idx
|
||||
withIndex x.idx
|
||||
|
||||
@[inline] private def withJP (x : JoinPointId) : Collector → Collector :=
|
||||
withIndex x.idx
|
||||
withIndex x.idx
|
||||
|
||||
def insertParams (s : IndexSet) (ys : Array Param) : IndexSet :=
|
||||
ys.foldl (fun s p => s.insert p.x.idx) s
|
||||
ys.foldl (init := s) fun s p => s.insert p.x.idx
|
||||
|
||||
@[inline] private def withParams (ys : Array Param) : Collector → Collector :=
|
||||
fun k bv fv => k (insertParams bv ys) fv
|
||||
fun k bv fv => k (insertParams bv ys) fv
|
||||
|
||||
@[inline] private def seq : Collector → Collector → Collector :=
|
||||
fun k₁ k₂ bv fv => k₂ bv (k₁ bv fv)
|
||||
fun k₁ k₂ bv fv => k₂ bv (k₁ bv fv)
|
||||
|
||||
instance : AndThen Collector := ⟨seq⟩
|
||||
|
||||
private def collectArg : Arg → Collector
|
||||
| Arg.var x => collectVar x
|
||||
| irrelevant => skip
|
||||
| Arg.var x => collectVar x
|
||||
| irrelevant => skip
|
||||
|
||||
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
|
||||
fun bv fv => as.foldl (fun fv a => f a bv fv) fv
|
||||
fun bv fv => as.foldl (fun fv a => f a bv fv) fv
|
||||
|
||||
private def collectArgs (as : Array Arg) : Collector :=
|
||||
collectArray as collectArg
|
||||
collectArray as collectArg
|
||||
|
||||
private def collectExpr : Expr → Collector
|
||||
| Expr.ctor _ ys => collectArgs ys
|
||||
| Expr.reset _ x => collectVar x
|
||||
| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys
|
||||
| Expr.proj _ x => collectVar x
|
||||
| Expr.uproj _ x => collectVar x
|
||||
| Expr.sproj _ _ x => collectVar x
|
||||
| Expr.fap _ ys => collectArgs ys
|
||||
| Expr.pap _ ys => collectArgs ys
|
||||
| Expr.ap x ys => collectVar x >> collectArgs ys
|
||||
| Expr.box _ x => collectVar x
|
||||
| Expr.unbox x => collectVar x
|
||||
| Expr.lit v => skip
|
||||
| Expr.isShared x => collectVar x
|
||||
| Expr.isTaggedPtr x => collectVar x
|
||||
| Expr.ctor _ ys => collectArgs ys
|
||||
| Expr.reset _ x => collectVar x
|
||||
| Expr.reuse x _ _ ys => collectVar x >> collectArgs ys
|
||||
| Expr.proj _ x => collectVar x
|
||||
| Expr.uproj _ x => collectVar x
|
||||
| Expr.sproj _ _ x => collectVar x
|
||||
| Expr.fap _ ys => collectArgs ys
|
||||
| Expr.pap _ ys => collectArgs ys
|
||||
| Expr.ap x ys => collectVar x >> collectArgs ys
|
||||
| Expr.box _ x => collectVar x
|
||||
| Expr.unbox x => collectVar x
|
||||
| Expr.lit v => skip
|
||||
| Expr.isShared x => collectVar x
|
||||
| Expr.isTaggedPtr x => collectVar x
|
||||
|
||||
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
|
||||
collectArray alts $ fun alt => f alt.body
|
||||
collectArray alts $ fun alt => f alt.body
|
||||
|
||||
partial def collectFnBody : FnBody → Collector
|
||||
| FnBody.vdecl x _ v b => collectExpr v >> withVar x (collectFnBody b)
|
||||
| FnBody.jdecl j ys v b => withParams ys (collectFnBody v) >> withJP j (collectFnBody b)
|
||||
| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b
|
||||
| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b
|
||||
| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b
|
||||
| FnBody.setTag x _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.del x b => collectVar x >> collectFnBody b
|
||||
| FnBody.mdata _ b => collectFnBody b
|
||||
| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts
|
||||
| FnBody.jmp j ys => collectJP j >> collectArgs ys
|
||||
| FnBody.ret x => collectArg x
|
||||
| FnBody.unreachable => skip
|
||||
| FnBody.vdecl x _ v b => collectExpr v >> withVar x (collectFnBody b)
|
||||
| FnBody.jdecl j ys v b => withParams ys (collectFnBody v) >> withJP j (collectFnBody b)
|
||||
| FnBody.set x _ y b => collectVar x >> collectArg y >> collectFnBody b
|
||||
| FnBody.uset x _ y b => collectVar x >> collectVar y >> collectFnBody b
|
||||
| FnBody.sset x _ _ y _ b => collectVar x >> collectVar y >> collectFnBody b
|
||||
| FnBody.setTag x _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.inc x _ _ _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.dec x _ _ _ b => collectVar x >> collectFnBody b
|
||||
| FnBody.del x b => collectVar x >> collectFnBody b
|
||||
| FnBody.mdata _ b => collectFnBody b
|
||||
| FnBody.case _ x _ alts => collectVar x >> collectAlts collectFnBody alts
|
||||
| FnBody.jmp j ys => collectJP j >> collectArgs ys
|
||||
| FnBody.ret x => collectArg x
|
||||
| FnBody.unreachable => skip
|
||||
|
||||
end FreeIndices
|
||||
|
||||
def FnBody.collectFreeIndices (b : FnBody) (vs : IndexSet) : IndexSet :=
|
||||
FreeIndices.collectFnBody b {} vs
|
||||
FreeIndices.collectFnBody b {} vs
|
||||
|
||||
def FnBody.freeIndices (b : FnBody) : IndexSet :=
|
||||
b.collectFreeIndices {}
|
||||
b.collectFreeIndices {}
|
||||
|
||||
namespace HasIndex
|
||||
/- In principle, we can check whether a function body `b` contains an index `i` using
|
||||
|
|
@ -183,46 +183,46 @@ def visitVar (w : Index) (x : VarId) : Bool := w == x.idx
|
|||
def visitJP (w : Index) (x : JoinPointId) : Bool := w == x.idx
|
||||
|
||||
def visitArg (w : Index) : Arg → Bool
|
||||
| Arg.var x => visitVar w x
|
||||
| _ => false
|
||||
| Arg.var x => visitVar w x
|
||||
| _ => false
|
||||
|
||||
def visitArgs (w : Index) (xs : Array Arg) : Bool :=
|
||||
xs.any (visitArg w)
|
||||
xs.any (visitArg w)
|
||||
|
||||
def visitParams (w : Index) (ps : Array Param) : Bool :=
|
||||
ps.any (fun p => w == p.x.idx)
|
||||
ps.any (fun p => w == p.x.idx)
|
||||
|
||||
def visitExpr (w : Index) : Expr → Bool
|
||||
| Expr.ctor _ ys => visitArgs w ys
|
||||
| Expr.reset _ x => visitVar w x
|
||||
| Expr.reuse x _ _ ys => visitVar w x || visitArgs w ys
|
||||
| Expr.proj _ x => visitVar w x
|
||||
| Expr.uproj _ x => visitVar w x
|
||||
| Expr.sproj _ _ x => visitVar w x
|
||||
| Expr.fap _ ys => visitArgs w ys
|
||||
| Expr.pap _ ys => visitArgs w ys
|
||||
| Expr.ap x ys => visitVar w x || visitArgs w ys
|
||||
| Expr.box _ x => visitVar w x
|
||||
| Expr.unbox x => visitVar w x
|
||||
| Expr.lit v => false
|
||||
| Expr.isShared x => visitVar w x
|
||||
| Expr.isTaggedPtr x => visitVar w x
|
||||
| Expr.ctor _ ys => visitArgs w ys
|
||||
| Expr.reset _ x => visitVar w x
|
||||
| Expr.reuse x _ _ ys => visitVar w x || visitArgs w ys
|
||||
| Expr.proj _ x => visitVar w x
|
||||
| Expr.uproj _ x => visitVar w x
|
||||
| Expr.sproj _ _ x => visitVar w x
|
||||
| Expr.fap _ ys => visitArgs w ys
|
||||
| Expr.pap _ ys => visitArgs w ys
|
||||
| Expr.ap x ys => visitVar w x || visitArgs w ys
|
||||
| Expr.box _ x => visitVar w x
|
||||
| Expr.unbox x => visitVar w x
|
||||
| Expr.lit v => false
|
||||
| Expr.isShared x => visitVar w x
|
||||
| Expr.isTaggedPtr x => visitVar w x
|
||||
|
||||
partial def visitFnBody (w : Index) : FnBody → Bool
|
||||
| FnBody.vdecl x _ v b => visitExpr w v || visitFnBody w b
|
||||
| FnBody.jdecl j ys v b => visitFnBody w v || visitFnBody w b
|
||||
| FnBody.set x _ y b => visitVar w x || visitArg w y || visitFnBody w b
|
||||
| FnBody.uset x _ y b => visitVar w x || visitVar w y || visitFnBody w b
|
||||
| FnBody.sset x _ _ y _ b => visitVar w x || visitVar w y || visitFnBody w b
|
||||
| FnBody.setTag x _ b => visitVar w x || visitFnBody w b
|
||||
| FnBody.inc x _ _ _ b => visitVar w x || visitFnBody w b
|
||||
| FnBody.dec x _ _ _ b => visitVar w x || visitFnBody w b
|
||||
| FnBody.del x b => visitVar w x || visitFnBody w b
|
||||
| FnBody.mdata _ b => visitFnBody w b
|
||||
| FnBody.jmp j ys => visitJP w j || visitArgs w ys
|
||||
| FnBody.ret x => visitArg w x
|
||||
| FnBody.case _ x _ alts => visitVar w x || alts.any (fun alt => visitFnBody w alt.body)
|
||||
| FnBody.unreachable => false
|
||||
| FnBody.vdecl x _ v b => visitExpr w v || visitFnBody w b
|
||||
| FnBody.jdecl j ys v b => visitFnBody w v || visitFnBody w b
|
||||
| FnBody.set x _ y b => visitVar w x || visitArg w y || visitFnBody w b
|
||||
| FnBody.uset x _ y b => visitVar w x || visitVar w y || visitFnBody w b
|
||||
| FnBody.sset x _ _ y _ b => visitVar w x || visitVar w y || visitFnBody w b
|
||||
| FnBody.setTag x _ b => visitVar w x || visitFnBody w b
|
||||
| FnBody.inc x _ _ _ b => visitVar w x || visitFnBody w b
|
||||
| FnBody.dec x _ _ _ b => visitVar w x || visitFnBody w b
|
||||
| FnBody.del x b => visitVar w x || visitFnBody w b
|
||||
| FnBody.mdata _ b => visitFnBody w b
|
||||
| FnBody.jmp j ys => visitJP w j || visitArgs w ys
|
||||
| FnBody.ret x => visitArg w x
|
||||
| FnBody.case _ x _ alts => visitVar w x || alts.any (fun alt => visitFnBody w alt.body)
|
||||
| FnBody.unreachable => false
|
||||
|
||||
end HasIndex
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ import Lean.Compiler.IR.NormIds
|
|||
|
||||
namespace Lean.IR
|
||||
|
||||
partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array FnBody → IndexSet → Array FnBody × Array Alt
|
||||
| bs, alts, altsF, ctx, ctxF =>
|
||||
partial def pushProjs (bs : Array FnBody) (alts : Array Alt) (altsF : Array IndexSet) (ctx : Array FnBody) (ctxF : IndexSet) : Array FnBody × Array Alt :=
|
||||
if bs.isEmpty then (ctx.reverse, alts)
|
||||
else
|
||||
let b := bs.back
|
||||
|
|
@ -37,8 +36,7 @@ partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array
|
|||
| _ => done ()
|
||||
| _ => done ()
|
||||
|
||||
partial def FnBody.pushProj : FnBody → FnBody
|
||||
| b =>
|
||||
partial def FnBody.pushProj (b : FnBody) : FnBody :=
|
||||
let (bs, term) := b.flatten
|
||||
let bs := modifyJPs bs pushProj
|
||||
match term with
|
||||
|
|
@ -52,7 +50,7 @@ partial def FnBody.pushProj : FnBody → FnBody
|
|||
|
||||
/-- Push projections inside `case` branches. -/
|
||||
def Decl.pushProj : Decl → Decl
|
||||
| Decl.fdecl f xs t b => (Decl.fdecl f xs t b.pushProj).normalizeIds
|
||||
| other => other
|
||||
| Decl.fdecl f xs t b => (Decl.fdecl f xs t b.pushProj).normalizeIds
|
||||
| other => other
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -14,82 +14,82 @@ namespace Lean.IR.ExplicitRC
|
|||
-/
|
||||
|
||||
structure VarInfo :=
|
||||
(ref : Bool := true) -- true if the variable may be a reference (aka pointer) at runtime
|
||||
(persistent : Bool := false) -- true if the variable is statically known to be marked a Persistent at runtime
|
||||
(consume : Bool := false) -- true if the variable RC must be "consumed"
|
||||
(ref : Bool := true) -- true if the variable may be a reference (aka pointer) at runtime
|
||||
(persistent : Bool := false) -- true if the variable is statically known to be marked a Persistent at runtime
|
||||
(consume : Bool := false) -- true if the variable RC must be "consumed"
|
||||
|
||||
abbrev VarMap := Std.RBMap VarId VarInfo (fun x y => x.idx < y.idx)
|
||||
|
||||
structure Context :=
|
||||
(env : Environment)
|
||||
(decls : Array Decl)
|
||||
(varMap : VarMap := {})
|
||||
(jpLiveVarMap : JPLiveVarMap := {}) -- map: join point => live variables
|
||||
(localCtx : LocalContext := {}) -- we use it to store the join point declarations
|
||||
(env : Environment)
|
||||
(decls : Array Decl)
|
||||
(varMap : VarMap := {})
|
||||
(jpLiveVarMap : JPLiveVarMap := {}) -- map: join point => live variables
|
||||
(localCtx : LocalContext := {}) -- we use it to store the join point declarations
|
||||
|
||||
def getDecl (ctx : Context) (fid : FunId) : Decl :=
|
||||
match findEnvDecl' ctx.env fid ctx.decls with
|
||||
| some decl => decl
|
||||
| none => arbitrary _ -- unreachable if well-formed
|
||||
match findEnvDecl' ctx.env fid ctx.decls with
|
||||
| some decl => decl
|
||||
| none => arbitrary _ -- unreachable if well-formed
|
||||
|
||||
def getVarInfo (ctx : Context) (x : VarId) : VarInfo :=
|
||||
match ctx.varMap.find? x with
|
||||
| some info => info
|
||||
| none => {} -- unreachable in well-formed code
|
||||
match ctx.varMap.find? x with
|
||||
| some info => info
|
||||
| none => {} -- unreachable in well-formed code
|
||||
|
||||
def getJPParams (ctx : Context) (j : JoinPointId) : Array Param :=
|
||||
match ctx.localCtx.getJPParams j with
|
||||
| some ps => ps
|
||||
| none => #[] -- unreachable in well-formed code
|
||||
match ctx.localCtx.getJPParams j with
|
||||
| some ps => ps
|
||||
| none => #[] -- unreachable in well-formed code
|
||||
|
||||
def getJPLiveVars (ctx : Context) (j : JoinPointId) : LiveVarSet :=
|
||||
match ctx.jpLiveVarMap.find? j with
|
||||
| some s => s
|
||||
| none => {}
|
||||
match ctx.jpLiveVarMap.find? j with
|
||||
| some s => s
|
||||
| none => {}
|
||||
|
||||
def mustConsume (ctx : Context) (x : VarId) : Bool :=
|
||||
let info := getVarInfo ctx x
|
||||
info.ref && info.consume
|
||||
let info := getVarInfo ctx x
|
||||
info.ref && info.consume
|
||||
|
||||
@[inline] def addInc (ctx : Context) (x : VarId) (b : FnBody) (n := 1) : FnBody :=
|
||||
let info := getVarInfo ctx x
|
||||
if n == 0 then b else FnBody.inc x n true info.persistent b
|
||||
let info := getVarInfo ctx x
|
||||
if n == 0 then b else FnBody.inc x n true info.persistent b
|
||||
|
||||
@[inline] def addDec (ctx : Context) (x : VarId) (b : FnBody) : FnBody :=
|
||||
let info := getVarInfo ctx x
|
||||
FnBody.dec x 1 true info.persistent b
|
||||
let info := getVarInfo ctx x
|
||||
FnBody.dec x 1 true info.persistent b
|
||||
|
||||
private def updateRefUsingCtorInfo (ctx : Context) (x : VarId) (c : CtorInfo) : Context :=
|
||||
if c.isRef then ctx
|
||||
else
|
||||
let m := ctx.varMap
|
||||
{ ctx with
|
||||
varMap := match m.find? x with
|
||||
| some info => m.insert x { info with ref := false } -- I really want a Lenses library + notation
|
||||
| none => m }
|
||||
if c.isRef then
|
||||
ctx
|
||||
else
|
||||
let m := ctx.varMap
|
||||
{ ctx with
|
||||
varMap := match m.find? x with
|
||||
| some info => m.insert x { info with ref := false } -- I really want a Lenses library + notation
|
||||
| none => m }
|
||||
|
||||
private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet) (b : FnBody) : FnBody :=
|
||||
caseLiveVars.fold
|
||||
(fun b x => if !altLiveVars.contains x && mustConsume ctx x then addDec ctx x b else b)
|
||||
b
|
||||
caseLiveVars.fold (init := b) fun b x =>
|
||||
if !altLiveVars.contains x && mustConsume ctx x then addDec ctx x b else b
|
||||
|
||||
/- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/
|
||||
private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool :=
|
||||
let x := xs[i]
|
||||
i.all fun j => xs[j] != x
|
||||
let x := xs[i]
|
||||
i.all fun j => xs[j] != x
|
||||
|
||||
/- Return true if `x` also occurs in `ys` in a position that is not consumed.
|
||||
That is, it is also passed as a borrow reference. -/
|
||||
@[specialize]
|
||||
private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Bool :=
|
||||
ys.size.any $ fun i =>
|
||||
let y := ys[i]
|
||||
match y with
|
||||
| Arg.irrelevant => false
|
||||
| Arg.var y => x == y && !consumeParamPred i
|
||||
ys.size.any fun i =>
|
||||
let y := ys[i]
|
||||
match y with
|
||||
| Arg.irrelevant => false
|
||||
| Arg.var y => x == y && !consumeParamPred i
|
||||
|
||||
private def isBorrowParam (x : VarId) (ys : Array Arg) (ps : Array Param) : Bool :=
|
||||
isBorrowParamAux x ys fun i => not ps[i].borrow
|
||||
isBorrowParamAux x ys fun i => not ps[i].borrow
|
||||
|
||||
/-
|
||||
Return `n`, the number of times `x` is consumed.
|
||||
|
|
@ -98,190 +98,187 @@ Return `n`, the number of times `x` is consumed.
|
|||
-/
|
||||
@[specialize]
|
||||
private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Nat :=
|
||||
ys.size.fold (init := 0) fun i n =>
|
||||
let y := ys[i]
|
||||
match y with
|
||||
| Arg.irrelevant => n
|
||||
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
|
||||
ys.size.fold (init := 0) fun i n =>
|
||||
let y := ys[i]
|
||||
match y with
|
||||
| Arg.irrelevant => n
|
||||
| Arg.var y => if x == y && consumeParamPred i then n+1 else n
|
||||
|
||||
@[specialize]
|
||||
private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat → Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
let x := xs[i]
|
||||
match x with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
let info := getVarInfo ctx x
|
||||
if !info.ref || !isFirstOcc xs i then b
|
||||
else
|
||||
let numConsuptions := getNumConsumptions x xs consumeParamPred -- number of times the argument is
|
||||
let numIncs :=
|
||||
if !info.consume || -- `x` is not a variable that must be consumed by the current procedure
|
||||
liveVarsAfter.contains x || -- `x` is live after executing instruction
|
||||
isBorrowParamAux x xs consumeParamPred -- `x` is used in a position that is passed as a borrow reference
|
||||
then numConsuptions
|
||||
else numConsuptions - 1
|
||||
-- dbgTrace ("addInc " ++ toString x ++ " nconsumptions: " ++ toString numConsuptions ++ " incs: " ++ toString numIncs
|
||||
-- ++ " consume: " ++ toString info.consume ++ " live: " ++ toString (liveVarsAfter.contains x)
|
||||
-- ++ " borrowParam : " ++ toString (isBorrowParamAux x xs consumeParamPred)) $ fun _ =>
|
||||
addInc ctx x b numIncs
|
||||
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
let x := xs[i]
|
||||
match x with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
let info := getVarInfo ctx x
|
||||
if !info.ref || !isFirstOcc xs i then b
|
||||
else
|
||||
let numConsuptions := getNumConsumptions x xs consumeParamPred -- number of times the argument is
|
||||
let numIncs :=
|
||||
if !info.consume || -- `x` is not a variable that must be consumed by the current procedure
|
||||
liveVarsAfter.contains x || -- `x` is live after executing instruction
|
||||
isBorrowParamAux x xs consumeParamPred -- `x` is used in a position that is passed as a borrow reference
|
||||
then numConsuptions
|
||||
else numConsuptions - 1
|
||||
-- dbgTrace ("addInc " ++ toString x ++ " nconsumptions: " ++ toString numConsuptions ++ " incs: " ++ toString numIncs
|
||||
-- ++ " consume: " ++ toString info.consume ++ " live: " ++ toString (liveVarsAfter.contains x)
|
||||
-- ++ " borrowParam : " ++ toString (isBorrowParamAux x xs consumeParamPred)) $ fun _ =>
|
||||
addInc ctx x b numIncs
|
||||
|
||||
private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||||
addIncBeforeAux ctx xs (fun i => not ps[i].borrow) b liveVarsAfter
|
||||
addIncBeforeAux ctx xs (fun i => not ps[i].borrow) b liveVarsAfter
|
||||
|
||||
/- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/
|
||||
private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||||
xs.size.fold
|
||||
(fun i b =>
|
||||
match xs[i] with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
/- We must add a `dec` if `x` must be consumed, it is alive after the application,
|
||||
and it has been borrowed by the application.
|
||||
Remark: `x` may occur multiple times in the application (e.g., `f x y x`).
|
||||
This is why we check whether it is the first occurrence. -/
|
||||
if mustConsume ctx x && isFirstOcc xs i && isBorrowParam x xs ps && !bLiveVars.contains x then
|
||||
addDec ctx x b
|
||||
else b)
|
||||
b
|
||||
xs.size.fold (init := b) fun i b =>
|
||||
match xs[i] with
|
||||
| Arg.irrelevant => b
|
||||
| Arg.var x =>
|
||||
/- We must add a `dec` if `x` must be consumed, it is alive after the application,
|
||||
and it has been borrowed by the application.
|
||||
Remark: `x` may occur multiple times in the application (e.g., `f x y x`).
|
||||
This is why we check whether it is the first occurrence. -/
|
||||
if mustConsume ctx x && isFirstOcc xs i && isBorrowParam x xs ps && !bLiveVars.contains x then
|
||||
addDec ctx x b
|
||||
else b
|
||||
|
||||
private def addIncBeforeConsumeAll (ctx : Context) (xs : Array Arg) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||||
addIncBeforeAux ctx xs (fun i => true) b liveVarsAfter
|
||||
addIncBeforeAux ctx xs (fun i => true) b liveVarsAfter
|
||||
|
||||
/- Add `dec` instructions for parameters that are references, are not alive in `b`, and are not borrow.
|
||||
That is, we must make sure these parameters are consumed. -/
|
||||
private def addDecForDeadParams (ctx : Context) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||||
ps.foldl
|
||||
(fun b p => if !p.borrow && p.ty.isObj && !bLiveVars.contains p.x then addDec ctx p.x b else b)
|
||||
b
|
||||
ps.foldl (init := b) fun b p =>
|
||||
if !p.borrow && p.ty.isObj && !bLiveVars.contains p.x then addDec ctx p.x b else b
|
||||
|
||||
private def isPersistent : Expr → Bool
|
||||
| Expr.fap c xs => xs.isEmpty -- all global constants are persistent objects
|
||||
| _ => false
|
||||
| Expr.fap c xs => xs.isEmpty -- all global constants are persistent objects
|
||||
| _ => false
|
||||
|
||||
/- We do not need to consume the projection of a variable that is not consumed -/
|
||||
private def consumeExpr (m : VarMap) : Expr → Bool
|
||||
| Expr.proj i x => match m.find? x with
|
||||
| some info => info.consume
|
||||
| none => true
|
||||
| other => true
|
||||
| Expr.proj i x => match m.find? x with
|
||||
| some info => info.consume
|
||||
| none => true
|
||||
| other => true
|
||||
|
||||
/- Return true iff `v` at runtime is a scalar value stored in a tagged pointer.
|
||||
We do not need RC operations for this kind of value. -/
|
||||
private def isScalarBoxedInTaggedPtr (v : Expr) : Bool :=
|
||||
match v with
|
||||
| Expr.ctor c ys => c.size == 0 && c.ssize == 0 && c.usize == 0
|
||||
| Expr.lit (LitVal.num n) => n ≤ maxSmallNat
|
||||
| _ => false
|
||||
match v with
|
||||
| Expr.ctor c ys => c.size == 0 && c.ssize == 0 && c.usize == 0
|
||||
| Expr.lit (LitVal.num n) => n ≤ maxSmallNat
|
||||
| _ => false
|
||||
|
||||
private def updateVarInfo (ctx : Context) (x : VarId) (t : IRType) (v : Expr) : Context :=
|
||||
{ ctx with
|
||||
varMap := ctx.varMap.insert x {
|
||||
ref := t.isObj && !isScalarBoxedInTaggedPtr v,
|
||||
persistent := isPersistent v,
|
||||
consume := consumeExpr ctx.varMap v } }
|
||||
{ ctx with
|
||||
varMap := ctx.varMap.insert x {
|
||||
ref := t.isObj && !isScalarBoxedInTaggedPtr v,
|
||||
persistent := isPersistent v,
|
||||
consume := consumeExpr ctx.varMap v
|
||||
}
|
||||
}
|
||||
|
||||
private def addDecIfNeeded (ctx : Context) (x : VarId) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||||
if mustConsume ctx x && !bLiveVars.contains x then addDec ctx x b else b
|
||||
if mustConsume ctx x && !bLiveVars.contains x then addDec ctx x b else b
|
||||
|
||||
private def processVDecl (ctx : Context) (z : VarId) (t : IRType) (v : Expr) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody × LiveVarSet :=
|
||||
-- dbgTrace ("processVDecl " ++ toString z ++ " " ++ toString (format v)) $ fun _ =>
|
||||
let b := match v with
|
||||
| (Expr.ctor _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||||
| (Expr.reuse _ _ _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||||
| (Expr.proj _ x) =>
|
||||
let b := addDecIfNeeded ctx x b bLiveVars
|
||||
let b := if (getVarInfo ctx x).consume then addInc ctx z b else b
|
||||
(FnBody.vdecl z t v b)
|
||||
| (Expr.uproj _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||||
| (Expr.sproj _ _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||||
| (Expr.fap f ys) =>
|
||||
-- dbgTrace ("processVDecl " ++ toString v) $ fun _ =>
|
||||
let ps := (getDecl ctx f).params
|
||||
let b := addDecAfterFullApp ctx ys ps b bLiveVars
|
||||
let b := FnBody.vdecl z t v b
|
||||
addIncBefore ctx ys ps b bLiveVars
|
||||
| (Expr.pap _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||||
| (Expr.ap x ys) =>
|
||||
let ysx := ys.push (Arg.var x) -- TODO: avoid temporary array allocation
|
||||
addIncBeforeConsumeAll ctx ysx (FnBody.vdecl z t v b) bLiveVars
|
||||
| (Expr.unbox x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||||
| other => FnBody.vdecl z t v b -- Expr.reset, Expr.box, Expr.lit are handled here
|
||||
let liveVars := updateLiveVars v bLiveVars
|
||||
let liveVars := liveVars.erase z
|
||||
(b, liveVars)
|
||||
let b := match v with
|
||||
| (Expr.ctor _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||||
| (Expr.reuse _ _ _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||||
| (Expr.proj _ x) =>
|
||||
let b := addDecIfNeeded ctx x b bLiveVars
|
||||
let b := if (getVarInfo ctx x).consume then addInc ctx z b else b
|
||||
(FnBody.vdecl z t v b)
|
||||
| (Expr.uproj _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||||
| (Expr.sproj _ _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||||
| (Expr.fap f ys) =>
|
||||
-- dbgTrace ("processVDecl " ++ toString v) $ fun _ =>
|
||||
let ps := (getDecl ctx f).params
|
||||
let b := addDecAfterFullApp ctx ys ps b bLiveVars
|
||||
let b := FnBody.vdecl z t v b
|
||||
addIncBefore ctx ys ps b bLiveVars
|
||||
| (Expr.pap _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||||
| (Expr.ap x ys) =>
|
||||
let ysx := ys.push (Arg.var x) -- TODO: avoid temporary array allocation
|
||||
addIncBeforeConsumeAll ctx ysx (FnBody.vdecl z t v b) bLiveVars
|
||||
| (Expr.unbox x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||||
| other => FnBody.vdecl z t v b -- Expr.reset, Expr.box, Expr.lit are handled here
|
||||
let liveVars := updateLiveVars v bLiveVars
|
||||
let liveVars := liveVars.erase z
|
||||
(b, liveVars)
|
||||
|
||||
def updateVarInfoWithParams (ctx : Context) (ps : Array Param) : Context :=
|
||||
let m := ps.foldl (init := ctx.varMap) fun m p =>
|
||||
m.insert p.x { ref := p.ty.isObj, consume := !p.borrow }
|
||||
{ ctx with varMap := m }
|
||||
let m := ps.foldl (init := ctx.varMap) fun m p =>
|
||||
m.insert p.x { ref := p.ty.isObj, consume := !p.borrow }
|
||||
{ ctx with varMap := m }
|
||||
|
||||
partial def visitFnBody : FnBody → Context → (FnBody × LiveVarSet)
|
||||
| FnBody.vdecl x t v b, ctx =>
|
||||
let ctx := updateVarInfo ctx x t v
|
||||
let (b, bLiveVars) := visitFnBody b ctx
|
||||
processVDecl ctx x t v b bLiveVars
|
||||
| FnBody.jdecl j xs v b, ctx =>
|
||||
let (v, vLiveVars) := visitFnBody v (updateVarInfoWithParams ctx xs)
|
||||
let v := addDecForDeadParams ctx xs v vLiveVars
|
||||
let ctx := { ctx with jpLiveVarMap := updateJPLiveVarMap j xs v ctx.jpLiveVarMap }
|
||||
let (b, bLiveVars) := visitFnBody b ctx
|
||||
(FnBody.jdecl j xs v b, bLiveVars)
|
||||
| FnBody.uset x i y b, ctx =>
|
||||
let (b, s) := visitFnBody b ctx
|
||||
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
|
||||
let s := s.insert x
|
||||
(FnBody.uset x i y b, s)
|
||||
| FnBody.sset x i o y t b, ctx =>
|
||||
let (b, s) := visitFnBody b ctx
|
||||
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
|
||||
let s := s.insert x
|
||||
(FnBody.sset x i o y t b, s)
|
||||
| FnBody.mdata m b, ctx =>
|
||||
let (b, s) := visitFnBody b ctx
|
||||
(FnBody.mdata m b, s)
|
||||
| b@(FnBody.case tid x xType alts), ctx =>
|
||||
let caseLiveVars := collectLiveVars b ctx.jpLiveVarMap
|
||||
let alts := alts.map $ fun alt => match alt with
|
||||
| Alt.ctor c b =>
|
||||
let ctx := updateRefUsingCtorInfo ctx x c
|
||||
let (b, altLiveVars) := visitFnBody b ctx
|
||||
let b := addDecForAlt ctx caseLiveVars altLiveVars b
|
||||
Alt.ctor c b
|
||||
| Alt.default b =>
|
||||
let (b, altLiveVars) := visitFnBody b ctx
|
||||
let b := addDecForAlt ctx caseLiveVars altLiveVars b
|
||||
Alt.default b
|
||||
(FnBody.case tid x xType alts, caseLiveVars)
|
||||
| b@(FnBody.ret x), ctx =>
|
||||
match x with
|
||||
| Arg.var x =>
|
||||
let info := getVarInfo ctx x
|
||||
if info.ref && !info.consume then (addInc ctx x b, mkLiveVarSet x) else (b, mkLiveVarSet x)
|
||||
| _ => (b, {})
|
||||
| b@(FnBody.jmp j xs), ctx =>
|
||||
let jLiveVars := getJPLiveVars ctx j
|
||||
let ps := getJPParams ctx j
|
||||
let b := addIncBefore ctx xs ps b jLiveVars
|
||||
let bLiveVars := collectLiveVars b ctx.jpLiveVarMap
|
||||
(b, bLiveVars)
|
||||
| FnBody.unreachable, _ => (FnBody.unreachable, {})
|
||||
| other, ctx => (other, {}) -- unreachable if well-formed
|
||||
| FnBody.vdecl x t v b, ctx =>
|
||||
let ctx := updateVarInfo ctx x t v
|
||||
let (b, bLiveVars) := visitFnBody b ctx
|
||||
processVDecl ctx x t v b bLiveVars
|
||||
| FnBody.jdecl j xs v b, ctx =>
|
||||
let (v, vLiveVars) := visitFnBody v (updateVarInfoWithParams ctx xs)
|
||||
let v := addDecForDeadParams ctx xs v vLiveVars
|
||||
let ctx := { ctx with jpLiveVarMap := updateJPLiveVarMap j xs v ctx.jpLiveVarMap }
|
||||
let (b, bLiveVars) := visitFnBody b ctx
|
||||
(FnBody.jdecl j xs v b, bLiveVars)
|
||||
| FnBody.uset x i y b, ctx =>
|
||||
let (b, s) := visitFnBody b ctx
|
||||
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
|
||||
let s := s.insert x
|
||||
(FnBody.uset x i y b, s)
|
||||
| FnBody.sset x i o y t b, ctx =>
|
||||
let (b, s) := visitFnBody b ctx
|
||||
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
|
||||
let s := s.insert x
|
||||
(FnBody.sset x i o y t b, s)
|
||||
| FnBody.mdata m b, ctx =>
|
||||
let (b, s) := visitFnBody b ctx
|
||||
(FnBody.mdata m b, s)
|
||||
| b@(FnBody.case tid x xType alts), ctx =>
|
||||
let caseLiveVars := collectLiveVars b ctx.jpLiveVarMap
|
||||
let alts := alts.map $ fun alt => match alt with
|
||||
| Alt.ctor c b =>
|
||||
let ctx := updateRefUsingCtorInfo ctx x c
|
||||
let (b, altLiveVars) := visitFnBody b ctx
|
||||
let b := addDecForAlt ctx caseLiveVars altLiveVars b
|
||||
Alt.ctor c b
|
||||
| Alt.default b =>
|
||||
let (b, altLiveVars) := visitFnBody b ctx
|
||||
let b := addDecForAlt ctx caseLiveVars altLiveVars b
|
||||
Alt.default b
|
||||
(FnBody.case tid x xType alts, caseLiveVars)
|
||||
| b@(FnBody.ret x), ctx =>
|
||||
match x with
|
||||
| Arg.var x =>
|
||||
let info := getVarInfo ctx x
|
||||
if info.ref && !info.consume then (addInc ctx x b, mkLiveVarSet x) else (b, mkLiveVarSet x)
|
||||
| _ => (b, {})
|
||||
| b@(FnBody.jmp j xs), ctx =>
|
||||
let jLiveVars := getJPLiveVars ctx j
|
||||
let ps := getJPParams ctx j
|
||||
let b := addIncBefore ctx xs ps b jLiveVars
|
||||
let bLiveVars := collectLiveVars b ctx.jpLiveVarMap
|
||||
(b, bLiveVars)
|
||||
| FnBody.unreachable, _ => (FnBody.unreachable, {})
|
||||
| other, ctx => (other, {}) -- unreachable if well-formed
|
||||
|
||||
partial def visitDecl (env : Environment) (decls : Array Decl) : Decl → Decl
|
||||
| Decl.fdecl f xs t b =>
|
||||
let ctx : Context := { env := env, decls := decls }
|
||||
let ctx := updateVarInfoWithParams ctx xs
|
||||
let (b, bLiveVars) := visitFnBody b ctx
|
||||
let b := addDecForDeadParams ctx xs b bLiveVars
|
||||
Decl.fdecl f xs t b
|
||||
| other => other
|
||||
| Decl.fdecl f xs t b =>
|
||||
let ctx : Context := { env := env, decls := decls }
|
||||
let ctx := updateVarInfoWithParams ctx xs
|
||||
let (b, bLiveVars) := visitFnBody b ctx
|
||||
let b := addDecForDeadParams ctx xs b bLiveVars
|
||||
Decl.fdecl f xs t b
|
||||
| other => other
|
||||
|
||||
end ExplicitRC
|
||||
|
||||
def explicitRC (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let env ← getEnv
|
||||
pure $ decls.map (ExplicitRC.visitDecl env decls)
|
||||
let env ← getEnv
|
||||
pure $ decls.map (ExplicitRC.visitDecl env decls)
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -26,56 +26,56 @@ namespace Lean.IR.ResetReuse
|
|||
-/
|
||||
|
||||
private def mayReuse (c₁ c₂ : CtorInfo) : Bool :=
|
||||
c₁.size == c₂.size && c₁.usize == c₂.usize && c₁.ssize == c₂.ssize &&
|
||||
/- The following condition is a heuristic.
|
||||
We don't want to reuse cells from different types even when they are compatible
|
||||
because it produces counterintuitive behavior. -/
|
||||
c₁.name.getPrefix == c₂.name.getPrefix
|
||||
c₁.size == c₂.size && c₁.usize == c₂.usize && c₁.ssize == c₂.ssize &&
|
||||
/- The following condition is a heuristic.
|
||||
We don't want to reuse cells from different types even when they are compatible
|
||||
because it produces counterintuitive behavior. -/
|
||||
c₁.name.getPrefix == c₂.name.getPrefix
|
||||
|
||||
private partial def S (w : VarId) (c : CtorInfo) : FnBody → FnBody
|
||||
| FnBody.vdecl x t v@(Expr.ctor c' ys) b =>
|
||||
if mayReuse c c' then
|
||||
let updtCidx := c.cidx != c'.cidx
|
||||
FnBody.vdecl x t (Expr.reuse w c' updtCidx ys) b
|
||||
else
|
||||
FnBody.vdecl x t v (S w c b)
|
||||
| FnBody.jdecl j ys v b =>
|
||||
let v' := S w c v
|
||||
if v == v' then FnBody.jdecl j ys v (S w c b)
|
||||
else FnBody.jdecl j ys v' b
|
||||
| FnBody.case tid x xType alts => FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (S w c)
|
||||
| b =>
|
||||
if b.isTerminal then b
|
||||
else let
|
||||
(instr, b) := b.split
|
||||
instr.setBody (S w c b)
|
||||
| FnBody.vdecl x t v@(Expr.ctor c' ys) b =>
|
||||
if mayReuse c c' then
|
||||
let updtCidx := c.cidx != c'.cidx
|
||||
FnBody.vdecl x t (Expr.reuse w c' updtCidx ys) b
|
||||
else
|
||||
FnBody.vdecl x t v (S w c b)
|
||||
| FnBody.jdecl j ys v b =>
|
||||
let v' := S w c v
|
||||
if v == v' then FnBody.jdecl j ys v (S w c b)
|
||||
else FnBody.jdecl j ys v' b
|
||||
| FnBody.case tid x xType alts => FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody (S w c)
|
||||
| b =>
|
||||
if b.isTerminal then b
|
||||
else let
|
||||
(instr, b) := b.split
|
||||
instr.setBody (S w c b)
|
||||
|
||||
/- We use `Context` to track join points in scope. -/
|
||||
abbrev M := ReaderT LocalContext (StateT Index Id)
|
||||
|
||||
private def mkFresh : M VarId := do
|
||||
let idx ← getModify (fun n => n + 1)
|
||||
pure { idx := idx }
|
||||
let idx ← getModify (fun n => n + 1)
|
||||
pure { idx := idx }
|
||||
|
||||
private def tryS (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody := do
|
||||
let w ← mkFresh
|
||||
let b' := S w c b
|
||||
if b == b' then pure b
|
||||
else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b'
|
||||
let w ← mkFresh
|
||||
let b' := S w c b
|
||||
if b == b' then pure b
|
||||
else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b'
|
||||
|
||||
private def Dfinalize (x : VarId) (c : CtorInfo) : FnBody × Bool → M FnBody
|
||||
| (b, true) => pure b
|
||||
| (b, false) => tryS x c b
|
||||
| (b, true) => pure b
|
||||
| (b, false) => tryS x c b
|
||||
|
||||
private def argsContainsVar (ys : Array Arg) (x : VarId) : Bool :=
|
||||
ys.any fun arg => match arg with
|
||||
| Arg.var y => x == y
|
||||
| _ => false
|
||||
ys.any fun arg => match arg with
|
||||
| Arg.var y => x == y
|
||||
| _ => false
|
||||
|
||||
private def isCtorUsing (b : FnBody) (x : VarId) : Bool :=
|
||||
match b with
|
||||
| (FnBody.vdecl _ _ (Expr.ctor _ ys) _) => argsContainsVar ys x
|
||||
| _ => false
|
||||
match b with
|
||||
| (FnBody.vdecl _ _ (Expr.ctor _ ys) _) => argsContainsVar ys x
|
||||
| _ => false
|
||||
|
||||
/- Given `Dmain b`, the resulting pair `(new_b, flag)` contains the new body `new_b`,
|
||||
and `flag == true` if `x` is live in `b`.
|
||||
|
|
@ -84,75 +84,75 @@ match b with
|
|||
`D` checks whether `x` is live in `F` or not. This is great for clarity but it
|
||||
is expensive: `O(n^2)` where `n` is the size of the function body. -/
|
||||
private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody → M (FnBody × Bool)
|
||||
| e@(FnBody.case tid y yType alts) => do
|
||||
let ctx ← read
|
||||
if e.hasLiveVar ctx x then do
|
||||
/- If `x` is live in `e`, we recursively process each branch. -/
|
||||
let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => Dmain x c b >>= Dfinalize x c
|
||||
pure (FnBody.case tid y yType alts, true)
|
||||
else pure (e, false)
|
||||
| FnBody.jdecl j ys v b => do
|
||||
let (b, found) ← withReader (fun ctx => ctx.addJP j ys v) (Dmain x c b)
|
||||
let (v, _ /- found' -/) ← Dmain x c v
|
||||
/- If `found' == true`, then `Dmain b` must also have returned `(b, true)` since
|
||||
we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`),
|
||||
then it must also live in `b` since `j` is reachable from `b` with a `jmp`.
|
||||
On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/
|
||||
pure (FnBody.jdecl j ys v b, found)
|
||||
| e => do
|
||||
let ctx ← read
|
||||
if e.isTerminal then
|
||||
pure (e, e.hasLiveVar ctx x)
|
||||
else do
|
||||
let (instr, b) := e.split
|
||||
if isCtorUsing instr x then
|
||||
/- If the scrutinee `x` (the one that is providing memory) is being
|
||||
stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
|
||||
It may work only if the new cell is consumed, but we ignore this case. -/
|
||||
pure (e, true)
|
||||
| e@(FnBody.case tid y yType alts) => do
|
||||
let ctx ← read
|
||||
if e.hasLiveVar ctx x then do
|
||||
/- If `x` is live in `e`, we recursively process each branch. -/
|
||||
let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => Dmain x c b >>= Dfinalize x c
|
||||
pure (FnBody.case tid y yType alts, true)
|
||||
else pure (e, false)
|
||||
| FnBody.jdecl j ys v b => do
|
||||
let (b, found) ← withReader (fun ctx => ctx.addJP j ys v) (Dmain x c b)
|
||||
let (v, _ /- found' -/) ← Dmain x c v
|
||||
/- If `found' == true`, then `Dmain b` must also have returned `(b, true)` since
|
||||
we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`),
|
||||
then it must also live in `b` since `j` is reachable from `b` with a `jmp`.
|
||||
On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/
|
||||
pure (FnBody.jdecl j ys v b, found)
|
||||
| e => do
|
||||
let ctx ← read
|
||||
if e.isTerminal then
|
||||
pure (e, e.hasLiveVar ctx x)
|
||||
else do
|
||||
let (b, found) ← Dmain x c b
|
||||
/- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar`
|
||||
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/
|
||||
if found || !instr.hasFreeVar x then
|
||||
pure (instr.setBody b, found)
|
||||
else do
|
||||
let b ← tryS x c b
|
||||
pure (instr.setBody b, true)
|
||||
let (instr, b) := e.split
|
||||
if isCtorUsing instr x then
|
||||
/- If the scrutinee `x` (the one that is providing memory) is being
|
||||
stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
|
||||
It may work only if the new cell is consumed, but we ignore this case. -/
|
||||
pure (e, true)
|
||||
else
|
||||
let (b, found) ← Dmain x c b
|
||||
/- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar`
|
||||
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/
|
||||
if found || !instr.hasFreeVar x then
|
||||
pure (instr.setBody b, found)
|
||||
else
|
||||
let b ← tryS x c b
|
||||
pure (instr.setBody b, true)
|
||||
|
||||
private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
|
||||
Dmain x c b >>= Dfinalize x c
|
||||
Dmain x c b >>= Dfinalize x c
|
||||
|
||||
partial def R : FnBody → M FnBody
|
||||
| FnBody.case tid x xType alts => do
|
||||
let alts ← alts.mapM fun alt => do
|
||||
let alt ← alt.mmodifyBody R
|
||||
match alt with
|
||||
| Alt.ctor c b =>
|
||||
if c.isScalar then pure alt
|
||||
else Alt.ctor c <$> D x c b
|
||||
| _ => pure alt
|
||||
pure $ FnBody.case tid x xType alts
|
||||
| FnBody.jdecl j ys v b => do
|
||||
let v ← R v
|
||||
let b ← withReader (fun ctx => ctx.addJP j ys v) (R b)
|
||||
pure $ FnBody.jdecl j ys v b
|
||||
| e => do
|
||||
if e.isTerminal then pure e
|
||||
else do
|
||||
let (instr, b) := e.split
|
||||
let b ← R b
|
||||
pure (instr.setBody b)
|
||||
| FnBody.case tid x xType alts => do
|
||||
let alts ← alts.mapM fun alt => do
|
||||
let alt ← alt.mmodifyBody R
|
||||
match alt with
|
||||
| Alt.ctor c b =>
|
||||
if c.isScalar then pure alt
|
||||
else Alt.ctor c <$> D x c b
|
||||
| _ => pure alt
|
||||
pure $ FnBody.case tid x xType alts
|
||||
| FnBody.jdecl j ys v b => do
|
||||
let v ← R v
|
||||
let b ← withReader (fun ctx => ctx.addJP j ys v) (R b)
|
||||
pure $ FnBody.jdecl j ys v b
|
||||
| e => do
|
||||
if e.isTerminal then pure e
|
||||
else do
|
||||
let (instr, b) := e.split
|
||||
let b ← R b
|
||||
pure (instr.setBody b)
|
||||
|
||||
end ResetReuse
|
||||
|
||||
open ResetReuse
|
||||
|
||||
def Decl.insertResetReuse : Decl → Decl
|
||||
| d@(Decl.fdecl f xs t b) =>
|
||||
let nextIndex := d.maxIndex + 1
|
||||
let b := (R b {}).run' nextIndex
|
||||
Decl.fdecl f xs t b
|
||||
| other => other
|
||||
| d@(Decl.fdecl f xs t b) =>
|
||||
let nextIndex := d.maxIndex + 1
|
||||
let b := (R b {}).run' nextIndex
|
||||
Decl.fdecl f xs t b
|
||||
| other => other
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -9,52 +9,51 @@ import Lean.Compiler.IR.Format
|
|||
namespace Lean.IR
|
||||
|
||||
def ensureHasDefault (alts : Array Alt) : Array Alt :=
|
||||
if alts.any Alt.isDefault then alts
|
||||
else if alts.size < 2 then alts
|
||||
else
|
||||
let last := alts.back;
|
||||
let alts := alts.pop;
|
||||
alts.push (Alt.default last.body)
|
||||
if alts.any Alt.isDefault then alts
|
||||
else if alts.size < 2 then alts
|
||||
else
|
||||
let last := alts.back;
|
||||
let alts := alts.pop;
|
||||
alts.push (Alt.default last.body)
|
||||
|
||||
private def getOccsOf (alts : Array Alt) (i : Nat) : Nat := do
|
||||
let aBody := (alts.get! i).body
|
||||
let n := 1
|
||||
for j in [i+1:alts.size] do
|
||||
if alts[j].body == aBody then
|
||||
n := n+1
|
||||
return n
|
||||
let aBody := (alts.get! i).body
|
||||
let n := 1
|
||||
for j in [i+1:alts.size] do
|
||||
if alts[j].body == aBody then
|
||||
n := n+1
|
||||
return n
|
||||
|
||||
private def maxOccs (alts : Array Alt) : Alt × Nat := do
|
||||
let maxAlt := alts[0]
|
||||
let max := getOccsOf alts 0
|
||||
for i in [1:alts.size] do
|
||||
let curr := getOccsOf alts i
|
||||
if curr > max then
|
||||
maxAlt := alts[i]
|
||||
max := curr
|
||||
return (maxAlt, max)
|
||||
let maxAlt := alts[0]
|
||||
let max := getOccsOf alts 0
|
||||
for i in [1:alts.size] do
|
||||
let curr := getOccsOf alts i
|
||||
if curr > max then
|
||||
maxAlt := alts[i]
|
||||
max := curr
|
||||
return (maxAlt, max)
|
||||
|
||||
private def addDefault (alts : Array Alt) : Array Alt :=
|
||||
if alts.size <= 1 || alts.any Alt.isDefault then alts
|
||||
else
|
||||
let (max, noccs) := maxOccs alts;
|
||||
if noccs == 1 then alts
|
||||
if alts.size <= 1 || alts.any Alt.isDefault then alts
|
||||
else
|
||||
let alts := alts.filter $ (fun alt => alt.body != max.body);
|
||||
alts.push (Alt.default max.body)
|
||||
let (max, noccs) := maxOccs alts;
|
||||
if noccs == 1 then alts
|
||||
else
|
||||
let alts := alts.filter $ (fun alt => alt.body != max.body);
|
||||
alts.push (Alt.default max.body)
|
||||
|
||||
private def mkSimpCase (tid : Name) (x : VarId) (xType : IRType) (alts : Array Alt) : FnBody :=
|
||||
let alts := alts.filter (fun alt => alt.body != FnBody.unreachable);
|
||||
let alts := addDefault alts;
|
||||
if alts.size == 0 then
|
||||
FnBody.unreachable
|
||||
else if alts.size == 1 then
|
||||
(alts.get! 0).body
|
||||
else
|
||||
FnBody.case tid x xType alts
|
||||
let alts := alts.filter (fun alt => alt.body != FnBody.unreachable);
|
||||
let alts := addDefault alts;
|
||||
if alts.size == 0 then
|
||||
FnBody.unreachable
|
||||
else if alts.size == 1 then
|
||||
(alts.get! 0).body
|
||||
else
|
||||
FnBody.case tid x xType alts
|
||||
|
||||
partial def FnBody.simpCase : FnBody → FnBody
|
||||
| b =>
|
||||
partial def FnBody.simpCase (b : FnBody) : FnBody :=
|
||||
let (bs, term) := b.flatten;
|
||||
let bs := modifyJPs bs simpCase;
|
||||
match term with
|
||||
|
|
@ -68,7 +67,7 @@ partial def FnBody.simpCase : FnBody → FnBody
|
|||
- Remove `case` if there is only one branch.
|
||||
- Merge most common branches using `Alt.default`. -/
|
||||
def Decl.simpCase : Decl → Decl
|
||||
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.simpCase
|
||||
| other => other
|
||||
| Decl.fdecl f xs t b => Decl.fdecl f xs t b.simpCase
|
||||
| other => other
|
||||
|
||||
end Lean.IR
|
||||
|
|
|
|||
|
|
@ -824,13 +824,13 @@ This is a standard trick we use in the elaborator, and it is also used to elabor
|
|||
Suppose, we are trying to elaborate
|
||||
```
|
||||
match g x with
|
||||
| ... => ...
|
||||
| ... => ...
|
||||
```
|
||||
`expandNonAtomicDiscrs?` converts it intro
|
||||
```
|
||||
let _discr := g x
|
||||
match _discr with
|
||||
| ... => ...
|
||||
| ... => ...
|
||||
```
|
||||
Thus, at `tryPostponeIfDiscrTypeIsMVar` we only need to check whether the type of `_discr` is not of the form `(?m ...)`.
|
||||
Note that, the auxiliary variable `_discr` is expanded at `elabAtomicDiscr`.
|
||||
|
|
|
|||
|
|
@ -108,17 +108,17 @@ syntax "throwError! " ((interpolatedStr term) <|> term) : term
|
|||
syntax "throwErrorAt! " term:max ((interpolatedStr term) <|> term) : term
|
||||
|
||||
macro_rules
|
||||
| `(throwError! $msg) =>
|
||||
if msg.getKind == interpolatedStrKind then
|
||||
`(throwError (msg! $msg))
|
||||
else
|
||||
`(throwError $msg)
|
||||
| `(throwError! $msg) =>
|
||||
if msg.getKind == interpolatedStrKind then
|
||||
`(throwError (msg! $msg))
|
||||
else
|
||||
`(throwError $msg)
|
||||
|
||||
macro_rules
|
||||
| `(throwErrorAt! $ref $msg) =>
|
||||
if msg.getKind == interpolatedStrKind then
|
||||
`(throwErrorAt $ref (msg! $msg))
|
||||
else
|
||||
`(throwErrorAt $ref $msg)
|
||||
| `(throwErrorAt! $ref $msg) =>
|
||||
if msg.getKind == interpolatedStrKind then
|
||||
`(throwErrorAt $ref (msg! $msg))
|
||||
else
|
||||
`(throwErrorAt $ref $msg)
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -211,12 +211,12 @@ def occurs : Level → Level → Bool
|
|||
| u, v => u == v
|
||||
|
||||
def ctorToNat : Level → Nat
|
||||
| zero .. => 0
|
||||
| param .. => 1
|
||||
| mvar .. => 2
|
||||
| succ .. => 3
|
||||
| max .. => 4
|
||||
| imax .. => 5
|
||||
| zero .. => 0
|
||||
| param .. => 1
|
||||
| mvar .. => 2
|
||||
| succ .. => 3
|
||||
| max .. => 4
|
||||
| imax .. => 5
|
||||
|
||||
/- TODO: use well founded recursion. -/
|
||||
partial def normLtAux : Level → Nat → Level → Nat → Bool
|
||||
|
|
|
|||
|
|
@ -111,15 +111,15 @@ def replaceFVarId (fvarId : FVarId) (v : Expr) (alt : Alt) : Alt :=
|
|||
|
||||
```
|
||||
inductive Vec (α : Type u) : Nat → Type u
|
||||
| nil : Vec α 0
|
||||
| cons {n} (head : α) (tail : Vec α n) : Vec α (n+1)
|
||||
| nil : Vec α 0
|
||||
| cons {n} (head : α) (tail : Vec α n) : Vec α (n+1)
|
||||
|
||||
inductive VecPred {α : Type u} (P : α → Prop) : {n : Nat} → Vec α n → Prop
|
||||
| nil : VecPred P Vec.nil
|
||||
| cons {n : Nat} {head : α} {tail : Vec α n} : P head → VecPred P tail → VecPred P (Vec.cons head tail)
|
||||
| nil : VecPred P Vec.nil
|
||||
| cons {n : Nat} {head : α} {tail : Vec α n} : P head → VecPred P tail → VecPred P (Vec.cons head tail)
|
||||
|
||||
theorem ex {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P
|
||||
| _, Vec.cons head _, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩
|
||||
| _, Vec.cons head _, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩
|
||||
```
|
||||
Recall that `_` in a pattern can be elaborated into pattern variable or an inaccessible term.
|
||||
The elaborator uses an inaccessible term when typing constraints restrict its value.
|
||||
|
|
@ -127,7 +127,7 @@ Thus, in the example above, the `_` at `Vec.cons head _` becomes the inaccessibl
|
|||
because the type ascription `(w : VecPred P Vec.nil)` propagates typing constraints that restrict its value to be `Vec.nil`.
|
||||
After elaboration the alternative becomes:
|
||||
```
|
||||
| .(0), @Vec.cons .(α) .(0) head .(Vec.nil), @VecPred.cons .(α) .(P) .(0) .(head) .(Vec.nil) h w => ⟨head, h⟩
|
||||
| .(0), @Vec.cons .(α) .(0) head .(Vec.nil), @VecPred.cons .(α) .(P) .(0) .(head) .(Vec.nil) h w => ⟨head, h⟩
|
||||
```
|
||||
where
|
||||
```
|
||||
|
|
@ -138,7 +138,7 @@ Then, when we process this alternative in this module, the following check will
|
|||
Note that if we had written
|
||||
```
|
||||
theorem ex {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P
|
||||
| _, Vec.cons head Vec.nil, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩
|
||||
| _, Vec.cons head Vec.nil, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩
|
||||
```
|
||||
we would get the easier to digest error message
|
||||
```
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ builtin_initialize builtinTokenTable : IO.Ref TokenTable ← IO.mkRef {}
|
|||
builtin_initialize builtinSyntaxNodeKindSetRef : IO.Ref SyntaxNodeKindSet ← IO.mkRef {}
|
||||
|
||||
def registerBuiltinNodeKind (k : SyntaxNodeKind) : IO Unit :=
|
||||
builtinSyntaxNodeKindSetRef.modify fun s => s.insert k
|
||||
builtinSyntaxNodeKindSetRef.modify fun s => s.insert k
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinNodeKind choiceKind
|
||||
|
|
@ -32,215 +32,213 @@ builtin_initialize
|
|||
builtin_initialize builtinParserCategoriesRef : IO.Ref ParserCategories ← IO.mkRef {}
|
||||
|
||||
private def throwParserCategoryAlreadyDefined {α} (catName : Name) : ExceptT String Id α :=
|
||||
throw s!"parser category '{catName}' has already been defined"
|
||||
throw s!"parser category '{catName}' has already been defined"
|
||||
|
||||
private def addParserCategoryCore (categories : ParserCategories) (catName : Name) (initial : ParserCategory) : Except String ParserCategories :=
|
||||
if categories.contains catName then
|
||||
throwParserCategoryAlreadyDefined catName
|
||||
else
|
||||
pure $ categories.insert catName initial
|
||||
if categories.contains catName then
|
||||
throwParserCategoryAlreadyDefined catName
|
||||
else
|
||||
pure $ categories.insert catName initial
|
||||
|
||||
/-- All builtin parser categories are Pratt's parsers -/
|
||||
private def addBuiltinParserCategory (catName : Name) (leadingIdentAsSymbol : Bool) : IO Unit := do
|
||||
let categories ← builtinParserCategoriesRef.get
|
||||
let categories ← IO.ofExcept $ addParserCategoryCore categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol}
|
||||
builtinParserCategoriesRef.set categories
|
||||
let categories ← builtinParserCategoriesRef.get
|
||||
let categories ← IO.ofExcept $ addParserCategoryCore categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol}
|
||||
builtinParserCategoriesRef.set categories
|
||||
|
||||
inductive ParserExtensionOleanEntry
|
||||
| token (val : Token) : ParserExtensionOleanEntry
|
||||
| kind (val : SyntaxNodeKind) : ParserExtensionOleanEntry
|
||||
| category (catName : Name) (leadingIdentAsSymbol : Bool)
|
||||
| parser (catName : Name) (declName : Name) (prio : Nat) : ParserExtensionOleanEntry
|
||||
| token (val : Token) : ParserExtensionOleanEntry
|
||||
| kind (val : SyntaxNodeKind) : ParserExtensionOleanEntry
|
||||
| category (catName : Name) (leadingIdentAsSymbol : Bool)
|
||||
| parser (catName : Name) (declName : Name) (prio : Nat) : ParserExtensionOleanEntry
|
||||
|
||||
inductive ParserExtensionEntry
|
||||
| token (val : Token) : ParserExtensionEntry
|
||||
| kind (val : SyntaxNodeKind) : ParserExtensionEntry
|
||||
| category (catName : Name) (leadingIdentAsSymbol : Bool)
|
||||
| parser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : ParserExtensionEntry
|
||||
| token (val : Token) : ParserExtensionEntry
|
||||
| kind (val : SyntaxNodeKind) : ParserExtensionEntry
|
||||
| category (catName : Name) (leadingIdentAsSymbol : Bool)
|
||||
| parser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : ParserExtensionEntry
|
||||
|
||||
structure ParserExtensionState :=
|
||||
(tokens : TokenTable := {})
|
||||
(kinds : SyntaxNodeKindSet := {})
|
||||
(categories : ParserCategories := {})
|
||||
(newEntries : List ParserExtensionOleanEntry := [])
|
||||
(tokens : TokenTable := {})
|
||||
(kinds : SyntaxNodeKindSet := {})
|
||||
(categories : ParserCategories := {})
|
||||
(newEntries : List ParserExtensionOleanEntry := [])
|
||||
|
||||
instance : Inhabited ParserExtensionState := ⟨{}⟩
|
||||
|
||||
abbrev ParserExtension := PersistentEnvExtension ParserExtensionOleanEntry ParserExtensionEntry ParserExtensionState
|
||||
|
||||
private def ParserExtension.mkInitial : IO ParserExtensionState := do
|
||||
let tokens ← builtinTokenTable.get
|
||||
let kinds ← builtinSyntaxNodeKindSetRef.get
|
||||
let categories ← builtinParserCategoriesRef.get
|
||||
pure { tokens := tokens, kinds := kinds, categories := categories }
|
||||
let tokens ← builtinTokenTable.get
|
||||
let kinds ← builtinSyntaxNodeKindSetRef.get
|
||||
let categories ← builtinParserCategoriesRef.get
|
||||
pure { tokens := tokens, kinds := kinds, categories := categories }
|
||||
|
||||
private def addTokenConfig (tokens : TokenTable) (tk : Token) : Except String TokenTable := do
|
||||
if tk == "" then throw "invalid empty symbol"
|
||||
else match tokens.find? tk with
|
||||
| none => pure $ tokens.insert tk tk
|
||||
| some _ => pure tokens
|
||||
if tk == "" then throw "invalid empty symbol"
|
||||
else match tokens.find? tk with
|
||||
| none => pure $ tokens.insert tk tk
|
||||
| some _ => pure tokens
|
||||
|
||||
def throwUnknownParserCategory {α} (catName : Name) : ExceptT String Id α :=
|
||||
throw s!"unknown parser category '{catName}'"
|
||||
throw s!"unknown parser category '{catName}'"
|
||||
|
||||
abbrev getCategory (categories : ParserCategories) (catName : Name) : Option ParserCategory :=
|
||||
categories.find? catName
|
||||
categories.find? catName
|
||||
|
||||
def addLeadingParser (categories : ParserCategories) (catName : Name) (parserName : Name) (p : Parser) (prio : Nat) : Except String ParserCategories :=
|
||||
match getCategory categories catName with
|
||||
| none =>
|
||||
throwUnknownParserCategory catName
|
||||
| some cat =>
|
||||
let addTokens (tks : List Token) : Except String ParserCategories :=
|
||||
let tks := tks.map $ fun tk => mkNameSimple tk
|
||||
let tables := tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with leadingTable := tables.leadingTable.insert tk (p, prio) }) cat.tables
|
||||
pure $ categories.insert catName { cat with tables := tables }
|
||||
match getCategory categories catName with
|
||||
| none =>
|
||||
throwUnknownParserCategory catName
|
||||
| some cat =>
|
||||
let addTokens (tks : List Token) : Except String ParserCategories :=
|
||||
let tks := tks.map $ fun tk => mkNameSimple tk
|
||||
let tables := tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with leadingTable := tables.leadingTable.insert tk (p, prio) }) cat.tables
|
||||
pure $ categories.insert catName { cat with tables := tables }
|
||||
match p.info.firstTokens with
|
||||
| FirstTokens.tokens tks => addTokens tks
|
||||
| FirstTokens.optTokens tks => addTokens tks
|
||||
| _ =>
|
||||
let tables := { cat.tables with leadingParsers := (p, prio) :: cat.tables.leadingParsers }
|
||||
pure $ categories.insert catName { cat with tables := tables }
|
||||
|
||||
private def addTrailingParserAux (tables : PrattParsingTables) (p : TrailingParser) (prio : Nat) : PrattParsingTables :=
|
||||
let addTokens (tks : List Token) : PrattParsingTables :=
|
||||
let tks := tks.map fun tk => mkNameSimple tk
|
||||
tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with trailingTable := tables.trailingTable.insert tk (p, prio) }) tables
|
||||
match p.info.firstTokens with
|
||||
| FirstTokens.tokens tks => addTokens tks
|
||||
| FirstTokens.optTokens tks => addTokens tks
|
||||
| _ =>
|
||||
let tables := { cat.tables with leadingParsers := (p, prio) :: cat.tables.leadingParsers }
|
||||
pure $ categories.insert catName { cat with tables := tables }
|
||||
|
||||
private def addTrailingParserAux (tables : PrattParsingTables) (p : TrailingParser) (prio : Nat) : PrattParsingTables :=
|
||||
let addTokens (tks : List Token) : PrattParsingTables :=
|
||||
let tks := tks.map fun tk => mkNameSimple tk
|
||||
tks.eraseDups.foldl (fun (tables : PrattParsingTables) tk => { tables with trailingTable := tables.trailingTable.insert tk (p, prio) }) tables
|
||||
match p.info.firstTokens with
|
||||
| FirstTokens.tokens tks => addTokens tks
|
||||
| FirstTokens.optTokens tks => addTokens tks
|
||||
| _ => { tables with trailingParsers := (p, prio) :: tables.trailingParsers }
|
||||
| _ => { tables with trailingParsers := (p, prio) :: tables.trailingParsers }
|
||||
|
||||
def addTrailingParser (categories : ParserCategories) (catName : Name) (p : TrailingParser) (prio : Nat) : Except String ParserCategories :=
|
||||
match getCategory categories catName with
|
||||
| none => throwUnknownParserCategory catName
|
||||
| some cat => pure $ categories.insert catName { cat with tables := addTrailingParserAux cat.tables p prio }
|
||||
match getCategory categories catName with
|
||||
| none => throwUnknownParserCategory catName
|
||||
| some cat => pure $ categories.insert catName { cat with tables := addTrailingParserAux cat.tables p prio }
|
||||
|
||||
def addParser (categories : ParserCategories) (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : Except String ParserCategories :=
|
||||
match leading, p with
|
||||
| true, p => addLeadingParser categories catName declName p prio
|
||||
| false, p => addTrailingParser categories catName p prio
|
||||
match leading, p with
|
||||
| true, p => addLeadingParser categories catName declName p prio
|
||||
| false, p => addTrailingParser categories catName p prio
|
||||
|
||||
def addParserTokens (tokenTable : TokenTable) (info : ParserInfo) : Except String TokenTable :=
|
||||
let newTokens := info.collectTokens []
|
||||
newTokens.foldlM addTokenConfig tokenTable
|
||||
let newTokens := info.collectTokens []
|
||||
newTokens.foldlM addTokenConfig tokenTable
|
||||
|
||||
private def updateBuiltinTokens (info : ParserInfo) (declName : Name) : IO Unit := do
|
||||
let tokenTable ← builtinTokenTable.swap {}
|
||||
match addParserTokens tokenTable info with
|
||||
| Except.ok tokenTable => builtinTokenTable.set tokenTable
|
||||
| Except.error msg => throw (IO.userError s!"invalid builtin parser '{declName}', {msg}")
|
||||
let tokenTable ← builtinTokenTable.swap {}
|
||||
match addParserTokens tokenTable info with
|
||||
| Except.ok tokenTable => builtinTokenTable.set tokenTable
|
||||
| Except.error msg => throw (IO.userError s!"invalid builtin parser '{declName}', {msg}")
|
||||
|
||||
def addBuiltinParser (catName : Name) (declName : Name) (leading : Bool) (p : Parser) (prio : Nat) : IO Unit := do
|
||||
let categories ← builtinParserCategoriesRef.get
|
||||
let categories ← IO.ofExcept $ addParser categories catName declName leading p prio
|
||||
builtinParserCategoriesRef.set categories
|
||||
builtinSyntaxNodeKindSetRef.modify p.info.collectKinds
|
||||
updateBuiltinTokens p.info declName
|
||||
let categories ← builtinParserCategoriesRef.get
|
||||
let categories ← IO.ofExcept $ addParser categories catName declName leading p prio
|
||||
builtinParserCategoriesRef.set categories
|
||||
builtinSyntaxNodeKindSetRef.modify p.info.collectKinds
|
||||
updateBuiltinTokens p.info declName
|
||||
|
||||
def addBuiltinLeadingParser (catName : Name) (declName : Name) (p : Parser) (prio : Nat) : IO Unit :=
|
||||
addBuiltinParser catName declName true p prio
|
||||
addBuiltinParser catName declName true p prio
|
||||
|
||||
def addBuiltinTrailingParser (catName : Name) (declName : Name) (p : TrailingParser) (prio : Nat) : IO Unit :=
|
||||
addBuiltinParser catName declName false p prio
|
||||
addBuiltinParser catName declName false p prio
|
||||
|
||||
private def ParserExtensionAddEntry (s : ParserExtensionState) (e : ParserExtensionEntry) : ParserExtensionState :=
|
||||
match e with
|
||||
| ParserExtensionEntry.token tk =>
|
||||
match addTokenConfig s.tokens tk with
|
||||
| Except.ok tokens => { s with tokens := tokens, newEntries := ParserExtensionOleanEntry.token tk :: s.newEntries }
|
||||
| _ => unreachable!
|
||||
| ParserExtensionEntry.kind k =>
|
||||
{ s with kinds := s.kinds.insert k, newEntries := ParserExtensionOleanEntry.kind k :: s.newEntries }
|
||||
| ParserExtensionEntry.category catName leadingIdentAsSymbol =>
|
||||
if s.categories.contains catName then s
|
||||
else { s with
|
||||
categories := s.categories.insert catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol },
|
||||
newEntries := ParserExtensionOleanEntry.category catName leadingIdentAsSymbol :: s.newEntries }
|
||||
| ParserExtensionEntry.parser catName declName leading parser prio =>
|
||||
match addParser s.categories catName declName leading parser prio with
|
||||
| Except.ok categories => { s with categories := categories, newEntries := ParserExtensionOleanEntry.parser catName declName prio :: s.newEntries }
|
||||
| _ => unreachable!
|
||||
match e with
|
||||
| ParserExtensionEntry.token tk =>
|
||||
match addTokenConfig s.tokens tk with
|
||||
| Except.ok tokens => { s with tokens := tokens, newEntries := ParserExtensionOleanEntry.token tk :: s.newEntries }
|
||||
| _ => unreachable!
|
||||
| ParserExtensionEntry.kind k =>
|
||||
{ s with kinds := s.kinds.insert k, newEntries := ParserExtensionOleanEntry.kind k :: s.newEntries }
|
||||
| ParserExtensionEntry.category catName leadingIdentAsSymbol =>
|
||||
if s.categories.contains catName then s
|
||||
else { s with
|
||||
categories := s.categories.insert catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol },
|
||||
newEntries := ParserExtensionOleanEntry.category catName leadingIdentAsSymbol :: s.newEntries }
|
||||
| ParserExtensionEntry.parser catName declName leading parser prio =>
|
||||
match addParser s.categories catName declName leading parser prio with
|
||||
| Except.ok categories => { s with categories := categories, newEntries := ParserExtensionOleanEntry.parser catName declName prio :: s.newEntries }
|
||||
| _ => unreachable!
|
||||
|
||||
unsafe def mkParserOfConstantUnsafe
|
||||
(env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name)
|
||||
(compileParserDescr : ParserDescr → Except String Parser)
|
||||
: Except String (Bool × Parser) :=
|
||||
match env.find? constName with
|
||||
| none => throw s!"unknow constant '{constName}'"
|
||||
| some info =>
|
||||
match info.type with
|
||||
| Expr.const `Lean.Parser.TrailingParser _ _ => do
|
||||
let p ← env.evalConst Parser opts constName
|
||||
pure ⟨false, p⟩
|
||||
| Expr.const `Lean.Parser.Parser _ _ => do
|
||||
let p ← env.evalConst Parser opts constName
|
||||
pure ⟨true, p⟩
|
||||
| Expr.const `Lean.ParserDescr _ _ => do
|
||||
let d ← env.evalConst ParserDescr opts constName
|
||||
let p ← compileParserDescr d
|
||||
pure ⟨true, p⟩
|
||||
| Expr.const `Lean.TrailingParserDescr _ _ => do
|
||||
let d ← env.evalConst TrailingParserDescr opts constName
|
||||
let p ← compileParserDescr d
|
||||
pure ⟨false, p⟩
|
||||
| _ => throw s!"unexpected parser type at '{constName}' (`ParserDescr`, `TrailingParserDescr`, `Parser` or `TrailingParser` expected"
|
||||
match env.find? constName with
|
||||
| none => throw s!"unknow constant '{constName}'"
|
||||
| some info =>
|
||||
match info.type with
|
||||
| Expr.const `Lean.Parser.TrailingParser _ _ => do
|
||||
let p ← env.evalConst Parser opts constName
|
||||
pure ⟨false, p⟩
|
||||
| Expr.const `Lean.Parser.Parser _ _ => do
|
||||
let p ← env.evalConst Parser opts constName
|
||||
pure ⟨true, p⟩
|
||||
| Expr.const `Lean.ParserDescr _ _ => do
|
||||
let d ← env.evalConst ParserDescr opts constName
|
||||
let p ← compileParserDescr d
|
||||
pure ⟨true, p⟩
|
||||
| Expr.const `Lean.TrailingParserDescr _ _ => do
|
||||
let d ← env.evalConst TrailingParserDescr opts constName
|
||||
let p ← compileParserDescr d
|
||||
pure ⟨false, p⟩
|
||||
| _ => throw s!"unexpected parser type at '{constName}' (`ParserDescr`, `TrailingParserDescr`, `Parser` or `TrailingParser` expected"
|
||||
|
||||
@[implementedBy mkParserOfConstantUnsafe]
|
||||
constant mkParserOfConstantAux
|
||||
(env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name)
|
||||
(compileParserDescr : ParserDescr → Except String Parser)
|
||||
: Except String (Bool × Parser) :=
|
||||
arbitrary _
|
||||
: Except String (Bool × Parser)
|
||||
|
||||
partial def compileParserDescr (env : Environment) (opts : Options) (categories : ParserCategories) (d : ParserDescr) : Except String Parser :=
|
||||
let rec visit : ParserDescr → Except String Parser
|
||||
| ParserDescr.andthen d₁ d₂ => andthen <$> visit d₁ <*> visit d₂
|
||||
| ParserDescr.orelse d₁ d₂ => orelse <$> visit d₁ <*> visit d₂
|
||||
| ParserDescr.optional d => optional <$> visit d
|
||||
| ParserDescr.lookahead d => lookahead <$> visit d
|
||||
| ParserDescr.«try» d => «try» <$> visit d
|
||||
| ParserDescr.notFollowedBy d => do let p ← visit d; pure $ notFollowedBy p "element" -- TODO allow user to set msg at ParserDescr
|
||||
| ParserDescr.many d => many <$> visit d
|
||||
| ParserDescr.many1 d => many1 <$> visit d
|
||||
| ParserDescr.sepBy d₁ d₂ => sepBy <$> visit d₁ <*> visit d₂
|
||||
| ParserDescr.sepBy1 d₁ d₂ => sepBy1 <$> visit d₁ <*> visit d₂
|
||||
| ParserDescr.node k prec d => leadingNode k prec <$> visit d
|
||||
| ParserDescr.trailingNode k prec d => trailingNode k prec <$> visit d
|
||||
| ParserDescr.symbol tk => pure $ symbol tk
|
||||
| ParserDescr.noWs => pure $ checkNoWsBefore
|
||||
| ParserDescr.numLit => pure $ numLit
|
||||
| ParserDescr.strLit => pure $ strLit
|
||||
| ParserDescr.charLit => pure $ charLit
|
||||
| ParserDescr.nameLit => pure $ nameLit
|
||||
| ParserDescr.interpolatedStr d => interpolatedStr <$> visit d
|
||||
| ParserDescr.ident => pure $ ident
|
||||
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol tk includeIdent
|
||||
| ParserDescr.parser constName => do
|
||||
let (_, p) ← mkParserOfConstantAux env opts categories constName visit;
|
||||
pure p
|
||||
| ParserDescr.cat catName prec =>
|
||||
match getCategory categories catName with
|
||||
| some _ => pure $ categoryParser catName prec
|
||||
| none => throwUnknownParserCategory catName
|
||||
visit d
|
||||
let rec visit : ParserDescr → Except String Parser
|
||||
| ParserDescr.andthen d₁ d₂ => andthen <$> visit d₁ <*> visit d₂
|
||||
| ParserDescr.orelse d₁ d₂ => orelse <$> visit d₁ <*> visit d₂
|
||||
| ParserDescr.optional d => optional <$> visit d
|
||||
| ParserDescr.lookahead d => lookahead <$> visit d
|
||||
| ParserDescr.«try» d => «try» <$> visit d
|
||||
| ParserDescr.notFollowedBy d => do let p ← visit d; pure $ notFollowedBy p "element" -- TODO allow user to set msg at ParserDescr
|
||||
| ParserDescr.many d => many <$> visit d
|
||||
| ParserDescr.many1 d => many1 <$> visit d
|
||||
| ParserDescr.sepBy d₁ d₂ => sepBy <$> visit d₁ <*> visit d₂
|
||||
| ParserDescr.sepBy1 d₁ d₂ => sepBy1 <$> visit d₁ <*> visit d₂
|
||||
| ParserDescr.node k prec d => leadingNode k prec <$> visit d
|
||||
| ParserDescr.trailingNode k prec d => trailingNode k prec <$> visit d
|
||||
| ParserDescr.symbol tk => pure $ symbol tk
|
||||
| ParserDescr.noWs => pure $ checkNoWsBefore
|
||||
| ParserDescr.numLit => pure $ numLit
|
||||
| ParserDescr.strLit => pure $ strLit
|
||||
| ParserDescr.charLit => pure $ charLit
|
||||
| ParserDescr.nameLit => pure $ nameLit
|
||||
| ParserDescr.interpolatedStr d => interpolatedStr <$> visit d
|
||||
| ParserDescr.ident => pure $ ident
|
||||
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol tk includeIdent
|
||||
| ParserDescr.parser constName => do
|
||||
let (_, p) ← mkParserOfConstantAux env opts categories constName visit;
|
||||
pure p
|
||||
| ParserDescr.cat catName prec =>
|
||||
match getCategory categories catName with
|
||||
| some _ => pure $ categoryParser catName prec
|
||||
| none => throwUnknownParserCategory catName
|
||||
visit d
|
||||
|
||||
def mkParserOfConstant (env : Environment) (opts : Options) (categories : ParserCategories) (constName : Name) : Except String (Bool × Parser) :=
|
||||
mkParserOfConstantAux env opts categories constName (compileParserDescr env opts categories)
|
||||
mkParserOfConstantAux env opts categories constName (compileParserDescr env opts categories)
|
||||
|
||||
structure ParserAttributeHook :=
|
||||
/- Called after a parser attribute is applied to a declaration. -/
|
||||
(postAdd (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit)
|
||||
/- Called after a parser attribute is applied to a declaration. -/
|
||||
(postAdd (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit)
|
||||
|
||||
def mkParserAttributeHooks : IO (IO.Ref (List ParserAttributeHook)) := IO.mkRef {}
|
||||
@[builtinInit mkParserAttributeHooks] constant parserAttributeHooks : IO.Ref (List ParserAttributeHook) := arbitrary _
|
||||
builtin_initialize parserAttributeHooks : IO.Ref (List ParserAttributeHook) ← IO.mkRef {}
|
||||
|
||||
def registerParserAttributeHook (hook : ParserAttributeHook) : IO Unit := do
|
||||
parserAttributeHooks.modify fun hooks => hook::hooks
|
||||
parserAttributeHooks.modify fun hooks => hook::hooks
|
||||
|
||||
def runParserAttributeHooks (catName : Name) (declName : Name) (builtin : Bool) : AttrM Unit := do
|
||||
let hooks ← parserAttributeHooks.get
|
||||
hooks.forM fun hook => hook.postAdd catName declName builtin
|
||||
let hooks ← parserAttributeHooks.get
|
||||
hooks.forM fun hook => hook.postAdd catName declName builtin
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinAttribute {
|
||||
|
|
@ -261,27 +259,23 @@ builtin_initialize
|
|||
}
|
||||
|
||||
private def ParserExtension.addImported (es : Array (Array ParserExtensionOleanEntry)) : ImportM ParserExtensionState := do
|
||||
let ctx ← read
|
||||
let s ← ParserExtension.mkInitial
|
||||
es.foldlM
|
||||
(fun s entries =>
|
||||
entries.foldlM
|
||||
(fun s entry =>
|
||||
match entry with
|
||||
| ParserExtensionOleanEntry.token tk => do
|
||||
let tokens ← IO.ofExcept (addTokenConfig s.tokens tk)
|
||||
pure { s with tokens := tokens }
|
||||
| ParserExtensionOleanEntry.kind k =>
|
||||
pure { s with kinds := s.kinds.insert k }
|
||||
| ParserExtensionOleanEntry.category catName leadingIdentAsSymbol => do
|
||||
let categories ← IO.ofExcept (addParserCategoryCore s.categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol})
|
||||
pure { s with categories := categories }
|
||||
| ParserExtensionOleanEntry.parser catName declName prio => do
|
||||
let p ← IO.ofExcept $ mkParserOfConstant ctx.env ctx.opts s.categories declName
|
||||
let categories ← IO.ofExcept $ addParser s.categories catName declName p.1 p.2 prio
|
||||
pure { s with categories := categories })
|
||||
s)
|
||||
s
|
||||
let ctx ← read
|
||||
let s ← ParserExtension.mkInitial
|
||||
es.foldlM (init := s) fun s entries =>
|
||||
entries.foldlM (init := s) fun s entry =>
|
||||
match entry with
|
||||
| ParserExtensionOleanEntry.token tk => do
|
||||
let tokens ← IO.ofExcept (addTokenConfig s.tokens tk)
|
||||
pure { s with tokens := tokens }
|
||||
| ParserExtensionOleanEntry.kind k =>
|
||||
pure { s with kinds := s.kinds.insert k }
|
||||
| ParserExtensionOleanEntry.category catName leadingIdentAsSymbol => do
|
||||
let categories ← IO.ofExcept (addParserCategoryCore s.categories catName { tables := {}, leadingIdentAsSymbol := leadingIdentAsSymbol})
|
||||
pure { s with categories := categories }
|
||||
| ParserExtensionOleanEntry.parser catName declName prio => do
|
||||
let p ← IO.ofExcept $ mkParserOfConstant ctx.env ctx.opts s.categories declName
|
||||
let categories ← IO.ofExcept $ addParser s.categories catName declName p.1 p.2 prio
|
||||
pure { s with categories := categories }
|
||||
|
||||
builtin_initialize parserExtension : ParserExtension ←
|
||||
registerPersistentEnvExtension {
|
||||
|
|
@ -294,31 +288,30 @@ builtin_initialize parserExtension : ParserExtension ←
|
|||
}
|
||||
|
||||
def isParserCategory (env : Environment) (catName : Name) : Bool :=
|
||||
(parserExtension.getState env).categories.contains catName
|
||||
(parserExtension.getState env).categories.contains catName
|
||||
|
||||
def addParserCategory (env : Environment) (catName : Name) (leadingIdentAsSymbol : Bool) : Except String Environment := do
|
||||
if isParserCategory env catName then
|
||||
throwParserCategoryAlreadyDefined catName
|
||||
else
|
||||
pure $ parserExtension.addEntry env $ ParserExtensionEntry.category catName leadingIdentAsSymbol
|
||||
if isParserCategory env catName then
|
||||
throwParserCategoryAlreadyDefined catName
|
||||
else
|
||||
pure $ parserExtension.addEntry env $ ParserExtensionEntry.category catName leadingIdentAsSymbol
|
||||
|
||||
/-
|
||||
Return true if in the given category leading identifiers in parsers may be treated as atoms/symbols.
|
||||
See comment at `ParserCategory`. -/
|
||||
def leadingIdentAsSymbol (env : Environment) (catName : Name) : Bool :=
|
||||
match getCategory (parserExtension.getState env).categories catName with
|
||||
| none => false
|
||||
| some cat => cat.leadingIdentAsSymbol
|
||||
match getCategory (parserExtension.getState env).categories catName with
|
||||
| none => false
|
||||
| some cat => cat.leadingIdentAsSymbol
|
||||
|
||||
def mkCategoryAntiquotParser (kind : Name) : Parser :=
|
||||
mkAntiquot kind.toString none
|
||||
mkAntiquot kind.toString none
|
||||
|
||||
-- helper decl to work around inlining issue https://github.com/leanprover/lean4/commit/3f6de2af06dd9a25f62294129f64bc05a29ea912#r41340377
|
||||
@[inline] private def mkCategoryAntiquotParserFn (kind : Name) : ParserFn :=
|
||||
(mkCategoryAntiquotParser kind).fn
|
||||
|
||||
def categoryParserFnImpl (catName : Name) : ParserFn :=
|
||||
fun ctx s =>
|
||||
def categoryParserFnImpl (catName : Name) : ParserFn := fun ctx s =>
|
||||
let catName := if catName == `syntax then `stx else catName -- temporary Hack
|
||||
let categories := (parserExtension.getState ctx.env).categories
|
||||
match getCategory categories catName with
|
||||
|
|
@ -327,149 +320,152 @@ fun ctx s =>
|
|||
| none => s.mkUnexpectedError ("unknown parser category '" ++ toString catName ++ "'")
|
||||
|
||||
@[builtinInit] def setCategoryParserFnRef : IO Unit :=
|
||||
categoryParserFnRef.set categoryParserFnImpl
|
||||
categoryParserFnRef.set categoryParserFnImpl
|
||||
|
||||
def addToken (env : Environment) (tk : Token) : Except String Environment := do
|
||||
-- Recall that `ParserExtension.addEntry` is pure, and assumes `addTokenConfig` does not fail.
|
||||
-- So, we must run it here to handle exception.
|
||||
addTokenConfig (parserExtension.getState env).tokens tk
|
||||
pure $ parserExtension.addEntry env $ ParserExtensionEntry.token tk
|
||||
-- Recall that `ParserExtension.addEntry` is pure, and assumes `addTokenConfig` does not fail.
|
||||
-- So, we must run it here to handle exception.
|
||||
addTokenConfig (parserExtension.getState env).tokens tk
|
||||
pure $ parserExtension.addEntry env $ ParserExtensionEntry.token tk
|
||||
|
||||
def addSyntaxNodeKind (env : Environment) (k : SyntaxNodeKind) : Environment :=
|
||||
parserExtension.addEntry env $ ParserExtensionEntry.kind k
|
||||
parserExtension.addEntry env $ ParserExtensionEntry.kind k
|
||||
|
||||
def isValidSyntaxNodeKind (env : Environment) (k : SyntaxNodeKind) : Bool :=
|
||||
let kinds := (parserExtension.getState env).kinds
|
||||
kinds.contains k
|
||||
let kinds := (parserExtension.getState env).kinds
|
||||
kinds.contains k
|
||||
|
||||
def getSyntaxNodeKinds (env : Environment) : List SyntaxNodeKind := do
|
||||
let kinds := (parserExtension.getState env).kinds
|
||||
kinds.foldl (fun ks k _ => k::ks) []
|
||||
let kinds := (parserExtension.getState env).kinds
|
||||
kinds.foldl (fun ks k _ => k::ks) []
|
||||
|
||||
def getTokenTable (env : Environment) : TokenTable :=
|
||||
(parserExtension.getState env).tokens
|
||||
(parserExtension.getState env).tokens
|
||||
|
||||
def mkInputContext (input : String) (fileName : String) : InputContext :=
|
||||
{ input := input,
|
||||
def mkInputContext (input : String) (fileName : String) : InputContext := {
|
||||
input := input,
|
||||
fileName := fileName,
|
||||
fileMap := input.toFileMap }
|
||||
fileMap := input.toFileMap
|
||||
}
|
||||
|
||||
def mkParserContext (env : Environment) (ctx : InputContext) : ParserContext :=
|
||||
{ prec := 0,
|
||||
def mkParserContext (env : Environment) (ctx : InputContext) : ParserContext := {
|
||||
prec := 0,
|
||||
toInputContext := ctx,
|
||||
env := env,
|
||||
tokens := getTokenTable env }
|
||||
tokens := getTokenTable env
|
||||
}
|
||||
|
||||
def mkParserState (input : String) : ParserState :=
|
||||
{ cache := initCacheForInput input }
|
||||
{ cache := initCacheForInput input }
|
||||
|
||||
/- convenience function for testing -/
|
||||
def runParserCategory (env : Environment) (catName : Name) (input : String) (fileName := "<input>") : Except String Syntax :=
|
||||
let c := mkParserContext env (mkInputContext input fileName)
|
||||
let s := mkParserState input
|
||||
let s := whitespace c s
|
||||
let s := categoryParserFnImpl catName c s
|
||||
if s.hasError then
|
||||
Except.error (s.toErrorMsg c)
|
||||
else if input.atEnd s.pos then
|
||||
Except.ok s.stxStack.back
|
||||
else
|
||||
Except.error ((s.mkError "end of input").toErrorMsg c)
|
||||
let c := mkParserContext env (mkInputContext input fileName)
|
||||
let s := mkParserState input
|
||||
let s := whitespace c s
|
||||
let s := categoryParserFnImpl catName c s
|
||||
if s.hasError then
|
||||
Except.error (s.toErrorMsg c)
|
||||
else if input.atEnd s.pos then
|
||||
Except.ok s.stxStack.back
|
||||
else
|
||||
Except.error ((s.mkError "end of input").toErrorMsg c)
|
||||
|
||||
def declareBuiltinParser (env : Environment) (addFnName : Name) (catName : Name) (declName : Name) (prio : Nat) : IO Environment :=
|
||||
let name := `_regBuiltinParser ++ declName
|
||||
let type := mkApp (mkConst `IO) (mkConst `Unit)
|
||||
let val := mkAppN (mkConst addFnName) #[toExpr catName, toExpr declName, mkConst declName, mkNatLit prio]
|
||||
let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false }
|
||||
match env.addAndCompile {} decl with
|
||||
-- TODO: pretty print error
|
||||
| Except.error _ => throw (IO.userError ("failed to emit registration code for builtin parser '" ++ toString declName ++ "'"))
|
||||
| Except.ok env => IO.ofExcept (setBuiltinInitAttr env name)
|
||||
let name := `_regBuiltinParser ++ declName
|
||||
let type := mkApp (mkConst `IO) (mkConst `Unit)
|
||||
let val := mkAppN (mkConst addFnName) #[toExpr catName, toExpr declName, mkConst declName, mkNatLit prio]
|
||||
let decl := Declaration.defnDecl { name := name, lparams := [], type := type, value := val, hints := ReducibilityHints.opaque, isUnsafe := false }
|
||||
match env.addAndCompile {} decl with
|
||||
-- TODO: pretty print error
|
||||
| Except.error _ => throw (IO.userError ("failed to emit registration code for builtin parser '" ++ toString declName ++ "'"))
|
||||
| Except.ok env => IO.ofExcept (setBuiltinInitAttr env name)
|
||||
|
||||
def declareLeadingBuiltinParser (env : Environment) (catName : Name) (declName : Name) (prio : Nat) : IO Environment := -- TODO: use CoreM?
|
||||
declareBuiltinParser env `Lean.Parser.addBuiltinLeadingParser catName declName prio
|
||||
declareBuiltinParser env `Lean.Parser.addBuiltinLeadingParser catName declName prio
|
||||
|
||||
def declareTrailingBuiltinParser (env : Environment) (catName : Name) (declName : Name) (prio : Nat) : IO Environment := -- TODO: use CoreM?
|
||||
declareBuiltinParser env `Lean.Parser.addBuiltinTrailingParser catName declName prio
|
||||
declareBuiltinParser env `Lean.Parser.addBuiltinTrailingParser catName declName prio
|
||||
|
||||
def getParserPriority (args : Syntax) : Except String Nat :=
|
||||
match args.getNumArgs with
|
||||
| 0 => pure 0
|
||||
| 1 => match (args.getArg 0).isNatLit? with
|
||||
| some prio => pure prio
|
||||
| none => throw "invalid parser attribute, numeral expected"
|
||||
| _ => throw "invalid parser attribute, no argument or numeral expected"
|
||||
match args.getNumArgs with
|
||||
| 0 => pure 0
|
||||
| 1 => match (args.getArg 0).isNatLit? with
|
||||
| some prio => pure prio
|
||||
| none => throw "invalid parser attribute, numeral expected"
|
||||
| _ => throw "invalid parser attribute, no argument or numeral expected"
|
||||
|
||||
private def BuiltinParserAttribute.add (attrName : Name) (catName : Name)
|
||||
(declName : Name) (args : Syntax) (persistent : Bool) : AttrM Unit := do
|
||||
let prio ← ofExcept (getParserPriority args)
|
||||
unless persistent do throwError! "invalid attribute '{attrName}', must be persistent"
|
||||
let decl ← getConstInfo declName
|
||||
let env ← getEnv
|
||||
match decl.type with
|
||||
| Expr.const `Lean.Parser.TrailingParser _ _ => do
|
||||
let env ← liftIO $ declareTrailingBuiltinParser env catName declName prio
|
||||
setEnv env
|
||||
| Expr.const `Lean.Parser.Parser _ _ => do
|
||||
let env ← liftIO $ declareLeadingBuiltinParser env catName declName prio
|
||||
setEnv env
|
||||
| _ => throwError! "unexpected parser type at '{declName}' (`Parser` or `TrailingParser` expected)"
|
||||
runParserAttributeHooks catName declName (builtin := true)
|
||||
let prio ← ofExcept (getParserPriority args)
|
||||
unless persistent do throwError! "invalid attribute '{attrName}', must be persistent"
|
||||
let decl ← getConstInfo declName
|
||||
let env ← getEnv
|
||||
match decl.type with
|
||||
| Expr.const `Lean.Parser.TrailingParser _ _ => do
|
||||
let env ← liftIO $ declareTrailingBuiltinParser env catName declName prio
|
||||
setEnv env
|
||||
| Expr.const `Lean.Parser.Parser _ _ => do
|
||||
let env ← liftIO $ declareLeadingBuiltinParser env catName declName prio
|
||||
setEnv env
|
||||
| _ => throwError! "unexpected parser type at '{declName}' (`Parser` or `TrailingParser` expected)"
|
||||
runParserAttributeHooks catName declName (builtin := true)
|
||||
|
||||
/-
|
||||
The parsing tables for builtin parsers are "stored" in the extracted source code.
|
||||
-/
|
||||
def registerBuiltinParserAttribute (attrName : Name) (catName : Name) (leadingIdentAsSymbol := false) : IO Unit := do
|
||||
addBuiltinParserCategory catName leadingIdentAsSymbol
|
||||
registerBuiltinAttribute {
|
||||
name := attrName,
|
||||
descr := "Builtin parser",
|
||||
add := fun declName args persistent => liftM $ BuiltinParserAttribute.add attrName catName declName args persistent,
|
||||
applicationTime := AttributeApplicationTime.afterCompilation
|
||||
}
|
||||
addBuiltinParserCategory catName leadingIdentAsSymbol
|
||||
registerBuiltinAttribute {
|
||||
name := attrName,
|
||||
descr := "Builtin parser",
|
||||
add := fun declName args persistent => liftM $ BuiltinParserAttribute.add attrName catName declName args persistent,
|
||||
applicationTime := AttributeApplicationTime.afterCompilation
|
||||
}
|
||||
|
||||
private def ParserAttribute.add (attrName : Name) (catName : Name) (declName : Name) (args : Syntax) (persistent : Bool) : AttrM Unit := do
|
||||
let prio ← ofExcept (getParserPriority args)
|
||||
let env ← getEnv
|
||||
let opts ← getOptions
|
||||
let categories := (parserExtension.getState env).categories
|
||||
match mkParserOfConstant env opts categories declName with
|
||||
| Except.error ex => throwError ex
|
||||
| Except.ok p => do
|
||||
let leading := p.1
|
||||
let parser := p.2
|
||||
let tokens := parser.info.collectTokens []
|
||||
tokens.forM fun token => do
|
||||
env ← getEnv
|
||||
match addToken env token with
|
||||
| Except.ok env => setEnv env
|
||||
| Except.error msg => throwError! "invalid parser '{declName}', {msg}"
|
||||
let kinds := parser.info.collectKinds {}
|
||||
kinds.forM fun kind _ => modifyEnv fun env => addSyntaxNodeKind env kind
|
||||
match addParser categories catName declName leading parser prio with
|
||||
| Except.ok _ => modifyEnv fun env => parserExtension.addEntry env $ ParserExtensionEntry.parser catName declName leading parser prio
|
||||
let prio ← ofExcept (getParserPriority args)
|
||||
let env ← getEnv
|
||||
let opts ← getOptions
|
||||
let categories := (parserExtension.getState env).categories
|
||||
match mkParserOfConstant env opts categories declName with
|
||||
| Except.error ex => throwError ex
|
||||
runParserAttributeHooks catName declName /- builtin -/ false
|
||||
| Except.ok p => do
|
||||
let leading := p.1
|
||||
let parser := p.2
|
||||
let tokens := parser.info.collectTokens []
|
||||
tokens.forM fun token => do
|
||||
env ← getEnv
|
||||
match addToken env token with
|
||||
| Except.ok env => setEnv env
|
||||
| Except.error msg => throwError! "invalid parser '{declName}', {msg}"
|
||||
let kinds := parser.info.collectKinds {}
|
||||
kinds.forM fun kind _ => modifyEnv fun env => addSyntaxNodeKind env kind
|
||||
match addParser categories catName declName leading parser prio with
|
||||
| Except.ok _ => modifyEnv fun env => parserExtension.addEntry env $ ParserExtensionEntry.parser catName declName leading parser prio
|
||||
| Except.error ex => throwError ex
|
||||
runParserAttributeHooks catName declName /- builtin -/ false
|
||||
|
||||
def mkParserAttributeImpl (attrName : Name) (catName : Name) : AttributeImpl :=
|
||||
{ name := attrName,
|
||||
def mkParserAttributeImpl (attrName : Name) (catName : Name) : AttributeImpl := {
|
||||
name := attrName,
|
||||
descr := "parser",
|
||||
add := fun declName args persistent => liftM $ ParserAttribute.add attrName catName declName args persistent,
|
||||
applicationTime := AttributeApplicationTime.afterCompilation }
|
||||
applicationTime := AttributeApplicationTime.afterCompilation
|
||||
}
|
||||
|
||||
/- A builtin parser attribute that can be extended by users. -/
|
||||
def registerBuiltinDynamicParserAttribute (attrName : Name) (catName : Name) : IO Unit := do
|
||||
registerBuiltinAttribute (mkParserAttributeImpl attrName catName)
|
||||
registerBuiltinAttribute (mkParserAttributeImpl attrName catName)
|
||||
|
||||
@[builtinInit] private def registerParserAttributeImplBuilder : IO Unit :=
|
||||
registerAttributeImplBuilder `parserAttr fun args =>
|
||||
match args with
|
||||
| [DataValue.ofName attrName, DataValue.ofName catName] => pure $ mkParserAttributeImpl attrName catName
|
||||
| _ => throw "invalid parser attribute implementation builder arguments"
|
||||
registerAttributeImplBuilder `parserAttr fun args =>
|
||||
match args with
|
||||
| [DataValue.ofName attrName, DataValue.ofName catName] => pure $ mkParserAttributeImpl attrName catName
|
||||
| _ => throw "invalid parser attribute implementation builder arguments"
|
||||
|
||||
def registerParserCategory (env : Environment) (attrName : Name) (catName : Name) (leadingIdentAsSymbol := false) : IO Environment := do
|
||||
env ← IO.ofExcept $ addParserCategory env catName leadingIdentAsSymbol
|
||||
registerAttributeOfBuilder env `parserAttr [DataValue.ofName attrName, DataValue.ofName catName]
|
||||
let env ← IO.ofExcept $ addParserCategory env catName leadingIdentAsSymbol
|
||||
registerAttributeOfBuilder env `parserAttr [DataValue.ofName attrName, DataValue.ofName catName]
|
||||
|
||||
-- declare `termParser` here since it is used everywhere via antiquotations
|
||||
|
||||
|
|
@ -483,10 +479,9 @@ builtin_initialize registerBuiltinParserAttribute `builtinCommandParser `command
|
|||
builtin_initialize registerBuiltinDynamicParserAttribute `commandParser `command
|
||||
|
||||
@[inline] def commandParser (rbp : Nat := 0) : Parser :=
|
||||
categoryParser `command rbp
|
||||
categoryParser `command rbp
|
||||
|
||||
def notFollowedByCategoryTokenFn (catName : Name) : ParserFn :=
|
||||
fun ctx s =>
|
||||
def notFollowedByCategoryTokenFn (catName : Name) : ParserFn := fun ctx s =>
|
||||
let categories := (parserExtension.getState ctx.env).categories
|
||||
match getCategory categories catName with
|
||||
| none => s.mkUnexpectedError s!"unknown parser category '{catName}'"
|
||||
|
|
@ -500,14 +495,15 @@ fun ctx s =>
|
|||
| _ => s
|
||||
| _ => s
|
||||
|
||||
@[inline] def notFollowedByCategoryToken (catName : Name) : Parser :=
|
||||
{ fn := notFollowedByCategoryTokenFn catName }
|
||||
@[inline] def notFollowedByCategoryToken (catName : Name) : Parser := {
|
||||
fn := notFollowedByCategoryTokenFn catName
|
||||
}
|
||||
|
||||
abbrev notFollowedByCommandToken : Parser :=
|
||||
notFollowedByCategoryToken `command
|
||||
notFollowedByCategoryToken `command
|
||||
|
||||
abbrev notFollowedByTermToken : Parser :=
|
||||
notFollowedByCategoryToken `term
|
||||
notFollowedByCategoryToken `term
|
||||
|
||||
end Parser
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -23,67 +23,67 @@ def ctx (interp) : ParserCompiler.Context Parenthesizer :=
|
|||
⟨`parenthesizer, parenthesizerAttribute, combinatorParenthesizerAttribute, interp⟩
|
||||
|
||||
unsafe def interpretParserDescr : ParserDescr → AttrM Parenthesizer
|
||||
| ParserDescr.andthen d₁ d₂ => andthen.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.orelse d₁ d₂ => orelse.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.optional d => optional.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.lookahead d => lookahead.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.try d => try.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.notFollowedBy d => notFollowedBy.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.many d => many.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.many1 d => many1.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.sepBy d₁ d₂ => sepBy.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.sepBy1 d₁ d₂ => sepBy1.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.node k prec d => leadingNode.parenthesizer k prec <$> interpretParserDescr d
|
||||
| ParserDescr.trailingNode k prec d => trailingNode.parenthesizer k prec <$> interpretParserDescr d
|
||||
| ParserDescr.symbol tk => pure $ symbol.parenthesizer tk
|
||||
| ParserDescr.numLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "numLit" `numLit) numLitNoAntiquot.parenthesizer
|
||||
| ParserDescr.strLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "strLit" `strLit) strLitNoAntiquot.parenthesizer
|
||||
| ParserDescr.charLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "charLit" `charLit) charLitNoAntiquot.parenthesizer
|
||||
| ParserDescr.nameLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "nameLit" `nameLit) nameLitNoAntiquot.parenthesizer
|
||||
| ParserDescr.ident => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "ident" `ident) identNoAntiquot.parenthesizer
|
||||
| ParserDescr.interpolatedStr d => interpolatedStr.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.parenthesizer tk includeIdent
|
||||
| ParserDescr.noWs => pure $ checkNoWsBefore.parenthesizer
|
||||
| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName
|
||||
| ParserDescr.cat catName prec => pure $ categoryParser.parenthesizer catName prec
|
||||
| ParserDescr.andthen d₁ d₂ => andthen.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.orelse d₁ d₂ => orelse.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.optional d => optional.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.lookahead d => lookahead.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.try d => try.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.notFollowedBy d => notFollowedBy.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.many d => many.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.many1 d => many1.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.sepBy d₁ d₂ => sepBy.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.sepBy1 d₁ d₂ => sepBy1.parenthesizer <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.node k prec d => leadingNode.parenthesizer k prec <$> interpretParserDescr d
|
||||
| ParserDescr.trailingNode k prec d => trailingNode.parenthesizer k prec <$> interpretParserDescr d
|
||||
| ParserDescr.symbol tk => pure $ symbol.parenthesizer tk
|
||||
| ParserDescr.numLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "numLit" `numLit) numLitNoAntiquot.parenthesizer
|
||||
| ParserDescr.strLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "strLit" `strLit) strLitNoAntiquot.parenthesizer
|
||||
| ParserDescr.charLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "charLit" `charLit) charLitNoAntiquot.parenthesizer
|
||||
| ParserDescr.nameLit => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "nameLit" `nameLit) nameLitNoAntiquot.parenthesizer
|
||||
| ParserDescr.ident => pure $ withAntiquot.parenthesizer (mkAntiquot.parenthesizer' "ident" `ident) identNoAntiquot.parenthesizer
|
||||
| ParserDescr.interpolatedStr d => interpolatedStr.parenthesizer <$> interpretParserDescr d
|
||||
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.parenthesizer tk includeIdent
|
||||
| ParserDescr.noWs => pure $ checkNoWsBefore.parenthesizer
|
||||
| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName
|
||||
| ParserDescr.cat catName prec => pure $ categoryParser.parenthesizer catName prec
|
||||
|
||||
@[builtinInit] unsafe def regHook : IO Unit :=
|
||||
ParserCompiler.registerParserCompiler (ctx interpretParserDescr)
|
||||
ParserCompiler.registerParserCompiler (ctx interpretParserDescr)
|
||||
|
||||
end Parenthesizer
|
||||
|
||||
namespace Formatter
|
||||
|
||||
def ctx (interp) : ParserCompiler.Context Formatter :=
|
||||
⟨`formatter, formatterAttribute, combinatorFormatterAttribute, interp⟩
|
||||
⟨`formatter, formatterAttribute, combinatorFormatterAttribute, interp⟩
|
||||
|
||||
unsafe def interpretParserDescr : ParserDescr → AttrM Formatter
|
||||
| ParserDescr.andthen d₁ d₂ => andthen.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.orelse d₁ d₂ => orelse.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.optional d => optional.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.lookahead d => lookahead.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.try d => try.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.notFollowedBy d => notFollowedBy.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.many d => many.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.many1 d => many1.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.sepBy d₁ d₂ => sepBy.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.sepBy1 d₁ d₂ => sepBy1.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.node k prec d => node.formatter k <$> interpretParserDescr d
|
||||
| ParserDescr.trailingNode k prec d => trailingNode.formatter k prec <$> interpretParserDescr d
|
||||
| ParserDescr.symbol tk => pure $ symbol.formatter tk
|
||||
| ParserDescr.numLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "numLit" `numLit) numLitNoAntiquot.formatter
|
||||
| ParserDescr.strLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "strLit" `strLit) strLitNoAntiquot.formatter
|
||||
| ParserDescr.charLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "charLit" `charLit) charLitNoAntiquot.formatter
|
||||
| ParserDescr.nameLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "nameLit" `nameLit) nameLitNoAntiquot.formatter
|
||||
| ParserDescr.interpolatedStr d => interpolatedStr.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.ident => pure $ withAntiquot.formatter (mkAntiquot.formatter' "ident" `ident) identNoAntiquot.formatter
|
||||
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.formatter tk
|
||||
| ParserDescr.noWs => pure $ checkNoWsBefore.formatter
|
||||
| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName
|
||||
| ParserDescr.cat catName prec => pure $ categoryParser.formatter catName
|
||||
| ParserDescr.andthen d₁ d₂ => andthen.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.orelse d₁ d₂ => orelse.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.optional d => optional.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.lookahead d => lookahead.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.try d => try.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.notFollowedBy d => notFollowedBy.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.many d => many.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.many1 d => many1.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.sepBy d₁ d₂ => sepBy.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.sepBy1 d₁ d₂ => sepBy1.formatter <$> interpretParserDescr d₁ <*> interpretParserDescr d₂
|
||||
| ParserDescr.node k prec d => node.formatter k <$> interpretParserDescr d
|
||||
| ParserDescr.trailingNode k prec d => trailingNode.formatter k prec <$> interpretParserDescr d
|
||||
| ParserDescr.symbol tk => pure $ symbol.formatter tk
|
||||
| ParserDescr.numLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "numLit" `numLit) numLitNoAntiquot.formatter
|
||||
| ParserDescr.strLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "strLit" `strLit) strLitNoAntiquot.formatter
|
||||
| ParserDescr.charLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "charLit" `charLit) charLitNoAntiquot.formatter
|
||||
| ParserDescr.nameLit => pure $ withAntiquot.formatter (mkAntiquot.formatter' "nameLit" `nameLit) nameLitNoAntiquot.formatter
|
||||
| ParserDescr.interpolatedStr d => interpolatedStr.formatter <$> interpretParserDescr d
|
||||
| ParserDescr.ident => pure $ withAntiquot.formatter (mkAntiquot.formatter' "ident" `ident) identNoAntiquot.formatter
|
||||
| ParserDescr.nonReservedSymbol tk includeIdent => pure $ nonReservedSymbol.formatter tk
|
||||
| ParserDescr.noWs => pure $ checkNoWsBefore.formatter
|
||||
| ParserDescr.parser constName => interpretParser (ctx interpretParserDescr) constName
|
||||
| ParserDescr.cat catName prec => pure $ categoryParser.formatter catName
|
||||
|
||||
@[builtinInit] unsafe def regHook : IO Unit :=
|
||||
ParserCompiler.registerParserCompiler (ctx interpretParserDescr)
|
||||
ParserCompiler.registerParserCompiler (ctx interpretParserDescr)
|
||||
|
||||
end Formatter
|
||||
|
||||
|
|
|
|||
|
|
@ -66,15 +66,15 @@ end ReplaceLevelImpl
|
|||
|
||||
@[implementedBy ReplaceLevelImpl.replaceUnsafe]
|
||||
partial def replaceLevel (f? : Level → Option Level) : Expr → Expr
|
||||
| e@(Expr.forallE _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateForallE! d b
|
||||
| e@(Expr.lam _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateLambdaE! d b
|
||||
| e@(Expr.mdata _ b _) => let b := replaceLevel f? b; e.updateMData! b
|
||||
| e@(Expr.letE _ t v b _) => let t := replaceLevel f? t; let v := replaceLevel f? v; let b := replaceLevel f? b; e.updateLet! t v b
|
||||
| e@(Expr.app f a _) => let f := replaceLevel f? f; let a := replaceLevel f? a; e.updateApp! f a
|
||||
| e@(Expr.proj _ _ b _) => let b := replaceLevel f? b; e.updateProj! b
|
||||
| e@(Expr.sort u _) => e.updateSort! (u.replace f?)
|
||||
| e@(Expr.const n us _) => e.updateConst! (us.map (Level.replace f?))
|
||||
| e => e
|
||||
| e@(Expr.forallE _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateForallE! d b
|
||||
| e@(Expr.lam _ d b _) => let d := replaceLevel f? d; let b := replaceLevel f? b; e.updateLambdaE! d b
|
||||
| e@(Expr.mdata _ b _) => let b := replaceLevel f? b; e.updateMData! b
|
||||
| e@(Expr.letE _ t v b _) => let t := replaceLevel f? t; let v := replaceLevel f? v; let b := replaceLevel f? b; e.updateLet! t v b
|
||||
| e@(Expr.app f a _) => let f := replaceLevel f? f; let a := replaceLevel f? a; e.updateApp! f a
|
||||
| e@(Expr.proj _ _ b _) => let b := replaceLevel f? b; e.updateProj! b
|
||||
| e@(Expr.sort u _) => e.updateSort! (u.replace f?)
|
||||
| e@(Expr.const n us _) => e.updateConst! (us.map (Level.replace f?))
|
||||
| e => e
|
||||
|
||||
end Expr
|
||||
end Lean
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue