/- Copyright (c) 2022 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ import Lean.Compiler.CompilerM import Lean.Compiler.Decl import Lean.Compiler.Stage1 import Lean.Compiler.InlineAttrs namespace Lean.Compiler namespace Simp partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) := do match e with | .fvar fvarId => let some d@(.ldecl (value := v) ..) ← findDecl? fvarId | return none if v.isLambda then return some d else findLambda? v | .mdata _ e => findLambda? e | _ => return none partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do match e with | .fvar fvarId => let some (.ldecl (value := v) ..) ← findDecl? fvarId | return e findExpr v | .mdata _ e' => if skipMData then findExpr e' else return e | _ => return e /-- Local function declaration statistics. Remark: we use the `userName` as the key. Thus, `ensureUniqueLetVarNames` must be used before collectin stastistics. -/ structure InlineStats where /-- Mapping from local function name to the number of times it is used in a declaration. -/ numOccs : Std.HashMap Name Nat := {} /-- Mapping from local function name to their LCNF size. -/ size : Std.HashMap Name Nat := {} def InlineStats.format (s : InlineStats) : Format := Id.run do let mut result := Format.nil for (k, n) in s.numOccs.toList do let some size := s.size.find? k | pure () result := result ++ "\n" ++ f!"{k} ↦ {n}, {size}" pure () return result def InlineStats.shouldInline (s : InlineStats) (k : Name) : Bool := Id.run do let some numOccs := s.numOccs.find? k | return false if numOccs == 1 then return true let some sz := s.size.find? k | return false return sz == 1 instance : ToFormat InlineStats where format := InlineStats.format partial def collectInlineStats (e : Expr) : CoreM InlineStats := do let ((_, s), _) ← goLambda e |>.run {} |>.run {} return s where goLambda (e : Expr) : StateRefT InlineStats CompilerM Unit := do withNewScope do let (_, body) ← visitLambda e go body goValue (value : Expr) : StateRefT InlineStats CompilerM Unit := do match value with | .lam .. => goLambda value | .app .. => match (← findLambda? value.getAppFn) with | some localDecl => trace[Meta.debug] "found decl {localDecl.userName}" if localDecl.value.isLambda then let key := localDecl.userName match (← get).numOccs.find? localDecl.userName with | some numOccs => modify fun s => { s with numOccs := s.numOccs.insert key (numOccs + 1) } | _ => let sz ← getLCNFSize localDecl.value modify fun { numOccs, size } => { numOccs := numOccs.insert key 1, size := size.insert key sz } | _ => pure () | _ => pure () go (e : Expr) : StateRefT InlineStats CompilerM Unit := do match e with | .letE .. => withNewScope do let body ← visitLet e fun value => do goValue value; return value go body | e => if let some casesInfo ← isCasesApp? e then let args := e.getAppArgs for i in casesInfo.altsRange do goLambda args[i]! else goValue e structure Config where increaseFactor : Nat := 2 structure Context where config : Config := {} /-- Statistics for deciding whether to inline local function declarations. -/ stats : InlineStats /-- We only inline local declarations when `localInline` is `true`. We set it to `false` when we are inlining a non local definition that may have let-declarations whose names collide with the ones stored at `stats`. -/ localInline : Bool := true structure State where simplified : Bool := false abbrev SimpM := ReaderT Context $ StateRefT State CompilerM def markSimplified : SimpM Unit := modify fun s => { s with simplified := true } def shouldInline (localDecl : LocalDecl) : SimpM Bool := return (← read).localInline && (← read).stats.shouldInline localDecl.userName def inlineCandidate? (e : Expr) : SimpM (Option Nat) := do let f := e.getAppFn let arity ← match f with | .const declName _ => unless hasInlineAttribute (← getEnv) declName do return none -- TODO: check whether function is recursive or not. -- We can skip the test and store function inline so far. let some decl ← getStage1Decl? declName | return none pure decl.getArity | _ => match (← findLambda? f) with | none => return none | some localDecl => unless (← shouldInline localDecl) do return none pure (getLambdaArity localDecl.value) if e.getAppNumArgs < arity then return none return e.getAppNumArgs - arity /-- If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`, return `e'`. -/ def expandTrivialExpr (e : Expr) : SimpM Expr := do if e.isFVar then let e' ← findExpr e unless e'.isLambda do if e != e' then markSimplified return e' return e mutual partial def visitLambda (e : Expr) : SimpM Expr := withNewScope do let (as, e) ← Compiler.visitLambda e let e ← mkLetUsingScope (← visitLet e) mkLambda as e partial def visitCases (casesInfo : CasesInfo) (e : Expr) : SimpM Expr := do let mut args := e.getAppArgs let major := args[casesInfo.discrsRange.stop - 1]! let major ← findExpr major if let some (ctorVal, ctorArgs) := major.constructorApp? (← getEnv) then /- Simplify `casesOn` constructor -/ let ctorIdx := ctorVal.cidx let alt := args[casesInfo.altsRange.start + ctorIdx]! let ctorFields := ctorArgs[ctorVal.numParams:] let alt := alt.beta ctorFields assert! !alt.isLambda markSimplified visitLet alt else for i in casesInfo.altsRange do args ← args.modifyM i visitLambda return mkAppN e.getAppFn args partial def inlineApp (e : Expr) (jp? : Option Expr := none) : SimpM Expr := do let f := e.getAppFn trace[Compiler.simp.inline] "inlining {e}" let value ← match f with | .const declName us => let some decl ← getStage1Decl? declName | unreachable! pure <| decl.value.instantiateLevelParams decl.levelParams us | _ => let some localDecl ← findLambda? f | unreachable! pure localDecl.value let args := e.getAppArgs let value := value.beta args let value ← attachOptJp value jp? assert! !value.isLambda markSimplified withReader (fun ctx => { ctx with localInline := !f.isConst }) do visitLet value /-- If `e` is an application that can be inlined, inline it. `k?` is the optional "continuation" for `e`, and it may contain loose bound variables that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instantiateRev xs` is an expression without loose bound variables. -/ partial def inlineApp? (e : Expr) (xs : Array Expr) (k? : Option Expr) : SimpM (Option Expr) := do let some numExtraArgs ← inlineCandidate? e | return none let args := e.getAppArgs if k?.isNone && numExtraArgs == 0 then -- Easy case, there is not continuation and `e` is not over applied inlineApp e else /- There is a continuation `k` or `e` is over applied. If `e` is over applied, the extra arguments act as continuation. -/ let toInline := mkAppN e.getAppFn args[:args.size - numExtraArgs] /- `toInline` is the application that is going to be inline We create a new join point ``` let jp := fun y => let x := y -- if `e` is over applied k ``` Recall that `visitLet` incorporates the current continuation to the new join point `jp`. -/ let jpDomain ← inferType toInline let binderName ← mkFreshUserName `_y let jp ← withNewScope do let y ← mkLocalDecl binderName jpDomain let body ← if numExtraArgs == 0 then visitLet k?.get! (xs.push y) else let x ← mkAuxLetDecl (mkAppN y args[args.size - numExtraArgs:]) if let some k := k? then visitLet k (xs.push x) else visitLet x (xs.push x) let body ← mkLetUsingScope body mkLambda #[y] body let jp ← mkJpDeclIfNotSimple jp /- Inline `toInline` and "go-to" `jp` with the result. -/ inlineApp toInline jp /-- Let-declaration basic block visitor. `e` may contain loose bound variables that still have to be instantiated with `xs`. -/ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do match e with | .letE binderName type value body nonDep => let mut value := value.instantiateRev xs if value.isLambda then value ← visitLambda value if value.isFVar then /- Eliminate `let _x_i := _x_j;` -/ markSimplified visitLet body (xs.push value) else if let some e ← inlineApp? value xs body then return e else let type := type.instantiateRev xs let x ← mkLetDecl binderName type value nonDep visitLet body (xs.push x) | _ => let e := e.instantiateRev xs if let some casesInfo ← isCasesApp? e then visitCases casesInfo e else if let some e ← inlineApp? e #[] none then return e else expandTrivialExpr e end end Simp def Decl.simp? (decl : Decl) : CoreM (Option Decl) := do let decl ← decl.ensureUniqueLetVarNames let stats ← Simp.collectInlineStats decl.value trace[Compiler.simp.inline.stats] "{decl.name}:{Format.nest 2 (format stats)}" let (value, s) ← Simp.visitLambda decl.value |>.run { stats } |>.run { simplified := false } |>.run' {} if s.simplified then return some { decl with value } else return none partial def Decl.simp (decl : Decl) : CoreM Decl := do if let some decl ← decl.simp? then -- TODO: bound number of steps? decl.simp else return decl builtin_initialize registerTraceClass `Compiler.simp.inline registerTraceClass `Compiler.simp.step registerTraceClass `Compiler.simp.inline.stats end Lean.Compiler