From 4d9483d1fa0ee06ad34a4dec0eb7b309baec3fa7 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 15 Oct 2022 11:44:36 -0700 Subject: [PATCH] fix: if inlined code returns a function and has more than one exit point, create an auxiliary function instead of a join point that takes a closure as argument --- src/Lean/Compiler/LCNF/Simp/Main.lean | 42 +++++++++++++++------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/Lean/Compiler/LCNF/Simp/Main.lean b/src/Lean/Compiler/LCNF/Simp/Main.lean index 322e9c7c18..4239ee5b85 100644 --- a/src/Lean/Compiler/LCNF/Simp/Main.lean +++ b/src/Lean/Compiler/LCNF/Simp/Main.lean @@ -136,35 +136,41 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d simp code else let code ← simp code + let simpK (result : FVarId) : SimpM Code := do + /- `result` contains the result of the inlined code -/ + if numArgs > info.arity then + let decl ← mkAuxLetDecl (mkAppN (.fvar result) info.args[info.arity:]) + addFVarSubst fvarId decl.fvarId + simp (.let decl k) + else + addFVarSubst fvarId result + simp k if oneExitPointQuick code then -- TODO: if `k` is small, we should also inline it here markSimplified code.bind fun fvarId' => do markUsedFVar fvarId' - /- fvarId' is the result of the computation -/ - if numArgs > info.arity then - let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:]) - addFVarSubst fvarId decl.fvarId - simp (.let decl k) - else - addFVarSubst fvarId fvarId' - simp k + simpK fvarId' -- else if info.ifReduce then -- eraseCode code -- return none else markSimplified - let jpParam ← mkAuxParam expectedType - let jpValue ← if numArgs > info.arity then - let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:]) - addFVarSubst fvarId decl.fvarId - simp (.let decl k) + if expectedType.headBeta.isForall then + /- + If `code` returns a function, we create an auxiliary local function declaration (and eta-expand it) + instead of creating a joinpoint that takes a closure as an argument. + -/ + let auxFunDecl ← mkAuxFunDecl #[] code + let auxFunDecl ← auxFunDecl.etaExpand + let k ← simpK auxFunDecl.fvarId + attachCodeDecls #[.fun auxFunDecl] k else - addFVarSubst fvarId jpParam.fvarId - simp k - let jpDecl ← mkAuxJpDecl #[jpParam] jpValue - let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] - return Code.jp jpDecl code + let jpParam ← mkAuxParam expectedType + let jpValue ← simpK jpParam.fvarId + let jpDecl ← mkAuxJpDecl #[jpParam] jpValue + let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] + return Code.jp jpDecl code /-- Simplify the given local function declaration.