feat: start simp for new LCNF format
This commit is contained in:
parent
9446ae3056
commit
cd0dd4cc2f
2 changed files with 120 additions and 169 deletions
|
|
@ -35,8 +35,11 @@ def getLocalDecl (fvarId : FVarId) : CompilerM LocalDecl := do
|
|||
let some decl := (← get).lctx.localDecls.find? fvarId | throwError "unknown free variable {fvarId.name}"
|
||||
return decl
|
||||
|
||||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
|
||||
return (← get).lctx.funDecls.find? fvarId
|
||||
|
||||
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
|
||||
let some decl := (← get).lctx.funDecls.find? fvarId | throwError "unknown local function {fvarId.name}"
|
||||
let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
|
||||
return decl
|
||||
|
||||
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
|
||||
|
|
|
|||
|
|
@ -3,36 +3,14 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
|||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
#exit -- TODO: port to new LCNF
|
||||
import Lean.Compiler.CompilerM
|
||||
import Lean.Compiler.Decl
|
||||
import Lean.Compiler.Stage1
|
||||
import Lean.Util.Recognizers
|
||||
import Lean.Compiler.InlineAttrs
|
||||
import Lean.Compiler.LCNF.CompilerM
|
||||
import Lean.Compiler.LCNF.Stage1
|
||||
|
||||
namespace Lean.Compiler
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
partial def findLambdaCore? (lctx : LocalContext) (e : Expr) : Option LocalDecl :=
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
if let some d@(.ldecl (value := v) ..) := lctx.find? fvarId then
|
||||
if v.isLambda then some d else findLambdaCore? lctx v
|
||||
else
|
||||
none
|
||||
| .mdata _ e => findLambdaCore? lctx e
|
||||
| _ => none
|
||||
|
||||
partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) :=
|
||||
return findLambdaCore? (← getLCtx) e
|
||||
|
||||
partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
let some (.ldecl (value := v) ..) ← findDecl? fvarId | return e
|
||||
findExpr v
|
||||
| .mdata _ e' => if skipMData then findExpr e' else return e
|
||||
| _ => return e
|
||||
|
||||
/--
|
||||
Local function usage information used to decide whether it should be inlined or not.
|
||||
The information is an approximation, but it is on the "safe" side. That is, if we tagged
|
||||
|
|
@ -41,7 +19,7 @@ a function with `.once`, then it is applied only once. A local function may be m
|
|||
a big problem in practice because we run the simplifier multiple times, and this information
|
||||
is recomputed from scratch at the beginning of each simplification step.
|
||||
-/
|
||||
inductive LocalFunInfo where
|
||||
inductive FunDeclInfo where
|
||||
| /--
|
||||
Local function is applied once, and must be inlined.
|
||||
-/
|
||||
|
|
@ -51,40 +29,67 @@ inductive LocalFunInfo where
|
|||
if it is small.
|
||||
-/
|
||||
many
|
||||
| /--
|
||||
Function must be inlined.
|
||||
-/
|
||||
mustInline
|
||||
deriving Repr, Inhabited
|
||||
|
||||
/--
|
||||
Local function declaration statistics.
|
||||
|
||||
Remark: we use the `userName` as the key.
|
||||
-/
|
||||
structure LocalFunInfoMap where
|
||||
structure FunDeclInfoMap where
|
||||
/--
|
||||
Mapping from local function name to inlining information.
|
||||
-/
|
||||
map : Std.HashMap Name LocalFunInfo := {}
|
||||
map : Std.HashMap FVarId FunDeclInfo := {}
|
||||
deriving Inhabited
|
||||
|
||||
def LocalFunInfoMap.format (s : LocalFunInfoMap) : Format := Id.run do
|
||||
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
|
||||
let mut result := Format.nil
|
||||
for (k, n) in s.map.toList do
|
||||
result := result ++ "\n" ++ f!"{k} ↦ {repr n}"
|
||||
for (fvarId, info) in s.map.toList do
|
||||
let localDecl ← getLocalDecl fvarId
|
||||
result := result ++ "\n" ++ f!"{localDecl.userName} ↦ {repr info}"
|
||||
return result
|
||||
|
||||
instance : ToFormat LocalFunInfoMap where
|
||||
format := LocalFunInfoMap.format
|
||||
|
||||
/--
|
||||
Add new occurrence for the local function with binder name `key`.
|
||||
-/
|
||||
def LocalFunInfoMap.add (s : LocalFunInfoMap) (key : Name) : LocalFunInfoMap :=
|
||||
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map.find? key with
|
||||
| some .once => { map := map.insert key .many }
|
||||
| none => { map := map.insert key .once }
|
||||
match map.find? fvarId with
|
||||
| some .once => { map := map.insert fvarId .many }
|
||||
| none => { map := map.insert fvarId .once }
|
||||
| _ => { map }
|
||||
|
||||
/--
|
||||
Add new occurrence for the local function with binder name `key`.
|
||||
-/
|
||||
def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } => { map := map.insert fvarId .mustInline }
|
||||
|
||||
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
if let some decl ← LCNF.findFunDecl? fvarId then
|
||||
return some decl
|
||||
else if let .ldecl (value := v) .. ← getLocalDecl fvarId then
|
||||
findFunDecl? v
|
||||
else
|
||||
return none
|
||||
| .mdata _ e => findFunDecl? e
|
||||
| _ => return none
|
||||
|
||||
partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
let .ldecl (value := v) .. ← getLocalDecl fvarId | return e
|
||||
findExpr v
|
||||
| .mdata _ e' => if skipMData then findExpr e' else return e
|
||||
| _ => return e
|
||||
|
||||
structure Config where
|
||||
smallThreshold : Nat := 1
|
||||
|
||||
|
|
@ -92,10 +97,8 @@ structure Context where
|
|||
config : Config := {}
|
||||
|
||||
structure State where
|
||||
/--
|
||||
(Approximate) information for deciding whether to inline local function declarations.
|
||||
-/
|
||||
localInfoMap : LocalFunInfoMap := {}
|
||||
subst : FVarSubst := {}
|
||||
funDeclInfoMap : FunDeclInfoMap := {}
|
||||
/--
|
||||
`true` if some simplification was performed in the current simplification pass.
|
||||
-/
|
||||
|
|
@ -125,120 +128,31 @@ structure State where
|
|||
This is a performance counter.
|
||||
-/
|
||||
inlineLocal : Nat := 0
|
||||
deriving Inhabited
|
||||
|
||||
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
/-
|
||||
Ensure binder names are unique, and update local function information.
|
||||
If `mustInline = true`, then local functions in `e` are marked with binders of the
|
||||
form `_mustInline.<idx>`.
|
||||
Remark: we used to store the `mustInline` information in the map `localInfoMap`,
|
||||
using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect
|
||||
because there is no guarantee that we will be able to inline all occurrences of the
|
||||
function in the current `simp` step. Since, we recompute `localInfoMap` from scratch
|
||||
at the beginning of each compiler pass, the information was being lost.
|
||||
-/
|
||||
|
||||
structure Internalize.State where
|
||||
nextIdx : Nat
|
||||
localInfoMap : LocalFunInfoMap
|
||||
|
||||
private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit :=
|
||||
unless mustInline do
|
||||
modify fun s => { s with localInfoMap := s.localInfoMap.add key }
|
||||
|
||||
/--
|
||||
`instantiateRevInternalize` implementation.
|
||||
-/
|
||||
private def instantiateRevInternalizeCore (lctx : LocalContext) (e : Expr) (args : Array Expr) (mustInline : Bool) : StateM Internalize.State Expr :=
|
||||
go e {}
|
||||
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :=
|
||||
go code
|
||||
where
|
||||
/-- Auxiliary functions for instantiating `args` in types. -/
|
||||
inst (e : Expr) (offset : Nat) : Expr :=
|
||||
match e with
|
||||
| .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => e
|
||||
| .mdata k b => .mdata k (inst b offset)
|
||||
| .proj s i b => .proj s i (inst b offset)
|
||||
| .app f a => if offset >= e.looseBVarRange then e else .app (inst f offset) (inst a offset)
|
||||
| .bvar idx => if idx >= offset then args[args.size - (idx - offset) - 1]! else e
|
||||
| .forallE n d b bi => if offset >= e.looseBVarRange then e else .forallE n (inst d offset) (inst b (offset + 1)) bi
|
||||
| .lam n d b bi => if offset >= e.looseBVarRange then e else .lam n (inst d offset) (inst b (offset + 1)) bi
|
||||
| .letE n t v b nd => if offset >= e.looseBVarRange then e else .letE n (inst t offset) (inst v offset) (inst b (offset + 1)) nd
|
||||
go (code : Code) : SimpM Unit := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
if decl.value.isApp then
|
||||
if let some funDecl ← findFunDecl? decl.value.getAppFn then
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add funDecl.fvarId }
|
||||
go k
|
||||
| .fun decl k =>
|
||||
if mustInline then
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline decl.fvarId }
|
||||
go decl.value; go k
|
||||
| .jp decl k => go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .return .. | .jmp .. | .unreach .. => return ()
|
||||
|
||||
go (e : Expr) (ctx : Std.PArray (Option Name)) : StateM Internalize.State Expr := do
|
||||
let instantiate (e : Expr) := if args.size == 0 then e else inst e ctx.size
|
||||
let updtBVar (idx : Nat) :=
|
||||
let offset := ctx.size
|
||||
if idx >= offset then
|
||||
args[args.size - (idx - offset) - 1]!
|
||||
else
|
||||
.bvar idx
|
||||
match e with
|
||||
| .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => return e
|
||||
| .mdata k b => return .mdata k (← go b ctx)
|
||||
| .proj s i b => return .proj s i (← go b ctx)
|
||||
| .app f a =>
|
||||
let f ← go f ctx
|
||||
let a ← go a ctx
|
||||
match f with
|
||||
| .fvar .. =>
|
||||
match findLambdaCore? lctx f with
|
||||
| some localDecl => updateFunInfo localDecl.userName mustInline
|
||||
| _ => pure ()
|
||||
| .bvar idx =>
|
||||
match ctx[ctx.size - idx - 1]! with
|
||||
| some binderName => updateFunInfo binderName mustInline
|
||||
| none => pure ()
|
||||
| _ => pure ()
|
||||
return .app f a
|
||||
| .bvar idx => return updtBVar idx
|
||||
| .forallE .. => return instantiate e
|
||||
| .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi
|
||||
| .letE binderName type value body nonDep =>
|
||||
let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap })
|
||||
let binderName' :=
|
||||
if mustInline && value.isLambda then
|
||||
.num `_mustInline idx
|
||||
else match binderName with
|
||||
| .num p _ => .num p idx
|
||||
| _ => .num binderName idx
|
||||
let type := instantiate type
|
||||
let value ← go value ctx
|
||||
let ctxVal := match value with
|
||||
| .lam .. => some binderName'
|
||||
-- The next two cases simulate findLambdaCore? for `ctx`
|
||||
| .fvar .. => match findLambdaCore? lctx value with
|
||||
| some localDecl => some localDecl.userName
|
||||
| _ => none
|
||||
| .bvar idx => if idx < ctx.size then ctx[ctx.size - idx - 1]! else none
|
||||
| _ => none
|
||||
return .letE binderName' type value (← go body (ctx.push ctxVal)) nonDep
|
||||
|
||||
/--
|
||||
This function performs the following operations in the given expression in a single pass.
|
||||
- Ensure binder names for let-declarations are unique.
|
||||
- Update local function information. That is, it updates the map `localInfoMap`.
|
||||
- Apply `e.instantiateRev args`.
|
||||
|
||||
We use it to "internalize" expressions at startup and when performing inlining.
|
||||
-/
|
||||
def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do
|
||||
let lctx ← getLCtx
|
||||
let nextIdx := (← getThe CompilerM.State).nextIdx
|
||||
let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} })
|
||||
let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap }
|
||||
modifyThe CompilerM.State fun s => { s with nextIdx }
|
||||
modify fun s => { s with localInfoMap }
|
||||
return e
|
||||
|
||||
/--
|
||||
This function performs the following operations in the given expression in a single pass.
|
||||
- Ensure binder names for let-declarations are unique.
|
||||
- Update local function information. That is, it updates the map `localInfoMap`.
|
||||
-/
|
||||
def internalize (e : Expr) (mustInline := false) : SimpM Expr := do
|
||||
instantiateRevInternalize e #[] mustInline
|
||||
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
|
||||
match (← get).funDeclInfoMap.map.find? fvarId with
|
||||
| some .once | some .mustInline => return true
|
||||
| _ => return false
|
||||
|
||||
def markSimplified : SimpM Unit :=
|
||||
modify fun s => { s with simplified := true }
|
||||
|
|
@ -274,12 +188,54 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
|
|||
markSimplified
|
||||
return mkAppN f e.getAppArgs
|
||||
|
||||
def isOnceOrMustInline (binderName : Name) : SimpM Bool := do
|
||||
if binderName.getPrefix == `_mustInline then
|
||||
return true
|
||||
else match (← get).localInfoMap.map.find? binderName with
|
||||
| some .once => return true
|
||||
| _ => return false
|
||||
end Simp
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.simp.inline
|
||||
registerTraceClass `Compiler.simp.inline.info
|
||||
registerTraceClass `Compiler.simp.stat
|
||||
registerTraceClass `Compiler.simp.step
|
||||
registerTraceClass `Compiler.simp.step.new
|
||||
registerTraceClass `Compiler.simp.projInst
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
#exit -- TODO: port rest of file
|
||||
|
||||
|
||||
namespace Lean.Compiler
|
||||
namespace Simp
|
||||
|
||||
|
||||
/-
|
||||
Ensure binder names are unique, and update local function information.
|
||||
If `mustInline = true`, then local functions in `e` are marked with binders of the
|
||||
form `_mustInline.<idx>`.
|
||||
Remark: we used to store the `mustInline` information in the map `localInfoMap`,
|
||||
using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect
|
||||
because there is no guarantee that we will be able to inline all occurrences of the
|
||||
function in the current `simp` step. Since, we recompute `localInfoMap` from scratch
|
||||
at the beginning of each compiler pass, the information was being lost.
|
||||
-/
|
||||
|
||||
|
||||
/--
|
||||
This function performs the following operations in the given expression in a single pass.
|
||||
- Ensure binder names for let-declarations are unique.
|
||||
- Update local function information. That is, it updates the map `localInfoMap`.
|
||||
- Apply `e.instantiateRev args`.
|
||||
|
||||
We use it to "internalize" expressions at startup and when performing inlining.
|
||||
-/
|
||||
def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do
|
||||
let lctx ← getLCtx
|
||||
let nextIdx := (← getThe CompilerM.State).nextIdx
|
||||
let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} })
|
||||
let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap }
|
||||
modifyThe CompilerM.State fun s => { s with nextIdx }
|
||||
modify fun s => { s with localInfoMap }
|
||||
return e
|
||||
|
||||
|
||||
def isSmallValue (value : Expr) : SimpM Bool := do
|
||||
lcnfSizeLe value (← read).config.smallThreshold
|
||||
|
|
@ -748,12 +704,4 @@ partial def Decl.simp (decl : Decl) : CoreM Decl := do
|
|||
else
|
||||
return decl
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.simp.inline
|
||||
registerTraceClass `Compiler.simp.inline.info
|
||||
registerTraceClass `Compiler.simp.stat
|
||||
registerTraceClass `Compiler.simp.step
|
||||
registerTraceClass `Compiler.simp.step.new
|
||||
registerTraceClass `Compiler.simp.projInst
|
||||
|
||||
end Lean.Compiler
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue