fix: consider Prop-rebundled higher-order params to be fixed (#9198)

This PR changes the compiler's specialization analysis to consider
higher-order params that are rebundled in a way that only changes their
`Prop` arguments to be fixed. This means that they get specialized with
a mere `@[specialize]`, rather than the compiler having to opt-in to
more aggressive parameter-specific specialization.
This commit is contained in:
Cameron Zwarich 2025-07-04 17:02:24 -07:00 committed by GitHub
parent f5e47480f2
commit 37cffbda51
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 32 additions and 1 deletions

View file

@ -104,10 +104,31 @@ partial def evalLetValue (e : LetValue) : FixParamM Unit := do
| .const declName _ args => evalApp declName args
| _ => return ()
partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do
let .let { fvarId, value := (.fvar funFvarId args), .. } k := decl.value | return none
if args.size != decl.params.size then return none
let .return retFVarId := k | return none
if retFVarId != fvarId then return none
let some (.val funIdx) := (← read).assignment.find? funFvarId | return none
for h : i in [:decl.params.size] do
let param := decl.params[i]
-- TODO: Eliminate this dynamic bounds check.
let arg := args[i]!
if arg != .fvar param.fvarId && arg != .erased then return none
return some funIdx
partial def evalCode (code : Code) : FixParamM Unit := do
match code with
| .let decl k => evalLetValue decl.value; evalCode k
| .fun decl k | .jp decl k => evalCode decl.value; evalCode k
| .fun decl k =>
if let some paramIdx ← isEquivalentFunDecl? decl then
withReader (fun ctx =>
{ ctx with assignment := ctx.assignment.insert decl.fvarId (.val paramIdx) })
do evalCode k
else
evalCode decl.value
evalCode k
| .jp decl k => evalCode decl.value; evalCode k
| .cases c => c.alts.forM fun alt => evalCode alt.getCode
| .unreach .. | .jmp .. | .return .. => return ()

View file

@ -0,0 +1,10 @@
/--
trace: [Compiler.specialize.info] pmap [true, true, false, true]
[Compiler.specialize.info] pmap [N, N, O, H]
-/
#guard_msgs in
set_option trace.Compiler.specialize.info true in
@[specialize]
def pmap : (l : List α) → (f : (a : α) → a ∈ l → β) → List β
| [], _ => []
| a :: l, f => f a List.mem_cons_self :: pmap l (fun a h => f a (List.mem_cons_of_mem _ h))