From b73fe04710437daf1140cc8d06bef3e0f503fd33 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 16 Jul 2024 17:52:33 +0200 Subject: [PATCH] feat: add `Lean.Expr.numObjs` (#4754) Add helper function for computing the number of allocated sub-expressions in a given expression. Note: Use this function primarily for diagnosing performance issues. --- src/Lean/Util.lean | 1 + src/Lean/Util/NumObjs.lean | 47 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 src/Lean/Util/NumObjs.lean diff --git a/src/Lean/Util.lean b/src/Lean/Util.lean index f8d3b3066a..268df09c6f 100644 --- a/src/Lean/Util.lean +++ b/src/Lean/Util.lean @@ -31,3 +31,4 @@ import Lean.Util.FileSetupInfo import Lean.Util.Heartbeats import Lean.Util.SearchPath import Lean.Util.SafeExponentiation +import Lean.Util.NumObjs diff --git a/src/Lean/Util/NumObjs.lean b/src/Lean/Util/NumObjs.lean new file mode 100644 index 0000000000..54e70ec17b --- /dev/null +++ b/src/Lean/Util/NumObjs.lean @@ -0,0 +1,47 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Expr +import Lean.Util.PtrSet + +namespace Lean.Expr +namespace NumObjs + +unsafe structure State where + visited : PtrSet Expr := mkPtrSet + counter : Nat := 0 + +unsafe abbrev M := StateM State + +unsafe def visit (e : Expr) : M Unit := + unless (← get).visited.contains e do + modify fun { visited, counter } => { visited := visited.insert e, counter := counter + 1 } + match e with + | .forallE _ d b _ => visit d; visit b + | .lam _ d b _ => visit d; visit b + | .mdata _ b => visit b + | .letE _ t v b _ => visit t; visit v; visit b + | .app f a => visit f; visit a + | .proj _ _ b => visit b + | _ => return () + +unsafe def main (e : Expr) : Nat := + let (_, s) := NumObjs.visit e |>.run {} + s.counter + +end NumObjs + +/-- +Returns the number of allocated `Expr` objects in the given expression `e`. + +This operation is performed in `IO` because the result depends on the memory representation of the object. + +Note: Use this function primarily for diagnosing performance issues. +-/ +def numObjs (e : Expr) : IO Nat := + return unsafe NumObjs.main e + +end Lean.Expr