fix: mustInline at Compiler.simp
This commit is contained in:
parent
df89717ae3
commit
9700b58114
1 changed files with 19 additions and 23 deletions
|
|
@ -50,11 +50,6 @@ inductive LocalFunInfo where
|
|||
if it is small.
|
||||
-/
|
||||
many
|
||||
| /--
|
||||
We always inline this local function. We use this annotation for
|
||||
type class instance elements.
|
||||
-/
|
||||
mustInline
|
||||
deriving Repr, Inhabited
|
||||
|
||||
/--
|
||||
|
|
@ -89,14 +84,6 @@ def LocalFunInfoMap.add (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap :=
|
|||
| none => { map := map.insert key .once }
|
||||
| _ => { map }
|
||||
|
||||
/--
|
||||
Mark the function with binder name `key` as `.mustInline`.
|
||||
We use this marker for auxiliary functions in type class instances.
|
||||
-/
|
||||
def LocalFunInfoMap.addMustInline (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap :=
|
||||
match s with
|
||||
| { map } => { map := map.insert key .mustInline }
|
||||
|
||||
structure Config where
|
||||
smallThreshold : Nat := 1
|
||||
|
||||
|
|
@ -123,7 +110,13 @@ abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
|
|||
|
||||
/-
|
||||
Ensure binder names are unique, and update local function information.
|
||||
If `mustInline = true`, then local functions in `e` are marked as `.mustInline`.
|
||||
If `mustInline = true`, then local functions in `e` are marked with binders of the
|
||||
form `_mustInline.<idx>`.
|
||||
Remark: we used to store the `mustInline` information in the map `localInfoMap`,
|
||||
using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect
|
||||
because there is no guarantee that we will be able to inline all occurrences of the
|
||||
function in the current `simp` step. Since, we recompute `localInfoMap` from scratch
|
||||
at the beginning of each compiler pass, the information was being lost.
|
||||
-/
|
||||
|
||||
structure Internalize.State where
|
||||
|
|
@ -131,9 +124,7 @@ structure Internalize.State where
|
|||
localInfoMap : LocalFunInfoMap
|
||||
|
||||
private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit :=
|
||||
if mustInline then
|
||||
modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key }
|
||||
else
|
||||
unless mustInline do
|
||||
modify fun s => { s with localInfoMap := s.localInfoMap.add key }
|
||||
|
||||
/--
|
||||
|
|
@ -185,9 +176,12 @@ where
|
|||
| .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi
|
||||
| .letE binderName type value body nonDep =>
|
||||
let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap })
|
||||
let binderName' := match binderName with
|
||||
| .num p _ => .num p idx
|
||||
| _ => .num binderName idx
|
||||
let binderName' :=
|
||||
if mustInline then
|
||||
.num `_mustInline idx
|
||||
else match binderName with
|
||||
| .num p _ => .num p idx
|
||||
| _ => .num binderName idx
|
||||
let type := instantiate type
|
||||
let value ← go value ctx
|
||||
let ctxVal := match value with
|
||||
|
|
@ -260,9 +254,11 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
|
|||
return mkAppN f e.getAppArgs
|
||||
|
||||
def isOnceOrMustInline (binderName : Name) : SimpM Bool := do
|
||||
match (← get).localInfoMap.map.find? binderName with
|
||||
| some .once | some .mustInline => return true
|
||||
| _ => return false
|
||||
if binderName.getPrefix == `_mustInline then
|
||||
return true
|
||||
else match (← get).localInfoMap.map.find? binderName with
|
||||
| some .once => return true
|
||||
| _ => return false
|
||||
|
||||
def isSmallValue (value : Expr) : SimpM Bool := do
|
||||
lcnfSizeLe value (← read).config.smallThreshold
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue