lean4-htt/src/Lean/Compiler/LCNF/Simp.lean
2022-09-13 18:23:42 -07:00

1056 lines
36 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.Recognizers
import Lean.Meta.Instances
import Lean.Compiler.InlineAttrs
import Lean.Compiler.Specialize
import Lean.Compiler.ImplementedByAttr
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.ElimDead
import Lean.Compiler.LCNF.Bind
import Lean.Compiler.LCNF.PrettyPrinter
import Lean.Compiler.LCNF.Stage1
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.AlphaEqv
namespace Lean.Compiler.LCNF
/--
Return `true` if the arrow type contains an instance implicit argument.
-/
def hasLocalInst (type : Expr) : Bool :=
match type with
| .forallE _ _ b bi => bi.isInstImplicit || hasLocalInst b
| _ => false
/--
Return `true` if `decl` is supposed to be inlined/specialized.
-/
def isTemplateLike (decl : Decl) : CoreM Bool := do
if hasLocalInst decl.type then
return true -- `decl` applications will be specialized
else if (← Meta.isInstance decl.name) then
return true -- `decl` is "fuel" for code specialization
else if hasInlineAttribute (← getEnv) decl.name || hasSpecializeAttribute (← getEnv) decl.name then
return true -- `decl` is going to be inlined or specialized
else
return false
namespace Simp
/--
Local function usage information used to decide whether it should be inlined or not.
The information is an approximation, but it is on the "safe" side. That is, if we tagged
a function with `.once`, then it is applied only once. A local function may be marked as
`.many`, but after simplifications the number of applications may reduce to 1. This is not
a big problem in practice because we run the simplifier multiple times, and this information
is recomputed from scratch at the beginning of each simplification step.
-/
inductive FunDeclInfo where
| /--
Local function is applied once, and must be inlined.
-/
once
| /--
Local function is applied many times, and will only be inlined
if it is small.
-/
many
| /--
Function must be inlined.
-/
mustInline
deriving Repr, Inhabited
/--
Local function declaration statistics.
-/
structure FunDeclInfoMap where
/--
Mapping from local function name to inlining information.
-/
map : Std.HashMap FVarId FunDeclInfo := {}
deriving Inhabited
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
let mut result := Format.nil
for (fvarId, info) in s.map.toList do
let localDecl ← getLocalDecl fvarId
result := result ++ "\n" ++ f!"{localDecl.userName} ↦ {repr info}"
return result
/--
Add new occurrence for the local function with binder name `key`.
-/
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } =>
match map.find? fvarId with
| some .once => { map := map.insert fvarId .many }
| none => { map := map.insert fvarId .once }
| _ => { map }
/--
Add new occurrence for the local function with binder name `key`.
-/
def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } => { map := map.insert fvarId .mustInline }
def FunDeclInfoMap.restore (s : FunDeclInfoMap) (fvarId : FVarId) (saved? : Option FunDeclInfo) : FunDeclInfoMap :=
match s, saved? with
| { map }, none => { map := map.erase fvarId }
| { map }, some saved => { map := map.insert fvarId saved }
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
match e with
| .fvar fvarId =>
if let some decl ← LCNF.findFunDecl? fvarId then
return some decl
else if let .ldecl (value := v) .. ← getLocalDecl fvarId then
findFunDecl? v
else
return none
| .mdata _ e => findFunDecl? e
| _ => return none
partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
match e with
| .fvar fvarId =>
let .ldecl (value := v) .. ← getLocalDecl fvarId | return e
findExpr v
| .mdata _ e' => if skipMData then findExpr e' else return e
| _ => return e
structure Config where
/--
Any local function declaration or join point with size `≤ smallThresold` is inlined
even if there are multiple occurrences.
We currently don't do the same for global declarations because we are not saving
the stage1 compilation result in .olean files yet.
-/
smallThreshold : Nat := 1
/--
If `etaPoly` is true, we eta expand any global function application when
the function takes local instances. The idea is that we do not generate code
for this kind of application, and we want all of them to specialized or inlined.
-/
etaPoly : Bool := false
/--
If `inlinePartial` is `true`, we inline partial function applications tagged
with `[inline]`. Note that this option is automatically disabled when processing
declarations tagged with `[inline]`, marked to be specialized, or instances.
-/
inlinePartial := false
/--
If `implementedBy` is `true`, we apply the `implementedBy` replacements.
Remark: we only apply `casesOn` replacements at phase 2 because `cases` constructor
may not have enough information for reconstructing the original `casesOn` application at
phase 1.
-/
implementedBy := false
deriving Inhabited
structure Context where
/--
Name of the declaration being simplified.
We currently use this information because we are generating phase1 declarations on demand,
and it may trigger non-termination when trying to access the phase1 declaration.
-/
declName : Name
config : Config := {}
discrCtorMap : FVarIdMap Expr := {}
structure State where
/--
Free variable substitution. We use it to implement inlining and removing redundant variables `let _x.i := _x.j`
-/
subst : FVarSubst := {}
/--
Track used local declarations to be able to eliminate dead variables.
-/
used : UsedLocalDecls := {}
/--
Mapping used to decide whether a local function declaration must be inlined or not.
-/
funDeclInfoMap : FunDeclInfoMap := {}
/--
`true` if some simplification was performed in the current simplification pass.
-/
simplified : Bool := false
/--
Number of visited `let-declarations` and terminal values.
This is a performance counter, and currently has no impact on code generation.
-/
visited : Nat := 0
/--
Number of definitions inlined.
This is a performance counter.
-/
inline : Nat := 0
/--
Number of local functions inlined.
This is a performance counter.
-/
inlineLocal : Nat := 0
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
instance : MonadFVarSubst SimpM where
getSubst := return (← get).subst
/--
Use `findExpr`, and if the result is a free variable, check whether it is in the map `discrCtorMap`.
We use this method when simplifying projections and cases-constructor.
-/
def findCtor (e : Expr) : SimpM Expr := do
let e ← findExpr e
let .fvar fvarId := e | return e
let some ctor := (← read).discrCtorMap.find? fvarId | return e
return ctor
/--
Execute `x` with the information that `discr = ctorName ctorFields`.
We use this information to simplify nested cases on the same discriminant.
Remark: we do not perform the reverse direction at this phase.
That is, we do not replace occurrences of `ctorName ctorFields` with `discr`.
We wait more type information to be erased.
-/
def withDiscrCtor (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : SimpM α) : SimpM α := do
let ctorInfo ← getConstInfoCtor ctorName
let mut ctor := mkConst ctorName
for _ in [:ctorInfo.numParams] do
ctor := .app ctor erasedExpr -- the parameters are irrelevant for optimizations that use this information
for field in ctorFields do
ctor := .app ctor (.fvar field.fvarId)
withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor }) do
x
/-- Set the `simplified` flag to `true`. -/
def markSimplified : SimpM Unit :=
modify fun s => { s with simplified := true }
/-- Increment `visited` performance counter. -/
def incVisited : SimpM Unit :=
modify fun s => { s with visited := s.visited + 1 }
/-- Increment `inline` performance counter. It is the number of inlined global declarations. -/
def incInline : SimpM Unit :=
modify fun s => { s with inline := s.inline + 1 }
/-- Increment `inlineLocal` performance counter. It is the number of inlined local function and join point declarations. -/
def incInlineLocal : SimpM Unit :=
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
/-- Mark the local function declaration or join point with the given id as a "must inline". -/
def addMustInline (fvarId : FVarId) : SimpM Unit :=
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline fvarId }
/-- Add a new occurrence of local function `fvarId`. -/
def addFunOcc (fvarId : FVarId) : SimpM Unit :=
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add fvarId }
/--
Traverse `code` and update function occurrence map.
This map is used to decide whether we inline local functions or not.
If `mustInline := true`, then all local function declarations occurring in
`code` are tagged as `.mustInline`.
Recall that we use `.mustInline` for local function declarations occurring in type class instances.
-/
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :=
go code
where
go (code : Code) : SimpM Unit := do
match code with
| .let decl k =>
if decl.value.isApp then
if let some funDecl ← findFunDecl? decl.value.getAppFn then
addFunOcc funDecl.fvarId
go k
| .fun decl k =>
if mustInline then
addMustInline decl.fvarId
go decl.value; go k
| .jp decl k => go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .jmp fvarId .. =>
let funDecl ← getFunDecl fvarId
addFunOcc funDecl.fvarId
| .return .. | .unreach .. => return ()
/--
Execute `x` with `fvarId` set as `mustInline`.
After execution the original setting is restored.
-/
def withAddMustInline (fvarId : FVarId) (x : SimpM α) : SimpM α := do
let saved? := (← get).funDeclInfoMap.map.find? fvarId
try
addMustInline fvarId
x
finally
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.restore fvarId saved? }
/--
Return true if the given local function declaration or join point id is marked as
`once` or `mustInline`. We use this information to decide whether to inline them.
-/
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
match (← get).funDeclInfoMap.map.find? fvarId with
| some .once | some .mustInline => return true
| _ => return false
/--
Return `true` if the given local function declaration is considered "small".
-/
def isSmall (decl : FunDecl) : SimpM Bool :=
return decl.value.sizeLe (← read).config.smallThreshold
/--
Return `true` if the given local function declaration should be inlined.
-/
def shouldInlineLocal (decl : FunDecl) : SimpM Bool := do
if (← isOnceOrMustInline decl.fvarId) then
return true
else
isSmall decl
/--
Result of `inlineCandidate?`.
It contains information for inlining local and global functions.
-/
structure InlineCandidateInfo where
isLocal : Bool
params : Array Param
/-- Value (lambda expression) of the function to be inlined. -/
value : Code
f : Expr
args : Array Expr
/-- The arity (aka number of parameters) of the function to be inlined. -/
def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
| { params, .. } => params.size
/--
Return `some i` if `decl` is of the form
```
def f (a_0 ... a_i ...) :=
...
cases a_i
| ...
| ...
```
That is, `f` is a sequence of declarations followed by a `cases` on the parameter `i`.
We use this function to decide whether we should inline a declaration tagged with
`[inlineIfReduce]` or not.
-/
def isCasesOnParam? (decl : Decl) : Option Nat :=
go decl.value
where
go (code : Code) : Option Nat :=
match code with
| .let _ k | .jp _ k | .fun _ k => go k
| .cases c => decl.params.findIdx? fun param => param.fvarId == c.discr
| _ => none
/--
Return `some info` if `e` should be inlined.
-/
def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
let mut e := e
let mut mustInline := false
if e.isAppOfArity ``inline 2 then
e ← findExpr e.appArg!
mustInline := true
let numArgs := e.getAppNumArgs
let f := e.getAppFn
if let .const declName us ← findExpr f then
if declName == (← read).declName then return none -- TODO: remove after we start storing phase1 code in .olean files
let inlineIfReduce := hasInlineIfReduceAttribute (← getEnv) declName
unless mustInline || hasInlineAttribute (← getEnv) declName || inlineIfReduce do return none
-- TODO: check whether function is recursive or not.
-- We can skip the test and store function inline so far.
let some decl ← getStage1Decl? declName | return none
let arity := decl.getArity
let inlinePartial := (← read).config.inlinePartial
if !mustInline && !inlinePartial && numArgs < arity then return none
if inlineIfReduce then
let some paramIdx := isCasesOnParam? decl | return none
unless paramIdx < numArgs do return none
let arg ← findCtor (e.getArg! paramIdx)
unless arg.isConstructorApp (← getEnv) do return none
let params := decl.instantiateParamsLevelParams us
let value := decl.instantiateValueLevelParams us
incInline
return some {
isLocal := false
f := e.getAppFn
args := e.getAppArgs
params, value
}
else if let some decl ← findFunDecl? f then
unless numArgs > 0 do return none -- It is not worth to inline a local function that does not take any arguments
unless mustInline || (← shouldInlineLocal decl) do return none
-- Remark: we inline local function declarations even if they are partial applied
incInlineLocal
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
return some {
isLocal := true
f := e.getAppFn
args := e.getAppArgs
params := decl.params
value := decl.value
}
else
return none
/--
Add substitution `fvarId ↦ val`. `val` is a free variable, or
it is a type, type former, or `lcErased`.
-/
def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit :=
modify fun s => { s with subst := s.subst.insert fvarId val }
/--
Return `true` if `c` has only one exit point.
This is a quick approximation. It does not check cases
such as: a `cases` with many but only one alternative is not reachable.
It is only used to avoid the creation of auxiliary join points, and does not need
to be precise.
-/
private partial def oneExitPointQuick (c : Code) : Bool :=
go c
where
go (c : Code) : Bool :=
match c with
| .let _ k | .fun _ k => go k
-- Approximation, the cases may have many unreachable alternatives, and only reachable.
| .cases c => c.alts.size == 1 && c.alts.any fun alt => go alt.getCode
-- Approximation, we assume that any code containing join points have more than one exit point
| .jp .. | .jmp .. => false
| .return .. | .unreach .. => true
/--
"Beta-reduce" `(fun params => code) args`.
If `mustInline` is true, the local function declarations in the resulting code are marked as `.mustInline`.
See comment at `updateFunDeclInfo`.
-/
def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInline := false) : SimpM Code := do
-- TODO: add necessary casts to `args`
let mut subst := {}
for param in params, arg in args do
subst := subst.insert param.fvarId arg
let code ← code.internalize subst
updateFunDeclInfo code mustInline
return code
/--
Create a new local function declaration when `info.args.size < info.params.size`.
We use this function to inline/specialize a partial application of a local function.
-/
def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
let mut subst := {}
for param in info.params, arg in info.args do
subst := subst.insert param.fvarId arg
let mut paramsNew := #[]
for param in info.params[info.args.size:] do
let type ← replaceExprFVars param.type subst
let paramNew ← mkAuxParam type
paramsNew := paramsNew.push paramNew
subst := subst.insert param.fvarId (.fvar paramNew.fvarId)
let code ← info.value.internalize subst
updateFunDeclInfo code
mkAuxFunDecl paramsNew code
/--
If the value of the given let-declaration is an application that can be inlined, inline it.
`k` is the "continuation" for the let declaration.
-/
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
if k matches .unreach .. then return some k
let some info ← inlineCandidate? letDecl.value | return none
markSimplified
let numArgs := info.args.size
trace[Compiler.simp.inline] "inlining {letDecl.value}"
let fvarId := letDecl.fvarId
if numArgs < info.arity then
let funDecl ← specializePartialApp info
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
return some (.fun funDecl k)
else
let code ← betaReduce info.params info.value info.args[:info.arity]
if k.isReturnOf fvarId && numArgs == info.arity then
/- Easy case, the continuation `k` is just returning the result of the application. -/
return code
else if oneExitPointQuick code then
/-
`code` has only one exit point, thus we can attach the continuation directly there,
and simplify the result.
-/
code.bind fun fvarId' => do
/- fvarId' is the result of the computation -/
if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
let k ← replaceFVar k fvarId decl.fvarId
return .let decl k
else
replaceFVar k fvarId fvarId'
else
/-
`code` has multiple exit points, and the continuation is non-trivial
Thus, we create an auxiliary join point.
-/
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
let jpValue ← if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
let k ← replaceFVar k fvarId decl.fvarId
pure <| .let decl k
else
replaceFVar k fvarId jpParam.fvarId
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
return Code.jp jpDecl code
/--
Try to inline a join point.
-/
partial def inlineJp? (fvarId : FVarId) (args : Array Expr) : SimpM (Option Code) := do
let some decl ← LCNF.findFunDecl? fvarId | return none
unless (← shouldInlineLocal decl) do return none
markSimplified
betaReduce decl.params decl.value args
/--
Mark `fvarId` as an used free variable.
This is information is used to eliminate dead local declarations.
-/
def markUsedFVar (fvarId : FVarId) : SimpM Unit :=
modify fun s => { s with used := s.used.insert fvarId }
/--
Mark all free variables occurring in `e` as used.
This is information is used to eliminate dead local declarations.
-/
def markUsedExpr (e : Expr) : SimpM Unit :=
modify fun s => { s with used := collectLocalDecls s.used e }
/--
Mark all free variables occurring on the right-hand side of the given let declaration as used.
This is information is used to eliminate dead local declarations.
-/
def markUsedLetDecl (letDecl : LetDecl) : SimpM Unit :=
markUsedExpr letDecl.value
mutual
/--
Mark all free variables occurring in `code` as used.
-/
partial def markUsedCode (code : Code) : SimpM Unit := do
match code with
| .let decl k => markUsedLetDecl decl; markUsedCode k
| .jp decl k | .fun decl k => markUsedFunDecl decl; markUsedCode k
| .return fvarId => markUsedFVar fvarId
| .unreach .. => return ()
| .jmp fvarId args => markUsedFVar fvarId; args.forM markUsedExpr
| .cases c => markUsedFVar c.discr; c.alts.forM fun alt => markUsedCode alt.getCode
/--
Mark all free variables occurring in `funDecl` as used.
-/
partial def markUsedFunDecl (funDecl : FunDecl) : SimpM Unit :=
markUsedCode funDecl.value
end
/--
Return `true` if `fvarId` is in the `used` set.
-/
def isUsed (fvarId : FVarId) : SimpM Bool :=
return (← get).used.contains fvarId
/--
Attach the given `decls` to `code`. For example, assume `decls` is `#[.let _x.1 := 10, .let _x.2 := true]`,
then the result is
```
let _x.1 := 10
let _x.2 := true
<code>
```
-/
def attachCodeDecls (decls : Array CodeDecl) (code : Code) : SimpM Code := do
go decls.size code
where
go (i : Nat) (code : Code) : SimpM Code := do
if i > 0 then
let decl := decls[i-1]!
if (← isUsed decl.fvarId) then
match decl with
| .let decl => markUsedLetDecl decl; go (i-1) (.let decl code)
| .fun decl => markUsedFunDecl decl; go (i-1) (.fun decl code)
| .jp decl => markUsedFunDecl decl; go (i-1) (.jp decl code)
else
eraseFVar decl.fvarId
go (i-1) code
else
return code
/--
Erase all free variables occurring in `decls` from the local context.
-/
def eraseCodeDecls (decls : Array CodeDecl) : SimpM Unit := do
decls.forM fun decl => eraseFVar decl.fvarId
/--
Auxiliary function for projecting "type class dictionary access".
That is, we are trying to extract one of the type class instance elements.
Remark: We do not consider parent instances to be elements.
For example, suppose `e` is `_x_4.1`, and we have
```
_x_2 : Monad (ReaderT Bool (ExceptT String Id)) := @ReaderT.Monad Bool (ExceptT String Id) _x_1
_x_3 : Applicative (ReaderT Bool (ExceptT String Id)) := _x_2.1
_x_4 : Functor (ReaderT Bool (ExceptT String Id)) := _x_3.1
```
Then, we will expand `_x_4.1` since it corresponds to the `Functor` `map` element,
and its type is not a type class, but is of the form
```
{α β : Type u} → (α → β) → ...
```
In the example above, the compiler should not expand `_x_3.1` or `_x_2.1` because they are
type class applications: `Functor` and `Applicative` respectively.
By eagerly expanding them, we may produce inefficient and bloated code.
For example, we may be using `_x_3.1` to invoke a function that expects a `Functor` instance.
By expanding `_x_3.1` we will be just expanding the code that creates this instance.
The result is representing a sequence of code containing let-declarations and local function declarations (`Array CodeDecl`)
and the free variable containing the result (`FVarId`). The resulting `FVarId` often depends only on a small
subset of `Array CodeDecl`. However, this method does try to filter the relevant ones.
We rely on the `used` var set available in `SimpM` to filter them. See `attachCodeDecls`.
-/
partial def inlineProjInst? (e : Expr) : SimpM (Option (Array CodeDecl × FVarId)) := do
let .proj _ i s := e | return none
let sType ← inferType s
unless (← isClass? sType).isSome do return none
let eType ← inferType e
unless (← isClass? eType).isNone do return none
let (fvarId?, decls) ← visit s [i] |>.run |>.run #[]
if let some fvarId := fvarId? then
return some (decls, fvarId)
else
eraseCodeDecls decls
return none
where
visit (e : Expr) (projs : List Nat) : OptionT (StateRefT (Array CodeDecl) SimpM) FVarId := do
let e ← findExpr e
if let .proj _ i s := e then
visit s (i :: projs)
else if let some (ctorVal, ctorArgs) := e.constructorApp? (← getEnv) then
let i :: projs := projs | unreachable!
let e := ctorArgs[ctorVal.numParams + i]!
if projs.isEmpty then
let .fvar fvarId := e | unreachable!
return fvarId
else
visit e projs
else
let .const declName us := e.getAppFn | failure
let some decl ← getStage1Decl? declName | failure
guard (decl.getArity == e.getAppNumArgs)
let code := decl.instantiateValueLevelParams us
let code ← betaReduce decl.params code e.getAppArgs (mustInline := true)
visitCode code projs
visitCode (code : Code) (projs : List Nat) : OptionT (StateRefT (Array CodeDecl) SimpM) FVarId := do
match code with
| .let decl k => modify (·.push (.let decl)); visitCode k projs
| .fun decl k => modify (·.push (.fun decl)); visitCode k projs
| .return fvarId => visit (.fvar fvarId) projs
| _ => failure
/--
Return `some _jp.k` if the given `cases` is of the form
```
cases _x.i
(... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁)
...
(... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ)
```
where `_jp.k` is a join point of the form
```
let _jp.k y :=
cases y ...
```
The goal is to mark `_jp.k` as must inline in this scenario.
The function also allows `_jp.k` to have a small prefix before
`cases y`. The small prefix is set using the configuration option
`config.smallThreshold`. It is currently the same threshold used to
decide when to inline a function that has multiple occurrences.
Example: consider the following declarations
```
@[inline] def pred? (x : Nat) : Option Nat :=
match x with
| 0 => none
| x+1 => some x
def isZero (x : Nat) :=
match pred? x with
| some _ => false
| none => true
```
After inlining `pred?` in `isZero`, this simplification is applicable, producing
Remark: this method does not assume `cases` has already been normalized,
but returns a normalized `FVarId` in case of success.
-/
def isCasesOnCases? (cases : Cases) : OptionT SimpM FVarId := do
let jpFirst ← isCtorJmp? cases.alts[0]!.getCode
let funDecl ← getFunDecl jpFirst
guard <| (← isJpCases funDecl)
for alt in cases.alts[1:] do
let jp ← isCtorJmp? alt.getCode
guard <| jpFirst == jp
return jpFirst
where
isCtorJmp? (code : Code) : OptionT SimpM FVarId := do
match code with
| .let _ k | .jp _ k | .fun _ k => isCtorJmp? k
| .return .. | .unreach .. | .cases .. => failure
| .jmp jpFVarId args =>
let #[arg] := args | failure
let arg ← findExpr (← normExpr arg)
guard <| arg.isConstructorApp (← getEnv)
normFVar jpFVarId
isJpCases (funDecl : FunDecl) : SimpM Bool := do
let param := funDecl.params[0]!
let smallThreshold := (← read).config.smallThreshold
let rec go (code : Code) (prefixSize : Nat) : Bool :=
prefixSize <= smallThreshold &&
match code with
| .let _ k => go k (prefixSize + 1)
| .cases c => c.discr == param.fvarId
| _ => false
return go funDecl.value 0
/--
Return the alternative in `alts` whose body appears in most arms,
and the number of occurrences.
We use this function to decide whether to create a `.default` case
or not.
-/
private def getMaxOccs (alts : Array Alt) : Alt × Nat := Id.run do
let mut maxAlt := alts[0]!
let mut max := getNumOccsOf alts 0
for i in [1:alts.size] do
let curr := getNumOccsOf alts i
if curr > max then
maxAlt := alts[i]!
max := curr
return (maxAlt, max)
where
/--
Return the number of occurrences of `alts[i]` in `alts`.
We use alpha equivalence.
Note that the number of occurrences can be greater than 1 only when
the alternative does not depend on field parameters
-/
getNumOccsOf (alts : Array Alt) (i : Nat) : Nat := Id.run do
let code := alts[i]!.getCode
let mut n := 1
for j in [i+1:alts.size] do
if Code.alphaEqv alts[j]!.getCode code then
n := n+1
return n
/--
Add a default case to the given `cases` alternatives if there
are alternatives with equivalent (aka alpha equivalent) right hand sides.
-/
private def addDefault (alts : Array Alt) : SimpM (Array Alt) := do
if alts.size <= 1 || alts.any (· matches .default ..) then
return alts
else
let (max, noccs) := getMaxOccs alts
if noccs == 1 then
return alts
else
let mut altsNew := #[]
let mut first := true
markSimplified
for alt in alts do
if Code.alphaEqv alt.getCode max.getCode then
let .alt _ ps k := alt | unreachable!
eraseParams ps
unless first do
eraseFVarsAt k
first := false
else
altsNew := altsNew.push alt
return altsNew.push (.default max.getCode)
/--
Try to simplify projections `.proj _ i s` where `s` is constructor.
-/
def simpProj? (e : Expr) : OptionT SimpM Expr := do
let .proj _ i s := e | failure
let s ← findCtor s
let some (ctorVal, args) := s.constructorApp? (← getEnv) | failure
markSimplified
return args[ctorVal.numParams + i]!
/--
Application over application.
```
let _x.i := f a
_x.i b
```
is simplified to `f a b`.
-/
def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
guard e.isApp
let f := e.getAppFn
guard f.isFVar
let f ← findExpr f
guard <| f.isApp || f.isConst
markSimplified
return mkAppN f e.getAppArgs
def applyImplementedBy? (e : Expr) : OptionT SimpM Expr := do
guard <| (← read).config.implementedBy
let .const declName us := e.getAppFn | failure
let some declNameNew := getImplementedBy? (← getEnv) declName | failure
markSimplified
return mkAppN (.const declNameNew us) e.getAppArgs
/-- Try to apply simple simplifications. -/
def simpValue? (e : Expr) : SimpM (Option Expr) :=
-- TODO: more simplifications
simpProj? e <|> simpAppApp? e <|> applyImplementedBy? e
/--
Erase the given free variable from the local context,
and set the `simplified` flag to true.
-/
def eraseLocalDecl (fvarId : FVarId) : SimpM Unit := do
eraseFVar fvarId
markSimplified
/--
When the configuration flag `etaPoly = true`, we eta-expand
partial applications of functions that take local instances as arguments.
This kind of function is inlined or specialized, and we create new
simplification opportunities by eta-expanding them.
-/
def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do
guard <| (← read).config.etaPoly
let .const declName _ := letDecl.value.getAppFn | failure
let info ← getConstInfo declName
guard <| hasLocalInst info.type
guard <| !(← Meta.isInstance declName)
let some decl ← getStage1Decl? declName | failure
let numArgs := letDecl.value.getAppNumArgs
guard <| decl.getArity > numArgs
let params ← mkNewParams letDecl.type
let value := mkAppN letDecl.value (params.map (.fvar ·.fvarId))
let auxDecl ← mkAuxLetDecl value
let funDecl ← mkAuxFunDecl params (.let auxDecl (.return auxDecl.fvarId))
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
eraseLocalDecl letDecl.fvarId
return funDecl
mutual
/--
Simplify the given local function declaration.
-/
partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
let type ← normExpr decl.type
let params ← normParams decl.params
let value ← simp decl.value
decl.update type params value
/--
Try to simplify `cases` of `constructor`
-/
partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
let discr ← normFVar cases.discr
let discrExpr ← findCtor (.fvar discr)
let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) (useRaw := true) | return none
let (alt, cases) := cases.extractAlt! ctorVal.name
eraseFVarsAt (.cases cases)
markSimplified
match alt with
| .default k => simp k
| .alt _ params k =>
let fields := ctorArgs[ctorVal.numParams:]
let mut auxDecls := #[]
for param in params, field in fields do
/-
`field` may not be a free variable. Recall that `constructorApp?` has special support for numerals,
and `ctorArgs` contains a numeral if `discrExpr` is a numeral. We may have other cases in the future.
To make the code robust, we add auxiliary declarations whenever the `field` is not a free variable.
-/
if field.isFVar then
addSubst param.fvarId field
else
let auxDecl ← mkAuxLetDecl field
auxDecls := auxDecls.push (CodeDecl.let auxDecl)
addSubst param.fvarId (.fvar auxDecl.fvarId)
let k ← simp k
eraseParams params
attachCodeDecls auxDecls k
/--
Simplify `code`
-/
partial def simp (code : Code) : SimpM Code := withIncRecDepth do
incVisited
match code with
| .let decl k =>
let mut decl ← normLetDecl decl
if let some value ← simpValue? decl.value then
decl ← decl.updateValue value
if let some funDecl ← etaPolyApp? decl then
simp (.fun funDecl k)
else if decl.value.isFVar then
/- Eliminate `let _x_i := _x_j;` -/
addSubst decl.fvarId decl.value
eraseLocalDecl decl.fvarId
simp k
else if let some code ← inlineApp? decl k then
eraseFVar decl.fvarId
simp code
else if let some (decls, fvarId) ← inlineProjInst? decl.value then
addSubst decl.fvarId (.fvar fvarId)
eraseLocalDecl decl.fvarId
let k ← simp k
attachCodeDecls decls k
else
let k ← simp k
if (← isUsed decl.fvarId) then
markUsedLetDecl decl
return code.updateLet! decl k
else
/- Dead variable elimination -/
eraseLocalDecl decl.fvarId
return k
| .fun decl k | .jp decl k =>
let mut decl := decl
let toBeInlined ← isOnceOrMustInline decl.fvarId
if toBeInlined then
/-
If the declaration will be inlined, it is wasteful to eagerly simplify it.
So, we just normalize it (i.e., apply the substitution to it).
-/
decl ← normFunDecl decl
else
/-
Note that functions in `decl` will be marked as used even if `decl` is not actually used.
They will only be deleted in the next pass.
-/
if code.isFun then
decl ← decl.etaExpand
decl ← simpFunDecl decl
let k ← simp k
if (← isUsed decl.fvarId) then
if toBeInlined then
/-
`decl` was supposed to be inlined, but there are still references to it.
Thus, we must all variables in `decl` as used. Recall it was not eagerly simplified.
-/
markUsedFunDecl decl
return code.updateFun! decl k
else
/- Dead function elimination -/
eraseLocalDecl decl.fvarId
return k
| .return fvarId =>
let fvarId ← normFVar fvarId
markUsedFVar fvarId
return code.updateReturn! fvarId
| .unreach type =>
return code.updateUnreach! (← normExpr type)
| .jmp fvarId args =>
let fvarId ← normFVar fvarId
let args ← normExprs args
if let some code ← inlineJp? fvarId args then
simp code
else
markUsedFVar fvarId
args.forM markUsedExpr
return code.updateJmp! fvarId args
| .cases c =>
if let some k ← simpCasesOnCtor? c then
return k
else
let simpCasesDefault := do
let discr ← normFVar c.discr
let resultType ← normExpr c.resultType
markUsedFVar discr
let alts ← c.alts.mapMonoM fun alt =>
match alt with
| .alt ctorName ps k =>
withDiscrCtor discr ctorName ps do
return alt.updateCode (← simp k)
| .default k => return alt.updateCode (← simp k)
let alts ← addDefault alts
if alts.size == 1 && alts[0]! matches .default .. then
return alts[0]!.getCode
else
return code.updateCases! resultType discr alts
if let some jpFVarId ← isCasesOnCases? c then
withAddMustInline jpFVarId simpCasesDefault
else
simpCasesDefault
end
end Simp
open Simp
def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
updateFunDeclInfo decl.value
trace[Compiler.simp.inline.info] "{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}"
traceM `Compiler.simp.step do ppDecl decl
let value ← simp decl.value
traceM `Compiler.simp.step.new do return m!"{decl.name} :=\n{← ppCode value}"
let s ← get
trace[Compiler.simp.stat] "{decl.name}, size: {value.size}, # visited: {s.visited}, # inline: {s.inline}, # inline local: {s.inlineLocal}"
if (← get).simplified then
return some { decl with value }
else
return none
partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
let mut config := config
if (← isTemplateLike decl) then
/-
We do not eta-expand or inline partial applications in template like code.
Recall we don't want to generate code for them.
Remark: by eta-expanding partial applications in instances, we also make the simplifier
work harder when inlining instance projections.
-/
config := { config with etaPoly := false, inlinePartial := false }
go decl config
where
go (decl : Decl) (config : Config) : CompilerM Decl := do
if let some decl ← decl.simp? |>.run { config, declName := decl.name } |>.run' {} then
-- TODO: bound number of steps?
go decl config
else
return decl
def simp (config : Config := {}) (occurence : Nat := 0) : Pass :=
.mkPerDeclaration `simp (Decl.simp · config) .base (occurence := occurence)
builtin_initialize
registerTraceClass `Compiler.simp (inherited := true)
registerTraceClass `Compiler.simp.inline
registerTraceClass `Compiler.simp.stat
registerTraceClass `Compiler.simp.step
registerTraceClass `Compiler.simp.step.new
end Lean.Compiler.LCNF