feat(frontends/lean/builtin_exprs): use ; in do-notation

This commit is contained in:
Leonardo de Moura 2019-06-27 17:56:00 -07:00
parent ab487ea4ac
commit 91e1d30cf8
35 changed files with 589 additions and 582 deletions

View file

@ -24,7 +24,7 @@ if c then pure () else e
@[macroInline]
def mcond {m : Type → Type u} [Monad m] {α : Type} (mbool : m Bool) (tm fm : m α) : m α :=
do b ← mbool, cond b tm fm
do b ← mbool; cond b tm fm
@[macroInline]
def mwhen {m : Type → Type u} [Monad m] (c : m Bool) (t : m Unit) : m Unit :=
@ -70,20 +70,20 @@ def mfor {m : Type u → Type v} [Applicative m] {α : Type w} {β : Type u} (f
@[specialize]
def mfilter {m : Type → Type v} [Monad m] {α : Type} (f : α → m Bool) : List α → m (List α)
| [] := pure []
| (h :: t) := do b ← f h, t' ← mfilter t, cond b (pure (h :: t')) (pure t')
| (h :: t) := do b ← f h; t' ← mfilter t; cond b (pure (h :: t')) (pure t')
@[specialize]
def mfoldl {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (s → α → m s) → s → List α → m s
| f s [] := pure s
| f s (h :: r) := do
s' ← f s h,
s' ← f s h;
mfoldl f s' r
@[specialize]
def mfoldr {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (α → s → m s) → s → List α → m s
| f s [] := pure s
| f s (h :: r) := do
s' ← mfoldr f s r,
s' ← mfoldr f s r;
f h s'
@[specialize]
@ -94,14 +94,14 @@ def mfirst {m : Type u → Type v} [Monad m] [Alternative m] {α : Type w} {β :
@[specialize]
def mexists {m : Type → Type u} [Monad m] {α : Type v} (f : α → m Bool) : List α → m Bool
| [] := pure false
| (a::as) := do b ← f a, match b with
| (a::as) := do b ← f a; match b with
| true := pure true
| false := mexists as
@[specialize]
def mforall {m : Type → Type u} [Monad m] {α : Type v} (f : α → m Bool) : List α → m Bool
| [] := pure true
| (a::as) := do b ← f a, match b with
| (a::as) := do b ← f a; match b with
| true := mforall as
| false := pure false

View file

@ -21,13 +21,13 @@ match toBool b with
| false := f
@[macroInline] def orM {m : Type u → Type v} {β : Type u} [Monad m] [HasToBool β] (x y : m β) : m β :=
do b ← x,
do b ← x;
match toBool b with
| true := pure b
| false := y
@[macroInline] def andM {m : Type u → Type v} {β : Type u} [Monad m] [HasToBool β] (x y : m β) : m β :=
do b ← x,
do b ← x;
match toBool b with
| true := y
| false := pure b

View file

@ -31,7 +31,7 @@ namespace OptionT
{ pure := @OptionT.pure _ _, bind := @OptionT.bind _ _ }
protected def orelse (ma : OptionT m α) (mb : OptionT m α) : OptionT m α :=
(do { some a ← ma | mb,
(do { some a ← ma | mb;
pure (some a) } : m (Option α))
@[inline] protected def fail : OptionT m α :=
@ -55,7 +55,7 @@ namespace OptionT
⟨λ α, OptionT.monadMap⟩
protected def catch (ma : OptionT m α) (handle : Unit → OptionT m α) : OptionT m α :=
(do { some a ← ma | (handle ()),
(do { some a ← ma | (handle ());
pure a } : m (Option α))
instance : MonadExcept Unit (OptionT m) :=

View file

@ -30,7 +30,7 @@ pure
λ r, pure a
@[inline] protected def bind (x : ReaderT ρ m α) (f : α → ReaderT ρ m β) : ReaderT ρ m β :=
λ r, do a ← x r, f a r
λ r, do a ← x r; f a r
@[inline] protected def map (f : α → β) (x : ReaderT ρ m α) : ReaderT ρ m β :=
λ r, f <$> x r

View file

@ -30,10 +30,10 @@ variables [Monad m] {α β : Type u}
λ s, pure (a, s)
@[inline] protected def bind (x : StateT σ m α) (f : α → StateT σ m β) : StateT σ m β :=
λ s, do (a, s) ← x s, f a s
λ s, do (a, s) ← x s; f a s
@[inline] protected def map (f : α → β) (x : StateT σ m α) : StateT σ m β :=
λ s, do (a, s) ← x s, pure (f a, s)
λ s, do (a, s) ← x s; pure (f a, s)
instance : Monad (StateT σ m) :=
{ pure := @StateT.pure _ _ _, bind := @StateT.bind _ _ _, map := @StateT.map _ _ _ }
@ -59,7 +59,7 @@ instance [Alternative m] : Alternative (StateT σ m) :=
λ s, pure (⟨⟩, f s)
@[inline] protected def lift {α : Type u} (t : m α) : StateT σ m α :=
λ s, do a ← t, pure (a, s)
λ s, do a ← t; pure (a, s)
instance : HasMonadLift m (StateT σ m) :=
⟨@StateT.lift σ m _⟩
@ -70,8 +70,8 @@ instance (σ m m') [Monad m] [Monad m'] : MonadFunctor m m' (StateT σ m) (State
@[inline] protected def adapt {σ σ' σ'' α : Type u} {m : Type u → Type v} [Monad m] (split : σσ' × σ'')
(join : σ' → σ'' → σ) (x : StateT σ' m α) : StateT σ m α :=
λ st, do
let (st, ctx) := split st,
(a, st') ← x st,
let (st, ctx) := split st;
(a, st') ← x st;
pure (a, join st' ctx)
instance (ε) [MonadExcept ε m] : MonadExcept ε (StateT σ m) :=
@ -100,7 +100,7 @@ section
variables {σ : Type u} {m : Type u → Type v}
@[inline] def getModify [MonadState σ m] [Monad m] (f : σσ) : m σ :=
do s ← get, modify f, pure s
do s ← get; modify f; pure s
-- NOTE: The Ordering of the following two instances determines that the top-most `StateT` Monad layer
-- will be picked first

View file

@ -177,7 +177,7 @@ miterate₂ a₁ a₂ b (λ _ a₁ a₂ b, f b a₁ a₂)
| i :=
if h : i < a.size then
let idx : Fin a.size := ⟨i, h⟩;
do r ← f (a.fget idx),
do r ← f (a.fget idx);
match r with
| some v := pure r
| none := mfindAux (i+1)
@ -216,7 +216,7 @@ variables {m : Type → Type v} [Monad m]
| i :=
if h : i < a.size then
let idx : Fin a.size := ⟨i, h⟩;
do b ← p (a.fget idx),
do b ← p (a.fget idx);
match b with
| true := pure true
| false := anyMAux (i+1)
@ -266,7 +266,7 @@ variables {m : Type v → Type v} [Monad m]
let idx : Fin a.size := ⟨i, h⟩;
let v : α := a.fget idx;
let a := a.fset idx (@unsafeCast _ _ ⟨v⟩ ());
do newV ← f i v, ummapAux (i+1) (a.fset idx (@unsafeCast _ _ ⟨v⟩ newV))
do newV ← f i v; ummapAux (i+1) (a.fset idx (@unsafeCast _ _ ⟨v⟩ newV))
else
pure (unsafeCast a)
@ -277,10 +277,10 @@ ummapAux (λ i a, f a) 0 as
ummapAux f 0 as
@[implementedBy Array.ummap] def mmap (f : α → m β) (as : Array α) : m (Array β) :=
as.mfoldl (λ bs a, do b ← f a, pure (bs.push b)) (mkEmpty as.size)
as.mfoldl (λ bs a, do b ← f a; pure (bs.push b)) (mkEmpty as.size)
@[implementedBy Array.ummapIdx] def mmapIdx (f : Nat → α → m β) (as : Array α) : m (Array β) :=
as.miterate (mkEmpty as.size) (λ i a bs, do b ← f i.val a, pure (bs.push b))
as.miterate (mkEmpty as.size) (λ i a bs, do b ← f i.val a; pure (bs.push b))
end
@[inline] def modify [Inhabited α] (a : Array α) (i : Nat) (f : αα) : Array α :=

View file

@ -17,7 +17,7 @@ variables {α : Type u} {β : Type v} {δ : Type w} {m : Type w → Type w} [Mon
@[specialize] def mfoldl (f : δ → α → β → m δ) : δ → AssocList α β → m δ
| d nil := pure d
| d (cons a b es) := do d ← f d a b, mfoldl d es
| d (cons a b es) := do d ← f d a b; mfoldl d es
@[inline] def foldl (f : δ → α → β → δ) (d : δ) (as : AssocList α β) : δ :=
Id.run (mfoldl f d as)

View file

@ -133,7 +133,7 @@ variables {m : Type v → Type v} [Monad m]
| (leaf vs) b := vs.mfoldl f b
@[specialize] def mfoldl (f : β → α → m β) (b : β) (t : PersistentArray α) : m β :=
do b ← mfoldlAux f t.root b, t.tail.mfoldl f b
do b ← mfoldlAux f t.root b; t.tail.mfoldl f b
end
@ -152,8 +152,8 @@ variables {m : Type v → Type v} [Monad m]
@[specialize] def mmap (f : α → m β) (t : PersistentArray α) : m (PersistentArray β) :=
do
root ← mmapAux f t.root,
tail ← t.tail.mmap f,
root ← mmapAux f t.root;
tail ← t.tail.mmap f;
pure { tail := tail, root := root, .. t }
end

View file

@ -115,7 +115,7 @@ def IO.setRandSeed (n : Nat) : IO Unit :=
IO.stdGenRef.set (mkStdGen n)
def IO.rand (lo hi : Nat) : IO Nat :=
do gen ← IO.stdGenRef.get,
let (r, gen) := randNat gen lo hi,
IO.stdGenRef.set gen,
do gen ← IO.stdGenRef.get;
let (r, gen) := randNat gen lo hi;
IO.stdGenRef.set gen;
pure r

View file

@ -41,8 +41,8 @@ protected def max : RBNode α β → Option (Sigma (λ k : α, β k))
@[specialize] def mfold {m : Type w → Type w'} [Monad m] (f : σ → Π (k : α), β k → m σ) : σ → RBNode α β → m σ
| b leaf := pure b
| b (node _ l k v r) := do
b ← mfold b l,
b ← f b k v,
b ← mfold b l;
b ← f b k v;
mfold b r
@[specialize] def revFold (f : σ → Π (k : α), β k → σ) : σ → RBNode α β → σ

View file

@ -85,7 +85,7 @@ open Fs
@[specialize] partial def iterate {α β : Type} : α → (α → IO (Sum α β)) → IO β
| a f :=
do v ← f a,
do v ← f a;
match v with
| Sum.inl a := iterate a f
| Sum.inr b := pure b
@ -154,18 +154,18 @@ do b ← h.read 1,
def handle.readToEnd (h : handle) : m String :=
Prim.liftIO $ Prim.iterate "" $ λ r, do
done ← h.isEof,
done ← h.isEof;
if done
then pure (Sum.inr r) -- stop
else do
-- HACK: use less efficient `getLine` while `read` is broken
c ← h.getLine,
c ← h.getLine;
pure $ Sum.inl (r ++ c) -- continue
def readFile (fname : String) (bin := false) : m String :=
do h ← handle.mk fname Mode.read bin,
r ← h.readToEnd,
h.close,
do h ← handle.mk fname Mode.read bin;
r ← h.readToEnd;
h.close;
pure r
-- def writeFile (fname : String) (data : String) (bin := false) : m Unit :=
@ -229,8 +229,8 @@ variables {m : Type → Type} [Monad m] [monadIO m]
@[inline] def Ref.swap {α : Type} (r : Ref α) (a : α) : m α := Prim.liftIO (Prim.Ref.swap r a)
@[inline] def Ref.reset {α : Type} (r : Ref α) : m Unit := Prim.liftIO (Prim.Ref.reset r)
@[inline] def Ref.modify {α : Type} (r : Ref α) (f : αα) : m Unit :=
do v ← r.get,
r.reset,
do v ← r.get;
r.reset;
r.set (f v)
end
end IO
@ -259,7 +259,7 @@ instance HasRepr.HasEval {α : Type u} [HasRepr α] : HasEval α :=
⟨λ a, IO.println (repr a)⟩
instance IO.HasEval {α : Type} [HasEval α] : HasEval (IO α) :=
⟨λ x, do a ← x, HasEval.eval a⟩
⟨λ x, do a ← x; HasEval.eval a⟩
-- special case: do not print `()`
instance IOUnit.HasEval : HasEval (IO Unit) :=

View file

@ -145,31 +145,31 @@ constant attributeArrayRef : IO.Ref (Array AttributeImpl) := default _
/- Low level attribute registration function. -/
def registerAttribute (attr : AttributeImpl) : IO Unit :=
do m ← attributeMapRef.get,
when (m.contains attr.name) $ throw (IO.userError ("invalid attribute declaration, '" ++ toString attr.name ++ "' has already been used")),
initializing ← IO.initializing,
unless initializing $ throw (IO.userError ("failed to register attribute, attributes can only be registered during initialization")),
attributeMapRef.modify (λ m, m.insert attr.name attr),
do m ← attributeMapRef.get;
when (m.contains attr.name) $ throw (IO.userError ("invalid attribute declaration, '" ++ toString attr.name ++ "' has already been used"));
initializing ← IO.initializing;
unless initializing $ throw (IO.userError ("failed to register attribute, attributes can only be registered during initialization"));
attributeMapRef.modify (λ m, m.insert attr.name attr);
attributeArrayRef.modify (λ attrs, attrs.push attr)
/- Return true iff `n` is the name of a registered attribute. -/
@[export lean.is_attribute_core]
def isAttribute (n : Name) : IO Bool :=
do m ← attributeMapRef.get, pure (m.contains n)
do m ← attributeMapRef.get; pure (m.contains n)
/- Return the name of all registered attributes. -/
def getAttributeNames : IO (List Name) :=
do m ← attributeMapRef.get, pure $ m.fold (λ r n _, n::r) []
do m ← attributeMapRef.get; pure $ m.fold (λ r n _, n::r) []
def getAttributeImpl (attrName : Name) : IO AttributeImpl :=
do m ← attributeMapRef.get,
do m ← attributeMapRef.get;
match m.find attrName with
| some attr := pure attr
| none := throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
@[export lean.attribute_application_time_core]
def attributeApplicationTime (n : Name) : IO AttributeApplicationTime :=
do attr ← getAttributeImpl n,
do attr ← getAttributeImpl n;
pure attr.applicationTime
namespace Environment
@ -181,7 +181,7 @@ namespace Environment
- `args` is not valid for `attr`. -/
@[export lean.add_attribute_core]
def addAttribute (env : Environment) (decl : Name) (attrName : Name) (args : Syntax := Syntax.missing) (persistent := true) : IO Environment :=
do attr ← getAttributeImpl attrName,
do attr ← getAttributeImpl attrName;
attr.add env decl args persistent
/- Add a scoped attribute `attr` to declaration `decl` with arguments `args` and scope `decl.getPrefix`.
@ -194,7 +194,7 @@ do attr ← getAttributeImpl attrName,
Remark: the attribute will not be activated if `decl` is not inside the current namespace `env.getNamespace`. -/
@[export lean.add_scoped_attribute_core]
def addScopedAttribute (env : Environment) (decl : Name) (attrName : Name) (args : Syntax := Syntax.missing) : IO Environment :=
do attr ← getAttributeImpl attrName,
do attr ← getAttributeImpl attrName;
attr.addScoped env decl args
/- Remove attribute `attr` from declaration `decl`. The effect is the current scope.
@ -204,36 +204,36 @@ do attr ← getAttributeImpl attrName,
- `args` is not valid for `attr`. -/
@[export lean.erase_attribute_core]
def eraseAttribute (env : Environment) (decl : Name) (attrName : Name) (persistent := true) : IO Environment :=
do attr ← getAttributeImpl attrName,
do attr ← getAttributeImpl attrName;
attr.erase env decl persistent
/- Activate the scoped attribute `attr` for all declarations in scope `scope`.
We use this function to implement the command `open foo`. -/
@[export lean.activate_scoped_attribute_core]
def activateScopedAttribute (env : Environment) (attrName : Name) (scope : Name) : IO Environment :=
do attr ← getAttributeImpl attrName,
do attr ← getAttributeImpl attrName;
attr.activateScoped env scope
/- Activate all scoped attributes at `scope` -/
@[export lean.activate_scoped_attributes_core]
def activateScopedAttributes (env : Environment) (scope : Name) : IO Environment :=
do attrs ← attributeArrayRef.get,
do attrs ← attributeArrayRef.get;
attrs.mfoldl (λ env attr, attr.activateScoped env scope) env
/- We use this function to implement commands `namespace foo` and `section foo`.
It activates scoped attributes in the new resulting namespace. -/
@[export lean.push_scope_core]
def pushScope (env : Environment) (header : Name) (isNamespace : Bool) : IO Environment :=
do let env := env.pushScopeCore header isNamespace,
let ns := env.getNamespace,
attrs ← attributeArrayRef.get,
attrs.mfoldl (λ env attr, do env ← attr.pushScope env, if isNamespace then attr.activateScoped env ns else pure env) env
do let env := env.pushScopeCore header isNamespace;
let ns := env.getNamespace;
attrs ← attributeArrayRef.get;
attrs.mfoldl (λ env attr, do env ← attr.pushScope env; if isNamespace then attr.activateScoped env ns else pure env) env
/- We use this function to implement commands `end foo` for closing namespaces and sections. -/
@[export lean.pop_scope_core]
def popScope (env : Environment) : IO Environment :=
do let env := env.popScopeCore,
attrs ← attributeArrayRef.get,
do let env := env.popScopeCore;
attrs ← attributeArrayRef.get;
attrs.mfoldl (λ env attr, attr.popScope env) env
end Environment
@ -260,20 +260,20 @@ ext : PersistentEnvExtension Name NameSet ← registerPersistentEnvExtension {
let r : Array Name := es.fold (λ a e, a.push e) Array.empty;
r.qsort Name.quickLt,
statsFn := λ s, "tag attribute" ++ Format.line ++ "number of local entries: " ++ format s.size
},
};
let attrImpl : AttributeImpl := {
name := name,
descr := descr,
add := λ env decl args persistent, do
unless args.isMissing $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', unexpected argument")),
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent")),
unless args.isMissing $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', unexpected argument"));
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent"));
unless (env.getModuleIdxFor decl).isNone $
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module")),
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module"));
match validate env decl with
| Except.error msg := throw (IO.userError ("invalid attribute '" ++ toString name ++ "', " ++ msg))
| _ := pure $ ext.addEntry env decl
},
registerAttribute attrImpl,
};
registerAttribute attrImpl;
pure { attr := attrImpl, ext := ext }
namespace TagAttribute
@ -309,23 +309,23 @@ ext : PersistentEnvExtension (Name × α) (NameMap α) ← registerPersistentEnv
let r : Array (Name × α) := m.fold (λ a n p, a.push (n, p)) Array.empty;
r.qsort (λ a b, Name.quickLt a.1 b.1),
statsFn := λ s, "parametric attribute" ++ Format.line ++ "number of local entries: " ++ format s.size
},
};
let attrImpl : AttributeImpl := {
name := name,
descr := descr,
add := λ env decl args persistent, do
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent")),
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent"));
unless (env.getModuleIdxFor decl).isNone $
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module")),
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module"));
match getParam env decl args with
| Except.error msg := throw (IO.userError ("invalid attribute '" ++ toString name ++ "', " ++ msg))
| Except.ok val := do
let env := ext.addEntry env (decl, val),
let env := ext.addEntry env (decl, val);
match afterSet env decl val with
| Except.error msg := throw (IO.userError ("invalid attribute '" ++ toString name ++ "', " ++ msg))
| Except.ok env := pure env
},
registerAttribute attrImpl,
};
registerAttribute attrImpl;
pure { attr := attrImpl, ext := ext }
namespace ParametricAttribute
@ -368,19 +368,19 @@ ext : PersistentEnvExtension (Name × α) (NameMap α) ← registerPersistentEnv
let r : Array (Name × α) := m.fold (λ a n p, a.push (n, p)) Array.empty;
r.qsort (λ a b, Name.quickLt a.1 b.1),
statsFn := λ s, "enumeration attribute extension" ++ Format.line ++ "number of local entries: " ++ format s.size
},
};
let attrs := attrDescrs.map $ λ ⟨name, descr, val⟩, { AttributeImpl .
name := name,
descr := descr,
add := λ env decl args persistent, do
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent")),
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString name ++ "', must be persistent"));
unless (env.getModuleIdxFor decl).isNone $
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module")),
throw (IO.userError ("invalid attribute '" ++ toString name ++ "', declaration is in an imported module"));
match validate env decl val with
| Except.error msg := throw (IO.userError ("invalid attribute '" ++ toString name ++ "', " ++ msg))
| _ := pure $ ext.addEntry env (decl, val)
},
attrs.mfor registerAttribute,
};
attrs.mfor registerAttribute;
pure { ext := ext, attrs := attrs }
namespace EnumAttributes

View file

@ -101,11 +101,11 @@ private def consumeNLambdas : Nat → Expr → Option Expr
partial def getClassName (env : Environment) : Expr → Option Name
| (Expr.pi _ _ _ d) := getClassName d
| e := do
Expr.const c _ ← pure e.getAppFn | none,
info ← env.find c,
Expr.const c _ ← pure e.getAppFn | none;
info ← env.find c;
match info.value with
| some val := do
body ← consumeNLambdas e.getAppNumArgs val,
body ← consumeNLambdas e.getAppNumArgs val;
getClassName body
| none :=
if isClass env c then some c
@ -125,8 +125,8 @@ registerAttribute {
name := `class,
descr := "type class",
add := λ env decl args persistent, do
unless args.isMissing $ throw (IO.userError ("invalid attribute 'class', unexpected argument")),
unless persistent $ throw (IO.userError ("invalid attribute 'class', must be persistent")),
unless args.isMissing $ throw (IO.userError ("invalid attribute 'class', unexpected argument"));
unless persistent $ throw (IO.userError ("invalid attribute 'class', must be persistent"));
IO.ofExcept (addClass env decl)
}
@ -135,8 +135,8 @@ registerAttribute {
name := `instance,
descr := "type class instance",
add := λ env decl args persistent, do
unless args.isMissing $ throw (IO.userError ("invalid attribute 'instance', unexpected argument")),
unless persistent $ throw (IO.userError ("invalid attribute 'instance', must be persistent")),
unless args.isMissing $ throw (IO.userError ("invalid attribute 'instance', unexpected argument"));
unless persistent $ throw (IO.userError ("invalid attribute 'instance', must be persistent"));
IO.ofExcept (addInstance env decl)
}

View file

@ -58,9 +58,9 @@ def mkUInt32Lit (n : Nat) : Expr :=
mkUIntLit {nbits := 32} n
def foldBinUInt (fn : NumScalarTypeInfo → Bool → Nat → Nat → Nat) (beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr :=
do n₁ ← getNumLit a₁,
n₂ ← getNumLit a₂,
info ← getInfoFromVal a₁,
do n₁ ← getNumLit a₁;
n₂ ← getNumLit a₂;
info ← getInfoFromVal a₁;
pure $ mkUIntLit info (fn info beforeErasure n₁ n₂)
def foldUIntAdd := foldBinUInt $ λ _ _, (+)
@ -77,8 +77,8 @@ def uintBinFoldFns : List (Name × BinFoldFn) :=
numScalarTypes.foldl (λ r info, r ++ (preUIntBinFoldFns.map (λ ⟨suffix, fn⟩, (info.id ++ suffix, fn)))) []
def foldNatBinOp (fn : Nat → Nat → Nat) (a₁ a₂ : Expr) : Option Expr :=
do n₁ ← getNumLit a₁,
n₂ ← getNumLit a₂,
do n₁ ← getNumLit a₁;
n₂ ← getNumLit a₂;
pure $ Expr.lit (Literal.natVal (fn n₁ n₂))
def foldNatAdd (_ : Bool) := foldNatBinOp (+)
@ -105,8 +105,8 @@ match beforeErasure, r with
def foldNatBinPred (mkPred : Expr → Expr → Expr) (fn : Nat → Nat → Bool)
(beforeErasure : Bool) (a₁ a₂ : Expr) : Option Expr :=
do n₁ ← getNumLit a₁,
n₂ ← getNumLit a₂,
do n₁ ← getNumLit a₁;
n₂ ← getNumLit a₂;
pure $ toDecidableExpr beforeErasure (mkPred a₁ a₂) (fn n₁ n₂)
def foldNatDecEq := foldNatBinPred mkNatEq (λ a b, a = b)
@ -156,18 +156,18 @@ def binFoldFns : List (Name × BinFoldFn) :=
boolFoldFns ++ uintBinFoldFns ++ natFoldFns
def foldNatSucc (_ : Bool) (a : Expr) : Option Expr :=
do n ← getNumLit a,
do n ← getNumLit a;
pure $ Expr.lit (Literal.natVal (n+1))
def foldCharOfNat (beforeErasure : Bool) (a : Expr) : Option Expr :=
do guard (!beforeErasure),
n ← getNumLit a,
do guard (!beforeErasure);
n ← getNumLit a;
pure $
if isValidChar (UInt32.ofNat n) then mkUInt32Lit n
else mkUInt32Lit 0
def foldToNat (_ : Bool) (a : Expr) : Option Expr :=
do n ← getNumLit a,
do n ← getNumLit a;
pure $ Expr.lit (Literal.natVal n)
def uintFoldToNatFns : List (Name × UnFoldFn) :=
@ -188,7 +188,7 @@ unFoldFns.lookup fn
def foldBinOp (beforeErasure : Bool) (f : Expr) (a : Expr) (b : Expr) : Option Expr :=
match f with
| Expr.const fn _ := do
foldFn ← findBinFoldFn fn,
foldFn ← findBinFoldFn fn;
foldFn beforeErasure a b
| _ := none
@ -196,7 +196,7 @@ match f with
def foldUnOp (beforeErasure : Bool) (f : Expr) (a : Expr) : Option Expr :=
match f with
| Expr.const fn _ := do
foldFn ← findUnFoldFn fn,
foldFn ← findUnFoldFn fn;
foldFn beforeErasure a
| _ := none

View file

@ -147,7 +147,7 @@ def getExternEntryFor (d : ExternAttrData) (backend : Name) : Option ExternEntry
getExternEntryForAux backend d.entries
def mkExternCall (d : ExternAttrData) (backend : Name) (args : List String) : Option String :=
do e ← getExternEntryFor d backend,
do e ← getExternEntryFor d backend;
expandExternEntry e args
def isExtern (env : Environment) (fn : Name) : Bool :=
@ -161,8 +161,8 @@ match getExternAttrData env fn with
| _ := false
def getExternNameFor (env : Environment) (backend : Name) (fn : Name) : Option String :=
do data ← getExternAttrData env fn,
entry ← getExternEntryFor data backend,
do data ← getExternAttrData env fn;
entry ← getExternEntryFor data backend;
match entry with
| ExternEntry.standard _ n := pure n
| ExternEntry.foreign _ n := pure n

View file

@ -353,7 +353,7 @@ bs.map $ λ b, match b with
@[inline] def mmodifyJPs {m : Type → Type} [Monad m] (bs : Array FnBody) (f : FnBody → m FnBody) : m (Array FnBody) :=
bs.mmap $ λ b, match b with
| FnBody.jdecl j xs v k := do v ← f v, pure $ FnBody.jdecl j xs v k
| FnBody.jdecl j xs v k := do v ← f v; pure $ FnBody.jdecl j xs v k
| other := pure other
@[export lean.ir.mk_alt_core] def mkAlt (n : Name) (cidx : Nat) (size : Nat) (usize : Nat) (ssize : Nat) (b : FnBody) : Alt := Alt.ctor ⟨n, cidx, size, usize, ssize⟩ b
@ -500,7 +500,7 @@ else none
def addParamsRename (ρ : IndexRenaming) (ps₁ ps₂ : Array Param) : Option IndexRenaming :=
if ps₁.size != ps₂.size then none
else Array.foldl₂ (λ ρ p₁ p₂, do ρρ, addParamRename ρ p₁ p₂) (some ρ) ps₁ ps₂
else Array.foldl₂ (λ ρ p₁ p₂, do ρρ; addParamRename ρ p₁ p₂) (some ρ) ps₁ ps₂
partial def FnBody.alphaEqv : IndexRenaming → FnBody → FnBody → Bool
| ρ (FnBody.vdecl x₁ t₁ v₁ b₁) (FnBody.vdecl x₂ t₂ v₂ b₂) := t₁ == t₂ && aeqv ρ v₁ v₂ && FnBody.alphaEqv (addVarRename ρ x₁.idx x₂.idx) b₁ b₂

View file

@ -70,19 +70,19 @@ if exported then ps else initBorrow ps
partial def visitFnBody (fnid : FunId) : FnBody → State ParamMap Unit
| (FnBody.jdecl j xs v b) := do
modify $ λ m, m.insert (Key.jp fnid j) (initBorrow xs),
visitFnBody v,
modify $ λ m, m.insert (Key.jp fnid j) (initBorrow xs);
visitFnBody v;
visitFnBody b
| e :=
unless (e.isTerminal) $ do
let (instr, b) := e.split,
let (instr, b) := e.split;
visitFnBody b
def visitDecls (env : Environment) (decls : Array Decl) : State ParamMap Unit :=
decls.mfor $ λ decl, match decl with
| Decl.fdecl f xs _ b := do
let exported := isExport env f,
modify $ λ m, m.insert (Key.decl f) (initBorrowIfNotExported exported xs),
let exported := isExport env f;
modify $ λ m, m.insert (Key.decl f) (initBorrowIfNotExported exported xs);
visitFnBody f b
| _ := pure ()
end InitParamMap
@ -156,32 +156,32 @@ def ownArgs (xs : Array Arg) : M Unit :=
xs.mfor ownArg
def isOwned (x : VarId) : M Bool :=
do s ← get,
do s ← get;
pure $ s.owned.contains x.idx
/- Updates `map[k]` using the current set of `owned` variables. -/
def updateParamMap (k : Key) : M Unit :=
do
s ← get,
s ← get;
match s.map.find k with
| some ps := do
ps ← ps.mmap $ λ (p : Param),
if p.borrow && s.owned.contains p.x.idx then do
markModifiedParamMap, pure { borrow := false, .. p }
markModifiedParamMap; pure { borrow := false, .. p }
else
pure p,
pure p;
modify $ λ s, { map := s.map.insert k ps, .. s }
| none := pure ()
def getParamInfo (k : Key) : M (Array Param) :=
do
s ← get,
s ← get;
match s.map.find k with
| some ps := pure ps
| none :=
match k with
| (Key.decl fn) := do
ctx ← read,
ctx ← read;
match findEnvDecl ctx.env fn with
| some decl := pure decl.params
| none := pure Array.empty -- unreachable if well-formed input
@ -190,8 +190,8 @@ match s.map.find k with
/- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.mfor $ λ i, do
let x := xs.get i,
let p := ps.get i,
let x := xs.get i;
let p := ps.get i;
unless p.borrow $ ownArg x
/- For each xs[i], if xs[i] is owned, then mark ps[i] as owned.
@ -201,8 +201,8 @@ xs.size.mfor $ λ i, do
"break" the tail call. -/
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.mfor $ λ i, do
let x := xs.get i,
let p := ps.get i,
let x := xs.get i;
let p := ps.get i;
match x with
| Arg.var x := mwhen (isOwned x) $ ownVar p.x
| _ := pure ()
@ -219,7 +219,7 @@ xs.size.mfor $ λ i, do
-/
def ownArgsIfParam (xs : Array Arg) : M Unit :=
do
ctx ← read,
ctx ← read;
xs.mfor $ λ x,
match x with
| Arg.var x := when (ctx.paramSet.contains x.idx) $ ownVar x
@ -230,7 +230,7 @@ def collectExpr (z : VarId) : Expr → M Unit
| (Expr.reuse x _ _ ys) := ownVar z *> ownVar x *> ownArgsIfParam ys
| (Expr.ctor _ xs) := ownVar z *> ownArgsIfParam xs
| (Expr.proj _ x) := mwhen (isOwned z) $ ownVar x
| (Expr.fap g xs) := do ps ← getParamInfo (Key.decl g),
| (Expr.fap g xs) := do ps ← getParamInfo (Key.decl g);
-- dbgTrace ("collectExpr: " ++ toString g ++ " " ++ toString (formatParams ps)) $ λ _,
ownVar z *> ownArgsUsingParams xs ps
| (Expr.ap x ys) := ownVar z *> ownVar x *> ownArgs ys
@ -238,12 +238,12 @@ def collectExpr (z : VarId) : Expr → M Unit
| other := pure ()
def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit :=
do ctx ← read,
do ctx ← read;
match v, b with
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) :=
when (ctx.currFn == g && x == z) $ do
-- dbgTrace ("preserveTailCall " ++ toString b) $ λ _, do
ps ← getParamInfo (Key.decl g),
ps ← getParamInfo (Key.decl g);
ownParamsUsingArgs ys ps
| _, _ := pure ()
@ -252,24 +252,24 @@ def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx :=
partial def collectFnBody : FnBody → M Unit
| (FnBody.jdecl j ys v b) := do
adaptReader (λ ctx, updateParamSet ctx ys) (collectFnBody v),
ctx ← read,
updateParamMap (Key.jp ctx.currFn j),
adaptReader (λ ctx, updateParamSet ctx ys) (collectFnBody v);
ctx ← read;
updateParamMap (Key.jp ctx.currFn j);
collectFnBody b
| (FnBody.vdecl x _ v b) := collectFnBody b *> collectExpr x v *> preserveTailCall x v b
| (FnBody.jmp j ys) := do
ctx ← read,
ps ← getParamInfo (Key.jp ctx.currFn j),
ownArgsUsingParams ys ps, -- for making sure the join point can reuse
ctx ← read;
ps ← getParamInfo (Key.jp ctx.currFn j);
ownArgsUsingParams ys ps; -- for making sure the join point can reuse
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
| (FnBody.case _ _ alts) := alts.mfor $ λ alt, collectFnBody alt.body
| e := unless (e.isTerminal) $ collectFnBody e.body
@[specialize] partial def whileModifingOwnedAux (x : M Unit) : Unit → M Unit
| _ := do
modify $ λ s, { modifiedOwned := false, .. s },
x,
s ← get,
modify $ λ s, { modifiedOwned := false, .. s };
x;
s ← get;
if s.modifiedOwned then whileModifingOwnedAux ()
else pure ()
@ -280,18 +280,18 @@ whileModifingOwnedAux x ()
partial def collectDecl : Decl → M Unit
| (Decl.fdecl f ys _ b) :=
adaptReader (λ ctx, let ctx := updateParamSet ctx ys; { currFn := f, .. ctx }) $ do
modify $ λ s : BorrowInfState, { owned := {}, .. s },
whileModifingOwned (collectFnBody b),
modify $ λ s : BorrowInfState, { owned := {}, .. s };
whileModifingOwned (collectFnBody b);
updateParamMap (Key.decl f)
| _ := pure ()
@[specialize] partial def whileModifingParamMapAux (x : M Unit) : Unit → M Unit
| _ := do
modify $ λ s, { modifiedParamMap := false, .. s },
s ← get,
modify $ λ s, { modifiedParamMap := false, .. s };
s ← get;
-- dbgTrace (toString s.map) $ λ _, do
x,
s ← get,
x;
s ← get;
if s.modifiedParamMap then whileModifingParamMapAux ()
else pure ()
@ -301,8 +301,8 @@ whileModifingParamMapAux x ()
def collectDecls (decls : Array Decl) : M ParamMap :=
do
whileModifingParamMap (decls.mfor collectDecl),
s ← get,
whileModifingParamMap (decls.mfor collectDecl);
s ← get;
pure s.map
def infer (env : Environment) (decls : Array Decl) : ParamMap :=
@ -312,9 +312,9 @@ end Borrow
def inferBorrow (decls : Array Decl) : CompilerM (Array Decl) :=
do
env ← getEnv,
let decls := decls.map Decl.normalizeIds,
let paramMap := Borrow.infer env decls,
env ← getEnv;
let decls := decls.map Decl.normalizeIds;
let paramMap := Borrow.infer env decls;
pure (Borrow.applyParamMap decls paramMap)
end IR

View file

@ -35,8 +35,8 @@ Name.mkString n "_boxed"
abbrev N := State Nat
private def mkFresh : N VarId :=
do idx ← get,
modify (+1),
do idx ← get;
modify (+1);
pure {idx := idx }
def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
@ -45,8 +45,8 @@ ps.size > 0 && (decl.resultType.isScalar || ps.any (λ p, p.ty.isScalar || p.bor
def mkBoxedVersionAux (decl : Decl) : N Decl :=
do
let ps := decl.params,
qs ← ps.mmap (λ _, do x ← mkFresh, pure { Param . x := x, ty := IRType.object, borrow := false }),
let ps := decl.params;
qs ← ps.mmap (λ _, do x ← mkFresh; pure { Param . x := x, ty := IRType.object, borrow := false });
(newVDecls, xs) ← qs.size.mfold
(λ i (r : Array FnBody × Array Arg),
let (newVDecls, xs) := r;
@ -54,19 +54,19 @@ qs ← ps.mmap (λ _, do x ← mkFresh, pure { Param . x := x, ty := IRType.obje
let q := qs.get i;
if !p.ty.isScalar then pure (newVDecls, xs.push (Arg.var q.x))
else do
x ← mkFresh,
x ← mkFresh;
pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (default _)), xs.push (Arg.var x)))
(Array.empty, Array.empty),
r ← mkFresh,
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (default _)),
(Array.empty, Array.empty);
r ← mkFresh;
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (default _));
body ←
if !decl.resultType.isScalar then do {
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
} else do {
newR ← mkFresh,
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (default _)),
newR ← mkFresh;
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (default _));
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
},
};
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
def mkBoxedVersion (decl : Decl) : Decl :=
@ -111,7 +111,7 @@ structure BoxingContext :=
abbrev M := ReaderT BoxingContext (StateT Index Id)
def mkFresh : M VarId :=
do idx ← getModify (+1),
do idx ← getModify (+1);
pure { idx := idx }
def getEnv : M Environment := BoxingContext.env <$> read
@ -119,17 +119,17 @@ def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read
def getResultType : M IRType := BoxingContext.resultType <$> read
def getVarType (x : VarId) : M IRType :=
do localCtx ← getLocalContext,
do localCtx ← getLocalContext;
match localCtx.getType x with
| some t := pure t
| none := pure IRType.object -- unreachable, we assume the code is well formed
def getJPParams (j : JoinPointId) : M (Array Param) :=
do localCtx ← getLocalContext,
do localCtx ← getLocalContext;
match localCtx.getJPParams j with
| some ys := pure ys
| none := pure Array.empty -- unreachable, we assume the code is well formed
def getDecl (fid : FunId) : M Decl :=
do ctx ← read,
do ctx ← read;
match findEnvDecl' ctx.env fid ctx.decls with
| some decl := pure decl
| none := pure (default _) -- unreachable if well-formed
@ -150,11 +150,11 @@ def mkCast (x : VarId) (xType : IRType) : Expr :=
if xType.isScalar then Expr.box xType x else Expr.unbox x
@[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody :=
do xType ← getVarType x,
do xType ← getVarType x;
if eqvTypes xType expected then k x
else do
y ← mkFresh,
let v := mkCast x xType,
y ← mkFresh;
let v := mkCast x xType;
FnBody.vdecl y expected v <$> k y
@[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody :=
@ -169,27 +169,27 @@ xs.miterate (Array.empty, Array.empty) $ λ i (x : Arg) (r : Array Arg × Array
match x with
| Arg.irrelevant := pure (xs.push x, bs)
| Arg.var x := do
xType ← getVarType x,
xType ← getVarType x;
if eqvTypes xType expected then pure (xs.push (Arg.var x), bs)
else do
y ← mkFresh,
let v := mkCast x xType,
let b := FnBody.vdecl y expected v FnBody.nil,
y ← mkFresh;
let v := mkCast x xType;
let b := FnBody.vdecl y expected v FnBody.nil;
pure (xs.push (Arg.var y), bs.push b)
@[inline] def castArgsIfNeeded (xs : Array Arg) (ps : Array Param) (k : Array Arg → M FnBody) : M FnBody :=
do (ys, bs) ← castArgsIfNeededAux xs (λ i, (ps.get i).ty),
b ← k ys,
do (ys, bs) ← castArgsIfNeededAux xs (λ i, (ps.get i).ty);
b ← k ys;
pure (reshape bs b)
@[inline] def boxArgsIfNeeded (xs : Array Arg) (k : Array Arg → M FnBody) : M FnBody :=
do (ys, bs) ← castArgsIfNeededAux xs (λ _, IRType.object),
b ← k ys,
do (ys, bs) ← castArgsIfNeededAux xs (λ _, IRType.object);
b ← k ys;
pure (reshape bs b)
def unboxResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
if ty.isScalar then do
y ← mkFresh,
y ← mkFresh;
pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b)
else
pure $ FnBody.vdecl x ty e b
@ -197,7 +197,7 @@ else
def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b : FnBody) : M FnBody :=
if eqvTypes ty eType then pure $ FnBody.vdecl x ty e b
else do
y ← mkFresh,
y ← mkFresh;
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty (mkCast y eType) b)
def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
@ -210,13 +210,13 @@ match e with
| Expr.reuse w c u ys :=
boxArgsIfNeeded ys $ λ ys, pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b
| Expr.fap f ys := do
decl ← getDecl f,
decl ← getDecl f;
castArgsIfNeeded ys decl.params $ λ ys,
castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b
| Expr.pap f ys := do
env ← getEnv,
decl ← getDecl f,
let f := if requiresBoxedVersion env decl then mkBoxedName f else f,
env ← getEnv;
decl ← getDecl f;
let f := if requiresBoxedVersion env decl then mkBoxedName f else f;
boxArgsIfNeeded ys $ λ ys, pure $ FnBody.vdecl x ty (Expr.pap f ys) b
| Expr.ap f ys :=
boxArgsIfNeeded ys $ λ ys,
@ -226,32 +226,32 @@ match e with
partial def visitFnBody : FnBody → M FnBody
| (FnBody.vdecl x t v b) := do
b ← withVDecl x t v (visitFnBody b),
b ← withVDecl x t v (visitFnBody b);
visitVDeclExpr x t v b
| (FnBody.jdecl j xs v b) := do
v ← withParams xs (visitFnBody v),
b ← withJDecl j xs v (visitFnBody b),
v ← withParams xs (visitFnBody v);
b ← withJDecl j xs v (visitFnBody b);
pure $ FnBody.jdecl j xs v b
| (FnBody.uset x i y b) := do
b ← visitFnBody b,
b ← visitFnBody b;
castVarIfNeeded y IRType.usize $ λ y,
pure $ FnBody.uset x i y b
| (FnBody.sset x i o y ty b) := do
b ← visitFnBody b,
b ← visitFnBody b;
castVarIfNeeded y ty $ λ y,
pure $ FnBody.sset x i o y ty b
| (FnBody.mdata d b) :=
FnBody.mdata d <$> visitFnBody b
| (FnBody.case tid x alts) := do
let expected := getScrutineeType alts,
alts ← alts.mmap $ λ alt, alt.mmodifyBody visitFnBody,
let expected := getScrutineeType alts;
alts ← alts.mmap $ λ alt, alt.mmodifyBody visitFnBody;
castVarIfNeeded x expected $ λ x,
pure $ FnBody.case tid x alts
| (FnBody.ret x) := do
expected ← getResultType,
expected ← getResultType;
castArgIfNeeded x expected (λ x, pure $ FnBody.ret x)
| (FnBody.jmp j ys) := do
ps ← getJPParams j,
ps ← getJPParams j;
castArgsIfNeeded ys ps (λ ys, pure $ FnBody.jmp j ys)
| other :=
pure other
@ -269,7 +269,7 @@ addBoxedVersions env decls
end ExplicitBoxing
def explicitBoxing (decls : Array Decl) : CompilerM (Array Decl) :=
do env ← getEnv,
do env ← getEnv;
pure $ ExplicitBoxing.run env decls
end IR

View file

@ -17,17 +17,17 @@ structure Context :=
abbrev M := ExceptT String (ReaderT Context Id)
def getDecl (c : Name) : M Decl :=
do ctx ← read,
do ctx ← read;
match findEnvDecl' ctx.env c ctx.decls with
| none := throw ("unknown declaration '" ++ toString c ++ "'")
| some d := pure d
def checkVar (x : VarId) : M Unit :=
do ctx ← read,
do ctx ← read;
unless (ctx.localCtx.isLocalVar x.idx || ctx.localCtx.isParam x.idx) $ throw ("unknown variable '" ++ toString x ++ "'")
def checkJP (j : JoinPointId) : M Unit :=
do ctx ← read,
do ctx ← read;
unless (ctx.localCtx.isJP j.idx) $ throw ("unknown join point '" ++ toString j ++ "'")
def checkArg (a : Arg) : M Unit :=
@ -46,7 +46,7 @@ def checkObjType (ty : IRType) : M Unit := checkType ty IRType.isObj
def checkScalarType (ty : IRType) : M Unit := checkType ty IRType.isScalar
@[inline] def checkVarType (x : VarId) (p : IRType → Bool) : M Unit :=
do ctx ← read,
do ctx ← read;
match ctx.localCtx.getType x with
| some ty := checkType ty p
| none := throw ("unknown variable '" ++ toString x ++ "'")
@ -59,14 +59,14 @@ checkVarType x IRType.isScalar
def checkFullApp (c : FunId) (ys : Array Arg) : M Unit :=
do
decl ← getDecl c,
unless (ys.size == decl.params.size) (throw ("incorrect number of arguments to '" ++ toString c ++ "', " ++ toString ys.size ++ " provided, " ++ toString decl.params.size ++ " expected")),
decl ← getDecl c;
unless (ys.size == decl.params.size) (throw ("incorrect number of arguments to '" ++ toString c ++ "', " ++ toString ys.size ++ " provided, " ++ toString decl.params.size ++ " expected"));
checkArgs ys
def checkPartialApp (c : FunId) (ys : Array Arg) : M Unit :=
do
decl ← getDecl c,
unless (ys.size < decl.params.size) (throw ("too many arguments too partial application '" ++ toString c ++ "', num. args: " ++ toString ys.size ++ ", arity: " ++ toString decl.params.size)),
decl ← getDecl c;
unless (ys.size < decl.params.size) (throw ("too many arguments too partial application '" ++ toString c ++ "', num. args: " ++ toString ys.size ++ ", arity: " ++ toString decl.params.size));
checkArgs ys
def checkExpr (ty : IRType) : Expr → M Unit
@ -87,22 +87,22 @@ def checkExpr (ty : IRType) : Expr → M Unit
| (Expr.lit _) := pure ()
@[inline] def withParams (ps : Array Param) (k : M Unit) : M Unit :=
do ctx ← read,
do ctx ← read;
localCtx ← ps.mfoldl (λ (ctx : LocalContext) p, do
when (ctx.contains p.x.idx) $ throw ("invalid parameter declaration, shadowing is not allowed"),
pure $ ctx.addParam p) ctx.localCtx,
when (ctx.contains p.x.idx) $ throw ("invalid parameter declaration, shadowing is not allowed");
pure $ ctx.addParam p) ctx.localCtx;
adaptReader (λ _, { localCtx := localCtx, .. ctx }) k
partial def checkFnBody : FnBody → M Unit
| (FnBody.vdecl x t v b) := do
checkExpr t v,
ctx ← read,
when (ctx.localCtx.contains x.idx) $ throw ("invalid variable declaration, shadowing is not allowed"),
checkExpr t v;
ctx ← read;
when (ctx.localCtx.contains x.idx) $ throw ("invalid variable declaration, shadowing is not allowed");
adaptReader (λ ctx : Context, { localCtx := ctx.localCtx.addLocal x t v, .. ctx }) (checkFnBody b)
| (FnBody.jdecl j ys v b) := do
withParams ys (checkFnBody v),
ctx ← read,
when (ctx.localCtx.contains j.idx) $ throw ("invalid join point declaration, shadowing is not allowed"),
withParams ys (checkFnBody v);
ctx ← read;
when (ctx.localCtx.contains j.idx) $ throw ("invalid join point declaration, shadowing is not allowed");
adaptReader (λ ctx : Context, { localCtx := ctx.localCtx.addJP j ys v, .. ctx }) (checkFnBody b)
| (FnBody.set x _ y b) := checkVar x *> checkArg y *> checkFnBody b
| (FnBody.uset x _ y b) := checkVar x *> checkVar y *> checkFnBody b
@ -125,7 +125,7 @@ end Checker
def checkDecl (decls : Array Decl) (decl : Decl) : CompilerM Unit :=
do
env ← getEnv,
env ← getEnv;
match Checker.checkDecl decl { env := env, decls := decls } with
| Except.error msg := throw ("IR check failed at '" ++ toString decl.name ++ "', error: " ++ msg)
| other := pure ()

View file

@ -49,14 +49,14 @@ match opts.find optName with
| other := opts.getBool tracePrefixOptionName
private def logDeclsAux (optName : Name) (cls : Name) (decls : Array Decl) : CompilerM Unit :=
do opts ← read,
do opts ← read;
when (isLogEnabledFor opts optName) $ log (LogEntry.step cls decls)
@[inline] def logDecls (cls : Name) (decl : Array Decl) : CompilerM Unit :=
logDeclsAux (tracePrefixOptionName ++ cls) cls decl
private def logMessageIfAux {α : Type} [HasFormat α] (optName : Name) (a : α) : CompilerM Unit :=
do opts ← read,
do opts ← read;
when (isLogEnabledFor opts optName) $ log (LogEntry.message (format a))
@[inline] def logMessageIf {α : Type} [HasFormat α] (cls : Name) (a : α) : CompilerM Unit :=
@ -95,15 +95,15 @@ def findEnvDecl (env : Environment) (n : Name) : Option Decl :=
(declMapExt.getState env).find n
def findDecl (n : Name) : CompilerM (Option Decl) :=
do s ← get,
do s ← get;
pure $ findEnvDecl s.env n
def containsDecl (n : Name) : CompilerM Bool :=
do s ← get,
do s ← get;
pure $ (declMapExt.getState s.env).contains n
def getDecl (n : Name) : CompilerM Decl :=
do (some decl) ← findDecl n | throw ("unknown declaration '" ++ toString n ++ "'"),
do (some decl) ← findDecl n | throw ("unknown declaration '" ++ toString n ++ "'");
pure decl
@[export lean.ir.add_decl_core]
@ -114,7 +114,7 @@ def getDecls (env : Environment) : List Decl :=
declMapExt.getEntries env
def getEnv : CompilerM Environment :=
do s ← get, pure s.env
do s ← get; pure s.env
def addDecl (decl : Decl) : CompilerM Unit :=
modifyEnv (λ env, declMapExt.addEntry env decl)
@ -128,16 +128,16 @@ match decls.find (λ decl, if decl.name == n then some decl else none) with
| none := (declMapExt.getState env).find n
def findDecl' (n : Name) (decls : Array Decl) : CompilerM (Option Decl) :=
do s ← get, pure $ findEnvDecl' s.env n decls
do s ← get; pure $ findEnvDecl' s.env n decls
def containsDecl' (n : Name) (decls : Array Decl) : CompilerM Bool :=
if decls.any (λ decl, decl.name == n) then pure true
else do
s ← get,
s ← get;
pure $ (declMapExt.getState s.env).contains n
def getDecl' (n : Name) (decls : Array Decl) : CompilerM Decl :=
do (some decl) ← findDecl' n decls | throw ("unknown declaration '" ++ toString n ++ "'"),
do (some decl) ← findDecl' n decls | throw ("unknown declaration '" ++ toString n ++ "'");
pure decl
end IR

View file

@ -24,30 +24,30 @@ namespace IR
private def compileAux (decls : Array Decl) : CompilerM Unit :=
do
logDecls `init decls,
checkDecls decls,
let decls := decls.map Decl.pushProj,
logDecls `push_proj decls,
let decls := decls.map Decl.insertResetReuse,
logDecls `reset_reuse decls,
let decls := decls.map Decl.elimDead,
logDecls `elim_dead decls,
let decls := decls.map Decl.simpCase,
logDecls `simp_case decls,
let decls := decls.map Decl.normalizeIds,
decls ← inferBorrow decls,
logDecls `borrow decls,
decls ← explicitBoxing decls,
logDecls `boxing decls,
decls ← explicitRC decls,
logDecls `rc decls,
let decls := decls.map Decl.expandResetReuse,
logDecls `expand_reset_reuse decls,
let decls := decls.map Decl.pushProj,
logDecls `push_proj decls,
logDecls `result decls,
checkDecls decls,
addDecls decls,
logDecls `init decls;
checkDecls decls;
let decls := decls.map Decl.pushProj;
logDecls `push_proj decls;
let decls := decls.map Decl.insertResetReuse;
logDecls `reset_reuse decls;
let decls := decls.map Decl.elimDead;
logDecls `elim_dead decls;
let decls := decls.map Decl.simpCase;
logDecls `simp_case decls;
let decls := decls.map Decl.normalizeIds;
decls ← inferBorrow decls;
logDecls `borrow decls;
decls ← explicitBoxing decls;
logDecls `boxing decls;
decls ← explicitRC decls;
logDecls `rc decls;
let decls := decls.map Decl.expandResetReuse;
logDecls `expand_reset_reuse decls;
let decls := decls.map Decl.pushProj;
logDecls `push_proj decls;
logDecls `result decls;
checkDecls decls;
addDecls decls;
pure ()
@[export lean.ir.compile_core]

View file

