perf: combine instantiateRev and internalize in a single traversal

This commit is contained in:
Leonardo de Moura 2022-08-20 19:57:40 -07:00
parent b2a99e1b68
commit fa7769260a

View file

@ -11,13 +11,18 @@ import Lean.Compiler.InlineAttrs
namespace Lean.Compiler
namespace Simp
partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) := do
partial def findLambdaCore? (lctx : LocalContext) (e : Expr) : Option LocalDecl :=
match e with
| .fvar fvarId =>
let some d@(.ldecl (value := v) ..) ← findDecl? fvarId | return none
if v.isLambda then return some d else findLambda? v
| .mdata _ e => findLambda? e
| _ => return none
if let some d@(.ldecl (value := v) ..) := lctx.find? fvarId then
if v.isLambda then some d else findLambdaCore? lctx v
else
none
| .mdata _ e => findLambdaCore? lctx e
| _ => none
partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) :=
return findLambdaCore? (← getLCtx) e
partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do
match e with
@ -116,59 +121,109 @@ structure State where
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
/--
/-
Ensure binder names are unique, and update local function information.
If `mustInline = true`, then local functions in `e` are marked as `.mustInline`.
-/
partial def internalize (e : Expr) (mustInline := false): SimpM Expr := do
visitLambda e
structure Internalize.State where
nextIdx : Nat
localInfoMap : LocalFunInfoMap
private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit :=
if mustInline then
modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key }
else
modify fun s => { s with localInfoMap := s.localInfoMap.add key }
/--
`instantiateRevInternalize` implementation.
-/
private def instantiateRevInternalizeCore (lctx : LocalContext) (e : Expr) (args : Array Expr) (mustInline : Bool) : StateM Internalize.State Expr :=
go e {}
where
visitLambda (e : Expr) : SimpM Expr := do
withNewScope do
let (as, e) ← Compiler.visitLambdaCore e
let e ← mkLetUsingScope (← visitLet e as)
mkLambda as e
visitCases (casesInfo : CasesInfo) (cases : Expr) : SimpM Expr := do
let mut args := cases.getAppArgs
for i in casesInfo.altsRange do
args ← args.modifyM i visitLambda
return mkAppN cases.getAppFn args
visitValue (e : Expr) : SimpM Unit := do
if e.isApp then
match (← findLambda? e.getAppFn) with
| some localDecl =>
if localDecl.value.isLambda then
let key := localDecl.userName
if mustInline then
modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key }
else
modify fun s => { s with localInfoMap := s.localInfoMap.add key }
| _ => pure ()
visitLet (e : Expr) (xs : Array Expr) : SimpM Expr := do
/-- Auxiliary functions for instantiating `args` in types. -/
inst (e : Expr) (offset : Nat) : Expr :=
match e with
| .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => e
| .mdata k b => .mdata k (inst b offset)
| .proj s i b => .proj s i (inst b offset)
| .app f a => if offset >= e.looseBVarRange then e else .app (inst f offset) (inst a offset)
| .bvar idx => if idx >= offset then args[args.size - (idx - offset) - 1]! else e
| .forallE n d b bi => if offset >= e.looseBVarRange then e else .forallE n (inst d offset) (inst b (offset + 1)) bi
| .lam n d b bi => if offset >= e.looseBVarRange then e else .lam n (inst d offset) (inst b (offset + 1)) bi
| .letE n t v b nd => if offset >= e.looseBVarRange then e else .letE n (inst t offset) (inst v offset) (inst b (offset + 1)) nd
go (e : Expr) (ctx : Std.PArray (Option Name)) : StateM Internalize.State Expr := do
let instantiate (e : Expr) := if args.size == 0 then e else inst e ctx.size
let updtBVar (idx : Nat) :=
let offset := ctx.size
if idx >= offset then
args[args.size - (idx - offset) - 1]!
else
.bvar idx
match e with
| .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => return e
| .mdata k b => return .mdata k (← go b ctx)
| .proj s i b => return .proj s i (← go b ctx)
| .app f a =>
let f ← go f ctx
let a ← go a ctx
match f with
| .fvar .. =>
match findLambdaCore? lctx f with
| some localDecl => updateFunInfo localDecl.userName mustInline
| _ => pure ()
| .bvar idx =>
match ctx[ctx.size - idx - 1]! with
| some binderName => updateFunInfo binderName mustInline
| none => pure ()
| _ => pure ()
return .app f a
| .bvar idx => return updtBVar idx
| .forallE .. => return instantiate e
| .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi
| .letE binderName type value body nonDep =>
let idx ← mkFreshLetVarIdx
let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap })
let binderName' := match binderName with
| .num p _ => .num p idx
| _ => .num binderName idx
let type := type.instantiateRev xs
let mut value := value.instantiateRev xs
if value.isLambda then
value ← visitLambda value
else
visitValue value
let x ← mkLetDecl binderName' type value nonDep
visitLet body (xs.push x)
| _ =>
let e := e.instantiateRev xs
if let some casesInfo ← isCasesApp? e then
visitCases casesInfo e
else
visitValue e
return e
let type := instantiate type
let value ← go value ctx
let ctxVal := match value with
| .lam .. => some binderName'
-- The next two cases simulate findLambdaCore? for `ctx`
| .fvar .. => match findLambdaCore? lctx value with
| some localDecl => some localDecl.userName
| _ => none
| .bvar idx => if idx < ctx.size then ctx[ctx.size - idx - 1]! else none
| _ => none
return .letE binderName' type value (← go body (ctx.push ctxVal)) nonDep
/--
This function performs the following operations in the given expression in a single pass.
- Ensure binder names for let-declarations are unique.
- Update local function information. That is, it updates the map `localInfoMap`.
- Apply `e.instantiateRev args`.
We use it to "internalize" expressions at startup and when performing inlining.
-/
def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do
let lctx ← getLCtx
let nextIdx := (← getThe CompilerM.State).nextIdx
let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} })
let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap }
modifyThe CompilerM.State fun s => { s with nextIdx }
modify fun s => { s with localInfoMap }
return e
/--
This function performs the following operations in the given expression in a single pass.
- Ensure binder names for let-declarations are unique.
- Update local function information. That is, it updates the map `localInfoMap`.
-/
def internalize (e : Expr) (mustInline := false) : SimpM Expr := do
instantiateRevInternalize e #[] mustInline
def markSimplified : SimpM Unit :=
modify fun s => { s with simplified := true }
@ -378,9 +433,12 @@ where
visit value
def betaReduce (e : Expr) (args : Array Expr) : SimpM Expr := do
-- TODO: add necessary casts
let result ← internalize (e.beta args)
-- trace[Meta.debug] "inline:\n{result}"
-- TODO: add necessary casts to `args`
let rec getLambdaBody : Expr → Expr
| .lam _ _ b _ => getLambdaBody b
| b => b
let result ← instantiateRevInternalize (getLambdaBody e) args
trace[Meta.debug] "inline:\n{result}"
return result
/--