diff --git a/library/init/lean/compiler/ir/livevars.lean b/library/init/lean/compiler/ir/livevars.lean index f38075193f..d14b4b54fd 100644 --- a/library/init/lean/compiler/ir/livevars.lean +++ b/library/init/lean/compiler/ir/livevars.lean @@ -81,7 +81,7 @@ end IsLive Recall that we say that a join point `j` is free in `b` if `b` contains `FnBody.jmp j ys` and `j` is not local. -/ -def FnBody.isLive (b : FnBody) (ctx : Context) (x : VarId) : Bool := +def FnBody.hasLiveVar (b : FnBody) (ctx : Context) (x : VarId) : Bool := (IsLive.visitFnBody x.idx b).run' ctx end IR diff --git a/library/init/lean/compiler/ir/resetreuse.lean b/library/init/lean/compiler/ir/resetreuse.lean index 50d7eb7119..87f3ed8763 100644 --- a/library/init/lean/compiler/ir/resetreuse.lean +++ b/library/init/lean/compiler/ir/resetreuse.lean @@ -5,8 +5,9 @@ Authors: Leonardo de Moura -/ prelude import init.control.state +import init.control.reader import init.lean.compiler.ir.basic -import init.lean.compiler.ir.freevars +import init.lean.compiler.ir.livevars namespace Lean namespace IR @@ -20,9 +21,11 @@ namespace IR Here are the main differences: - We use the State monad to manage the generation of fresh variable names. - Support for join points, and `uset` and `sset` instructions for unboxed data. - - `R` uses the `flatten` and `reshape` idiom. - - `D` returns a pair `(b, found)` to avoid quadratic behavior when checking - the last occurrence of the variable `x` + - `D` uses the auxiliary function `Dmain`. + - `Dmain` returns a pair `(b, found)` to avoid quadratic behavior when checking + the last occurrence of the variable `x`. + - Because we have join points in the actual implementation, a variable may be live even if it + does not occur in a function body. See example at `livevars.lean`. -/ private def mayReuse (c₁ c₂ : CtorInfo) : Bool := @@ -50,7 +53,8 @@ private partial def S (w : VarId) (c : CtorInfo) : FnBody → FnBody (instr, b) := b.split in instr <;> S b -abbrev M := State Index +/- We use `Context` to track join points in scope. -/ +abbrev M := ReaderT Context (StateT Index Id) local attribute [instance] monadInhabited private def mkFresh : M VarId := @@ -68,56 +72,76 @@ private def Dfinalize (x : VarId) (c : CtorInfo) : FnBody × Bool → M FnBody | (b, false) := tryS x c b private partial def Dmain (x : VarId) (c : CtorInfo) : FnBody → M (FnBody × Bool) -| b@(FnBody.case tid y alts) := - if b.hasFreeVar x then do +| e@(FnBody.case tid y alts) := do + ctx ← read, + if e.hasLiveVar ctx x then do + /- If `x` is live in `e`, we recursively process each branch. -/ alts ← alts.hmmap $ λ alt, alt.mmodifyBody (λ b, Dmain b >>= Dfinalize x c), pure (FnBody.case tid y alts, true) - else - pure (b, false) -| e := + else pure (e, false) +| e@(FnBody.jdecl j ys t v b) := do + (b, _) ← adaptReader (λ ctx : Context, ctx.addDecl e) (Dmain b), + (v, found) ← Dmain v, + /- If `found == true`, then `Dmain b` must also have returned `(b, true)` since + we assume the IR does not have dead join points. So, if `x` is live in `j`, + then it must also live in `b` since `j` is reachable from `b` with a `jmp`. -/ + pure (FnBody.jdecl j ys t v b, found) +| e := do + ctx ← read, if e.isTerminal then - pure (e, e.hasFreeVar x) + pure (e, e.hasLiveVar ctx x) else do let (instr, b) := e.split, (b, found) ← Dmain b, + /- Remark: it is fine to use `hasFreeVar` instead of `hasLiveVar` + since `instr` is not a `FnBody.jmp` (it is not a terminal) nor it is a `FnBody.jdecl`. -/ if found || !instr.hasFreeVar x then pure (instr <;> b, found) else do b ← tryS x c b, pure (instr <;> b, true) +/- Auxiliary function used to implement an additional heuristic at `D`. -/ private partial def hasCtorUsing (x : VarId) : FnBody → Bool | (FnBody.vdecl x _ (Expr.ctor _ ys) b) := - ys.any (λ arg, Arg.hasFreeVar arg x) || hasCtorUsing b -| b := !b.isTerminal && hasCtorUsing b.body + ys.any (λ arg, match arg with + | Arg.var y := x == y + | _ := false) + || hasCtorUsing b +| (FnBody.jdecl _ _ _ v b) := hasCtorUsing v || hasCtorUsing b +| b := !b.isTerminal && hasCtorUsing b.body private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody := /- If the scrutinee `x` (the one that is providing memory) is being - stored in a constructor, then reuse will probably not work. + stored in a constructor, then reuse will probably not be able to reuse memory at runtime. It may work only if the new cell is consumed, but we ignore this case. -/ if hasCtorUsing x b then pure b else Dmain x c b >>= Dfinalize x c private partial def R : FnBody → M FnBody -| b := do - let (bs, term) := b.flatten, - bs ← mmodifyJPs bs R, - match term with - | FnBody.case tid x alts := do +| (FnBody.case tid x alts) := do alts ← alts.hmmap $ λ alt, do { alt ← alt.mmodifyBody R, match alt with | Alt.ctor c b := Alt.ctor c <$> D x c b | _ := pure alt }, - let term := FnBody.case tid x alts, - pure $ reshape bs term - | other := pure $ reshape bs term + pure $ FnBody.case tid x alts +| e@(FnBody.jdecl j ys t v b) := do + v ← R v, + b ← adaptReader (λ ctx : Context, ctx.addDecl e) (R b), + pure $ FnBody.jdecl j ys t v b +| e := do + if e.isTerminal then pure e + else do + let (instr, b) := e.split, + b ← R b, + pure (instr <;> b) def Decl.insertResetReuse : Decl → Decl | d@(Decl.fdecl f xs t b) := let nextIndex := d.maxIndex + 1 in - let b := (R b).run' nextIndex in + let b := (R b {}).run' nextIndex in Decl.fdecl f xs t b | other := other diff --git a/tests/playground/badreset.lean b/tests/playground/badreset.lean index 461278d7ec..aa4ea33e55 100644 --- a/tests/playground/badreset.lean +++ b/tests/playground/badreset.lean @@ -1,12 +1,23 @@ @[noinline] def g (x : Nat × Nat) := x set_option trace.compiler.boxed true +set_option trace.compiler.lambda_pure true @[noinline] def f (b : Bool) (x : Nat × Nat) : (Nat × Nat) × (Nat × Nat) := -let done (y : Nat × Nat) := (g (g (g y)), x) in +let done (y : Nat × Nat) := (g (g (g x)), y) in match b with | true := match x with | (a, b) := done (a, 0) | false := match x with | (a, b) := done (0, b) +@[noinline] def h {α : Type} (x : Nat × α) := x.1 + +def tst2 (p : Nat × (Except Nat Nat)) : Nat × Nat := +match p with +| (a, b) := + let done (x : Nat) := (h p + 1, x) in + match b with + | Except.ok v := done v + | Except.error w := done w + def main (xs : List String) : IO Unit := IO.println $ f true (xs.head.toNat, xs.tail.head.toNat)