perf: combine instantiateRev and internalize in a single traversal
This commit is contained in:
parent
b2a99e1b68
commit
fa7769260a
1 changed files with 110 additions and 52 deletions
|
|
@ -11,13 +11,18 @@ import Lean.Compiler.InlineAttrs
|
|||
namespace Lean.Compiler
|
||||
namespace Simp
|
||||
|
||||
partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) := do
|
||||
partial def findLambdaCore? (lctx : LocalContext) (e : Expr) : Option LocalDecl :=
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
let some d@(.ldecl (value := v) ..) ← findDecl? fvarId | return none
|
||||
if v.isLambda then return some d else findLambda? v
|
||||
| .mdata _ e => findLambda? e
|
||||
| _ => return none
|
||||
if let some d@(.ldecl (value := v) ..) := lctx.find? fvarId then
|
||||
if v.isLambda then some d else findLambdaCore? lctx v
|
||||
else
|
||||
none
|
||||
| .mdata _ e => findLambdaCore? lctx e
|
||||
| _ => none
|
||||
|
||||
partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) :=
|
||||
return findLambdaCore? (← getLCtx) e
|
||||
|
||||
partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do
|
||||
match e with
|
||||
|
|
@ -116,59 +121,109 @@ structure State where
|
|||
|
||||
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
/--
|
||||
/-
|
||||
Ensure binder names are unique, and update local function information.
|
||||
If `mustInline = true`, then local functions in `e` are marked as `.mustInline`.
|
||||
-/
|
||||
partial def internalize (e : Expr) (mustInline := false): SimpM Expr := do
|
||||
visitLambda e
|
||||
|
||||
structure Internalize.State where
|
||||
nextIdx : Nat
|
||||
localInfoMap : LocalFunInfoMap
|
||||
|
||||
private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit :=
|
||||
if mustInline then
|
||||
modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key }
|
||||
else
|
||||
modify fun s => { s with localInfoMap := s.localInfoMap.add key }
|
||||
|
||||
/--
|
||||
`instantiateRevInternalize` implementation.
|
||||
-/
|
||||
private def instantiateRevInternalizeCore (lctx : LocalContext) (e : Expr) (args : Array Expr) (mustInline : Bool) : StateM Internalize.State Expr :=
|
||||
go e {}
|
||||
where
|
||||
visitLambda (e : Expr) : SimpM Expr := do
|
||||
withNewScope do
|
||||
let (as, e) ← Compiler.visitLambdaCore e
|
||||
let e ← mkLetUsingScope (← visitLet e as)
|
||||
mkLambda as e
|
||||
|
||||
visitCases (casesInfo : CasesInfo) (cases : Expr) : SimpM Expr := do
|
||||
let mut args := cases.getAppArgs
|
||||
for i in casesInfo.altsRange do
|
||||
args ← args.modifyM i visitLambda
|
||||
return mkAppN cases.getAppFn args
|
||||
|
||||
visitValue (e : Expr) : SimpM Unit := do
|
||||
if e.isApp then
|
||||
match (← findLambda? e.getAppFn) with
|
||||
| some localDecl =>
|
||||
if localDecl.value.isLambda then
|
||||
let key := localDecl.userName
|
||||
if mustInline then
|
||||
modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key }
|
||||
else
|
||||
modify fun s => { s with localInfoMap := s.localInfoMap.add key }
|
||||
| _ => pure ()
|
||||
|
||||
visitLet (e : Expr) (xs : Array Expr) : SimpM Expr := do
|
||||
/-- Auxiliary functions for instantiating `args` in types. -/
|
||||
inst (e : Expr) (offset : Nat) : Expr :=
|
||||
match e with
|
||||
| .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => e
|
||||
| .mdata k b => .mdata k (inst b offset)
|
||||
| .proj s i b => .proj s i (inst b offset)
|
||||
| .app f a => if offset >= e.looseBVarRange then e else .app (inst f offset) (inst a offset)
|
||||
| .bvar idx => if idx >= offset then args[args.size - (idx - offset) - 1]! else e
|
||||
| .forallE n d b bi => if offset >= e.looseBVarRange then e else .forallE n (inst d offset) (inst b (offset + 1)) bi
|
||||
| .lam n d b bi => if offset >= e.looseBVarRange then e else .lam n (inst d offset) (inst b (offset + 1)) bi
|
||||
| .letE n t v b nd => if offset >= e.looseBVarRange then e else .letE n (inst t offset) (inst v offset) (inst b (offset + 1)) nd
|
||||
|
||||
go (e : Expr) (ctx : Std.PArray (Option Name)) : StateM Internalize.State Expr := do
|
||||
let instantiate (e : Expr) := if args.size == 0 then e else inst e ctx.size
|
||||
let updtBVar (idx : Nat) :=
|
||||
let offset := ctx.size
|
||||
if idx >= offset then
|
||||
args[args.size - (idx - offset) - 1]!
|
||||
else
|
||||
.bvar idx
|
||||
match e with
|
||||
| .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => return e
|
||||
| .mdata k b => return .mdata k (← go b ctx)
|
||||
| .proj s i b => return .proj s i (← go b ctx)
|
||||
| .app f a =>
|
||||
let f ← go f ctx
|
||||
let a ← go a ctx
|
||||
match f with
|
||||
| .fvar .. =>
|
||||
match findLambdaCore? lctx f with
|
||||
| some localDecl => updateFunInfo localDecl.userName mustInline
|
||||
| _ => pure ()
|
||||
| .bvar idx =>
|
||||
match ctx[ctx.size - idx - 1]! with
|
||||
| some binderName => updateFunInfo binderName mustInline
|
||||
| none => pure ()
|
||||
| _ => pure ()
|
||||
return .app f a
|
||||
| .bvar idx => return updtBVar idx
|
||||
| .forallE .. => return instantiate e
|
||||
| .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi
|
||||
| .letE binderName type value body nonDep =>
|
||||
let idx ← mkFreshLetVarIdx
|
||||
let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap })
|
||||
let binderName' := match binderName with
|
||||
| .num p _ => .num p idx
|
||||
| _ => .num binderName idx
|
||||
let type := type.instantiateRev xs
|
||||
let mut value := value.instantiateRev xs
|
||||
if value.isLambda then
|
||||
value ← visitLambda value
|
||||
else
|
||||
visitValue value
|
||||
let x ← mkLetDecl binderName' type value nonDep
|
||||
visitLet body (xs.push x)
|
||||
| _ =>
|
||||
let e := e.instantiateRev xs
|
||||
if let some casesInfo ← isCasesApp? e then
|
||||
visitCases casesInfo e
|
||||
else
|
||||
visitValue e
|
||||
return e
|
||||
let type := instantiate type
|
||||
let value ← go value ctx
|
||||
let ctxVal := match value with
|
||||
| .lam .. => some binderName'
|
||||
-- The next two cases simulate findLambdaCore? for `ctx`
|
||||
| .fvar .. => match findLambdaCore? lctx value with
|
||||
| some localDecl => some localDecl.userName
|
||||
| _ => none
|
||||
| .bvar idx => if idx < ctx.size then ctx[ctx.size - idx - 1]! else none
|
||||
| _ => none
|
||||
return .letE binderName' type value (← go body (ctx.push ctxVal)) nonDep
|
||||
|
||||
/--
|
||||
This function performs the following operations in the given expression in a single pass.
|
||||
- Ensure binder names for let-declarations are unique.
|
||||
- Update local function information. That is, it updates the map `localInfoMap`.
|
||||
- Apply `e.instantiateRev args`.
|
||||
|
||||
We use it to "internalize" expressions at startup and when performing inlining.
|
||||
-/
|
||||
def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do
|
||||
let lctx ← getLCtx
|
||||
let nextIdx := (← getThe CompilerM.State).nextIdx
|
||||
let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} })
|
||||
let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap }
|
||||
modifyThe CompilerM.State fun s => { s with nextIdx }
|
||||
modify fun s => { s with localInfoMap }
|
||||
return e
|
||||
|
||||
/--
|
||||
This function performs the following operations in the given expression in a single pass.
|
||||
- Ensure binder names for let-declarations are unique.
|
||||
- Update local function information. That is, it updates the map `localInfoMap`.
|
||||
-/
|
||||
def internalize (e : Expr) (mustInline := false) : SimpM Expr := do
|
||||
instantiateRevInternalize e #[] mustInline
|
||||
|
||||
def markSimplified : SimpM Unit :=
|
||||
modify fun s => { s with simplified := true }
|
||||
|
|
@ -378,9 +433,12 @@ where
|
|||
visit value
|
||||
|
||||
def betaReduce (e : Expr) (args : Array Expr) : SimpM Expr := do
|
||||
-- TODO: add necessary casts
|
||||
let result ← internalize (e.beta args)
|
||||
-- trace[Meta.debug] "inline:\n{result}"
|
||||
-- TODO: add necessary casts to `args`
|
||||
let rec getLambdaBody : Expr → Expr
|
||||
| .lam _ _ b _ => getLambdaBody b
|
||||
| b => b
|
||||
let result ← instantiateRevInternalize (getLambdaBody e) args
|
||||
trace[Meta.debug] "inline:\n{result}"
|
||||
return result
|
||||
|
||||
/--
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue