diff --git a/src/Lean/Compiler/LCNF/FixedParams.lean b/src/Lean/Compiler/LCNF/FixedParams.lean index 490e1dcbc0..a8b35bb3d6 100644 --- a/src/Lean/Compiler/LCNF/FixedParams.lean +++ b/src/Lean/Compiler/LCNF/FixedParams.lean @@ -104,10 +104,31 @@ partial def evalLetValue (e : LetValue) : FixParamM Unit := do | .const declName _ args => evalApp declName args | _ => return () +partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do + let .let { fvarId, value := (.fvar funFvarId args), .. } k := decl.value | return none + if args.size != decl.params.size then return none + let .return retFVarId := k | return none + if retFVarId != fvarId then return none + let some (.val funIdx) := (← read).assignment.find? funFvarId | return none + for h : i in [:decl.params.size] do + let param := decl.params[i] + -- TODO: Eliminate this dynamic bounds check. + let arg := args[i]! + if arg != .fvar param.fvarId && arg != .erased then return none + return some funIdx + partial def evalCode (code : Code) : FixParamM Unit := do match code with | .let decl k => evalLetValue decl.value; evalCode k - | .fun decl k | .jp decl k => evalCode decl.value; evalCode k + | .fun decl k => + if let some paramIdx ← isEquivalentFunDecl? decl then + withReader (fun ctx => + { ctx with assignment := ctx.assignment.insert decl.fvarId (.val paramIdx) }) + do evalCode k + else + evalCode decl.value + evalCode k + | .jp decl k => evalCode decl.value; evalCode k | .cases c => c.alts.forM fun alt => evalCode alt.getCode | .unreach .. | .jmp .. | .return .. => return () diff --git a/tests/lean/run/specFixedHOParamModuloErased.lean b/tests/lean/run/specFixedHOParamModuloErased.lean new file mode 100644 index 0000000000..dc902da5c0 --- /dev/null +++ b/tests/lean/run/specFixedHOParamModuloErased.lean @@ -0,0 +1,10 @@ +/-- +trace: [Compiler.specialize.info] pmap [true, true, false, true] +[Compiler.specialize.info] pmap [N, N, O, H] +-/ +#guard_msgs in +set_option trace.Compiler.specialize.info true in +@[specialize] +def pmap : (l : List α) → (f : (a : α) → a ∈ l → β) → List β + | [], _ => [] + | a :: l, f => f a List.mem_cons_self :: pmap l (fun a h => f a (List.mem_cons_of_mem _ h))