diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 2647514691..28229f094a 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -326,12 +326,14 @@ Result of `inlineCandidate?`. It contains information for inlining local and global functions. -/ structure InlineCandidateInfo where - isLocal : Bool - params : Array Param + isLocal : Bool + params : Array Param /-- Value (lambda expression) of the function to be inlined. -/ - value : Code - f : Expr - args : Array Expr + value : Code + f : Expr + args : Array Expr + /-- `ifReduce = true` if the declaration being inlined was tagged with `inlineIfReduce`. -/ + ifReduce : Bool /-- The arity (aka number of parameters) of the function to be inlined. -/ def InlineCandidateInfo.arity : InlineCandidateInfo → Nat @@ -383,15 +385,16 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do if inlineIfReduce then let some paramIdx := isCasesOnParam? decl | return none unless paramIdx < numArgs do return none - let arg ← findCtor (e.getArg! paramIdx) + let arg ← findExpr (e.getArg! paramIdx) unless arg.isConstructorApp (← getEnv) do return none let params := decl.instantiateParamsLevelParams us let value := decl.instantiateValueLevelParams us incInline return some { - isLocal := false - f := e.getAppFn - args := e.getAppArgs + isLocal := false + f := e.getAppFn + args := e.getAppArgs + ifReduce := inlineIfReduce params, value } else if let some decl ← findFunDecl? f then @@ -401,22 +404,16 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do incInlineLocal modify fun s => { s with inlineLocal := s.inlineLocal + 1 } return some { - isLocal := true - f := e.getAppFn - args := e.getAppArgs - params := decl.params - value := decl.value + isLocal := true + f := e.getAppFn + args := e.getAppArgs + params := decl.params + value := decl.value + ifReduce := false } else return none -/-- -Add substitution `fvarId ↦ val`. `val` is a free variable, or -it is a type, type former, or `lcErased`. --/ -def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit := - modify fun s => { s with subst := s.subst.insert fvarId val } - /-- Return `true` if `c` has only one exit point. This is a quick approximation. It does not check cases @@ -431,7 +428,7 @@ where match c with | .let _ k | .fun _ k => go k -- Approximation, the cases may have many unreachable alternatives, and only reachable. - | .cases c => c.alts.size == 1 && c.alts.any fun alt => go alt.getCode + | .cases c => c.alts.size == 1 && go c.alts[0]!.getCode -- Approximation, we assume that any code containing join points have more than one exit point | .jp .. | .jmp .. => false | .return .. | .unreach .. => true @@ -468,56 +465,6 @@ def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do updateFunDeclInfo code mkAuxFunDecl paramsNew code -/-- -If the value of the given let-declaration is an application that can be inlined, inline it. - -`k` is the "continuation" for the let declaration. --/ -partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do - if k matches .unreach .. then return some k - let some info ← inlineCandidate? letDecl.value | return none - markSimplified - let numArgs := info.args.size - trace[Compiler.simp.inline] "inlining {letDecl.value}" - let fvarId := letDecl.fvarId - if numArgs < info.arity then - let funDecl ← specializePartialApp info - addSubst letDecl.fvarId (.fvar funDecl.fvarId) - return some (.fun funDecl k) - else - let code ← betaReduce info.params info.value info.args[:info.arity] - if k.isReturnOf fvarId && numArgs == info.arity then - /- Easy case, the continuation `k` is just returning the result of the application. -/ - return code - else if oneExitPointQuick code then - /- - `code` has only one exit point, thus we can attach the continuation directly there, - and simplify the result. - -/ - code.bind fun fvarId' => do - /- fvarId' is the result of the computation -/ - if numArgs > info.arity then - let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:]) - let k ← replaceFVar k fvarId decl.fvarId - return .let decl k - else - replaceFVar k fvarId fvarId' - else - /- - `code` has multiple exit points, and the continuation is non-trivial - Thus, we create an auxiliary join point. - -/ - let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity])) - let jpValue ← if numArgs > info.arity then - let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:]) - let k ← replaceFVar k fvarId decl.fvarId - pure <| .let decl k - else - replaceFVar k fvarId jpParam.fvarId - let jpDecl ← mkAuxJpDecl #[jpParam] jpValue - let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] - return Code.jp jpDecl code - /-- Try to inline a join point. -/ @@ -868,11 +815,83 @@ def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do let value := mkAppN letDecl.value (params.map (.fvar ·.fvarId)) let auxDecl ← mkAuxLetDecl value let funDecl ← mkAuxFunDecl params (.let auxDecl (.return auxDecl.fvarId)) - addSubst letDecl.fvarId (.fvar funDecl.fvarId) + addFVarSubst letDecl.fvarId funDecl.fvarId eraseLetDecl letDecl return funDecl +/-- +Similar to `Code.isReturnOf`, but taking the current substitution into account. +-/ +def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do + match c with + | .return fvarId' => return (← normFVar fvarId') == fvarId + | _ => return false + mutual +/-- +If the value of the given let-declaration is an application that can be inlined, +inline it and simplify the result. + +`k` is the "continuation" for the let declaration, if the application is inlined, +it will also be simplified. + +Note: `inlineApp?` did not use to be in this mutually recursive declaration. +It used to be invoked by `simp`, and would return `Option Code` that would be +then simplified by `simp`. However, this simpler architecture produced an +exponential blowup in when processing functions such as `Lean.Elab.Deriving.Ord.mkMatch.mkAlts`. +The key problem is that when inlining a declaration we often can reduce the number +of exit points by simplified the inlined code, and then connecting the result to the +continuation `k`. However, this optimization is only possible if we simplify the +inlined code **before** we attach it to the continuation. +-/ +partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do + let some info ← inlineCandidate? letDecl.value | return none + let numArgs := info.args.size + trace[Compiler.simp.inline] "inlining {letDecl.value}" + let fvarId := letDecl.fvarId + if numArgs < info.arity then + let funDecl ← specializePartialApp info + addFVarSubst fvarId funDecl.fvarId + markSimplified + simp (.fun funDecl k) + else + let code ← betaReduce info.params info.value info.args[:info.arity] + if k.isReturnOf fvarId && numArgs == info.arity then + /- Easy case, the continuation `k` is just returning the result of the application. -/ + markSimplified + simp code + else + let code ← simp code + if oneExitPointQuick code then + -- TODO: if `k` is small, we should also inline it here + markSimplified + code.bind fun fvarId' => do + markUsedFVar fvarId' + /- fvarId' is the result of the computation -/ + if numArgs > info.arity then + let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:]) + addFVarSubst fvarId decl.fvarId + simp (.let decl k) + else + addFVarSubst fvarId fvarId' + simp k + -- else if info.ifReduce then + -- eraseCode code + -- return none + else + markSimplified + let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity])) + let jpValue ← if numArgs > info.arity then + let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:]) + addFVarSubst fvarId decl.fvarId + simp (.let decl k) + else + addFVarSubst fvarId jpParam.fvarId + simp k + let jpDecl ← mkAuxJpDecl #[jpParam] jpValue + let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] + return Code.jp jpDecl code + /-- Simplify the given local function declaration. -/ @@ -904,11 +923,11 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do To make the code robust, we add auxiliary declarations whenever the `field` is not a free variable. -/ if field.isFVar then - addSubst param.fvarId field + addFVarSubst param.fvarId field.fvarId! else let auxDecl ← mkAuxLetDecl field auxDecls := auxDecls.push (CodeDecl.let auxDecl) - addSubst param.fvarId (.fvar auxDecl.fvarId) + addFVarSubst param.fvarId auxDecl.fvarId let k ← simp k eraseParams params attachCodeDecls auxDecls k @@ -927,14 +946,14 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do simp (.fun funDecl k) else if decl.value.isFVar then /- Eliminate `let _x_i := _x_j;` -/ - addSubst decl.fvarId decl.value + addFVarSubst decl.fvarId decl.value.fvarId! eraseLetDecl decl simp k else if let some code ← inlineApp? decl k then eraseLetDecl decl - simp code + return code else if let some (decls, fvarId) ← inlineProjInst? decl.value then - addSubst decl.fvarId (.fvar fvarId) + addFVarSubst decl.fvarId fvarId eraseLetDecl decl let k ← simp k attachCodeDecls decls k @@ -959,7 +978,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do else /- Note that functions in `decl` will be marked as used even if `decl` is not actually used. - They will only be deleted in the next pass. + They will only be deleted in the next pass. TODO: investigate whether this is a problem. -/ if code.isFun then if decl.isEtaExpandCandidate then diff --git a/tests/lean/run/simpExpBlowup.lean b/tests/lean/run/simpExpBlowup.lean new file mode 100644 index 0000000000..446ef28bef --- /dev/null +++ b/tests/lean/run/simpExpBlowup.lean @@ -0,0 +1,4 @@ +import Lean + +set_option trace.Compiler.result true +#eval Lean.Compiler.compile #[``Lean.Elab.Deriving.Ord.mkMatch.mkAlts]