fix: mustInline at Compiler.simp

This commit is contained in:
Leonardo de Moura 2022-08-21 13:57:56 -07:00
parent df89717ae3
commit 9700b58114

View file

@ -50,11 +50,6 @@ inductive LocalFunInfo where
if it is small.
-/
many
| /--
We always inline this local function. We use this annotation for
type class instance elements.
-/
mustInline
deriving Repr, Inhabited
/--
@ -89,14 +84,6 @@ def LocalFunInfoMap.add (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap :=
| none => { map := map.insert key .once }
| _ => { map }
/--
Mark the function with binder name `key` as `.mustInline`.
We use this marker for auxiliary functions in type class instances.
-/
def LocalFunInfoMap.addMustInline (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap :=
match s with
| { map } => { map := map.insert key .mustInline }
structure Config where
smallThreshold : Nat := 1
@ -123,7 +110,13 @@ abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
/-
Ensure binder names are unique, and update local function information.
If `mustInline = true`, then local functions in `e` are marked as `.mustInline`.
If `mustInline = true`, then local functions in `e` are marked with binders of the
form `_mustInline.<idx>`.
Remark: we used to store the `mustInline` information in the map `localInfoMap`,
using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect
because there is no guarantee that we will be able to inline all occurrences of the
function in the current `simp` step. Since, we recompute `localInfoMap` from scratch
at the beginning of each compiler pass, the information was being lost.
-/
structure Internalize.State where
@ -131,9 +124,7 @@ structure Internalize.State where
localInfoMap : LocalFunInfoMap
private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit :=
if mustInline then
modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key }
else
unless mustInline do
modify fun s => { s with localInfoMap := s.localInfoMap.add key }
/--
@ -185,9 +176,12 @@ where
| .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi
| .letE binderName type value body nonDep =>
let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap })
let binderName' := match binderName with
| .num p _ => .num p idx
| _ => .num binderName idx
let binderName' :=
if mustInline then
.num `_mustInline idx
else match binderName with
| .num p _ => .num p idx
| _ => .num binderName idx
let type := instantiate type
let value ← go value ctx
let ctxVal := match value with
@ -260,9 +254,11 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
return mkAppN f e.getAppArgs
def isOnceOrMustInline (binderName : Name) : SimpM Bool := do
match (← get).localInfoMap.map.find? binderName with
| some .once | some .mustInline => return true
| _ => return false
if binderName.getPrefix == `_mustInline then
return true
else match (← get).localInfoMap.map.find? binderName with
| some .once => return true
| _ => return false
def isSmallValue (value : Expr) : SimpM Bool := do
lcnfSizeLe value (← read).config.smallThreshold