diff --git a/library/init/lean/compiler/ir/default.lean b/library/init/lean/compiler/ir/default.lean index a8f541d045..b505d47290 100644 --- a/library/init/lean/compiler/ir/default.lean +++ b/library/init/lean/compiler/ir/default.lean @@ -19,7 +19,7 @@ namespace IR def test (d : Decl) : IO Unit := do IO.println d, - IO.println $ ("Max variable " ++ toString d.maxVar), + IO.println $ ("Max index " ++ toString d.maxIndex), let d := d.pushProj, IO.println "=== After push projections ===", IO.println d, diff --git a/library/init/lean/compiler/ir/elimdead.lean b/library/init/lean/compiler/ir/elimdead.lean index d86503c8c2..16aebef41f 100644 --- a/library/init/lean/compiler/ir/elimdead.lean +++ b/library/init/lean/compiler/ir/elimdead.lean @@ -17,7 +17,7 @@ partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnB let curr := bs.back in let bs := bs.pop in let keep (_ : Unit) := - let used := curr.collectFreeVars used in + let used := curr.collectFreeIndices used in let b := curr <;> b in reshapeWithoutDeadAux bs b used in let keepIfUsed (vidx : Index) := @@ -29,7 +29,7 @@ partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnB | _ := keep () def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody := -reshapeWithoutDeadAux bs term term.freeVars +reshapeWithoutDeadAux bs term term.freeIndices partial def FnBody.elimDead : FnBody → FnBody | b := diff --git a/library/init/lean/compiler/ir/freevars.lean b/library/init/lean/compiler/ir/freevars.lean index 9adcbda389..0fa446c2e4 100644 --- a/library/init/lean/compiler/ir/freevars.lean +++ b/library/init/lean/compiler/ir/freevars.lean @@ -9,10 +9,89 @@ import init.lean.compiler.ir.basic namespace Lean namespace IR +namespace MaxIndex +/- Compute the maximum index `M` used in a declaration. + We `M` to initialize the fresh index generator used to create fresh + variable and join point names. + + Recall that we variable and join points share the same namespace in + our implementation. +-/ + +abbrev Collector := Index → Index + +@[inline] private def skip : Collector := id +@[inline] private def collect (x : Index) : Collector := λ y, if x > y then x else y +@[inline] private def collectVar (x : VarId) : Collector := collect x.idx +@[inline] private def collectJP (j : JoinPointId) : Collector := collect j.idx +@[inline] private def seq (k₁ k₂ : Collector) : Collector := k₂ ∘ k₁ +instance : HasAndthen Collector := ⟨seq⟩ + +private def collectArg : Arg → Collector +| (Arg.var x) := collectVar x +| irrelevant := skip + +@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := +λ m, as.foldl (λ m a, f a m) m + +private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg +private def collectParam (p : Param) : Collector := collectVar p.x +private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam + +private def collectExpr : Expr → Collector +| (Expr.ctor _ ys) := collectArgs ys +| (Expr.reset x) := collectVar x +| (Expr.reuse x _ _ ys) := collectVar x; collectArgs ys +| (Expr.proj _ x) := collectVar x +| (Expr.uproj _ x) := collectVar x +| (Expr.sproj _ _ x) := collectVar x +| (Expr.fap _ ys) := collectArgs ys +| (Expr.pap _ ys) := collectArgs ys +| (Expr.ap x ys) := collectVar x; collectArgs ys +| (Expr.box _ x) := collectVar x +| (Expr.unbox x) := collectVar x +| (Expr.lit v) := skip +| (Expr.isShared x) := collectVar x +| (Expr.isTaggedPtr x) := collectVar x + +private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector := +collectArray alts $ λ alt, f alt.body + +partial def collectFnBody : FnBody → Collector +| (FnBody.vdecl x _ v b) := collectExpr v; collectFnBody b +| (FnBody.jdecl j ys _ v b) := collectFnBody v; collectParams ys; collectFnBody b +| (FnBody.set x _ y b) := collectVar x; collectVar y; collectFnBody b +| (FnBody.uset x _ y b) := collectVar x; collectVar y; collectFnBody b +| (FnBody.sset x _ _ y _ b) := collectVar x; collectVar y; collectFnBody b +| (FnBody.release x _ b) := collectVar x; collectFnBody b +| (FnBody.inc x _ _ b) := collectVar x; collectFnBody b +| (FnBody.dec x _ _ b) := collectVar x; collectFnBody b +| (FnBody.mdata _ b) := collectFnBody b +| (FnBody.case _ x alts) := collectVar x; collectAlts collectFnBody alts +| (FnBody.jmp j ys) := collectJP j; collectArgs ys +| (FnBody.ret x) := collectArg x +| FnBody.unreachable := skip + +partial def collectDecl : Decl → Collector +| (Decl.fdecl _ xs _ b) := collectParams xs; collectFnBody b +| (Decl.extern _ xs _) := collectParams xs + +end MaxIndex + +def FnBody.maxIndex (b : FnBody) : Index := +MaxIndex.collectFnBody b 0 + +def Decl.maxIndex (d : Decl) : Index := +MaxIndex.collectDecl d 0 + +/-- Set of variable and join point names -/ abbrev IndexSet := RBTree Index (λ a b, a < b) instance vsetInh : Inhabited IndexSet := ⟨{}⟩ -namespace FreeVariables +namespace FreeIndices +/- We say a variable (join point) index (aka name) is free in a function body + if there isn't a `FnBody.vdecl` (`Fnbody.jdecl`) binding it. -/ + abbrev Collector := IndexSet → IndexSet → IndexSet @[inline] private def skip : Collector := @@ -91,83 +170,19 @@ partial def collectFnBody : FnBody → Collector | (FnBody.ret x) := collectArg x | FnBody.unreachable := skip -end FreeVariables +end FreeIndices -def FnBody.collectFreeVars (b : FnBody) (vs : IndexSet) : IndexSet := -FreeVariables.collectFnBody b {} vs +def FnBody.collectFreeIndices (b : FnBody) (vs : IndexSet) : IndexSet := +FreeIndices.collectFnBody b {} vs -def FnBody.freeVars (b : FnBody) : IndexSet := -b.collectFreeVars {} - -namespace MaxVar -abbrev Collector := Index → Index - -@[inline] private def skip : Collector := id -@[inline] private def collect (x : Index) : Collector := λ y, if x > y then x else y -@[inline] private def collectVar (x : VarId) : Collector := collect x.idx -@[inline] private def collectJP (j : JoinPointId) : Collector := collect j.idx -@[inline] private def seq (k₁ k₂ : Collector) : Collector := k₂ ∘ k₁ -instance : HasAndthen Collector := ⟨seq⟩ - -private def collectArg : Arg → Collector -| (Arg.var x) := collectVar x -| irrelevant := skip - -@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := -λ m, as.foldl (λ m a, f a m) m - -private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg -private def collectParam (p : Param) : Collector := collectVar p.x -private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam - -private def collectExpr : Expr → Collector -| (Expr.ctor _ ys) := collectArgs ys -| (Expr.reset x) := collectVar x -| (Expr.reuse x _ _ ys) := collectVar x; collectArgs ys -| (Expr.proj _ x) := collectVar x -| (Expr.uproj _ x) := collectVar x -| (Expr.sproj _ _ x) := collectVar x -| (Expr.fap _ ys) := collectArgs ys -| (Expr.pap _ ys) := collectArgs ys -| (Expr.ap x ys) := collectVar x; collectArgs ys -| (Expr.box _ x) := collectVar x -| (Expr.unbox x) := collectVar x -| (Expr.lit v) := skip -| (Expr.isShared x) := collectVar x -| (Expr.isTaggedPtr x) := collectVar x - -private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector := -collectArray alts $ λ alt, f alt.body - -partial def collectFnBody : FnBody → Collector -| (FnBody.vdecl x _ v b) := collectExpr v; collectFnBody b -| (FnBody.jdecl j ys _ v b) := collectFnBody v; collectParams ys; collectFnBody b -| (FnBody.set x _ y b) := collectVar x; collectVar y; collectFnBody b -| (FnBody.uset x _ y b) := collectVar x; collectVar y; collectFnBody b -| (FnBody.sset x _ _ y _ b) := collectVar x; collectVar y; collectFnBody b -| (FnBody.release x _ b) := collectVar x; collectFnBody b -| (FnBody.inc x _ _ b) := collectVar x; collectFnBody b -| (FnBody.dec x _ _ b) := collectVar x; collectFnBody b -| (FnBody.mdata _ b) := collectFnBody b -| (FnBody.case _ x alts) := collectVar x; collectAlts collectFnBody alts -| (FnBody.jmp j ys) := collectJP j; collectArgs ys -| (FnBody.ret x) := collectArg x -| FnBody.unreachable := skip - -partial def collectDecl : Decl → Collector -| (Decl.fdecl _ xs _ b) := collectParams xs; collectFnBody b -| (Decl.extern _ xs _) := collectParams xs - -end MaxVar - -def FnBody.maxVar (b : FnBody) : Index := -MaxVar.collectFnBody b 0 - -def Decl.maxVar (d : Decl) : Index := -MaxVar.collectDecl d 0 +def FnBody.freeIndices (b : FnBody) : IndexSet := +b.collectFreeIndices {} namespace HasIndex - +/- In principle, we can check whether a function body `b` contains an index `i` using + `b.freeIndices.contains i`, but it is more efficient to avoid the construction + of the set of freeIndices and just search whether `i` occurs in `b` or not. +-/ def visitVar (w : Index) (x : VarId) : Bool := w == x.idx def visitJP (w : Index) (x : JoinPointId) : Bool := w == x.idx diff --git a/library/init/lean/compiler/ir/pushproj.lean b/library/init/lean/compiler/ir/pushproj.lean index a3e6277947..7861604821 100644 --- a/library/init/lean/compiler/ir/pushproj.lean +++ b/library/init/lean/compiler/ir/pushproj.lean @@ -11,23 +11,23 @@ namespace Lean namespace IR partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array FnBody → IndexSet → Array FnBody × Array Alt -| bs alts afvs ps vs := - if bs.isEmpty then (ps.reverse, alts) +| bs alts altsF ctx ctxF := + if bs.isEmpty then (ctx.reverse, alts) else let b := bs.back in let bs := bs.pop in - let done (_ : Unit) := (bs.push b ++ ps.reverse, alts) in - let skip (_ : Unit) := pushProjs bs alts afvs (ps.push b) (b.collectFreeVars vs) in + let done (_ : Unit) := (bs.push b ++ ctx.reverse, alts) in + let skip (_ : Unit) := pushProjs bs alts altsF (ctx.push b) (b.collectFreeIndices ctxF) in match b with | FnBody.vdecl x t v _ := match v with | Expr.proj _ _ := - if !vs.contains x.idx && !afvs.all (λ s, s.contains x.idx) then + if !ctxF.contains x.idx && !altsF.all (λ s, s.contains x.idx) then let alts := alts.hmapIdx $ λ i alt, alt.modifyBody $ λ b', - if (afvs.get i).contains x.idx then b <;> b' + if (altsF.get i).contains x.idx then b <;> b' else b' in - let fvs := afvs.hmap $ λ s, if s.contains x.idx then b.collectFreeVars s else s in - pushProjs bs alts fvs ps vs + let altsF := altsF.hmap $ λ s, if s.contains x.idx then b.collectFreeIndices s else s in + pushProjs bs alts altsF ctx ctxF else skip () | Expr.uproj _ _ := skip () @@ -41,8 +41,8 @@ partial def FnBody.pushProj : FnBody → FnBody let bs := modifyJPs bs FnBody.pushProj in match term with | FnBody.case tid x alts := - let afvs := alts.map $ λ alt, alt.body.freeVars in - let (bs, alts) := pushProjs bs alts afvs Array.empty {x.idx} in + let altsF := alts.map $ λ alt, alt.body.freeIndices in + let (bs, alts) := pushProjs bs alts altsF Array.empty {x.idx} in let alts := alts.hmap $ λ alt, alt.modifyBody FnBody.pushProj in let term := FnBody.case tid x alts in reshape bs term diff --git a/library/init/lean/compiler/ir/resetreuse.lean b/library/init/lean/compiler/ir/resetreuse.lean index 8f23fd5f0c..50d7eb7119 100644 --- a/library/init/lean/compiler/ir/resetreuse.lean +++ b/library/init/lean/compiler/ir/resetreuse.lean @@ -116,8 +116,8 @@ private partial def R : FnBody → M FnBody def Decl.insertResetReuse : Decl → Decl | d@(Decl.fdecl f xs t b) := - let nextVar := d.maxVar + 1 in - let b := (R b).run' nextVar in + let nextIndex := d.maxIndex + 1 in + let b := (R b).run' nextIndex in Decl.fdecl f xs t b | other := other