diff --git a/src/Lean/Compiler/LCNF/Simp/DiscrM.lean b/src/Lean/Compiler/LCNF/Simp/DiscrM.lean index 7043832292..8276faf25c 100644 --- a/src/Lean/Compiler/LCNF/Simp/DiscrM.lean +++ b/src/Lean/Compiler/LCNF/Simp/DiscrM.lean @@ -51,18 +51,21 @@ def getIndInfo? (type : Expr) : CoreM (Option (List Level × Array Expr)) := do Execute `x` with the information that `discr = ctorName ctorFields`. We use this information to simplify nested cases on the same discriminant. -/ -def withDiscrCtorImp (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : DiscrM α) : DiscrM α := do - let ctorInfo ← getConstInfoCtor ctorName - let fieldArgs := ctorFields.map (.fvar ·.fvarId) - if let some (us, params) ← getIndInfo? (← getType discr) then - let ctor := mkAppN (mkAppN (mkConst ctorName us) params) fieldArgs - withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor, ctorDiscrMap := ctx.ctorDiscrMap.insert ctor discr }) do - x - else - -- For the discrCtor map, the constructor parameters are irrelevant for optimizations that use this information - let ctor := mkAppN (mkAppN (mkConst ctorName) (mkArray ctorInfo.numParams erasedExpr)) fieldArgs - withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor }) do - x +@[inline] def withDiscrCtorImp (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : DiscrM α) : DiscrM α := do + let ctx ← updateCtx + withReader (fun _ => ctx) x +where + updateCtx : DiscrM DiscrM.Context := do + let ctorInfo ← getConstInfoCtor ctorName + let fieldArgs := ctorFields.map (.fvar ·.fvarId) + let ctx ← read + if let some (us, params) ← getIndInfo? (← getType discr) then + let ctor := mkAppN (mkAppN (mkConst ctorName us) params) fieldArgs + return { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor, ctorDiscrMap := ctx.ctorDiscrMap.insert ctor discr } + else + -- For the discrCtor map, the constructor parameters are irrelevant for optimizations that use this information + let ctor := mkAppN (mkAppN (mkConst ctorName) (mkArray ctorInfo.numParams erasedExpr)) fieldArgs + return { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor } @[inline, inheritDoc withDiscrCtorImp] def withDiscrCtor [MonadFunctorT DiscrM m] (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) : m α → m α := monadMap (m := DiscrM) <| withDiscrCtorImp discr ctorName ctorFields diff --git a/src/Lean/Compiler/LCNF/Simp/SimpM.lean b/src/Lean/Compiler/LCNF/Simp/SimpM.lean index 3dee973497..e7a0a8e6cc 100644 --- a/src/Lean/Compiler/LCNF/Simp/SimpM.lean +++ b/src/Lean/Compiler/LCNF/Simp/SimpM.lean @@ -119,18 +119,22 @@ partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit : Execute `x` with an updated `inlineStack`. If `value` is of the form `const ...`, add `const` to the stack. Otherwise, do not change the `inlineStack`. -/ -def withInlining (value : Expr) (recursive : Bool) (x : SimpM α) : SimpM α := do +@[inline] def withInlining (value : Expr) (recursive : Bool) (x : SimpM α) : SimpM α := do let f := value.getAppFn if let .const declName _ := f then + let numOccs ← check declName + withReader (fun ctx => { ctx with inlineStack := declName :: ctx.inlineStack, inlineStackOccs := ctx.inlineStackOccs.insert declName numOccs }) x + else + x +where + check (declName : Name) : SimpM Nat := do trace[Compiler.simp.inline] "{declName}" let numOccs := (← read).inlineStackOccs.find? declName |>.getD 0 let numOccs := numOccs + 1 let inlineIfReduce ← if let some decl ← getDecl? declName then pure decl.inlineIfReduceAttr else pure false if recursive && inlineIfReduce && numOccs > (← getConfig).maxRecInlineIfReduce then throwError "function `{declName}` has been recursively inlined more than #{(← getConfig).maxRecInlineIfReduce}, consider removing the attribute `[inlineIfReduce]` from this declaration or increasing the limit using `set_option compiler.maxRecInlineIfReduce `" - withReader (fun ctx => { ctx with inlineStack := declName :: ctx.inlineStack, inlineStackOccs := ctx.inlineStackOccs.insert declName numOccs }) x - else - x + return numOccs /-- Similar to the default `Lean.withIncRecDepth`, but include the `inlineStack` in the error messsage.