feat(frontends/lean/builtin_exprs): use ; in do-notation
This commit is contained in:
parent
ab487ea4ac
commit
91e1d30cf8
35 changed files with 589 additions and 582 deletions
|
|
@ -24,7 +24,7 @@ if c then pure () else e
|
|||
|
||||
@[macroInline]
|
||||
def mcond {m : Type → Type u} [Monad m] {α : Type} (mbool : m Bool) (tm fm : m α) : m α :=
|
||||
do b ← mbool, cond b tm fm
|
||||
do b ← mbool; cond b tm fm
|
||||
|
||||
@[macroInline]
|
||||
def mwhen {m : Type → Type u} [Monad m] (c : m Bool) (t : m Unit) : m Unit :=
|
||||
|
|
@ -70,20 +70,20 @@ def mfor {m : Type u → Type v} [Applicative m] {α : Type w} {β : Type u} (f
|
|||
@[specialize]
|
||||
def mfilter {m : Type → Type v} [Monad m] {α : Type} (f : α → m Bool) : List α → m (List α)
|
||||
| [] := pure []
|
||||
| (h :: t) := do b ← f h, t' ← mfilter t, cond b (pure (h :: t')) (pure t')
|
||||
| (h :: t) := do b ← f h; t' ← mfilter t; cond b (pure (h :: t')) (pure t')
|
||||
|
||||
@[specialize]
|
||||
def mfoldl {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (s → α → m s) → s → List α → m s
|
||||
| f s [] := pure s
|
||||
| f s (h :: r) := do
|
||||
s' ← f s h,
|
||||
s' ← f s h;
|
||||
mfoldl f s' r
|
||||
|
||||
@[specialize]
|
||||
def mfoldr {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (α → s → m s) → s → List α → m s
|
||||
| f s [] := pure s
|
||||
| f s (h :: r) := do
|
||||
s' ← mfoldr f s r,
|
||||
s' ← mfoldr f s r;
|
||||
f h s'
|
||||
|
||||
@[specialize]
|
||||
|
|
@ -94,14 +94,14 @@ def mfirst {m : Type u → Type v} [Monad m] [Alternative m] {α : Type w} {β :
|
|||
@[specialize]
|
||||
def mexists {m : Type → Type u} [Monad m] {α : Type v} (f : α → m Bool) : List α → m Bool
|
||||
| [] := pure false
|
||||
| (a::as) := do b ← f a, match b with
|
||||
| (a::as) := do b ← f a; match b with
|
||||
| true := pure true
|
||||
| false := mexists as
|
||||
|
||||
@[specialize]
|
||||
def mforall {m : Type → Type u} [Monad m] {α : Type v} (f : α → m Bool) : List α → m Bool
|
||||
| [] := pure true
|
||||
| (a::as) := do b ← f a, match b with
|
||||
| (a::as) := do b ← f a; match b with
|
||||
| true := mforall as
|
||||
| false := pure false
|
||||
|
||||
|
|
|
|||
|
|
@ -21,13 +21,13 @@ match toBool b with
|
|||
| false := f
|
||||
|
||||
@[macroInline] def orM {m : Type u → Type v} {β : Type u} [Monad m] [HasToBool β] (x y : m β) : m β :=
|
||||
do b ← x,
|
||||
do b ← x;
|
||||
match toBool b with
|
||||
| true := pure b
|
||||
| false := y
|
||||
|
||||
@[macroInline] def andM {m : Type u → Type v} {β : Type u} [Monad m] [HasToBool β] (x y : m β) : m β :=
|
||||
do b ← x,
|
||||
do b ← x;
|
||||
match toBool b with
|
||||
| true := y
|
||||
| false := pure b
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ namespace OptionT
|
|||
{ pure := @OptionT.pure _ _, bind := @OptionT.bind _ _ }
|
||||
|
||||
protected def orelse (ma : OptionT m α) (mb : OptionT m α) : OptionT m α :=
|
||||
(do { some a ← ma | mb,
|
||||
(do { some a ← ma | mb;
|
||||
pure (some a) } : m (Option α))
|
||||
|
||||
@[inline] protected def fail : OptionT m α :=
|
||||
|
|
@ -55,7 +55,7 @@ namespace OptionT
|
|||
⟨λ α, OptionT.monadMap⟩
|
||||
|
||||
protected def catch (ma : OptionT m α) (handle : Unit → OptionT m α) : OptionT m α :=
|
||||
(do { some a ← ma | (handle ()),
|
||||
(do { some a ← ma | (handle ());
|
||||
pure a } : m (Option α))
|
||||
|
||||
instance : MonadExcept Unit (OptionT m) :=
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ pure
|
|||
λ r, pure a
|
||||
|
||||
@[inline] protected def bind (x : ReaderT ρ m α) (f : α → ReaderT ρ m β) : ReaderT ρ m β :=
|
||||
λ r, do a ← x r, f a r
|
||||
λ r, do a ← x r; f a r
|
||||
|
||||
@[inline] protected def map (f : α → β) (x : ReaderT ρ m α) : ReaderT ρ m β :=
|
||||
λ r, f <$> x r
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ variables [Monad m] {α β : Type u}
|
|||
λ s, pure (a, s)
|
||||
|
||||
@[inline] protected def bind (x : StateT σ m α) (f : α → StateT σ m β) : StateT σ m β :=
|
||||
λ s, do (a, s) ← x s, f a s
|
||||
λ s, do (a, s) ← x s; f a s
|
||||
|
||||
@[inline] protected def map (f : α → β) (x : StateT σ m α) : StateT σ m β :=
|
||||
λ s, do (a, s) ← x s, pure (f a, s)
|
||||
λ s, do (a, s) ← x s; pure (f a, s)
|
||||
|
||||
instance : Monad (StateT σ m) :=
|
||||
{ pure := @StateT.pure _ _ _, bind := @StateT.bind _ _ _, map := @StateT.map _ _ _ }
|
||||
|
|
@ -59,7 +59,7 @@ instance [Alternative m] : Alternative (StateT σ m) :=
|
|||
λ s, pure (⟨⟩, f s)
|
||||
|
||||
@[inline] protected def lift {α : Type u} (t : m α) : StateT σ m α :=
|
||||
λ s, do a ← t, pure (a, s)
|
||||
λ s, do a ← t; pure (a, s)
|
||||
|
||||
instance : HasMonadLift m (StateT σ m) :=
|
||||
⟨@StateT.lift σ m _⟩
|
||||
|
|
@ -70,8 +70,8 @@ instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateT σ m) (State
|
|||
@[inline] protected def adapt {σ σ' σ'' α : Type u} {m : Type u → Type v} [Monad m] (split : σ → σ' × σ'')
|
||||
(join : σ' → σ'' → σ) (x : StateT σ' m α) : StateT σ m α :=
|
||||
λ st, do
|
||||
let (st, ctx) := split st,
|
||||
(a, st') ← x st,
|
||||
let (st, ctx) := split st;
|
||||
(a, st') ← x st;
|
||||
pure (a, join st' ctx)
|
||||
|
||||
instance (ε) [MonadExcept ε m] : MonadExcept ε (StateT σ m) :=
|
||||
|
|
@ -100,7 +100,7 @@ section
|
|||
variables {σ : Type u} {m : Type u → Type v}
|
||||
|
||||
@[inline] def getModify [MonadState σ m] [Monad m] (f : σ → σ) : m σ :=
|
||||
do s ← get, modify f, pure s
|
||||
do s ← get; modify f; pure s
|
||||
|
||||
-- NOTE: The Ordering of the following two instances determines that the top-most `StateT` Monad layer
|
||||
-- will be picked first
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ miterate₂ a₁ a₂ b (λ _ a₁ a₂ b, f b a₁ a₂)
|
|||
| i :=
|
||||
if h : i < a.size then
|
||||
let idx : Fin a.size := ⟨i, h⟩;
|
||||
do r ← f (a.fget idx),
|
||||
do r ← f (a.fget idx);
|
||||
match r with
|
||||
| some v := pure r
|
||||
| none := mfindAux (i+1)
|
||||
|
|
@ -216,7 +216,7 @@ variables {m : Type → Type v} [Monad m]
|
|||
| i :=
|
||||
if h : i < a.size then
|
||||
let idx : Fin a.size := ⟨i, h⟩;
|
||||
do b ← p (a.fget idx),
|
||||
do b ← p (a.fget idx);
|
||||
match b with
|
||||
| true := pure true
|
||||
| false := anyMAux (i+1)
|
||||
|
|
@ -266,7 +266,7 @@ variables {m : Type v → Type v} [Monad m]
|
|||
let idx : Fin a.size := ⟨i, h⟩;
|
||||
let v : α := a.fget idx;
|
||||
let a := a.fset idx (@unsafeCast _ _ ⟨v⟩ ());
|
||||
do newV ← f i v, ummapAux (i+1) (a.fset idx (@unsafeCast _ _ ⟨v⟩ newV))
|
||||
do newV ← f i v; ummapAux (i+1) (a.fset idx (@unsafeCast _ _ ⟨v⟩ newV))
|
||||
else
|
||||
pure (unsafeCast a)
|
||||
|
||||
|
|
@ -277,10 +277,10 @@ ummapAux (λ i a, f a) 0 as
|
|||
ummapAux f 0 as
|
||||
|
||||
@[implementedBy Array.ummap] def mmap (f : α → m β) (as : Array α) : m (Array β) :=
|
||||
as.mfoldl (λ bs a, do b ← f a, pure (bs.push b)) (mkEmpty as.size)
|
||||
as.mfoldl (λ bs a, do b ← f a; pure (bs.push b)) (mkEmpty as.size)
|
||||
|
||||
@[implementedBy Array.ummapIdx] def mmapIdx (f : Nat → α → m β) (as : Array α) : m (Array β) :=
|
||||
as.miterate (mkEmpty as.size) (λ i a bs, do b ← f i.val a, pure (bs.push b))
|
||||
as.miterate (mkEmpty as.size) (λ i a bs, do b ← f i.val a; pure (bs.push b))
|
||||
end
|
||||
|
||||
@[inline] def modify [Inhabited α] (a : Array α) (i : Nat) (f : α → α) : Array α :=
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ variables {α : Type u} {β : Type v} {δ : Type w} {m : Type w → Type w} [Mon
|
|||
|
||||
@[specialize] def mfoldl (f : δ → α → β → m δ) : δ → AssocList α β → m δ
|
||||
| d nil := pure d
|
||||
| d (cons a b es) := do d ← f d a b, mfoldl d es
|
||||
| d (cons a b es) := do d ← f d a b; mfoldl d es
|
||||
|
||||
@[inline] def foldl (f : δ → α → β → δ) (d : δ) (as : AssocList α β) : δ :=
|
||||
Id.run (mfoldl f d as)
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ variables {m : Type v → Type v} [Monad m]
|
|||
| (leaf vs) b := vs.mfoldl f b
|
||||
|
||||
@[specialize] def mfoldl (f : β → α → m β) (b : β) (t : PersistentArray α) : m β :=
|
||||
do b ← mfoldlAux f t.root b, t.tail.mfoldl f b
|
||||
do b ← mfoldlAux f t.root b; t.tail.mfoldl f b
|
||||
|
||||
end
|
||||
|
||||
|
|
@ -152,8 +152,8 @@ variables {m : Type v → Type v} [Monad m]
|
|||
|
||||
@[specialize] def mmap (f : α → m β) (t : PersistentArray α) : m (PersistentArray β) :=
|
||||
do
|
||||
root ← mmapAux f t.root,
|
||||
tail ← t.tail.mmap f,
|
||||
root ← mmapAux f t.root;
|
||||
tail ← t.tail.mmap f;
|
||||
pure { tail := tail, root := root, .. t }
|
||||
|
||||
end
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def IO.setRandSeed (n : Nat) : IO Unit :=
|
|||
IO.stdGenRef.set (mkStdGen n)
|
||||
|
||||
def IO.rand (lo hi : Nat) : IO Nat :=
|
||||
do gen ← IO.stdGenRef.get,
|
||||
let (r, gen) := randNat gen lo hi,
|
||||
IO.stdGenRef.set gen,
|
||||
do gen ← IO.stdGenRef.get;
|
||||
let (r, gen) := randNat gen lo hi;
|
||||
IO.stdGenRef.set gen;
|
||||
pure r
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ protected def max : RBNode α β → Option (Sigma (λ k : α, β k))
|
|||
@[specialize] def mfold {m : Type w → Type w'} [Monad m] (f : σ → Π (k : α), β k → m σ) : σ → RBNode α β → m σ
|
||||
| b leaf := pure b
|
||||
| b (node _ l k v r) := do
|
||||
b ← mfold b l,
|
||||
b ← f b k v,
|
||||
b ← mfold b l;
|
||||
b ← f b k v;
|
||||
mfold b r
|
||||
|
||||
@[specialize] def revFold (f : σ → Π (k : α), β k → σ) : σ → RBNode α β → σ
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ open Fs
|
|||
|
||||
@[specialize] partial def iterate {α β : Type} : α → (α → IO (Sum α β)) → IO β
|
||||
| a f :=
|
||||
do v ← f a,
|
||||
do v ← f a;
|
||||
match v with
|
||||
| Sum.inl a := iterate a f
|
||||
| Sum.inr b := pure b
|
||||
|
|
@ -154,18 +154,18 @@ do b ← h.read 1,
|
|||
|
||||
def handle.readToEnd (h : handle) : m String :=
|
||||
Prim.liftIO $ Prim.iterate "" $ λ r, do
|
||||
done ← h.isEof,
|
||||
done ← h.isEof;
|
||||
if done
|
||||
then pure (Sum.inr r) -- stop
|
||||
else do
|
||||
-- HACK: use less efficient `getLine` while `read` is broken
|
||||
c ← h.getLine,
|
||||
c ← h.getLine;
|
||||
pure $ Sum.inl (r ++ c) -- continue
|
||||
|
||||
def readFile (fname : String) (bin := false) : m String :=
|
||||
do h ← handle.mk fname Mode.read bin,
|
||||
r ← h.readToEnd,
|
||||
h.close,
|
||||
do h ← handle.mk fname Mode.read bin;
|
||||
r ← h.readToEnd;
|
||||
h.close;
|
||||
pure r
|
||||
|
||||
-- def writeFile (fname : String) (data : String) (bin := false) : m Unit :=
|
||||
|
|
@ -229,8 +229,8 @@ variables {m : Type → Type} [Monad m] [monadIO m]
|
|||
@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := Prim.liftIO (Prim.Ref.swap r a)
|
||||
@[inline] def Ref.reset {α : Type} (r : Ref α) : m Unit := Prim.liftIO (Prim.Ref.reset r)
|
||||
@[inline] def Ref.modify {α : Type} (r : Ref α) (f : α → α) : m Unit :=
|
||||
do v ← r.get,
|
||||
r.reset,
|
||||
do v ← r.get;
|
||||
r.reset;
|
||||
r.set (f v)
|
||||
end
|
||||
end IO
|
||||
|
|
@ -259,7 +259,7 @@ instance HasRepr.HasEval {α : Type u} [HasRepr α] : HasEval α :=
|
|||
⟨λ a, IO.println (repr a)⟩
|
||||
|
||||
instance IO.HasEval {α : Type} [HasEval α] : HasEval (IO α) :=
|
||||
⟨λ x, do a ← x, HasEval.eval a⟩
|
||||
⟨λ x, do a ← x; HasEval.eval a⟩
|
||||
|
||||
-- special case: do not print `()`
|
||||
instance IOUnit.HasEval : HasEval (IO Unit) :=
|
||||
|
|
|
|||
|
|
@ -145,31 +145,31 @@ constant attributeArrayRef : IO.Ref (Array AttributeImpl) := default _
|
|||
|
||||
/- Low level attribute registration function. -/
|
||||
def registerAttribute (attr : AttributeImpl) : IO Unit :=
|
||||
do m ← attributeMapRef.get,
|
||||
when (m.contains attr.name) $ throw (IO.userError ("invalid attribute declaration, '" ++ toString attr.name ++ "' has already been used")),
|
||||
initializing ← IO.initializing,
|
||||
unless initializing $ throw (IO.userError ("failed to register attribute, attributes can only be registered during initialization")),
|
||||
attributeMapRef.modify (λ m, m.insert attr.name attr),
|
||||
do m ← attributeMapRef.get;
|
||||
when (m.contains attr.name) $ throw (IO.userError ("invalid attribute declaration, '" ++ toString attr.name ++ "' has already been used"));
|
||||
initializing ← IO.initializing;
|
||||
unless initializing $ throw (IO.userError ("failed to register attribute, attributes can only be registered during initialization"));
|
||||
attributeMapRef.modify (λ m, m.insert attr.name attr);
|
||||
attributeArrayRef.modify (λ attrs, attrs.push attr)
|
||||
|
||||
/- Return true iff `n` is the name of a registered attribute. -/
|
||||
@[export lean.is_attribute_core]
|
||||
def isAttribute (n : Name) : IO Bool :=
|
||||
do m ← attributeMapRef.get, pure (m.contains n)
|
||||
do m ← attributeMapRef.get; pure (m.contains n)
|
||||
|
||||
/- Return the name of all registered attributes. -/
|
||||
def getAttributeNames : IO (List Name) :=
|
||||
do m ← attributeMapRef.get, pure $ m.fold (λ r n _, n::r) []
|
||||
do m ← attributeMapRef.get; pure $ m.fold (λ r n _, n::r) []
|
||||
|
||||
def getAttributeImpl (attrName : Name) : IO AttributeImpl :=
|
||||
do m ← attributeMapRef.get,
|
||||
do m ← attributeMapRef.get;
|
||||
match m.find attrName with
|
||||
| some attr := pure attr
|
||||
| none := throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
|
||||
|
||||
@[export lean.attribute_application_time_core]
|
||||
def attributeApplicationTime (n : Name) : IO AttributeApplicationTime :=
|
||||
do attr ← getAttributeImpl n,
|
||||
do attr ← getAttributeImpl n;
|
||||
pure attr.applicationTime
|
||||
|
||||
namespace Environment
|
||||
|
|
@ -181,7 +181,7 @@ namespace Environment
|
|||
- `args` is not valid for `attr`. -/
|
||||
@[export lean.add_attribute_core]
|
||||
def addAttribute (env : Environment) (decl : Name) (attrName : Name) (args : Syntax := Syntax.missing) (persistent := true) : IO Environment :=
|
||||
do attr ← getAttributeImpl attrName,
|
||||
do attr ← getAttributeImpl attrName;
|
||||
attr.add env decl args persistent
|
||||
|
||||
/- Add a scoped attribute `attr` to declaration `decl` with arguments `args` and scope `decl.getPrefix`.
|
||||
|
|
@ -194,7 +194,7 @@ do attr ← getAttributeImpl attrName,
|
|||
Remark: the attribute will not be activated if `decl` is not inside the current namespace `env.getNamespace`. -/
|
||||
@[export lean.add_scoped_attribute_core]
|
||||
def addScopedAttribute (env : Environment) (decl : Name) (attrName : Name) (args : Syntax := Syntax.missing) : IO Environment :=
|
||||
do attr ← getAttributeImpl attrName,
|
||||
do attr ← getAttributeImpl attrName;
|
||||
attr.addScoped env decl args
|
||||
|
||||
/- Remove attribute `attr` from declaration `decl`. The effect is the current scope.
|
||||
|
|
@ -204,36 +204,36 @@ do attr ← getAttributeImpl attrName,
|
|||
- `args` is not valid for `attr`. -/
|
||||
@[export lean.erase_attribute_core]
|
||||
def eraseAttribute (env : Environment) (decl : Name) (attrName : Name) (persistent := true) : IO Environment :=
|
||||
do attr ← getAttributeImpl attrName,
|
||||
do attr ← getAttributeImpl attrName;
|
||||
attr.erase env decl persistent
|
||||
|
||||
/- Activate the scoped attribute `attr` for all declarations in scope `scope`.
|
||||
We use this function to implement the command `open foo`. -/
|
||||
@[export lean.activate_scoped_attribute_core]
|
||||
def activateScopedAttribute (env : Environment) (attrName : Name) (scope : Name) : IO Environment :=
|
||||
do attr ← getAttributeImpl attrName,
|
||||
do attr ← getAttributeImpl attrName;
|
||||
attr.activateScoped env scope
|
||||
|
||||
/- Activate all scoped attributes at `scope` -/
|
||||
@[export lean.activate_scoped_attributes_core]
|
||||
def activateScopedAttributes (env : Environment) (scope : Name) : IO Environment :=
|
||||
do attrs ← attributeArrayRef.get,
|
||||
do attrs ← attributeArrayRef.get;
|
||||
attrs.mfoldl (λ env attr, attr.activateScoped env scope) env
|
||||
|
||||
/- We use this function to implement commands `namespace foo` and `section foo`.
|
||||
It activates scoped attributes in the new resulting namespace. -/
|
||||
@[export lean.push_scope_core]
|
||||
def pushScope (env : Environment) (header : Name) (isNamespace : Bool) : IO Environment :=
|
||||
do let env := env.pushScopeCore header isNamespace,
|
||||
let ns := env.getNamespace,
|
||||
attrs ← attributeArrayRef.get,
|
||||
attrs.mfoldl (λ env attr, do env ← attr.pushScope env, if isNamespace then attr.activateScoped env ns else pure env) env
|
||||
do let env := env.pushScopeCore header isNamespace;
|
||||
let ns := env.getNamespace;
|
||||
attrs ← attributeArrayRef.get;
|
||||
attrs.mfoldl (λ env attr, do env ← attr.pushScope env; if isNamespace then attr.activateScoped env ns else pure env) env
|
||||
|
||||
/- We use this function to implement commands `end foo` for closing namespaces and sections. -/
|
||||
@[export lean.pop_scope_core]
|
||||
def popScope (env : Environment) : IO Environment :=
|
||||
do let env := env.popScopeCore,
|
||||
attrs ← attributeArrayRef.get,
|
||||
do let env := env.popScopeCore;
|
||||
attrs ← attributeArrayRef.get;
|
||||
attrs.mfoldl (λ env attr, attr.popScope env) env
|
||||
|
||||
end Environment
|
||||
|
|
@ -260,20 +260,20 @@ ext : PersistentEnvExtension Name NameSet ← registerPersistentEnvExtension {
|
|||
let r : Array Name := es.fold (λ a e, a.push e) Array.empty;
|
||||
r.qsort Name.quickLt,
|
||||
statsFn := λ s, "tag attribute" ++ Format.line ++ "number of local entries: " ++ format s.size
|
||||
},
|
||||
};
|
||||
let attrImpl : AttributeImpl := {
|
||||
name := name,
|
||||
descr := descr,
|
||||
add := λ env decl args persistent, do
|
||||
unless args.isMissing $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', unexpected argument")),
|
||||
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent")),
|
||||
unless args.isMissing $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', unexpected argument"));
|
||||
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent"));
|
||||
unless (env.getModuleIdxFor decl).isNone $
|
||||
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module")),
|
||||
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module"));
|
||||
match validate env decl with
|
||||
| Except.error msg := throw (IO.userError ("invalid attribute '" ++ toString name ++ "', " ++ msg))
|
||||
| _ := pure $ ext.addEntry env decl
|
||||
},
|
||||
registerAttribute attrImpl,
|
||||
};
|
||||
registerAttribute attrImpl;
|
||||
pure { attr := attrImpl, ext := ext }
|
||||
|
||||
namespace TagAttribute
|
||||
|
|
@ -309,23 +309,23 @@ ext : PersistentEnvExtension (Name × α) (NameMap α) ← registerPersistentEnv
|
|||
let r : Array (Name × α) := m.fold (λ a n p, a.push (n, p)) Array.empty;
|
||||
r.qsort (λ a b, Name.quickLt a.1 b.1),
|
||||
statsFn := λ s, "parametric attribute" ++ Format.line ++ "number of local entries: " ++ format s.size
|
||||
},
|
||||
};
|
||||
let attrImpl : AttributeImpl := {
|
||||
name := name,
|
||||
descr := descr,
|
||||
add := λ env decl args persistent, do
|
||||
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent")),
|
||||
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent"));
|
||||
unless (env.getModuleIdxFor decl).isNone $
|
||||
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module")),
|
||||
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module"));
|
||||
match getParam env decl args with
|
||||
| Except.error msg := throw (IO.userError ("invalid attribute '" ++ toString name ++ "', " ++ msg))
|
||||
| Except.ok val := do
|
||||
let env := ext.addEntry env (decl, val),
|
||||
let env := ext.addEntry env (decl, val);
|
||||
match afterSet env decl val with
|
||||
| Except.error msg := throw (IO.userError ("invalid attribute '" ++ toString name ++ "', " ++ msg))
|
||||
| Except.ok env := pure env
|
||||
},
|
||||
registerAttribute attrImpl,
|
||||
};
|
||||
registerAttribute attrImpl;
|
||||
pure { attr := attrImpl, ext := ext }
|
||||
|
||||
namespace ParametricAttribute
|
||||
|
|
@ -368,19 +368,19 @@ ext : PersistentEnvExtension (Name × α) (NameMap α) ← registerPersistentEnv
|
|||
let r : Array (Name × α) := m.fold (λ a n p, a.push (n, p)) Array.empty;
|
||||
r.qsort (λ a b, Name.quickLt a.1 b.1),
|
||||
statsFn := λ s, "enumeration attribute extension" ++ Format.line ++ "number of local entries: " ++ format s.size
|
||||
},
|
||||
};
|
||||
let attrs := attrDescrs.map $ λ ⟨name, descr, val⟩, { AttributeImpl .
|
||||
name := name,
|
||||
descr := descr,
|
||||
add := λ env decl args persistent, do
|
||||
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent")),
|
||||
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent"));
|
||||
unless (env.getModuleIdxFor decl).isNone $
|
||||
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module")),
|
||||
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module"));
|
||||
match validate env decl val with
|
||||
| Except.error msg := throw (IO.userError ("invalid attribute '" ++ toString name ++ "', " ++ msg))
|
||||
| _ := pure $ ext.addEntry env (decl, val)
|
||||
},
|
||||
attrs.mfor registerAttribute,
|
||||
};
|
||||
attrs.mfor registerAttribute;
|
||||
pure { ext := ext, attrs := attrs }
|
||||
|
||||
namespace EnumAttributes
|
||||
|
|
|
|||
|
|
@ -101,11 +101,11 @@ private def consumeNLambdas : Nat → Expr → Option Expr
|
|||
partial def getClassName (env : Environment) : Expr → Option Name
|
||||
| (Expr.pi _ _ _ d) := getClassName d
|
||||
| e := do
|
||||
Expr.const c _ ← pure e.getAppFn | none,
|
||||
info ← env.find c,
|
||||
Expr.const c _ ← pure e.getAppFn | none;
|
||||
info ← env.find c;
|
||||
match info.value with
|
||||
| some val := do
|
||||
body ← consumeNLambdas e.getAppNumArgs val,
|
||||
body ← consumeNLambdas e.getAppNumArgs val;
|
||||
getClassName body
|
||||
| none :=
|
||||
if isClass env c then some c
|
||||
|
|
@ -125,8 +125,8 @@ registerAttribute {
|
|||
name := `class,
|
||||
descr := "type class",
|
||||
add := λ env decl args persistent, do
|
||||
unless args.isMissing $ throw (IO.userError ("invalid attribute 'class', unexpected argument")),
|
||||
unless persistent $ throw (IO.userError ("invalid attribute 'class', must be persistent")),
|
||||
unless args.isMissing $ throw (IO.userError ("invalid attribute 'class', unexpected argument"));
|
||||
unless persistent $ throw (IO.userError ("invalid attribute 'class', must be persistent"));
|
||||
IO.ofExcept (addClass env decl)
|
||||
}
|
||||
|
||||
|
|
@ -135,8 +135,8 @@ registerAttribute {
|
|||
name := `instance,
|
||||
descr := "type class instance",
|
||||
add := λ env decl args persistent, do
|
||||
unless args.isMissing $ throw (IO.userError ("invalid attribute 'instance', unexpected argument")),
|
||||
unless persistent $ throw (IO.userError ("invalid attribute 'instance', must be persistent")),
|
||||
unless args.isMissing $ throw (IO.userError ("invalid attribute 'instance', unexpected argument"));
|
||||
unless persistent $ throw (IO.userError ("invalid attribute 'instance', must be persistent"));
|
||||
IO.ofExcept (addInstance env decl)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -58,9 +58,9 @@ def mkUInt32Lit (n : Nat) : Expr :=
|
|||
mkUIntLit {nbits := 32} n
|
||||
|
||||
def foldBinUInt (fn : NumScalarTypeInfo → Bool → Nat → Nat → Nat) (beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr :=
|
||||
do n₁ ← getNumLit a₁,
|
||||
n₂ ← getNumLit a₂,
|
||||
info ← getInfoFromVal a₁,
|
||||
do n₁ ← getNumLit a₁;
|
||||
n₂ ← getNumLit a₂;
|
||||
info ← getInfoFromVal a₁;
|
||||
pure $ mkUIntLit info (fn info beforeErasure n₁ n₂)
|
||||
|
||||
def foldUIntAdd := foldBinUInt $ λ _ _, (+)
|
||||
|
|
@ -77,8 +77,8 @@ def uintBinFoldFns : List (Name × BinFoldFn) :=
|
|||
numScalarTypes.foldl (λ r info, r ++ (preUIntBinFoldFns.map (λ ⟨suffix, fn⟩, (info.id ++ suffix, fn)))) []
|
||||
|
||||
def foldNatBinOp (fn : Nat → Nat → Nat) (a₁ a₂ : Expr) : Option Expr :=
|
||||
do n₁ ← getNumLit a₁,
|
||||
n₂ ← getNumLit a₂,
|
||||
do n₁ ← getNumLit a₁;
|
||||
n₂ ← getNumLit a₂;
|
||||
pure $ Expr.lit (Literal.natVal (fn n₁ n₂))
|
||||
|
||||
def foldNatAdd (_ : Bool) := foldNatBinOp (+)
|
||||
|
|
@ -105,8 +105,8 @@ match beforeErasure, r with
|
|||
|
||||
def foldNatBinPred (mkPred : Expr → Expr → Expr) (fn : Nat → Nat → Bool)
|
||||
(beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr :=
|
||||
do n₁ ← getNumLit a₁,
|
||||
n₂ ← getNumLit a₂,
|
||||
do n₁ ← getNumLit a₁;
|
||||
n₂ ← getNumLit a₂;
|
||||
pure $ toDecidableExpr beforeErasure (mkPred a₁ a₂) (fn n₁ n₂)
|
||||
|
||||
def foldNatDecEq := foldNatBinPred mkNatEq (λ a b, a = b)
|
||||
|
|
@ -156,18 +156,18 @@ def binFoldFns : List (Name × BinFoldFn) :=
|
|||
boolFoldFns ++ uintBinFoldFns ++ natFoldFns
|
||||
|
||||
def foldNatSucc (_ : Bool) (a : Expr) : Option Expr :=
|
||||
do n ← getNumLit a,
|
||||
do n ← getNumLit a;
|
||||
pure $ Expr.lit (Literal.natVal (n+1))
|
||||
|
||||
def foldCharOfNat (beforeErasure : Bool) (a : Expr) : Option Expr :=
|
||||
do guard (!beforeErasure),
|
||||
n ← getNumLit a,
|
||||
do guard (!beforeErasure);
|
||||
n ← getNumLit a;
|
||||
pure $
|
||||
if isValidChar (UInt32.ofNat n) then mkUInt32Lit n
|
||||
else mkUInt32Lit 0
|
||||
|
||||
def foldToNat (_ : Bool) (a : Expr) : Option Expr :=
|
||||
do n ← getNumLit a,
|
||||
do n ← getNumLit a;
|
||||
pure $ Expr.lit (Literal.natVal n)
|
||||
|
||||
def uintFoldToNatFns : List (Name × UnFoldFn) :=
|
||||
|
|
@ -188,7 +188,7 @@ unFoldFns.lookup fn
|
|||
def foldBinOp (beforeErasure : Bool) (f : Expr) (a : Expr) (b : Expr) : Option Expr :=
|
||||
match f with
|
||||
| Expr.const fn _ := do
|
||||
foldFn ← findBinFoldFn fn,
|
||||
foldFn ← findBinFoldFn fn;
|
||||
foldFn beforeErasure a b
|
||||
| _ := none
|
||||
|
||||
|
|
@ -196,7 +196,7 @@ match f with
|
|||
def foldUnOp (beforeErasure : Bool) (f : Expr) (a : Expr) : Option Expr :=
|
||||
match f with
|
||||
| Expr.const fn _ := do
|
||||
foldFn ← findUnFoldFn fn,
|
||||
foldFn ← findUnFoldFn fn;
|
||||
foldFn beforeErasure a
|
||||
| _ := none
|
||||
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ def getExternEntryFor (d : ExternAttrData) (backend : Name) : Option ExternEntry
|
|||
getExternEntryForAux backend d.entries
|
||||
|
||||
def mkExternCall (d : ExternAttrData) (backend : Name) (args : List String) : Option String :=
|
||||
do e ← getExternEntryFor d backend,
|
||||
do e ← getExternEntryFor d backend;
|
||||
expandExternEntry e args
|
||||
|
||||
def isExtern (env : Environment) (fn : Name) : Bool :=
|
||||
|
|
@ -161,8 +161,8 @@ match getExternAttrData env fn with
|
|||
| _ := false
|
||||
|
||||
def getExternNameFor (env : Environment) (backend : Name) (fn : Name) : Option String :=
|
||||
do data ← getExternAttrData env fn,
|
||||
entry ← getExternEntryFor data backend,
|
||||
do data ← getExternAttrData env fn;
|
||||
entry ← getExternEntryFor data backend;
|
||||
match entry with
|
||||
| ExternEntry.standard _ n := pure n
|
||||
| ExternEntry.foreign _ n := pure n
|
||||
|
|
|
|||
|
|
@ -353,7 +353,7 @@ bs.map $ λ b, match b with
|
|||
|
||||
@[inline] def mmodifyJPs {m : Type → Type} [Monad m] (bs : Array FnBody) (f : FnBody → m FnBody) : m (Array FnBody) :=
|
||||
bs.mmap $ λ b, match b with
|
||||
| FnBody.jdecl j xs v k := do v ← f v, pure $ FnBody.jdecl j xs v k
|
||||
| FnBody.jdecl j xs v k := do v ← f v; pure $ FnBody.jdecl j xs v k
|
||||
| other := pure other
|
||||
|
||||
@[export lean.ir.mk_alt_core] def mkAlt (n : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) (b : FnBody) : Alt := Alt.ctor ⟨n, cidx, size, usize, ssize⟩ b
|
||||
|
|
@ -500,7 +500,7 @@ else none
|
|||
|
||||
def addParamsRename (ρ : IndexRenaming) (ps₁ ps₂ : Array Param) : Option IndexRenaming :=
|
||||
if ps₁.size != ps₂.size then none
|
||||
else Array.foldl₂ (λ ρ p₁ p₂, do ρ ← ρ, addParamRename ρ p₁ p₂) (some ρ) ps₁ ps₂
|
||||
else Array.foldl₂ (λ ρ p₁ p₂, do ρ ← ρ; addParamRename ρ p₁ p₂) (some ρ) ps₁ ps₂
|
||||
|
||||
partial def FnBody.alphaEqv : IndexRenaming → FnBody → FnBody → Bool
|
||||
| ρ (FnBody.vdecl x₁ t₁ v₁ b₁) (FnBody.vdecl x₂ t₂ v₂ b₂) := t₁ == t₂ && aeqv ρ v₁ v₂ && FnBody.alphaEqv (addVarRename ρ x₁.idx x₂.idx) b₁ b₂
|
||||
|
|
|
|||
|
|
@ -70,19 +70,19 @@ if exported then ps else initBorrow ps
|
|||
|
||||
partial def visitFnBody (fnid : FunId) : FnBody → State ParamMap Unit
|
||||
| (FnBody.jdecl j xs v b) := do
|
||||
modify $ λ m, m.insert (Key.jp fnid j) (initBorrow xs),
|
||||
visitFnBody v,
|
||||
modify $ λ m, m.insert (Key.jp fnid j) (initBorrow xs);
|
||||
visitFnBody v;
|
||||
visitFnBody b
|
||||
| e :=
|
||||
unless (e.isTerminal) $ do
|
||||
let (instr, b) := e.split,
|
||||
let (instr, b) := e.split;
|
||||
visitFnBody b
|
||||
|
||||
def visitDecls (env : Environment) (decls : Array Decl) : State ParamMap Unit :=
|
||||
decls.mfor $ λ decl, match decl with
|
||||
| Decl.fdecl f xs _ b := do
|
||||
let exported := isExport env f,
|
||||
modify $ λ m, m.insert (Key.decl f) (initBorrowIfNotExported exported xs),
|
||||
let exported := isExport env f;
|
||||
modify $ λ m, m.insert (Key.decl f) (initBorrowIfNotExported exported xs);
|
||||
visitFnBody f b
|
||||
| _ := pure ()
|
||||
end InitParamMap
|
||||
|
|
@ -156,32 +156,32 @@ def ownArgs (xs : Array Arg) : M Unit :=
|
|||
xs.mfor ownArg
|
||||
|
||||
def isOwned (x : VarId) : M Bool :=
|
||||
do s ← get,
|
||||
do s ← get;
|
||||
pure $ s.owned.contains x.idx
|
||||
|
||||
/- Updates `map[k]` using the current set of `owned` variables. -/
|
||||
def updateParamMap (k : Key) : M Unit :=
|
||||
do
|
||||
s ← get,
|
||||
s ← get;
|
||||
match s.map.find k with
|
||||
| some ps := do
|
||||
ps ← ps.mmap $ λ (p : Param),
|
||||
if p.borrow && s.owned.contains p.x.idx then do
|
||||
markModifiedParamMap, pure { borrow := false, .. p }
|
||||
markModifiedParamMap; pure { borrow := false, .. p }
|
||||
else
|
||||
pure p,
|
||||
pure p;
|
||||
modify $ λ s, { map := s.map.insert k ps, .. s }
|
||||
| none := pure ()
|
||||
|
||||
def getParamInfo (k : Key) : M (Array Param) :=
|
||||
do
|
||||
s ← get,
|
||||
s ← get;
|
||||
match s.map.find k with
|
||||
| some ps := pure ps
|
||||
| none :=
|
||||
match k with
|
||||
| (Key.decl fn) := do
|
||||
ctx ← read,
|
||||
ctx ← read;
|
||||
match findEnvDecl ctx.env fn with
|
||||
| some decl := pure decl.params
|
||||
| none := pure Array.empty -- unreachable if well-formed input
|
||||
|
|
@ -190,8 +190,8 @@ match s.map.find k with
|
|||
/- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
|
||||
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.mfor $ λ i, do
|
||||
let x := xs.get i,
|
||||
let p := ps.get i,
|
||||
let x := xs.get i;
|
||||
let p := ps.get i;
|
||||
unless p.borrow $ ownArg x
|
||||
|
||||
/- For each xs[i], if xs[i] is owned, then mark ps[i] as owned.
|
||||
|
|
@ -201,8 +201,8 @@ xs.size.mfor $ λ i, do
|
|||
"break" the tail call. -/
|
||||
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
xs.size.mfor $ λ i, do
|
||||
let x := xs.get i,
|
||||
let p := ps.get i,
|
||||
let x := xs.get i;
|
||||
let p := ps.get i;
|
||||
match x with
|
||||
| Arg.var x := mwhen (isOwned x) $ ownVar p.x
|
||||
| _ := pure ()
|
||||
|
|
@ -219,7 +219,7 @@ xs.size.mfor $ λ i, do
|
|||
-/
|
||||
def ownArgsIfParam (xs : Array Arg) : M Unit :=
|
||||
do
|
||||
ctx ← read,
|
||||
ctx ← read;
|
||||
xs.mfor $ λ x,
|
||||
match x with
|
||||
| Arg.var x := when (ctx.paramSet.contains x.idx) $ ownVar x
|
||||
|
|
@ -230,7 +230,7 @@ def collectExpr (z : VarId) : Expr → M Unit
|
|||
| (Expr.reuse x _ _ ys) := ownVar z *> ownVar x *> ownArgsIfParam ys
|
||||
| (Expr.ctor _ xs) := ownVar z *> ownArgsIfParam xs
|
||||
| (Expr.proj _ x) := mwhen (isOwned z) $ ownVar x
|
||||
| (Expr.fap g xs) := do ps ← getParamInfo (Key.decl g),
|
||||
| (Expr.fap g xs) := do ps ← getParamInfo (Key.decl g);
|
||||
-- dbgTrace ("collectExpr: " ++ toString g ++ " " ++ toString (formatParams ps)) $ λ _,
|
||||
ownVar z *> ownArgsUsingParams xs ps
|
||||
| (Expr.ap x ys) := ownVar z *> ownVar x *> ownArgs ys
|
||||
|
|
@ -238,12 +238,12 @@ def collectExpr (z : VarId) : Expr → M Unit
|
|||
| other := pure ()
|
||||
|
||||
def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
match v, b with
|
||||
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) :=
|
||||
when (ctx.currFn == g && x == z) $ do
|
||||
-- dbgTrace ("preserveTailCall " ++ toString b) $ λ _, do
|
||||
ps ← getParamInfo (Key.decl g),
|
||||
ps ← getParamInfo (Key.decl g);
|
||||
ownParamsUsingArgs ys ps
|
||||
| _, _ := pure ()
|
||||
|
||||
|
|
@ -252,24 +252,24 @@ def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx :=
|
|||
|
||||
partial def collectFnBody : FnBody → M Unit
|
||||
| (FnBody.jdecl j ys v b) := do
|
||||
adaptReader (λ ctx, updateParamSet ctx ys) (collectFnBody v),
|
||||
ctx ← read,
|
||||
updateParamMap (Key.jp ctx.currFn j),
|
||||
adaptReader (λ ctx, updateParamSet ctx ys) (collectFnBody v);
|
||||
ctx ← read;
|
||||
updateParamMap (Key.jp ctx.currFn j);
|
||||
collectFnBody b
|
||||
| (FnBody.vdecl x _ v b) := collectFnBody b *> collectExpr x v *> preserveTailCall x v b
|
||||
| (FnBody.jmp j ys) := do
|
||||
ctx ← read,
|
||||
ps ← getParamInfo (Key.jp ctx.currFn j),
|
||||
ownArgsUsingParams ys ps, -- for making sure the join point can reuse
|
||||
ctx ← read;
|
||||
ps ← getParamInfo (Key.jp ctx.currFn j);
|
||||
ownArgsUsingParams ys ps; -- for making sure the join point can reuse
|
||||
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
|
||||
| (FnBody.case _ _ alts) := alts.mfor $ λ alt, collectFnBody alt.body
|
||||
| e := unless (e.isTerminal) $ collectFnBody e.body
|
||||
|
||||
@[specialize] partial def whileModifingOwnedAux (x : M Unit) : Unit → M Unit
|
||||
| _ := do
|
||||
modify $ λ s, { modifiedOwned := false, .. s },
|
||||
x,
|
||||
s ← get,
|
||||
modify $ λ s, { modifiedOwned := false, .. s };
|
||||
x;
|
||||
s ← get;
|
||||
if s.modifiedOwned then whileModifingOwnedAux ()
|
||||
else pure ()
|
||||
|
||||
|
|
@ -280,18 +280,18 @@ whileModifingOwnedAux x ()
|
|||
partial def collectDecl : Decl → M Unit
|
||||
| (Decl.fdecl f ys _ b) :=
|
||||
adaptReader (λ ctx, let ctx := updateParamSet ctx ys; { currFn := f, .. ctx }) $ do
|
||||
modify $ λ s : BorrowInfState, { owned := {}, .. s },
|
||||
whileModifingOwned (collectFnBody b),
|
||||
modify $ λ s : BorrowInfState, { owned := {}, .. s };
|
||||
whileModifingOwned (collectFnBody b);
|
||||
updateParamMap (Key.decl f)
|
||||
| _ := pure ()
|
||||
|
||||
@[specialize] partial def whileModifingParamMapAux (x : M Unit) : Unit → M Unit
|
||||
| _ := do
|
||||
modify $ λ s, { modifiedParamMap := false, .. s },
|
||||
s ← get,
|
||||
modify $ λ s, { modifiedParamMap := false, .. s };
|
||||
s ← get;
|
||||
-- dbgTrace (toString s.map) $ λ _, do
|
||||
x,
|
||||
s ← get,
|
||||
x;
|
||||
s ← get;
|
||||
if s.modifiedParamMap then whileModifingParamMapAux ()
|
||||
else pure ()
|
||||
|
||||
|
|
@ -301,8 +301,8 @@ whileModifingParamMapAux x ()
|
|||
|
||||
def collectDecls (decls : Array Decl) : M ParamMap :=
|
||||
do
|
||||
whileModifingParamMap (decls.mfor collectDecl),
|
||||
s ← get,
|
||||
whileModifingParamMap (decls.mfor collectDecl);
|
||||
s ← get;
|
||||
pure s.map
|
||||
|
||||
def infer (env : Environment) (decls : Array Decl) : ParamMap :=
|
||||
|
|
@ -312,9 +312,9 @@ end Borrow
|
|||
|
||||
def inferBorrow (decls : Array Decl) : CompilerM (Array Decl) :=
|
||||
do
|
||||
env ← getEnv,
|
||||
let decls := decls.map Decl.normalizeIds,
|
||||
let paramMap := Borrow.infer env decls,
|
||||
env ← getEnv;
|
||||
let decls := decls.map Decl.normalizeIds;
|
||||
let paramMap := Borrow.infer env decls;
|
||||
pure (Borrow.applyParamMap decls paramMap)
|
||||
|
||||
end IR
|
||||
|
|
|
|||
|
|
@ -35,8 +35,8 @@ Name.mkString n "_boxed"
|
|||
abbrev N := State Nat
|
||||
|
||||
private def mkFresh : N VarId :=
|
||||
do idx ← get,
|
||||
modify (+1),
|
||||
do idx ← get;
|
||||
modify (+1);
|
||||
pure {idx := idx }
|
||||
|
||||
def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
|
||||
|
|
@ -45,8 +45,8 @@ ps.size > 0 && (decl.resultType.isScalar || ps.any (λ p, p.ty.isScalar || p.bor
|
|||
|
||||
def mkBoxedVersionAux (decl : Decl) : N Decl :=
|
||||
do
|
||||
let ps := decl.params,
|
||||
qs ← ps.mmap (λ _, do x ← mkFresh, pure { Param . x := x, ty := IRType.object, borrow := false }),
|
||||
let ps := decl.params;
|
||||
qs ← ps.mmap (λ _, do x ← mkFresh; pure { Param . x := x, ty := IRType.object, borrow := false });
|
||||
(newVDecls, xs) ← qs.size.mfold
|
||||
(λ i (r : Array FnBody × Array Arg),
|
||||
let (newVDecls, xs) := r;
|
||||
|
|
@ -54,19 +54,19 @@ qs ← ps.mmap (λ _, do x ← mkFresh, pure { Param . x := x, ty := IRType.obje
|
|||
let q := qs.get i;
|
||||
if !p.ty.isScalar then pure (newVDecls, xs.push (Arg.var q.x))
|
||||
else do
|
||||
x ← mkFresh,
|
||||
x ← mkFresh;
|
||||
pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (default _)), xs.push (Arg.var x)))
|
||||
(Array.empty, Array.empty),
|
||||
r ← mkFresh,
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (default _)),
|
||||
(Array.empty, Array.empty);
|
||||
r ← mkFresh;
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (default _));
|
||||
body ←
|
||||
if !decl.resultType.isScalar then do {
|
||||
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
|
||||
} else do {
|
||||
newR ← mkFresh,
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (default _)),
|
||||
newR ← mkFresh;
|
||||
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (default _));
|
||||
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
|
||||
},
|
||||
};
|
||||
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
|
||||
|
||||
def mkBoxedVersion (decl : Decl) : Decl :=
|
||||
|
|
@ -111,7 +111,7 @@ structure BoxingContext :=
|
|||
abbrev M := ReaderT BoxingContext (StateT Index Id)
|
||||
|
||||
def mkFresh : M VarId :=
|
||||
do idx ← getModify (+1),
|
||||
do idx ← getModify (+1);
|
||||
pure { idx := idx }
|
||||
|
||||
def getEnv : M Environment := BoxingContext.env <$> read
|
||||
|
|
@ -119,17 +119,17 @@ def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read
|
|||
def getResultType : M IRType := BoxingContext.resultType <$> read
|
||||
|
||||
def getVarType (x : VarId) : M IRType :=
|
||||
do localCtx ← getLocalContext,
|
||||
do localCtx ← getLocalContext;
|
||||
match localCtx.getType x with
|
||||
| some t := pure t
|
||||
| none := pure IRType.object -- unreachable, we assume the code is well formed
|
||||
def getJPParams (j : JoinPointId) : M (Array Param) :=
|
||||
do localCtx ← getLocalContext,
|
||||
do localCtx ← getLocalContext;
|
||||
match localCtx.getJPParams j with
|
||||
| some ys := pure ys
|
||||
| none := pure Array.empty -- unreachable, we assume the code is well formed
|
||||
def getDecl (fid : FunId) : M Decl :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
match findEnvDecl' ctx.env fid ctx.decls with
|
||||
| some decl := pure decl
|
||||
| none := pure (default _) -- unreachable if well-formed
|
||||
|
|
@ -150,11 +150,11 @@ def mkCast (x : VarId) (xType : IRType) : Expr :=
|
|||
if xType.isScalar then Expr.box xType x else Expr.unbox x
|
||||
|
||||
@[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody :=
|
||||
do xType ← getVarType x,
|
||||
do xType ← getVarType x;
|
||||
if eqvTypes xType expected then k x
|
||||
else do
|
||||
y ← mkFresh,
|
||||
let v := mkCast x xType,
|
||||
y ← mkFresh;
|
||||
let v := mkCast x xType;
|
||||
FnBody.vdecl y expected v <$> k y
|
||||
|
||||
@[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody :=
|
||||
|
|
@ -169,27 +169,27 @@ xs.miterate (Array.empty, Array.empty) $ λ i (x : Arg) (r : Array Arg × Array
|
|||
match x with
|
||||
| Arg.irrelevant := pure (xs.push x, bs)
|
||||
| Arg.var x := do
|
||||
xType ← getVarType x,
|
||||
xType ← getVarType x;
|
||||
if eqvTypes xType expected then pure (xs.push (Arg.var x), bs)
|
||||
else do
|
||||
y ← mkFresh,
|
||||
let v := mkCast x xType,
|
||||
let b := FnBody.vdecl y expected v FnBody.nil,
|
||||
y ← mkFresh;
|
||||
let v := mkCast x xType;
|
||||
let b := FnBody.vdecl y expected v FnBody.nil;
|
||||
pure (xs.push (Arg.var y), bs.push b)
|
||||
|
||||
@[inline] def castArgsIfNeeded (xs : Array Arg) (ps : Array Param) (k : Array Arg → M FnBody) : M FnBody :=
|
||||
do (ys, bs) ← castArgsIfNeededAux xs (λ i, (ps.get i).ty),
|
||||
b ← k ys,
|
||||
do (ys, bs) ← castArgsIfNeededAux xs (λ i, (ps.get i).ty);
|
||||
b ← k ys;
|
||||
pure (reshape bs b)
|
||||
|
||||
@[inline] def boxArgsIfNeeded (xs : Array Arg) (k : Array Arg → M FnBody) : M FnBody :=
|
||||
do (ys, bs) ← castArgsIfNeededAux xs (λ _, IRType.object),
|
||||
b ← k ys,
|
||||
do (ys, bs) ← castArgsIfNeededAux xs (λ _, IRType.object);
|
||||
b ← k ys;
|
||||
pure (reshape bs b)
|
||||
|
||||
def unboxResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
|
||||
if ty.isScalar then do
|
||||
y ← mkFresh,
|
||||
y ← mkFresh;
|
||||
pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b)
|
||||
else
|
||||
pure $ FnBody.vdecl x ty e b
|
||||
|
|
@ -197,7 +197,7 @@ else
|
|||
def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b : FnBody) : M FnBody :=
|
||||
if eqvTypes ty eType then pure $ FnBody.vdecl x ty e b
|
||||
else do
|
||||
y ← mkFresh,
|
||||
y ← mkFresh;
|
||||
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty (mkCast y eType) b)
|
||||
|
||||
def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
|
||||
|
|
@ -210,13 +210,13 @@ match e with
|
|||
| Expr.reuse w c u ys :=
|
||||
boxArgsIfNeeded ys $ λ ys, pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b
|
||||
| Expr.fap f ys := do
|
||||
decl ← getDecl f,
|
||||
decl ← getDecl f;
|
||||
castArgsIfNeeded ys decl.params $ λ ys,
|
||||
castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b
|
||||
| Expr.pap f ys := do
|
||||
env ← getEnv,
|
||||
decl ← getDecl f,
|
||||
let f := if requiresBoxedVersion env decl then mkBoxedName f else f,
|
||||
env ← getEnv;
|
||||
decl ← getDecl f;
|
||||
let f := if requiresBoxedVersion env decl then mkBoxedName f else f;
|
||||
boxArgsIfNeeded ys $ λ ys, pure $ FnBody.vdecl x ty (Expr.pap f ys) b
|
||||
| Expr.ap f ys :=
|
||||
boxArgsIfNeeded ys $ λ ys,
|
||||
|
|
@ -226,32 +226,32 @@ match e with
|
|||
|
||||
partial def visitFnBody : FnBody → M FnBody
|
||||
| (FnBody.vdecl x t v b) := do
|
||||
b ← withVDecl x t v (visitFnBody b),
|
||||
b ← withVDecl x t v (visitFnBody b);
|
||||
visitVDeclExpr x t v b
|
||||
| (FnBody.jdecl j xs v b) := do
|
||||
v ← withParams xs (visitFnBody v),
|
||||
b ← withJDecl j xs v (visitFnBody b),
|
||||
v ← withParams xs (visitFnBody v);
|
||||
b ← withJDecl j xs v (visitFnBody b);
|
||||
pure $ FnBody.jdecl j xs v b
|
||||
| (FnBody.uset x i y b) := do
|
||||
b ← visitFnBody b,
|
||||
b ← visitFnBody b;
|
||||
castVarIfNeeded y IRType.usize $ λ y,
|
||||
pure $ FnBody.uset x i y b
|
||||
| (FnBody.sset x i o y ty b) := do
|
||||
b ← visitFnBody b,
|
||||
b ← visitFnBody b;
|
||||
castVarIfNeeded y ty $ λ y,
|
||||
pure $ FnBody.sset x i o y ty b
|
||||
| (FnBody.mdata d b) :=
|
||||
FnBody.mdata d <$> visitFnBody b
|
||||
| (FnBody.case tid x alts) := do
|
||||
let expected := getScrutineeType alts,
|
||||
alts ← alts.mmap $ λ alt, alt.mmodifyBody visitFnBody,
|
||||
let expected := getScrutineeType alts;
|
||||
alts ← alts.mmap $ λ alt, alt.mmodifyBody visitFnBody;
|
||||
castVarIfNeeded x expected $ λ x,
|
||||
pure $ FnBody.case tid x alts
|
||||
| (FnBody.ret x) := do
|
||||
expected ← getResultType,
|
||||
expected ← getResultType;
|
||||
castArgIfNeeded x expected (λ x, pure $ FnBody.ret x)
|
||||
| (FnBody.jmp j ys) := do
|
||||
ps ← getJPParams j,
|
||||
ps ← getJPParams j;
|
||||
castArgsIfNeeded ys ps (λ ys, pure $ FnBody.jmp j ys)
|
||||
| other :=
|
||||
pure other
|
||||
|
|
@ -269,7 +269,7 @@ addBoxedVersions env decls
|
|||
end ExplicitBoxing
|
||||
|
||||
def explicitBoxing (decls : Array Decl) : CompilerM (Array Decl) :=
|
||||
do env ← getEnv,
|
||||
do env ← getEnv;
|
||||
pure $ ExplicitBoxing.run env decls
|
||||
|
||||
end IR
|
||||
|
|
|
|||
|
|
@ -17,17 +17,17 @@ structure Context :=
|
|||
abbrev M := ExceptT String (ReaderT Context Id)
|
||||
|
||||
def getDecl (c : Name) : M Decl :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
match findEnvDecl' ctx.env c ctx.decls with
|
||||
| none := throw ("unknown declaration '" ++ toString c ++ "'")
|
||||
| some d := pure d
|
||||
|
||||
def checkVar (x : VarId) : M Unit :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
unless (ctx.localCtx.isLocalVar x.idx || ctx.localCtx.isParam x.idx) $ throw ("unknown variable '" ++ toString x ++ "'")
|
||||
|
||||
def checkJP (j : JoinPointId) : M Unit :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
unless (ctx.localCtx.isJP j.idx) $ throw ("unknown join point '" ++ toString j ++ "'")
|
||||
|
||||
def checkArg (a : Arg) : M Unit :=
|
||||
|
|
@ -46,7 +46,7 @@ def checkObjType (ty : IRType) : M Unit := checkType ty IRType.isObj
|
|||
def checkScalarType (ty : IRType) : M Unit := checkType ty IRType.isScalar
|
||||
|
||||
@[inline] def checkVarType (x : VarId) (p : IRType → Bool) : M Unit :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
match ctx.localCtx.getType x with
|
||||
| some ty := checkType ty p
|
||||
| none := throw ("unknown variable '" ++ toString x ++ "'")
|
||||
|
|
@ -59,14 +59,14 @@ checkVarType x IRType.isScalar
|
|||
|
||||
def checkFullApp (c : FunId) (ys : Array Arg) : M Unit :=
|
||||
do
|
||||
decl ← getDecl c,
|
||||
unless (ys.size == decl.params.size) (throw ("incorrect number of arguments to '" ++ toString c ++ "', " ++ toString ys.size ++ " provided, " ++ toString decl.params.size ++ " expected")),
|
||||
decl ← getDecl c;
|
||||
unless (ys.size == decl.params.size) (throw ("incorrect number of arguments to '" ++ toString c ++ "', " ++ toString ys.size ++ " provided, " ++ toString decl.params.size ++ " expected"));
|
||||
checkArgs ys
|
||||
|
||||
def checkPartialApp (c : FunId) (ys : Array Arg) : M Unit :=
|
||||
do
|
||||
decl ← getDecl c,
|
||||
unless (ys.size < decl.params.size) (throw ("too many arguments too partial application '" ++ toString c ++ "', num. args: " ++ toString ys.size ++ ", arity: " ++ toString decl.params.size)),
|
||||
decl ← getDecl c;
|
||||
unless (ys.size < decl.params.size) (throw ("too many arguments too partial application '" ++ toString c ++ "', num. args: " ++ toString ys.size ++ ", arity: " ++ toString decl.params.size));
|
||||
checkArgs ys
|
||||
|
||||
def checkExpr (ty : IRType) : Expr → M Unit
|
||||
|
|
@ -87,22 +87,22 @@ def checkExpr (ty : IRType) : Expr → M Unit
|
|||
| (Expr.lit _) := pure ()
|
||||
|
||||
@[inline] def withParams (ps : Array Param) (k : M Unit) : M Unit :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
localCtx ← ps.mfoldl (λ (ctx : LocalContext) p, do
|
||||
when (ctx.contains p.x.idx) $ throw ("invalid parameter declaration, shadowing is not allowed"),
|
||||
pure $ ctx.addParam p) ctx.localCtx,
|
||||
when (ctx.contains p.x.idx) $ throw ("invalid parameter declaration, shadowing is not allowed");
|
||||
pure $ ctx.addParam p) ctx.localCtx;
|
||||
adaptReader (λ _, { localCtx := localCtx, .. ctx }) k
|
||||
|
||||
partial def checkFnBody : FnBody → M Unit
|
||||
| (FnBody.vdecl x t v b) := do
|
||||
checkExpr t v,
|
||||
ctx ← read,
|
||||
when (ctx.localCtx.contains x.idx) $ throw ("invalid variable declaration, shadowing is not allowed"),
|
||||
checkExpr t v;
|
||||
ctx ← read;
|
||||
when (ctx.localCtx.contains x.idx) $ throw ("invalid variable declaration, shadowing is not allowed");
|
||||
adaptReader (λ ctx : Context, { localCtx := ctx.localCtx.addLocal x t v, .. ctx }) (checkFnBody b)
|
||||
| (FnBody.jdecl j ys v b) := do
|
||||
withParams ys (checkFnBody v),
|
||||
ctx ← read,
|
||||
when (ctx.localCtx.contains j.idx) $ throw ("invalid join point declaration, shadowing is not allowed"),
|
||||
withParams ys (checkFnBody v);
|
||||
ctx ← read;
|
||||
when (ctx.localCtx.contains j.idx) $ throw ("invalid join point declaration, shadowing is not allowed");
|
||||
adaptReader (λ ctx : Context, { localCtx := ctx.localCtx.addJP j ys v, .. ctx }) (checkFnBody b)
|
||||
| (FnBody.set x _ y b) := checkVar x *> checkArg y *> checkFnBody b
|
||||
| (FnBody.uset x _ y b) := checkVar x *> checkVar y *> checkFnBody b
|
||||
|
|
@ -125,7 +125,7 @@ end Checker
|
|||
|
||||
def checkDecl (decls : Array Decl) (decl : Decl) : CompilerM Unit :=
|
||||
do
|
||||
env ← getEnv,
|
||||
env ← getEnv;
|
||||
match Checker.checkDecl decl { env := env, decls := decls } with
|
||||
| Except.error msg := throw ("IR check failed at '" ++ toString decl.name ++ "', error: " ++ msg)
|
||||
| other := pure ()
|
||||
|
|
|
|||
|
|
@ -49,14 +49,14 @@ match opts.find optName with
|
|||
| other := opts.getBool tracePrefixOptionName
|
||||
|
||||
private def logDeclsAux (optName : Name) (cls : Name) (decls : Array Decl) : CompilerM Unit :=
|
||||
do opts ← read,
|
||||
do opts ← read;
|
||||
when (isLogEnabledFor opts optName) $ log (LogEntry.step cls decls)
|
||||
|
||||
@[inline] def logDecls (cls : Name) (decl : Array Decl) : CompilerM Unit :=
|
||||
logDeclsAux (tracePrefixOptionName ++ cls) cls decl
|
||||
|
||||
private def logMessageIfAux {α : Type} [HasFormat α] (optName : Name) (a : α) : CompilerM Unit :=
|
||||
do opts ← read,
|
||||
do opts ← read;
|
||||
when (isLogEnabledFor opts optName) $ log (LogEntry.message (format a))
|
||||
|
||||
@[inline] def logMessageIf {α : Type} [HasFormat α] (cls : Name) (a : α) : CompilerM Unit :=
|
||||
|
|
@ -95,15 +95,15 @@ def findEnvDecl (env : Environment) (n : Name) : Option Decl :=
|
|||
(declMapExt.getState env).find n
|
||||
|
||||
def findDecl (n : Name) : CompilerM (Option Decl) :=
|
||||
do s ← get,
|
||||
do s ← get;
|
||||
pure $ findEnvDecl s.env n
|
||||
|
||||
def containsDecl (n : Name) : CompilerM Bool :=
|
||||
do s ← get,
|
||||
do s ← get;
|
||||
pure $ (declMapExt.getState s.env).contains n
|
||||
|
||||
def getDecl (n : Name) : CompilerM Decl :=
|
||||
do (some decl) ← findDecl n | throw ("unknown declaration '" ++ toString n ++ "'"),
|
||||
do (some decl) ← findDecl n | throw ("unknown declaration '" ++ toString n ++ "'");
|
||||
pure decl
|
||||
|
||||
@[export lean.ir.add_decl_core]
|
||||
|
|
@ -114,7 +114,7 @@ def getDecls (env : Environment) : List Decl :=
|
|||
declMapExt.getEntries env
|
||||
|
||||
def getEnv : CompilerM Environment :=
|
||||
do s ← get, pure s.env
|
||||
do s ← get; pure s.env
|
||||
|
||||
def addDecl (decl : Decl) : CompilerM Unit :=
|
||||
modifyEnv (λ env, declMapExt.addEntry env decl)
|
||||
|
|
@ -128,16 +128,16 @@ match decls.find (λ decl, if decl.name == n then some decl else none) with
|
|||
| none := (declMapExt.getState env).find n
|
||||
|
||||
def findDecl' (n : Name) (decls : Array Decl) : CompilerM (Option Decl) :=
|
||||
do s ← get, pure $ findEnvDecl' s.env n decls
|
||||
do s ← get; pure $ findEnvDecl' s.env n decls
|
||||
|
||||
def containsDecl' (n : Name) (decls : Array Decl) : CompilerM Bool :=
|
||||
if decls.any (λ decl, decl.name == n) then pure true
|
||||
else do
|
||||
s ← get,
|
||||
s ← get;
|
||||
pure $ (declMapExt.getState s.env).contains n
|
||||
|
||||
def getDecl' (n : Name) (decls : Array Decl) : CompilerM Decl :=
|
||||
do (some decl) ← findDecl' n decls | throw ("unknown declaration '" ++ toString n ++ "'"),
|
||||
do (some decl) ← findDecl' n decls | throw ("unknown declaration '" ++ toString n ++ "'");
|
||||
pure decl
|
||||
|
||||
end IR
|
||||
|
|
|
|||
|
|
@ -24,30 +24,30 @@ namespace IR
|
|||
|
||||
private def compileAux (decls : Array Decl) : CompilerM Unit :=
|
||||
do
|
||||
logDecls `init decls,
|
||||
checkDecls decls,
|
||||
let decls := decls.map Decl.pushProj,
|
||||
logDecls `push_proj decls,
|
||||
let decls := decls.map Decl.insertResetReuse,
|
||||
logDecls `reset_reuse decls,
|
||||
let decls := decls.map Decl.elimDead,
|
||||
logDecls `elim_dead decls,
|
||||
let decls := decls.map Decl.simpCase,
|
||||
logDecls `simp_case decls,
|
||||
let decls := decls.map Decl.normalizeIds,
|
||||
decls ← inferBorrow decls,
|
||||
logDecls `borrow decls,
|
||||
decls ← explicitBoxing decls,
|
||||
logDecls `boxing decls,
|
||||
decls ← explicitRC decls,
|
||||
logDecls `rc decls,
|
||||
let decls := decls.map Decl.expandResetReuse,
|
||||
logDecls `expand_reset_reuse decls,
|
||||
let decls := decls.map Decl.pushProj,
|
||||
logDecls `push_proj decls,
|
||||
logDecls `result decls,
|
||||
checkDecls decls,
|
||||
addDecls decls,
|
||||
logDecls `init decls;
|
||||
checkDecls decls;
|
||||
let decls := decls.map Decl.pushProj;
|
||||
logDecls `push_proj decls;
|
||||
let decls := decls.map Decl.insertResetReuse;
|
||||
logDecls `reset_reuse decls;
|
||||
let decls := decls.map Decl.elimDead;
|
||||
logDecls `elim_dead decls;
|
||||
let decls := decls.map Decl.simpCase;
|
||||
logDecls `simp_case decls;
|
||||
let decls := decls.map Decl.normalizeIds;
|
||||
decls ← inferBorrow decls;
|
||||
logDecls `borrow decls;
|
||||
decls ← explicitBoxing decls;
|
||||
logDecls `boxing decls;
|
||||
decls ← explicitRC decls;
|
||||
logDecls `rc decls;
|
||||
let decls := decls.map Decl.expandResetReuse;
|
||||
logDecls `expand_reset_reuse decls;
|
||||
let decls := decls.map Decl.pushProj;
|
||||
logDecls `push_proj decls;
|
||||
logDecls `result decls;
|
||||
checkDecls decls;
|
||||
addDecls decls;
|
||||
pure ()
|
||||
|
||||
@[export lean.ir.compile_core]
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ abbrev M := ReaderT Context (EState String String)
|
|||
def getEnv : M Environment := Context.env <$> read
|
||||
def getModName : M Name := Context.modName <$> read
|
||||
def getDecl (n : Name) : M Decl :=
|
||||
do env ← getEnv,
|
||||
do env ← getEnv;
|
||||
match findEnvDecl env n with
|
||||
| some d := pure d
|
||||
| none := throw ("unknown declaration '" ++ toString n ++ "'")
|
||||
|
|
@ -77,7 +77,7 @@ def openNamespaces (n : Name) : M Unit :=
|
|||
openNamespacesAux n.getPrefix
|
||||
|
||||
def openNamespacesFor (n : Name) : M Unit :=
|
||||
do env ← getEnv,
|
||||
do env ← getEnv;
|
||||
match getExportNameFor env n with
|
||||
| none := pure ()
|
||||
| some n := openNamespaces n
|
||||
|
|
@ -91,7 +91,7 @@ def closeNamespaces (n : Name) : M Unit :=
|
|||
closeNamespacesAux n.getPrefix
|
||||
|
||||
def closeNamespacesFor (n : Name) : M Unit :=
|
||||
do env ← getEnv,
|
||||
do env ← getEnv;
|
||||
match getExportNameFor env n with
|
||||
| none := pure ()
|
||||
| some n := closeNamespaces n
|
||||
|
|
@ -100,14 +100,14 @@ def throwInvalidExportName {α : Type} (n : Name) : M α :=
|
|||
throw ("invalid export name '" ++ toString n ++ "'")
|
||||
|
||||
def toBaseCppName (n : Name) : M String :=
|
||||
do env ← getEnv,
|
||||
do env ← getEnv;
|
||||
match getExportNameFor env n with
|
||||
| some (Name.mkString _ s) := pure s
|
||||
| some _ := throwInvalidExportName n
|
||||
| none := if n == `main then pure leanMainFn else pure n.mangle
|
||||
|
||||
def toCppName (n : Name) : M String :=
|
||||
do env ← getEnv,
|
||||
do env ← getEnv;
|
||||
match getExportNameFor env n with
|
||||
| some s := pure (s.toStringWithSep "::")
|
||||
| none := if n == `main then pure leanMainFn else pure n.mangle
|
||||
|
|
@ -116,7 +116,7 @@ def emitCppName (n : Name) : M Unit :=
|
|||
toCppName n >>= emit
|
||||
|
||||
def toCppInitName (n : Name) : M String :=
|
||||
do env ← getEnv,
|
||||
do env ← getEnv;
|
||||
match getExportNameFor env n with
|
||||
| some (Name.mkString p s) := pure $ (Name.mkString p ("_init_" ++ s)).toStringWithSep "::"
|
||||
| some _ := throwInvalidExportName n
|
||||
|
|
@ -127,24 +127,24 @@ toCppInitName n >>= emit
|
|||
|
||||
def emitFnDeclAux (decl : Decl) (cppBaseName : String) (addExternForConsts : Bool) : M Unit :=
|
||||
do
|
||||
let ps := decl.params,
|
||||
when (ps.isEmpty && addExternForConsts) (emit "extern "),
|
||||
emit (toCppType decl.resultType ++ " " ++ cppBaseName),
|
||||
let ps := decl.params;
|
||||
when (ps.isEmpty && addExternForConsts) (emit "extern ");
|
||||
emit (toCppType decl.resultType ++ " " ++ cppBaseName);
|
||||
unless (ps.isEmpty) $ do {
|
||||
emit "(",
|
||||
emit "(";
|
||||
ps.size.mfor $ λ i, do {
|
||||
when (i > 0) (emit ", "),
|
||||
when (i > 0) (emit ", ");
|
||||
emit (toCppType (ps.get i).ty)
|
||||
},
|
||||
};
|
||||
emit ")"
|
||||
},
|
||||
};
|
||||
emitLn ";"
|
||||
|
||||
def emitFnDecl (decl : Decl) (addExternForConsts : Bool) : M Unit :=
|
||||
do
|
||||
openNamespacesFor decl.name,
|
||||
cppBaseName ← toBaseCppName decl.name,
|
||||
emitFnDeclAux decl cppBaseName addExternForConsts,
|
||||
openNamespacesFor decl.name;
|
||||
cppBaseName ← toBaseCppName decl.name;
|
||||
emitFnDeclAux decl cppBaseName addExternForConsts;
|
||||
closeNamespacesFor decl.name
|
||||
|
||||
def cppQualifiedNameToName (s : String) : Name :=
|
||||
|
|
@ -152,48 +152,48 @@ def cppQualifiedNameToName (s : String) : Name :=
|
|||
|
||||
def emitExternDeclAux (decl : Decl) (cppName : String) : M Unit :=
|
||||
do
|
||||
let qCppName := cppQualifiedNameToName cppName,
|
||||
openNamespaces qCppName,
|
||||
env ← getEnv,
|
||||
let extC := isExternC env decl.name,
|
||||
when extC (emit "extern \"C\" "),
|
||||
(Name.mkString _ qCppBaseName) ← pure qCppName | throw "invalid name",
|
||||
emitFnDeclAux decl qCppBaseName (!extC),
|
||||
let qCppName := cppQualifiedNameToName cppName;
|
||||
openNamespaces qCppName;
|
||||
env ← getEnv;
|
||||
let extC := isExternC env decl.name;
|
||||
when extC (emit "extern \"C\" ");
|
||||
(Name.mkString _ qCppBaseName) ← pure qCppName | throw "invalid name";
|
||||
emitFnDeclAux decl qCppBaseName (!extC);
|
||||
closeNamespaces qCppName
|
||||
|
||||
def emitFnDecls : M Unit :=
|
||||
do
|
||||
env ← getEnv,
|
||||
let decls := getDecls env,
|
||||
let modDecls : NameSet := decls.foldl (λ s d, s.insert d.name) {},
|
||||
let usedDecls : NameSet := decls.foldl (λ s d, collectUsedDecls env d (s.insert d.name)) {},
|
||||
let usedDecls := usedDecls.toList,
|
||||
env ← getEnv;
|
||||
let decls := getDecls env;
|
||||
let modDecls : NameSet := decls.foldl (λ s d, s.insert d.name) {};
|
||||
let usedDecls : NameSet := decls.foldl (λ s d, collectUsedDecls env d (s.insert d.name)) {};
|
||||
let usedDecls := usedDecls.toList;
|
||||
usedDecls.mfor $ λ n, do
|
||||
decl ← getDecl n,
|
||||
decl ← getDecl n;
|
||||
match getExternNameFor env `cpp decl.name with
|
||||
| some cppName := emitExternDeclAux decl cppName
|
||||
| none := emitFnDecl decl (!modDecls.contains n)
|
||||
|
||||
def emitMainFn : M Unit :=
|
||||
do
|
||||
d ← getDecl `main,
|
||||
d ← getDecl `main;
|
||||
match d with
|
||||
| Decl.fdecl f xs t b := do
|
||||
unless (xs.size == 2 || xs.size == 1) (throw "invalid main function, incorrect arity when generating code"),
|
||||
env ← getEnv,
|
||||
let usesLeanAPI := usesLeanNamespace env d,
|
||||
when usesLeanAPI (emitLn "namespace lean { void initialize(); }"),
|
||||
emitLn "int main(int argc, char ** argv) {",
|
||||
unless (xs.size == 2 || xs.size == 1) (throw "invalid main function, incorrect arity when generating code");
|
||||
env ← getEnv;
|
||||
let usesLeanAPI := usesLeanNamespace env d;
|
||||
when usesLeanAPI (emitLn "namespace lean { void initialize(); }");
|
||||
emitLn "int main(int argc, char ** argv) {";
|
||||
if usesLeanAPI then
|
||||
emitLn "lean::initialize();"
|
||||
else
|
||||
emitLn "lean::initialize_runtime_module();",
|
||||
emitLn "obj * w = lean::io_mk_world();",
|
||||
modName ← getModName,
|
||||
emitLn ("w = initialize_" ++ (modName.mangle "") ++ "(w);"),
|
||||
emitLn "lean::initialize_runtime_module();";
|
||||
emitLn "obj * w = lean::io_mk_world();";
|
||||
modName ← getModName;
|
||||
emitLn ("w = initialize_" ++ (modName.mangle "") ++ "(w);");
|
||||
emitLns ["lean::io_mark_end_initialization();",
|
||||
"if (io_result_is_ok(w)) {",
|
||||
"lean::scoped_task_manager tmanager(lean::hardware_concurrency());"],
|
||||
"lean::scoped_task_manager tmanager(lean::hardware_concurrency());"];
|
||||
if xs.size == 2 then do {
|
||||
emitLns ["obj* in = lean::box(0);",
|
||||
"int i = argc;",
|
||||
|
|
@ -201,12 +201,12 @@ match d with
|
|||
" i--;",
|
||||
" obj* n = lean::alloc_cnstr(1,2,0); lean::cnstr_set(n, 0, lean::mk_string(argv[i])); lean::cnstr_set(n, 1, in);",
|
||||
" in = n;",
|
||||
"}"],
|
||||
"}"];
|
||||
emitLn ("w = " ++ leanMainFn ++ "(in, w);")
|
||||
} else do {
|
||||
emitLn ("w = " ++ leanMainFn ++ "(w);")
|
||||
},
|
||||
emitLn "}",
|
||||
};
|
||||
emitLn "}";
|
||||
emitLns ["if (io_result_is_ok(w)) {",
|
||||
" int ret = lean::unbox(io_result_get_value(w));",
|
||||
" lean::dec_ref(w);",
|
||||
|
|
@ -215,13 +215,13 @@ match d with
|
|||
" lean::io_result_show_error(w);",
|
||||
" lean::dec_ref(w);",
|
||||
" return 1;",
|
||||
"}"],
|
||||
"}"];
|
||||
emitLn "}"
|
||||
| other := throw "function declaration expected"
|
||||
|
||||
def hasMainFn : M Bool :=
|
||||
do env ← getEnv,
|
||||
let decls := getDecls env,
|
||||
do env ← getEnv;
|
||||
let decls := getDecls env;
|
||||
pure $ decls.any (λ d, d.name == `main)
|
||||
|
||||
def emitMainFnIfNeeded : M Unit :=
|
||||
|
|
@ -229,16 +229,16 @@ mwhen hasMainFn emitMainFn
|
|||
|
||||
def emitFileHeader : M Unit :=
|
||||
do
|
||||
env ← getEnv,
|
||||
modName ← getModName,
|
||||
emitLn "// Lean compiler output",
|
||||
emitLn ("// Module: " ++ toString modName),
|
||||
emit "// Imports:",
|
||||
env.imports.mfor $ λ m, emit (" " ++ toString m),
|
||||
emitLn "",
|
||||
emitLn "#include \"runtime/object.h\"",
|
||||
emitLn "#include \"runtime/apply.h\"",
|
||||
mwhen hasMainFn $ emitLn "#include \"runtime/init_module.h\"",
|
||||
env ← getEnv;
|
||||
modName ← getModName;
|
||||
emitLn "// Lean compiler output";
|
||||
emitLn ("// Module: " ++ toString modName);
|
||||
emit "// Imports:";
|
||||
env.imports.mfor $ λ m, emit (" " ++ toString m);
|
||||
emitLn "";
|
||||
emitLn "#include \"runtime/object.h\"";
|
||||
emitLn "#include \"runtime/apply.h\"";
|
||||
mwhen hasMainFn $ emitLn "#include \"runtime/init_module.h\"";
|
||||
emitLns [
|
||||
"typedef lean::object obj; typedef lean::usize usize;",
|
||||
"typedef lean::uint8 uint8; typedef lean::uint16 uint16;",
|
||||
|
|
@ -256,26 +256,26 @@ def throwUnknownVar {α : Type} (x : VarId) : M α :=
|
|||
throw ("unknown variable '" ++ toString x ++ "'")
|
||||
|
||||
def isObj (x : VarId) : M Bool :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
match ctx.varMap.find x with
|
||||
| some t := pure t.isObj
|
||||
| none := throwUnknownVar x
|
||||
|
||||
def getJPParams (j : JoinPointId) : M (Array Param) :=
|
||||
do ctx ← read,
|
||||
do ctx ← read;
|
||||
match ctx.jpMap.find j with
|
||||
| some ps := pure ps
|
||||
| none := throw "unknown join point"
|
||||
|
||||
def declareVar (x : VarId) (t : IRType) : M Unit :=
|
||||
do emit (toCppType t), emit " ", emit x, emit "; "
|
||||
do emit (toCppType t); emit " "; emit x; emit "; "
|
||||
|
||||
def declareParams (ps : Array Param) : M Unit :=
|
||||
ps.mfor $ λ p, declareVar p.x p.ty
|
||||
|
||||
partial def declareVars : FnBody → Bool → M Bool
|
||||
| e@(FnBody.vdecl x t _ b) d := do
|
||||
ctx ← read,
|
||||
ctx ← read;
|
||||
if isTailCallTo ctx.mainFn e then
|
||||
pure d
|
||||
else
|
||||
|
|
@ -285,9 +285,9 @@ partial def declareVars : FnBody → Bool → M Bool
|
|||
|
||||
def emitTag (x : VarId) : M Unit :=
|
||||
do
|
||||
xIsObj ← isObj x,
|
||||
xIsObj ← isObj x;
|
||||
if xIsObj then do
|
||||
emit "lean::obj_tag(", emit x, emit ")"
|
||||
emit "lean::obj_tag("; emit x; emit ")"
|
||||
else
|
||||
emit x
|
||||
|
||||
|
|
@ -299,75 +299,75 @@ else match alts.get 0 with
|
|||
|
||||
def emitIf (emitBody : FnBody → M Unit) (x : VarId) (tag : Nat) (t : FnBody) (e : FnBody) : M Unit :=
|
||||
do
|
||||
emit "if (", emitTag x, emit " == ", emit tag, emitLn ")",
|
||||
emitBody t,
|
||||
emitLn "else",
|
||||
emit "if ("; emitTag x; emit " == "; emit tag; emitLn ")";
|
||||
emitBody t;
|
||||
emitLn "else";
|
||||
emitBody e
|
||||
|
||||
def emitCase (emitBody : FnBody → M Unit) (x : VarId) (alts : Array Alt) : M Unit :=
|
||||
match isIf alts with
|
||||
| some (tag, t, e) := emitIf emitBody x tag t e
|
||||
| _ := do
|
||||
emit "switch (", emitTag x, emitLn ") {",
|
||||
let alts := ensureHasDefault alts,
|
||||
emit "switch ("; emitTag x; emitLn ") {";
|
||||
let alts := ensureHasDefault alts;
|
||||
alts.mfor $ λ alt, match alt with
|
||||
| Alt.ctor c b := emit "case " *> emit c.cidx *> emitLn ":" *> emitBody b
|
||||
| Alt.default b := emitLn "default: " *> emitBody b,
|
||||
| Alt.default b := emitLn "default: " *> emitBody b;
|
||||
emitLn "}"
|
||||
|
||||
def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit :=
|
||||
do
|
||||
emit (if checkRef then "lean::inc" else "lean::inc_ref"),
|
||||
emit "(" *> emit x,
|
||||
when (n != 1) (emit ", " *> emit n),
|
||||
emit (if checkRef then "lean::inc" else "lean::inc_ref");
|
||||
emit "(" *> emit x;
|
||||
when (n != 1) (emit ", " *> emit n);
|
||||
emitLn ");"
|
||||
|
||||
def emitDec (x : VarId) (n : Nat) (checkRef : Bool) : M Unit :=
|
||||
do
|
||||
emit (if checkRef then "lean::dec" else "lean::dec_ref"),
|
||||
emit "(" *> emit x,
|
||||
when (n != 1) (emit ", " *> emit n),
|
||||
emit (if checkRef then "lean::dec" else "lean::dec_ref");
|
||||
emit "("; emit x;
|
||||
when (n != 1) (do emit ", "; emit n);
|
||||
emitLn ");"
|
||||
|
||||
def emitDel (x : VarId) : M Unit :=
|
||||
do emit "lean::free_heap_obj(", emit x, emitLn ");"
|
||||
do emit "lean::free_heap_obj("; emit x; emitLn ");"
|
||||
|
||||
def emitSetTag (x : VarId) (i : Nat) : M Unit :=
|
||||
do emit "lean::cnstr_set_tag(", emit x, emit ", ", emit i, emitLn ");"
|
||||
do emit "lean::cnstr_set_tag("; emit x; emit ", "; emit i; emitLn ");"
|
||||
|
||||
def emitSet (x : VarId) (i : Nat) (y : Arg) : M Unit :=
|
||||
do emit "lean::cnstr_set(", emit x, emit ", ", emit i, emit ", ", emitArg y, emitLn ");"
|
||||
do emit "lean::cnstr_set("; emit x; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
|
||||
|
||||
def emitOffset (n : Nat) (offset : Nat) : M Unit :=
|
||||
if n > 0 then do
|
||||
emit "sizeof(void*)*", emit n,
|
||||
emit "sizeof(void*)*"; emit n;
|
||||
when (offset > 0) (emit " + " *> emit offset)
|
||||
else
|
||||
emit offset
|
||||
|
||||
def emitUSet (x : VarId) (n : Nat) (y : VarId) : M Unit :=
|
||||
do emit "lean::cnstr_set_scalar(", emit x, emit ", ", emitOffset n 0, emit ", ", emit y, emitLn ");"
|
||||
do emit "lean::cnstr_set_scalar("; emit x; emit ", "; emitOffset n 0; emit ", "; emit y; emitLn ");"
|
||||
|
||||
def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) : M Unit :=
|
||||
do emit "lean::cnstr_set_scalar(", emit x, emit ", ", emitOffset n offset, emit ", ", emit y, emitLn ");"
|
||||
do emit "lean::cnstr_set_scalar("; emit x; emit ", "; emitOffset n offset; emit ", "; emit y; emitLn ");"
|
||||
|
||||
def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit :=
|
||||
do
|
||||
ps ← getJPParams j,
|
||||
unless (xs.size == ps.size) (throw "invalid goto"),
|
||||
ps ← getJPParams j;
|
||||
unless (xs.size == ps.size) (throw "invalid goto");
|
||||
xs.size.mfor $ λ i, do {
|
||||
let p := ps.get i,
|
||||
let x := xs.get i,
|
||||
emit p.x, emit " = ", emitArg x, emitLn ";"
|
||||
},
|
||||
emit "goto ", emit j, emitLn ";"
|
||||
let p := ps.get i;
|
||||
let x := xs.get i;
|
||||
emit p.x; emit " = "; emitArg x; emitLn ";"
|
||||
};
|
||||
emit "goto "; emit j; emitLn ";"
|
||||
|
||||
def emitLhs (z : VarId) : M Unit :=
|
||||
do emit z, emit " = "
|
||||
do emit z; emit " = "
|
||||
|
||||
def emitArgs (ys : Array Arg) : M Unit :=
|
||||
ys.size.mfor $ λ i, do
|
||||
when (i > 0) (emit ", "),
|
||||
when (i > 0) (emit ", ");
|
||||
emitArg (ys.get i)
|
||||
|
||||
def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit :=
|
||||
|
|
@ -377,82 +377,82 @@ else emit "sizeof(size_t)*" *> emit usize *> emit " + " *> emit ssize
|
|||
|
||||
def emitAllocCtor (c : CtorInfo) : M Unit :=
|
||||
do
|
||||
emit "lean::alloc_cnstr(", emit c.cidx, emit ", ", emit c.size, emit ", ",
|
||||
emitCtorScalarSize c.usize c.ssize, emitLn ");"
|
||||
emit "lean::alloc_cnstr("; emit c.cidx; emit ", "; emit c.size; emit ", ";
|
||||
emitCtorScalarSize c.usize c.ssize; emitLn ");"
|
||||
|
||||
def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit :=
|
||||
ys.size.mfor $ λ i, do
|
||||
emit "lean::cnstr_set(", emit z, emit ", ", emit i, emit ", ", emitArg (ys.get i), emitLn ");"
|
||||
emit "lean::cnstr_set("; emit z; emit ", "; emit i; emit ", "; emitArg (ys.get i); emitLn ");"
|
||||
|
||||
def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit :=
|
||||
do
|
||||
emitLhs z,
|
||||
emitLhs z;
|
||||
if c.size == 0 && c.usize == 0 && c.ssize == 0 then do
|
||||
emit "lean::box(", emit c.cidx, emitLn ");"
|
||||
emit "lean::box("; emit c.cidx; emitLn ");"
|
||||
else do
|
||||
emitAllocCtor c, emitCtorSetArgs z ys
|
||||
emitAllocCtor c; emitCtorSetArgs z ys
|
||||
|
||||
def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit :=
|
||||
do
|
||||
emit "if (lean::is_exclusive(", emit x, emitLn ")) {",
|
||||
emit "if (lean::is_exclusive("; emit x; emitLn ")) {";
|
||||
n.mfor $ λ i, do {
|
||||
emit " lean::cnstr_release(", emit x, emit ", ", emit i, emitLn ");"
|
||||
},
|
||||
emit " ", emitLhs z, emit x, emitLn ";",
|
||||
emitLn "} else {",
|
||||
emit " lean::dec_ref(", emit x, emitLn ");",
|
||||
emit " ", emitLhs z, emitLn "lean::box(0);",
|
||||
emit " lean::cnstr_release("; emit x; emit ", "; emit i; emitLn ");"
|
||||
};
|
||||
emit " "; emitLhs z; emit x; emitLn ";";
|
||||
emitLn "} else {";
|
||||
emit " lean::dec_ref("; emit x; emitLn ");";
|
||||
emit " "; emitLhs z; emitLn "lean::box(0);";
|
||||
emitLn "}"
|
||||
|
||||
def emitReuse (z : VarId) (x : VarId) (c : CtorInfo) (updtHeader : Bool) (ys : Array Arg) : M Unit :=
|
||||
do
|
||||
emit "if (lean::is_scalar(", emit x, emitLn ")) {",
|
||||
emit " ", emitLhs z, emitAllocCtor c,
|
||||
emitLn "} else {",
|
||||
emit " ", emitLhs z, emit x, emitLn ";",
|
||||
when updtHeader (do emit " lean::cnstr_set_tag(", emit z, emit ", ", emit c.cidx, emitLn ");"),
|
||||
emitLn "}",
|
||||
emit "if (lean::is_scalar("; emit x; emitLn ")) {";
|
||||
emit " "; emitLhs z; emitAllocCtor c;
|
||||
emitLn "} else {";
|
||||
emit " "; emitLhs z; emit x; emitLn ";";
|
||||
when updtHeader (do emit " lean::cnstr_set_tag("; emit z; emit ", "; emit c.cidx; emitLn ");");
|
||||
emitLn "}";
|
||||
emitCtorSetArgs z ys
|
||||
|
||||
def emitProj (z : VarId) (i : Nat) (x : VarId) : M Unit :=
|
||||
do emitLhs z, emit "lean::cnstr_get(", emit x, emit ", ", emit i, emitLn ");"
|
||||
do emitLhs z; emit "lean::cnstr_get("; emit x; emit ", "; emit i; emitLn ");"
|
||||
|
||||
def emitUProj (z : VarId) (i : Nat) (x : VarId) : M Unit :=
|
||||
do emitLhs z, emit "lean::cnstr_get_scalar<usize>(", emit x, emit ", sizeof(void*)*", emit i, emitLn ");"
|
||||
do emitLhs z; emit "lean::cnstr_get_scalar<usize>("; emit x; emit ", sizeof(void*)*"; emit i; emitLn ");"
|
||||
|
||||
def emitSProj (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M Unit :=
|
||||
do emitLhs z, emit "lean::cnstr_get_scalar<", emit (toCppType t), emit ">(", emit x, emit ", ", emitOffset n offset, emitLn ");"
|
||||
do emitLhs z; emit "lean::cnstr_get_scalar<"; emit (toCppType t); emit ">("; emit x; emit ", "; emitOffset n offset; emitLn ");"
|
||||
|
||||
def toStringArgs (ys : Array Arg) : List String :=
|
||||
ys.toList.map argToCppString
|
||||
|
||||
def emitFullApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit :=
|
||||
do
|
||||
emitLhs z,
|
||||
decl ← getDecl f,
|
||||
emitLhs z;
|
||||
decl ← getDecl f;
|
||||
match decl with
|
||||
| Decl.extern _ _ _ extData :=
|
||||
match mkExternCall extData `cpp (toStringArgs ys) with
|
||||
| some c := emit c *> emitLn ";"
|
||||
| none := throw "failed to emit extern application"
|
||||
| _ := do emitCppName f, when (ys.size > 0) (do emit "(", emitArgs ys, emit ")"), emitLn ";"
|
||||
| _ := do emitCppName f; when (ys.size > 0) (do emit "("; emitArgs ys; emit ")"); emitLn ";"
|
||||
|
||||
def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit :=
|
||||
do
|
||||
decl ← getDecl f,
|
||||
let arity := decl.params.size,
|
||||
emitLhs z, emit "lean::alloc_closure(reinterpret_cast<void*>(", emitCppName f, emit "), ", emit arity, emit ", ", emit ys.size, emitLn ");",
|
||||
decl ← getDecl f;
|
||||
let arity := decl.params.size;
|
||||
emitLhs z; emit "lean::alloc_closure(reinterpret_cast<void*>("; emitCppName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");";
|
||||
ys.size.mfor $ λ i, do {
|
||||
let y := ys.get i,
|
||||
emit "lean::closure_set(", emit z, emit ", ", emit i, emit ", ", emitArg y, emitLn ");"
|
||||
let y := ys.get i;
|
||||
emit "lean::closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
|
||||
}
|
||||
|
||||
def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=
|
||||
if ys.size > closureMaxArgs then do
|
||||
emit "{ obj* _aargs[] = {", emitArgs ys, emitLn "};",
|
||||
emitLhs z, emit "lean::apply_m(", emit f, emit ", ", emit ys.size, emitLn ", _aargs); }"
|
||||
emit "{ obj* _aargs[] = {"; emitArgs ys; emitLn "};";
|
||||
emitLhs z; emit "lean::apply_m("; emit f; emit ", "; emit ys.size; emitLn ", _aargs); }"
|
||||
else do
|
||||
emitLhs z, emit "lean::apply_", emit ys.size, emit "(", emit f, emit ", ", emitArgs ys, emitLn ");"
|
||||
emitLhs z; emit "lean::apply_"; emit ys.size; emit "("; emit f; emit ", "; emitArgs ys; emitLn ");"
|
||||
|
||||
def emitBoxFn (xType : IRType) : M Unit :=
|
||||
match xType with
|
||||
|
|
@ -463,24 +463,24 @@ match xType with
|
|||
| other := emit "lean::box"
|
||||
|
||||
def emitBox (z : VarId) (x : VarId) (xType : IRType) : M Unit :=
|
||||
do emitLhs z, emitBoxFn xType, emit "(", emit x, emitLn ");"
|
||||
do emitLhs z; emitBoxFn xType; emit "("; emit x; emitLn ");"
|
||||
|
||||
def emitUnbox (z : VarId) (t : IRType) (x : VarId) : M Unit :=
|
||||
do
|
||||
emitLhs z,
|
||||
emitLhs z;
|
||||
match t with
|
||||
| IRType.usize := emit "lean::unbox_size_t"
|
||||
| IRType.uint32 := emit "lean::unbox_uint32"
|
||||
| IRType.uint64 := emit "lean::unbox_uint64"
|
||||
| IRType.float := throw "floats are not supported yet"
|
||||
| other := emit "lean::unbox",
|
||||
emit "(", emit x, emitLn ");"
|
||||
| other := emit "lean::unbox";
|
||||
emit "("; emit x; emitLn ");"
|
||||
|
||||
def emitIsShared (z : VarId) (x : VarId) : M Unit :=
|
||||
do emitLhs z, emit "!lean::is_exclusive(", emit x, emitLn ");"
|
||||
do emitLhs z; emit "!lean::is_exclusive("; emit x; emitLn ");"
|
||||
|
||||
def emitIsTaggedPtr (z : VarId) (x : VarId) : M Unit :=
|
||||
do emitLhs z, emit "!lean::is_scalar(", emit x, emitLn ");"
|
||||
do emitLhs z; emit "!lean::is_scalar("; emit x; emitLn ");"
|
||||
|
||||
def toHexDigit (c : Nat) : String :=
|
||||
String.singleton c.digitChar
|
||||
|
|
@ -501,9 +501,9 @@ q ++ "\""
|
|||
|
||||
def emitNumLit (t : IRType) (v : Nat) : M Unit :=
|
||||
if t.isObj then do
|
||||
emit "lean::mk_nat_obj(",
|
||||
emit "lean::mk_nat_obj(";
|
||||
if v < uint32Sz then emit v *> emit "u"
|
||||
else emit "lean::mpz(\"" *> emit v *> emit "\")",
|
||||
else emit "lean::mpz(\"" *> emit v *> emit "\")";
|
||||
emit ")"
|
||||
else
|
||||
emit v
|
||||
|
|
@ -512,7 +512,7 @@ def emitLit (z : VarId) (t : IRType) (v : LitVal) : M Unit :=
|
|||
emitLhs z *>
|
||||
match v with
|
||||
| LitVal.num v := emitNumLit t v *> emitLn ";"
|
||||
| LitVal.str v := do emit "lean::mk_string(", emit (quoteString v), emitLn ");"
|
||||
| LitVal.str v := do emit "lean::mk_string("; emit (quoteString v); emitLn ");"
|
||||
|
||||
def emitVDecl (z : VarId) (t : IRType) (v : Expr) : M Unit :=
|
||||
match v with
|
||||
|
|
@ -533,7 +533,7 @@ match v with
|
|||
|
||||
def isTailCall (x : VarId) (v : Expr) (b : FnBody) : M Bool :=
|
||||
do
|
||||
ctx ← read,
|
||||
ctx ← read;
|
||||
match v, b with
|
||||
| Expr.fap f _, FnBody.ret (Arg.var y) := pure $ f == ctx.mainFn && x == y
|
||||
| _, _ := pure false
|
||||
|
|
@ -567,35 +567,38 @@ n.any $ λ i,
|
|||
def emitTailCall (v : Expr) : M Unit :=
|
||||
match v with
|
||||
| Expr.fap _ ys := do
|
||||
ctx ← read,
|
||||
let ps := ctx.mainParams,
|
||||
unless (ps.size == ys.size) (throw "invalid tail call"),
|
||||
ctx ← read;
|
||||
let ps := ctx.mainParams;
|
||||
unless (ps.size == ys.size) (throw "invalid tail call");
|
||||
if overwriteParam ps ys then do {
|
||||
emitLn "{",
|
||||
emitLn "{";
|
||||
ps.size.mfor $ λ i, do {
|
||||
let p := ps.get i, let y := ys.get i,
|
||||
let p := ps.get i;
|
||||
let y := ys.get i;
|
||||
unless (paramEqArg p y) $ do {
|
||||
emit (toCppType p.ty), emit " _tmp_", emit i, emit " = ", emitArg y, emitLn ";"
|
||||
emit (toCppType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
|
||||
}
|
||||
},
|
||||
};
|
||||
ps.size.mfor $ λ i, do {
|
||||
let p := ps.get i, let y := ys.get i,
|
||||
unless (paramEqArg p y) (do emit p.x, emit " = _tmp_", emit i, emitLn ";")
|
||||
},
|
||||
let p := ps.get i;
|
||||
let y := ys.get i;
|
||||
unless (paramEqArg p y) (do emit p.x; emit " = _tmp_"; emit i; emitLn ";")
|
||||
};
|
||||
emitLn "}"
|
||||
} else do {
|
||||
ys.size.mfor $ λ i, do {
|
||||
let p := ps.get i, let y := ys.get i,
|
||||
unless (paramEqArg p y) (do emit p.x, emit " = ", emitArg y, emitLn ";")
|
||||
let p := ps.get i;
|
||||
let y := ys.get i;
|
||||
unless (paramEqArg p y) (do emit p.x; emit " = "; emitArg y; emitLn ";")
|
||||
}
|
||||
},
|
||||
};
|
||||
emitLn "goto _start;"
|
||||
| _ := throw "bug at emitTailCall"
|
||||
|
||||
partial def emitBlock (emitBody : FnBody → M Unit) : FnBody → M Unit
|
||||
| (FnBody.jdecl j xs v b) := emitBlock b
|
||||
| d@(FnBody.vdecl x t v b) :=
|
||||
do ctx ← read, if isTailCallTo ctx.mainFn d then emitTailCall v else emitVDecl x t v *> emitBlock b
|
||||
do ctx ← read; if isTailCallTo ctx.mainFn d then emitTailCall v else emitVDecl x t v *> emitBlock b
|
||||
| (FnBody.inc x n c b) := emitInc x n c *> emitBlock b
|
||||
| (FnBody.dec x n c b) := emitDec x n c *> emitBlock b
|
||||
| (FnBody.del x b) := emitDel x *> emitBlock b
|
||||
|
|
@ -610,45 +613,45 @@ partial def emitBlock (emitBody : FnBody → M Unit) : FnBody → M Unit
|
|||
| FnBody.unreachable := emitLn "lean_unreachable();"
|
||||
|
||||
partial def emitJPs (emitBody : FnBody → M Unit) : FnBody → M Unit
|
||||
| (FnBody.jdecl j xs v b) := do emit j, emitLn ":", emitBody v, emitJPs b
|
||||
| (FnBody.jdecl j xs v b) := do emit j; emitLn ":"; emitBody v; emitJPs b
|
||||
| e := unless e.isTerminal (emitJPs e.body)
|
||||
|
||||
partial def emitFnBody : FnBody → M Unit
|
||||
| b := do
|
||||
emitLn "{",
|
||||
declared ← declareVars b false,
|
||||
when declared (emitLn ""),
|
||||
emitBlock emitFnBody b,
|
||||
emitJPs emitFnBody b,
|
||||
emitLn "{";
|
||||
declared ← declareVars b false;
|
||||
when declared (emitLn "");
|
||||
emitBlock emitFnBody b;
|
||||
emitJPs emitFnBody b;
|
||||
emitLn "}"
|
||||
|
||||
def emitDeclAux (d : Decl) : M Unit :=
|
||||
do
|
||||
env ← getEnv,
|
||||
let (vMap, jpMap) := mkVarJPMaps d,
|
||||
env ← getEnv;
|
||||
let (vMap, jpMap) := mkVarJPMaps d;
|
||||
adaptReader (λ ctx : Context, { varMap := vMap, jpMap := jpMap, .. ctx }) $ do
|
||||
unless (hasInitAttr env d.name) $
|
||||
match d with
|
||||
| Decl.fdecl f xs t b := do
|
||||
openNamespacesFor f,
|
||||
baseName ← toBaseCppName f,
|
||||
emit (toCppType t), emit " ",
|
||||
openNamespacesFor f;
|
||||
baseName ← toBaseCppName f;
|
||||
emit (toCppType t); emit " ";
|
||||
if xs.size > 0 then do {
|
||||
emit baseName,
|
||||
emit "(",
|
||||
emit baseName;
|
||||
emit "(";
|
||||
xs.size.mfor $ λ i, do {
|
||||
when (i > 0) (emit ", "),
|
||||
let x := xs.get i,
|
||||
emit (toCppType x.ty), emit " ", emit(x.x)
|
||||
},
|
||||
when (i > 0) (emit ", ");
|
||||
let x := xs.get i;
|
||||
emit (toCppType x.ty); emit " "; emit(x.x)
|
||||
};
|
||||
emit ")"
|
||||
} else do {
|
||||
emit ("_init_" ++ baseName ++ "()")
|
||||
},
|
||||
emitLn " {",
|
||||
emitLn "_start:",
|
||||
adaptReader (λ ctx : Context, { mainFn := f, mainParams := xs, .. ctx }) (emitFnBody b),
|
||||
emitLn "}",
|
||||
};
|
||||
emitLn " {";
|
||||
emitLn "_start:";
|
||||
adaptReader (λ ctx : Context, { mainFn := f, mainParams := xs, .. ctx }) (emitFnBody b);
|
||||
emitLn "}";
|
||||
closeNamespacesFor f
|
||||
| _ := pure ()
|
||||
|
||||
|
|
@ -660,8 +663,8 @@ catch
|
|||
|
||||
def emitFns : M Unit :=
|
||||
do
|
||||
env ← getEnv,
|
||||
let decls := getDecls env,
|
||||
env ← getEnv;
|
||||
let decls := getDecls env;
|
||||
decls.reverse.mfor emitDecl
|
||||
|
||||
def quoteNameAux : Name → Option String
|
||||
|
|
@ -677,67 +680,67 @@ else quoteNameAux n
|
|||
|
||||
def emitDeclInit (d : Decl) : M Unit :=
|
||||
do
|
||||
env ← getEnv,
|
||||
let n := d.name,
|
||||
env ← getEnv;
|
||||
let n := d.name;
|
||||
if isIOUnitInitFn env n then do {
|
||||
emit "w = ", emitCppName n, emitLn "(w);",
|
||||
emit "w = "; emitCppName n; emitLn "(w);";
|
||||
emitLn "if (io_result_is_error(w)) return w;"
|
||||
} else if (d.params.size == 0) then do {
|
||||
match getInitFnNameFor env d.name with
|
||||
| some initFn := do {
|
||||
emit "w = ", emitCppName initFn, emitLn "(w);",
|
||||
emitLn "if (io_result_is_error(w)) return w;",
|
||||
emitCppName n, emitLn " = io_result_get_value(w);"
|
||||
emit "w = "; emitCppName initFn; emitLn "(w);";
|
||||
emitLn "if (io_result_is_error(w)) return w;";
|
||||
emitCppName n; emitLn " = io_result_get_value(w);"
|
||||
}
|
||||
| _ := do {
|
||||
emitCppName n, emit " = ", emitCppInitName n, emitLn "();"
|
||||
},
|
||||
emitCppName n; emit " = "; emitCppInitName n; emitLn "();"
|
||||
};
|
||||
if d.resultType.isObj then do {
|
||||
emit "lean::mark_persistent(", emitCppName n, emitLn ");",
|
||||
emit "lean::mark_persistent("; emitCppName n; emitLn ");";
|
||||
match quoteName n with
|
||||
| some q := do emit ("lean::register_constant(" ++ q ++ ", "), emitCppName n, emitLn ");"
|
||||
| some q := do emit ("lean::register_constant(" ++ q ++ ", "); emitCppName n; emitLn ");"
|
||||
| none := pure ()
|
||||
} else unless d.resultType.isIrrelevant $ do {
|
||||
match quoteName n with
|
||||
| some q := do emit ("lean::register_constant(" ++ q ++ ", "), emitBoxFn d.resultType, emit "(", emitCppName n, emitLn "));"
|
||||
| some q := do emit ("lean::register_constant(" ++ q ++ ", "); emitBoxFn d.resultType; emit "("; emitCppName n; emitLn "));"
|
||||
| none := pure ()
|
||||
}
|
||||
} else
|
||||
/- TODO(Leo): perhaps we should add a flag to disable closure registration. -/
|
||||
match quoteName d.name with
|
||||
| some q := do
|
||||
let clsName := if requiresBoxedVersion env d then mkBoxedName d.name else d.name,
|
||||
emit ("REGISTER_LEAN_FUNCTION(" ++ q ++ ", " ++ toString d.params.size ++ ", "), emitCppName clsName, emitLn ");"
|
||||
let clsName := if requiresBoxedVersion env d then mkBoxedName d.name else d.name;
|
||||
emit ("REGISTER_LEAN_FUNCTION(" ++ q ++ ", " ++ toString d.params.size ++ ", "); emitCppName clsName; emitLn ");"
|
||||
| _ := pure ()
|
||||
|
||||
def emitInitFn : M Unit :=
|
||||
do
|
||||
env ← getEnv,
|
||||
modName ← getModName,
|
||||
env.imports.mfor $ λ m, emitLn ("obj* initialize_" ++ m.mangle "" ++ "(obj*);"),
|
||||
env ← getEnv;
|
||||
modName ← getModName;
|
||||
env.imports.mfor $ λ m, emitLn ("obj* initialize_" ++ m.mangle "" ++ "(obj*);");
|
||||
emitLns [
|
||||
"static bool _G_initialized = false;",
|
||||
"obj* initialize_" ++ modName.mangle "" ++ "(obj* w) {",
|
||||
"if (_G_initialized) return w;",
|
||||
"_G_initialized = true;",
|
||||
"if (io_result_is_error(w)) return w;"
|
||||
],
|
||||
];
|
||||
env.imports.mfor $ λ m, emitLns [
|
||||
"w = initialize_" ++ m.mangle "" ++ "(w);",
|
||||
"if (io_result_is_error(w)) return w;"
|
||||
],
|
||||
let decls := getDecls env,
|
||||
decls.reverse.mfor emitDeclInit,
|
||||
];
|
||||
let decls := getDecls env;
|
||||
decls.reverse.mfor emitDeclInit;
|
||||
emitLns [
|
||||
"return w;",
|
||||
"}"]
|
||||
|
||||
def main : M Unit :=
|
||||
do
|
||||
emitFileHeader,
|
||||
emitFnDecls,
|
||||
emitFns,
|
||||
emitInitFn,
|
||||
emitFileHeader;
|
||||
emitFnDecls;
|
||||
emitFns;
|
||||
emitInitFn;
|
||||
emitMainFnIfNeeded
|
||||
|
||||
end EmitCpp
|
||||
|
|
|
|||
|
|
@ -29,12 +29,12 @@ partial def visitFnBody : FnBody → M Bool
|
|||
let checkFn (f : FunId) : M Bool :=
|
||||
if leanNameSpacePrefix.isPrefixOf f then pure true
|
||||
else do {
|
||||
s ← get,
|
||||
s ← get;
|
||||
if s.contains f then
|
||||
visitFnBody b
|
||||
else do
|
||||
modify (λ s, s.insert f),
|
||||
env ← read,
|
||||
modify (λ s, s.insert f);
|
||||
env ← read;
|
||||
match findEnvDecl env f with
|
||||
| some (Decl.fdecl _ _ _ fbody) := visitFnBody fbody <||> visitFnBody b
|
||||
| other := visitFnBody b
|
||||
|
|
@ -74,7 +74,7 @@ partial def collectFnBody : FnBody → M Unit
|
|||
| e := unless e.isTerminal $ collectFnBody e.body
|
||||
|
||||
def collectInitDecl (fn : Name) : M Unit :=
|
||||
do env ← read,
|
||||
do env ← read;
|
||||
match getInitFnNameFor env fn with
|
||||
| some initFn := collect initFn
|
||||
| _ := pure ()
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ mask.foldl
|
|||
|
||||
abbrev M := ReaderT Context (State Nat)
|
||||
def mkFresh : M VarId :=
|
||||
do idx ← get, modify (+1), pure { idx := idx }
|
||||
do idx ← get; modify (+1); pure { idx := idx }
|
||||
|
||||
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
|
||||
mask.size.mfold
|
||||
|
|
@ -145,7 +145,7 @@ mask.size.mfold
|
|||
match mask.get i with
|
||||
| some _ := pure b -- code took ownership of this field
|
||||
| none := do
|
||||
fld ← mkFresh,
|
||||
fld ← mkFresh;
|
||||
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true b)))
|
||||
b
|
||||
|
||||
|
|
@ -240,22 +240,22 @@ and `z := reuse x ctor_i ws; F` is replaced with
|
|||
-/
|
||||
def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
|
||||
do
|
||||
ctx ← read,
|
||||
let b := reuseToSet ctx x y b,
|
||||
ctx ← read;
|
||||
let b := reuseToSet ctx x y b;
|
||||
releaseUnreadFields y mask b
|
||||
|
||||
-- Expand `bs; x := reset[n] y; b`
|
||||
partial def expand (mainFn : FnBody → Array FnBody → M FnBody)
|
||||
(bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody :=
|
||||
do
|
||||
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b,
|
||||
let (bs, mask) := eraseProjIncFor n y bs,
|
||||
let bSlow := mkSlowPath x y mask b,
|
||||
bFast ← mkFastPath x y mask b,
|
||||
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b;
|
||||
let (bs, mask) := eraseProjIncFor n y bs;
|
||||
let bSlow := mkSlowPath x y mask b;
|
||||
bFast ← mkFastPath x y mask b;
|
||||
/- We only optimize recursively the fast. -/
|
||||
bFast ← mainFn bFast Array.empty,
|
||||
c ← mkFresh,
|
||||
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast),
|
||||
bFast ← mainFn bFast Array.empty;
|
||||
c ← mkFresh;
|
||||
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast);
|
||||
pure $ reshape bs b
|
||||
|
||||
partial def searchAndExpand : FnBody → Array FnBody → M FnBody
|
||||
|
|
@ -265,10 +265,10 @@ partial def searchAndExpand : FnBody → Array FnBody → M FnBody
|
|||
else
|
||||
searchAndExpand b (push bs d)
|
||||
| (FnBody.jdecl j xs v b) bs := do
|
||||
v ← searchAndExpand v Array.empty,
|
||||
v ← searchAndExpand v Array.empty;
|
||||
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
|
||||
| (FnBody.case tid x alts) bs := do
|
||||
alts ← alts.mmap $ λ alt, alt.mmodifyBody $ λ b, searchAndExpand b Array.empty,
|
||||
alts ← alts.mmap $ λ alt, alt.mmodifyBody $ λ b, searchAndExpand b Array.empty;
|
||||
pure $ reshape bs (FnBody.case tid x alts)
|
||||
| b bs :=
|
||||
if b.isTerminal then pure $ reshape bs b
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ partial def visitFnBody (w : Index) : FnBody → M Bool
|
|||
| (FnBody.del x b) := visitVar w x <||> visitFnBody b
|
||||
| (FnBody.mdata _ b) := visitFnBody b
|
||||
| (FnBody.jmp j ys) := visitArgs w ys <||> do {
|
||||
ctx ← get,
|
||||
ctx ← get;
|
||||
match ctx.getJPBody j with
|
||||
| some b :=
|
||||
-- `j` is not a local join point since we assume we cannot shadow join point declarations.
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace UniqueIds
|
|||
abbrev M := StateT IndexSet Id
|
||||
|
||||
def checkId (id : Index) : M Bool :=
|
||||
do found ← get,
|
||||
do found ← get;
|
||||
if found.contains id then pure false
|
||||
else modify (λ s, s.insert id) *> pure true
|
||||
|
||||
|
|
@ -80,39 +80,39 @@ abbrev N := ReaderT IndexRenaming (State Nat)
|
|||
|
||||
@[inline] def withVar {α : Type} (x : VarId) (k : VarId → N α) : N α :=
|
||||
λ m, do
|
||||
n ← getModify (+1),
|
||||
n ← getModify (+1);
|
||||
k { idx := n } (m.insert x.idx n)
|
||||
|
||||
@[inline] def withJP {α : Type} (x : JoinPointId) (k : JoinPointId → N α) : N α :=
|
||||
λ m, do
|
||||
n ← getModify (+1),
|
||||
n ← getModify (+1);
|
||||
k { idx := n } (m.insert x.idx n)
|
||||
|
||||
@[inline] def withParams {α : Type} (ps : Array Param) (k : Array Param → N α) : N α :=
|
||||
λ m, do
|
||||
m ← ps.mfoldl (λ (m : IndexRenaming) p, do n ← getModify (+1), pure $ m.insert p.x.idx n) m,
|
||||
let ps := ps.map $ λ p, { x := normVar p.x m, .. p },
|
||||
m ← ps.mfoldl (λ (m : IndexRenaming) p, do n ← getModify (+1); pure $ m.insert p.x.idx n) m;
|
||||
let ps := ps.map $ λ p, { x := normVar p.x m, .. p };
|
||||
k ps m
|
||||
|
||||
instance MtoN {α} : HasCoe (M α) (N α) :=
|
||||
⟨λ x m, pure $ x m⟩
|
||||
|
||||
partial def normFnBody : FnBody → N FnBody
|
||||
| (FnBody.vdecl x t v b) := do v ← normExpr v, withVar x $ λ x, FnBody.vdecl x t v <$> normFnBody b
|
||||
| (FnBody.vdecl x t v b) := do v ← normExpr v; withVar x $ λ x, FnBody.vdecl x t v <$> normFnBody b
|
||||
| (FnBody.jdecl j ys v b) := do
|
||||
(ys, v) ← withParams ys $ λ ys, do { v ← normFnBody v, pure (ys, v) },
|
||||
(ys, v) ← withParams ys $ λ ys, do { v ← normFnBody v; pure (ys, v) };
|
||||
withJP j $ λ j, FnBody.jdecl j ys v <$> normFnBody b
|
||||
| (FnBody.set x i y b) := do x ← normVar x, y ← normArg y, FnBody.set x i y <$> normFnBody b
|
||||
| (FnBody.uset x i y b) := do x ← normVar x, y ← normVar y, FnBody.uset x i y <$> normFnBody b
|
||||
| (FnBody.sset x i o y t b) := do x ← normVar x, y ← normVar y, FnBody.sset x i o y t <$> normFnBody b
|
||||
| (FnBody.setTag x i b) := do x ← normVar x, FnBody.setTag x i <$> normFnBody b
|
||||
| (FnBody.inc x n c b) := do x ← normVar x, FnBody.inc x n c <$> normFnBody b
|
||||
| (FnBody.dec x n c b) := do x ← normVar x, FnBody.dec x n c <$> normFnBody b
|
||||
| (FnBody.del x b) := do x ← normVar x, FnBody.del x <$> normFnBody b
|
||||
| (FnBody.set x i y b) := do x ← normVar x; y ← normArg y; FnBody.set x i y <$> normFnBody b
|
||||
| (FnBody.uset x i y b) := do x ← normVar x; y ← normVar y; FnBody.uset x i y <$> normFnBody b
|
||||
| (FnBody.sset x i o y t b) := do x ← normVar x; y ← normVar y; FnBody.sset x i o y t <$> normFnBody b
|
||||
| (FnBody.setTag x i b) := do x ← normVar x; FnBody.setTag x i <$> normFnBody b
|
||||
| (FnBody.inc x n c b) := do x ← normVar x; FnBody.inc x n c <$> normFnBody b
|
||||
| (FnBody.dec x n c b) := do x ← normVar x; FnBody.dec x n c <$> normFnBody b
|
||||
| (FnBody.del x b) := do x ← normVar x; FnBody.del x <$> normFnBody b
|
||||
| (FnBody.mdata d b) := FnBody.mdata d <$> normFnBody b
|
||||
| (FnBody.case tid x alts) := do
|
||||
x ← normVar x,
|
||||
alts ← alts.mmap $ λ alt, alt.mmodifyBody normFnBody,
|
||||
x ← normVar x;
|
||||
alts ← alts.mmap $ λ alt, alt.mmodifyBody normFnBody;
|
||||
pure $ FnBody.case tid x alts
|
||||
| (FnBody.jmp j ys) := FnBody.jmp <$> normJP j <*> normArgs ys
|
||||
| (FnBody.ret x) := FnBody.ret <$> normArg x
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ partial def visitDecl (env : Environment) (decls : Array Decl) : Decl → Decl
|
|||
end ExplicitRC
|
||||
|
||||
def explicitRC (decls : Array Decl) : CompilerM (Array Decl) :=
|
||||
do env ← getEnv,
|
||||
do env ← getEnv;
|
||||
pure $ decls.map (ExplicitRC.visitDecl env decls)
|
||||
|
||||
end IR
|
||||
|
|
|
|||
|
|
@ -59,12 +59,12 @@ private partial def S (w : VarId) (c : CtorInfo) : FnBody → FnBody
|
|||
abbrev M := ReaderT LocalContext (StateT Index Id)
|
||||
|
||||
private def mkFresh : M VarId :=
|
||||
do idx ← getModify (+1),
|
||||
do idx ← getModify (+1);
|
||||
pure { idx := idx }
|
||||
|
||||
private def tryS (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
|
||||
do w ← mkFresh,
|
||||
let b' := S w c b,
|
||||
do w ← mkFresh;
|
||||
let b' := S w c b;
|
||||
if b == b' then pure b
|
||||
else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b'
|
||||
|
||||
|
|
@ -90,38 +90,38 @@ match b with
|
|||
is expensive: `O(n^2)` where `n` is the size of the function body. -/
|
||||
private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody → M (FnBody × Bool)
|
||||
| e@(FnBody.case tid y alts) := do
|
||||
ctx ← read,
|
||||
ctx ← read;
|
||||
if e.hasLiveVar ctx x then do
|
||||
/- If `x` is live in `e`, we recursively process each branch. -/
|
||||
alts ← alts.mmap $ λ alt, alt.mmodifyBody (λ b, Dmain b >>= Dfinalize x c),
|
||||
alts ← alts.mmap $ λ alt, alt.mmodifyBody (λ b, Dmain b >>= Dfinalize x c);
|
||||
pure (FnBody.case tid y alts, true)
|
||||
else pure (e, false)
|
||||
| (FnBody.jdecl j ys v b) := do
|
||||
(b, _) ← adaptReader (λ ctx : LocalContext, ctx.addJP j ys v) (Dmain b),
|
||||
(v, found) ← Dmain v,
|
||||
(b, _) ← adaptReader (λ ctx : LocalContext, ctx.addJP j ys v) (Dmain b);
|
||||
(v, found) ← Dmain v;
|
||||
/- If `found == true`, then `Dmain b` must also have returned `(b, true)` since
|
||||
we assume the IR does not have dead join points. So, if `x` is live in `j`,
|
||||
then it must also live in `b` since `j` is reachable from `b` with a `jmp`. -/
|
||||
pure (FnBody.jdecl j ys v b, found)
|
||||
| e := do
|
||||
ctx ← read,
|
||||
ctx ← read;
|
||||
if e.isTerminal then
|
||||
pure (e, e.hasLiveVar ctx x)
|
||||
else do
|
||||
let (instr, b) := e.split,
|
||||
let (instr, b) := e.split;
|
||||
if isCtorUsing instr x then
|
||||
/- If the scrutinee `x` (the one that is providing memory) is being
|
||||
stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
|
||||
It may work only if the new cell is consumed, but we ignore this case. -/
|
||||
pure (e, true)
|
||||
else do
|
||||
(b, found) ← Dmain b,
|
||||
(b, found) ← Dmain b;
|
||||
/- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar`
|
||||
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/
|
||||
if found || !instr.hasFreeVar x then
|
||||
pure (instr.setBody b, found)
|
||||
else do
|
||||
b ← tryS x c b,
|
||||
b ← tryS x c b;
|
||||
pure (instr.setBody b, true)
|
||||
|
||||
private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
|
||||
|
|
@ -130,23 +130,23 @@ Dmain x c b >>= Dfinalize x c
|
|||
partial def R : FnBody → M FnBody
|
||||
| (FnBody.case tid x alts) := do
|
||||
alts ← alts.mmap $ λ alt, do {
|
||||
alt ← alt.mmodifyBody R,
|
||||
alt ← alt.mmodifyBody R;
|
||||
match alt with
|
||||
| Alt.ctor c b :=
|
||||
if c.isScalar then pure alt
|
||||
else Alt.ctor c <$> D x c b
|
||||
| _ := pure alt
|
||||
},
|
||||
};
|
||||
pure $ FnBody.case tid x alts
|
||||
| (FnBody.jdecl j ys v b) := do
|
||||
v ← R v,
|
||||
b ← adaptReader (λ ctx : LocalContext, ctx.addJP j ys v) (R b),
|
||||
v ← R v;
|
||||
b ← adaptReader (λ ctx : LocalContext, ctx.addJP j ys v) (R b);
|
||||
pure $ FnBody.jdecl j ys v b
|
||||
| e := do
|
||||
if e.isTerminal then pure e
|
||||
else do
|
||||
let (instr, b) := e.split,
|
||||
b ← R b,
|
||||
let (instr, b) := e.split;
|
||||
b ← R b;
|
||||
pure (instr.setBody b)
|
||||
|
||||
end ResetReuse
|
||||
|
|
|
|||
|
|
@ -140,15 +140,15 @@ instance EnvExtension.Inhabited (σ : Type) [Inhabited σ] : Inhabited (EnvExten
|
|||
|
||||
unsafe def registerEnvExtensionUnsafe {σ : Type} (initState : σ) : IO (EnvExtension σ) :=
|
||||
do
|
||||
initializing ← IO.initializing,
|
||||
unless initializing $ throw (IO.userError ("failed to register environment, extensions can only be registered during initialization")),
|
||||
exts ← envExtensionsRef.get,
|
||||
let idx := exts.size,
|
||||
initializing ← IO.initializing;
|
||||
unless initializing $ throw (IO.userError ("failed to register environment, extensions can only be registered during initialization"));
|
||||
exts ← envExtensionsRef.get;
|
||||
let idx := exts.size;
|
||||
let ext : EnvExtension σ := {
|
||||
idx := idx,
|
||||
initial := initState
|
||||
},
|
||||
envExtensionsRef.modify (λ exts, exts.push (unsafeCast ext)),
|
||||
};
|
||||
envExtensionsRef.modify (λ exts, exts.push (unsafeCast ext));
|
||||
pure ext
|
||||
|
||||
/- Environment extensions can only be registered during initialization.
|
||||
|
|
@ -159,14 +159,14 @@ pure ext
|
|||
constant registerEnvExtension {σ : Type} (initState : σ) : IO (EnvExtension σ) := default _
|
||||
|
||||
private def mkInitialExtensionStates : IO (Array EnvExtensionState) :=
|
||||
do exts ← envExtensionsRef.get, pure $ exts.map $ λ ext, ext.initial
|
||||
do exts ← envExtensionsRef.get; pure $ exts.map $ λ ext, ext.initial
|
||||
|
||||
@[export lean.mk_empty_environment_core]
|
||||
def mkEmptyEnvironment (trustLevel : UInt32 := 0) : IO Environment :=
|
||||
do
|
||||
initializing ← IO.initializing,
|
||||
when initializing $ throw (IO.userError "environment objects cannot be created during initialization"),
|
||||
exts ← mkInitialExtensionStates,
|
||||
initializing ← IO.initializing;
|
||||
when initializing $ throw (IO.userError "environment objects cannot be created during initialization");
|
||||
exts ← mkInitialExtensionStates;
|
||||
pure { const2ModIdx := {},
|
||||
constants := {},
|
||||
header := { trustLevel := trustLevel },
|
||||
|
|
@ -243,10 +243,10 @@ unsafe def registerPersistentEnvExtensionUnsafe {α σ : Type} (descr : Persiste
|
|||
do
|
||||
let s : PersistentEnvExtensionState α σ := {
|
||||
importedEntries := Array.empty,
|
||||
state := descr.addImportedFn Array.empty },
|
||||
pExts ← persistentEnvExtensionsRef.get,
|
||||
when (pExts.any (λ ext, ext.name == descr.name)) $ throw (IO.userError ("invalid environment extension, '" ++ toString descr.name ++ "' has already been used")),
|
||||
ext ← registerEnvExtension s,
|
||||
state := descr.addImportedFn Array.empty };
|
||||
pExts ← persistentEnvExtensionsRef.get;
|
||||
when (pExts.any (λ ext, ext.name == descr.name)) $ throw (IO.userError ("invalid environment extension, '" ++ toString descr.name ++ "' has already been used"));
|
||||
ext ← registerEnvExtension s;
|
||||
let pExt : PersistentEnvExtension α σ := {
|
||||
toEnvExtension := ext,
|
||||
name := descr.name,
|
||||
|
|
@ -254,8 +254,8 @@ let pExt : PersistentEnvExtension α σ := {
|
|||
addEntryFn := descr.addEntryFn,
|
||||
exportEntriesFn := descr.exportEntriesFn,
|
||||
statsFn := descr.statsFn
|
||||
},
|
||||
persistentEnvExtensionsRef.modify (λ pExts, pExts.push (unsafeCast pExt)),
|
||||
};
|
||||
persistentEnvExtensionsRef.modify (λ pExts, pExts.push (unsafeCast pExt));
|
||||
pure pExt
|
||||
|
||||
@[implementedBy registerPersistentEnvExtensionUnsafe]
|
||||
|
|
@ -311,15 +311,15 @@ instance CPPExtensionState.inhabited : Inhabited CPPExtensionState := inferInsta
|
|||
|
||||
@[export lean.register_extension_core]
|
||||
unsafe def registerCPPExtension (initial : CPPExtensionState) : Option Nat :=
|
||||
unsafeIO (do ext ← registerEnvExtension initial, pure ext.idx)
|
||||
unsafeIO (do ext ← registerEnvExtension initial; pure ext.idx)
|
||||
|
||||
@[export lean.set_extension_core]
|
||||
unsafe def setCPPExtensionState (env : Environment) (idx : Nat) (s : CPPExtensionState) : Option Environment :=
|
||||
unsafeIO (do exts ← envExtensionsRef.get, pure $ (exts.get idx).setState env s)
|
||||
unsafeIO (do exts ← envExtensionsRef.get; pure $ (exts.get idx).setState env s)
|
||||
|
||||
@[export lean.get_extension_core]
|
||||
unsafe def getCPPExtensionState (env : Environment) (idx : Nat) : Option CPPExtensionState :=
|
||||
unsafeIO (do exts ← envExtensionsRef.get, pure $ (exts.get idx).getState env)
|
||||
unsafeIO (do exts ← envExtensionsRef.get; pure $ (exts.get idx).getState env)
|
||||
|
||||
/- Legacy support for Modification objects -/
|
||||
|
||||
|
|
@ -371,15 +371,15 @@ constant readModuleData (fname : @& String) : IO ModuleData := default _
|
|||
|
||||
def mkModuleData (env : Environment) : IO ModuleData :=
|
||||
do
|
||||
pExts ← persistentEnvExtensionsRef.get,
|
||||
pExts ← persistentEnvExtensionsRef.get;
|
||||
let entries : Array (Name × Array EnvExtensionEntry) := pExts.size.fold
|
||||
(λ i result,
|
||||
let state := (pExts.get i).getState env;
|
||||
let exportEntriesFn := (pExts.get i).exportEntriesFn;
|
||||
let extName := (pExts.get i).name;
|
||||
result.push (extName, exportEntriesFn state))
|
||||
Array.empty,
|
||||
bytes ← serializeModifications (modListExtension.getState env),
|
||||
Array.empty;
|
||||
bytes ← serializeModifications (modListExtension.getState env);
|
||||
pure {
|
||||
imports := env.header.imports,
|
||||
constants := env.constants.foldStage2 (λ cs _ c, cs.push c) Array.empty,
|
||||
|
|
@ -389,7 +389,7 @@ serialized := bytes
|
|||
|
||||
@[export lean.write_module_core]
|
||||
def writeModule (env : Environment) (fname : String) : IO Unit :=
|
||||
do modData ← mkModuleData env, saveModuleData fname modData
|
||||
do modData ← mkModuleData env; saveModuleData fname modData
|
||||
|
||||
@[extern 2 "lean_find_olean"]
|
||||
constant findOLean (modName : Name) : IO String := default _
|
||||
|
|
@ -400,11 +400,11 @@ partial def importModulesAux : List Name → (NameSet × Array ModuleData) → I
|
|||
if s.contains m then
|
||||
importModulesAux ms (s, mods)
|
||||
else do
|
||||
let s := s.insert m,
|
||||
mFile ← findOLean m,
|
||||
mod ← readModuleData mFile,
|
||||
(s, mods) ← importModulesAux mod.imports.toList (s, mods),
|
||||
let mods := mods.push mod,
|
||||
let s := s.insert m;
|
||||
mFile ← findOLean m;
|
||||
mod ← readModuleData mFile;
|
||||
(s, mods) ← importModulesAux mod.imports.toList (s, mods);
|
||||
let mods := mods.push mod;
|
||||
importModulesAux ms (s, mods)
|
||||
|
||||
private partial def getEntriesFor (mod : ModuleData) (extId : Name) : Nat → Array EnvExtensionState
|
||||
|
|
@ -417,7 +417,7 @@ private partial def getEntriesFor (mod : ModuleData) (extId : Name) : Nat → Ar
|
|||
|
||||
private def setImportedEntries (env : Environment) (mods : Array ModuleData) : IO Environment :=
|
||||
do
|
||||
pExtDescrs ← persistentEnvExtensionsRef.get,
|
||||
pExtDescrs ← persistentEnvExtensionsRef.get;
|
||||
pure $ mods.iterate env $ λ _ mod env,
|
||||
pExtDescrs.iterate env $ λ _ extDescr env,
|
||||
let entries := getEntriesFor mod extDescr.name 0;
|
||||
|
|
@ -427,7 +427,7 @@ pure $ mods.iterate env $ λ _ mod env,
|
|||
|
||||
private def finalizePersistentExtensions (env : Environment) : IO Environment :=
|
||||
do
|
||||
pExtDescrs ← persistentEnvExtensionsRef.get,
|
||||
pExtDescrs ← persistentEnvExtensionsRef.get;
|
||||
pure $ pExtDescrs.iterate env $ λ _ extDescr env,
|
||||
extDescr.toEnvExtension.modifyState env $ λ s,
|
||||
{ state := extDescr.addImportedFn s.importedEntries, .. s }
|
||||
|
|
@ -435,17 +435,17 @@ pure $ pExtDescrs.iterate env $ λ _ extDescr env,
|
|||
@[export lean.import_modules_core]
|
||||
def importModules (modNames : List Name) (trustLevel : UInt32 := 0) : IO Environment :=
|
||||
do
|
||||
(_, mods) ← importModulesAux modNames ({}, Array.empty),
|
||||
(_, mods) ← importModulesAux modNames ({}, Array.empty);
|
||||
let const2ModIdx := mods.iterate {} $ λ modIdx (mod : ModuleData) (m : HashMap Name ModuleIdx),
|
||||
mod.constants.iterate m $ λ _ cinfo m,
|
||||
m.insert cinfo.name modIdx.val,
|
||||
m.insert cinfo.name modIdx.val;
|
||||
constants ← mods.miterate SMap.empty $ λ _ (mod : ModuleData) (cs : ConstMap),
|
||||
mod.constants.miterate cs $ λ _ cinfo cs, do {
|
||||
when (cs.contains cinfo.name) $ throw (IO.userError ("import failed, environment already contains '" ++ toString cinfo.name ++ "'")),
|
||||
when (cs.contains cinfo.name) $ throw (IO.userError ("import failed, environment already contains '" ++ toString cinfo.name ++ "'"));
|
||||
pure $ cs.insert cinfo.name cinfo
|
||||
},
|
||||
let constants := constants.switch,
|
||||
exts ← mkInitialExtensionStates,
|
||||
};
|
||||
let constants := constants.switch;
|
||||
exts ← mkInitialExtensionStates;
|
||||
let env : Environment := {
|
||||
const2ModIdx := const2ModIdx,
|
||||
constants := constants,
|
||||
|
|
@ -455,10 +455,10 @@ let env : Environment := {
|
|||
trustLevel := trustLevel,
|
||||
imports := modNames.toArray
|
||||
}
|
||||
},
|
||||
env ← setImportedEntries env mods,
|
||||
env ← finalizePersistentExtensions env,
|
||||
env ← mods.miterate env $ λ _ mod env, performModifications env mod.serialized,
|
||||
};
|
||||
env ← setImportedEntries env mods;
|
||||
env ← finalizePersistentExtensions env;
|
||||
env ← mods.miterate env $ λ _ mod env, performModifications env mod.serialized;
|
||||
pure env
|
||||
|
||||
namespace Environment
|
||||
|
|
@ -466,25 +466,25 @@ namespace Environment
|
|||
@[export lean.display_stats_core]
|
||||
def displayStats (env : Environment) : IO Unit :=
|
||||
do
|
||||
pExtDescrs ← persistentEnvExtensionsRef.get,
|
||||
let numModules := ((pExtDescrs.get 0).toEnvExtension.getState env).importedEntries.size,
|
||||
IO.println ("direct imports: " ++ toString env.header.imports),
|
||||
IO.println ("number of imported modules: " ++ toString numModules),
|
||||
IO.println ("number of consts: " ++ toString env.constants.size),
|
||||
IO.println ("number of imported consts: " ++ toString env.constants.stageSizes.1),
|
||||
IO.println ("number of local consts: " ++ toString env.constants.stageSizes.2),
|
||||
IO.println ("number of buckets for imported consts: " ++ toString env.constants.numBuckets),
|
||||
IO.println ("map depth for local consts: " ++ toString env.constants.maxDepth),
|
||||
IO.println ("trust level: " ++ toString env.header.trustLevel),
|
||||
IO.println ("number of extensions: " ++ toString env.extensions.size),
|
||||
pExtDescrs ← persistentEnvExtensionsRef.get;
|
||||
let numModules := ((pExtDescrs.get 0).toEnvExtension.getState env).importedEntries.size;
|
||||
IO.println ("direct imports: " ++ toString env.header.imports);
|
||||
IO.println ("number of imported modules: " ++ toString numModules);
|
||||
IO.println ("number of consts: " ++ toString env.constants.size);
|
||||
IO.println ("number of imported consts: " ++ toString env.constants.stageSizes.1);
|
||||
IO.println ("number of local consts: " ++ toString env.constants.stageSizes.2);
|
||||
IO.println ("number of buckets for imported consts: " ++ toString env.constants.numBuckets);
|
||||
IO.println ("map depth for local consts: " ++ toString env.constants.maxDepth);
|
||||
IO.println ("trust level: " ++ toString env.header.trustLevel);
|
||||
IO.println ("number of extensions: " ++ toString env.extensions.size);
|
||||
pExtDescrs.mfor $ λ extDescr, do {
|
||||
IO.println ("extension '" ++ toString extDescr.name ++ "'"),
|
||||
let s := extDescr.toEnvExtension.getState env,
|
||||
let fmt := extDescr.statsFn s.state,
|
||||
unless fmt.isNil (IO.println (" " ++ toString (Format.nest 2 (extDescr.statsFn s.state)))),
|
||||
IO.println (" number of imported entries: " ++ toString (s.importedEntries.foldl (λ sum es, sum + es.size) 0)),
|
||||
IO.println ("extension '" ++ toString extDescr.name ++ "'");
|
||||
let s := extDescr.toEnvExtension.getState env;
|
||||
let fmt := extDescr.statsFn s.state;
|
||||
unless fmt.isNil (IO.println (" " ++ toString (Format.nest 2 (extDescr.statsFn s.state))));
|
||||
IO.println (" number of imported entries: " ++ toString (s.importedEntries.foldl (λ sum es, sum + es.size) 0));
|
||||
pure ()
|
||||
},
|
||||
};
|
||||
pure ()
|
||||
|
||||
end Environment
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ modifyConstTable (λ cs, cs.qsort (λ e₁ e₂, Name.quickLt e₁.1 e₂.1))
|
|||
the program may crash if the type provided by the user is incorrect.
|
||||
It also assumes there are no threads trying to update the table concurrently. -/
|
||||
unsafe def evalConst (α : Type) [Inhabited α] (c : Name) : IO α :=
|
||||
do cs ← getConstTable,
|
||||
do cs ← getConstTable;
|
||||
match cs.binSearch (c, default _) (λ e₁ e₂, Name.quickLt e₁.1 e₂.1) with
|
||||
| some (_, v) := pure (unsafeCast v)
|
||||
| none := throw (IO.userError ("unknow constant '" ++ toString c ++ "'"))
|
||||
|
|
|
|||
|
|
@ -29,30 +29,30 @@ IO.mkRef (mkNameMap OptionDecl)
|
|||
private constant optionDeclsRef : IO.Ref OptionDecls := default _
|
||||
|
||||
def registerOption (name : Name) (decl : OptionDecl) : IO Unit :=
|
||||
do decls ← optionDeclsRef.get,
|
||||
do decls ← optionDeclsRef.get;
|
||||
when (decls.contains name) $
|
||||
throw $ IO.userError ("invalid option declaration '" ++ toString name ++ "', option already exists"),
|
||||
throw $ IO.userError ("invalid option declaration '" ++ toString name ++ "', option already exists");
|
||||
optionDeclsRef.set $ decls.insert name decl
|
||||
|
||||
def getOptionDecls : IO OptionDecls := optionDeclsRef.get
|
||||
|
||||
def getOptionDecl (name : Name) : IO OptionDecl :=
|
||||
do decls ← getOptionDecls,
|
||||
(some decl) ← pure (decls.find name) | throw $ IO.userError ("unknown option '" ++ toString name ++ "'"),
|
||||
do decls ← getOptionDecls;
|
||||
(some decl) ← pure (decls.find name) | throw $ IO.userError ("unknown option '" ++ toString name ++ "'");
|
||||
pure decl
|
||||
|
||||
def getOptionDefaulValue (name : Name) : IO DataValue :=
|
||||
do decl ← getOptionDecl name,
|
||||
do decl ← getOptionDecl name;
|
||||
pure decl.defValue
|
||||
|
||||
def getOptionDescr (name : Name) : IO String :=
|
||||
do decl ← getOptionDecl name,
|
||||
do decl ← getOptionDecl name;
|
||||
pure decl.descr
|
||||
|
||||
def setOptionFromString (opts : Options) (entry : String) : IO Options :=
|
||||
do let ps := (entry.split "=").map String.trim,
|
||||
[key, val] ← pure ps | throw "invalid configuration option entry, it must be of the form '<key> = <value>'",
|
||||
defValue ← getOptionDefaulValue key.toName,
|
||||
do let ps := (entry.split "=").map String.trim;
|
||||
[key, val] ← pure ps | throw "invalid configuration option entry, it must be of the form '<key> = <value>'";
|
||||
defValue ← getOptionDefaulValue key.toName;
|
||||
match defValue with
|
||||
| DataValue.ofString v := pure $ opts.setString key val
|
||||
| DataValue.ofBool v :=
|
||||
|
|
@ -61,10 +61,10 @@ do let ps := (entry.split "=").map String.trim,
|
|||
else throw $ IO.userError ("invalid Bool option value '" ++ val ++ "'")
|
||||
| DataValue.ofName v := pure $ opts.setName key val.toName
|
||||
| DataValue.ofNat v := do
|
||||
unless val.isNat $ throw (IO.userError ("invalid Nat option value '" ++ val ++ "'")),
|
||||
unless val.isNat $ throw (IO.userError ("invalid Nat option value '" ++ val ++ "'"));
|
||||
pure $ opts.setNat key val.toNat
|
||||
| DataValue.ofInt v := do
|
||||
unless val.isInt $ throw (IO.userError ("invalid Int option value '" ++ val ++ "'")),
|
||||
unless val.isInt $ throw (IO.userError ("invalid Int option value '" ++ val ++ "'"));
|
||||
pure $ opts.setInt key val.toInt
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -972,9 +972,9 @@ match info.updateTokens tables.tokens with
|
|||
| Except.error msg := throw (IO.userError msg)
|
||||
|
||||
def addBuiltinLeadingParser (tablesRef : IO.Ref ParsingTables) (declName : Name) (p : Parser) : IO Unit :=
|
||||
do tables ← tablesRef.get,
|
||||
tablesRef.reset,
|
||||
tables ← updateTokens tables p.info,
|
||||
do tables ← tablesRef.get;
|
||||
tablesRef.reset;
|
||||
tables ← updateTokens tables p.info;
|
||||
match p.info.firstTokens with
|
||||
| FirstTokens.tokens tks :=
|
||||
let tables := tks.foldl (λ (tables : ParsingTables) tk, { leadingTable := tables.leadingTable.insert (mkSimpleName tk.val) p, .. tables }) tables;
|
||||
|
|
@ -983,9 +983,9 @@ do tables ← tablesRef.get,
|
|||
throw (IO.userError ("invalid builtin parser '" ++ toString declName ++ "', initial token is not statically known"))
|
||||
|
||||
def addBuiltinTrailingParser (tablesRef : IO.Ref ParsingTables) (declName : Name) (p : TrailingParser) : IO Unit :=
|
||||
do tables ← tablesRef.get,
|
||||
tablesRef.reset,
|
||||
tables ← updateTokens tables p.info,
|
||||
do tables ← tablesRef.get;
|
||||
tablesRef.reset;
|
||||
tables ← updateTokens tables p.info;
|
||||
match p.info.firstTokens with
|
||||
| FirstTokens.tokens tks :=
|
||||
let tables := tks.foldl (λ (tables : ParsingTables) tk, { trailingTable := tables.trailingTable.insert (mkSimpleName tk.val) p, .. tables }) tables;
|
||||
|
|
@ -1017,8 +1017,8 @@ registerAttribute {
|
|||
name := attrName,
|
||||
descr := "Builtin parser",
|
||||
add := λ env declName args persistent, do {
|
||||
unless args.isMissing $ throw (IO.userError ("invalid attribute '" ++ toString attrName ++ "', unexpected argument")),
|
||||
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString attrName ++ "', must be persistent")),
|
||||
unless args.isMissing $ throw (IO.userError ("invalid attribute '" ++ toString attrName ++ "', unexpected argument"));
|
||||
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString attrName ++ "', must be persistent"));
|
||||
match env.find declName with
|
||||
| none := throw "unknown declaration"
|
||||
| some decl :=
|
||||
|
|
@ -1045,7 +1045,7 @@ registerBuiltinParserAttribute `builtinTermParser `Lean.Parser.builtinTermParsin
|
|||
|
||||
@[noinline] unsafe def runBuiltinParserUnsafe (kind : String) (ref : IO.Ref ParsingTables) : ParserFn leading :=
|
||||
λ a c s,
|
||||
match unsafeIO (do tables ← ref.get, pure $ prattParser kind tables a c s) with
|
||||
match unsafeIO (do tables ← ref.get; pure $ prattParser kind tables a c s) with
|
||||
| some s := s
|
||||
| none := s.mkError "failed to access builtin reference"
|
||||
|
||||
|
|
|
|||
|
|
@ -109,11 +109,11 @@ match s with
|
|||
|
||||
@[specialize] partial def mreplace {m : Type → Type} [Monad m] (fn : Syntax → m (Option Syntax)) : Syntax → m Syntax
|
||||
| stx@(node kind args scopes) := do
|
||||
o ← fn stx,
|
||||
o ← fn stx;
|
||||
(match o with
|
||||
| some stx := pure stx
|
||||
| none := do args ← args.mmap mreplace, pure (node kind args scopes))
|
||||
| stx := do o ← fn stx, pure (o.getOrElse stx)
|
||||
| none := do args ← args.mmap mreplace; pure (node kind args scopes))
|
||||
| stx := do o ← fn stx; pure (o.getOrElse stx)
|
||||
|
||||
@[inline] def replace {m : Type → Type} [Monad m] (fn : Syntax → m (Option Syntax)) := @mreplace Id _
|
||||
|
||||
|
|
@ -126,14 +126,14 @@ private def updateInfo : SourceInfo → String.Pos → SourceInfo
|
|||
@[inline]
|
||||
private def updateLeadingAux : Syntax → State String.Pos (Option Syntax)
|
||||
| (atom (some info) val) := do
|
||||
last ← get,
|
||||
set info.trailing.stopPos,
|
||||
let newInfo := updateInfo info last in
|
||||
last ← get;
|
||||
set info.trailing.stopPos;
|
||||
let newInfo := updateInfo info last;
|
||||
pure $ some (atom (some newInfo) val)
|
||||
| (ident (some info) rawVal val pre scopes) := do
|
||||
last ← get,
|
||||
set info.trailing.stopPos,
|
||||
let newInfo := updateInfo info last in
|
||||
last ← get;
|
||||
set info.trailing.stopPos;
|
||||
let newInfo := updateInfo info last;
|
||||
pure $ some (ident (some newInfo) rawVal val pre scopes)
|
||||
| _ := pure none
|
||||
|
||||
|
|
@ -184,9 +184,9 @@ partial def reprint : Syntax → Option String
|
|||
if kind == choiceKind then
|
||||
if args.size == 0 then failure
|
||||
else do
|
||||
s ← reprint (args.get 0),
|
||||
args.mfoldlFrom (λ s stx, do s' ← reprint stx, guard (s == s'), pure s) s 1
|
||||
else args.mfoldl (λ r stx, do s ← reprint stx, pure $ r ++ s) ""
|
||||
s ← reprint (args.get 0);
|
||||
args.mfoldlFrom (λ s stx, do s' ← reprint stx; guard (s == s'); pure s) s 1
|
||||
else args.mfoldl (λ r stx, do s ← reprint stx; pure $ r ++ s) ""
|
||||
| missing := ""
|
||||
|
||||
open Lean.Format
|
||||
|
|
|
|||
|
|
@ -48,23 +48,23 @@ traceCtx cls msg (pure () : m Unit)
|
|||
|
||||
instance (m) [Monad m] : MonadTracer (TraceT m) :=
|
||||
{ traceRoot := λ α pos cls msg ctx, do {
|
||||
st ← get,
|
||||
st ← get;
|
||||
if st.opts.getBool cls = true then do {
|
||||
modify $ λ st, {curPos := pos, curTraces := [], ..st},
|
||||
a ← ctx.get,
|
||||
modify $ λ (st : TraceState), {roots := st.roots.insert pos ⟨msg, st.curTraces⟩, ..st},
|
||||
modify $ λ st, {curPos := pos, curTraces := [], ..st};
|
||||
a ← ctx.get;
|
||||
modify $ λ (st : TraceState), {roots := st.roots.insert pos ⟨msg, st.curTraces⟩, ..st};
|
||||
pure a
|
||||
} else ctx.get
|
||||
},
|
||||
traceCtx := λ α cls msg ctx, do {
|
||||
st ← get,
|
||||
st ← get;
|
||||
-- tracing enabled?
|
||||
some _ ← pure st.curPos | ctx.get,
|
||||
some _ ← pure st.curPos | ctx.get;
|
||||
-- Trace class enabled?
|
||||
if st.opts.getBool cls = true then do {
|
||||
set {curTraces := [], ..st},
|
||||
a ← ctx.get,
|
||||
modify $ λ (st' : TraceState), {curTraces := st.curTraces ++ [⟨msg, st'.curTraces⟩], ..st'},
|
||||
set {curTraces := [], ..st};
|
||||
a ← ctx.get;
|
||||
modify $ λ (st' : TraceState), {curTraces := st.curTraces ++ [⟨msg, st'.curTraces⟩], ..st'};
|
||||
pure a
|
||||
} else
|
||||
-- disable tracing inside 'ctx'
|
||||
|
|
@ -76,7 +76,7 @@ instance (m) [Monad m] : MonadTracer (TraceT m) :=
|
|||
}
|
||||
|
||||
def TraceT.run {m α} [Monad m] (opts : Options) (x : TraceT m α) : m (α × TraceMap) :=
|
||||
do (a, st) ← StateT.run x {opts := opts, roots := mkRBMap _ _ _, curPos := none, curTraces := []},
|
||||
do (a, st) ← StateT.run x {opts := opts, roots := RBMap.empty, curPos := none, curTraces := []};
|
||||
pure (a, st.roots)
|
||||
|
||||
end Trace
|
||||
|
|
|
|||
|
|
@ -95,7 +95,11 @@ static expr parse_let_body(parser & p, pos_info const & pos, bool in_do_block) {
|
|||
p.next();
|
||||
return p.parse_expr();
|
||||
} else {
|
||||
p.check_token_next(get_comma_tk(), "invalid 'do' block 'let' declaration, ',' or 'in' expected");
|
||||
if (p.curr_is_token(get_semicolon_tk()) || p.curr_is_token(get_comma_tk())) {
|
||||
p.next();
|
||||
} else {
|
||||
p.check_token_next(get_semicolon_tk(), "invalid 'do' block 'let' declaration, ',', ';' or 'in' expected");
|
||||
}
|
||||
if (p.curr_is_token(get_let_tk())) {
|
||||
p.next();
|
||||
return parse_let(p, pos, in_do_block);
|
||||
|
|
@ -277,7 +281,7 @@ static expr parse_do(parser & p, bool has_braces) {
|
|||
expr type, curr;
|
||||
std::tie(lhs, type, curr, else_case) = parse_do_action(p, new_locals);
|
||||
es.push_back(curr);
|
||||
if (p.curr_is_token(get_comma_tk())) {
|
||||
if (/* p.curr_is_token(get_comma_tk()) || */ p.curr_is_token(get_semicolon_tk())) {
|
||||
p.next();
|
||||
for (expr const & l : new_locals)
|
||||
p.add_local(l);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue