feat: start simp for new LCNF format

This commit is contained in:
Leonardo de Moura 2022-08-27 19:59:31 -07:00
parent 9446ae3056
commit cd0dd4cc2f
2 changed files with 120 additions and 169 deletions

View file

@ -35,8 +35,11 @@ def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do
let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}"
return decl
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
return (← get).lctx.funDecls.find? fvarId
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
let some decl := (← get).lctx.funDecls.find? fvarId | throwError "unknown local function {fvarId.name}"
let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
return decl
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do

View file

@ -3,36 +3,14 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
#exit -- TODO: port to new LCNF
import Lean.Compiler.CompilerM
import Lean.Compiler.Decl
import Lean.Compiler.Stage1
import Lean.Util.Recognizers
import Lean.Compiler.InlineAttrs
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.Stage1
namespace Lean.Compiler
namespace Lean.Compiler.LCNF
namespace Simp
partial def findLambdaCore? (lctx : LocalContext) (e : Expr) : Option LocalDecl :=
match e with
| .fvar fvarId =>
if let some d@(.ldecl (value := v) ..) := lctx.find? fvarId then
if v.isLambda then some d else findLambdaCore? lctx v
else
none
| .mdata _ e => findLambdaCore? lctx e
| _ => none
partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) :=
return findLambdaCore? (← getLCtx) e
partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do
match e with
| .fvar fvarId =>
let some (.ldecl (value := v) ..) ← findDecl? fvarId | return e
findExpr v
| .mdata _ e' => if skipMData then findExpr e' else return e
| _ => return e
/--
Local function usage information used to decide whether it should be inlined or not.
The information is an approximation, but it is on the "safe" side. That is, if we tagged
@ -41,7 +19,7 @@ a function with `.once`, then it is applied only once. A local function may be m
a big problem in practice because we run the simplifier multiple times, and this information
is recomputed from scratch at the beginning of each simplification step.
-/
inductive LocalFunInfo where
inductive FunDeclInfo where
| /--
Local function is applied once, and must be inlined.
-/
@ -51,40 +29,67 @@ inductive LocalFunInfo where
if it is small.
-/
many
| /--
Function must be inlined.
-/
mustInline
deriving Repr, Inhabited
/--
Local function declaration statistics.
Remark: we use the `userName` as the key.
-/
structure LocalFunInfoMap where
structure FunDeclInfoMap where
/--
Mapping from local function name to inlining information.
-/
map : Std.HashMap Name LocalFunInfo := {}
map : Std.HashMap FVarId FunDeclInfo := {}
deriving Inhabited
def LocalFunInfoMap.format (s : LocalFunInfoMap) : Format := Id.run do
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
let mut result := Format.nil
for (k, n) in s.map.toList do
result := result ++ "\n" ++ f!"{k} ↦ {repr n}"
for (fvarId, info) in s.map.toList do
let localDecl ← getLocalDecl fvarId
result := result ++ "\n" ++ f!"{localDecl.userName} ↦ {repr info}"
return result
instance : ToFormat LocalFunInfoMap where
format := LocalFunInfoMap.format
/--
Add new occurrence for the local function with binder name `key`.
-/
def LocalFunInfoMap.add (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap :=
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } =>
match map.find? key with
| some .once => { map := map.insert key .many }
| none => { map := map.insert key .once }
match map.find? fvarId with
| some .once => { map := map.insert fvarId .many }
| none => { map := map.insert fvarId .once }
| _ => { map }
/--
Add new occurrence for the local function with binder name `key`.
-/
def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } => { map := map.insert fvarId .mustInline }
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
match e with
| .fvar fvarId =>
if let some decl ← LCNF.findFunDecl? fvarId then
return some decl
else if let .ldecl (value := v) .. ← getLocalDecl fvarId then
findFunDecl? v
else
return none
| .mdata _ e => findFunDecl? e
| _ => return none
partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
match e with
| .fvar fvarId =>
let .ldecl (value := v) .. ← getLocalDecl fvarId | return e
findExpr v
| .mdata _ e' => if skipMData then findExpr e' else return e
| _ => return e
structure Config where
smallThreshold : Nat := 1
@ -92,10 +97,8 @@ structure Context where
config : Config := {}
structure State where
/--
(Approximate) information for deciding whether to inline local function declarations.
-/
localInfoMap : LocalFunInfoMap := {}
subst : FVarSubst := {}
funDeclInfoMap : FunDeclInfoMap := {}
/--
`true` if some simplification was performed in the current simplification pass.
-/
@ -125,120 +128,31 @@ structure State where
This is a performance counter.
-/
inlineLocal : Nat := 0
deriving Inhabited
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
/-
Ensure binder names are unique, and update local function information.
If `mustInline = true`, then local functions in `e` are marked with binders of the
form `_mustInline.<idx>`.
Remark: we used to store the `mustInline` information in the map `localInfoMap`,
using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect
because there is no guarantee that we will be able to inline all occurrences of the
function in the current `simp` step. Since, we recompute `localInfoMap` from scratch
at the beginning of each compiler pass, the information was being lost.
-/
structure Internalize.State where
nextIdx : Nat
localInfoMap : LocalFunInfoMap
private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit :=
unless mustInline do
modify fun s => { s with localInfoMap := s.localInfoMap.add key }
/--
`instantiateRevInternalize` implementation.
-/
private def instantiateRevInternalizeCore (lctx : LocalContext) (e : Expr) (args : Array Expr) (mustInline : Bool) : StateM Internalize.State Expr :=
go e {}
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :=
go code
where
/-- Auxiliary functions for instantiating `args` in types. -/
inst (e : Expr) (offset : Nat) : Expr :=
match e with
| .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => e
| .mdata k b => .mdata k (inst b offset)
| .proj s i b => .proj s i (inst b offset)
| .app f a => if offset >= e.looseBVarRange then e else .app (inst f offset) (inst a offset)
| .bvar idx => if idx >= offset then args[args.size - (idx - offset) - 1]! else e
| .forallE n d b bi => if offset >= e.looseBVarRange then e else .forallE n (inst d offset) (inst b (offset + 1)) bi
| .lam n d b bi => if offset >= e.looseBVarRange then e else .lam n (inst d offset) (inst b (offset + 1)) bi
| .letE n t v b nd => if offset >= e.looseBVarRange then e else .letE n (inst t offset) (inst v offset) (inst b (offset + 1)) nd
go (code : Code) : SimpM Unit := do
match code with
| .let decl k =>
if decl.value.isApp then
if let some funDecl ← findFunDecl? decl.value.getAppFn then
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add funDecl.fvarId }
go k
| .fun decl k =>
if mustInline then
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline decl.fvarId }
go decl.value; go k
| .jp decl k => go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .return .. | .jmp .. | .unreach .. => return ()
go (e : Expr) (ctx : Std.PArray (Option Name)) : StateM Internalize.State Expr := do
let instantiate (e : Expr) := if args.size == 0 then e else inst e ctx.size
let updtBVar (idx : Nat) :=
let offset := ctx.size
if idx >= offset then
args[args.size - (idx - offset) - 1]!
else
.bvar idx
match e with
| .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => return e
| .mdata k b => return .mdata k (← go b ctx)
| .proj s i b => return .proj s i (← go b ctx)
| .app f a =>
let f ← go f ctx
let a ← go a ctx
match f with
| .fvar .. =>
match findLambdaCore? lctx f with
| some localDecl => updateFunInfo localDecl.userName mustInline
| _ => pure ()
| .bvar idx =>
match ctx[ctx.size - idx - 1]! with
| some binderName => updateFunInfo binderName mustInline
| none => pure ()
| _ => pure ()
return .app f a
| .bvar idx => return updtBVar idx
| .forallE .. => return instantiate e
| .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi
| .letE binderName type value body nonDep =>
let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap })
let binderName' :=
if mustInline && value.isLambda then
.num `_mustInline idx
else match binderName with
| .num p _ => .num p idx
| _ => .num binderName idx
let type := instantiate type
let value ← go value ctx
let ctxVal := match value with
| .lam .. => some binderName'
-- The next two cases simulate findLambdaCore? for `ctx`
| .fvar .. => match findLambdaCore? lctx value with
| some localDecl => some localDecl.userName
| _ => none
| .bvar idx => if idx < ctx.size then ctx[ctx.size - idx - 1]! else none
| _ => none
return .letE binderName' type value (← go body (ctx.push ctxVal)) nonDep
/--
This function performs the following operations in the given expression in a single pass.
- Ensure binder names for let-declarations are unique.
- Update local function information. That is, it updates the map `localInfoMap`.
- Apply `e.instantiateRev args`.
We use it to "internalize" expressions at startup and when performing inlining.
-/
def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do
let lctx ← getLCtx
let nextIdx := (← getThe CompilerM.State).nextIdx
let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} })
let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap }
modifyThe CompilerM.State fun s => { s with nextIdx }
modify fun s => { s with localInfoMap }
return e
/--
This function performs the following operations in the given expression in a single pass.
- Ensure binder names for let-declarations are unique.
- Update local function information. That is, it updates the map `localInfoMap`.
-/
def internalize (e : Expr) (mustInline := false) : SimpM Expr := do
instantiateRevInternalize e #[] mustInline
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
match (← get).funDeclInfoMap.map.find? fvarId with
| some .once | some .mustInline => return true
| _ => return false
def markSimplified : SimpM Unit :=
modify fun s => { s with simplified := true }
@ -274,12 +188,54 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
markSimplified
return mkAppN f e.getAppArgs
def isOnceOrMustInline (binderName : Name) : SimpM Bool := do
if binderName.getPrefix == `_mustInline then
return true
else match (← get).localInfoMap.map.find? binderName with
| some .once => return true
| _ => return false
end Simp
builtin_initialize
registerTraceClass `Compiler.simp.inline
registerTraceClass `Compiler.simp.inline.info
registerTraceClass `Compiler.simp.stat
registerTraceClass `Compiler.simp.step
registerTraceClass `Compiler.simp.step.new
registerTraceClass `Compiler.simp.projInst
end Lean.Compiler.LCNF
#exit -- TODO: port rest of file
namespace Lean.Compiler
namespace Simp
/-
Ensure binder names are unique, and update local function information.
If `mustInline = true`, then local functions in `e` are marked with binders of the
form `_mustInline.<idx>`.
Remark: we used to store the `mustInline` information in the map `localInfoMap`,
using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect
because there is no guarantee that we will be able to inline all occurrences of the
function in the current `simp` step. Since, we recompute `localInfoMap` from scratch
at the beginning of each compiler pass, the information was being lost.
-/
/--
This function performs the following operations in the given expression in a single pass.
- Ensure binder names for let-declarations are unique.
- Update local function information. That is, it updates the map `localInfoMap`.
- Apply `e.instantiateRev args`.
We use it to "internalize" expressions at startup and when performing inlining.
-/
def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do
let lctx ← getLCtx
let nextIdx := (← getThe CompilerM.State).nextIdx
let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} })
let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap }
modifyThe CompilerM.State fun s => { s with nextIdx }
modify fun s => { s with localInfoMap }
return e
def isSmallValue (value : Expr) : SimpM Bool := do
lcnfSizeLe value (← read).config.smallThreshold
@ -748,12 +704,4 @@ partial def Decl.simp (decl : Decl) : CoreM Decl := do
else
return decl
builtin_initialize
registerTraceClass `Compiler.simp.inline
registerTraceClass `Compiler.simp.inline.info
registerTraceClass `Compiler.simp.stat
registerTraceClass `Compiler.simp.step
registerTraceClass `Compiler.simp.step.new
registerTraceClass `Compiler.simp.projInst
end Lean.Compiler