diff --git a/src/Lean/Compiler/IR/FreeVars.lean b/src/Lean/Compiler/IR/FreeVars.lean index af118f391c..05fab205f1 100644 --- a/src/Lean/Compiler/IR/FreeVars.lean +++ b/src/Lean/Compiler/IR/FreeVars.lean @@ -108,87 +108,84 @@ namespace FreeIndices /-! We say a variable (join point) index (aka name) is free in a function body if there isn't a `FnBody.vdecl` (`Fnbody.jdecl`) binding it. -/ -abbrev Collector := IndexSet → IndexSet → IndexSet +structure State where + freeIndices : IndexSet := {} -@[inline] private def skip : Collector := - fun _ fv => fv +abbrev M := StateM State -@[inline] private def collectIndex (x : Index) : Collector := - fun bv fv => if bv.contains x then fv else fv.insert x +private def visitIndex (x : Index) : M Unit := do + modify fun s => { s with freeIndices := s.freeIndices.insert x } -@[inline] private def collectVar (x : VarId) : Collector := - collectIndex x.idx +private def visitVar (x : VarId) : M Unit := + visitIndex x.idx -@[inline] private def collectJP (x : JoinPointId) : Collector := - collectIndex x.idx +private def visitJP (j : JoinPointId) : M Unit := + visitIndex j.idx -@[inline] private def withIndex (x : Index) : Collector → Collector := - fun k bv fv => k (bv.insert x) fv +private def visitArg (arg : Arg) : M Unit := + match arg with + | .var x => visitVar x + | .erased => pure () -@[inline] private def withVar (x : VarId) : Collector → Collector := - withIndex x.idx +private def visitParam (p : Param) : M Unit := + visitVar p.x -@[inline] private def withJP (x : JoinPointId) : Collector → Collector := - withIndex x.idx - -def insertParams (s : IndexSet) (ys : Array Param) : IndexSet := - ys.foldl (init := s) fun s p => s.insert p.x.idx - -@[inline] private def withParams (ys : Array Param) : Collector → Collector := - fun k bv fv => k (insertParams bv ys) fv - -@[inline] private def seq : Collector → Collector → Collector := - fun k₁ k₂ bv fv => k₂ bv (k₁ bv fv) - -instance : AndThen Collector where - andThen a b := private seq a (b ()) - -private def collectArg : Arg → Collector - | .var x => collectVar x - | .erased => skip - -private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := - fun bv fv => as.foldl (fun fv a => f a bv fv) fv - -private def collectArgs (as : Array Arg) : Collector := - collectArray as collectArg - -private def collectExpr : Expr → Collector +private def visitExpr (e : Expr) : M Unit := do + match e with | .proj _ x | .uproj _ x | .sproj _ _ x | .box _ x | .unbox x | .reset _ x | .isShared x => - collectVar x + visitVar x | .ctor _ ys | .fap _ ys | .pap _ ys => - collectArgs ys + ys.forM visitArg | .ap x ys | .reuse x _ _ ys => - collectVar x >> collectArgs ys - | .lit _ => skip + visitVar x + ys.forM visitArg + | .lit _ => pure () -private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector := - collectArray alts fun alt => f alt.body - -partial def collectFnBody : FnBody → Collector +partial def visitFnBody (fnBody : FnBody) : M Unit := do + match fnBody with | .vdecl x _ v b => - collectExpr v >> withVar x (collectFnBody b) + visitVar x + visitExpr v + visitFnBody b | .jdecl j ys v b => - withParams ys (collectFnBody v) >> withJP j (collectFnBody b) + visitJP j + visitFnBody v + ys.forM visitParam + visitFnBody b | .set x _ y b => - collectVar x >> collectArg y >> collectFnBody b + visitVar x + visitArg y + visitFnBody b | .uset x _ y b | .sset x _ _ y _ b => - collectVar x >> collectVar y >> collectFnBody b + visitVar x + visitVar y + visitFnBody b | .setTag x _ b | .inc x _ _ _ b | .dec x _ _ _ b | .del x b => - collectVar x >> collectFnBody b + visitVar x + visitFnBody b | .case _ x _ alts => - collectVar x >> - collectAlts collectFnBody alts + visitVar x + alts.forM (visitFnBody ·.body) | .jmp j ys => - collectJP j >> collectArgs ys + visitJP j + ys.forM visitArg | .ret x => - collectArg x - | .unreachable => skip + visitArg x + | .unreachable => pure () + +private def visitDecl (decl : Decl) : M Unit := do + match decl with + | .fdecl (xs := xs) (body := b) .. => + xs.forM visitParam + visitFnBody b + | .extern (xs := xs) .. => + xs.forM visitParam end FreeIndices -def FnBody.collectFreeIndices (b : FnBody) (vs : IndexSet) : IndexSet := - FreeIndices.collectFnBody b {} vs +def FnBody.collectFreeIndices (b : FnBody) (init : IndexSet) : IndexSet := Id.run do + let ⟨_, { freeIndices }⟩ := FreeIndices.visitFnBody b |>.run { freeIndices := init } + return freeIndices def FnBody.freeIndices (b : FnBody) : IndexSet := b.collectFreeIndices {}