lean4-htt/src/Lean/Compiler/CompilerM.lean
2022-08-19 11:56:22 -07:00

344 lines
12 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.ForEachExpr
import Lean.Meta.Transform
import Lean.Compiler.InferType
namespace Lean.Compiler
/--
The state managed by the `CompilerM` `Monad`.
-/
structure CompilerM.State where
/--
A `LocalContext` to store local declarations from let binders
and other constructs in as we move through `Expr`s.
-/
lctx : LocalContext := {}
letFVars : Array Expr := #[]
/-- Next auxiliary variable suffix -/
nextIdx : Nat := 1
deriving Inhabited
abbrev CompilerM := StateRefT CompilerM.State CoreM
instance : AddMessageContext CompilerM where
addMessageContext msgData := do
let env ← getEnv
let lctx := (← get).lctx
let opts ← getOptions
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
instance : MonadInferType CompilerM where
inferType e := do InferType.inferType e { lctx := (← get).lctx }
instance : MonadLCtx CompilerM where
getLCtx := return (← get).lctx
/--
Add a new local declaration with the given arguments to the `LocalContext` of `CompilerM`.
Returns the free variable representing the new declaration.
-/
def mkLocalDecl (binderName : Name) (type : Expr) (bi := BinderInfo.default) : CompilerM Expr := do
let fvarId ← mkFreshFVarId
modify fun s => { s with lctx := s.lctx.mkLocalDecl fvarId binderName type bi }
return .fvar fvarId
/--
Add a new let declaration with the given arguments to the `LocalContext` of `CompilerM`.
Returns the free variable representing the new declaration.
-/
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (nonDep : Bool) : CompilerM Expr := do
let fvarId ← mkFreshFVarId
let x := .fvar fvarId
modify fun s => { s with lctx := s.lctx.mkLetDecl fvarId binderName type value nonDep, letFVars := s.letFVars.push x }
return x
def mkAuxLetDeclName (prefixName := `_x) : CompilerM Name := do
let r := .num prefixName (← get).nextIdx
modify fun s => { s with nextIdx := s.nextIdx + 1 }
return r
/--
Create a new auxiliary let declaration with value `e` The name of the
declaration is guaranteed to be unique.
Returns the free variable representing the new declaration.
-/
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM Expr := do
if e.isFVar then
return e
else
mkLetDecl (← mkAuxLetDeclName prefixName) (← inferType e) e (nonDep := false)
/--
Create an auxiliary let declaration with value `e`, that is a join point.
recognizable by the _jp name prefix.
Returns the free variable representing the new declaration.
-/
def mkJpDecl (e : Expr) : CompilerM Expr := do
mkAuxLetDecl e `_jp
/--
Compute the maximum auxiliary let variable index that is used within `e`.
-/
def getMaxLetVarIdx (e : Expr) : IO Nat := do
let maxRef ← IO.mkRef 0
e.forEach fun
| .letE (.num (.str .anonymous s) i) .. =>
if s.get 0 == '_' then maxRef.modify (Nat.max · i) else pure ()
| _ => pure ()
maxRef.get
/--
Make sure all let-declarations have unique user-facing names.
We use this method when we want to retrieve candidates for code trasnformations. Examples:
let-declarations that are safe to unfold without producing code blowup, and join point detection.
Remark: user-facing names provided by users are preserved. We keep them as the prefix
of the new unique names.
-/
def ensureUniqueLetVarNamesCore (e : Expr) : StateRefT Nat CoreM Expr :=
let pre (e : Expr) : StateRefT Nat CoreM TransformStep := do
match e with
| .letE binderName type value body nonDep =>
let idx ← modifyGet fun s => (s, s+1)
let binderName' := match binderName with
| .num p _ => .num p idx
| _ => .num binderName idx
return .visit <| .letE binderName' type value body nonDep
| _ => return .visit e
Core.transform e pre
def ensureUniqueLetVarNames (e : Expr) : CompilerM Expr := do
let (e, nextIdx) ← ensureUniqueLetVarNamesCore e |>.run (← get).nextIdx
modify fun s => { s with nextIdx }
return e
/--
Move through all consecutive lambda abstractions at the top level of `e`.
Returning the body of the last one we find with all bound variables
instantiated as new free variables.
Returns a tuple consisting of:
1. An `Array` of all the newly created free variables
2. The (fully instantiated) body of the last lambda binder that was visited
-/
def visitLambda (e : Expr) : CompilerM (Array Expr × Expr) :=
go e #[]
where
go (e : Expr) (fvars : Array Expr) := do
if let .lam binderName type body binderInfo := e then
let type := type.instantiateRev fvars
let fvar ← mkLocalDecl binderName type binderInfo
go body (fvars.push fvar)
else
return (fvars, e.instantiateRev fvars)
/--
Given an expression representing a `match` return a tuple consisting of:
1. The motive
2. The discriminators
3. The expressions inside of the match arms
-/
def visitMatch (cases : Expr) (casesInfo : CasesInfo) : CompilerM (Expr × Array Expr × Array Expr) := do
let args := cases.getAppArgs
let motive := args.get! casesInfo.motivePos
let mut discrs := #[]
for i in casesInfo.discrsRange do
discrs := discrs.push args[i]!
let mut arms := #[]
for i in casesInfo.altsRange do
arms := arms.push (←visitLambda args[i]!).snd
return (motive, discrs, arms)
def withNewScopeImp (x : CompilerM α) : CompilerM α := do
let saved ← get
modify fun s => { s with letFVars := #[] }
try x
finally
let saved := { saved with nextIdx := (← get).nextIdx }
set saved
def withNewScope [MonadFunctorT CompilerM m] (x : m α) : m α :=
monadMap (m := CompilerM) withNewScopeImp x
/--
A typeclass for `Monad`s that are able to perform a visit operation for
let binders. That is move through a chain of consecutive let binders
and returning the body of the final one.
-/
class VisitLet (m : Type → Type) where
/--
Move through consecutive top level let binders in the first argument,
applying the function in the second argument to the binder name
and the values before the the local declarations for the binders are
created. The final return value is the body of the last let binder in
the chain.
-/
visitLet : Expr → (Name → Expr → m Expr) → m Expr
export VisitLet (visitLet)
def visitLetImp (e : Expr) (f : Name → Expr → CompilerM Expr) : CompilerM Expr :=
go e #[]
where
go (e : Expr) (fvars : Array Expr) : CompilerM Expr := do
if let .letE binderName type value body nonDep := e then
let type := type.instantiateRev fvars
let value := value.instantiateRev fvars
let value ← f binderName value
let fvar ← mkLetDecl binderName type value nonDep
go body (fvars.push fvar)
else
return e.instantiateRev fvars
instance : VisitLet CompilerM where
visitLet := visitLetImp
instance [VisitLet m] : VisitLet (ReaderT ρ m) where
visitLet e f ctx := visitLet e (f · · ctx)
instance [VisitLet m] : VisitLet (StateRefT' ω σ m) := inferInstanceAs (VisitLet (ReaderT _ _))
def mkLetUsingScope (e : Expr) : CompilerM Expr := do
let e ← if e.isLambda then
/-
In LCNF, terminal expression in a `let`-block must not be a lambda.
-/
mkAuxLetDecl e
else
pure e
return (← get).lctx.mkLambda (← get).letFVars e
/--
Shorthand for `LocalContext.mkLambda` with the `LocalContext` of `CompilerM`.
-/
def mkLambda (xs : Array Expr) (e : Expr) : CompilerM Expr :=
return (← get).lctx.mkLambda xs e
/--
Given a join point `jp` of the form `fun y => body`, if `jp` is simple (see `isSimpleLCNF`), just return it
Otherwise, create `let jp := fun y => body` declaration and return `jp`.
-/
def mkJpDeclIfNotSimple (jp : Expr) : CompilerM Expr := do
if (← isSimpleLCNF jp.bindingBody!) then
-- Join point is too simple, we eagerly inline it.
return jp
else
mkJpDecl jp
/--
Create "jump" to join point `jp` with value `e`.
Remarks:
- If `e` is unreachable, then result is unreachable
- Add `cast` if `e`'s type is not compatible with the type expected by `jp`. It avoids `cast` on `cast`.
- If creates an auxiliary let-declaration if `e` is not a free variable.
-/
def mkJump (jp : Expr) (e : Expr) : CompilerM Expr := do
let .forallE _ d b _ ← inferType jp | unreachable!
let mkJpApp (x : Expr) := mkApp jp x |>.headBeta
if isLcUnreachable e then
mkLcUnreachable b
else if compatibleTypes (← inferType e) d then
let x ← mkAuxLetDecl e
return mkJpApp x
else if let some x := isLcCast? e then
let x ← mkAuxLetDecl (← mkLcCast x d)
return mkJpApp x
else
let x ← mkAuxLetDecl e
let x ← mkAuxLetDecl (← mkLcCast x d)
return mkJpApp x
/--
Given a let-declaration block `e`, return a new block that jumps to `jp` at its "exit points".
`e` must contain all join points declarations used in `e`.
Example: Suppose `e` is of the form
```
let _jp.1 := fun y =>
let _x.1 := Nat.add y y
Nat.mul _x.1 y
casesOn _x.2
(fun x => _jp.1 x)
(fun x => Nat.add x x)
```
then, `attachJp e _jp.2` produces the new let-block.
```
let _jp.1 := fun y =>
let _x.1 := Nat.add y y
let _x.2 := Nat.mul _x.1 y
_jp.2 _x.2
casesOn _x.2
(fun x => _jp.1 x)
(fun x =>
let _x.3 := Nat.add x x
_jp.2 _x.3)
```
If `e` contains a jump to a join point `_jp.i` not declared in `e`, we throw an exception because
an invalid block would be generated. It would be invalid because the input join poinp `jp` would not
be applied to `_jp.i`. Note that, we could have decided to create a copy of `_jp.i` where we apply `jp` to it,
by we decided to not do it to avoid code duplication.
-/
partial def attachJp (e : Expr) (jp : Expr) : CompilerM Expr := do
withNewScope do
mkLetUsingScope (← visitLet e #[] |>.run {})
where
visitLambda (e : Expr) : ReaderT FVarIdSet CompilerM Expr := do
withNewScope do
let (as, e) ← Compiler.visitLambda e
let e ← mkLetUsingScope (← visitLet e #[])
mkLambda as e
visitCases (casesInfo : CasesInfo) (cases : Expr) : ReaderT FVarIdSet CompilerM Expr := do
let mut args := cases.getAppArgs
let .forallE _ _ b _ ← inferType jp | unreachable! -- jp's type is guaranteed to be an nondependent arrow
args := casesInfo.updateResultingType args b
for i in casesInfo.altsRange do
args ← args.modifyM i visitLambda
return mkAppN cases.getAppFn args
visitLet (e : Expr) (xs : Array Expr) : ReaderT FVarIdSet CompilerM Expr := do
match e with
| .letE binderName type value body nonDep =>
let mkDecl (type value : Expr) := do
let x ← mkLetDecl binderName type value nonDep
withReader (fun jps => if isJpBinderName binderName then jps.insert x.fvarId! else jps) do
visitLet body (xs.push x)
let type := type.instantiateRev xs
let value := value.instantiateRev xs
if isJpBinderName binderName then
let value ← visitLambda value
-- Recall that the resulting type of join point may change after the attachment
let type ← inferType value
mkDecl type value
else
mkDecl type value
| _ =>
let e := e.instantiateRev xs
if let some fvarId ← isJump? e then
unless (← read).contains fvarId do
throwError "failed to attach join point to let-block, it contains a out of scope join point"
return e
else if let some casesInfo ← isCasesApp? e then
visitCases casesInfo e
else
mkJump jp e
/--
Given a let-declaration block `e` and `jp? = some jp`, return a new block that jumps
to `jp` at its "exit points". If `jp? = none`, it just returns `e`.
-/
def attachOptJp (e : Expr) (jp? : Option Expr) : CompilerM Expr :=
if let some jp := jp? then
attachJp e jp
else
return e
end Lean.Compiler