From 37cffbda51dbe556ef71ead98d24e7e6f4a7bd74 Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Fri, 4 Jul 2025 17:02:24 -0700 Subject: [PATCH] fix: consider Prop-rebundled higher-order params to be fixed (#9198) This PR changes the compiler's specialization analysis to consider higher-order params that are rebundled in a way that only changes their `Prop` arguments to be fixed. This means that they get specialized with a mere `@[specialize]`, rather than the compiler having to opt-in to more aggressive parameter-specific specialization. --- src/Lean/Compiler/LCNF/FixedParams.lean | 23 ++++++++++++++++++- .../run/specFixedHOParamModuloErased.lean | 10 ++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 tests/lean/run/specFixedHOParamModuloErased.lean diff --git a/src/Lean/Compiler/LCNF/FixedParams.lean b/src/Lean/Compiler/LCNF/FixedParams.lean index 490e1dcbc0..a8b35bb3d6 100644 --- a/src/Lean/Compiler/LCNF/FixedParams.lean +++ b/src/Lean/Compiler/LCNF/FixedParams.lean @@ -104,10 +104,31 @@ partial def evalLetValue (e : LetValue) : FixParamM Unit := do | .const declName _ args => evalApp declName args | _ => return () +partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do + let .let { fvarId, value := (.fvar funFvarId args), .. } k := decl.value | return none + if args.size != decl.params.size then return none + let .return retFVarId := k | return none + if retFVarId != fvarId then return none + let some (.val funIdx) := (← read).assignment.find? funFvarId | return none + for h : i in [:decl.params.size] do + let param := decl.params[i] + -- TODO: Eliminate this dynamic bounds check. + let arg := args[i]! + if arg != .fvar param.fvarId && arg != .erased then return none + return some funIdx + partial def evalCode (code : Code) : FixParamM Unit := do match code with | .let decl k => evalLetValue decl.value; evalCode k - | .fun decl k | .jp decl k => evalCode decl.value; evalCode k + | .fun decl k => + if let some paramIdx ← isEquivalentFunDecl? decl then + withReader (fun ctx => + { ctx with assignment := ctx.assignment.insert decl.fvarId (.val paramIdx) }) + do evalCode k + else + evalCode decl.value + evalCode k + | .jp decl k => evalCode decl.value; evalCode k | .cases c => c.alts.forM fun alt => evalCode alt.getCode | .unreach .. | .jmp .. | .return .. => return () diff --git a/tests/lean/run/specFixedHOParamModuloErased.lean b/tests/lean/run/specFixedHOParamModuloErased.lean new file mode 100644 index 0000000000..dc902da5c0 --- /dev/null +++ b/tests/lean/run/specFixedHOParamModuloErased.lean @@ -0,0 +1,10 @@ +/-- +trace: [Compiler.specialize.info] pmap [true, true, false, true] +[Compiler.specialize.info] pmap [N, N, O, H] +-/ +#guard_msgs in +set_option trace.Compiler.specialize.info true in +@[specialize] +def pmap : (l : List α) → (f : (a : α) → a ∈ l → β) → List β + | [], _ => [] + | a :: l, f => f a List.mem_cons_self :: pmap l (fun a h => f a (List.mem_cons_of_mem _ h))