@ -35,7 +35,7 @@ abbrev M := ReaderT Context (EState String String)
def getEnv : M Environment := Context.env <$> read
def getModName : M Name := Context.modName <$> read
def getDecl (n : Name) : M Decl :=
do env ← getEnv,
do env ← getEnv;
match findEnvDecl env n with
| some d := pure d
| none := throw ("unknown declaration '" ++ toString n ++ "'")
@ -77,7 +77,7 @@ def openNamespaces (n : Name) : M Unit :=
openNamespacesAux n.getPrefix
def openNamespacesFor (n : Name) : M Unit :=
do env ← getEnv,
do env ← getEnv;
match getExportNameFor env n with
| none := pure ()
| some n := openNamespaces n
@ -91,7 +91,7 @@ def closeNamespaces (n : Name) : M Unit :=
closeNamespacesAux n.getPrefix
def closeNamespacesFor (n : Name) : M Unit :=
do env ← getEnv,
do env ← getEnv;
match getExportNameFor env n with
| none := pure ()
| some n := closeNamespaces n
@ -100,14 +100,14 @@ def throwInvalidExportName {α : Type} (n : Name) : M α :=
throw ("invalid export name '" ++ toString n ++ "'")
def toBaseCppName (n : Name) : M String :=
do env ← getEnv,
do env ← getEnv;
match getExportNameFor env n with
| some (Name.mkString _ s) := pure s
| some _ := throwInvalidExportName n
| none := if n == `main then pure leanMainFn else pure n.mangle
def toCppName (n : Name) : M String :=
do env ← getEnv,
do env ← getEnv;
match getExportNameFor env n with
| some s := pure (s.toStringWithSep "::")
| none := if n == `main then pure leanMainFn else pure n.mangle
@ -116,7 +116,7 @@ def emitCppName (n : Name) : M Unit :=
toCppName n >>= emit
def toCppInitName (n : Name) : M String :=
do env ← getEnv,
do env ← getEnv;
match getExportNameFor env n with
| some (Name.mkString p s) := pure $ (Name.mkString p ("_init_" ++ s)).toStringWithSep "::"
| some _ := throwInvalidExportName n
@ -127,24 +127,24 @@ toCppInitName n >>= emit
def emitFnDeclAux (decl : Decl) (cppBaseName : String) (addExternForConsts : Bool) : M Unit :=
do
let ps := decl.params,
when (ps.isEmpty && addExternForConsts) (emit "extern "),
emit (toCppType decl.resultType ++ " " ++ cppBaseName),
let ps := decl.params;
when (ps.isEmpty && addExternForConsts) (emit "extern ");
emit (toCppType decl.resultType ++ " " ++ cppBaseName);
unless (ps.isEmpty) $ do {
emit "(",
emit "(";
ps.size.mfor $ λ i, do {
when (i > 0) (emit ", "),
when (i > 0) (emit ", ");
emit (toCppType (ps.get i).ty)
},
};
emit ")"
},
};
emitLn ";"
def emitFnDecl (decl : Decl) (addExternForConsts : Bool) : M Unit :=
do
openNamespacesFor decl.name,
cppBaseName ← toBaseCppName decl.name,
emitFnDeclAux decl cppBaseName addExternForConsts,
openNamespacesFor decl.name;
cppBaseName ← toBaseCppName decl.name;
emitFnDeclAux decl cppBaseName addExternForConsts;
closeNamespacesFor decl.name
def cppQualifiedNameToName (s : String) : Name :=
@ -152,48 +152,48 @@ def cppQualifiedNameToName (s : String) : Name :=
def emitExternDeclAux (decl : Decl) (cppName : String) : M Unit :=
do
let qCppName := cppQualifiedNameToName cppName,
openNamespaces qCppName,
env ← getEnv,
let extC := isExternC env decl.name,
when extC (emit "extern \"C\" "),
(Name.mkString _ qCppBaseName) ← pure qCppName | throw "invalid name",
emitFnDeclAux decl qCppBaseName (!extC),
let qCppName := cppQualifiedNameToName cppName;
openNamespaces qCppName;
env ← getEnv;
let extC := isExternC env decl.name;
when extC (emit "extern \"C\" ");
(Name.mkString _ qCppBaseName) ← pure qCppName | throw "invalid name";
emitFnDeclAux decl qCppBaseName (!extC);
closeNamespaces qCppName
def emitFnDecls : M Unit :=
do
env ← getEnv,
let decls := getDecls env,
let modDecls : NameSet := decls.foldl (λ s d, s.insert d.name) {},
let usedDecls : NameSet := decls.foldl (λ s d, collectUsedDecls env d (s.insert d.name)) {},
let usedDecls := usedDecls.toList,
env ← getEnv;
let decls := getDecls env;
let modDecls : NameSet := decls.foldl (λ s d, s.insert d.name) {};
let usedDecls : NameSet := decls.foldl (λ s d, collectUsedDecls env d (s.insert d.name)) {};
let usedDecls := usedDecls.toList;
usedDecls.mfor $ λ n, do
decl ← getDecl n,
decl ← getDecl n;
match getExternNameFor env `cpp decl.name with
| some cppName := emitExternDeclAux decl cppName
| none := emitFnDecl decl (!modDecls.contains n)
def emitMainFn : M Unit :=
do
d ← getDecl `main,
d ← getDecl `main;
match d with
| Decl.fdecl f xs t b := do
unless (xs.size == 2 || xs.size == 1) (throw "invalid main function, incorrect arity when generating code"),
env ← getEnv,
let usesLeanAPI := usesLeanNamespace env d,
when usesLeanAPI (emitLn "namespace lean { void initialize(); }"),
emitLn "int main(int argc, char ** argv) {",
unless (xs.size == 2 || xs.size == 1) (throw "invalid main function, incorrect arity when generating code");
env ← getEnv;
let usesLeanAPI := usesLeanNamespace env d;
when usesLeanAPI (emitLn "namespace lean { void initialize(); }");
emitLn "int main(int argc, char ** argv) {";
if usesLeanAPI then
emitLn "lean::initialize();"
else
emitLn "lean::initialize_runtime_module();",
emitLn "obj * w = lean::io_mk_world();",
modName ← getModName,
emitLn ("w = initialize_" ++ (modName.mangle "") ++ "(w);"),
emitLn "lean::initialize_runtime_module();";
emitLn "obj * w = lean::io_mk_world();";
modName ← getModName;
emitLn ("w = initialize_" ++ (modName.mangle "") ++ "(w);");
emitLns ["lean::io_mark_end_initialization();",
"if (io_result_is_ok(w)) {",
"lean::scoped_task_manager tmanager(lean::hardware_concurrency());"],
"lean::scoped_task_manager tmanager(lean::hardware_concurrency());"];
if xs.size == 2 then do {
emitLns ["obj* in = lean::box(0);",
"int i = argc;",
@ -201,12 +201,12 @@ match d with
" i--;",
" obj* n = lean::alloc_cnstr(1,2,0); lean::cnstr_set(n, 0, lean::mk_string(argv[i])); lean::cnstr_set(n, 1, in);",
" in = n;",
"}"],
"}"];
emitLn ("w = " ++ leanMainFn ++ "(in, w);")
} else do {
emitLn ("w = " ++ leanMainFn ++ "(w);")
},
emitLn "}",
};
emitLn "}";
emitLns ["if (io_result_is_ok(w)) {",
" int ret = lean::unbox(io_result_get_value(w));",
" lean::dec_ref(w);",
@ -215,13 +215,13 @@ match d with
" lean::io_result_show_error(w);",
" lean::dec_ref(w);",
" return 1;",
"}"],
"}"];
emitLn "}"
| other := throw "function declaration expected"
def hasMainFn : M Bool :=
do env ← getEnv,
let decls := getDecls env,
do env ← getEnv;
let decls := getDecls env;
pure $ decls.any (λ d, d.name == `main)
def emitMainFnIfNeeded : M Unit :=
@ -229,16 +229,16 @@ mwhen hasMainFn emitMainFn
def emitFileHeader : M Unit :=
do
env ← getEnv,
modName ← getModName,
emitLn "// Lean compiler output",
emitLn ("// Module: " ++ toString modName),
emit "// Imports:",
env.imports.mfor $ λ m, emit (" " ++ toString m),
emitLn "",
emitLn "#include \"runtime/object.h\"",
emitLn "#include \"runtime/apply.h\"",
mwhen hasMainFn $ emitLn "#include \"runtime/init_module.h\"",
env ← getEnv;
modName ← getModName;
emitLn "// Lean compiler output";
emitLn ("// Module: " ++ toString modName);
emit "// Imports:";
env.imports.mfor $ λ m, emit (" " ++ toString m);
emitLn "";
emitLn "#include \"runtime/object.h\"";
emitLn "#include \"runtime/apply.h\"";
mwhen hasMainFn $ emitLn "#include \"runtime/init_module.h\"";
emitLns [
"typedef lean::object obj; typedef lean::usize usize;",
"typedef lean::uint8 uint8; typedef lean::uint16 uint16;",
@ -256,26 +256,26 @@ def throwUnknownVar {α : Type} (x : VarId) : M α :=
throw ("unknown variable '" ++ toString x ++ "'")
def isObj (x : VarId) : M Bool :=
do ctx ← read,
do ctx ← read;
match ctx.varMap.find x with
| some t := pure t.isObj
| none := throwUnknownVar x
def getJPParams (j : JoinPointId) : M (Array Param) :=
do ctx ← read,
do ctx ← read;
match ctx.jpMap.find j with
| some ps := pure ps
| none := throw "unknown join point"
def declareVar (x : VarId) (t : IRType) : M Unit :=
do emit (toCppType t), emit " ", emit x, emit "; "
do emit (toCppType t); emit " "; emit x; emit "; "
def declareParams (ps : Array Param) : M Unit :=
ps.mfor $ λ p, declareVar p.x p.ty
partial def declareVars : FnBody → Bool → M Bool
| e@(FnBody.vdecl x t _ b) d := do
ctx ← read,
ctx ← read;
if isTailCallTo ctx.mainFn e then
pure d
else
@ -285,9 +285,9 @@ partial def declareVars : FnBody → Bool → M Bool
def emitTag (x : VarId) : M Unit :=
do
xIsObj ← isObj x,
xIsObj ← isObj x;
if xIsObj then do
emit "lean::obj_tag(", emit x, emit ")"
emit "lean::obj_tag("; emit x; emit ")"
else
emit x
@ -299,75 +299,75 @@ else match alts.get 0 with
def emitIf (emitBody : FnBody → M Unit) (x : VarId) (tag : Nat) (t : FnBody) (e : FnBody) : M Unit :=
do
emit "if (", emitTag x, emit " == ", emit tag, emitLn ")",
emitBody t,
emitLn "else",
emit "if ("; emitTag x; emit " == "; emit tag; emitLn ")";
emitBody t;
emitLn "else";
emitBody e
def emitCase (emitBody : FnBody → M Unit) (x : VarId) (alts : Array Alt) : M Unit :=
match isIf alts with
| some (tag, t, e) := emitIf emitBody x tag t e
| _ := do
emit "switch (", emitTag x, emitLn ") {",
let alts := ensureHasDefault alts,
emit "switch ("; emitTag x; emitLn ") {";
let alts := ensureHasDefault alts;
alts.mfor $ λ alt, match alt with
| Alt.ctor c b := emit "case " *> emit c.cidx *> emitLn ":" *> emitBody b
| Alt.default b := emitLn "default: " *> emitBody b,
| Alt.default b := emitLn "default: " *> emitBody b;
emitLn "}"
def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit :=
do
emit (if checkRef then "lean::inc" else "lean::inc_ref"),
emit "(" *> emit x,
when (n != 1) (emit ", " *> emit n),
emit (if checkRef then "lean::inc" else "lean::inc_ref");
emit "(" *> emit x;
when (n != 1) (emit ", " *> emit n);
emitLn ");"
def emitDec (x : VarId) (n : Nat) (checkRef : Bool) : M Unit :=
do
emit (if checkRef then "lean::dec" else "lean::dec_ref"),
emit "(" *> emit x,
when (n != 1) (emit ", " *> emit n),
emit (if checkRef then "lean::dec" else "lean::dec_ref");
emit "("; emit x;
when (n != 1) (do emit ", "; emit n);
emitLn ");"
def emitDel (x : VarId) : M Unit :=
do emit "lean::free_heap_obj(", emit x, emitLn ");"
do emit "lean::free_heap_obj("; emit x; emitLn ");"
def emitSetTag (x : VarId) (i : Nat) : M Unit :=
do emit "lean::cnstr_set_tag(", emit x, emit ", ", emit i, emitLn ");"
do emit "lean::cnstr_set_tag("; emit x; emit ", "; emit i; emitLn ");"
def emitSet (x : VarId) (i : Nat) (y : Arg) : M Unit :=
do emit "lean::cnstr_set(", emit x, emit ", ", emit i, emit ", ", emitArg y, emitLn ");"
do emit "lean::cnstr_set("; emit x; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
def emitOffset (n : Nat) (offset : Nat) : M Unit :=
if n > 0 then do
emit "sizeof(void*)*", emit n,
emit "sizeof(void*)*"; emit n;
when (offset > 0) (emit " + " *> emit offset)
else
emit offset
def emitUSet (x : VarId) (n : Nat) (y : VarId) : M Unit :=
do emit "lean::cnstr_set_scalar(", emit x, emit ", ", emitOffset n 0, emit ", ", emit y, emitLn ");"
do emit "lean::cnstr_set_scalar("; emit x; emit ", "; emitOffset n 0; emit ", "; emit y; emitLn ");"
def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) : M Unit :=
do emit "lean::cnstr_set_scalar(", emit x, emit ", ", emitOffset n offset, emit ", ", emit y, emitLn ");"
do emit "lean::cnstr_set_scalar("; emit x; emit ", "; emitOffset n offset; emit ", "; emit y; emitLn ");"
def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit :=
do
ps ← getJPParams j,
unless (xs.size == ps.size) (throw "invalid goto"),
ps ← getJPParams j;
unless (xs.size == ps.size) (throw "invalid goto");
xs.size.mfor $ λ i, do {
let p := ps.get i,
let x := xs.get i,
emit p.x, emit " = ", emitArg x, emitLn ";"
},
emit "goto ", emit j, emitLn ";"
let p := ps.get i;
let x := xs.get i;
emit p.x; emit " = "; emitArg x; emitLn ";"
};
emit "goto "; emit j; emitLn ";"
def emitLhs (z : VarId) : M Unit :=
do emit z, emit " = "
do emit z; emit " = "
def emitArgs (ys : Array Arg) : M Unit :=
ys.size.mfor $ λ i, do
when (i > 0) (emit ", "),
when (i > 0) (emit ", ");
emitArg (ys.get i)
def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit :=
@ -377,82 +377,82 @@ else emit "sizeof(size_t)*" *> emit usize *> emit " + " *> emit ssize
def emitAllocCtor (c : CtorInfo) : M Unit :=
do
emit "lean::alloc_cnstr(", emit c.cidx, emit ", ", emit c.size, emit ", ",
emitCtorScalarSize c.usize c.ssize, emitLn ");"
emit "lean::alloc_cnstr("; emit c.cidx; emit ", "; emit c.size; emit ", ";
emitCtorScalarSize c.usize c.ssize; emitLn ");"
def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit :=
ys.size.mfor $ λ i, do
emit "lean::cnstr_set(", emit z, emit ", ", emit i, emit ", ", emitArg (ys.get i), emitLn ");"
emit "lean::cnstr_set("; emit z; emit ", "; emit i; emit ", "; emitArg (ys.get i); emitLn ");"
def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit :=
do
emitLhs z,
emitLhs z;
if c.size == 0 && c.usize == 0 && c.ssize == 0 then do
emit "lean::box(", emit c.cidx, emitLn ");"
emit "lean::box("; emit c.cidx; emitLn ");"
else do
emitAllocCtor c, emitCtorSetArgs z ys
emitAllocCtor c; emitCtorSetArgs z ys
def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit :=
do
emit "if (lean::is_exclusive(", emit x, emitLn ")) {",
emit "if (lean::is_exclusive("; emit x; emitLn ")) {";
n.mfor $ λ i, do {
emit " lean::cnstr_release(", emit x, emit ", ", emit i, emitLn ");"
},
emit " ", emitLhs z, emit x, emitLn ";",
emitLn "} else {",
emit " lean::dec_ref(", emit x, emitLn ");",
emit " ", emitLhs z, emitLn "lean::box(0);",
emit " lean::cnstr_release("; emit x; emit ", "; emit i; emitLn ");"
};
emit " "; emitLhs z; emit x; emitLn ";";
emitLn "} else {";
emit " lean::dec_ref("; emit x; emitLn ");";
emit " "; emitLhs z; emitLn "lean::box(0);";
emitLn "}"
def emitReuse (z : VarId) (x : VarId) (c : CtorInfo) (updtHeader : Bool) (ys : Array Arg) : M Unit :=
do
emit "if (lean::is_scalar(", emit x, emitLn ")) {",
emit " ", emitLhs z, emitAllocCtor c,
emitLn "} else {",
emit " ", emitLhs z, emit x, emitLn ";",
when updtHeader (do emit " lean::cnstr_set_tag(", emit z, emit ", ", emit c.cidx, emitLn ");"),
emitLn "}",
emit "if (lean::is_scalar("; emit x; emitLn ")) {";
emit " "; emitLhs z; emitAllocCtor c;
emitLn "} else {";
emit " "; emitLhs z; emit x; emitLn ";";
when updtHeader (do emit " lean::cnstr_set_tag("; emit z; emit ", "; emit c.cidx; emitLn ");");
emitLn "}";
emitCtorSetArgs z ys
def emitProj (z : VarId) (i : Nat) (x : VarId) : M Unit :=
do emitLhs z, emit "lean::cnstr_get(", emit x, emit ", ", emit i, emitLn ");"
do emitLhs z; emit "lean::cnstr_get("; emit x; emit ", "; emit i; emitLn ");"
def emitUProj (z : VarId) (i : Nat) (x : VarId) : M Unit :=
do emitLhs z, emit "lean::cnstr_get_scalar<usize>(", emit x, emit ", sizeof(void*)*", emit i, emitLn ");"
do emitLhs z; emit "lean::cnstr_get_scalar<usize>("; emit x; emit ", sizeof(void*)*"; emit i; emitLn ");"
def emitSProj (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M Unit :=
do emitLhs z, emit "lean::cnstr_get_scalar<", emit (toCppType t), emit ">(", emit x, emit ", ", emitOffset n offset, emitLn ");"
do emitLhs z; emit "lean::cnstr_get_scalar<"; emit (toCppType t); emit ">("; emit x; emit ", "; emitOffset n offset; emitLn ");"
def toStringArgs (ys : Array Arg) : List String :=
ys.toList.map argToCppString
def emitFullApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit :=
do
emitLhs z,
decl ← getDecl f,
emitLhs z;
decl ← getDecl f;
match decl with
| Decl.extern _ _ _ extData :=
match mkExternCall extData `cpp (toStringArgs ys) with
| some c := emit c *> emitLn ";"
| none := throw "failed to emit extern application"
| _ := do emitCppName f, when (ys.size > 0) (do emit "(", emitArgs ys, emit ")"), emitLn ";"
| _ := do emitCppName f; when (ys.size > 0) (do emit "("; emitArgs ys; emit ")"); emitLn ";"
def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit :=
do
decl ← getDecl f,
let arity := decl.params.size,
emitLhs z, emit "lean::alloc_closure(reinterpret_cast<void*>(", emitCppName f, emit "), ", emit arity, emit ", ", emit ys.size, emitLn ");",
decl ← getDecl f;
let arity := decl.params.size;
emitLhs z; emit "lean::alloc_closure(reinterpret_cast<void*>("; emitCppName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");";
ys.size.mfor $ λ i, do {
let y := ys.get i,
emit "lean::closure_set(", emit z, emit ", ", emit i, emit ", ", emitArg y, emitLn ");"
let y := ys.get i;
emit "lean::closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
}
def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=
if ys.size > closureMaxArgs then do
emit "{ obj* _aargs[] = {", emitArgs ys, emitLn "};",
emitLhs z, emit "lean::apply_m(", emit f, emit ", ", emit ys.size, emitLn ", _aargs); }"
emit "{ obj* _aargs[] = {"; emitArgs ys; emitLn "};";
emitLhs z; emit "lean::apply_m("; emit f; emit ", "; emit ys.size; emitLn ", _aargs); }"
else do
emitLhs z, emit "lean::apply_", emit ys.size, emit "(", emit f, emit ", ", emitArgs ys, emitLn ");"
emitLhs z; emit "lean::apply_"; emit ys.size; emit "("; emit f; emit ", "; emitArgs ys; emitLn ");"
def emitBoxFn (xType : IRType) : M Unit :=
match xType with
@ -463,24 +463,24 @@ match xType with
| other := emit "lean::box"
def emitBox (z : VarId) (x : VarId) (xType : IRType) : M Unit :=
do emitLhs z, emitBoxFn xType, emit "(", emit x, emitLn ");"
do emitLhs z; emitBoxFn xType; emit "("; emit x; emitLn ");"
def emitUnbox (z : VarId) (t : IRType) (x : VarId) : M Unit :=
do
emitLhs z,
emitLhs z;
match t with
| IRType.usize := emit "lean::unbox_size_t"
| IRType.uint32 := emit "lean::unbox_uint32"
| IRType.uint64 := emit "lean::unbox_uint64"
| IRType.float := throw "floats are not supported yet"
| other := emit "lean::unbox",
emit "(", emit x, emitLn ");"
| other := emit "lean::unbox";
emit "("; emit x; emitLn ");"
def emitIsShared (z : VarId) (x : VarId) : M Unit :=
do emitLhs z, emit "!lean::is_exclusive(", emit x, emitLn ");"
do emitLhs z; emit "!lean::is_exclusive("; emit x; emitLn ");"
def emitIsTaggedPtr (z : VarId) (x : VarId) : M Unit :=
do emitLhs z, emit "!lean::is_scalar(", emit x, emitLn ");"
do emitLhs z; emit "!lean::is_scalar("; emit x; emitLn ");"
def toHexDigit (c : Nat) : String :=
String.singleton c.digitChar
@ -501,9 +501,9 @@ q ++ "\""
def emitNumLit (t : IRType) (v : Nat) : M Unit :=
if t.isObj then do
emit "lean::mk_nat_obj(",
emit "lean::mk_nat_obj(";
if v < uint32Sz then emit v *> emit "u"
else emit "lean::mpz(\"" *> emit v *> emit "\")",
else emit "lean::mpz(\"" *> emit v *> emit "\")";
emit ")"
else
emit v
@ -512,7 +512,7 @@ def emitLit (z : VarId) (t : IRType) (v : LitVal) : M Unit :=
emitLhs z *>
match v with
| LitVal.num v := emitNumLit t v *> emitLn ";"
| LitVal.str v := do emit "lean::mk_string(", emit (quoteString v), emitLn ");"
| LitVal.str v := do emit "lean::mk_string("; emit (quoteString v); emitLn ");"
def emitVDecl (z : VarId) (t : IRType) (v : Expr) : M Unit :=
match v with
@ -533,7 +533,7 @@ match v with
def isTailCall (x : VarId) (v : Expr) (b : FnBody) : M Bool :=
do
ctx ← read,
ctx ← read;
match v, b with
| Expr.fap f _, FnBody.ret (Arg.var y) := pure $ f == ctx.mainFn && x == y
| _, _ := pure false
@ -567,35 +567,38 @@ n.any $ λ i,
def emitTailCall (v : Expr) : M Unit :=
match v with
| Expr.fap _ ys := do
ctx ← read,
let ps := ctx.mainParams,
unless (ps.size == ys.size) (throw "invalid tail call"),
ctx ← read;
let ps := ctx.mainParams;
unless (ps.size == ys.size) (throw "invalid tail call");
if overwriteParam ps ys then do {
emitLn "{",
emitLn "{";
ps.size.mfor $ λ i, do {
let p := ps.get i, let y := ys.get i,
let p := ps.get i;
let y := ys.get i;
unless (paramEqArg p y) $ do {
emit (toCppType p.ty), emit " _tmp_", emit i, emit " = ", emitArg y, emitLn ";"
emit (toCppType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
}
},
};
ps.size.mfor $ λ i, do {
let p := ps.get i, let y := ys.get i,
unless (paramEqArg p y) (do emit p.x, emit " = _tmp_", emit i, emitLn ";")
},
let p := ps.get i;
let y := ys.get i;
unless (paramEqArg p y) (do emit p.x; emit " = _tmp_"; emit i; emitLn ";")
};
emitLn "}"
} else do {
ys.size.mfor $ λ i, do {
let p := ps.get i, let y := ys.get i,
unless (paramEqArg p y) (do emit p.x, emit " = ", emitArg y, emitLn ";")
let p := ps.get i;
let y := ys.get i;
unless (paramEqArg p y) (do emit p.x; emit " = "; emitArg y; emitLn ";")
}
},
};
emitLn "goto _start;"
| _ := throw "bug at emitTailCall"
partial def emitBlock (emitBody : FnBody → M Unit) : FnBody → M Unit
| (FnBody.jdecl j xs v b) := emitBlock b
| d@(FnBody.vdecl x t v b) :=
do ctx ← read, if isTailCallTo ctx.mainFn d then emitTailCall v else emitVDecl x t v *> emitBlock b
do ctx ← read; if isTailCallTo ctx.mainFn d then emitTailCall v else emitVDecl x t v *> emitBlock b
| (FnBody.inc x n c b) := emitInc x n c *> emitBlock b
| (FnBody.dec x n c b) := emitDec x n c *> emitBlock b
| (FnBody.del x b) := emitDel x *> emitBlock b
@ -610,45 +613,45 @@ partial def emitBlock (emitBody : FnBody → M Unit) : FnBody → M Unit
| FnBody.unreachable := emitLn "lean_unreachable();"
partial def emitJPs (emitBody : FnBody → M Unit) : FnBody → M Unit
| (FnBody.jdecl j xs v b) := do emit j, emitLn ":", emitBody v, emitJPs b
| (FnBody.jdecl j xs v b) := do emit j; emitLn ":"; emitBody v; emitJPs b
| e := unless e.isTerminal (emitJPs e.body)
partial def emitFnBody : FnBody → M Unit
| b := do
emitLn "{",
declared ← declareVars b false,
when declared (emitLn ""),
emitBlock emitFnBody b,
emitJPs emitFnBody b,
emitLn "{";
declared ← declareVars b false;
when declared (emitLn "");
emitBlock emitFnBody b;
emitJPs emitFnBody b;
emitLn "}"
def emitDeclAux (d : Decl) : M Unit :=
do
env ← getEnv,
let (vMap, jpMap) := mkVarJPMaps d,
env ← getEnv;
let (vMap, jpMap) := mkVarJPMaps d;
adaptReader (λ ctx : Context, { varMap := vMap, jpMap := jpMap, .. ctx }) $ do
unless (hasInitAttr env d.name) $
match d with
| Decl.fdecl f xs t b := do
openNamespacesFor f,
baseName ← toBaseCppName f,
emit (toCppType t), emit " ",
openNamespacesFor f;
baseName ← toBaseCppName f;
emit (toCppType t); emit " ";
if xs.size > 0 then do {
emit baseName,
emit "(",
emit baseName;
emit "(";
xs.size.mfor $ λ i, do {
when (i > 0) (emit ", "),
let x := xs.get i,
emit (toCppType x.ty), emit " ", emit(x.x)
},
when (i > 0) (emit ", ");
let x := xs.get i;
emit (toCppType x.ty); emit " "; emit(x.x)
};
emit ")"
} else do {
emit ("_init_" ++ baseName ++ "()")
},
emitLn " {",
emitLn "_start:",
adaptReader (λ ctx : Context, { mainFn := f, mainParams := xs, .. ctx }) (emitFnBody b),
emitLn "}",
};
emitLn " {";
emitLn "_start:";
adaptReader (λ ctx : Context, { mainFn := f, mainParams := xs, .. ctx }) (emitFnBody b);
emitLn "}";
closeNamespacesFor f
| _ := pure ()
@ -660,8 +663,8 @@ catch
def emitFns : M Unit :=
do
env ← getEnv,
let decls := getDecls env,
env ← getEnv;
let decls := getDecls env;
decls.reverse.mfor emitDecl
def quoteNameAux : Name → Option String
@ -677,67 +680,67 @@ else quoteNameAux n
def emitDeclInit (d : Decl) : M Unit :=
do
env ← getEnv,
let n := d.name,
env ← getEnv;
let n := d.name;
if isIOUnitInitFn env n then do {
emit "w = ", emitCppName n, emitLn "(w);",
emit "w = "; emitCppName n; emitLn "(w);";
emitLn "if (io_result_is_error(w)) return w;"
} else if (d.params.size == 0) then do {
match getInitFnNameFor env d.name with
| some initFn := do {
emit "w = ", emitCppName initFn, emitLn "(w);",
emitLn "if (io_result_is_error(w)) return w;",
emitCppName n, emitLn " = io_result_get_value(w);"
emit "w = "; emitCppName initFn; emitLn "(w);";
emitLn "if (io_result_is_error(w)) return w;";
emitCppName n; emitLn " = io_result_get_value(w);"
}
| _ := do {
emitCppName n, emit " = ", emitCppInitName n, emitLn "();"
},
emitCppName n; emit " = "; emitCppInitName n; emitLn "();"
};
if d.resultType.isObj then do {
emit "lean::mark_persistent(", emitCppName n, emitLn ");",
emit "lean::mark_persistent("; emitCppName n; emitLn ");";
match quoteName n with
| some q := do emit ("lean::register_constant(" ++ q ++ ", "), emitCppName n, emitLn ");"
| some q := do emit ("lean::register_constant(" ++ q ++ ", "); emitCppName n; emitLn ");"
| none := pure ()
} else unless d.resultType.isIrrelevant $ do {
match quoteName n with
| some q := do emit ("lean::register_constant(" ++ q ++ ", "), emitBoxFn d.resultType, emit "(", emitCppName n, emitLn "));"
| some q := do emit ("lean::register_constant(" ++ q ++ ", "); emitBoxFn d.resultType; emit "("; emitCppName n; emitLn "));"
| none := pure ()
}
} else
/- TODO(Leo): perhaps we should add a flag to disable closure registration. -/
match quoteName d.name with
| some q := do
let clsName := if requiresBoxedVersion env d then mkBoxedName d.name else d.name,
emit ("REGISTER_LEAN_FUNCTION(" ++ q ++ ", " ++ toString d.params.size ++ ", "), emitCppName clsName, emitLn ");"
let clsName := if requiresBoxedVersion env d then mkBoxedName d.name else d.name;
emit ("REGISTER_LEAN_FUNCTION(" ++ q ++ ", " ++ toString d.params.size ++ ", "); emitCppName clsName; emitLn ");"
| _ := pure ()
def emitInitFn : M Unit :=
do
env ← getEnv,
modName ← getModName,
env.imports.mfor $ λ m, emitLn ("obj* initialize_" ++ m.mangle "" ++ "(obj*);"),
env ← getEnv;
modName ← getModName;
env.imports.mfor $ λ m, emitLn ("obj* initialize_" ++ m.mangle "" ++ "(obj*);");
emitLns [
"static bool _G_initialized = false;",
"obj* initialize_" ++ modName.mangle "" ++ "(obj* w) {",
"if (_G_initialized) return w;",
"_G_initialized = true;",
"if (io_result_is_error(w)) return w;"
],
];
env.imports.mfor $ λ m, emitLns [
"w = initialize_" ++ m.mangle "" ++ "(w);",
"if (io_result_is_error(w)) return w;"
],
let decls := getDecls env,
decls.reverse.mfor emitDeclInit,
];
let decls := getDecls env;
decls.reverse.mfor emitDeclInit;
emitLns [
"return w;",
"}"]
def main : M Unit :=
do
emitFileHeader,
emitFnDecls,
emitFns,
emitInitFn,
emitFileHeader;
emitFnDecls;
emitFns;
emitInitFn;
emitMainFnIfNeeded
end EmitCpp

View file

@ -29,12 +29,12 @@ partial def visitFnBody : FnBody → M Bool
let checkFn (f : FunId) : M Bool :=
if leanNameSpacePrefix.isPrefixOf f then pure true
else do {
s ← get,
s ← get;
if s.contains f then
visitFnBody b
else do
modify (λ s, s.insert f),
env ← read,
modify (λ s, s.insert f);
env ← read;
match findEnvDecl env f with
| some (Decl.fdecl _ _ _ fbody) := visitFnBody fbody <||> visitFnBody b
| other := visitFnBody b
@ -74,7 +74,7 @@ partial def collectFnBody : FnBody → M Unit
| e := unless e.isTerminal $ collectFnBody e.body
def collectInitDecl (fn : Name) : M Unit :=
do env ← read,
do env ← read;
match getInitFnNameFor env fn with
| some initFn := collect initFn
| _ := pure ()

View file

@ -137,7 +137,7 @@ mask.foldl
abbrev M := ReaderT Context (State Nat)
def mkFresh : M VarId :=
do idx ← get, modify (+1), pure { idx := idx }
do idx ← get; modify (+1); pure { idx := idx }
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
mask.size.mfold
@ -145,7 +145,7 @@ mask.size.mfold
match mask.get i with
| some _ := pure b -- code took ownership of this field
| none := do
fld ← mkFresh,
fld ← mkFresh;
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true b)))
b
@ -240,22 +240,22 @@ and `z := reuse x ctor_i ws; F` is replaced with
-/
def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
do
ctx ← read,
let b := reuseToSet ctx x y b,
ctx ← read;
let b := reuseToSet ctx x y b;
releaseUnreadFields y mask b
-- Expand `bs; x := reset[n] y; b`
partial def expand (mainFn : FnBody → Array FnBody → M FnBody)
(bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody :=
do
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b,
let (bs, mask) := eraseProjIncFor n y bs,
let bSlow := mkSlowPath x y mask b,
bFast ← mkFastPath x y mask b,
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b;
let (bs, mask) := eraseProjIncFor n y bs;
let bSlow := mkSlowPath x y mask b;
bFast ← mkFastPath x y mask b;
/- We only optimize recursively the fast. -/
bFast ← mainFn bFast Array.empty,
c ← mkFresh,
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast),
bFast ← mainFn bFast Array.empty;
c ← mkFresh;
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast);
pure $ reshape bs b
partial def searchAndExpand : FnBody → Array FnBody → M FnBody
@ -265,10 +265,10 @@ partial def searchAndExpand : FnBody → Array FnBody → M FnBody
else
searchAndExpand b (push bs d)
| (FnBody.jdecl j xs v b) bs := do
v ← searchAndExpand v Array.empty,
v ← searchAndExpand v Array.empty;
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
| (FnBody.case tid x alts) bs := do
alts ← alts.mmap $ λ alt, alt.mmodifyBody $ λ b, searchAndExpand b Array.empty,
alts ← alts.mmap $ λ alt, alt.mmodifyBody $ λ b, searchAndExpand b Array.empty;
pure $ reshape bs (FnBody.case tid x alts)
| b bs :=
if b.isTerminal then pure $ reshape bs b

View file

@ -58,7 +58,7 @@ partial def visitFnBody (w : Index) : FnBody → M Bool
| (FnBody.del x b) := visitVar w x <||> visitFnBody b
| (FnBody.mdata _ b) := visitFnBody b
| (FnBody.jmp j ys) := visitArgs w ys <||> do {
ctx ← get,
ctx ← get;
match ctx.getJPBody j with
| some b :=
-- `j` is not a local join point since we assume we cannot shadow join point declarations.

View file

@ -15,7 +15,7 @@ namespace UniqueIds
abbrev M := StateT IndexSet Id
def checkId (id : Index) : M Bool :=
do found ← get,
do found ← get;
if found.contains id then pure false
else modify (λ s, s.insert id) *> pure true
@ -80,39 +80,39 @@ abbrev N := ReaderT IndexRenaming (State Nat)
@[inline] def withVar {α : Type} (x : VarId) (k : VarId → N α) : N α :=
λ m, do
n ← getModify (+1),
n ← getModify (+1);
k { idx := n } (m.insert x.idx n)
@[inline] def withJP {α : Type} (x : JoinPointId) (k : JoinPointId → N α) : N α :=
λ m, do
n ← getModify (+1),
n ← getModify (+1);
k { idx := n } (m.insert x.idx n)
@[inline] def withParams {α : Type} (ps : Array Param) (k : Array Param → N α) : N α :=
λ m, do
m ← ps.mfoldl (λ (m : IndexRenaming) p, do n ← getModify (+1), pure $ m.insert p.x.idx n) m,
let ps := ps.map $ λ p, { x := normVar p.x m, .. p },
m ← ps.mfoldl (λ (m : IndexRenaming) p, do n ← getModify (+1); pure $ m.insert p.x.idx n) m;
let ps := ps.map $ λ p, { x := normVar p.x m, .. p };
k ps m
instance MtoN {α} : HasCoe (M α) (N α) :=
⟨λ x m, pure $ x m⟩
partial def normFnBody : FnBody → N FnBody
| (FnBody.vdecl x t v b) := do v ← normExpr v, withVar x $ λ x, FnBody.vdecl x t v <$> normFnBody b
| (FnBody.vdecl x t v b) := do v ← normExpr v; withVar x $ λ x, FnBody.vdecl x t v <$> normFnBody b
| (FnBody.jdecl j ys v b) := do
(ys, v) ← withParams ys $ λ ys, do { v ← normFnBody v, pure (ys, v) },
(ys, v) ← withParams ys $ λ ys, do { v ← normFnBody v; pure (ys, v) };
withJP j $ λ j, FnBody.jdecl j ys v <$> normFnBody b
| (FnBody.set x i y b) := do x ← normVar x, y ← normArg y, FnBody.set x i y <$> normFnBody b
| (FnBody.uset x i y b) := do x ← normVar x, y ← normVar y, FnBody.uset x i y <$> normFnBody b
| (FnBody.sset x i o y t b) := do x ← normVar x, y ← normVar y, FnBody.sset x i o y t <$> normFnBody b
| (FnBody.setTag x i b) := do x ← normVar x, FnBody.setTag x i <$> normFnBody b
| (FnBody.inc x n c b) := do x ← normVar x, FnBody.inc x n c <$> normFnBody b
| (FnBody.dec x n c b) := do x ← normVar x, FnBody.dec x n c <$> normFnBody b
| (FnBody.del x b) := do x ← normVar x, FnBody.del x <$> normFnBody b
| (FnBody.set x i y b) := do x ← normVar x; y ← normArg y; FnBody.set x i y <$> normFnBody b
| (FnBody.uset x i y b) := do x ← normVar x; y ← normVar y; FnBody.uset x i y <$> normFnBody b
| (FnBody.sset x i o y t b) := do x ← normVar x; y ← normVar y; FnBody.sset x i o y t <$> normFnBody b
| (FnBody.setTag x i b) := do x ← normVar x; FnBody.setTag x i <$> normFnBody b
| (FnBody.inc x n c b) := do x ← normVar x; FnBody.inc x n c <$> normFnBody b
| (FnBody.dec x n c b) := do x ← normVar x; FnBody.dec x n c <$> normFnBody b
| (FnBody.del x b) := do x ← normVar x; FnBody.del x <$> normFnBody b
| (FnBody.mdata d b) := FnBody.mdata d <$> normFnBody b
| (FnBody.case tid x alts) := do
x ← normVar x,
alts ← alts.mmap $ λ alt, alt.mmodifyBody normFnBody,
x ← normVar x;
alts ← alts.mmap $ λ alt, alt.mmodifyBody normFnBody;
pure $ FnBody.case tid x alts
| (FnBody.jmp j ys) := FnBody.jmp <$> normJP j <*> normArgs ys
| (FnBody.ret x) := FnBody.ret <$> normArg x

View file

@ -283,7 +283,7 @@ partial def visitDecl (env : Environment) (decls : Array Decl) : Decl → Decl
end ExplicitRC
def explicitRC (decls : Array Decl) : CompilerM (Array Decl) :=
do env ← getEnv,
do env ← getEnv;
pure $ decls.map (ExplicitRC.visitDecl env decls)
end IR

View file

@ -59,12 +59,12 @@ private partial def S (w : VarId) (c : CtorInfo) : FnBody → FnBody
abbrev M := ReaderT LocalContext (StateT Index Id)
private def mkFresh : M VarId :=
do idx ← getModify (+1),
do idx ← getModify (+1);
pure { idx := idx }
private def tryS (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
do w ← mkFresh,
let b' := S w c b,
do w ← mkFresh;
let b' := S w c b;
if b == b' then pure b
else pure $ FnBody.vdecl w IRType.object (Expr.reset c.size x) b'
@ -90,38 +90,38 @@ match b with
is expensive: `O(n^2)` where `n` is the size of the function body. -/
private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody → M (FnBody × Bool)
| e@(FnBody.case tid y alts) := do
ctx ← read,
ctx ← read;
if e.hasLiveVar ctx x then do
/- If `x` is live in `e`, we recursively process each branch. -/
alts ← alts.mmap $ λ alt, alt.mmodifyBody (λ b, Dmain b >>= Dfinalize x c),
alts ← alts.mmap $ λ alt, alt.mmodifyBody (λ b, Dmain b >>= Dfinalize x c);
pure (FnBody.case tid y alts, true)
else pure (e, false)
| (FnBody.jdecl j ys v b) := do
(b, _) ← adaptReader (λ ctx : LocalContext, ctx.addJP j ys v) (Dmain b),
(v, found) ← Dmain v,
(b, _) ← adaptReader (λ ctx : LocalContext, ctx.addJP j ys v) (Dmain b);
(v, found) ← Dmain v;
/- If `found == true`, then `Dmain b` must also have returned `(b, true)` since
we assume the IR does not have dead join points. So, if `x` is live in `j`,
then it must also live in `b` since `j` is reachable from `b` with a `jmp`. -/
pure (FnBody.jdecl j ys v b, found)
| e := do
ctx ← read,
ctx ← read;
if e.isTerminal then
pure (e, e.hasLiveVar ctx x)
else do
let (instr, b) := e.split,
let (instr, b) := e.split;
if isCtorUsing instr x then
/- If the scrutinee `x` (the one that is providing memory) is being
stored in a constructor, then reuse will probably not be able to reuse memory at runtime.
It may work only if the new cell is consumed, but we ignore this case. -/
pure (e, true)
else do
(b, found) ← Dmain b,
(b, found) ← Dmain b;
/- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar`
since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/
if found || !instr.hasFreeVar x then
pure (instr.setBody b, found)
else do
b ← tryS x c b,
b ← tryS x c b;
pure (instr.setBody b, true)
private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
@ -130,23 +130,23 @@ Dmain x c b >>= Dfinalize x c
partial def R : FnBody → M FnBody
| (FnBody.case tid x alts) := do
alts ← alts.mmap $ λ alt, do {
alt ← alt.mmodifyBody R,
alt ← alt.mmodifyBody R;
match alt with
| Alt.ctor c b :=
if c.isScalar then pure alt
else Alt.ctor c <$> D x c b
| _ := pure alt
},
};
pure $ FnBody.case tid x alts
| (FnBody.jdecl j ys v b) := do
v ← R v,
b ← adaptReader (λ ctx : LocalContext, ctx.addJP j ys v) (R b),
v ← R v;
b ← adaptReader (λ ctx : LocalContext, ctx.addJP j ys v) (R b);
pure $ FnBody.jdecl j ys v b
| e := do
if e.isTerminal then pure e
else do
let (instr, b) := e.split,
b ← R b,
let (instr, b) := e.split;
b ← R b;
pure (instr.setBody b)
end ResetReuse

