From 776a9b0dcb95e6aa5f69748df579f5704cfdfa92 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 20 Aug 2022 17:03:57 -0700 Subject: [PATCH] feat: don't eagerly simplify local functions that will be inlined --- src/Lean/Compiler/Simp.lean | 19 ++++++-- tests/lean/inlineIssue.lean.expected.out | 60 ++++++++++++------------ 2 files changed, 45 insertions(+), 34 deletions(-) diff --git a/src/Lean/Compiler/Simp.lean b/src/Lean/Compiler/Simp.lean index d824bfc956..0a469c18ed 100644 --- a/src/Lean/Compiler/Simp.lean +++ b/src/Lean/Compiler/Simp.lean @@ -204,10 +204,16 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do markSimplified return mkAppN f e.getAppArgs -def shouldInlineLocal (localDecl : LocalDecl) : SimpM Bool := do - match (← get).localInfoMap.map.find? localDecl.userName with +def isOnceOrMustInline (binderName : Name) : SimpM Bool := do + match (← get).localInfoMap.map.find? binderName with | some .once | some .mustInline => return true - | _ => lcnfSizeLe localDecl.value (← read).config.smallThreshold + | _ => return false + +def shouldInlineLocal (localDecl : LocalDecl) : SimpM Bool := do + if (← isOnceOrMustInline localDecl.userName) then + return true + else + lcnfSizeLe localDecl.value (← read).config.smallThreshold structure InlineCandidateInfo where isLocal : Bool @@ -566,7 +572,12 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do | .letE binderName type value body nonDep => let mut value := value.instantiateRev xs if value.isLambda then - value ← visitLambda value + unless (← isOnceOrMustInline binderName) do + /- + If the local function will be inlined anyway, we don't simplify it here, + we do it after its is inlined and we have information about the actual arguments. + -/ + value ← visitLambda value else if let some value' ← simpValue? value then if value'.isLet then let e := mkFlatLet binderName type value' body nonDep diff --git a/tests/lean/inlineIssue.lean.expected.out b/tests/lean/inlineIssue.lean.expected.out index 7c508928c7..e9559ebeee 100644 --- a/tests/lean/inlineIssue.lean.expected.out +++ b/tests/lean/inlineIssue.lean.expected.out @@ -3,43 +3,43 @@ fun x => let h.14 := fun x_1 => Nat.casesOn x_1 - (let _x.86 := Nat.mul x x_1; - let _x.87 := Nat.mul _x.86 x; - let _x.88 := Nat.mul _x.87 x_1; - Nat.mul _x.88 x_1) + (let _x.92 := Nat.mul x x_1; + let _x.93 := Nat.mul _x.92 x; + let _x.94 := Nat.mul _x.93 x_1; + Nat.mul _x.94 x_1) fun n => Nat.add n x_1; let _x.15 := 1; - let _x.89 := Nat.add x _x.15; - let _x.19 := h.14 _x.89; + let _x.95 := Nat.add x _x.15; + let _x.19 := h.14 _x.95; let _x.20 := 2; - let _x.90 := Nat.add x _x.20; - let _x.24 := h.14 _x.90; - let _x.91 := Nat.add _x.19 _x.24; + let _x.96 := Nat.add x _x.20; + let _x.24 := h.14 _x.96; + let _x.97 := Nat.add _x.19 _x.24; let _x.26 := 3; - let _x.92 := Nat.add x _x.26; - let _x.30 := h.14 _x.92; - let _x.93 := Nat.add _x.91 _x.30; + let _x.98 := Nat.add x _x.26; + let _x.30 := h.14 _x.98; + let _x.99 := Nat.add _x.97 _x.30; let _x.32 := 4; - let _x.94 := Nat.add x _x.32; - let _x.36 := h.14 _x.94; - let _x.95 := Nat.add _x.93 _x.36; + let _x.100 := Nat.add x _x.32; + let _x.36 := h.14 _x.100; + let _x.101 := Nat.add _x.99 _x.36; let _x.38 := 5; - let _x.96 := Nat.add x _x.38; - let _x.42 := h.14 _x.96; - let _x.97 := Nat.add _x.95 _x.42; + let _x.102 := Nat.add x _x.38; + let _x.42 := h.14 _x.102; + let _x.103 := Nat.add _x.101 _x.42; let _x.44 := 6; - let _x.98 := Nat.add x _x.44; - let _x.48 := h.14 _x.98; - let _x.99 := Nat.add _x.97 _x.48; + let _x.104 := Nat.add x _x.44; + let _x.48 := h.14 _x.104; + let _x.105 := Nat.add _x.103 _x.48; let _x.50 := 7; - let _x.100 := Nat.add x _x.50; - let _x.54 := h.14 _x.100; - let _x.101 := Nat.add _x.99 _x.54; + let _x.106 := Nat.add x _x.50; + let _x.54 := h.14 _x.106; + let _x.107 := Nat.add _x.105 _x.54; let _x.56 := 8; - let _x.102 := Nat.add x _x.56; - let _x.60 := h.14 _x.102; - let _x.103 := Nat.add _x.101 _x.60; + let _x.108 := Nat.add x _x.56; + let _x.60 := h.14 _x.108; + let _x.109 := Nat.add _x.107 _x.60; let _x.62 := 9; - let _x.104 := Nat.add x _x.62; - let _x.66 := h.14 _x.104; - Nat.add _x.103 _x.66 + let _x.110 := Nat.add x _x.62; + let _x.66 := h.14 _x.110; + Nat.add _x.109 _x.66