diff --git a/src/Lean/Compiler/Simp.lean b/src/Lean/Compiler/Simp.lean index 16c38ff1ff..00245998cc 100644 --- a/src/Lean/Compiler/Simp.lean +++ b/src/Lean/Compiler/Simp.lean @@ -50,11 +50,6 @@ inductive LocalFunInfo where if it is small. -/ many - | /-- - We always inline this local function. We use this annotation for - type class instance elements. - -/ - mustInline deriving Repr, Inhabited /-- @@ -89,14 +84,6 @@ def LocalFunInfoMap.add (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap := | none => { map := map.insert key .once } | _ => { map } -/-- -Mark the function with binder name `key` as `.mustInline`. -We use this marker for auxiliary functions in type class instances. --/ -def LocalFunInfoMap.addMustInline (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap := - match s with - | { map } => { map := map.insert key .mustInline } - structure Config where smallThreshold : Nat := 1 @@ -123,7 +110,13 @@ abbrev SimpM := ReaderT Context $ StateRefT State CompilerM /- Ensure binder names are unique, and update local function information. -If `mustInline = true`, then local functions in `e` are marked as `.mustInline`. +If `mustInline = true`, then local functions in `e` are marked with binders of the +form `_mustInline.`. +Remark: we used to store the `mustInline` information in the map `localInfoMap`, +using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect +because there is no guarantee that we will be able to inline all occurrences of the +function in the current `simp` step. Since, we recompute `localInfoMap` from scratch +at the beginning of each compiler pass, the information was being lost. -/ structure Internalize.State where @@ -131,9 +124,7 @@ structure Internalize.State where localInfoMap : LocalFunInfoMap private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit := - if mustInline then - modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key } - else + unless mustInline do modify fun s => { s with localInfoMap := s.localInfoMap.add key } /-- @@ -185,9 +176,12 @@ where | .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi | .letE binderName type value body nonDep => let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap }) - let binderName' := match binderName with - | .num p _ => .num p idx - | _ => .num binderName idx + let binderName' := + if mustInline then + .num `_mustInline idx + else match binderName with + | .num p _ => .num p idx + | _ => .num binderName idx let type := instantiate type let value ← go value ctx let ctxVal := match value with @@ -260,9 +254,11 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do return mkAppN f e.getAppArgs def isOnceOrMustInline (binderName : Name) : SimpM Bool := do - match (← get).localInfoMap.map.find? binderName with - | some .once | some .mustInline => return true - | _ => return false + if binderName.getPrefix == `_mustInline then + return true + else match (← get).localInfoMap.map.find? binderName with + | some .once => return true + | _ => return false def isSmallValue (value : Expr) : SimpM Bool := do lcnfSizeLe value (← read).config.smallThreshold