136 lines
5 KiB
Text
136 lines
5 KiB
Text
/-
|
||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
module
|
||
|
||
prelude
|
||
public import Lean.Compiler.LCNF.InferType
|
||
|
||
public section
|
||
|
||
namespace Lean.Compiler.LCNF
|
||
|
||
/-- Helper class for lifting `CompilerM.codeBind` -/
|
||
class MonadCodeBind (m : Type → Type) where
|
||
codeBind : (c : Code) → (f : FVarId → m Code) → m Code
|
||
|
||
/--
|
||
Return code that is equivalent to `c >>= f`. That is, executes `c`, and then `f x`, where
|
||
`x` is a variable that contains the result of `c`'s computation.
|
||
|
||
If `c` contains a jump to a join point `jp_i` not declared in `c`, we throw an exception because
|
||
an invalid block would be generated. It would be invalid because `f` would not
|
||
be applied to `jp_i`. Note that, we could have decided to create a copy of `jp_i` where we apply `f` to it,
|
||
by we decided to not do it to avoid code duplication.
|
||
-/
|
||
abbrev Code.bind [MonadCodeBind m] (c : Code) (f : FVarId → m Code) : m Code :=
|
||
MonadCodeBind.codeBind c f
|
||
|
||
partial def CompilerM.codeBind (c : Code) (f : FVarId → CompilerM Code) : CompilerM Code := do
|
||
go c |>.run {}
|
||
where
|
||
go (c : Code) : ReaderT FVarIdSet CompilerM Code := do
|
||
match c with
|
||
| .let decl k => return .let decl (← go k)
|
||
| .fun decl k => return .fun decl (← go k)
|
||
| .jp decl k =>
|
||
let value ← go decl.value
|
||
let type ← value.inferParamType decl.params
|
||
let decl ← decl.update' type value
|
||
withReader (fun s => s.insert decl.fvarId) do
|
||
return .jp decl (← go k)
|
||
| .cases c =>
|
||
let alts ← c.alts.mapM fun
|
||
| .alt ctorName params k => return .alt ctorName params (← go k)
|
||
| .default k => return .default (← go k)
|
||
if alts.isEmpty then
|
||
throwError "`Code.bind` failed, empty `cases` found"
|
||
let resultType ← mkCasesResultType alts
|
||
return .cases { c with alts, resultType }
|
||
| .return fvarId => f fvarId
|
||
| .jmp fvarId .. =>
|
||
unless (← read).contains fvarId do
|
||
throwError "`Code.bind` failed, it contains a out of scope join point"
|
||
return c
|
||
| .unreach type =>
|
||
/-
|
||
Create an auxiliary parameter `aux : type` to compute the resulting type of `f aux`.
|
||
This code is not very efficient, we could ask caller to provide the type of `c >>= f`,
|
||
but this is more convenient, and this case is seldom reached.
|
||
-/
|
||
let auxParam ← mkAuxParam type
|
||
let k ← f auxParam.fvarId
|
||
let typeNew ← k.inferType
|
||
eraseCode k
|
||
eraseParam auxParam
|
||
return .unreach typeNew
|
||
|
||
instance : MonadCodeBind CompilerM where
|
||
codeBind := CompilerM.codeBind
|
||
|
||
instance [MonadCodeBind m] : MonadCodeBind (ReaderT ρ m) where
|
||
codeBind c f ctx := c.bind fun fvarId => f fvarId ctx
|
||
|
||
instance [STWorld ω m] [MonadCodeBind m] : MonadCodeBind (StateRefT' ω σ m) where
|
||
codeBind c f sref := c.bind fun fvarId => f fvarId sref
|
||
|
||
/--
|
||
Create new parameters for the given arrow type.
|
||
Example: if `type` is `Nat → Bool → Int`, the result is
|
||
an array containing two new parameters with types `Nat` and `Bool`.
|
||
-/
|
||
partial def mkNewParams (type : Expr) : CompilerM (Array Param) :=
|
||
go type #[] #[]
|
||
where
|
||
go (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do
|
||
match type with
|
||
| .forallE _ d b _ =>
|
||
let d := d.instantiateRev xs
|
||
let p ← mkAuxParam d
|
||
go b (xs.push (.fvar p.fvarId)) (ps.push p)
|
||
| _ =>
|
||
let type := type.instantiateRev xs
|
||
let type' := type.headBeta
|
||
if type' != type then
|
||
go type' #[] ps
|
||
else
|
||
return ps
|
||
|
||
def isEtaExpandCandidateCore (type : Expr) (params : Array Param) : Bool :=
|
||
let typeArity := getArrowArity type
|
||
let valueArity := params.size
|
||
typeArity > valueArity
|
||
|
||
abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl) : Bool :=
|
||
isEtaExpandCandidateCore decl.type decl.params
|
||
|
||
def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : CompilerM (Array Param × Code) := do
|
||
let valueType ← instantiateForall type (params.map (mkFVar ·.fvarId))
|
||
let psNew ← mkNewParams valueType
|
||
let params := params ++ psNew
|
||
let xs := psNew.map fun p => .fvar p.fvarId
|
||
let value ← value.bind fun fvarId => do
|
||
let auxDecl ← mkAuxLetDecl (.fvar fvarId xs)
|
||
return .let auxDecl (.return auxDecl.fvarId)
|
||
return (params, value)
|
||
|
||
def etaExpandCore? (type : Expr) (params : Array Param) (value : Code) : CompilerM (Option (Array Param × Code)) := do
|
||
if isEtaExpandCandidateCore type params then
|
||
etaExpandCore type params value
|
||
else
|
||
return none
|
||
|
||
def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
|
||
let some (params, value) ← etaExpandCore? decl.type decl.params decl.value | return decl
|
||
decl.update decl.type params value
|
||
|
||
def Decl.etaExpand (decl : Decl) : CompilerM Decl := do
|
||
match decl.value with
|
||
| .code code =>
|
||
let some (params, newCode) ← etaExpandCore? decl.type decl.params code | return decl
|
||
return { decl with params, value := .code newCode}
|
||
| .extern .. => return decl
|
||
|
||
end Lean.Compiler.LCNF
|