lean4-htt/src/Lean/Compiler/Simp.lean
2022-08-17 14:35:07 -07:00

307 lines
No EOL
9.9 KiB
Text

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.CompilerM
import Lean.Compiler.Decl
import Lean.Compiler.Stage1
import Lean.Compiler.InlineAttrs
namespace Lean.Compiler
namespace Simp
partial def findLambda? (e : Expr) : CompilerM (Option LocalDecl) := do
match e with
| .fvar fvarId =>
let some d@(.ldecl (value := v) ..) ← findDecl? fvarId | return none
if v.isLambda then return some d else findLambda? v
| .mdata _ e => findLambda? e
| _ => return none
partial def findExpr (e : Expr) (skipMData := true): CompilerM Expr := do
match e with
| .fvar fvarId =>
let some (.ldecl (value := v) ..) ← findDecl? fvarId | return e
findExpr v
| .mdata _ e' => if skipMData then findExpr e' else return e
| _ => return e
/--
Local function declaration statistics.
Remark: we use the `userName` as the key. Thus, `ensureUniqueLetVarNames`
must be used before collectin stastistics.
-/
structure InlineStats where
/--
Mapping from local function name to the number of times it is used
in a declaration.
-/
numOccs : Std.HashMap Name Nat := {}
/--
Mapping from local function name to their LCNF size.
-/
size : Std.HashMap Name Nat := {}
def InlineStats.format (s : InlineStats) : Format := Id.run do
let mut result := Format.nil
for (k, n) in s.numOccs.toList do
let some size := s.size.find? k | pure ()
result := result ++ "\n" ++ f!"{k} ↦ {n}, {size}"
pure ()
return result
def InlineStats.shouldInline (s : InlineStats) (k : Name) : Bool := Id.run do
let some numOccs := s.numOccs.find? k | return false
if numOccs == 1 then return true
let some sz := s.size.find? k | return false
return sz == 1
instance : ToFormat InlineStats where
format := InlineStats.format
partial def collectInlineStats (e : Expr) : CoreM InlineStats := do
let ((_, s), _) ← goLambda e |>.run {} |>.run {}
return s
where
goLambda (e : Expr) : StateRefT InlineStats CompilerM Unit := do
withNewScope do
let (_, body) ← visitLambda e
go body
goValue (value : Expr) : StateRefT InlineStats CompilerM Unit := do
match value with
| .lam .. => goLambda value
| .app .. =>
match (← findLambda? value.getAppFn) with
| some localDecl =>
trace[Meta.debug] "found decl {localDecl.userName}"
if localDecl.value.isLambda then
let key := localDecl.userName
match (← get).numOccs.find? localDecl.userName with
| some numOccs => modify fun s => { s with numOccs := s.numOccs.insert key (numOccs + 1) }
| _ =>
let sz ← getLCNFSize localDecl.value
modify fun { numOccs, size } => { numOccs := numOccs.insert key 1, size := size.insert key sz }
| _ => pure ()
| _ => pure ()
go (e : Expr) : StateRefT InlineStats CompilerM Unit := do
match e with
| .letE .. =>
withNewScope do
let body ← visitLet e fun value => do goValue value; return value
go body
| e =>
if let some casesInfo ← isCasesApp? e then
let args := e.getAppArgs
for i in casesInfo.altsRange do
goLambda args[i]!
else
goValue e
structure Config where
increaseFactor : Nat := 2
structure Context where
config : Config := {}
/--
Statistics for deciding whether to inline local function declarations.
-/
stats : InlineStats
/--
We only inline local declarations when `localInline` is `true`.
We set it to `false` when we are inlining a non local definition
that may have let-declarations whose names collide with the ones
stored at `stats`.
-/
localInline : Bool := true
structure State where
simplified : Bool := false
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
def markSimplified : SimpM Unit :=
modify fun s => { s with simplified := true }
def shouldInline (localDecl : LocalDecl) : SimpM Bool :=
return (← read).localInline && (← read).stats.shouldInline localDecl.userName
def inlineCandidate? (e : Expr) : SimpM (Option Nat) := do
let f := e.getAppFn
let arity ← match f with
| .const declName _ =>
unless hasInlineAttribute (← getEnv) declName do return none
-- TODO: check whether function is recursive or not.
-- We can skip the test and store function inline so far.
let some decl ← getStage1Decl? declName | return none
pure decl.getArity
| _ =>
match (← findLambda? f) with
| none => return none
| some localDecl =>
unless (← shouldInline localDecl) do return none
pure (getLambdaArity localDecl.value)
if e.getAppNumArgs < arity then return none
return e.getAppNumArgs - arity
/--
If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`,
return `e'`. -/
def expandTrivialExpr (e : Expr) : SimpM Expr := do
if e.isFVar then
let e' ← findExpr e
unless e'.isLambda do
if e != e' then markSimplified
return e'
return e
mutual
partial def visitLambda (e : Expr) : SimpM Expr :=
withNewScope do
let (as, e) ← Compiler.visitLambda e
let e ← mkLetUsingScope (← visitLet e)
mkLambda as e
partial def visitCases (casesInfo : CasesInfo) (e : Expr) : SimpM Expr := do
let mut args := e.getAppArgs
let major := args[casesInfo.discrsRange.stop - 1]!
let major ← findExpr major
if let some (ctorVal, ctorArgs) := major.constructorApp? (← getEnv) then
/- Simplify `casesOn` constructor -/
let ctorIdx := ctorVal.cidx
let alt := args[casesInfo.altsRange.start + ctorIdx]!
let ctorFields := ctorArgs[ctorVal.numParams:]
let alt := alt.beta ctorFields
assert! !alt.isLambda
markSimplified
visitLet alt
else
for i in casesInfo.altsRange do
args ← args.modifyM i visitLambda
return mkAppN e.getAppFn args
partial def inlineApp (e : Expr) (jp? : Option Expr := none) : SimpM Expr := do
let f := e.getAppFn
trace[Compiler.simp.inline] "inlining {e}"
let value ← match f with
| .const declName us =>
let some decl ← getStage1Decl? declName | unreachable!
pure <| decl.value.instantiateLevelParams decl.levelParams us
| _ =>
let some localDecl ← findLambda? f | unreachable!
pure localDecl.value
let args := e.getAppArgs
let value := value.beta args
let value ← attachOptJp value jp?
assert! !value.isLambda
markSimplified
withReader (fun ctx => { ctx with localInline := !f.isConst }) do
visitLet value
/--
If `e` is an application that can be inlined, inline it.
`k?` is the optional "continuation" for `e`, and it may contain loose bound variables
that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instantiateRev xs`
is an expression without loose bound variables.
-/
partial def inlineApp? (e : Expr) (xs : Array Expr) (k? : Option Expr) : SimpM (Option Expr) := do
let some numExtraArgs ← inlineCandidate? e | return none
let args := e.getAppArgs
if k?.isNone && numExtraArgs == 0 then
-- Easy case, there is not continuation and `e` is not over applied
inlineApp e
else
/-
There is a continuation `k` or `e` is over applied.
If `e` is over applied, the extra arguments act as continuation.
-/
let toInline := mkAppN e.getAppFn args[:args.size - numExtraArgs]
/-
`toInline` is the application that is going to be inline
We create a new join point
```
let jp := fun y =>
let x := y <extra-arguments> -- if `e` is over applied
k
```
Recall that `visitLet` incorporates the current continuation
to the new join point `jp`.
-/
let jpDomain ← inferType toInline
let binderName ← mkFreshUserName `_y
let jp ← withNewScope do
let y ← mkLocalDecl binderName jpDomain
let body ← if numExtraArgs == 0 then
visitLet k?.get! (xs.push y)
else
let x ← mkAuxLetDecl (mkAppN y args[args.size - numExtraArgs:])
if let some k := k? then
visitLet k (xs.push x)
else
visitLet x (xs.push x)
let body ← mkLetUsingScope body
mkLambda #[y] body
let jp ← mkJpDeclIfNotSimple jp
/- Inline `toInline` and "go-to" `jp` with the result. -/
inlineApp toInline jp
/--
Let-declaration basic block visitor. `e` may contain loose bound variables that
still have to be instantiated with `xs`.
-/
partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
match e with
| .letE binderName type value body nonDep =>
let mut value := value.instantiateRev xs
if value.isLambda then
value ← visitLambda value
if value.isFVar then
/- Eliminate `let _x_i := _x_j;` -/
markSimplified
visitLet body (xs.push value)
else if let some e ← inlineApp? value xs body then
return e
else
let type := type.instantiateRev xs
let x ← mkLetDecl binderName type value nonDep
visitLet body (xs.push x)
| _ =>
let e := e.instantiateRev xs
if let some casesInfo ← isCasesApp? e then
visitCases casesInfo e
else if let some e ← inlineApp? e #[] none then
return e
else
expandTrivialExpr e
end
end Simp
def Decl.simp? (decl : Decl) : CoreM (Option Decl) := do
let decl ← decl.ensureUniqueLetVarNames
let stats ← Simp.collectInlineStats decl.value
trace[Compiler.simp.inline.stats] "{decl.name}:{Format.nest 2 (format stats)}"
let (value, s) ← Simp.visitLambda decl.value |>.run { stats } |>.run { simplified := false } |>.run' {}
if s.simplified then
return some { decl with value }
else
return none
partial def Decl.simp (decl : Decl) : CoreM Decl := do
if let some decl ← decl.simp? then
-- TODO: bound number of steps?
decl.simp
else
return decl
builtin_initialize
registerTraceClass `Compiler.simp.inline
registerTraceClass `Compiler.simp.step
registerTraceClass `Compiler.simp.inline.stats
end Lean.Compiler