lean4-htt/src/Lean/Compiler/LCNF/Probing.lean
2022-10-15 20:12:53 -07:00

193 lines
6.7 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.ForEachExpr
namespace Lean.Compiler.LCNF
abbrev Probe α β := Array α → CompilerM (Array β)
namespace Probe
@[inline]
def map (f : α → CompilerM β) : Probe α β := fun data => data.mapM f
@[inline]
def filter (f : α → CompilerM Bool) : Probe α α := fun data => data.filterM f
@[inline]
def sorted [Inhabited α] [inst : LT α] [DecidableRel inst.lt] : Probe α α := fun data => return data.qsort (· < ·)
@[inline]
def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
let decls := decls.map fun decl => (decl.size, decl)
return decls.qsort fun (sz₁, decl₁) (sz₂, decl₂) =>
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := fun data => do
let mut map := HashMap.empty
for d in data do
if let some count := map.find? d then
map := map.insert d (count + 1)
else
map := map.insert d 1
return map.toArray
@[inline]
def countUniqueSorted [ToString α] [BEq α] [Hashable α] [Inhabited α] : Probe α (α × Nat) :=
countUnique >=> fun data => return data.qsort (fun l r => l.snd < r.snd)
def getExprs (skipTypes : Bool := true) : Probe Decl Expr := fun decls => do
let (_, res) ← start decls |>.run #[]
return res
where
go (e : Expr) : StateRefT (Array Expr) CompilerM Unit := do
modify fun s => s.push e
start (decls : Array Decl) : StateRefT (Array Expr) CompilerM Unit :=
decls.forM (fun decl => decl.forEachExpr go skipTypes)
partial def getJps : Probe Decl FunDecl := fun decls => do
let (_, res) ← start decls |>.run #[]
return res
where
go (code : Code) : StateRefT (Array FunDecl) CompilerM Unit := do
match code with
| .let _ k => go k
| .fun decl k => go decl.value; go k
| .jp decl k => modify (·.push decl); go decl.value; go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
start (decls : Array Decl) : StateRefT (Array FunDecl) CompilerM Unit :=
decls.forM fun decl => go decl.value
partial def filterByLet (f : LetDecl → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let decl k => do if (← f decl) then return true else go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByFun (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k | .jp _ k => go k
| .fun decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByJp (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k => go decl.value <||> go k
| .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByFunDecl (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByCases (f : Cases → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByJmp (f : FVarId → Array Expr → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp fn var => f fn var
| .return .. | .unreach .. => return false
partial def filterByReturn (f : FVarId → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .unreach .. => return false
| .return var => f var
partial def filterByUnreach (f : Expr → CompilerM Bool) : Probe Decl Decl :=
filter (fun decl => go decl.value)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. => return false
| .unreach typ => f typ
@[inline]
def declNames : Probe Decl Name :=
Probe.map (fun decl => return decl.name)
@[inline]
def toString [ToString α] : Probe α String :=
Probe.map (return ToString.toString ·)
@[inline]
def count : Probe α Nat := fun data => return #[data.size]
@[inline]
def sum : Probe Nat Nat := fun data => return #[data.foldl (init := 0) (·+·)]
@[inline]
def tail (n : Nat) : Probe α α := fun data => return data[data.size - n:]
@[inline]
def head (n : Nat) : Probe α α := fun data => return data[:n]
def runOnModule (moduleName : Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
let ext := getExt phase
let env ← getEnv
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
let decls := ext.getModuleEntries env modIdx
probe decls |>.run (phase := phase)
def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (Array β) := do
let ext := getExt phase
let env ← getEnv
let mut decls := #[]
for modIdx in [:env.allImportedModuleNames.size] do
decls := decls.append <| ext.getModuleEntries env modIdx
probe decls |>.run (phase := phase)
def toPass [ToString β] (probe : Probe Decl β) (phase : Phase) : Pass where
phase := phase
name := `probe
run := fun decls => do
let res ← probe decls
trace[Compiler.probe] s!"{res}"
return decls
builtin_initialize
registerTraceClass `Compiler.probe (inherited := true)
end Probe
end Lean.Compiler.LCNF