diff --git a/src/Lean/Compiler/Simp.lean b/src/Lean/Compiler/Simp.lean index 462839ec67..c7e4a76e04 100644 --- a/src/Lean/Compiler/Simp.lean +++ b/src/Lean/Compiler/Simp.lean @@ -11,13 +11,18 @@ import Lean.Compiler.InlineAttrs namespace Lean.Compiler namespace Simp -partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) := do +partial def findLambdaCore? (lctx : LocalContext) (e : Expr) : Option LocalDecl := match e with | .fvar fvarId => - let some d@(.ldecl (value := v) ..) ← findDecl? fvarId | return none - if v.isLambda then return some d else findLambda? v - | .mdata _ e => findLambda? e - | _ => return none + if let some d@(.ldecl (value := v) ..) := lctx.find? fvarId then + if v.isLambda then some d else findLambdaCore? lctx v + else + none + | .mdata _ e => findLambdaCore? lctx e + | _ => none + +partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) := + return findLambdaCore? (← getLCtx) e partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do match e with @@ -116,59 +121,109 @@ structure State where abbrev SimpM := ReaderT Context $ StateRefT State CompilerM -/-- +/- Ensure binder names are unique, and update local function information. If `mustInline = true`, then local functions in `e` are marked as `.mustInline`. -/ -partial def internalize (e : Expr) (mustInline := false): SimpM Expr := do - visitLambda e + +structure Internalize.State where + nextIdx : Nat + localInfoMap : LocalFunInfoMap + +private def updateFunInfo (key : Name) (mustInline : Bool) : StateM Internalize.State Unit := + if mustInline then + modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key } + else + modify fun s => { s with localInfoMap := s.localInfoMap.add key } + +/-- +`instantiateRevInternalize` implementation. +-/ +private def instantiateRevInternalizeCore (lctx : LocalContext) (e : Expr) (args : Array Expr) (mustInline : Bool) : StateM Internalize.State Expr := + go e {} where - visitLambda (e : Expr) : SimpM Expr := do - withNewScope do - let (as, e) ← Compiler.visitLambdaCore e - let e ← mkLetUsingScope (← visitLet e as) - mkLambda as e - - visitCases (casesInfo : CasesInfo) (cases : Expr) : SimpM Expr := do - let mut args := cases.getAppArgs - for i in casesInfo.altsRange do - args ← args.modifyM i visitLambda - return mkAppN cases.getAppFn args - - visitValue (e : Expr) : SimpM Unit := do - if e.isApp then - match (← findLambda? e.getAppFn) with - | some localDecl => - if localDecl.value.isLambda then - let key := localDecl.userName - if mustInline then - modify fun s => { s with localInfoMap := s.localInfoMap.addMustInline key } - else - modify fun s => { s with localInfoMap := s.localInfoMap.add key } - | _ => pure () - - visitLet (e : Expr) (xs : Array Expr) : SimpM Expr := do + /-- Auxiliary functions for instantiating `args` in types. -/ + inst (e : Expr) (offset : Nat) : Expr := match e with + | .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => e + | .mdata k b => .mdata k (inst b offset) + | .proj s i b => .proj s i (inst b offset) + | .app f a => if offset >= e.looseBVarRange then e else .app (inst f offset) (inst a offset) + | .bvar idx => if idx >= offset then args[args.size - (idx - offset) - 1]! else e + | .forallE n d b bi => if offset >= e.looseBVarRange then e else .forallE n (inst d offset) (inst b (offset + 1)) bi + | .lam n d b bi => if offset >= e.looseBVarRange then e else .lam n (inst d offset) (inst b (offset + 1)) bi + | .letE n t v b nd => if offset >= e.looseBVarRange then e else .letE n (inst t offset) (inst v offset) (inst b (offset + 1)) nd + + go (e : Expr) (ctx : Std.PArray (Option Name)) : StateM Internalize.State Expr := do + let instantiate (e : Expr) := if args.size == 0 then e else inst e ctx.size + let updtBVar (idx : Nat) := + let offset := ctx.size + if idx >= offset then + args[args.size - (idx - offset) - 1]! + else + .bvar idx + match e with + | .sort .. | .lit .. | .const .. | .mvar .. | .fvar .. => return e + | .mdata k b => return .mdata k (← go b ctx) + | .proj s i b => return .proj s i (← go b ctx) + | .app f a => + let f ← go f ctx + let a ← go a ctx + match f with + | .fvar .. => + match findLambdaCore? lctx f with + | some localDecl => updateFunInfo localDecl.userName mustInline + | _ => pure () + | .bvar idx => + match ctx[ctx.size - idx - 1]! with + | some binderName => updateFunInfo binderName mustInline + | none => pure () + | _ => pure () + return .app f a + | .bvar idx => return updtBVar idx + | .forallE .. => return instantiate e + | .lam n d b bi => return .lam n (instantiate d) (← go b (ctx.push none)) bi | .letE binderName type value body nonDep => - let idx ← mkFreshLetVarIdx + let idx ← modifyGet fun { nextIdx, localInfoMap } => (nextIdx, { nextIdx := nextIdx + 1, localInfoMap }) let binderName' := match binderName with | .num p _ => .num p idx | _ => .num binderName idx - let type := type.instantiateRev xs - let mut value := value.instantiateRev xs - if value.isLambda then - value ← visitLambda value - else - visitValue value - let x ← mkLetDecl binderName' type value nonDep - visitLet body (xs.push x) - | _ => - let e := e.instantiateRev xs - if let some casesInfo ← isCasesApp? e then - visitCases casesInfo e - else - visitValue e - return e + let type := instantiate type + let value ← go value ctx + let ctxVal := match value with + | .lam .. => some binderName' + -- The next two cases simulate findLambdaCore? for `ctx` + | .fvar .. => match findLambdaCore? lctx value with + | some localDecl => some localDecl.userName + | _ => none + | .bvar idx => if idx < ctx.size then ctx[ctx.size - idx - 1]! else none + | _ => none + return .letE binderName' type value (← go body (ctx.push ctxVal)) nonDep + +/-- +This function performs the following operations in the given expression in a single pass. +- Ensure binder names for let-declarations are unique. +- Update local function information. That is, it updates the map `localInfoMap`. +- Apply `e.instantiateRev args`. + +We use it to "internalize" expressions at startup and when performing inlining. +-/ +def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do + let lctx ← getLCtx + let nextIdx := (← getThe CompilerM.State).nextIdx + let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} }) + let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap } + modifyThe CompilerM.State fun s => { s with nextIdx } + modify fun s => { s with localInfoMap } + return e + +/-- +This function performs the following operations in the given expression in a single pass. +- Ensure binder names for let-declarations are unique. +- Update local function information. That is, it updates the map `localInfoMap`. +-/ +def internalize (e : Expr) (mustInline := false) : SimpM Expr := do + instantiateRevInternalize e #[] mustInline def markSimplified : SimpM Unit := modify fun s => { s with simplified := true } @@ -378,9 +433,12 @@ where visit value def betaReduce (e : Expr) (args : Array Expr) : SimpM Expr := do - -- TODO: add necessary casts - let result ← internalize (e.beta args) - -- trace[Meta.debug] "inline:\n{result}" + -- TODO: add necessary casts to `args` + let rec getLambdaBody : Expr → Expr + | .lam _ _ b _ => getLambdaBody b + | b => b + let result ← instantiateRevInternalize (getLambdaBody e) args + trace[Meta.debug] "inline:\n{result}" return result /--