lean4-htt/src/Lean/Compiler/LCNF/Probing.lean
2025-07-25 12:02:51 +00:00

213 lines
7.2 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
module
prelude
public import Lean.Compiler.LCNF.CompilerM
public import Lean.Compiler.LCNF.PassManager
public import Lean.Compiler.LCNF.PhaseExt
public section
namespace Lean.Compiler.LCNF
abbrev Probe α β := Array α → CompilerM (Array β)
namespace Probe
@[inline]
def map (f : α → CompilerM β) : Probe α β := fun data => data.mapM f
@[inline]
def filter (f : α → CompilerM Bool) : Probe α α := fun data => data.filterM f
@[inline]
def sorted [Inhabited α] [LT α] [DecidableLT α] : Probe α α := fun data => return data.qsort (· < ·)
@[inline]
def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
let decls := decls.map fun decl => (decl.size, decl)
return decls.qsort fun (sz₁, decl₁) (sz₂, decl₂) =>
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := fun data => do
let mut map := Std.HashMap.emptyWithCapacity data.size
for d in data do
if let some count := map[d]? then
map := map.insert d (count + 1)
else
map := map.insert d 1
return map.toArray
@[inline]
def countUniqueSorted [ToString α] [BEq α] [Hashable α] [Inhabited α] : Probe α (α × Nat) :=
countUnique >=> fun data => return data.qsort (fun l r => l.snd < r.snd)
partial def getLetValues : Probe Decl LetValue := fun decls => do
let (_, res) ← start decls |>.run #[]
return res
where
go (c : Code) : StateRefT (Array LetValue) CompilerM Unit := do
match c with
| .let (decl : LetDecl) (k : Code) =>
modify fun s => s.push decl.value
go k
| .fun decl k | .jp decl k =>
go decl.value
go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
start (decls : Array Decl) : StateRefT (Array LetValue) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
partial def getJps : Probe Decl FunDecl := fun decls => do
let (_, res) ← start decls |>.run #[]
return res
where
go (code : Code) : StateRefT (Array FunDecl) CompilerM Unit := do
match code with
| .let _ k => go k
| .fun decl k => go decl.value; go k
| .jp decl k => modify (·.push decl); go decl.value; go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
start (decls : Array Decl) : StateRefT (Array FunDecl) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
partial def filterByLet (f : LetDecl → CompilerM Bool) : Probe Decl Decl :=
filter (·.value.isCodeAndM go)
where
go : Code → CompilerM Bool
| .let decl k => do if (← f decl) then return true else go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByFun (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (·.value.isCodeAndM go)
where
go : Code → CompilerM Bool
| .let _ k | .jp _ k => go k
| .fun decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByJp (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (·.value.isCodeAndM go)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k => go decl.value <||> go k
| .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByFunDecl (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
filter (·.value.isCodeAndM go)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByCases (f : Cases → CompilerM Bool) : Probe Decl Decl :=
filter (·.value.isCodeAndM go)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
partial def filterByJmp (f : FVarId → Array Arg → CompilerM Bool) : Probe Decl Decl :=
filter (·.value.isCodeAndM go)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp fn var => f fn var
| .return .. | .unreach .. => return false
partial def filterByReturn (f : FVarId → CompilerM Bool) : Probe Decl Decl :=
filter (·.value.isCodeAndM go)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .unreach .. => return false
| .return var => f var
partial def filterByUnreach (f : Expr → CompilerM Bool) : Probe Decl Decl :=
filter (·.value.isCodeAndM go)
where
go : Code → CompilerM Bool
| .let _ k => go k
| .fun decl k | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. => return false
| .unreach typ => f typ
@[inline]
def declNames : Probe Decl Name :=
Probe.map (fun decl => return decl.name)
@[inline]
def toString [ToString α] : Probe α String :=
Probe.map (return ToString.toString ·)
@[inline]
def count : Probe α Nat := fun data => return #[data.size]
@[inline]
def sum : Probe Nat Nat := fun data => return #[data.foldl (init := 0) (·+·)]
@[inline]
def tail (n : Nat) : Probe α α := fun data => return data[(data.size - n)...*]
@[inline]
def head (n : Nat) : Probe α α := fun data => return data[*...n]
def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
let ext := getExt phase
let env ← getEnv
let decls ← declNames.mapM fun name => do
let some decl := getDeclCore? env ext name | throwError "decl `{name}` not found"
return decl
probe decls |>.run (phase := phase)
def runOnModule (moduleName : Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
let ext := getExt phase
let env ← getEnv
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
let decls := ext.getModuleEntries env modIdx
probe decls |>.run (phase := phase)
def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (Array β) := do
let ext := getExt phase
let env ← getEnv
let mut decls := #[]
for modIdx in *...env.allImportedModuleNames.size do
decls := decls.append <| ext.getModuleEntries env modIdx
probe decls |>.run (phase := phase)
def toPass [ToString β] (probe : Probe Decl β) (phase : Phase) : Pass where
phase := phase
name := `probe
run := fun decls => do
let res ← probe decls
trace[Compiler.probe] s!"{res}"
return decls
builtin_initialize
registerTraceClass `Compiler.probe (inherited := true)
end Probe
end Lean.Compiler.LCNF