lean4-htt/src/Lean/Compiler/LCNF/Simp.lean
2022-09-01 20:52:08 -07:00

871 lines
29 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.Recognizers
import Lean.Compiler.InlineAttrs
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.ElimDead
import Lean.Compiler.LCNF.Bind
import Lean.Compiler.LCNF.PrettyPrinter
import Lean.Compiler.LCNF.Stage1
namespace Lean.Compiler.LCNF
namespace Simp
/--
Local function usage information used to decide whether it should be inlined or not.
The information is an approximation, but it is on the "safe" side. That is, if we tagged
a function with `.once`, then it is applied only once. A local function may be marked as
`.many`, but after simplifications the number of applications may reduce to 1. This is not
a big problem in practice because we run the simplifier multiple times, and this information
is recomputed from scratch at the beginning of each simplification step.
-/
inductive FunDeclInfo where
| /--
Local function is applied once, and must be inlined.
-/
once
| /--
Local function is applied many times, and will only be inlined
if it is small.
-/
many
| /--
Function must be inlined.
-/
mustInline
deriving Repr, Inhabited
/--
Local function declaration statistics.
-/
structure FunDeclInfoMap where
/--
Mapping from local function name to inlining information.
-/
map : Std.HashMap FVarId FunDeclInfo := {}
deriving Inhabited
def FunDeclInfoMap.format (s : FunDeclInfoMap) : CompilerM Format := do
let mut result := Format.nil
for (fvarId, info) in s.map.toList do
let localDecl ← getLocalDecl fvarId
result := result ++ "\n" ++ f!"{localDecl.userName} ↦ {repr info}"
return result
/--
Add new occurrence for the local function with binder name `key`.
-/
def FunDeclInfoMap.add (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } =>
match map.find? fvarId with
| some .once => { map := map.insert fvarId .many }
| none => { map := map.insert fvarId .once }
| _ => { map }
/--
Add new occurrence for the local function with binder name `key`.
-/
def FunDeclInfoMap.addMustInline (s : FunDeclInfoMap) (fvarId : FVarId) : FunDeclInfoMap :=
match s with
| { map } => { map := map.insert fvarId .mustInline }
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
match e with
| .fvar fvarId =>
if let some decl ← LCNF.findFunDecl? fvarId then
return some decl
else if let .ldecl (value := v) .. ← getLocalDecl fvarId then
findFunDecl? v
else
return none
| .mdata _ e => findFunDecl? e
| _ => return none
partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
match e with
| .fvar fvarId =>
let .ldecl (value := v) .. ← getLocalDecl fvarId | return e
findExpr v
| .mdata _ e' => if skipMData then findExpr e' else return e
| _ => return e
structure Config where
smallThreshold : Nat := 1
structure Context where
config : Config := {}
structure State where
/--
Free variable substitution. We use it to implement inlining and removing redundant variables `let _x.i := _x.j`
-/
subst : FVarSubst := {}
/--
Track used local declarations to be able to eliminate dead variables.
-/
used : UsedLocalDecls := {}
/--
Mapping used to decide whether a local function declaration must be inlined or not.
-/
funDeclInfoMap : FunDeclInfoMap := {}
/--
`true` if some simplification was performed in the current simplification pass.
-/
simplified : Bool := false
/--
Number of visited `let-declarations` and terminal values.
This is a performance counter, and currently has no impact on code generation.
-/
visited : Nat := 0
/--
Number of definitions inlined.
This is a performance counter.
-/
inline : Nat := 0
/--
Number of local functions inlined.
This is a performance counter.
-/
inlineLocal : Nat := 0
abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
instance : MonadFVarSubst SimpM where
getSubst := return (← get).subst
def markSimplified : SimpM Unit :=
modify fun s => { s with simplified := true }
def incVisited : SimpM Unit :=
modify fun s => { s with visited := s.visited + 1 }
def incInline : SimpM Unit :=
modify fun s => { s with inline := s.inline + 1 }
def incInlineLocal : SimpM Unit :=
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :=
go code
where
go (code : Code) : SimpM Unit := do
match code with
| .let decl k =>
if decl.value.isApp then
if let some funDecl ← findFunDecl? decl.value.getAppFn then
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.add funDecl.fvarId }
go k
| .fun decl k =>
if mustInline then
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addMustInline decl.fvarId }
go decl.value; go k
| .jp decl k => go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .return .. | .jmp .. | .unreach .. => return ()
def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
match (← get).funDeclInfoMap.map.find? fvarId with
| some .once | some .mustInline => return true
| _ => return false
def isSmall (decl : FunDecl) : SimpM Bool :=
return decl.value.sizeLe (← read).config.smallThreshold
def shouldInlineLocal (decl : FunDecl) : SimpM Bool := do
if (← isOnceOrMustInline decl.fvarId) then
return true
else
isSmall decl
structure InlineCandidateInfo where
isLocal : Bool
params : Array Param
/-- Value (lambda expression) of the function to be inlined. -/
value : Code
def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
| { params, .. } => params.size
def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
let f := e.getAppFn
if let .const declName us ← findExpr f then
unless hasInlineAttribute (← getEnv) declName do return none
-- TODO: check whether function is recursive or not.
-- We can skip the test and store function inline so far.
let some decl ← getStage1Decl? declName | return none
let numArgs := e.getAppNumArgs
let arity := decl.getArity
if numArgs < arity then return none
let params := decl.instantiateParamsLevelParams us
let value := decl.instantiateValueLevelParams us
incInline
return some {
isLocal := false
params, value
}
else if let some decl ← findFunDecl? f then
unless (← shouldInlineLocal decl) do return none
let numArgs := e.getAppNumArgs
let arity := decl.getArity
if numArgs < arity then return none
incInlineLocal
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
return some {
isLocal := true
params := decl.params
value := decl.value
}
else
return none
private partial def oneExitPointQuick (c : Code) : Bool :=
go c
where
go (c : Code) : Bool :=
match c with
| .let _ k | .fun _ k => go k
-- Approximation, the cases may have many unreachable alternatives, and only reachable.
| .cases c => c.alts.size == 1 && c.alts.any fun alt => go alt.getCode
-- Approximation, we assume that any code containing join points have more than one exit point
| .jp .. | .jmp .. => false
| .return .. | .unreach .. => true
def betaReduce (params : Array Param) (code : Code) (args : Array Expr) : SimpM Code := do
-- TODO: add necessary casts to `args`
let mut subst := {}
for param in params, arg in args do
subst := subst.insert param.fvarId arg
let code ← code.internalize subst
updateFunDeclInfo code
return code
/--
If `e` is an application that can be inlined, inline it.
`k?` is the optional "continuation" for `e`, and it may contain loose bound variables
that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instantiateRev xs`
is an expression without loose bound variables.
-/
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
if k matches .unreach .. then return some k
let e := letDecl.value
let some info ← inlineCandidate? e | return none
markSimplified
let args := e.getAppArgs
let numArgs := args.size
trace[Compiler.simp.inline] "inlining {e}"
let code ← betaReduce info.params info.value args[:info.arity]
let fvarId := letDecl.fvarId
if k.isReturnOf fvarId && numArgs == info.arity then
/- Easy case, the continuation `k` is just returning the result of the application. -/
return code
else if oneExitPointQuick code then
/-
`code` has only one exit point, thus we can attach the continuation directly there,
and simplify the result.
-/
code.bind fun fvarId' => do
/- fvarId' is the result of the computation -/
if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') args[info.arity:])
let k ← replaceFVar k fvarId decl.fvarId
return .let decl k
else
replaceFVar k fvarId fvarId'
else
/-
`code` has multiple exit points, and the continuation is non-trivial
Thus, we create an auxiliary join point.
-/
let jpParam ← mkAuxParam (← inferType (mkAppN e.getAppFn args[:info.arity]))
let jpValue ← if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) args[info.arity:])
let k ← replaceFVar k fvarId decl.fvarId
pure <| .let decl k
else
replaceFVar k fvarId jpParam.fvarId
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
return Code.jp jpDecl code
def findCtor (e : Expr) : SimpM Expr := do
-- TODO: add support for mapping discriminants to constructors in branches
findExpr e
/--
Try to simplify projections `.proj _ i s` where `s` is constructor.
-/
def simpProj? (e : Expr) : OptionT SimpM Expr := do
let .proj _ i s := e | failure
let s ← findCtor s
let some (ctorVal, args) := s.constructorApp? (← getEnv) | failure
markSimplified
return args[ctorVal.numParams + i]!
/--
Application over application.
```
let _x.i := f a
_x.i b
```
is simplified to `f a b`.
-/
def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
guard e.isApp
let f := e.getAppFn
guard f.isFVar
let f ← findExpr f
guard <| f.isApp || f.isConst
markSimplified
return mkAppN f e.getAppArgs
def markUsedFVar (fvarId : FVarId) : SimpM Unit :=
modify fun s => { s with used := s.used.insert fvarId }
def markUsedExpr (e : Expr) : SimpM Unit :=
modify fun s => { s with used := collectLocalDecls s.used e }
def markUsedLetDecl (letDecl : LetDecl) : SimpM Unit :=
markUsedExpr letDecl.value
def isUsed (fvarId : FVarId) : SimpM Bool :=
return (← get).used.contains fvarId
def eraseLocalDecl (fvarId : FVarId) : SimpM Unit := do
eraseFVar fvarId
markSimplified
/--
Add substitution `fvarId ↦ val`. `val` is a free variable, or
it is a type, type former, or `lcErased`.
-/
def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit :=
modify fun s => { s with subst := s.subst.insert fvarId val }
/-- Try to apply simple simplifications. -/
def simpValue? (e : Expr) : SimpM (Option Expr) :=
simpProj? e <|> simpAppApp? e
mutual
partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
let type ← normExpr decl.type
let params ← normParams decl.params
let value ← simp decl.value
decl.update type params value
/-- Try to simplify `cases` of `constructor` -/
partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
let discr ← normFVar cases.discr
let discrExpr ← findExpr (.fvar discr)
let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) | return none
let (alt, cases) := cases.extractAlt! ctorVal.name
eraseFVarsAt (.cases cases)
markSimplified
match alt with
| .default k => simp k
| .alt _ params k =>
let fields := ctorArgs[ctorVal.numParams:]
for param in params, field in fields do
addSubst param.fvarId field
let k ← simp k
eraseParams params
return k
partial def simp (code : Code) : SimpM Code := do
incVisited
match code with
| .let decl k =>
let mut decl ← normLetDecl decl
if decl.value.isFVar then
/- Eliminate `let _x_i := _x_j;` -/
addSubst decl.fvarId decl.value
eraseLocalDecl decl.fvarId
simp k
else if let some code ← inlineApp? decl k then
eraseFVar decl.fvarId
simp code
else
if let some value ← simpValue? decl.value then
decl ← decl.updateValue value
/- TODO: simple value simplifications & inlining -/
let k ← simp k
if !decl.pure || (← isUsed decl.fvarId) then
markUsedLetDecl decl
return code.updateLet! decl k
else
/- Dead variable elimination -/
eraseLocalDecl decl.fvarId
return k
| .fun decl k | .jp decl k =>
/-
Variables in `decl` will be marked as used even if the function is eliminated.
Thus, they will only be deleted in the second pass.
Pontential improvement: if `decl.fvarId` is marked with `once` or `mustInline`, we will probably
inline this declaration, and it may be wasteful to simplify it here.
The alternative option is to just normalize `decl`, and if used mark all occurring there as used.
-/
let decl ← simpFunDecl decl
let k ← simp k
if (← isUsed decl.fvarId) then
return code.updateFun! decl k
else
/- Dead function elimination -/
eraseLocalDecl decl.fvarId
return k
| .return fvarId =>
let fvarId ← normFVar fvarId
markUsedFVar fvarId
return code.updateReturn! fvarId
| .unreach type =>
return code.updateUnreach! (← normExpr type)
| .jmp fvarId args =>
-- TODO: inline join point
let fvarId ← normFVar fvarId
let args ← normExprs args
markUsedFVar fvarId
args.forM markUsedExpr
return code.updateJmp! fvarId args
| .cases c =>
if let some k ← simpCasesOnCtor? c then
return k
else
-- TODO: other cases simplifications
let discr ← normFVar c.discr
let resultType ← normExpr c.resultType
markUsedFVar discr
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
return code.updateCases! resultType discr alts
end
end Simp
open Simp
def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
updateFunDeclInfo decl.value
trace[Compiler.simp.inline.info] "{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}"
traceM `Compiler.simp.step do ppDecl decl
let value ← simp decl.value
traceM `Compiler.simp.step.new do return m!"{decl.name} :=\n{← ppCode value}"
let s ← get
trace[Compiler.simp.stat] "{decl.name}, size: {value.size}, # visited: {s.visited}, # inline: {s.inline}, # inline local: {s.inlineLocal}"
if (← get).simplified then
return some { decl with value }
else
return none
partial def Decl.simp (decl : Decl) : CompilerM Decl := do
if let some decl ← decl.simp? |>.run {} |>.run' {} then
-- TODO: bound number of steps?
decl.simp
else
return decl
builtin_initialize
registerTraceClass `Compiler.simp.inline
registerTraceClass `Compiler.simp.inline.info
registerTraceClass `Compiler.simp.stat
registerTraceClass `Compiler.simp.step
registerTraceClass `Compiler.simp.step.new
registerTraceClass `Compiler.simp.projInst
end Lean.Compiler.LCNF
#exit -- TODO: port rest of file
namespace Lean.Compiler
namespace Simp
/-
Ensure binder names are unique, and update local function information.
If `mustInline = true`, then local functions in `e` are marked with binders of the
form `_mustInline.<idx>`.
Remark: we used to store the `mustInline` information in the map `localInfoMap`,
using a `.mustInline` constructor at `LocalFunInfo`. However, this was incorrect
because there is no guarantee that we will be able to inline all occurrences of the
function in the current `simp` step. Since, we recompute `localInfoMap` from scratch
at the beginning of each compiler pass, the information was being lost.
-/
/--
This function performs the following operations in the given expression in a single pass.
- Ensure binder names for let-declarations are unique.
- Update local function information. That is, it updates the map `localInfoMap`.
- Apply `e.instantiateRev args`.
We use it to "internalize" expressions at startup and when performing inlining.
-/
def instantiateRevInternalize (e : Expr) (args : Array Expr) (mustInline := false) : SimpM Expr := do
let lctx ← getLCtx
let nextIdx := (← getThe CompilerM.State).nextIdx
let localInfoMap ← modifyGet fun s => (s.localInfoMap, { s with localInfoMap := {} })
let (e, { localInfoMap, nextIdx }) := instantiateRevInternalizeCore lctx e args mustInline |>.run { nextIdx, localInfoMap }
modifyThe CompilerM.State fun s => { s with nextIdx }
modify fun s => { s with localInfoMap }
return e
def isSmallValue (value : Expr) : SimpM Bool := do
lcnfSizeLe value (← read).config.smallThreshold
def shouldInlineLocal (localDecl : LocalDecl) : SimpM Bool := do
if (← isOnceOrMustInline localDecl.userName) then
return true
else
isSmallValue localDecl.value
/--
If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`,
return `e'`.
-/
def expandTrivialExpr (e : Expr) : SimpM Expr := do
if e.isFVar then
let e' ← findExpr e
unless e'.isLambda do
if e != e' then
markSimplified
return e'
return e
/--
Given `value` of the form `let x_1 := v_1; ...; let x_n := v_n; e`,
return `let x_1; ...; let x_n := v_n; let y : type := e; body`.
This methods assumes `type` and `value` do not have loose bound variables.
Remark: `body` may have many loose bound variables, and the loose bound variables > 0
must be lifted by `n`.
-/
private def mkFlatLet (y : Name) (type : Expr) (value : Expr) (body : Expr) (nonDep : Bool := false) : Expr :=
match value with
| .letE binderName type value'@(.lam ..) (.bvar 0) nonDep =>
/- Easy case that is often generated by `inlineProjInst?` -/
.letE binderName type value' body nonDep
| _ => go value 0
where
go (value : Expr) (i : Nat) : Expr :=
match value with
| .letE n t v b d => .letE n t v (go b (i+1)) d
| _ => .letE y type value (body.liftLooseBVars 1 i) nonDep
def mkLetUsingScope (e : Expr) : SimpM Expr := do
modify fun s => { s with mkLet := s.mkLet + 1 }
Compiler.mkLetUsingScope e
def mkLambda (as : Array Expr) (e : Expr) : SimpM Expr := do
modify fun s => { s with mkLambda := s.mkLambda + 1 }
Compiler.mkLambda as e
/--
Helper function for simplifying expressions such as
```
let _x.1 := fun {α β} => ReaderT.bind ... α β;
_x.1
```
to `ReaderT.bing ...`
This kind of expression is often generated by `inlineProjInst?`.
They are a side-effect of the implicit lambda feature when we write
```
instance [Monad m] : Monad (ReaderT ρ m) where
bind := ReaderT.bind
```
Lean introduces implicit lambdas at the `ReaderT.bind` above.
-/
private def simpUsingEtaReduction (e : Expr) : Expr :=
match e with
| .letE _ _ v@(.lam ..) (.bvar 0) _ =>
let v := v.eta
if v.isLambda then e else v
| .letE n t v b d => .letE n t v (simpUsingEtaReduction b) d
| _ => e
private def etaExpand (type : Expr) (value : Expr) : SimpM Expr := do
let typeArity := getArrowArity type
let valueArity := getLambdaArity value
if typeArity <= valueArity then
-- It can be < because of the "any" type
return value
else
withNewScope do
let (xs, _) ← visitArrow type
let value := getLambdaBody value
let value := value.instantiateRev xs[:valueArity]
let valueType ← inferType value
let f ← mkLocalDecl (← mkFreshUserName `_f) valueType
let k ← mkLambda #[f] (mkAppN f xs[valueArity:])
let value ← attachJp value k
mkLambda xs value
/--
Auxiliary function for projecting "type class dictionary access".
That is, we are trying to extract one of the type class instance elements.
Remark: We do not consider parent instances to be elements.
For example, suppose `e` is `_x_4.1`, and we have
```
_x_2 : Monad (ReaderT Bool (ExceptT String Id)) := @ReaderT.Monad Bool (ExceptT String Id) _x_1
_x_3 : Applicative (ReaderT Bool (ExceptT String Id)) := _x_2.1
_x_4 : Functor (ReaderT Bool (ExceptT String Id)) := _x_3.1
```
Then, we will expand `_x_4.1` since it corresponds to the `Functor` `map` element,
and its type is not a type class, but is of the form
```
{α β : Type u} → (α → β) → ...
```
In the example above, the compiler should not expand `_x_3.1` or `_x_2.1` because they are
type class applications: `Functor` and `Applicative` respectively.
By eagerly expanding them, we may produce inefficient and bloated code.
For example, we may be using `_x_3.1` to invoke a function that expects a `Functor` instance.
By expanding `_x_3.1` we will be just expanding the code that creates this instance.
-/
partial def inlineProjInst? (e : Expr) : OptionT SimpM Expr := do
let .proj _ _ s := e | failure
let sType ← inferType s
guard (← isClass? sType).isSome
let eType ← inferType e
guard (← isClass? eType).isNone
/-
We use `withNewScope` + `mkLetUsingScope` to filter the relevant let-declarations.
Recall that we are extracting only one of the type class elements.
-/
let value ← withNewScope do mkLetUsingScope (← visitProj e)
markSimplified
let value := simpUsingEtaReduction value
let value ← internalize value (mustInline := true)
trace[Compiler.simp.projInst] "{e} =>\n{value}"
return value
where
visitProj (e : Expr) : OptionT SimpM Expr := do
let .proj _ i s := e | unreachable!
let s ← visit s
if let some (ctorVal, ctorArgs) := s.constructorApp? (← getEnv) then
return ctorArgs[ctorVal.numParams + i]!
else
failure
visit (e : Expr) : OptionT SimpM Expr := do
let e ← findExpr e
if e.isConstructorApp (← getEnv) then
return e
else if e.isProj then
/- We may have nested projections as we traverse parent classes. -/
visit (← visitProj e)
else
let .const declName us := e.getAppFn | failure
let some decl ← getStage1Decl? declName | failure
guard <| decl.getArity == e.getAppNumArgs
let value := decl.value.instantiateLevelParams decl.levelParams us
let value := value.beta e.getAppArgs
/-
Here, we just go inside of the let-declaration block without trying to simplify it.
Reason: a type class instannce may have many elements, and it does not make sense to simplify
all of them when we are extracting only one of them.
-/
let value ← Compiler.visitLet (m := SimpM) value fun _ value => return value
visit value
def betaReduce (e : Expr) (args : Array Expr) : SimpM Expr := do
-- TODO: add necessary casts to `args`
let result ← instantiateRevInternalize (getLambdaBody e) args
return result
/--
Try "cases on cases" simplification.
If `casesFn args` is of the form
```
casesOn _x.i
(... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁)
...
(... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ)
```
where `_jp.k` is a join point of the form
```
let _jp.k := fun y =>
casesOn y ...
```
Then, inline `_jp.k`. The idea is to force the `casesOn` application in the join point to
reduce after the inlining step.
Example: consider the following declarations
```
@[inline] def pred? (x : Nat) : Option Nat :=
match x with
| 0 => none
| x+1 => some x
def isZero (x : Nat) :=
match pred? x with
| some _ => false
| none => true
```
After inlining `pred?` in `isZero`, we have
```
let _jp.1 := fun y : Option Nat =>
casesOn y true (fun y => false)
casesOn x
(let _x.1 := none; _jp.1 _x.1)
(fun n => let _x.2 := some n; _jp.1 _x.2)
```
and this simplification is applicable, producing
```
casesOn x true (fun n => false)
```
-/
def simpCasesOnCases? (casesInfo : CasesInfo) (casesFn : Expr) (args : Array Expr) : OptionT SimpM Expr := do
let mut jpFirst? := none
for i in casesInfo.altsRange do
let alt := args[i]!
let jp ← isJpCtor? alt
if let some jpFirst := jpFirst? then
guard <| jp == jpFirst
else
let some localDecl ← findDecl? jp | failure
let .lam _ _ jpBody _ := localDecl.value | failure
guard (← isCasesApp? jpBody).isSome
jpFirst? := jp
let some jpFVarId := jpFirst? | failure
let some localDecl ← findDecl? jpFVarId | failure
let .lam _ _ jpBody _ := localDecl.value | failure
let mut args := args
for i in casesInfo.altsRange do
args := args.modify i (inlineJp · jpBody)
return mkAppN casesFn args
where
isJpCtor? (alt : Expr) : OptionT SimpM FVarId := do
match alt with
| .lam _ _ b _ => isJpCtor? b
| .letE _ _ v b _ => match b with
| .letE .. => isJpCtor? b
| .app (.fvar fvarId) (.bvar 0) =>
let some localDecl ← findDecl? fvarId | failure
guard localDecl.isJp
guard <| v.isConstructorApp (← getEnv)
return fvarId
| _ => failure
| _ => failure
inlineJp (alt : Expr) (jpBody : Expr) : Expr :=
match alt with
| .lam n d b bi => .lam n d (inlineJp b jpBody) bi
| .letE n t v b nd => .letE n t v (inlineJp b jpBody) nd
| _ => jpBody
mutual
/--
Simplify the given lambda expression.
If `checkEmptyTypes := true`, then return `fun a_i : t_i => lcUnreachable` if
`t_i` is the `Empty` type.
-/
partial def visitLambda (e : Expr) (checkEmptyTypes := false): SimpM Expr :=
withNewScope do
let (as, e) ← Compiler.visitLambdaCore e
if checkEmptyTypes then
for a in as do
if (← isEmptyType (← inferType a)) then
let e := e.instantiateRev as
let unreach ← mkLcUnreachable (← inferType e)
let r ← mkLambda as unreach
return r
let e ← mkLetUsingScope (← visitLet e as)
mkLambda as e
partial def visitCases (casesInfo : CasesInfo) (e : Expr) : SimpM Expr := do
let f := e.getAppFn
let mut args := e.getAppArgs
let major := args[casesInfo.discrsRange.stop - 1]!
let major ← findExpr major
if let some (ctorVal, ctorArgs) := major.constructorApp? (← getEnv) then
/- Simplify `casesOn` constructor -/
let ctorIdx := ctorVal.cidx
let alt := args[casesInfo.altsRange.start + ctorIdx]!
let ctorFields := ctorArgs[ctorVal.numParams:]
let alt := alt.beta ctorFields
assert! !alt.isLambda
markSimplified
visitLet alt
else if let some e ← simpCasesOnCases? casesInfo f args then
visitCases casesInfo e
else
for i in casesInfo.altsRange do
args ← args.modifyM i (visitLambda · (checkEmptyTypes := true))
return mkAppN f args
/-- Try to apply simple simplifications. -/
partial def simpValue? (e : Expr) : SimpM (Option Expr) :=
simpProj? e <|> simpAppApp? e <|> inlineProjInst? e
/--
Let-declaration basic block visitor. `e` may contain loose bound variables that
still have to be instantiated with `xs`.
-/
partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
modify fun s => { s with visited := s.visited + 1 }
match e with
| .letE binderName type value body nonDep =>
let type := type.instantiateRev xs
let mut value := value.instantiateRev xs
if value.isLambda then
unless (← isOnceOrMustInline binderName) do
/-
If the local function will be inlined anyway, we don't simplify it here,
we do it after its is inlined and we have information about the actual arguments.
-/
value ← visitLambda value
unless isJpBinderName binderName || (← isSmallValue value) do
/-
This lambda is not going to be inlined. So, we eta-expand it IF it is not a join point.
Recall that local function declarations that are not join points will be lambda lifted
anyway. Eta-expanding here also creates new simplification opportunities for
monadic local functions before we perform the lambda-lifting.
For example, consider the local function
```
let _x.23 := fun xs body =>
...
let _x.29 := StateRefT'.lift _x.24
let _x.30 := _x.25 _x.29
let _x.31 := fun a => ...
ReaderT.bind _x.30 _x.31
```
The function applications `StateRefT'.lift` and `ReaderT.bind` are not inlined because
they are partially applied. After, we eta-expand this code, it will be reduced at this stage.
-/
value ← etaExpand type value
else if let some value' ← simpValue? value then
if value'.isLet then
let e := mkFlatLet binderName type value' body nonDep
let e ← visitLet e xs
return e
value := value'
if value.isFVar then
/- Eliminate `let _x_i := _x_j;` -/
markSimplified
visitLet body (xs.push value)
else if let some e ← inlineApp? value xs body then
return e
else
let x ← mkLetDecl binderName type value nonDep
visitLet body (xs.push x)
| _ =>
let e := e.instantiateRev xs
if let some value ← simpValue? e then
visitLet value
else if let some casesInfo ← isCasesApp? e then
visitCases casesInfo e
else if let some e ← inlineApp? e #[] none then
return e
else
expandTrivialExpr e
end
end Simp
end Lean.Compiler