/- Copyright (c) 2022 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ import Lean.Compiler.CompilerM import Lean.Compiler.Decl import Lean.Compiler.Stage1 import Lean.Compiler.InlineAttrs namespace Lean.Compiler namespace Simp partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) := do match e with | .fvar fvarId => let some d@(.ldecl (value := v) ..) ← findDecl? fvarId | return none if v.isLambda then return some d else findLambda? v | .mdata _ e => findLambda? e | _ => return none partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do match e with | .fvar fvarId => let some (.ldecl (value := v) ..) ← findDecl? fvarId | return e findExpr v | .mdata _ e' => if skipMData then findExpr e' else return e | _ => return e /-- Local function declaration statistics. Remark: we use the `userName` as the key. Thus, `ensureUniqueLetVarNames` must be used before collectin stastistics. -/ structure InlineStats where /-- Mapping from local function name to the number of times it is used in a declaration. -/ numOccs : Std.HashMap Name Nat := {} /-- Mapping from local function name to their LCNF size. -/ size : Std.HashMap Name Nat := {} deriving Inhabited def InlineStats.format (s : InlineStats) : Format := Id.run do let mut result := Format.nil for (k, n) in s.numOccs.toList do let some size := s.size.find? k | pure () result := result ++ "\n" ++ f!"{k} ↦ {n}, {size}" pure () return result def InlineStats.shouldInline (s : InlineStats) (k : Name) : Bool := Id.run do let some numOccs := s.numOccs.find? k | return false if numOccs == 1 then return true let some sz := s.size.find? k | return false return sz == 1 def InlineStats.add (s : InlineStats) (key : Name) (sz : Nat) : InlineStats := match s with | { numOccs, size } => { numOccs := numOccs.insert key 1, size := size.insert key sz } instance : ToFormat InlineStats where format := InlineStats.format partial def collectInlineStats (e : Expr) : CoreM InlineStats := do let ((_, s), _) ← goLambda e |>.run {} |>.run {} return s where goLambda (e : Expr) : StateRefT InlineStats CompilerM Unit := do withNewScope do let (_, body) ← visitLambda e go body goValue (value : Expr) : StateRefT InlineStats CompilerM Unit := do match value with | .lam .. => goLambda value | .app .. => match (← findLambda? value.getAppFn) with | some localDecl => if localDecl.value.isLambda then let key := localDecl.userName match (← get).numOccs.find? localDecl.userName with | some numOccs => modify fun s => { s with numOccs := s.numOccs.insert key (numOccs + 1) } | _ => let sz ← getLCNFSize localDecl.value modify fun s => s.add key sz | _ => pure () | _ => pure () go (e : Expr) : StateRefT InlineStats CompilerM Unit := do match e with | .letE .. => withNewScope do let body ← visitLet e fun _ value => do goValue value; return value go body | e => if let some casesInfo ← isCasesApp? e then let args := e.getAppArgs for i in casesInfo.altsRange do goLambda args[i]! else goValue e structure Config where increaseFactor : Nat := 2 structure Context where config : Config := {} structure State where /-- Statistics for deciding whether to inline local function declarations. -/ stats : InlineStats simplified : Bool := false deriving Inhabited abbrev SimpM := ReaderT Context $ StateRefT State CompilerM def markSimplified : SimpM Unit := modify fun s => { s with simplified := true } def findCtor (e : Expr) : SimpM Expr := do -- TODO: add support for mapping discriminants to constructors in branches findExpr e /-- Try to simplify projections `.proj _ i s` where `s` is constructor. -/ def simpProj? (e : Expr) : OptionT SimpM Expr := do let .proj _ i s := e | failure let s ← findCtor s let some (ctorVal, args) := s.constructorApp? (← getEnv) | failure return args[ctorVal.numParams + i]! /-- Application over application. ``` let _x.i := f a _x.i b ``` is simplified to `f a b`. -/ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do guard e.isApp let f := e.getAppFn guard f.isFVar let f ← findExpr f guard <| f.isApp || f.isConst return mkAppN f e.getAppArgs def shouldInline (localDecl : LocalDecl) : SimpM Bool := return (← get).stats.shouldInline localDecl.userName structure InlineCandidateInfo where isLocal : Bool arity : Nat /-- Value (lambda expression) of the function to be inlined. -/ value : Expr def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do let f := e.getAppFn if let .const declName us ← findExpr f then unless hasInlineAttribute (← getEnv) declName do return none -- TODO: check whether function is recursive or not. -- We can skip the test and store function inline so far. let some decl ← getStage1Decl? declName | return none let numArgs := e.getAppNumArgs let arity := decl.getArity if numArgs < arity then return none /- Recall that we use binder names to build `InlineStats`. Thus, we use `ensureUniqueLetVarNames` to make sure there is no name collision. -/ let value ← ensureUniqueLetVarNames (decl.value.instantiateLevelParams decl.levelParams us) return some { arity, value isLocal := false } else if let some localDecl ← findLambda? f then unless (← shouldInline localDecl) do return none let numArgs := e.getAppNumArgs let arity := getLambdaArity localDecl.value if numArgs < arity then return none let value ← ensureUniqueLetVarNames localDecl.value return some { arity, value isLocal := true } else return none /-- If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`, return `e'`. -/ def expandTrivialExpr (e : Expr) : SimpM Expr := do if e.isFVar then let e' ← findExpr e unless e'.isLambda do if e != e' then markSimplified return e' return e /-- Given `value` of the form `let x_1 := v_1; ...; let x_n := v_n; e`, return `let x_1; ...; let x_n := v_n; let y : type := e; body`. This methods assumes `type` and `value` do not have loose bound variables. Remark: `body` may have many loose bound variables, and the loose bound variables > 0 must be lifted by `n`. -/ def mkFlatLet (y : Name) (type : Expr) (value : Expr) (body : Expr) (nonDep : Bool := false) : Expr := go value 0 where go (value : Expr) (i : Nat) : Expr := match value with | .letE n t v b d => .letE n t v (go b (i+1)) d | _ => .letE y type value (body.liftLooseBVars 1 i) nonDep /-- Update inlining statistics (`stats` field) with the local function declarations in `e`. We use this method to make sure type class instance elements are inlined in the current compiler simp pass. -/ private def updateStatsUsing (e : Expr) : SimpM Unit := do match e with | .letE binderName _ v b _ => if v.isLambda then modify fun s => { s with stats := s.stats.add binderName 1 } updateStatsUsing b | _ => return () /-- Auxiliary function for projecting "type class dictionary access". That is, we are trying to extract one of the type class instance elements. Remark: We do not consider parent instances to be elements. For example, suppose `e` is `_x_4.1`, and we have ``` _x_2 : Monad (ReaderT Bool (ExceptT String Id)) := @ReaderT.Monad Bool (ExceptT String Id) _x_1 _x_3 : Applicative (ReaderT Bool (ExceptT String Id)) := _x_2.1 _x_4 : Functor (ReaderT Bool (ExceptT String Id)) := _x_3.1 ``` Then, we will expand `_x_4.1` since it corresponds to the `Functor` `map` element, and its type is not a type class, but is of the form ``` {α β : Type u} → (α → β) → ... ``` In the example above, the compiler should not expand `_x_3.1` or `_x_2.1` because they are type class applications: `Functor` and `Applicative` respectively. By eagerly expanding them, we may produce inefficient and bloated code. For example, we may be using `_x_3.1` to invoke a function that expects a `Functor` instance. By expanding `_x_3.1` we will be just expanding the code that creates this instance. -/ partial def inlineProjInst? (e : Expr) : OptionT SimpM Expr := do let .proj _ _ s := e | failure let sType ← inferType s guard (← isClass? sType).isSome let eType ← inferType e guard (← isClass? eType).isNone /- We use `withNewScope` + `mkLetUsingScope` to filter the relevant let-declarations. Recall that we are extracting only one of the type class elements. -/ let value ← withNewScope do mkLetUsingScope (← visitProj e) let value ← ensureUniqueLetVarNames value updateStatsUsing value return value where visitProj (e : Expr) : OptionT SimpM Expr := do let .proj _ i s := e | unreachable! let s ← visit s if let some (ctorVal, ctorArgs) := s.constructorApp? (← getEnv) then return ctorArgs[ctorVal.numParams + i]! else failure visit (e : Expr) : OptionT SimpM Expr := do let e ← findExpr e if e.isConstructorApp (← getEnv) then return e else if e.isProj then /- We may have nested projections as we traverse parent classes. -/ visit (← visitProj e) else let .const declName us := e.getAppFn | failure let some decl ← getStage1Decl? declName | failure guard <| decl.getArity == e.getAppNumArgs let value := decl.value.instantiateLevelParams decl.levelParams us let value := value.beta e.getAppArgs /- Here, we just go inside of the let-declaration block without trying to simplify it. Reason: a type class instannce may have many elements, and it does not make sense to simplify all of them when we are extracting only one of them. -/ let value ← Compiler.visitLet (m := SimpM) value fun _ value => return value visit value mutual partial def visitLambda (e : Expr) : SimpM Expr := withNewScope do let (as, e) ← Compiler.visitLambda e let e ← mkLetUsingScope (← visitLet e) mkLambda as e partial def visitCases (casesInfo : CasesInfo) (e : Expr) : SimpM Expr := do let mut args := e.getAppArgs let major := args[casesInfo.discrsRange.stop - 1]! let major ← findExpr major if let some (ctorVal, ctorArgs) := major.constructorApp? (← getEnv) then /- Simplify `casesOn` constructor -/ let ctorIdx := ctorVal.cidx let alt := args[casesInfo.altsRange.start + ctorIdx]! let ctorFields := ctorArgs[ctorVal.numParams:] let alt := alt.beta ctorFields assert! !alt.isLambda markSimplified visitLet alt else for i in casesInfo.altsRange do args ← args.modifyM i visitLambda return mkAppN e.getAppFn args /-- If `e` is an application that can be inlined, inline it. `k?` is the optional "continuation" for `e`, and it may contain loose bound variables that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instantiateRev xs` is an expression without loose bound variables. -/ partial def inlineApp? (e : Expr) (xs : Array Expr) (k? : Option Expr) : SimpM (Option Expr) := do let some info ← inlineCandidate? e | return none let args := e.getAppArgs let numArgs := args.size trace[Compiler.simp.inline] "inlining {e}" markSimplified if k?.isNone && numArgs == info.arity then /- Easy case, there is no continuation and `e` is not over applied -/ visitLet (info.value.beta args) else if (← onlyOneExitPoint info.value) then /- If `info.value` has only one exit point, we don't need to create a new auxiliary join point -/ let mut value := info.value.beta args[:info.arity] if numArgs > info.arity then let type ← inferType (mkAppN e.getAppFn args[:info.arity]) value := mkFlatLet (← mkAuxLetDeclName) type value (mkAppN (.bvar 0) args[info.arity:]) if let some k := k? then let type ← inferType e value := mkFlatLet (← mkAuxLetDeclName) type value k visitLet value xs else /- There is a continuation `k` or `e` is over applied. If `e` is over applied, the extra arguments act as a continuation. We create a new join point ``` let jp := fun y => let x := y -- if `e` is over applied k ``` Recall that `visitLet` incorporates the current continuation to the new join point `jp`. -/ let jpDomain ← inferType (mkAppN e.getAppFn args[:info.arity]) let binderName ← mkFreshUserName `_y let jp ← withNewScope do let y ← mkLocalDecl binderName jpDomain let body ← if numArgs == info.arity then visitLet k?.get! (xs.push y) else let x ← mkAuxLetDecl (mkAppN y args[info.arity:]) if let some k := k? then visitLet k (xs.push x) else visitLet x (xs.push x) let body ← mkLetUsingScope body mkLambda #[y] body let jp ← mkJpDeclIfNotSimple jp let value := info.value.beta args[:info.arity] let value ← attachJp value jp visitLet value /-- Try to apply simple simplifications. -/ partial def simpValue? (e : Expr) : SimpM (Option Expr) := simpProj? e <|> simpAppApp? e <|> inlineProjInst? e /-- Let-declaration basic block visitor. `e` may contain loose bound variables that still have to be instantiated with `xs`. -/ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do match e with | .letE binderName type value body nonDep => let mut value := value.instantiateRev xs if value.isLambda then value ← visitLambda value else if let some value' ← simpValue? value then if value'.isLet then let e := mkFlatLet binderName type value' body nonDep let e ← visitLet e xs return e value := value' if value.isFVar then /- Eliminate `let _x_i := _x_j;` -/ markSimplified visitLet body (xs.push value) else if let some e ← inlineApp? value xs body then return e else let type := type.instantiateRev xs let x ← mkLetDecl binderName type value nonDep visitLet body (xs.push x) | _ => let e := e.instantiateRev xs if let some value ← simpValue? e then visitLet value else if let some casesInfo ← isCasesApp? e then visitCases casesInfo e else if let some e ← inlineApp? e #[] none then return e else expandTrivialExpr e end end Simp def Decl.simp? (decl : Decl) : CoreM (Option Decl) := do let decl ← decl.ensureUniqueLetVarNames let stats ← Simp.collectInlineStats decl.value trace[Compiler.simp.inline.stats] "{decl.name}:{Format.nest 2 (format stats)}" trace[Compiler.simp.step] "{decl.name} :=\n{decl.value}" let (value, s) ← Simp.visitLambda decl.value |>.run {} |>.run { stats, simplified := false } |>.run' { nextIdx := (← getMaxLetVarIdx decl.value) + 1 } trace[Compiler.simp.step.new] "{decl.name} :=\n{value}" trace[Compiler.simp.stat] "{decl.name}: {← getLCNFSize decl.value}" if s.simplified then return some { decl with value } else return none partial def Decl.simp (decl : Decl) : CoreM Decl := do if let some decl ← decl.simp? then -- TODO: bound number of steps? decl.simp else return decl builtin_initialize registerTraceClass `Compiler.simp.inline registerTraceClass `Compiler.simp.stat registerTraceClass `Compiler.simp.step registerTraceClass `Compiler.simp.step.new registerTraceClass `Compiler.simp.inline.stats end Lean.Compiler