refactor: split Simp.lean
This commit is contained in:
parent
85119ba9d1
commit
35ca2b203c
15 changed files with 1418 additions and 1288 deletions
|
|
@ -363,6 +363,28 @@ def Decl.size (decl : Decl) : Nat :=
|
|||
def Decl.getArity (decl : Decl) : Nat :=
|
||||
decl.params.size
|
||||
|
||||
/--
|
||||
Return `some i` if `decl` is of the form
|
||||
```
|
||||
def f (a_0 ... a_i ...) :=
|
||||
...
|
||||
cases a_i
|
||||
| ...
|
||||
| ...
|
||||
```
|
||||
That is, `f` is a sequence of declarations followed by a `cases` on the parameter `i`.
|
||||
We use this function to decide whether we should inline a declaration tagged with
|
||||
`[inlineIfReduce]` or not.
|
||||
-/
|
||||
def Decl.isCasesOnParam? (decl : Decl) : Option Nat :=
|
||||
go decl.value
|
||||
where
|
||||
go (code : Code) : Option Nat :=
|
||||
match code with
|
||||
| .let _ k | .jp _ k | .fun _ k => go k
|
||||
| .cases c => decl.params.findIdx? fun param => param.fvarId == c.discr
|
||||
| _ => none
|
||||
|
||||
def Decl.instantiateTypeLevelParams (decl : Decl) (us : List Level) : Expr :=
|
||||
decl.type.instantiateLevelParams decl.levelParams us
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,12 @@ def eraseCodeDecl (decl : CodeDecl) : CompilerM Unit := do
|
|||
| .let decl => eraseLetDecl decl
|
||||
| .jp decl | .fun decl => eraseFunDecl decl
|
||||
|
||||
/--
|
||||
Erase all free variables occurring in `decls` from the local context.
|
||||
-/
|
||||
def eraseCodeDecls (decls : Array CodeDecl) : CompilerM Unit := do
|
||||
decls.forM fun decl => eraseCodeDecl decl
|
||||
|
||||
def eraseDecl (decl : Decl) : CompilerM Unit := do
|
||||
eraseParams decl.params
|
||||
eraseCode decl.value
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
56
src/Lean/Compiler/LCNF/Simp/Basic.lean
Normal file
56
src/Lean/Compiler/LCNF/Simp/Basic.lean
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Meta.Instances
|
||||
import Lean.Compiler.InlineAttrs
|
||||
import Lean.Compiler.Specialize
|
||||
import Lean.Compiler.LCNF.CompilerM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
/--
|
||||
Return `true` if the arrow type contains an instance implicit argument.
|
||||
-/
|
||||
def hasLocalInst (type : Expr) : Bool :=
|
||||
match type with
|
||||
| .forallE _ _ b bi => bi.isInstImplicit || hasLocalInst b
|
||||
| _ => false
|
||||
|
||||
/--
|
||||
Return `true` if `decl` is supposed to be inlined/specialized.
|
||||
-/
|
||||
def Decl.isTemplateLike (decl : Decl) : CoreM Bool := do
|
||||
if hasLocalInst decl.type then
|
||||
return true -- `decl` applications will be specialized
|
||||
else if (← Meta.isInstance decl.name) then
|
||||
return true -- `decl` is "fuel" for code specialization
|
||||
else if hasInlineAttribute (← getEnv) decl.name || hasSpecializeAttribute (← getEnv) decl.name then
|
||||
return true -- `decl` is going to be inlined or specialized
|
||||
else
|
||||
return false
|
||||
|
||||
namespace Simp
|
||||
|
||||
partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
let some decl ← findLetDecl? fvarId | return e
|
||||
findExpr decl.value
|
||||
| .mdata _ e' => if skipMData then findExpr e' else return e
|
||||
| _ => return e
|
||||
|
||||
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
if let some decl ← LCNF.findFunDecl? fvarId then
|
||||
return some decl
|
||||
else if let some decl ← findLetDecl? fvarId then
|
||||
findFunDecl? decl.value
|
||||
else
|
||||
return none
|
||||
| .mdata _ e => findFunDecl? e
|
||||
| _ => return none
|
||||
|
||||
end Simp
|
||||
end Lean.Compiler.LCNF
|
||||
37
src/Lean/Compiler/LCNF/Simp/Config.lean
Normal file
37
src/Lean/Compiler/LCNF/Simp/Config.lean
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
structure Config where
|
||||
/--
|
||||
Any local function declaration or join point with size `≤ smallThresold` is inlined
|
||||
even if there are multiple occurrences.
|
||||
We currently don't do the same for global declarations because we are not saving
|
||||
the stage1 compilation result in .olean files yet.
|
||||
-/
|
||||
smallThreshold : Nat := 1
|
||||
/--
|
||||
If `etaPoly` is true, we eta expand any global function application when
|
||||
the function takes local instances. The idea is that we do not generate code
|
||||
for this kind of application, and we want all of them to specialized or inlined.
|
||||
-/
|
||||
etaPoly : Bool := false
|
||||
/--
|
||||
If `inlinePartial` is `true`, we inline partial function applications tagged
|
||||
with `[inline]`. Note that this option is automatically disabled when processing
|
||||
declarations tagged with `[inline]`, marked to be specialized, or instances.
|
||||
-/
|
||||
inlinePartial := false
|
||||
/--
|
||||
If `implementedBy` is `true`, we apply the `implementedBy` replacements.
|
||||
Remark: we only apply `casesOn` replacements at phase 2 because `cases` constructor
|
||||
may not have enough information for reconstructing the original `casesOn` application at
|
||||
phase 1.
|
||||
-/
|
||||
implementedBy := false
|
||||
deriving Inhabited
|
||||
|
||||
65
src/Lean/Compiler/LCNF/Simp/DefaultAlt.lean
Normal file
65
src/Lean/Compiler/LCNF/Simp/DefaultAlt.lean
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.Simp.SimpM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Return the alternative in `alts` whose body appears in most arms,
|
||||
and the number of occurrences.
|
||||
We use this function to decide whether to create a `.default` case
|
||||
or not.
|
||||
-/
|
||||
private def getMaxOccs (alts : Array Alt) : Alt × Nat := Id.run do
|
||||
let mut maxAlt := alts[0]!
|
||||
let mut max := getNumOccsOf alts 0
|
||||
for i in [1:alts.size] do
|
||||
let curr := getNumOccsOf alts i
|
||||
if curr > max then
|
||||
maxAlt := alts[i]!
|
||||
max := curr
|
||||
return (maxAlt, max)
|
||||
where
|
||||
/--
|
||||
Return the number of occurrences of `alts[i]` in `alts`.
|
||||
We use alpha equivalence.
|
||||
Note that the number of occurrences can be greater than 1 only when
|
||||
the alternative does not depend on field parameters
|
||||
-/
|
||||
getNumOccsOf (alts : Array Alt) (i : Nat) : Nat := Id.run do
|
||||
let code := alts[i]!.getCode
|
||||
let mut n := 1
|
||||
for j in [i+1:alts.size] do
|
||||
if Code.alphaEqv alts[j]!.getCode code then
|
||||
n := n+1
|
||||
return n
|
||||
|
||||
/--
|
||||
Add a default case to the given `cases` alternatives if there
|
||||
are alternatives with equivalent (aka alpha equivalent) right hand sides.
|
||||
-/
|
||||
def addDefaultAlt (alts : Array Alt) : SimpM (Array Alt) := do
|
||||
if alts.size <= 1 || alts.any (· matches .default ..) then
|
||||
return alts
|
||||
else
|
||||
let (max, noccs) := getMaxOccs alts
|
||||
if noccs == 1 then
|
||||
return alts
|
||||
else
|
||||
let mut altsNew := #[]
|
||||
let mut first := true
|
||||
markSimplified
|
||||
for alt in alts do
|
||||
if Code.alphaEqv alt.getCode max.getCode then
|
||||
let .alt _ ps k := alt | unreachable!
|
||||
eraseParams ps
|
||||
unless first do
|
||||
eraseCode k
|
||||
first := false
|
||||
else
|
||||
altsNew := altsNew.push alt
|
||||
return altsNew.push (.default max.getCode)
|
||||
126
src/Lean/Compiler/LCNF/Simp/FunDeclInfo.lean
Normal file
126
src/Lean/Compiler/LCNF/Simp/FunDeclInfo.lean
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.Simp.Basic
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Local function usage information used to decide whether it should be inlined or not.
|
||||
The information is an approximation, but it is on the "safe" side. That is, if we tagged
|
||||
a function with `.once`, then it is applied only once. A local function may be marked as
|
||||
`.many`, but after simplifications the number of applications may reduce to 1. This is not
|
||||
a big problem in practice because we run the simplifier multiple times, and this information
|
||||
is recomputed from scratch at the beginning of each simplification step.
|
||||
-/
|
||||
inductive FunDeclInfo where
|
||||
/--
|
||||
Local function is applied once, and must be inlined.
|
||||
-/
|
||||
| once
|
||||
/--
|
||||
Local function is applied many times or occur as an argument of another function,
|
||||
and will only be inlined if it is small.
|
||||
-/
|
||||
| many
|
||||
/--
|
||||
Function must be inlined.
|
||||
-/
|
||||
| mustInline
|
||||
deriving Repr, Inhabited
|
||||
|
||||
/--
|
||||
Local function declaration statistics.
|
||||
-/
|
||||
structure FunDeclInfoMap where
|
||||
/--
|
||||
Mapping from local function name to inlining information.
|
||||
-/
|
||||
map : HashMap FVarId FunDeclInfo := {}
|
||||
deriving Inhabited
|
||||
|
||||
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
|
||||
let mut result := Format.nil
|
||||
for (fvarId, info) in s.map.toList do
|
||||
let binderName ← getBinderName fvarId
|
||||
result := result ++ "\n" ++ f!"{binderName} ↦ {repr info}"
|
||||
return result
|
||||
|
||||
/--
|
||||
Add new occurrence for the local function with binder name `key`.
|
||||
-/
|
||||
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map.find? fvarId with
|
||||
| some .once => { map := map.insert fvarId .many }
|
||||
| none => { map := map.insert fvarId .once }
|
||||
| _ => { map }
|
||||
|
||||
/--
|
||||
Add new occurrence for the local function occurring as an argument for another function.
|
||||
-/
|
||||
def FunDeclInfoMap.addHo (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } =>
|
||||
match map.find? fvarId with
|
||||
| some .once | none => { map := map.insert fvarId .many }
|
||||
| _ => { map }
|
||||
|
||||
/--
|
||||
Add new occurrence for the local function with binder name `key`.
|
||||
-/
|
||||
def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
|
||||
match s with
|
||||
| { map } => { map := map.insert fvarId .mustInline }
|
||||
|
||||
def FunDeclInfoMap.restore (s : FunDeclInfoMap) (fvarId : FVarId) (saved? : Option FunDeclInfo) : FunDeclInfoMap :=
|
||||
match s, saved? with
|
||||
| { map }, none => { map := map.erase fvarId }
|
||||
| { map }, some saved => { map := map.insert fvarId saved }
|
||||
|
||||
/--
|
||||
Traverse `code` and update function occurrence map.
|
||||
This map is used to decide whether we inline local functions or not.
|
||||
If `mustInline := true`, then all local function declarations occurring in
|
||||
`code` are tagged as `.mustInline`.
|
||||
Recall that we use `.mustInline` for local function declarations occurring in type class instances.
|
||||
-/
|
||||
partial def FunDeclInfoMap.update (s : FunDeclInfoMap) (code : Code) (mustInline := false) : CompilerM FunDeclInfoMap := do
|
||||
let (_, s) ← go code |>.run s
|
||||
return s
|
||||
where
|
||||
addArgOcc (e : Expr) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
let some funDecl ← findFunDecl? e | return ()
|
||||
modify fun s => s.addHo funDecl.fvarId
|
||||
|
||||
addOccs (e : Expr) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
match e with
|
||||
| .app f a => addArgOcc a; addOccs f
|
||||
| .fvar _ =>
|
||||
let some funDecl ← findFunDecl? e | return ()
|
||||
modify fun s => s.add funDecl.fvarId
|
||||
| _ => return ()
|
||||
|
||||
go (code : Code) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
addOccs decl.value
|
||||
go k
|
||||
| .fun decl k =>
|
||||
if mustInline then
|
||||
modify fun s => s.addMustInline decl.fvarId
|
||||
go decl.value; go k
|
||||
| .jp decl k => go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .jmp fvarId args =>
|
||||
let funDecl ← getFunDecl fvarId
|
||||
modify fun s => s.add funDecl.fvarId
|
||||
args.forM addArgOcc
|
||||
| .return .. | .unreach .. => return ()
|
||||
|
||||
end Simp
|
||||
end Lean.Compiler.LCNF
|
||||
80
src/Lean/Compiler/LCNF/Simp/InlineCandidate.lean
Normal file
80
src/Lean/Compiler/LCNF/Simp/InlineCandidate.lean
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.Simp.SimpM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Result of `inlineCandidate?`.
|
||||
It contains information for inlining local and global functions.
|
||||
-/
|
||||
structure InlineCandidateInfo where
|
||||
isLocal : Bool
|
||||
params : Array Param
|
||||
/-- Value (lambda expression) of the function to be inlined. -/
|
||||
value : Code
|
||||
f : Expr
|
||||
args : Array Expr
|
||||
/-- `ifReduce = true` if the declaration being inlined was tagged with `inlineIfReduce`. -/
|
||||
ifReduce : Bool
|
||||
|
||||
/-- The arity (aka number of parameters) of the function to be inlined. -/
|
||||
def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
|
||||
| { params, .. } => params.size
|
||||
|
||||
/--
|
||||
Return `some info` if `e` should be inlined.
|
||||
-/
|
||||
def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
|
||||
let mut e := e
|
||||
let mut mustInline := false
|
||||
if e.isAppOfArity ``inline 2 then
|
||||
e ← findExpr e.appArg!
|
||||
mustInline := true
|
||||
let numArgs := e.getAppNumArgs
|
||||
let f := e.getAppFn
|
||||
if let .const declName us ← findExpr f then
|
||||
if declName == (← read).declName then return none -- TODO: remove after we start storing phase1 code in .olean files
|
||||
let inlineIfReduce := hasInlineIfReduceAttribute (← getEnv) declName
|
||||
unless mustInline || hasInlineAttribute (← getEnv) declName || inlineIfReduce do return none
|
||||
-- TODO: check whether function is recursive or not.
|
||||
-- We can skip the test and store function inline so far.
|
||||
let some decl ← getDecl? declName | return none
|
||||
let arity := decl.getArity
|
||||
let inlinePartial := (← read).config.inlinePartial
|
||||
if !mustInline && !inlinePartial && numArgs < arity then return none
|
||||
if inlineIfReduce then
|
||||
let some paramIdx := decl.isCasesOnParam? | return none
|
||||
unless paramIdx < numArgs do return none
|
||||
let arg ← findExpr (e.getArg! paramIdx)
|
||||
unless arg.isConstructorApp (← getEnv) do return none
|
||||
let params := decl.instantiateParamsLevelParams us
|
||||
let value := decl.instantiateValueLevelParams us
|
||||
incInline
|
||||
return some {
|
||||
isLocal := false
|
||||
f := e.getAppFn
|
||||
args := e.getAppArgs
|
||||
ifReduce := inlineIfReduce
|
||||
params, value
|
||||
}
|
||||
else if let some decl ← findFunDecl? f then
|
||||
unless numArgs > 0 do return none -- It is not worth to inline a local function that does not take any arguments
|
||||
unless mustInline || (← shouldInlineLocal decl) do return none
|
||||
-- Remark: we inline local function declarations even if they are partial applied
|
||||
incInlineLocal
|
||||
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
|
||||
return some {
|
||||
isLocal := true
|
||||
f := e.getAppFn
|
||||
args := e.getAppArgs
|
||||
params := decl.params
|
||||
value := decl.value
|
||||
ifReduce := false
|
||||
}
|
||||
else
|
||||
return none
|
||||
75
src/Lean/Compiler/LCNF/Simp/InlineProj.lean
Normal file
75
src/Lean/Compiler/LCNF/Simp/InlineProj.lean
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.Simp.SimpM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Auxiliary function for projecting "type class dictionary access".
|
||||
That is, we are trying to extract one of the type class instance elements.
|
||||
Remark: We do not consider parent instances to be elements.
|
||||
For example, suppose `e` is `_x_4.1`, and we have
|
||||
```
|
||||
_x_2 : Monad (ReaderT Bool (ExceptT String Id)) := @ReaderT.Monad Bool (ExceptT String Id) _x_1
|
||||
_x_3 : Applicative (ReaderT Bool (ExceptT String Id)) := _x_2.1
|
||||
_x_4 : Functor (ReaderT Bool (ExceptT String Id)) := _x_3.1
|
||||
```
|
||||
Then, we will expand `_x_4.1` since it corresponds to the `Functor` `map` element,
|
||||
and its type is not a type class, but is of the form
|
||||
```
|
||||
{α β : Type u} → (α → β) → ...
|
||||
```
|
||||
In the example above, the compiler should not expand `_x_3.1` or `_x_2.1` because they are
|
||||
type class applications: `Functor` and `Applicative` respectively.
|
||||
By eagerly expanding them, we may produce inefficient and bloated code.
|
||||
For example, we may be using `_x_3.1` to invoke a function that expects a `Functor` instance.
|
||||
By expanding `_x_3.1` we will be just expanding the code that creates this instance.
|
||||
|
||||
The result is representing a sequence of code containing let-declarations and local function declarations (`Array CodeDecl`)
|
||||
and the free variable containing the result (`FVarId`). The resulting `FVarId` often depends only on a small
|
||||
subset of `Array CodeDecl`. However, this method does try to filter the relevant ones.
|
||||
We rely on the `used` var set available in `SimpM` to filter them. See `attachCodeDecls`.
|
||||
-/
|
||||
partial def inlineProjInst? (e : Expr) : SimpM (Option (Array CodeDecl × FVarId)) := do
|
||||
let .proj _ i s := e | return none
|
||||
let sType ← inferType s
|
||||
unless (← isClass? sType).isSome do return none
|
||||
let eType ← inferType e
|
||||
unless (← isClass? eType).isNone do return none
|
||||
let (fvarId?, decls) ← visit s [i] |>.run |>.run #[]
|
||||
if let some fvarId := fvarId? then
|
||||
return some (decls, fvarId)
|
||||
else
|
||||
eraseCodeDecls decls
|
||||
return none
|
||||
where
|
||||
visit (e : Expr) (projs : List Nat) : OptionT (StateRefT (Array CodeDecl) SimpM) FVarId := do
|
||||
let e ← findExpr e
|
||||
if let .proj _ i s := e then
|
||||
visit s (i :: projs)
|
||||
else if let some (ctorVal, ctorArgs) := e.constructorApp? (← getEnv) then
|
||||
let i :: projs := projs | unreachable!
|
||||
let e := ctorArgs[ctorVal.numParams + i]!
|
||||
if projs.isEmpty then
|
||||
let .fvar fvarId := e | unreachable!
|
||||
return fvarId
|
||||
else
|
||||
visit e projs
|
||||
else
|
||||
let .const declName us := e.getAppFn | failure
|
||||
let some decl ← getDecl? declName | failure
|
||||
guard (decl.getArity == e.getAppNumArgs)
|
||||
let code := decl.instantiateValueLevelParams us
|
||||
let code ← betaReduce decl.params code e.getAppArgs (mustInline := true)
|
||||
visitCode code projs
|
||||
|
||||
visitCode (code : Code) (projs : List Nat) : OptionT (StateRefT (Array CodeDecl) SimpM) FVarId := do
|
||||
match code with
|
||||
| .let decl k => modify (·.push (.let decl)); visitCode k projs
|
||||
| .fun decl k => modify (·.push (.fun decl)); visitCode k projs
|
||||
| .return fvarId => visit (.fvar fvarId) projs
|
||||
| _ => failure
|
||||
260
src/Lean/Compiler/LCNF/Simp/JpCases.lean
Normal file
260
src/Lean/Compiler/LCNF/Simp/JpCases.lean
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.DependsOn
|
||||
import Lean.Compiler.LCNF.InferType
|
||||
import Lean.Compiler.LCNF.Simp.Basic
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Given the function declaration `decl`, return `true` if it is of the form
|
||||
```
|
||||
f y :=
|
||||
... /- This part is not bigger than smallThreshold. -/
|
||||
cases y
|
||||
| ... => ...
|
||||
...
|
||||
```
|
||||
-/
|
||||
def isJpCases (decl : FunDecl) (smallThreshold : Nat) : CompilerM Bool := do
|
||||
if decl.params.size != 1 then
|
||||
return false
|
||||
else
|
||||
let param := decl.params[0]!
|
||||
let rec go (code : Code) (prefixSize : Nat) : Bool :=
|
||||
prefixSize <= smallThreshold &&
|
||||
match code with
|
||||
| .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/
|
||||
| .cases c => c.discr == param.fvarId
|
||||
| _ => false
|
||||
return go decl.value 0
|
||||
|
||||
abbrev JpCasesInfo := FVarIdMap NameSet
|
||||
|
||||
/-- Return `true` if the collected information suggests opportunities for the `JpCases` optimization. -/
|
||||
def JpCasesInfo.isCandidate (info : JpCasesInfo) : Bool :=
|
||||
info.any fun _ s => !s.isEmpty
|
||||
|
||||
/--
|
||||
Return a map containing entries `jpFVarId ↦ ctorNames` where `jpFVarId` is the id of join point
|
||||
in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor names such that
|
||||
there is a jump `.jmp jpFVarId #[x]` in `code` and `x` is a constructor application.
|
||||
-/
|
||||
partial def collectJpCasesInfo (code : Code) (smallThreshold : Nat): CompilerM JpCasesInfo := do
|
||||
let (_, s) ← go code |>.run {}
|
||||
return s
|
||||
where
|
||||
go (code : Code) : StateRefT JpCasesInfo CompilerM Unit := do
|
||||
match code with
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value; go k
|
||||
| .jp decl k =>
|
||||
if (← isJpCases decl smallThreshold) then
|
||||
modify fun s => s.insert decl.fvarId {}
|
||||
go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .return .. | .unreach .. => return ()
|
||||
| .jmp fvarId args =>
|
||||
if args.size == 1 then
|
||||
if let some ctorNames := (← get).find? fvarId then
|
||||
let arg ← findExpr args[0]!
|
||||
let some (cval, _) := arg.constructorApp? (← getEnv) | return ()
|
||||
modify fun s => s.insert fvarId <| ctorNames.insert cval.name
|
||||
|
||||
/--
|
||||
Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases`.
|
||||
-/
|
||||
private def extractJpCases (code : Code) : Array CodeDecl × Cases :=
|
||||
go code #[]
|
||||
where
|
||||
go (code : Code) (decls : Array CodeDecl) :=
|
||||
match code with
|
||||
| .let decl k => go k <| decls.push (.let decl)
|
||||
| .cases c => (decls, c)
|
||||
| _ => unreachable! -- `code` is not the body of a join point that satisfies `isJpCases`
|
||||
|
||||
structure JpCasesAlt where
|
||||
decl : FunDecl
|
||||
default : Bool
|
||||
dependsOnDiscr : Bool
|
||||
|
||||
abbrev Ctor2JpCasesAlt := FVarIdMap (NameMap JpCasesAlt)
|
||||
|
||||
open Internalize in
|
||||
private def mkJpAlt (decls : Array CodeDecl) (discr : Param) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do
|
||||
go |>.run' {}
|
||||
where
|
||||
go : InternalizeM JpCasesAlt := do
|
||||
let s : FVarIdSet := {}
|
||||
let mut paramsNew := #[]
|
||||
let dependsOnDiscr := k.dependsOn (s.insert discr.fvarId)
|
||||
if dependsOnDiscr then
|
||||
paramsNew := paramsNew.push (← internalizeParam discr)
|
||||
paramsNew := paramsNew ++ (← fields.mapM internalizeParam)
|
||||
let decls ← decls.mapM internalizeCodeDecl
|
||||
let k ← internalizeCode k
|
||||
let value := LCNF.attachCodeDecls decls k
|
||||
return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr }
|
||||
|
||||
/--
|
||||
Try to optimize `jpCases` join points.
|
||||
We say a join point is a `jpCases` when it satifies the predicate `isJpCases`.
|
||||
If we have a jump to `jpCases` with a constructor, then we can optimize the code by creating an new join point for
|
||||
the constructor.
|
||||
Example: suppose we have
|
||||
```lean
|
||||
jp _jp.1 y :=
|
||||
let x.1 := true
|
||||
cases y
|
||||
| nil => let x.2 := g x.1; return x.2
|
||||
| cons h t => let x.3 := h x.1; return x.3
|
||||
...
|
||||
cases x.4
|
||||
| ctor1 =>
|
||||
let x.5 := cons z.1 z.2
|
||||
jmp _jp.1 x.5
|
||||
| ctor2 =>
|
||||
let x.6 := f x.4
|
||||
jmp _jp.1 x.6
|
||||
```
|
||||
This `simpJpCases?` converts it to
|
||||
```lean
|
||||
jp _jp.2 h t :=
|
||||
let x.1 := true
|
||||
let x.3 := h x.1
|
||||
return x.3
|
||||
jp _jp.1 y :=
|
||||
let x.1 := true
|
||||
cases y
|
||||
| nil => let x.2 := g x.1; return x.2
|
||||
| cons h t => jmp _jp.2 h t
|
||||
...
|
||||
cases x.4
|
||||
| ctor1 =>
|
||||
-- The constructor has been eliminated here
|
||||
jmp _jp.2 z.1 z.2
|
||||
| ctor2 =>
|
||||
let x.6 := f x.4
|
||||
jmp _jp.1 x.6
|
||||
```
|
||||
Note that if all jumps to the join point are with constructors,
|
||||
then the join point is eliminated as dead code.
|
||||
-/
|
||||
partial def simpJpCases? (code : Code) (smallThreshold : Nat) : CompilerM (Option Code) := do
|
||||
let info ← collectJpCasesInfo code smallThreshold
|
||||
unless info.isCandidate do return none
|
||||
traceM `Compiler.simp.jpCases do
|
||||
let mut msg : MessageData := "candidates"
|
||||
for (fvarId, ctorName) in info.toList do
|
||||
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {ctorName.toList}"
|
||||
return msg
|
||||
visit code info |>.run' {}
|
||||
where
|
||||
visit (code : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
return code.updateLet! decl (← visit k)
|
||||
| .fun decl k =>
|
||||
let value ← visit decl.value
|
||||
let decl ← decl.updateValue value
|
||||
return code.updateFun! decl (← visit k)
|
||||
| .jp decl k =>
|
||||
if let some code ← visitJp? decl k then
|
||||
return code
|
||||
else
|
||||
let value ← visit decl.value
|
||||
let decl ← decl.updateValue value
|
||||
return code.updateFun! decl (← visit k)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← visit alt.getCode)
|
||||
return code.updateAlts! alts
|
||||
| .return _ | .unreach _ => return code
|
||||
| .jmp fvarId args =>
|
||||
let some code ← visitJmp? fvarId args | return code
|
||||
return code
|
||||
|
||||
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
let some s := (← read).find? decl.fvarId | return none
|
||||
if s.isEmpty then return none
|
||||
-- This join point satisfies `isJp` and there jumps with constructors in `s` to it.
|
||||
let p := decl.params[0]!
|
||||
let (decls, cases) := extractJpCases decl.value
|
||||
let mut jpAltMap := {}
|
||||
let mut jpAltDecls := #[]
|
||||
let mut altsNew := #[]
|
||||
for alt in cases.alts do
|
||||
match alt with
|
||||
| .default k =>
|
||||
let k ← visit k
|
||||
let explicitCtorNames := cases.getCtorNames
|
||||
if s.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then
|
||||
let jpAlt ← mkJpAlt decls p #[] k (default := true)
|
||||
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
|
||||
eraseCode k
|
||||
for ctorNameInJmp in s do
|
||||
unless explicitCtorNames.contains ctorNameInJmp do
|
||||
jpAltMap := jpAltMap.insert ctorNameInJmp jpAlt
|
||||
let args := if jpAlt.dependsOnDiscr then #[.fvar p.fvarId] else #[]
|
||||
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
|
||||
else
|
||||
altsNew := altsNew.push (alt.updateCode k)
|
||||
| .alt ctorName fields k =>
|
||||
let k ← visit k
|
||||
if s.contains ctorName then
|
||||
let jpAlt ← mkJpAlt decls p fields k (default := false)
|
||||
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
|
||||
jpAltMap := jpAltMap.insert ctorName jpAlt
|
||||
let mut args := fields.map (mkFVar ·.fvarId)
|
||||
if jpAlt.dependsOnDiscr then
|
||||
args := #[mkFVar p.fvarId] ++ args
|
||||
eraseCode k
|
||||
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
|
||||
else
|
||||
altsNew := altsNew.push (alt.updateCode k)
|
||||
modify fun s => s.insert decl.fvarId jpAltMap
|
||||
let value := LCNF.attachCodeDecls decls (.cases { cases with alts := altsNew })
|
||||
let decl ← decl.updateValue value
|
||||
let code := .jp decl (← visit k)
|
||||
return LCNF.attachCodeDecls jpAltDecls code
|
||||
|
||||
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
let some ctorJpAltMap := (← get).find? fvarId | return none
|
||||
assert! args.size == 1
|
||||
let arg ← findExpr args[0]!
|
||||
let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none
|
||||
let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none
|
||||
if jpAlt.default then
|
||||
if jpAlt.dependsOnDiscr then
|
||||
return some <| .jmp jpAlt.decl.fvarId args
|
||||
else
|
||||
return some <| .jmp jpAlt.decl.fvarId #[]
|
||||
else
|
||||
let fields := ctorArgs[ctorVal.numParams:]
|
||||
-- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too.
|
||||
-- We use a for-loop because we may have other special cases in the future.
|
||||
let mut auxDecls := #[]
|
||||
let mut fieldsNew := #[]
|
||||
for field in fields do
|
||||
if field.isFVar then
|
||||
fieldsNew := fieldsNew.push field
|
||||
else
|
||||
let letDecl ← mkAuxLetDecl field
|
||||
auxDecls := auxDecls.push (CodeDecl.let letDecl)
|
||||
fieldsNew := fieldsNew.push (.fvar letDecl.fvarId)
|
||||
let code ← if jpAlt.dependsOnDiscr then
|
||||
pure <| .jmp jpAlt.decl.fvarId (args ++ fieldsNew)
|
||||
else
|
||||
pure <| .jmp jpAlt.decl.fvarId fieldsNew
|
||||
return some <| LCNF.attachCodeDecls auxDecls code
|
||||
|
||||
end Simp
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.simp.jpCases
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
||||
310
src/Lean/Compiler/LCNF/Simp/Main.lean
Normal file
310
src/Lean/Compiler/LCNF/Simp/Main.lean
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.ImplementedByAttr
|
||||
import Lean.Compiler.LCNF.ElimDead
|
||||
import Lean.Compiler.LCNF.AlphaEqv
|
||||
import Lean.Compiler.LCNF.PrettyPrinter
|
||||
import Lean.Compiler.LCNF.Bind
|
||||
import Lean.Compiler.LCNF.Simp.FunDeclInfo
|
||||
import Lean.Compiler.LCNF.Simp.InlineCandidate
|
||||
import Lean.Compiler.LCNF.Simp.InlineProj
|
||||
import Lean.Compiler.LCNF.Simp.Used
|
||||
import Lean.Compiler.LCNF.Simp.DefaultAlt
|
||||
import Lean.Compiler.LCNF.Simp.SimpValue
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Return `true` if `c` has only one exit point.
|
||||
This is a quick approximation. It does not check cases
|
||||
such as: a `cases` with many but only one alternative is not reachable.
|
||||
It is only used to avoid the creation of auxiliary join points, and does not need
|
||||
to be precise.
|
||||
-/
|
||||
private partial def oneExitPointQuick (c : Code) : Bool :=
|
||||
go c
|
||||
where
|
||||
go (c : Code) : Bool :=
|
||||
match c with
|
||||
| .let _ k | .fun _ k => go k
|
||||
-- Approximation, the cases may have many unreachable alternatives, and only reachable.
|
||||
| .cases c => c.alts.size == 1 && go c.alts[0]!.getCode
|
||||
-- Approximation, we assume that any code containing join points have more than one exit point
|
||||
| .jp .. | .jmp .. | .unreach .. => false
|
||||
| .return .. => true
|
||||
|
||||
/--
|
||||
Create a new local function declaration when `info.args.size < info.params.size`.
|
||||
We use this function to inline/specialize a partial application of a local function.
|
||||
-/
|
||||
def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
|
||||
let mut subst := {}
|
||||
for param in info.params, arg in info.args do
|
||||
subst := subst.insert param.fvarId arg
|
||||
let mut paramsNew := #[]
|
||||
for param in info.params[info.args.size:] do
|
||||
let type ← replaceExprFVars param.type subst (translator := true)
|
||||
let paramNew ← mkAuxParam type
|
||||
paramsNew := paramsNew.push paramNew
|
||||
subst := subst.insert param.fvarId (.fvar paramNew.fvarId)
|
||||
let code ← info.value.internalize subst
|
||||
updateFunDeclInfo code
|
||||
mkAuxFunDecl paramsNew code
|
||||
|
||||
/--
|
||||
Try to inline a join point.
|
||||
-/
|
||||
partial def inlineJp? (fvarId : FVarId) (args : Array Expr) : SimpM (Option Code) := do
|
||||
let some decl ← LCNF.findFunDecl? fvarId | return none
|
||||
unless (← shouldInlineLocal decl) do return none
|
||||
markSimplified
|
||||
betaReduce decl.params decl.value args
|
||||
|
||||
/--
|
||||
When the configuration flag `etaPoly = true`, we eta-expand
|
||||
partial applications of functions that take local instances as arguments.
|
||||
This kind of function is inlined or specialized, and we create new
|
||||
simplification opportunities by eta-expanding them.
|
||||
-/
|
||||
def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do
|
||||
guard <| (← read).config.etaPoly
|
||||
let .const declName _ := letDecl.value.getAppFn | failure
|
||||
let some info := (← getEnv).find? declName | failure
|
||||
guard <| hasLocalInst info.type
|
||||
guard <| !(← Meta.isInstance declName)
|
||||
let some decl ← getDecl? declName | failure
|
||||
let numArgs := letDecl.value.getAppNumArgs
|
||||
guard <| decl.getArity > numArgs
|
||||
let params ← mkNewParams letDecl.type
|
||||
let value := mkAppN letDecl.value (params.map (.fvar ·.fvarId))
|
||||
let auxDecl ← mkAuxLetDecl value
|
||||
let funDecl ← mkAuxFunDecl params (.let auxDecl (.return auxDecl.fvarId))
|
||||
addFVarSubst letDecl.fvarId funDecl.fvarId
|
||||
eraseLetDecl letDecl
|
||||
return funDecl
|
||||
|
||||
/--
|
||||
Similar to `Code.isReturnOf`, but taking the current substitution into account.
|
||||
-/
|
||||
def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do
|
||||
match c with
|
||||
| .return fvarId' => return (← normFVar fvarId') == fvarId
|
||||
| _ => return false
|
||||
|
||||
mutual
|
||||
/--
|
||||
If the value of the given let-declaration is an application that can be inlined,
|
||||
inline it and simplify the result.
|
||||
|
||||
`k` is the "continuation" for the let declaration, if the application is inlined,
|
||||
it will also be simplified.
|
||||
|
||||
Note: `inlineApp?` did not use to be in this mutually recursive declaration.
|
||||
It used to be invoked by `simp`, and would return `Option Code` that would be
|
||||
then simplified by `simp`. However, this simpler architecture produced an
|
||||
exponential blowup in when processing functions such as `Lean.Elab.Deriving.Ord.mkMatch.mkAlts`.
|
||||
The key problem is that when inlining a declaration we often can reduce the number
|
||||
of exit points by simplified the inlined code, and then connecting the result to the
|
||||
continuation `k`. However, this optimization is only possible if we simplify the
|
||||
inlined code **before** we attach it to the continuation.
|
||||
-/
|
||||
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
|
||||
let some info ← inlineCandidate? letDecl.value | return none
|
||||
let numArgs := info.args.size
|
||||
withInlining letDecl.value do
|
||||
let fvarId := letDecl.fvarId
|
||||
if numArgs < info.arity then
|
||||
let funDecl ← specializePartialApp info
|
||||
addFVarSubst fvarId funDecl.fvarId
|
||||
markSimplified
|
||||
simp (.fun funDecl k)
|
||||
else
|
||||
let code ← betaReduce info.params info.value info.args[:info.arity]
|
||||
if k.isReturnOf fvarId && numArgs == info.arity then
|
||||
/- Easy case, the continuation `k` is just returning the result of the application. -/
|
||||
markSimplified
|
||||
simp code
|
||||
else
|
||||
let code ← simp code
|
||||
if oneExitPointQuick code then
|
||||
-- TODO: if `k` is small, we should also inline it here
|
||||
markSimplified
|
||||
code.bind fun fvarId' => do
|
||||
markUsedFVar fvarId'
|
||||
/- fvarId' is the result of the computation -/
|
||||
if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
|
||||
addFVarSubst fvarId decl.fvarId
|
||||
simp (.let decl k)
|
||||
else
|
||||
addFVarSubst fvarId fvarId'
|
||||
simp k
|
||||
-- else if info.ifReduce then
|
||||
-- eraseCode code
|
||||
-- return none
|
||||
else
|
||||
markSimplified
|
||||
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
|
||||
let jpValue ← if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
|
||||
addFVarSubst fvarId decl.fvarId
|
||||
simp (.let decl k)
|
||||
else
|
||||
addFVarSubst fvarId jpParam.fvarId
|
||||
simp k
|
||||
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
|
||||
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
return Code.jp jpDecl code
|
||||
|
||||
/--
|
||||
Simplify the given local function declaration.
|
||||
-/
|
||||
partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
|
||||
let type ← normExpr decl.type
|
||||
let params ← normParams decl.params
|
||||
let value ← simp decl.value
|
||||
decl.update type params value
|
||||
|
||||
/--
|
||||
Try to simplify `cases` of `constructor`
|
||||
-/
|
||||
partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
||||
let discr ← normFVar cases.discr
|
||||
let discrExpr ← findCtor (.fvar discr)
|
||||
let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) (useRaw := true) | return none
|
||||
let (alt, cases) := cases.extractAlt! ctorVal.name
|
||||
eraseCode (.cases cases)
|
||||
markSimplified
|
||||
match alt with
|
||||
| .default k => simp k
|
||||
| .alt _ params k =>
|
||||
let fields := ctorArgs[ctorVal.numParams:]
|
||||
let mut auxDecls := #[]
|
||||
for param in params, field in fields do
|
||||
/-
|
||||
`field` may not be a free variable. Recall that `constructorApp?` has special support for numerals,
|
||||
and `ctorArgs` contains a numeral if `discrExpr` is a numeral. We may have other cases in the future.
|
||||
To make the code robust, we add auxiliary declarations whenever the `field` is not a free variable.
|
||||
-/
|
||||
if field.isFVar then
|
||||
addFVarSubst param.fvarId field.fvarId!
|
||||
else
|
||||
let auxDecl ← mkAuxLetDecl field
|
||||
auxDecls := auxDecls.push (CodeDecl.let auxDecl)
|
||||
addFVarSubst param.fvarId auxDecl.fvarId
|
||||
let k ← simp k
|
||||
eraseParams params
|
||||
attachCodeDecls auxDecls k
|
||||
|
||||
/--
|
||||
Simplify `code`
|
||||
-/
|
||||
partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
||||
incVisited
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let mut decl ← normLetDecl decl
|
||||
if let some value ← simpValue? decl.value then
|
||||
decl ← decl.updateValue value
|
||||
if let some funDecl ← etaPolyApp? decl then
|
||||
simp (.fun funDecl k)
|
||||
else if decl.value.isFVar then
|
||||
/- Eliminate `let _x_i := _x_j;` -/
|
||||
addFVarSubst decl.fvarId decl.value.fvarId!
|
||||
eraseLetDecl decl
|
||||
simp k
|
||||
else if let some code ← inlineApp? decl k then
|
||||
eraseLetDecl decl
|
||||
return code
|
||||
else if let some (decls, fvarId) ← inlineProjInst? decl.value then
|
||||
addFVarSubst decl.fvarId fvarId
|
||||
eraseLetDecl decl
|
||||
let k ← simp k
|
||||
attachCodeDecls decls k
|
||||
else
|
||||
let k ← simp k
|
||||
if (← isUsed decl.fvarId) then
|
||||
markUsedLetDecl decl
|
||||
return code.updateLet! decl k
|
||||
else
|
||||
/- Dead variable elimination -/
|
||||
eraseLetDecl decl
|
||||
return k
|
||||
| .fun decl k | .jp decl k =>
|
||||
let mut decl := decl
|
||||
let toBeInlined ← isOnceOrMustInline decl.fvarId
|
||||
if toBeInlined then
|
||||
/-
|
||||
If the declaration will be inlined, it is wasteful to eagerly simplify it.
|
||||
So, we just normalize it (i.e., apply the substitution to it).
|
||||
-/
|
||||
decl ← normFunDecl decl
|
||||
else
|
||||
/-
|
||||
Note that functions in `decl` will be marked as used even if `decl` is not actually used.
|
||||
They will only be deleted in the next pass. TODO: investigate whether this is a problem.
|
||||
-/
|
||||
if code.isFun then
|
||||
if decl.isEtaExpandCandidate then
|
||||
/- We must apply substitution before trying to eta-expand, otherwise `inferType` may fail. -/
|
||||
decl ← normFunDecl decl
|
||||
/-
|
||||
We want to eta-expand **before** trying to simplify local function declaration because
|
||||
eta-expansion creates many optimization opportunities.
|
||||
-/
|
||||
decl ← decl.etaExpand
|
||||
markSimplified
|
||||
decl ← simpFunDecl decl
|
||||
let k ← simp k
|
||||
if (← isUsed decl.fvarId) then
|
||||
if toBeInlined then
|
||||
/-
|
||||
`decl` was supposed to be inlined, but there are still references to it.
|
||||
Thus, we must all variables in `decl` as used. Recall it was not eagerly simplified.
|
||||
-/
|
||||
markUsedFunDecl decl
|
||||
return code.updateFun! decl k
|
||||
else
|
||||
/- Dead function elimination -/
|
||||
eraseFunDecl decl
|
||||
return k
|
||||
| .return fvarId =>
|
||||
let fvarId ← normFVar fvarId
|
||||
markUsedFVar fvarId
|
||||
return code.updateReturn! fvarId
|
||||
| .unreach type =>
|
||||
return code.updateUnreach! (← normExpr type)
|
||||
| .jmp fvarId args =>
|
||||
let fvarId ← normFVar fvarId
|
||||
let args ← normExprs args
|
||||
if let some code ← inlineJp? fvarId args then
|
||||
simp code
|
||||
else
|
||||
markUsedFVar fvarId
|
||||
args.forM markUsedExpr
|
||||
return code.updateJmp! fvarId args
|
||||
| .cases c =>
|
||||
if let some k ← simpCasesOnCtor? c then
|
||||
return k
|
||||
else
|
||||
let simpCasesDefault := do
|
||||
let discr ← normFVar c.discr
|
||||
let resultType ← normExpr c.resultType
|
||||
markUsedFVar discr
|
||||
let alts ← c.alts.mapMonoM fun alt =>
|
||||
match alt with
|
||||
| .alt ctorName ps k =>
|
||||
withDiscrCtor discr ctorName ps do
|
||||
return alt.updateCode (← simp k)
|
||||
| .default k => return alt.updateCode (← simp k)
|
||||
let alts ← addDefaultAlt alts
|
||||
if alts.size == 1 && alts[0]! matches .default .. then
|
||||
return alts[0]!.getCode
|
||||
else
|
||||
return code.updateCases! resultType discr alts
|
||||
simpCasesDefault
|
||||
end
|
||||
239
src/Lean/Compiler/LCNF/Simp/SimpM.lean
Normal file
239
src/Lean/Compiler/LCNF/Simp/SimpM.lean
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.ImplementedByAttr
|
||||
import Lean.Compiler.LCNF.ElimDead
|
||||
import Lean.Compiler.LCNF.AlphaEqv
|
||||
import Lean.Compiler.LCNF.PrettyPrinter
|
||||
import Lean.Compiler.LCNF.Bind
|
||||
import Lean.Compiler.LCNF.Simp.JpCases
|
||||
import Lean.Compiler.LCNF.Simp.FunDeclInfo
|
||||
import Lean.Compiler.LCNF.Simp.Config
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
structure Context where
|
||||
/--
|
||||
Name of the declaration being simplified.
|
||||
We currently use this information because we are generating phase1 declarations on demand,
|
||||
and it may trigger non-termination when trying to access the phase1 declaration.
|
||||
-/
|
||||
declName : Name
|
||||
config : Config := {}
|
||||
discrCtorMap : FVarIdMap Expr := {}
|
||||
/--
|
||||
Stack of global declarations being recursively inlined.
|
||||
-/
|
||||
inlineStack : List Name := []
|
||||
|
||||
structure State where
|
||||
/--
|
||||
Free variable substitution. We use it to implement inlining and removing redundant variables `let _x.i := _x.j`
|
||||
-/
|
||||
subst : FVarSubst := {}
|
||||
/--
|
||||
Track used local declarations to be able to eliminate dead variables.
|
||||
-/
|
||||
used : UsedLocalDecls := {}
|
||||
/--
|
||||
Mapping used to decide whether a local function declaration must be inlined or not.
|
||||
-/
|
||||
funDeclInfoMap : FunDeclInfoMap := {}
|
||||
/--
|
||||
`true` if some simplification was performed in the current simplification pass.
|
||||
-/
|
||||
simplified : Bool := false
|
||||
/--
|
||||
Number of visited `let-declarations` and terminal values.
|
||||
This is a performance counter, and currently has no impact on code generation.
|
||||
-/
|
||||
visited : Nat := 0
|
||||
/--
|
||||
Number of definitions inlined.
|
||||
This is a performance counter.
|
||||
-/
|
||||
inline : Nat := 0
|
||||
/--
|
||||
Number of local functions inlined.
|
||||
This is a performance counter.
|
||||
-/
|
||||
inlineLocal : Nat := 0
|
||||
|
||||
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
instance : MonadFVarSubst SimpM false where
|
||||
getSubst := return (← get).subst
|
||||
|
||||
instance : MonadFVarSubstState SimpM where
|
||||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||||
|
||||
/--
|
||||
Use `findExpr`, and if the result is a free variable, check whether it is in the map `discrCtorMap`.
|
||||
We use this method when simplifying projections and cases-constructor.
|
||||
-/
|
||||
def findCtor (e : Expr) : SimpM Expr := do
|
||||
let e ← findExpr e
|
||||
let .fvar fvarId := e | return e
|
||||
let some ctor := (← read).discrCtorMap.find? fvarId | return e
|
||||
return ctor
|
||||
|
||||
/--
|
||||
Execute `x` with the information that `discr = ctorName ctorFields`.
|
||||
We use this information to simplify nested cases on the same discriminant.
|
||||
|
||||
Remark: we do not perform the reverse direction at this phase.
|
||||
That is, we do not replace occurrences of `ctorName ctorFields` with `discr`.
|
||||
We wait more type information to be erased.
|
||||
-/
|
||||
def withDiscrCtor (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : SimpM α) : SimpM α := do
|
||||
let ctorInfo ← getConstInfoCtor ctorName
|
||||
let mut ctor := mkConst ctorName
|
||||
for _ in [:ctorInfo.numParams] do
|
||||
ctor := .app ctor erasedExpr -- the parameters are irrelevant for optimizations that use this information
|
||||
for field in ctorFields do
|
||||
ctor := .app ctor (.fvar field.fvarId)
|
||||
withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor }) do
|
||||
x
|
||||
|
||||
/-- Set the `simplified` flag to `true`. -/
|
||||
def markSimplified : SimpM Unit :=
|
||||
modify fun s => { s with simplified := true }
|
||||
|
||||
/-- Increment `visited` performance counter. -/
|
||||
def incVisited : SimpM Unit :=
|
||||
modify fun s => { s with visited := s.visited + 1 }
|
||||
|
||||
/-- Increment `inline` performance counter. It is the number of inlined global declarations. -/
|
||||
def incInline : SimpM Unit :=
|
||||
modify fun s => { s with inline := s.inline + 1 }
|
||||
|
||||
/-- Increment `inlineLocal` performance counter. It is the number of inlined local function and join point declarations. -/
|
||||
def incInlineLocal : SimpM Unit :=
|
||||
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
|
||||
|
||||
/-- Mark the local function declaration or join point with the given id as a "must inline". -/
|
||||
def addMustInline (fvarId : FVarId) : SimpM Unit :=
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline fvarId }
|
||||
|
||||
/-- Add a new occurrence of local function `fvarId`. -/
|
||||
def addFunOcc (fvarId : FVarId) : SimpM Unit :=
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add fvarId }
|
||||
|
||||
/-- Add a new occurrence of local function `fvarId` in argument position . -/
|
||||
def addFunHoOcc (fvarId : FVarId) : SimpM Unit :=
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addHo fvarId }
|
||||
|
||||
@[inheritDoc FunDeclInfoMap.update]
|
||||
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit := do
|
||||
let map ← modifyGet fun s => (s.funDeclInfoMap, { s with funDeclInfoMap := {} })
|
||||
let map ← map.update code mustInline
|
||||
modify fun s => { s with funDeclInfoMap := map }
|
||||
|
||||
/--
|
||||
Execute `x` with an updated `inlineStack`. If `value` is of the form `const ...`, add `const` to the stack.
|
||||
Otherwise, do not change the `inlineStack`.
|
||||
-/
|
||||
@[inline] def withInlining (value : Expr) (x : SimpM α) : SimpM α := do
|
||||
trace[Compiler.simp.inline] "inlining {value}"
|
||||
let f := value.getAppFn
|
||||
let stack := (← read).inlineStack
|
||||
let inlineStack := if let .const declName _ := f then declName :: stack else stack
|
||||
withReader (fun ctx => { ctx with inlineStack }) x
|
||||
|
||||
/--
|
||||
Similar to the default `Lean.withIncRecDepth`, but include the `inlineStack` in the error messsage.
|
||||
-/
|
||||
@[inline] def withIncRecDepth (x : SimpM α) : SimpM α := do
|
||||
let curr ← MonadRecDepth.getRecDepth
|
||||
let max ← MonadRecDepth.getMaxRecDepth
|
||||
if curr == max then
|
||||
throwMaxRecDepth
|
||||
else
|
||||
MonadRecDepth.withRecDepth (curr+1) x
|
||||
where
|
||||
throwMaxRecDepth : SimpM α := do
|
||||
match (← read).inlineStack with
|
||||
| [] => throwError maxRecDepthErrorMessage
|
||||
| declName :: stack =>
|
||||
let mut fmt := f!"{declName}\n"
|
||||
let mut prev := declName
|
||||
let mut ellipsis := false
|
||||
for declName in stack do
|
||||
if prev == declName then
|
||||
unless ellipsis do
|
||||
ellipsis := true
|
||||
fmt := fmt ++ "...\n"
|
||||
else
|
||||
fmt := fmt ++ f!"{declName}\n"
|
||||
prev := declName
|
||||
ellipsis := false
|
||||
throwError "maximum recursion depth reached in the code generator\nfunction inline stack:\n{fmt}"
|
||||
|
||||
/--
|
||||
Execute `x` with `fvarId` set as `mustInline`.
|
||||
After execution the original setting is restored.
|
||||
-/
|
||||
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
|
||||
let saved? := (← get).funDeclInfoMap.map.find? fvarId
|
||||
try
|
||||
addMustInline fvarId
|
||||
x
|
||||
finally
|
||||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.restore fvarId saved? }
|
||||
|
||||
/--
|
||||
Return true if the given local function declaration or join point id is marked as
|
||||
`once` or `mustInline`. We use this information to decide whether to inline them.
|
||||
-/
|
||||
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
|
||||
match (← get).funDeclInfoMap.map.find? fvarId with
|
||||
| some .once | some .mustInline => return true
|
||||
| _ => return false
|
||||
|
||||
/--
|
||||
Return `true` if the given local function declaration is considered "small".
|
||||
-/
|
||||
def isSmall (decl : FunDecl) : SimpM Bool :=
|
||||
return decl.value.sizeLe (← read).config.smallThreshold
|
||||
|
||||
/--
|
||||
Return `true` if the given local function declaration should be inlined.
|
||||
-/
|
||||
def shouldInlineLocal (decl : FunDecl) : SimpM Bool := do
|
||||
if (← isOnceOrMustInline decl.fvarId) then
|
||||
return true
|
||||
else
|
||||
isSmall decl
|
||||
|
||||
/--
|
||||
"Beta-reduce" `(fun params => code) args`.
|
||||
If `mustInline` is true, the local function declarations in the resulting code are marked as `.mustInline`.
|
||||
See comment at `updateFunDeclInfo`.
|
||||
-/
|
||||
def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInline := false) : SimpM Code := do
|
||||
-- TODO: add necessary casts to `args`
|
||||
let mut subst := {}
|
||||
for param in params, arg in args do
|
||||
subst := subst.insert param.fvarId arg
|
||||
let code ← code.internalize subst
|
||||
updateFunDeclInfo code mustInline
|
||||
return code
|
||||
|
||||
/--
|
||||
Erase the given let-declaration from the local context,
|
||||
and set the `simplified` flag to true.
|
||||
-/
|
||||
def eraseLetDecl (decl : LetDecl) : SimpM Unit := do
|
||||
LCNF.eraseLetDecl decl
|
||||
markSimplified
|
||||
|
||||
/--
|
||||
Erase the given local function declaration from the local context,
|
||||
and set the `simplified` flag to true.
|
||||
-/
|
||||
def eraseFunDecl (decl : FunDecl) : SimpM Unit := do
|
||||
LCNF.eraseFunDecl decl
|
||||
markSimplified
|
||||
48
src/Lean/Compiler/LCNF/Simp/SimpValue.lean
Normal file
48
src/Lean/Compiler/LCNF/Simp/SimpValue.lean
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.Simp.SimpM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Try to simplify projections `.proj _ i s` where `s` is constructor.
|
||||
-/
|
||||
def simpProj? (e : Expr) : OptionT SimpM Expr := do
|
||||
let .proj _ i s := e | failure
|
||||
let s ← findCtor s
|
||||
let some (ctorVal, args) := s.constructorApp? (← getEnv) | failure
|
||||
markSimplified
|
||||
return args[ctorVal.numParams + i]!
|
||||
|
||||
/--
|
||||
Application over application.
|
||||
```
|
||||
let _x.i := f a
|
||||
_x.i b
|
||||
```
|
||||
is simplified to `f a b`.
|
||||
-/
|
||||
def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
|
||||
guard e.isApp
|
||||
let f := e.getAppFn
|
||||
guard f.isFVar
|
||||
let f ← findExpr f
|
||||
guard <| f.isApp || f.isConst
|
||||
markSimplified
|
||||
return mkAppN f e.getAppArgs
|
||||
|
||||
def applyImplementedBy? (e : Expr) : OptionT SimpM Expr := do
|
||||
guard <| (← read).config.implementedBy
|
||||
let .const declName us := e.getAppFn | failure
|
||||
let some declNameNew := getImplementedBy? (← getEnv) declName | failure
|
||||
markSimplified
|
||||
return mkAppN (.const declNameNew us) e.getAppArgs
|
||||
|
||||
/-- Try to apply simple simplifications. -/
|
||||
def simpValue? (e : Expr) : SimpM (Option Expr) :=
|
||||
-- TODO: more simplifications
|
||||
simpProj? e <|> simpAppApp? e <|> applyImplementedBy? e
|
||||
82
src/Lean/Compiler/LCNF/Simp/Used.lean
Normal file
82
src/Lean/Compiler/LCNF/Simp/Used.lean
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Compiler.LCNF.Simp.SimpM
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
/--
|
||||
Mark `fvarId` as an used free variable.
|
||||
This is information is used to eliminate dead local declarations.
|
||||
-/
|
||||
def markUsedFVar (fvarId : FVarId) : SimpM Unit :=
|
||||
modify fun s => { s with used := s.used.insert fvarId }
|
||||
|
||||
/--
|
||||
Mark all free variables occurring in `e` as used.
|
||||
This is information is used to eliminate dead local declarations.
|
||||
-/
|
||||
def markUsedExpr (e : Expr) : SimpM Unit :=
|
||||
modify fun s => { s with used := collectLocalDecls s.used e }
|
||||
|
||||
/--
|
||||
Mark all free variables occurring on the right-hand side of the given let declaration as used.
|
||||
This is information is used to eliminate dead local declarations.
|
||||
-/
|
||||
def markUsedLetDecl (letDecl : LetDecl) : SimpM Unit :=
|
||||
markUsedExpr letDecl.value
|
||||
|
||||
mutual
|
||||
/--
|
||||
Mark all free variables occurring in `code` as used.
|
||||
-/
|
||||
partial def markUsedCode (code : Code) : SimpM Unit := do
|
||||
match code with
|
||||
| .let decl k => markUsedLetDecl decl; markUsedCode k
|
||||
| .jp decl k | .fun decl k => markUsedFunDecl decl; markUsedCode k
|
||||
| .return fvarId => markUsedFVar fvarId
|
||||
| .unreach .. => return ()
|
||||
| .jmp fvarId args => markUsedFVar fvarId; args.forM markUsedExpr
|
||||
| .cases c => markUsedFVar c.discr; c.alts.forM fun alt => markUsedCode alt.getCode
|
||||
|
||||
/--
|
||||
Mark all free variables occurring in `funDecl` as used.
|
||||
-/
|
||||
partial def markUsedFunDecl (funDecl : FunDecl) : SimpM Unit :=
|
||||
markUsedCode funDecl.value
|
||||
end
|
||||
|
||||
/--
|
||||
Return `true` if `fvarId` is in the `used` set.
|
||||
-/
|
||||
def isUsed (fvarId : FVarId) : SimpM Bool :=
|
||||
return (← get).used.contains fvarId
|
||||
|
||||
/--
|
||||
Attach the given `decls` to `code`. For example, assume `decls` is `#[.let _x.1 := 10, .let _x.2 := true]`,
|
||||
then the result is
|
||||
```
|
||||
let _x.1 := 10
|
||||
let _x.2 := true
|
||||
<code>
|
||||
```
|
||||
-/
|
||||
def attachCodeDecls (decls : Array CodeDecl) (code : Code) : SimpM Code := do
|
||||
go decls.size code
|
||||
where
|
||||
go (i : Nat) (code : Code) : SimpM Code := do
|
||||
if i > 0 then
|
||||
let decl := decls[i-1]!
|
||||
if (← isUsed decl.fvarId) then
|
||||
match decl with
|
||||
| .let decl => markUsedLetDecl decl; go (i-1) (.let decl code)
|
||||
| .fun decl => markUsedFunDecl decl; go (i-1) (.fun decl code)
|
||||
| .jp decl => markUsedFunDecl decl; go (i-1) (.jp decl code)
|
||||
else
|
||||
eraseCodeDecl decl
|
||||
go (i-1) code
|
||||
else
|
||||
return code
|
||||
|
|
@ -408,7 +408,7 @@ mutual
|
|||
end
|
||||
|
||||
def main (decl : Decl) : SpecializeM Decl := do
|
||||
if (← isTemplateLike decl) then
|
||||
if (← decl.isTemplateLike) then
|
||||
return decl
|
||||
else
|
||||
let value ← withParams decl.params <| visitCode decl.value
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue