perf: treat partial application and eta expansion equally for specialization (#9438)

This commit is contained in:
Cameron Zwarich 2025-07-20 07:57:21 -07:00 committed by GitHub
parent 1a9757d1f6
commit 2f7c0366f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -53,6 +53,7 @@ structure Context where
Set of let-declarations in scope that do not depend on parameters.
-/
ground : FVarIdSet := {}
underApplied : FVarIdSet := {}
/--
Name of the declaration being processed
-/
@ -77,9 +78,25 @@ def isGround [TraverseFVar α] (e : α) : SpecializeM Bool := do
@[inline] def withLetDecl (decl : LetDecl) (x : SpecializeM α) : SpecializeM α := do
let grd ← isGround decl.value
let isUnderApplied ←
match decl.value with
| .const fnName _ args =>
match ← getDecl? fnName with
-- This ascription to `Bool` is required to avoid this being inferred as `Prop`,
-- even with a type specified on the `let` binding.
| some { params, .. } => pure ((args.size < params.size) : Bool)
| none => pure false
| .fvar fnFVarId args =>
match ← findFunDecl? fnFVarId with
-- This ascription to `Bool` is required to avoid this being inferred as `Prop`,
-- even with a type specified on the `let` binding.
| some { params, .. } => pure ((args.size < params.size) : Bool)
| none => pure false
| _ => pure false
let fvarId := decl.fvarId
withReader (fun ctx => { ctx with
scope := ctx.scope.insert fvarId
underApplied := if isUnderApplied then ctx.underApplied.insert fvarId else ctx.underApplied
ground := if grd then ctx.ground.insert fvarId else ctx.ground
}) x
@ -156,7 +173,9 @@ def collect (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM
let lctx := (← getThe CompilerM.State).lctx
let abstract (fvarId : FVarId) : Bool :=
-- We convert let-declarations that are not ground into parameters
!lctx.funDecls.contains fvarId && !ctx.ground.contains fvarId
!lctx.funDecls.contains fvarId &&
!ctx.underApplied.contains fvarId &&
!ctx.ground.contains fvarId
Closure.run (inScope := ctx.scope.contains) (abstract := abstract) do
let mut argMask := #[]
for paramInfo in paramsInfo, arg in args do