feat: basic compiler probing framework with examples

This commit is contained in:
Henrik Böving 2022-10-14 23:50:35 +02:00 committed by Leonardo de Moura
parent 05694a11f3
commit 38788a72be
2 changed files with 189 additions and 3 deletions

View file

@ -5,13 +5,150 @@ Authors: Henrik Böving
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.ForEachExpr
namespace Lean.Compiler.LCNF
namespace Probing
abbrev Probe α β := Array α → CompilerM (Array β)
--abbrev DeclFilter (m : Type → Type) [MonadLiftT] := Decl → OptionT m Decl
namespace Probe
end Probing
@[inline]
def map (f : α → CompilerM β) : Probe α β := fun data => data.mapM f
@[inline]
def filter (f : α → CompilerM Bool) : Probe α α := fun data => data.filterM f
@[inline]
def sorted [Inhabited α] [inst : LT α] [DecidableRel inst.lt] : Probe α α := fun data => return data.qsort (· < ·)
def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := fun data => do
let mut map := HashMap.empty
for d in data do
if let some count := map.find? d then
map := map.insert d (count + 1)
else
map := map.insert d 1
return map.toArray
@[inline]
def countUniqueSorted [ToString α] [BEq α] [Hashable α] [Inhabited α] : Probe α (α × Nat) :=
countUnique >=> fun data => return data.qsort (fun l r => l.snd < r.snd)
def getExprs (skipTypes : Bool := true) : Probe Decl Expr := fun decls => do
let (_, res) ← start decls |>.run #[]
return res
where
go (e : Expr) : StateRefT (Array Expr) CompilerM Unit := do
modify fun s => s.push e
start (decls : Array Decl) : StateRefT (Array Expr) CompilerM Unit :=
decls.forM (fun decl => decl.forEachExpr go skipTypes)
partial def filterByLet (f : LetDecl → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let decl k => do if (← f decl) then return true else go k
| .fun _ k | .jp _ k => go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByFun (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k | .jp _ k => go k
| .fun decl k => do if (← f decl) then return true else go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByJp (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k | .fun _ k => go k
| .jp decl k => do if (← f decl) then return true else go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByFunDecl (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => do if (← f decl) then return true else go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByCases (f : Cases → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k => go k | .fun _ k | .jp _ k => go k
| .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByJmp (f : FVarId → Array Expr → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k | .fun _ k | .jp _ k => go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp fn var => f fn var
| .return .. | .unreach .. => return false
partial def filterByReturn (f : FVarId → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k | .fun _ k | .jp _ k => go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .unreach .. => return false
| .return var => f var
partial def filterByUnreach (f : Expr → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k | .fun _ k | .jp _ k => go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. => return false
| .unreach typ => f typ
@[inline]
def declNames : Probe Decl Name :=
Probe.map (fun decl => return decl.name)
@[inline]
def count : Probe α Nat := fun data => return #[data.size]
def runOnModule (moduleName : Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
let ext := getExt phase
let env ← getEnv
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
let decls := ext.getModuleEntries env modIdx
probe decls |>.run (phase := phase)
def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (Array β) := do
let ext := getExt phase
let env ← getEnv
let mut decls := #[]
for modIdx in [:env.allImportedModuleNames.size] do
decls := decls.append <| ext.getModuleEntries env modIdx
probe decls |>.run (phase := phase)
def toPass [ToString β] (probe : Probe Decl β) (phase : Phase) : Pass where
phase := phase
name := `probe
run := fun decls => do
let res ← probe decls
trace[Compiler.probe] s!"{res}"
return decls
builtin_initialize
registerTraceClass `Compiler.probe (inherited := true)
end Probe
end Lean.Compiler.LCNF

View file

@ -0,0 +1,49 @@
import Lean
import Lean.Compiler.LCNF.Probing
open Lean.Compiler.LCNF
-- Find functions that have jps which take a lambda
#eval
Probe.runGlobally (phase := .mono) <|
Probe.filterByJp (·.params.anyM (fun param => return param.type.isForall)) >=>
Probe.declNames
-- Count lambda lifted functions
def lambdaCounter : Probe Decl Nat :=
Probe.filter (fun decl =>
if let .str _ val := decl.name then
return val.startsWith "_lambda"
else
return false) >=>
Probe.declNames >=>
Probe.count
-- Run everywhere
#eval
Probe.runGlobally (phase := .mono) <|
lambdaCounter
-- Run limited
#eval
Probe.runOnModule `Lean.Compiler.LCNF.JoinPoints (phase := .mono) <|
lambdaCounter
-- Find most commonly used function with threshold
#eval
Probe.runOnModule `Lean.Compiler.LCNF.JoinPoints (phase := .mono) <|
Probe.getExprs >=>
Probe.filter (fun e => return e.isApp && e.getAppFn.isConst) >=>
Probe.map (fun e => return s!"{e.getAppFn.constName!}") >=>
Probe.countUniqueSorted >=>
Probe.filter (fun (_, count) => return count > 100)
-- To get that real shell feeling
infixr:55 " | " => Bind.kleisliRight
#eval
Probe.runOnModule `Lean.Compiler.LCNF.JoinPoints (phase := .mono) <|
Probe.getExprs |
Probe.filter (fun e => return e.isApp && e.getAppFn.isConst) |
Probe.map (fun e => return s!"{e.getAppFn.constName!}") |
Probe.countUniqueSorted |
Probe.filter (fun (_, count) => return count > 100)