From 2f7c0366f55078d63c8393fb300e4df0db6a6ee3 Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Sun, 20 Jul 2025 07:57:21 -0700 Subject: [PATCH] perf: treat partial application and eta expansion equally for specialization (#9438) --- src/Lean/Compiler/LCNF/Specialize.lean | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/Lean/Compiler/LCNF/Specialize.lean b/src/Lean/Compiler/LCNF/Specialize.lean index f2726bb7ea..6921f2b8fc 100644 --- a/src/Lean/Compiler/LCNF/Specialize.lean +++ b/src/Lean/Compiler/LCNF/Specialize.lean @@ -53,6 +53,7 @@ structure Context where Set of let-declarations in scope that do not depend on parameters. -/ ground : FVarIdSet := {} + underApplied : FVarIdSet := {} /-- Name of the declaration being processed -/ @@ -77,9 +78,25 @@ def isGround [TraverseFVar α] (e : α) : SpecializeM Bool := do @[inline] def withLetDecl (decl : LetDecl) (x : SpecializeM α) : SpecializeM α := do let grd ← isGround decl.value + let isUnderApplied ← + match decl.value with + | .const fnName _ args => + match ← getDecl? fnName with + -- This ascription to `Bool` is required to avoid this being inferred as `Prop`, + -- even with a type specified on the `let` binding. + | some { params, .. } => pure ((args.size < params.size) : Bool) + | none => pure false + | .fvar fnFVarId args => + match ← findFunDecl? fnFVarId with + -- This ascription to `Bool` is required to avoid this being inferred as `Prop`, + -- even with a type specified on the `let` binding. + | some { params, .. } => pure ((args.size < params.size) : Bool) + | none => pure false + | _ => pure false let fvarId := decl.fvarId withReader (fun ctx => { ctx with scope := ctx.scope.insert fvarId + underApplied := if isUnderApplied then ctx.underApplied.insert fvarId else ctx.underApplied ground := if grd then ctx.ground.insert fvarId else ctx.ground }) x @@ -156,7 +173,9 @@ def collect (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM let lctx := (← getThe CompilerM.State).lctx let abstract (fvarId : FVarId) : Bool := -- We convert let-declarations that are not ground into parameters - !lctx.funDecls.contains fvarId && !ctx.ground.contains fvarId + !lctx.funDecls.contains fvarId && + !ctx.underApplied.contains fvarId && + !ctx.ground.contains fvarId Closure.run (inScope := ctx.scope.contains) (abstract := abstract) do let mut argMask := #[] for paramInfo in paramsInfo, arg in args do