View file

@ -140,15 +140,15 @@ instance EnvExtension.Inhabited (σ : Type) [Inhabited σ] : Inhabited (EnvExten
unsafe def registerEnvExtensionUnsafe {σ : Type} (initState : σ) : IO (EnvExtension σ) :=
do
initializing ← IO.initializing,
unless initializing $ throw (IO.userError ("failed to register environment, extensions can only be registered during initialization")),
exts ← envExtensionsRef.get,
let idx := exts.size,
initializing ← IO.initializing;
unless initializing $ throw (IO.userError ("failed to register environment, extensions can only be registered during initialization"));
exts ← envExtensionsRef.get;
let idx := exts.size;
let ext : EnvExtension σ := {
idx := idx,
initial := initState
},
envExtensionsRef.modify (λ exts, exts.push (unsafeCast ext)),
};
envExtensionsRef.modify (λ exts, exts.push (unsafeCast ext));
pure ext
/- Environment extensions can only be registered during initialization.
@ -159,14 +159,14 @@ pure ext
constant registerEnvExtension {σ : Type} (initState : σ) : IO (EnvExtension σ) := default _
private def mkInitialExtensionStates : IO (Array EnvExtensionState) :=
do exts ← envExtensionsRef.get, pure $ exts.map $ λ ext, ext.initial
do exts ← envExtensionsRef.get; pure $ exts.map $ λ ext, ext.initial
@[export lean.mk_empty_environment_core]
def mkEmptyEnvironment (trustLevel : UInt32 := 0) : IO Environment :=
do
initializing ← IO.initializing,
when initializing $ throw (IO.userError "environment objects cannot be created during initialization"),
exts ← mkInitialExtensionStates,
initializing ← IO.initializing;
when initializing $ throw (IO.userError "environment objects cannot be created during initialization");
exts ← mkInitialExtensionStates;
pure { const2ModIdx := {},
constants := {},
header := { trustLevel := trustLevel },
@ -243,10 +243,10 @@ unsafe def registerPersistentEnvExtensionUnsafe {α σ : Type} (descr : Persiste
do
let s : PersistentEnvExtensionState α σ := {
importedEntries := Array.empty,
state := descr.addImportedFn Array.empty },
pExts ← persistentEnvExtensionsRef.get,
when (pExts.any (λ ext, ext.name == descr.name)) $ throw (IO.userError ("invalid environment extension, '" ++ toString descr.name ++ "' has already been used")),
ext ← registerEnvExtension s,
state := descr.addImportedFn Array.empty };
pExts ← persistentEnvExtensionsRef.get;
when (pExts.any (λ ext, ext.name == descr.name)) $ throw (IO.userError ("invalid environment extension, '" ++ toString descr.name ++ "' has already been used"));
ext ← registerEnvExtension s;
let pExt : PersistentEnvExtension α σ := {
toEnvExtension := ext,
name := descr.name,
@ -254,8 +254,8 @@ let pExt : PersistentEnvExtension α σ := {
addEntryFn := descr.addEntryFn,
exportEntriesFn := descr.exportEntriesFn,
statsFn := descr.statsFn
},
persistentEnvExtensionsRef.modify (λ pExts, pExts.push (unsafeCast pExt)),
};
persistentEnvExtensionsRef.modify (λ pExts, pExts.push (unsafeCast pExt));
pure pExt
@[implementedBy registerPersistentEnvExtensionUnsafe]
@ -311,15 +311,15 @@ instance CPPExtensionState.inhabited : Inhabited CPPExtensionState := inferInsta
@[export lean.register_extension_core]
unsafe def registerCPPExtension (initial : CPPExtensionState) : Option Nat :=
unsafeIO (do ext ← registerEnvExtension initial, pure ext.idx)
unsafeIO (do ext ← registerEnvExtension initial; pure ext.idx)
@[export lean.set_extension_core]
unsafe def setCPPExtensionState (env : Environment) (idx : Nat) (s : CPPExtensionState) : Option Environment :=
unsafeIO (do exts ← envExtensionsRef.get, pure $ (exts.get idx).setState env s)
unsafeIO (do exts ← envExtensionsRef.get; pure $ (exts.get idx).setState env s)
@[export lean.get_extension_core]
unsafe def getCPPExtensionState (env : Environment) (idx : Nat) : Option CPPExtensionState :=
unsafeIO (do exts ← envExtensionsRef.get, pure $ (exts.get idx).getState env)
unsafeIO (do exts ← envExtensionsRef.get; pure $ (exts.get idx).getState env)
/- Legacy support for Modification objects -/
@ -371,15 +371,15 @@ constant readModuleData (fname : @& String) : IO ModuleData := default _
def mkModuleData (env : Environment) : IO ModuleData :=
do
pExts ← persistentEnvExtensionsRef.get,
pExts ← persistentEnvExtensionsRef.get;
let entries : Array (Name × Array EnvExtensionEntry) := pExts.size.fold
(λ i result,
let state := (pExts.get i).getState env;
let exportEntriesFn := (pExts.get i).exportEntriesFn;
let extName := (pExts.get i).name;
result.push (extName, exportEntriesFn state))
Array.empty,
bytes ← serializeModifications (modListExtension.getState env),
Array.empty;
bytes ← serializeModifications (modListExtension.getState env);
pure {
imports := env.header.imports,
constants := env.constants.foldStage2 (λ cs _ c, cs.push c) Array.empty,
@ -389,7 +389,7 @@ serialized := bytes
@[export lean.write_module_core]
def writeModule (env : Environment) (fname : String) : IO Unit :=
do modData ← mkModuleData env, saveModuleData fname modData
do modData ← mkModuleData env; saveModuleData fname modData
@[extern 2 "lean_find_olean"]
constant findOLean (modName : Name) : IO String := default _
@ -400,11 +400,11 @@ partial def importModulesAux : List Name → (NameSet × Array ModuleData) → I
if s.contains m then
importModulesAux ms (s, mods)
else do
let s := s.insert m,
mFile ← findOLean m,
mod ← readModuleData mFile,
(s, mods) ← importModulesAux mod.imports.toList (s, mods),
let mods := mods.push mod,
let s := s.insert m;
mFile ← findOLean m;
mod ← readModuleData mFile;
(s, mods) ← importModulesAux mod.imports.toList (s, mods);
let mods := mods.push mod;
importModulesAux ms (s, mods)
private partial def getEntriesFor (mod : ModuleData) (extId : Name) : Nat → Array EnvExtensionState
@ -417,7 +417,7 @@ private partial def getEntriesFor (mod : ModuleData) (extId : Name) : Nat → Ar
private def setImportedEntries (env : Environment) (mods : Array ModuleData) : IO Environment :=
do
pExtDescrs ← persistentEnvExtensionsRef.get,
pExtDescrs ← persistentEnvExtensionsRef.get;
pure $ mods.iterate env $ λ _ mod env,
pExtDescrs.iterate env $ λ _ extDescr env,
let entries := getEntriesFor mod extDescr.name 0;
@ -427,7 +427,7 @@ pure $ mods.iterate env $ λ _ mod env,
private def finalizePersistentExtensions (env : Environment) : IO Environment :=
do
pExtDescrs ← persistentEnvExtensionsRef.get,
pExtDescrs ← persistentEnvExtensionsRef.get;
pure $ pExtDescrs.iterate env $ λ _ extDescr env,
extDescr.toEnvExtension.modifyState env $ λ s,
{ state := extDescr.addImportedFn s.importedEntries, .. s }
@ -435,17 +435,17 @@ pure $ pExtDescrs.iterate env $ λ _ extDescr env,
@[export lean.import_modules_core]
def importModules (modNames : List Name) (trustLevel : UInt32 := 0) : IO Environment :=
do
(_, mods) ← importModulesAux modNames ({}, Array.empty),
(_, mods) ← importModulesAux modNames ({}, Array.empty);
let const2ModIdx := mods.iterate {} $ λ modIdx (mod : ModuleData) (m : HashMap Name ModuleIdx),
mod.constants.iterate m $ λ _ cinfo m,
m.insert cinfo.name modIdx.val,
m.insert cinfo.name modIdx.val;
constants ← mods.miterate SMap.empty $ λ _ (mod : ModuleData) (cs : ConstMap),
mod.constants.miterate cs $ λ _ cinfo cs, do {
when (cs.contains cinfo.name) $ throw (IO.userError ("import failed, environment already contains '" ++ toString cinfo.name ++ "'")),
when (cs.contains cinfo.name) $ throw (IO.userError ("import failed, environment already contains '" ++ toString cinfo.name ++ "'"));
pure $ cs.insert cinfo.name cinfo
},
let constants := constants.switch,
exts ← mkInitialExtensionStates,
};
let constants := constants.switch;
exts ← mkInitialExtensionStates;
let env : Environment := {
const2ModIdx := const2ModIdx,
constants := constants,
@ -455,10 +455,10 @@ let env : Environment := {
trustLevel := trustLevel,
imports := modNames.toArray
}
},
env ← setImportedEntries env mods,
env ← finalizePersistentExtensions env,
env ← mods.miterate env $ λ _ mod env, performModifications env mod.serialized,
};
env ← setImportedEntries env mods;
env ← finalizePersistentExtensions env;
env ← mods.miterate env $ λ _ mod env, performModifications env mod.serialized;
pure env
namespace Environment
@ -466,25 +466,25 @@ namespace Environment
@[export lean.display_stats_core]
def displayStats (env : Environment) : IO Unit :=
do
pExtDescrs ← persistentEnvExtensionsRef.get,
let numModules := ((pExtDescrs.get 0).toEnvExtension.getState env).importedEntries.size,
IO.println ("direct imports: " ++ toString env.header.imports),
IO.println ("number of imported modules: " ++ toString numModules),
IO.println ("number of consts: " ++ toString env.constants.size),
IO.println ("number of imported consts: " ++ toString env.constants.stageSizes.1),
IO.println ("number of local consts: " ++ toString env.constants.stageSizes.2),
IO.println ("number of buckets for imported consts: " ++ toString env.constants.numBuckets),
IO.println ("map depth for local consts: " ++ toString env.constants.maxDepth),
IO.println ("trust level: " ++ toString env.header.trustLevel),
IO.println ("number of extensions: " ++ toString env.extensions.size),
pExtDescrs ← persistentEnvExtensionsRef.get;
let numModules := ((pExtDescrs.get 0).toEnvExtension.getState env).importedEntries.size;
IO.println ("direct imports: " ++ toString env.header.imports);
IO.println ("number of imported modules: " ++ toString numModules);
IO.println ("number of consts: " ++ toString env.constants.size);
IO.println ("number of imported consts: " ++ toString env.constants.stageSizes.1);
IO.println ("number of local consts: " ++ toString env.constants.stageSizes.2);
IO.println ("number of buckets for imported consts: " ++ toString env.constants.numBuckets);
IO.println ("map depth for local consts: " ++ toString env.constants.maxDepth);
IO.println ("trust level: " ++ toString env.header.trustLevel);
IO.println ("number of extensions: " ++ toString env.extensions.size);
pExtDescrs.mfor $ λ extDescr, do {
IO.println ("extension '" ++ toString extDescr.name ++ "'"),
let s := extDescr.toEnvExtension.getState env,
let fmt := extDescr.statsFn s.state,
unless fmt.isNil (IO.println (" " ++ toString (Format.nest 2 (extDescr.statsFn s.state)))),
IO.println (" number of imported entries: " ++ toString (s.importedEntries.foldl (λ sum es, sum + es.size) 0)),
IO.println ("extension '" ++ toString extDescr.name ++ "'");
let s := extDescr.toEnvExtension.getState env;
let fmt := extDescr.statsFn s.state;
unless fmt.isNil (IO.println (" " ++ toString (Format.nest 2 (extDescr.statsFn s.state))));
IO.println (" number of imported entries: " ++ toString (s.importedEntries.foldl (λ sum es, sum + es.size) 0));
pure ()
},
};
pure ()
end Environment

View file

@ -32,7 +32,7 @@ modifyConstTable (λ cs, cs.qsort (λ e₁ e₂, Name.quickLt e₁.1 e₂.1))
the program may crash if the type provided by the user is incorrect.
It also assumes there are no threads trying to update the table concurrently. -/
unsafe def evalConst (α : Type) [Inhabited α] (c : Name) : IO α :=
do cs ← getConstTable,
do cs ← getConstTable;
match cs.binSearch (c, default _) (λ e₁ e₂, Name.quickLt e₁.1 e₂.1) with
| some (_, v) := pure (unsafeCast v)
| none := throw (IO.userError ("unknow constant '" ++ toString c ++ "'"))

View file

@ -29,30 +29,30 @@ IO.mkRef (mkNameMap OptionDecl)
private constant optionDeclsRef : IO.Ref OptionDecls := default _
def registerOption (name : Name) (decl : OptionDecl) : IO Unit :=
do decls ← optionDeclsRef.get,
do decls ← optionDeclsRef.get;
when (decls.contains name) $
throw $ IO.userError ("invalid option declaration '" ++ toString name ++ "', option already exists"),
throw $ IO.userError ("invalid option declaration '" ++ toString name ++ "', option already exists");
optionDeclsRef.set $ decls.insert name decl
def getOptionDecls : IO OptionDecls := optionDeclsRef.get
def getOptionDecl (name : Name) : IO OptionDecl :=
do decls ← getOptionDecls,
(some decl) ← pure (decls.find name) | throw $ IO.userError ("unknown option '" ++ toString name ++ "'"),
do decls ← getOptionDecls;
(some decl) ← pure (decls.find name) | throw $ IO.userError ("unknown option '" ++ toString name ++ "'");
pure decl
def getOptionDefaulValue (name : Name) : IO DataValue :=
do decl ← getOptionDecl name,
do decl ← getOptionDecl name;
pure decl.defValue
def getOptionDescr (name : Name) : IO String :=
do decl ← getOptionDecl name,
do decl ← getOptionDecl name;
pure decl.descr
def setOptionFromString (opts : Options) (entry : String) : IO Options :=
do let ps := (entry.split "=").map String.trim,
[key, val] ← pure ps | throw "invalid configuration option entry, it must be of the form '<key> = <value>'",
defValue ← getOptionDefaulValue key.toName,
do let ps := (entry.split "=").map String.trim;
[key, val] ← pure ps | throw "invalid configuration option entry, it must be of the form '<key> = <value>'";
defValue ← getOptionDefaulValue key.toName;
match defValue with
| DataValue.ofString v := pure $ opts.setString key val
| DataValue.ofBool v :=
@ -61,10 +61,10 @@ do let ps := (entry.split "=").map String.trim,
else throw $ IO.userError ("invalid Bool option value '" ++ val ++ "'")
| DataValue.ofName v := pure $ opts.setName key val.toName
| DataValue.ofNat v := do
unless val.isNat $ throw (IO.userError ("invalid Nat option value '" ++ val ++ "'")),
unless val.isNat $ throw (IO.userError ("invalid Nat option value '" ++ val ++ "'"));
pure $ opts.setNat key val.toNat
| DataValue.ofInt v := do
unless val.isInt $ throw (IO.userError ("invalid Int option value '" ++ val ++ "'")),
unless val.isInt $ throw (IO.userError ("invalid Int option value '" ++ val ++ "'"));
pure $ opts.setInt key val.toInt
end Lean

