diff --git a/src/Lean/Compiler/LCNF/JoinPoints.lean b/src/Lean/Compiler/LCNF/JoinPoints.lean index 0ea63eaab6..48f7fc1cc4 100644 --- a/src/Lean/Compiler/LCNF/JoinPoints.lean +++ b/src/Lean/Compiler/LCNF/JoinPoints.lean @@ -612,8 +612,8 @@ builtin_initialize def Decl.extendJoinPointContext (decl : Decl) : CompilerM Decl := do JoinPointContextExtender.extend decl -def extendJoinPointContext : Pass := - .mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext .mono +def extendJoinPointContext (occurrence : Nat := 0) (phase := Phase.mono) (_h : phase ≠ .base := by simp): Pass := + .mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext phase (occurrence := occurrence) builtin_initialize registerTraceClass `Compiler.extendJoinPointContext (inherited := true) diff --git a/src/Lean/Compiler/LCNF/Passes.lean b/src/Lean/Compiler/LCNF/Passes.lean index 1bf81ae7ca..335a7e6d34 100644 --- a/src/Lean/Compiler/LCNF/Passes.lean +++ b/src/Lean/Compiler/LCNF/Passes.lean @@ -59,13 +59,14 @@ def builtinPassManager : PassManager := { toMono, simp (occurrence := 3) (phase := .mono), reduceJpArity (phase := .mono), - extendJoinPointContext, + extendJoinPointContext (phase := .mono) (occurrence := 0), floatLetIn (phase := .mono) (occurrence := 1), reduceArity, commonJoinPointArgs, simp (occurrence := 4) (phase := .mono), floatLetIn (phase := .mono) (occurrence := 2), lambdaLifting, + extendJoinPointContext (phase := .mono) (occurrence := 1), simp (occurrence := 5) (phase := .mono), cse (occurrence := 2) (phase := .mono), -- TODO: reduce function arity