View file

@ -972,9 +972,9 @@ match info.updateTokens tables.tokens with
| Except.error msg := throw (IO.userError msg)
def addBuiltinLeadingParser (tablesRef : IO.Ref ParsingTables) (declName : Name) (p : Parser) : IO Unit :=
do tables ← tablesRef.get,
tablesRef.reset,
tables ← updateTokens tables p.info,
do tables ← tablesRef.get;
tablesRef.reset;
tables ← updateTokens tables p.info;
match p.info.firstTokens with
| FirstTokens.tokens tks :=
let tables := tks.foldl (λ (tables : ParsingTables) tk, { leadingTable := tables.leadingTable.insert (mkSimpleName tk.val) p, .. tables }) tables;
@ -983,9 +983,9 @@ do tables ← tablesRef.get,
throw (IO.userError ("invalid builtin parser '" ++ toString declName ++ "', initial token is not statically known"))
def addBuiltinTrailingParser (tablesRef : IO.Ref ParsingTables) (declName : Name) (p : TrailingParser) : IO Unit :=
do tables ← tablesRef.get,
tablesRef.reset,
tables ← updateTokens tables p.info,
do tables ← tablesRef.get;
tablesRef.reset;
tables ← updateTokens tables p.info;
match p.info.firstTokens with
| FirstTokens.tokens tks :=
let tables := tks.foldl (λ (tables : ParsingTables) tk, { trailingTable := tables.trailingTable.insert (mkSimpleName tk.val) p, .. tables }) tables;
@ -1017,8 +1017,8 @@ registerAttribute {
name := attrName,
descr := "Builtin parser",
add := λ env declName args persistent, do {
unless args.isMissing $ throw (IO.userError ("invalid attribute '" ++ toString attrName ++ "', unexpected argument")),
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString attrName ++ "', must be persistent")),
unless args.isMissing $ throw (IO.userError ("invalid attribute '" ++ toString attrName ++ "', unexpected argument"));
unless persistent $ throw (IO.userError ("invalid attribute '" ++ toString attrName ++ "', must be persistent"));
match env.find declName with
| none := throw "unknown declaration"
| some decl :=
@ -1045,7 +1045,7 @@ registerBuiltinParserAttribute `builtinTermParser `Lean.Parser.builtinTermParsin
@[noinline] unsafe def runBuiltinParserUnsafe (kind : String) (ref : IO.Ref ParsingTables) : ParserFn leading :=
λ a c s,
match unsafeIO (do tables ← ref.get, pure $ prattParser kind tables a c s) with
match unsafeIO (do tables ← ref.get; pure $ prattParser kind tables a c s) with
| some s := s
| none := s.mkError "failed to access builtin reference"

View file

@ -109,11 +109,11 @@ match s with
@[specialize] partial def mreplace {m : Type → Type} [Monad m] (fn : Syntax → m (Option Syntax)) : Syntax → m Syntax
| stx@(node kind args scopes) := do
o ← fn stx,
o ← fn stx;
(match o with
| some stx := pure stx
| none := do args ← args.mmap mreplace, pure (node kind args scopes))
| stx := do o ← fn stx, pure (o.getOrElse stx)
| none := do args ← args.mmap mreplace; pure (node kind args scopes))
| stx := do o ← fn stx; pure (o.getOrElse stx)
@[inline] def replace {m : Type → Type} [Monad m] (fn : Syntax → m (Option Syntax)) := @mreplace Id _
@ -126,14 +126,14 @@ private def updateInfo : SourceInfo → String.Pos → SourceInfo
@[inline]
private def updateLeadingAux : Syntax → State String.Pos (Option Syntax)
| (atom (some info) val) := do
last ← get,
set info.trailing.stopPos,
let newInfo := updateInfo info last in
last ← get;
set info.trailing.stopPos;
let newInfo := updateInfo info last;
pure $ some (atom (some newInfo) val)
| (ident (some info) rawVal val pre scopes) := do
last ← get,
set info.trailing.stopPos,
let newInfo := updateInfo info last in
last ← get;
set info.trailing.stopPos;
let newInfo := updateInfo info last;
pure $ some (ident (some newInfo) rawVal val pre scopes)
| _ := pure none
@ -184,9 +184,9 @@ partial def reprint : Syntax → Option String
if kind == choiceKind then
if args.size == 0 then failure
else do
s ← reprint (args.get 0),
args.mfoldlFrom (λ s stx, do s' ← reprint stx, guard (s == s'), pure s) s 1
else args.mfoldl (λ r stx, do s ← reprint stx, pure $ r ++ s) ""
s ← reprint (args.get 0);
args.mfoldlFrom (λ s stx, do s' ← reprint stx; guard (s == s'); pure s) s 1
else args.mfoldl (λ r stx, do s ← reprint stx; pure $ r ++ s) ""
| missing := ""
open Lean.Format

View file

@ -48,23 +48,23 @@ traceCtx cls msg (pure () : m Unit)
instance (m) [Monad m] : MonadTracer (TraceT m) :=
{ traceRoot := λ α pos cls msg ctx, do {
st ← get,
st ← get;
if st.opts.getBool cls = true then do {
modify $ λ st, {curPos := pos, curTraces := [], ..st},
a ← ctx.get,
modify $ λ (st : TraceState), {roots := st.roots.insert pos ⟨msg, st.curTraces⟩, ..st},
modify $ λ st, {curPos := pos, curTraces := [], ..st};
a ← ctx.get;
modify $ λ (st : TraceState), {roots := st.roots.insert pos ⟨msg, st.curTraces⟩, ..st};
pure a
} else ctx.get
},
traceCtx := λ α cls msg ctx, do {
st ← get,
st ← get;
-- tracing enabled?
some _ ← pure st.curPos | ctx.get,
some _ ← pure st.curPos | ctx.get;
-- Trace class enabled?
if st.opts.getBool cls = true then do {
set {curTraces := [], ..st},
a ← ctx.get,
modify $ λ (st' : TraceState), {curTraces := st.curTraces ++ [⟨msg, st'.curTraces⟩], ..st'},
set {curTraces := [], ..st};
a ← ctx.get;
modify $ λ (st' : TraceState), {curTraces := st.curTraces ++ [⟨msg, st'.curTraces⟩], ..st'};
pure a
} else
-- disable tracing inside 'ctx'
@ -76,7 +76,7 @@ instance (m) [Monad m] : MonadTracer (TraceT m) :=
}
def TraceT.run {m α} [Monad m] (opts : Options) (x : TraceT m α) : m (α × TraceMap) :=
do (a, st) ← StateT.run x {opts := opts, roots := mkRBMap _ _ _, curPos := none, curTraces := []},
do (a, st) ← StateT.run x {opts := opts, roots := RBMap.empty, curPos := none, curTraces := []};
pure (a, st.roots)
end Trace

View file

@ -95,7 +95,11 @@ static expr parse_let_body(parser & p, pos_info const & pos, bool in_do_block) {
p.next();
return p.parse_expr();
} else {
p.check_token_next(get_comma_tk(), "invalid 'do' block 'let' declaration, ',' or 'in' expected");
if (p.curr_is_token(get_semicolon_tk()) || p.curr_is_token(get_comma_tk())) {
p.next();
} else {
p.check_token_next(get_semicolon_tk(), "invalid 'do' block 'let' declaration, ',', ';' or 'in' expected");
}
if (p.curr_is_token(get_let_tk())) {
p.next();
return parse_let(p, pos, in_do_block);
@ -277,7 +281,7 @@ static expr parse_do(parser & p, bool has_braces) {
expr type, curr;
std::tie(lhs, type, curr, else_case) = parse_do_action(p, new_locals);
es.push_back(curr);
if (p.curr_is_token(get_comma_tk())) {
if (/* p.curr_is_token(get_comma_tk()) || */ p.curr_is_token(get_semicolon_tk())) {
p.next();
for (expr const & l : new_locals)
p.add_local(l);