refactor: introduce a phase separation to the IR (#12214)
This PR introduces a phase separation to the LCNF IR. This is a preparation for the merge of the old `Lean.Compiler.IR` and the new `Lean.Compiler.LCNF` framework. The change parametrizes all relevant `LCNF` data structures over a `Purity` parameter and additionally carries around proofs that the `Purity` has certain values, depending on what's required. This is done as opposed to indexing the types over `Purity` because we do (almost) never have to store the `Purity` value for phase generic structures this way.
This commit is contained in:
parent
6d370ec3c2
commit
5ce756f350
60 changed files with 1247 additions and 1047 deletions
|
|
@ -40,14 +40,14 @@ structure BuilderState where
|
|||
For this reason we carry around these kinds of bindings in this substitution and apply it whenever
|
||||
we access an fvar in the conversion.
|
||||
-/
|
||||
subst : LCNF.FVarSubst := {}
|
||||
subst : LCNF.FVarSubst .pure := {}
|
||||
|
||||
abbrev M := StateRefT BuilderState CoreM
|
||||
|
||||
instance : LCNF.MonadFVarSubst M false where
|
||||
instance : LCNF.MonadFVarSubst M .pure false where
|
||||
getSubst := return (← get).subst
|
||||
|
||||
instance : LCNF.MonadFVarSubstState M where
|
||||
instance : LCNF.MonadFVarSubstState M .pure where
|
||||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||||
|
||||
def M.run (x : M α) : CoreM α := do
|
||||
|
|
@ -102,7 +102,7 @@ def lowerLitValue (v : LCNF.LitValue) : LitVal × IRType :=
|
|||
| .uint64 v => ⟨.num (UInt64.toNat v), .uint64⟩
|
||||
| .usize v => ⟨.num (UInt64.toNat v), .usize⟩
|
||||
|
||||
def lowerArg (a : LCNF.Arg) : M Arg := do
|
||||
def lowerArg (a : LCNF.Arg .pure) : M Arg := do
|
||||
match a with
|
||||
| .fvar fvarId => getFVarValue fvarId
|
||||
| .erased | .type .. => return .erased
|
||||
|
|
@ -121,15 +121,15 @@ def lowerProj (base : VarId) (ctorInfo : CtorInfo) (field : CtorFieldInfo)
|
|||
| .erased => ⟨.erased, .erased⟩
|
||||
| .void => ⟨.erased, .void⟩
|
||||
|
||||
def lowerParam (p : LCNF.Param) : M Param := do
|
||||
def lowerParam (p : LCNF.Param .pure) : M Param := do
|
||||
let x ← bindVar p.fvarId
|
||||
let ty ← toIRType p.type
|
||||
if ty.isVoid || ty.isErased then
|
||||
Compiler.LCNF.addSubst p.fvarId .erased
|
||||
Compiler.LCNF.addSubst p.fvarId (.erased : LCNF.Arg .pure)
|
||||
return { x, borrow := p.borrow, ty }
|
||||
|
||||
mutual
|
||||
partial def lowerCode (c : LCNF.Code) : M FnBody := do
|
||||
partial def lowerCode (c : LCNF.Code .pure) : M FnBody := do
|
||||
match c with
|
||||
| .let decl k => lowerLet decl k
|
||||
| .jp decl k =>
|
||||
|
|
@ -149,7 +149,7 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
|
|||
for idx in 0...ps.size do
|
||||
let p := ps[idx]!
|
||||
if idx == info.fieldIdx then
|
||||
LCNF.addSubst p.fvarId (.fvar cases.discr)
|
||||
LCNF.addSubst p.fvarId (.fvar cases.discr : LCNF.Arg .pure)
|
||||
else
|
||||
bindErased p.fvarId
|
||||
lowerCode k
|
||||
|
|
@ -165,7 +165,7 @@ partial def lowerCode (c : LCNF.Code) : M FnBody := do
|
|||
| .unreach .. => return .unreachable
|
||||
| .fun .. => panic! "all local functions should be λ-lifted"
|
||||
|
||||
partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
|
||||
partial def lowerLet (decl : LCNF.LetDecl .pure) (k : LCNF.Code .pure) : M FnBody := do
|
||||
let value ← LCNF.normLetValue decl.value
|
||||
match value with
|
||||
| .lit litValue =>
|
||||
|
|
@ -175,7 +175,7 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
|
|||
| .proj typeName i fvarId =>
|
||||
if let some info ← hasTrivialStructure? typeName then
|
||||
if info.fieldIdx == i then
|
||||
LCNF.addSubst decl.fvarId (.fvar fvarId)
|
||||
LCNF.addSubst decl.fvarId (.fvar fvarId : LCNF.Arg .pure)
|
||||
else
|
||||
bindErased decl.fvarId
|
||||
lowerCode k
|
||||
|
|
@ -302,11 +302,11 @@ where
|
|||
else
|
||||
mkOverApplication name numParams args
|
||||
|
||||
partial def lowerAlt (discr : VarId) (a : LCNF.Alt) : M Alt := do
|
||||
partial def lowerAlt (discr : VarId) (a : LCNF.Alt .pure) : M Alt := do
|
||||
match a with
|
||||
| .alt ctorName params code =>
|
||||
let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
|
||||
let lowerParams (params : Array LCNF.Param) (fields : Array CtorFieldInfo) : M FnBody := do
|
||||
let lowerParams (params : Array (LCNF.Param .pure)) (fields : Array CtorFieldInfo) : M FnBody := do
|
||||
let rec loop (i : Nat) : M FnBody := do
|
||||
match params[i]?, fields[i]? with
|
||||
| some param, some field =>
|
||||
|
|
@ -340,7 +340,7 @@ where resultTypeForArity (type : Lean.Expr) (arity : Nat) : Lean.Expr :=
|
|||
| .const ``lcErased _ => mkConst ``lcErased
|
||||
| _ => panic! "invalid arity"
|
||||
|
||||
def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do
|
||||
def lowerDecl (d : LCNF.Decl .pure) : M (Option Decl) := do
|
||||
let params ← d.params.mapM lowerParam
|
||||
let mut resultType ← lowerResultType d.type d.params.size
|
||||
let taggedReturn := taggedReturnAttr.hasTag (← getEnv) d.name
|
||||
|
|
@ -366,7 +366,7 @@ def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do
|
|||
|
||||
end ToIR
|
||||
|
||||
def toIR (decls: Array LCNF.Decl) : CoreM (Array Decl) := do
|
||||
def toIR (decls: Array (LCNF.Decl .pure)) : CoreM (Array Decl) := do
|
||||
let mut irDecls := #[]
|
||||
for decl in decls do
|
||||
if let some irDecl ← ToIR.lowerDecl decl |>.run then
|
||||
|
|
|
|||
|
|
@ -40,14 +40,14 @@ def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do
|
|||
else
|
||||
return false
|
||||
|
||||
def eqvArg (a₁ a₂ : Arg) : EqvM Bool := do
|
||||
def eqvArg (a₁ a₂ : Arg pu) : EqvM Bool := do
|
||||
match a₁, a₂ with
|
||||
| .type e₁, .type e₂ => eqvType e₁ e₂
|
||||
| .type e₁ _, .type e₂ _ => eqvType e₁ e₂
|
||||
| .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂
|
||||
| .erased, .erased => return true
|
||||
| _, _ => return false
|
||||
|
||||
def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
|
||||
def eqvArgs (as₁ as₂ : Array (Arg pu)) : EqvM Bool := do
|
||||
if as₁.size = as₂.size then
|
||||
for a₁ in as₁, a₂ in as₂ do
|
||||
unless (← eqvArg a₁ a₂) do
|
||||
|
|
@ -56,19 +56,19 @@ def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
|
|||
else
|
||||
return false
|
||||
|
||||
def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do
|
||||
def eqvLetValue (e₁ e₂ : LetValue pu) : EqvM Bool := do
|
||||
match e₁, e₂ with
|
||||
| .lit v₁, .lit v₂ => return v₁ == v₂
|
||||
| .erased, .erased => return true
|
||||
| .proj s₁ i₁ x₁, .proj s₂ i₂ x₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
|
||||
| .const n₁ us₁ as₁, .const n₂ us₂ as₂ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
|
||||
| .proj s₁ i₁ x₁ _, .proj s₂ i₂ x₂ _ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
|
||||
| .const n₁ us₁ as₁ _, .const n₂ us₂ as₂ _ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
|
||||
| .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂
|
||||
| _, _ => return false
|
||||
|
||||
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
|
||||
withReader (·.insert fvarId₂ fvarId₁) x
|
||||
|
||||
@[inline] def withParams (params₁ params₂ : Array Param) (x : EqvM Bool) : EqvM Bool := do
|
||||
@[inline] def withParams (params₁ params₂ : Array (Param pu)) (x : EqvM Bool) : EqvM Bool := do
|
||||
if h : params₂.size = params₁.size then
|
||||
let rec @[specialize] go (i : Nat) : EqvM Bool := do
|
||||
if h : i < params₁.size then
|
||||
|
|
@ -85,7 +85,7 @@ def eqvLetValue (e₁ e₂ : LetValue) : EqvM Bool := do
|
|||
else
|
||||
return false
|
||||
|
||||
def sortAlts (alts : Array Alt) : Array Alt :=
|
||||
def sortAlts (alts : Array (Alt pu)) : Array (Alt pu) :=
|
||||
alts.qsort fun
|
||||
| .alt .., .default .. => true
|
||||
| .alt ctorName₁ .., .alt ctorName₂ .. => Name.lt ctorName₁ ctorName₂
|
||||
|
|
@ -93,13 +93,13 @@ def sortAlts (alts : Array Alt) : Array Alt :=
|
|||
|
||||
mutual
|
||||
|
||||
partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
|
||||
partial def eqvAlts (alts₁ alts₂ : Array (Alt pu)) : EqvM Bool := do
|
||||
if alts₁.size = alts₂.size then
|
||||
let alts₁ := sortAlts alts₁
|
||||
let alts₂ := sortAlts alts₂
|
||||
for alt₁ in alts₁, alt₂ in alts₂ do
|
||||
match alt₁, alt₂ with
|
||||
| .alt ctorName₁ ps₁ k₁, .alt ctorName₂ ps₂ k₂ =>
|
||||
| .alt ctorName₁ ps₁ k₁ _, .alt ctorName₂ ps₂ k₂ _ =>
|
||||
unless ctorName₁ == ctorName₂ do return false
|
||||
unless (← withParams ps₁ ps₂ (eqv k₁ k₂)) do return false
|
||||
| .default k₁, .default k₂ => unless (← eqv k₁ k₂) do return false
|
||||
|
|
@ -108,13 +108,13 @@ partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
|
|||
else
|
||||
return false
|
||||
|
||||
partial def eqv (code₁ code₂ : Code) : EqvM Bool := do
|
||||
partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
|
||||
match code₁, code₂ with
|
||||
| .let decl₁ k₁, .let decl₂ k₂ =>
|
||||
eqvType decl₁.type decl₂.type <&&>
|
||||
eqvLetValue decl₁.value decl₂.value <&&>
|
||||
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
|
||||
| .fun decl₁ k₁, .fun decl₂ k₂
|
||||
| .fun decl₁ k₁ _, .fun decl₂ k₂ _
|
||||
| .jp decl₁ k₁, .jp decl₂ k₂ =>
|
||||
eqvType decl₁.type decl₂.type <&&>
|
||||
withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&>
|
||||
|
|
@ -135,7 +135,7 @@ end AlphaEqv
|
|||
/--
|
||||
Return `true` if `c₁` and `c₂` are alpha equivalent.
|
||||
-/
|
||||
def Code.alphaEqv (c₁ c₂ : Code) : Bool :=
|
||||
def Code.alphaEqv (c₁ c₂ : Code pu) : Bool :=
|
||||
AlphaEqv.eqv c₁ c₂ |>.run {}
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -13,15 +13,21 @@ public section
|
|||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
builtin_initialize auxDeclCacheExt : CacheExtension Decl Name ← CacheExtension.register
|
||||
structure AuxDeclCacheKey where
|
||||
pu : Purity
|
||||
decl : Decl pu
|
||||
deriving BEq, Hashable
|
||||
|
||||
builtin_initialize auxDeclCacheExt : CacheExtension AuxDeclCacheKey Name ← CacheExtension.register
|
||||
|
||||
inductive CacheAuxDeclResult where
|
||||
| new
|
||||
| alreadyCached (declName : Name)
|
||||
|
||||
def cacheAuxDecl (decl : Decl) : CompilerM CacheAuxDeclResult := do
|
||||
def cacheAuxDecl (decl : Decl pu) : CompilerM CacheAuxDeclResult := do
|
||||
let key := { decl with name := .anonymous }
|
||||
let key ← normalizeFVarIds key
|
||||
let key := ⟨pu, key⟩
|
||||
match (← auxDeclCacheExt.find? key) with
|
||||
| some declName =>
|
||||
return .alreadyCached declName
|
||||
|
|
|
|||
|
|
@ -24,14 +24,50 @@ and the approach described in the paper
|
|||
|
||||
-/
|
||||
|
||||
structure Param where
|
||||
/--
|
||||
This type is used to index the fundamental LCNF IR data structures. Depending on its value different
|
||||
constructors are available for the different semantic phases of LCNF.
|
||||
|
||||
Notably in order to save memory we never index the IR types over `Purity`. Instead the type is
|
||||
parametrized by the phase and the individual constructors might carry a proof (that will be erased)
|
||||
that they are only allowed in a certain phase.
|
||||
-/
|
||||
inductive Purity where
|
||||
/--
|
||||
The code we are acting on is still pure, things like reordering up to value dependencies are
|
||||
acceptable.
|
||||
-/
|
||||
| pure
|
||||
/--
|
||||
The code we are acting on is to be considered generally impure, doing reorderings is potentially
|
||||
no longer legal.
|
||||
-/
|
||||
| impure
|
||||
deriving Inhabited, DecidableEq, Hashable
|
||||
|
||||
instance : ToString Purity where
|
||||
toString
|
||||
| .pure => "pure"
|
||||
| .impure => "impure"
|
||||
|
||||
@[inline]
|
||||
def Purity.withAssertPurity [Inhabited α] (is : Purity) (should : Purity)
|
||||
(k : (is = should) → α) : α :=
|
||||
if h : is = should then
|
||||
k h
|
||||
else
|
||||
panic! s!"Purity should be {should} but is {is}, this is a bug"
|
||||
|
||||
scoped macro "purity_tac" : tactic => `(tactic| first | with_reducible rfl | assumption)
|
||||
|
||||
structure Param (pu : Purity) where
|
||||
fvarId : FVarId
|
||||
binderName : Name
|
||||
type : Expr
|
||||
borrow : Bool
|
||||
deriving Inhabited, BEq
|
||||
|
||||
def Param.toExpr (p : Param) : Expr :=
|
||||
def Param.toExpr (p : Param pu) : Expr :=
|
||||
.fvar p.fvarId
|
||||
|
||||
inductive LitValue where
|
||||
|
|
@ -55,111 +91,111 @@ def LitValue.toExpr : LitValue → Expr
|
|||
| .uint64 v => .app (.const ``UInt64.ofNat []) (.lit (.natVal (UInt64.toNat v)))
|
||||
| .usize v => .app (.const ``USize.ofNat []) (.lit (.natVal (UInt64.toNat v)))
|
||||
|
||||
inductive Arg where
|
||||
inductive Arg (pu : Purity) where
|
||||
| erased
|
||||
| fvar (fvarId : FVarId)
|
||||
| type (expr : Expr)
|
||||
| type (expr : Expr) (h : pu = .pure := by purity_tac)
|
||||
deriving Inhabited, BEq, Hashable
|
||||
|
||||
def Param.toArg (p : Param) : Arg :=
|
||||
def Param.toArg (p : Param pu) : Arg pu :=
|
||||
.fvar p.fvarId
|
||||
|
||||
def Arg.toExpr (arg : Arg) : Expr :=
|
||||
def Arg.toExpr (arg : Arg pu) : Expr :=
|
||||
match arg with
|
||||
| .erased => erasedExpr
|
||||
| .fvar fvarId => .fvar fvarId
|
||||
| .type e => e
|
||||
| .type e _ => e
|
||||
|
||||
private unsafe def Arg.updateTypeImp (arg : Arg) (type' : Expr) : Arg :=
|
||||
private unsafe def Arg.updateTypeImp (arg : Arg pu) (type' : Expr) : Arg pu :=
|
||||
match arg with
|
||||
| .type ty => if ptrEq ty type' then arg else .type type'
|
||||
| .type ty _ => if ptrEq ty type' then arg else .type type'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg) (type : Expr) : Arg
|
||||
@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg pu) (type : Expr) : Arg pu
|
||||
|
||||
private unsafe def Arg.updateFVarImp (arg : Arg) (fvarId' : FVarId) : Arg :=
|
||||
private unsafe def Arg.updateFVarImp (arg : Arg pu) (fvarId' : FVarId) : Arg pu :=
|
||||
match arg with
|
||||
| .fvar fvarId => if fvarId' == fvarId then arg else .fvar fvarId'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg) (fvarId' : FVarId) : Arg
|
||||
@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg pu) (fvarId' : FVarId) : Arg pu
|
||||
|
||||
inductive LetValue where
|
||||
inductive LetValue (pu : Purity) where
|
||||
| lit (value : LitValue)
|
||||
| erased
|
||||
| proj (typeName : Name) (idx : Nat) (struct : FVarId)
|
||||
| const (declName : Name) (us : List Level) (args : Array Arg)
|
||||
| fvar (fvarId : FVarId) (args : Array Arg)
|
||||
| proj (typeName : Name) (idx : Nat) (struct : FVarId) (h : pu = .pure := by purity_tac)
|
||||
| const (declName : Name) (us : List Level) (args : Array (Arg pu)) (h : pu = .pure := by purity_tac)
|
||||
| fvar (fvarId : FVarId) (args : Array (Arg pu))
|
||||
deriving Inhabited, BEq, Hashable
|
||||
|
||||
def Arg.toLetValue (arg : Arg) : LetValue :=
|
||||
def Arg.toLetValue (arg : Arg pu) : LetValue pu :=
|
||||
match arg with
|
||||
| .fvar fvarId => .fvar fvarId #[]
|
||||
| .erased | .type .. => .erased
|
||||
|
||||
private unsafe def LetValue.updateProjImp (e : LetValue) (fvarId' : FVarId) : LetValue :=
|
||||
private unsafe def LetValue.updateProjImp (e : LetValue pu) (fvarId' : FVarId) : LetValue pu :=
|
||||
match e with
|
||||
| .proj s i fvarId => if fvarId == fvarId' then e else .proj s i fvarId'
|
||||
| .proj s i fvarId _ => if fvarId == fvarId' then e else .proj s i fvarId'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetValue.updateProjImp] opaque LetValue.updateProj! (e : LetValue) (fvarId' : FVarId) : LetValue
|
||||
@[implemented_by LetValue.updateProjImp] opaque LetValue.updateProj! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu
|
||||
|
||||
private unsafe def LetValue.updateConstImp (e : LetValue) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetValue :=
|
||||
private unsafe def LetValue.updateConstImp (e : LetValue pu) (declName' : Name) (us' : List Level) (args' : Array (Arg pu)) : LetValue pu :=
|
||||
match e with
|
||||
| .const declName us args => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args'
|
||||
| .const declName us args _ => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetValue.updateConstImp] opaque LetValue.updateConst! (e : LetValue) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetValue
|
||||
@[implemented_by LetValue.updateConstImp] opaque LetValue.updateConst! (e : LetValue pu) (declName' : Name) (us' : List Level) (args' : Array (Arg pu)) : LetValue pu
|
||||
|
||||
private unsafe def LetValue.updateFVarImp (e : LetValue) (fvarId' : FVarId) (args' : Array Arg) : LetValue :=
|
||||
private unsafe def LetValue.updateFVarImp (e : LetValue pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : LetValue pu :=
|
||||
match e with
|
||||
| .fvar fvarId args => if fvarId == fvarId' && ptrEq args args' then e else .fvar fvarId' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetValue.updateFVarImp] opaque LetValue.updateFVar! (e : LetValue) (fvarId' : FVarId) (args' : Array Arg) : LetValue
|
||||
@[implemented_by LetValue.updateFVarImp] opaque LetValue.updateFVar! (e : LetValue pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : LetValue pu
|
||||
|
||||
private unsafe def LetValue.updateArgsImp (e : LetValue) (args' : Array Arg) : LetValue :=
|
||||
private unsafe def LetValue.updateArgsImp (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu :=
|
||||
match e with
|
||||
| .const declName us args => if ptrEq args args' then e else .const declName us args'
|
||||
| .const declName us args h => if ptrEq args args' then e else .const declName us args'
|
||||
| .fvar fvarId args => if ptrEq args args' then e else .fvar fvarId args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetValue.updateArgsImp] opaque LetValue.updateArgs! (e : LetValue) (args' : Array Arg) : LetValue
|
||||
@[implemented_by LetValue.updateArgsImp] opaque LetValue.updateArgs! (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu
|
||||
|
||||
def LetValue.toExpr (e : LetValue) : Expr :=
|
||||
def LetValue.toExpr (e : LetValue pu) : Expr :=
|
||||
match e with
|
||||
| .lit v => v.toExpr
|
||||
| .erased => erasedExpr
|
||||
| .proj n i s => .proj n i (.fvar s)
|
||||
| .const n us as => mkAppN (.const n us) (as.map Arg.toExpr)
|
||||
| .proj n i s _ => .proj n i (.fvar s)
|
||||
| .const n us as _ => mkAppN (.const n us) (as.map Arg.toExpr)
|
||||
| .fvar fvarId as => mkAppN (.fvar fvarId) (as.map Arg.toExpr)
|
||||
|
||||
structure LetDecl where
|
||||
structure LetDecl (pu : Purity) where
|
||||
fvarId : FVarId
|
||||
binderName : Name
|
||||
type : Expr
|
||||
value : LetValue
|
||||
value : LetValue pu
|
||||
deriving Inhabited, BEq
|
||||
|
||||
mutual
|
||||
|
||||
inductive Alt where
|
||||
| alt (ctorName : Name) (params : Array Param) (code : Code)
|
||||
| default (code : Code)
|
||||
inductive Alt (pu : Purity) where
|
||||
| alt (ctorName : Name) (params : Array (Param pu)) (code : Code pu) (h : pu = .pure := by purity_tac)
|
||||
| default (code : Code pu)
|
||||
|
||||
inductive FunDecl where
|
||||
| mk (fvarId : FVarId) (binderName : Name) (params : Array Param) (type : Expr) (value : Code)
|
||||
inductive FunDecl (pu : Purity) where
|
||||
| mk (fvarId : FVarId) (binderName : Name) (params : Array (Param pu)) (type : Expr) (value : Code pu)
|
||||
|
||||
inductive Cases where
|
||||
| mk (typeName : Name) (resultType : Expr) (discr : FVarId) (alts : Array Alt)
|
||||
inductive Cases (pu : Purity) where
|
||||
| mk (typeName : Name) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu))
|
||||
deriving Inhabited
|
||||
|
||||
inductive Code where
|
||||
| let (decl : LetDecl) (k : Code)
|
||||
| fun (decl : FunDecl) (k : Code)
|
||||
| jp (decl : FunDecl) (k : Code)
|
||||
| jmp (fvarId : FVarId) (args : Array Arg)
|
||||
| cases (cases : Cases)
|
||||
inductive Code (pu : Purity) where
|
||||
| let (decl : LetDecl pu) (k : Code pu)
|
||||
| fun (decl : FunDecl pu) (k : Code pu) (h : pu = .pure := by purity_tac)
|
||||
| jp (decl : FunDecl pu) (k : Code pu)
|
||||
| jmp (fvarId : FVarId) (args : Array (Arg pu))
|
||||
| cases (cases : Cases pu)
|
||||
| return (fvarId : FVarId)
|
||||
| unreach (type : Expr)
|
||||
deriving Inhabited
|
||||
|
|
@ -167,99 +203,99 @@ inductive Code where
|
|||
end
|
||||
|
||||
@[inline]
|
||||
def FunDecl.fvarId : FunDecl → FVarId
|
||||
def FunDecl.fvarId : FunDecl pu → FVarId
|
||||
| .mk (fvarId := fvarId) .. => fvarId
|
||||
|
||||
@[inline]
|
||||
def FunDecl.binderName : FunDecl → Name
|
||||
def FunDecl.binderName : FunDecl pu → Name
|
||||
| .mk (binderName := binderName) .. => binderName
|
||||
|
||||
@[inline]
|
||||
def FunDecl.params : FunDecl → Array Param
|
||||
def FunDecl.params : FunDecl pu → Array (Param pu)
|
||||
| .mk (params := params) .. => params
|
||||
|
||||
@[inline]
|
||||
def FunDecl.type : FunDecl → Expr
|
||||
def FunDecl.type : FunDecl pu → Expr
|
||||
| .mk (type := type) .. => type
|
||||
|
||||
@[inline]
|
||||
def FunDecl.value : FunDecl → Code
|
||||
def FunDecl.value : FunDecl pu → Code pu
|
||||
| .mk (value := value) .. => value
|
||||
|
||||
@[inline]
|
||||
def FunDecl.updateBinderName : FunDecl → Name → FunDecl
|
||||
def FunDecl.updateBinderName : FunDecl pu → Name → FunDecl pu
|
||||
| .mk fvarId _ params type value, new =>
|
||||
.mk fvarId new params type value
|
||||
|
||||
@[inline]
|
||||
def FunDecl.toParam (decl : FunDecl) (borrow : Bool) : Param :=
|
||||
def FunDecl.toParam (decl : FunDecl pu) (borrow : Bool) : Param pu :=
|
||||
match decl with
|
||||
| .mk fvarId binderName _ type .. => ⟨fvarId, binderName, type, borrow⟩
|
||||
|
||||
@[inline]
|
||||
def Cases.typeName : Cases → Name
|
||||
def Cases.typeName : Cases pu → Name
|
||||
| .mk (typeName := typeName) .. => typeName
|
||||
|
||||
@[inline]
|
||||
def Cases.resultType : Cases → Expr
|
||||
def Cases.resultType : Cases pu → Expr
|
||||
| .mk (resultType := resultType) .. => resultType
|
||||
|
||||
@[inline]
|
||||
def Cases.discr : Cases → FVarId
|
||||
def Cases.discr : Cases pu → FVarId
|
||||
| .mk (discr := discr) .. => discr
|
||||
|
||||
@[inline]
|
||||
def Cases.alts : Cases → Array Alt
|
||||
def Cases.alts : Cases pu → Array (Alt pu)
|
||||
| .mk (alts := alts) .. => alts
|
||||
|
||||
@[inline]
|
||||
def Cases.updateAlts : Cases → Array Alt → Cases
|
||||
def Cases.updateAlts : Cases pu → Array (Alt pu) → Cases pu
|
||||
| .mk typeName resultType discr _, new =>
|
||||
.mk typeName resultType discr new
|
||||
|
||||
deriving instance Inhabited for Alt
|
||||
deriving instance Inhabited for FunDecl
|
||||
|
||||
def FunDecl.getArity (decl : FunDecl) : Nat :=
|
||||
def FunDecl.getArity (decl : FunDecl pu) : Nat :=
|
||||
decl.params.size
|
||||
|
||||
/--
|
||||
Return the constructor names that have an explicit (non-default) alternative.
|
||||
-/
|
||||
def Cases.getCtorNames (c : Cases) : NameSet :=
|
||||
def Cases.getCtorNames (c : Cases pu) : NameSet :=
|
||||
c.alts.foldl (init := {}) fun ctorNames alt =>
|
||||
match alt with
|
||||
| .default _ => ctorNames
|
||||
| .alt ctorName .. => ctorNames.insert ctorName
|
||||
|
||||
inductive CodeDecl where
|
||||
| let (decl : LetDecl)
|
||||
| fun (decl : FunDecl)
|
||||
| jp (decl : FunDecl)
|
||||
inductive CodeDecl (pu : Purity) where
|
||||
| let (decl : LetDecl pu)
|
||||
| fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac)
|
||||
| jp (decl : FunDecl pu)
|
||||
deriving Inhabited
|
||||
|
||||
def CodeDecl.fvarId : CodeDecl → FVarId
|
||||
| .let decl | .fun decl | .jp decl => decl.fvarId
|
||||
def CodeDecl.fvarId : CodeDecl pu → FVarId
|
||||
| .let decl | .fun decl _ | .jp decl => decl.fvarId
|
||||
|
||||
def attachCodeDecls (decls : Array CodeDecl) (code : Code) : Code :=
|
||||
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
|
||||
go decls.size code
|
||||
where
|
||||
go (i : Nat) (code : Code) : Code :=
|
||||
go (i : Nat) (code : Code pu) : Code pu :=
|
||||
if i > 0 then
|
||||
match decls[i-1]! with
|
||||
| .let decl => go (i-1) (.let decl code)
|
||||
| .fun decl => go (i-1) (.fun decl code)
|
||||
| .fun decl _ => go (i-1) (.fun decl code)
|
||||
| .jp decl => go (i-1) (.jp decl code)
|
||||
else
|
||||
code
|
||||
|
||||
mutual
|
||||
private unsafe def eqImp (c₁ c₂ : Code) : Bool :=
|
||||
private unsafe def eqImp (c₁ c₂ : Code pu) : Bool :=
|
||||
if ptrEq c₁ c₂ then
|
||||
true
|
||||
else match c₁, c₂ with
|
||||
| .let d₁ k₁, .let d₂ k₂ => d₁ == d₂ && eqImp k₁ k₂
|
||||
| .fun d₁ k₁, .fun d₂ k₂
|
||||
| .fun d₁ k₁ _, .fun d₂ k₂ _
|
||||
| .jp d₁ k₁, .jp d₂ k₂ => eqFunDecl d₁ d₂ && eqImp k₁ k₂
|
||||
| .cases c₁, .cases c₂ => eqCases c₁ c₂
|
||||
| .jmp j₁ as₁, .jmp j₂ as₂ => j₁ == j₂ && as₁ == as₂
|
||||
|
|
@ -267,7 +303,7 @@ mutual
|
|||
| .unreach t₁, .unreach t₂ => t₁ == t₂
|
||||
| _, _ => false
|
||||
|
||||
private unsafe def eqFunDecl (d₁ d₂ : FunDecl) : Bool :=
|
||||
private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool :=
|
||||
if ptrEq d₁ d₂ then
|
||||
true
|
||||
else
|
||||
|
|
@ -275,62 +311,62 @@ mutual
|
|||
d₁.params == d₂.params && d₁.type == d₂.type &&
|
||||
eqImp d₁.value d₂.value
|
||||
|
||||
private unsafe def eqCases (c₁ c₂ : Cases) : Bool :=
|
||||
private unsafe def eqCases (c₁ c₂ : Cases pu) : Bool :=
|
||||
c₁.resultType == c₂.resultType && c₁.discr == c₂.discr &&
|
||||
c₁.typeName == c₂.typeName && c₁.alts.isEqv c₂.alts eqAlt
|
||||
|
||||
private unsafe def eqAlt (a₁ a₂ : Alt) : Bool :=
|
||||
private unsafe def eqAlt (a₁ a₂ : Alt pu) : Bool :=
|
||||
match a₁, a₂ with
|
||||
| .default k₁, .default k₂ => eqImp k₁ k₂
|
||||
| .alt c₁ ps₁ k₁, .alt c₂ ps₂ k₂ => c₁ == c₂ && ps₁ == ps₂ && eqImp k₁ k₂
|
||||
| .alt c₁ ps₁ k₁ _, .alt c₂ ps₂ k₂ _ => c₁ == c₂ && ps₁ == ps₂ && eqImp k₁ k₂
|
||||
| _, _ => false
|
||||
end
|
||||
|
||||
@[implemented_by eqImp] protected opaque Code.beq : Code → Code → Bool
|
||||
@[implemented_by eqImp] protected opaque Code.beq : Code pu → Code pu → Bool
|
||||
|
||||
instance : BEq Code where
|
||||
instance : BEq (Code pu) where
|
||||
beq := Code.beq
|
||||
|
||||
@[implemented_by eqFunDecl] protected opaque FunDecl.beq : FunDecl → FunDecl → Bool
|
||||
@[implemented_by eqFunDecl] protected opaque FunDecl.beq : FunDecl pu → FunDecl pu → Bool
|
||||
|
||||
instance : BEq FunDecl where
|
||||
instance : BEq (FunDecl pu) where
|
||||
beq := FunDecl.beq
|
||||
|
||||
def Alt.getCode : Alt → Code
|
||||
def Alt.getCode : Alt pu → Code pu
|
||||
| .default k => k
|
||||
| .alt _ _ k => k
|
||||
| .alt _ _ k _ => k
|
||||
|
||||
def Alt.getParams : Alt → Array Param
|
||||
def Alt.getParams : Alt pu → Array (Param pu)
|
||||
| .default _ => #[]
|
||||
| .alt _ ps _ => ps
|
||||
| .alt _ ps _ _ => ps
|
||||
|
||||
def Alt.forCodeM [Monad m] (alt : Alt) (f : Code → m Unit) : m Unit := do
|
||||
def Alt.forCodeM [Monad m] (alt : Alt pu) (f : Code pu → m Unit) : m Unit := do
|
||||
match alt with
|
||||
| .default k => f k
|
||||
| .alt _ _ k => f k
|
||||
| .alt _ _ k _ => f k
|
||||
|
||||
private unsafe def updateAltCodeImp (alt : Alt) (k' : Code) : Alt :=
|
||||
private unsafe def updateAltCodeImp (alt : Alt pu) (k' : Code pu) : Alt pu :=
|
||||
match alt with
|
||||
| .default k => if ptrEq k k' then alt else .default k'
|
||||
| .alt ctorName ps k => if ptrEq k k' then alt else .alt ctorName ps k'
|
||||
| .alt ctorName ps k _ => if ptrEq k k' then alt else .alt ctorName ps k'
|
||||
|
||||
@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt) (c : Code) : Alt
|
||||
@[implemented_by updateAltCodeImp] opaque Alt.updateCode (alt : Alt pu) (c : Code pu) : Alt pu
|
||||
|
||||
private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Alt :=
|
||||
private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Code pu) : Alt pu :=
|
||||
match alt with
|
||||
| .alt ctorName ps k => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k'
|
||||
| .alt ctorName ps k _ => if ptrEq k k' && ptrEq ps ps' then alt else .alt ctorName ps' k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt) (ps' : Array Param) (k' : Code) : Alt
|
||||
@[implemented_by updateAltImp] opaque Alt.updateAlt! (alt : Alt pu) (ps' : Array (Param pu)) (k' : Code pu) : Alt pu
|
||||
|
||||
@[inline] private unsafe def updateAltsImp (c : Code) (alts : Array Alt) : Code :=
|
||||
@[inline] private unsafe def updateAltsImp (c : Code pu) (alts : Array (Alt pu)) : Code pu :=
|
||||
match c with
|
||||
| .cases cs => if ptrEq cs.alts alts then c else .cases <| cs.updateAlts alts
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateAltsImp] opaque Code.updateAlts! (c : Code) (alts : Array Alt) : Code
|
||||
@[implemented_by updateAltsImp] opaque Code.updateAlts! (c : Code pu) (alts : Array (Alt pu)) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateCasesImp (c : Code) (resultType : Expr) (discr : FVarId) (alts : Array Alt) : Code :=
|
||||
@[inline] private unsafe def updateCasesImp (c : Code pu) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) : Code pu :=
|
||||
match c with
|
||||
| .cases cs =>
|
||||
if ptrEq cs.alts alts && ptrEq cs.resultType resultType && cs.discr == discr then
|
||||
|
|
@ -339,54 +375,54 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al
|
|||
.cases <| ⟨cs.typeName, resultType, discr, alts⟩
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateCasesImp] opaque Code.updateCases! (c : Code) (resultType : Expr) (discr : FVarId) (alts : Array Alt) : Code
|
||||
@[implemented_by updateCasesImp] opaque Code.updateCases! (c : Code pu) (resultType : Expr) (discr : FVarId) (alts : Array (Alt pu)) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateLetImp (c : Code) (decl' : LetDecl) (k' : Code) : Code :=
|
||||
@[inline] private unsafe def updateLetImp (c : Code pu) (decl' : LetDecl pu) (k' : Code pu) : Code pu :=
|
||||
match c with
|
||||
| .let decl k => if ptrEq k k' && ptrEq decl decl' then c else .let decl' k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateLetImp] opaque Code.updateLet! (c : Code) (decl' : LetDecl) (k' : Code) : Code
|
||||
@[implemented_by updateLetImp] opaque Code.updateLet! (c : Code pu) (decl' : LetDecl pu) (k' : Code pu) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateContImp (c : Code) (k' : Code) : Code :=
|
||||
@[inline] private unsafe def updateContImp (c : Code pu) (k' : Code pu) : Code pu :=
|
||||
match c with
|
||||
| .let decl k => if ptrEq k k' then c else .let decl k'
|
||||
| .fun decl k => if ptrEq k k' then c else .fun decl k'
|
||||
| .fun decl k _ => if ptrEq k k' then c else .fun decl k'
|
||||
| .jp decl k => if ptrEq k k' then c else .jp decl k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code) (k' : Code) : Code
|
||||
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateFunImp (c : Code) (decl' : FunDecl) (k' : Code) : Code :=
|
||||
@[inline] private unsafe def updateFunImp (c : Code pu) (decl' : FunDecl pu) (k' : Code pu) : Code pu :=
|
||||
match c with
|
||||
| .fun decl k => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k'
|
||||
| .fun decl k _ => if ptrEq k k' && ptrEq decl decl' then c else .fun decl' k'
|
||||
| .jp decl k => if ptrEq k k' && ptrEq decl decl' then c else .jp decl' k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateFunImp] opaque Code.updateFun! (c : Code) (decl' : FunDecl) (k' : Code) : Code
|
||||
@[implemented_by updateFunImp] opaque Code.updateFun! (c : Code pu) (decl' : FunDecl pu) (k' : Code pu) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateReturnImp (c : Code) (fvarId' : FVarId) : Code :=
|
||||
@[inline] private unsafe def updateReturnImp (c : Code pu) (fvarId' : FVarId) : Code pu :=
|
||||
match c with
|
||||
| .return fvarId => if fvarId == fvarId' then c else .return fvarId'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code
|
||||
@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code pu) (fvarId' : FVarId) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code :=
|
||||
@[inline] private unsafe def updateJmpImp (c : Code pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : Code pu :=
|
||||
match c with
|
||||
| .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code
|
||||
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code pu) (fvarId' : FVarId) (args' : Array (Arg pu)) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateUnreachImp (c : Code) (type' : Expr) : Code :=
|
||||
@[inline] private unsafe def updateUnreachImp (c : Code pu) (type' : Expr) : Code pu :=
|
||||
match c with
|
||||
| .unreach type => if ptrEq type type' then c else .unreach type'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateUnreachImp] opaque Code.updateUnreach! (c : Code) (type' : Expr) : Code
|
||||
@[implemented_by updateUnreachImp] opaque Code.updateUnreach! (c : Code pu) (type' : Expr) : Code pu
|
||||
|
||||
private unsafe def updateParamCoreImp (p : Param) (type : Expr) : Param :=
|
||||
private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu :=
|
||||
if ptrEq type p.type then
|
||||
p
|
||||
else
|
||||
|
|
@ -397,9 +433,9 @@ Low-level update `Param` function. It does not update the local context.
|
|||
Consider using `Param.update : Param → Expr → CompilerM Param` if you want the local context
|
||||
to be updated.
|
||||
-/
|
||||
@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param) (type : Expr) : Param
|
||||
@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param pu) (type : Expr) : Param pu
|
||||
|
||||
private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : LetValue) : LetDecl :=
|
||||
private unsafe def updateLetDeclCoreImp (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : LetDecl pu :=
|
||||
if ptrEq type decl.type && ptrEq value decl.value then
|
||||
decl
|
||||
else
|
||||
|
|
@ -410,9 +446,9 @@ Low-level update `LetDecl` function. It does not update the local context.
|
|||
Consider using `LetDecl.update : LetDecl → Expr → Expr → CompilerM LetDecl` if you want the local context
|
||||
to be updated.
|
||||
-/
|
||||
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : LetValue) : LetDecl
|
||||
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : LetDecl pu
|
||||
|
||||
private unsafe def updateFunDeclCoreImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl :=
|
||||
private unsafe def updateFunDeclCoreImp (decl: FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : FunDecl pu :=
|
||||
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
|
||||
decl
|
||||
else
|
||||
|
|
@ -423,9 +459,9 @@ Low-level update `FunDecl` function. It does not update the local context.
|
|||
Consider using `FunDecl.update : LetDecl → Expr → Array Param → Code → CompilerM FunDecl` if you want the local context
|
||||
to be updated.
|
||||
-/
|
||||
@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl
|
||||
@[implemented_by updateFunDeclCoreImp] opaque FunDecl.updateCore (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : FunDecl pu
|
||||
|
||||
def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
|
||||
def Cases.extractAlt! (cases : Cases pu) (ctorName : Name) : Alt pu × Cases pu :=
|
||||
let found i := (cases.alts[i]!, cases.updateAlts (cases.alts.eraseIdx! i))
|
||||
if let some i := cases.alts.findFinIdx? fun | .alt ctorName' .. => ctorName == ctorName' | _ => false then
|
||||
found i
|
||||
|
|
@ -434,34 +470,34 @@ def Cases.extractAlt! (cases : Cases) (ctorName : Name) : Alt × Cases :=
|
|||
else
|
||||
unreachable!
|
||||
|
||||
def Alt.mapCodeM [Monad m] (alt : Alt) (f : Code → m Code) : m Alt := do
|
||||
def Alt.mapCodeM [Monad m] (alt : Alt pu) (f : Code pu → m (Code pu)) : m (Alt pu) := do
|
||||
return alt.updateCode (← f alt.getCode)
|
||||
|
||||
def Code.isDecl : Code → Bool
|
||||
def Code.isDecl : Code pu → Bool
|
||||
| .let .. | .fun .. | .jp .. => true
|
||||
| _ => false
|
||||
|
||||
def Code.isFun : Code → Bool
|
||||
def Code.isFun : Code pu → Bool
|
||||
| .fun .. => true
|
||||
| _ => false
|
||||
|
||||
def Code.isReturnOf : Code → FVarId → Bool
|
||||
def Code.isReturnOf : Code pu → FVarId → Bool
|
||||
| .return fvarId, fvarId' => fvarId == fvarId'
|
||||
| _, _ => false
|
||||
|
||||
partial def Code.size (c : Code) : Nat :=
|
||||
partial def Code.size (c : Code pu) : Nat :=
|
||||
go c 0
|
||||
where
|
||||
go (c : Code) (n : Nat) : Nat :=
|
||||
go (c : Code pu) (n : Nat) : Nat :=
|
||||
match c with
|
||||
| .let _ k => go k (n+1)
|
||||
| .jp decl k | .fun decl k => go k <| go decl.value n
|
||||
| .jp decl k | .fun decl k _ => go k <| go decl.value n
|
||||
| .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1)
|
||||
| .jmp .. => n+1
|
||||
| .return .. | unreach .. => n -- `return` & `unreach` have weight zero
|
||||
|
||||
/-- Return true iff `c.size ≤ n` -/
|
||||
partial def Code.sizeLe (c : Code) (n : Nat) : Bool :=
|
||||
partial def Code.sizeLe (c : Code pu) (n : Nat) : Bool :=
|
||||
match go c |>.run 0 with
|
||||
| .ok .. => true
|
||||
| .error .. => false
|
||||
|
|
@ -470,26 +506,26 @@ where
|
|||
modify (·+1)
|
||||
unless (← get) <= n do throw ()
|
||||
|
||||
go (c : Code) : EStateM Unit Nat Unit := do
|
||||
go (c : Code pu) : EStateM Unit Nat Unit := do
|
||||
match c with
|
||||
| .let _ k => inc; go k
|
||||
| .jp decl k | .fun decl k => inc; go decl.value; go k
|
||||
| .jp decl k | .fun decl k _ => inc; go decl.value; go k
|
||||
| .cases c => inc; c.alts.forM fun alt => go alt.getCode
|
||||
| .jmp .. => inc
|
||||
| .return .. | unreach .. => return ()
|
||||
|
||||
partial def Code.forM [Monad m] (c : Code) (f : Code → m Unit) : m Unit :=
|
||||
partial def Code.forM [Monad m] (c : Code pu) (f : Code pu → m Unit) : m Unit :=
|
||||
go c
|
||||
where
|
||||
go (c : Code) : m Unit := do
|
||||
go (c : Code pu) : m Unit := do
|
||||
f c
|
||||
match c with
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value; go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .unreach .. | .return .. | .jmp .. => return ()
|
||||
|
||||
partial def Code.instantiateValueLevelParams (code : Code) (levelParams : List Name) (us : List Level) : Code :=
|
||||
partial def Code.instantiateValueLevelParams (code : Code .pure) (levelParams : List Name) (us : List Level) : Code .pure :=
|
||||
instCode code
|
||||
where
|
||||
instLevel (u : Level) :=
|
||||
|
|
@ -498,67 +534,67 @@ where
|
|||
instExpr (e : Expr) :=
|
||||
e.instantiateLevelParamsNoCache levelParams us
|
||||
|
||||
instParams (ps : Array Param) :=
|
||||
instParams (ps : Array (Param .pure)) :=
|
||||
ps.mapMono fun p => p.updateCore (instExpr p.type)
|
||||
|
||||
instAlt (alt : Alt) :=
|
||||
instAlt (alt : Alt .pure) :=
|
||||
match alt with
|
||||
| .default k => alt.updateCode (instCode k)
|
||||
| .alt _ ps k => alt.updateAlt! (instParams ps) (instCode k)
|
||||
| .alt _ ps k _ => alt.updateAlt! (instParams ps) (instCode k)
|
||||
|
||||
instArg (arg : Arg) : Arg :=
|
||||
instArg (arg : Arg .pure) : Arg .pure :=
|
||||
match arg with
|
||||
| .type e => arg.updateType! (instExpr e)
|
||||
| .type e _ => arg.updateType! (instExpr e)
|
||||
| .fvar .. | .erased => arg
|
||||
|
||||
instLetValue (e : LetValue) : LetValue :=
|
||||
instLetValue (e : LetValue .pure) : LetValue .pure :=
|
||||
match e with
|
||||
| .const declName vs args => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg)
|
||||
| .const declName vs args _ => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg)
|
||||
| .fvar fvarId args => e.updateFVar! fvarId (args.mapMono instArg)
|
||||
| .proj .. | .lit .. | .erased => e
|
||||
|
||||
instLetDecl (decl : LetDecl) :=
|
||||
instLetDecl (decl : LetDecl .pure) :=
|
||||
decl.updateCore (instExpr decl.type) (instLetValue decl.value)
|
||||
|
||||
instFunDecl (decl : FunDecl) :=
|
||||
instFunDecl (decl : FunDecl .pure) :=
|
||||
decl.updateCore (instExpr decl.type) (instParams decl.params) (instCode decl.value)
|
||||
|
||||
instCode (code : Code) :=
|
||||
instCode (code : Code .pure) :=
|
||||
match code with
|
||||
| .let decl k => code.updateLet! (instLetDecl decl) (instCode k)
|
||||
| .jp decl k | .fun decl k => code.updateFun! (instFunDecl decl) (instCode k)
|
||||
| .jp decl k | .fun decl k _ => code.updateFun! (instFunDecl decl) (instCode k)
|
||||
| .cases c => code.updateCases! (instExpr c.resultType) c.discr (c.alts.mapMono instAlt)
|
||||
| .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instArg)
|
||||
| .return .. => code
|
||||
| .unreach type => code.updateUnreach! (instExpr type)
|
||||
|
||||
inductive DeclValue where
|
||||
| code (code : Code)
|
||||
inductive DeclValue (pu : Purity) where
|
||||
| code (code : Code pu)
|
||||
| extern (externAttrData : ExternAttrData)
|
||||
deriving Inhabited, BEq
|
||||
|
||||
partial def DeclValue.size : DeclValue → Nat
|
||||
partial def DeclValue.size : DeclValue pu → Nat
|
||||
| .code c => c.size
|
||||
| .extern .. => 0
|
||||
|
||||
def DeclValue.mapCode (f : Code → Code) : DeclValue → DeclValue :=
|
||||
def DeclValue.mapCode (f : Code pu → Code pu) : DeclValue pu → DeclValue pu :=
|
||||
fun
|
||||
| .code c => .code (f c)
|
||||
| .extern e => .extern e
|
||||
|
||||
def DeclValue.mapCodeM [Monad m] (f : Code → m Code) : DeclValue → m DeclValue :=
|
||||
def DeclValue.mapCodeM [Monad m] (f : Code pu → m (Code pu)) : DeclValue pu → m (DeclValue pu) :=
|
||||
fun v => do
|
||||
match v with
|
||||
| .code c => return .code (← f c)
|
||||
| .extern .. => return v
|
||||
|
||||
def DeclValue.forCodeM [Monad m] (f : Code → m Unit) : DeclValue → m Unit :=
|
||||
def DeclValue.forCodeM [Monad m] (f : Code pu → m Unit) : DeclValue pu → m Unit :=
|
||||
fun v => do
|
||||
match v with
|
||||
| .code c => f c
|
||||
| .extern .. => return ()
|
||||
|
||||
def DeclValue.isCodeAndM [Monad m] (v : DeclValue) (f : Code → m Bool) : m Bool :=
|
||||
def DeclValue.isCodeAndM [Monad m] (v : DeclValue pu) (f : Code pu → m Bool) : m Bool :=
|
||||
match v with
|
||||
| .code c => f c
|
||||
| .extern .. => pure false
|
||||
|
|
@ -566,7 +602,7 @@ def DeclValue.isCodeAndM [Monad m] (v : DeclValue) (f : Code → m Bool) : m Boo
|
|||
/--
|
||||
Declaration being processed by the Lean to Lean compiler passes.
|
||||
-/
|
||||
structure Decl where
|
||||
structure Decl (pu : Purity) where
|
||||
/--
|
||||
The name of the declaration from the `Environment` it came from
|
||||
-/
|
||||
|
|
@ -584,12 +620,12 @@ structure Decl where
|
|||
/--
|
||||
Parameters.
|
||||
-/
|
||||
params : Array Param
|
||||
params : Array (Param pu)
|
||||
/--
|
||||
The body of the declaration, usually changes as it progresses
|
||||
through compiler passes.
|
||||
-/
|
||||
value : DeclValue
|
||||
value : DeclValue pu
|
||||
/--
|
||||
We set this flag to true during LCNF conversion. When we receive
|
||||
a block of functions to be compiled, we set this flag to `true`
|
||||
|
|
@ -631,31 +667,37 @@ structure Decl where
|
|||
inlineAttr? : Option InlineAttributeKind
|
||||
deriving Inhabited, BEq
|
||||
|
||||
def Decl.size (decl : Decl) : Nat :=
|
||||
def Decl.size (decl : Decl pu) : Nat :=
|
||||
decl.value.size
|
||||
|
||||
def Decl.getArity (decl : Decl) : Nat :=
|
||||
def Decl.getArity (decl : Decl pu) : Nat :=
|
||||
decl.params.size
|
||||
|
||||
def Decl.inlineAttr (decl : Decl) : Bool :=
|
||||
def Decl.inlineAttr (decl : Decl pu) : Bool :=
|
||||
decl.inlineAttr? matches some .inline
|
||||
|
||||
def Decl.noinlineAttr (decl : Decl) : Bool :=
|
||||
def Decl.noinlineAttr (decl : Decl pu) : Bool :=
|
||||
decl.inlineAttr? matches some .noinline
|
||||
|
||||
def Decl.inlineIfReduceAttr (decl : Decl) : Bool :=
|
||||
def Decl.inlineIfReduceAttr (decl : Decl pu) : Bool :=
|
||||
decl.inlineAttr? matches some .inlineIfReduce
|
||||
|
||||
def Decl.alwaysInlineAttr (decl : Decl) : Bool :=
|
||||
def Decl.alwaysInlineAttr (decl : Decl pu) : Bool :=
|
||||
decl.inlineAttr? matches some .alwaysInline
|
||||
|
||||
/-- Return `true` if the given declaration has been annotated with `[inline]`, `[inline_if_reduce]`, `[macro_inline]`, or `[always_inline]` -/
|
||||
def Decl.inlineable (decl : Decl) : Bool :=
|
||||
def Decl.inlineable (decl : Decl pu) : Bool :=
|
||||
match decl.inlineAttr? with
|
||||
| some .noinline => false
|
||||
| some _ => true
|
||||
| none => false
|
||||
|
||||
def Decl.castPurity! (decl : Decl pu1) (pu2 : Purity) : Decl pu2 :=
|
||||
if h : pu1 = pu2 then
|
||||
h ▸ decl
|
||||
else
|
||||
panic! s!"Purity {pu1} does not match {pu2}, this is a bug"
|
||||
|
||||
/--
|
||||
Return `some i` if `decl` is of the form
|
||||
```
|
||||
|
|
@ -669,21 +711,21 @@ That is, `f` is a sequence of declarations followed by a `cases` on the paramete
|
|||
We use this function to decide whether we should inline a declaration tagged with
|
||||
`[inline_if_reduce]` or not.
|
||||
-/
|
||||
def Decl.isCasesOnParam? (decl : Decl) : Option Nat :=
|
||||
def Decl.isCasesOnParam? (decl : Decl pu) : Option Nat :=
|
||||
match decl.value with
|
||||
| .code c => go c
|
||||
| .extern .. => none
|
||||
where
|
||||
go (code : Code) : Option Nat :=
|
||||
go {pu : Purity} (code : Code pu) : Option Nat :=
|
||||
match code with
|
||||
| .let _ k | .jp _ k | .fun _ k => go k
|
||||
| .let _ k | .jp _ k | .fun _ k _ => go k
|
||||
| .cases c => decl.params.findIdx? fun param => param.fvarId == c.discr
|
||||
| _ => none
|
||||
|
||||
def Decl.instantiateTypeLevelParams (decl : Decl) (us : List Level) : Expr :=
|
||||
def Decl.instantiateTypeLevelParams (decl : Decl pu) (us : List Level) : Expr :=
|
||||
decl.type.instantiateLevelParamsNoCache decl.levelParams us
|
||||
|
||||
def Decl.instantiateParamsLevelParams (decl : Decl) (us : List Level) : Array Param :=
|
||||
def Decl.instantiateParamsLevelParams (decl : Decl pu) (us : List Level) : Array (Param pu) :=
|
||||
decl.params.mapMono fun param => param.updateCore (param.type.instantiateLevelParamsNoCache decl.levelParams us)
|
||||
|
||||
/--
|
||||
|
|
@ -700,7 +742,7 @@ def hasLocalInst (type : Expr) : CoreM Bool := do
|
|||
/--
|
||||
Return `true` if `decl` is supposed to be inlined/specialized.
|
||||
-/
|
||||
def Decl.isTemplateLike (decl : Decl) : CoreM Bool := do
|
||||
def Decl.isTemplateLike (decl : Decl pu) : CoreM Bool := do
|
||||
let env ← getEnv
|
||||
if ← hasLocalInst decl.type then
|
||||
return true -- `decl` applications will be specialized
|
||||
|
|
@ -721,40 +763,40 @@ private partial def collectType (e : Expr) : FVarIdHashSet → FVarIdHashSet :=
|
|||
| .proj .. | .letE .. => unreachable!
|
||||
| _ => id
|
||||
|
||||
private def collectArg (arg : Arg) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
private def collectArg (arg : Arg pu) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
match arg with
|
||||
| .erased => s
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
| .type e => collectType e s
|
||||
| .type e _ => collectType e s
|
||||
|
||||
private def collectArgs (args : Array Arg) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
private def collectArgs (args : Array (Arg pu)) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
args.foldl (init := s) fun s arg => collectArg arg s
|
||||
|
||||
private def collectLetValue (e : LetValue) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
private def collectLetValue (e : LetValue pu) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
match e with
|
||||
| .fvar fvarId args => collectArgs args <| s.insert fvarId
|
||||
| .const _ _ args => collectArgs args s
|
||||
| .proj _ _ fvarId => s.insert fvarId
|
||||
| .const _ _ args _ => collectArgs args s
|
||||
| .proj _ _ fvarId _ => s.insert fvarId
|
||||
| .lit .. | .erased => s
|
||||
|
||||
private partial def collectParams (ps : Array Param) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
private partial def collectParams (ps : Array (Param pu)) (s : FVarIdHashSet) : FVarIdHashSet :=
|
||||
ps.foldl (init := s) fun s p => collectType p.type s
|
||||
|
||||
mutual
|
||||
partial def FunDecl.collectUsed (decl : FunDecl) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
|
||||
partial def FunDecl.collectUsed (decl : FunDecl pu) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
|
||||
decl.value.collectUsed <| collectParams decl.params <| collectType decl.type s
|
||||
|
||||
partial def Code.collectUsed (code : Code) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
|
||||
partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarIdHashSet :=
|
||||
match code with
|
||||
| .let decl k => k.collectUsed <| collectLetValue decl.value <| collectType decl.type s
|
||||
| .jp decl k | .fun decl k => k.collectUsed <| decl.collectUsed s
|
||||
| .jp decl k | .fun decl k _ => k.collectUsed <| decl.collectUsed s
|
||||
| .cases c =>
|
||||
let s := s.insert c.discr
|
||||
let s := collectType c.resultType s
|
||||
c.alts.foldl (init := s) fun s alt =>
|
||||
match alt with
|
||||
| .default k => k.collectUsed s
|
||||
| .alt _ ps k => k.collectUsed <| collectParams ps s
|
||||
| .alt _ ps k _ => k.collectUsed <| collectParams ps s
|
||||
| .return fvarId => s.insert fvarId
|
||||
| .unreach type => collectType type s
|
||||
| .jmp fvarId args => collectArgs args <| s.insert fvarId
|
||||
|
|
@ -771,7 +813,7 @@ This is an overapproximation, and relies on the fact that our frontend
|
|||
computes strongly connected components.
|
||||
See comment at `recursive` field.
|
||||
-/
|
||||
partial def markRecDecls (decls : Array Decl) : Array Decl :=
|
||||
partial def markRecDecls (decls : Array (Decl pu)) : Array (Decl pu) :=
|
||||
let (_, isRec) := go |>.run {}
|
||||
decls.map fun decl =>
|
||||
if isRec.contains decl.name then
|
||||
|
|
@ -779,13 +821,13 @@ partial def markRecDecls (decls : Array Decl) : Array Decl :=
|
|||
else
|
||||
decl
|
||||
where
|
||||
visit (code : Code) : StateM NameSet Unit := do
|
||||
visit {pu : Purity} (code : Code pu) : StateM NameSet Unit := do
|
||||
match code with
|
||||
| .jp decl k | .fun decl k => visit decl.value; visit k
|
||||
| .jp decl k | .fun decl k _ => visit decl.value; visit k
|
||||
| .cases c => c.alts.forM fun alt => visit alt.getCode
|
||||
| .unreach .. | .jmp .. | .return .. => return ()
|
||||
| .let decl k =>
|
||||
if let .const declName _ _ := decl.value then
|
||||
if let .const declName _ _ _ := decl.value then
|
||||
if decls.any (·.name == declName) then
|
||||
modify fun s => s.insert declName
|
||||
visit k
|
||||
|
|
@ -793,13 +835,13 @@ where
|
|||
go : StateM NameSet Unit :=
|
||||
decls.forM (·.value.forCodeM visit)
|
||||
|
||||
def instantiateRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array Arg) : Expr :=
|
||||
def instantiateRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array (Arg pu)) : Expr :=
|
||||
if !e.hasLooseBVars then
|
||||
e
|
||||
else
|
||||
e.instantiateRange beginIdx endIdx (args.map (·.toExpr))
|
||||
|
||||
def instantiateRevRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array Arg) : Expr :=
|
||||
def instantiateRevRangeArgs (e : Expr) (beginIdx endIdx : Nat) (args : Array (Arg pu)) : Expr :=
|
||||
if !e.hasLooseBVars then
|
||||
e
|
||||
else
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ namespace Lean.Compiler.LCNF
|
|||
|
||||
/-- Helper class for lifting `CompilerM.codeBind` -/
|
||||
class MonadCodeBind (m : Type → Type) where
|
||||
codeBind : (c : Code) → (f : FVarId → m Code) → m Code
|
||||
codeBind : {pu : Purity} → (c : Code pu) → (f : FVarId → m (Code pu)) → m (Code pu)
|
||||
|
||||
/--
|
||||
Return code that is equivalent to `c >>= f`. That is, executes `c`, and then `f x`, where
|
||||
|
|
@ -25,16 +25,17 @@ an invalid block would be generated. It would be invalid because `f` would not
|
|||
be applied to `jp_i`. Note that, we could have decided to create a copy of `jp_i` where we apply `f` to it,
|
||||
by we decided to not do it to avoid code duplication.
|
||||
-/
|
||||
abbrev Code.bind [MonadCodeBind m] (c : Code) (f : FVarId → m Code) : m Code :=
|
||||
abbrev Code.bind [MonadCodeBind m] (c : Code pu) (f : FVarId → m (Code pu)) : m (Code pu) :=
|
||||
MonadCodeBind.codeBind c f
|
||||
|
||||
partial def CompilerM.codeBind (c : Code) (f : FVarId → CompilerM Code) : CompilerM Code := do
|
||||
partial def CompilerM.codeBind (c : Code pu) (f : FVarId → CompilerM (Code pu)) :
|
||||
CompilerM (Code pu) := do
|
||||
go c |>.run {}
|
||||
where
|
||||
go (c : Code) : ReaderT FVarIdSet CompilerM Code := do
|
||||
go (c : Code pu) : ReaderT FVarIdSet CompilerM (Code pu) := do
|
||||
match c with
|
||||
| .let decl k => return .let decl (← go k)
|
||||
| .fun decl k => return .fun decl (← go k)
|
||||
| .fun decl k _ => return .fun decl (← go k)
|
||||
| .jp decl k =>
|
||||
let value ← go decl.value
|
||||
let type ← value.inferParamType decl.params
|
||||
|
|
@ -43,7 +44,7 @@ where
|
|||
return .jp decl (← go k)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapM fun
|
||||
| .alt ctorName params k => return .alt ctorName params (← go k)
|
||||
| .alt ctorName params k _ => return .alt ctorName params (← go k)
|
||||
| .default k => return .default (← go k)
|
||||
if alts.isEmpty then
|
||||
throwError "`Code.bind` failed, empty `cases` found"
|
||||
|
|
@ -60,7 +61,7 @@ where
|
|||
This code is not very efficient, we could ask caller to provide the type of `c >>= f`,
|
||||
but this is more convenient, and this case is seldom reached.
|
||||
-/
|
||||
let auxParam ← mkAuxParam type
|
||||
let auxParam ← mkAuxParam (pu := pu) type
|
||||
let k ← f auxParam.fvarId
|
||||
let typeNew ← k.inferType
|
||||
eraseCode k
|
||||
|
|
@ -81,10 +82,10 @@ Create new parameters for the given arrow type.
|
|||
Example: if `type` is `Nat → Bool → Int`, the result is
|
||||
an array containing two new parameters with types `Nat` and `Bool`.
|
||||
-/
|
||||
partial def mkNewParams (type : Expr) : CompilerM (Array Param) :=
|
||||
partial def mkNewParams (type : Expr) : CompilerM (Array (Param pu)) :=
|
||||
go type #[] #[]
|
||||
where
|
||||
go (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do
|
||||
go (type : Expr) (xs : Array Expr) (ps : Array (Param pu)) : CompilerM (Array (Param pu)) := do
|
||||
match type with
|
||||
| .forallE _ d b _ =>
|
||||
let d := d.instantiateRev xs
|
||||
|
|
@ -98,15 +99,16 @@ where
|
|||
else
|
||||
return ps
|
||||
|
||||
def isEtaExpandCandidateCore (type : Expr) (params : Array Param) : Bool :=
|
||||
def isEtaExpandCandidateCore (type : Expr) (params : Array (Param .pure)) : Bool :=
|
||||
let typeArity := getArrowArity type
|
||||
let valueArity := params.size
|
||||
typeArity > valueArity
|
||||
|
||||
abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl) : Bool :=
|
||||
abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl .pure) : Bool :=
|
||||
isEtaExpandCandidateCore decl.type decl.params
|
||||
|
||||
def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : CompilerM (Array Param × Code) := do
|
||||
def etaExpandCore (type : Expr) (params : Array (Param .pure)) (value : Code .pure) :
|
||||
CompilerM (Array (Param .pure) × Code .pure) := do
|
||||
let valueType ← instantiateForall type (params.map (mkFVar ·.fvarId))
|
||||
let psNew ← mkNewParams valueType
|
||||
let params := params ++ psNew
|
||||
|
|
@ -116,17 +118,17 @@ def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : Compiler
|
|||
return .let auxDecl (.return auxDecl.fvarId)
|
||||
return (params, value)
|
||||
|
||||
def etaExpandCore? (type : Expr) (params : Array Param) (value : Code) : CompilerM (Option (Array Param × Code)) := do
|
||||
def etaExpandCore? (type : Expr) (params : Array (Param .pure)) (value : Code .pure) : CompilerM (Option (Array (Param .pure) × Code .pure)) := do
|
||||
if isEtaExpandCandidateCore type params then
|
||||
etaExpandCore type params value
|
||||
else
|
||||
return none
|
||||
|
||||
def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
|
||||
def FunDecl.etaExpand (decl : FunDecl .pure) : CompilerM (FunDecl .pure) := do
|
||||
let some (params, value) ← etaExpandCore? decl.type decl.params decl.value | return decl
|
||||
decl.update decl.type params value
|
||||
|
||||
def Decl.etaExpand (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.etaExpand (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
let some (params, newCode) ← etaExpandCore? decl.type decl.params code | return decl
|
||||
|
|
|
|||
|
|
@ -20,17 +20,17 @@ namespace CSE
|
|||
|
||||
structure State where
|
||||
map : PHashMap Expr FVarId := {}
|
||||
subst : FVarSubst := {}
|
||||
subst : FVarSubst .pure := {}
|
||||
|
||||
abbrev M := StateRefT State CompilerM
|
||||
|
||||
instance : MonadFVarSubst M false where
|
||||
instance : MonadFVarSubst M .pure false where
|
||||
getSubst := return (← get).subst
|
||||
|
||||
instance : MonadFVarSubstState M where
|
||||
instance : MonadFVarSubstState M .pure where
|
||||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||||
|
||||
@[inline] def getSubst : M FVarSubst :=
|
||||
@[inline] def getSubst : M (FVarSubst .pure) :=
|
||||
return (← get).subst
|
||||
|
||||
@[inline] def addEntry (value : Expr) (fvarId : FVarId) : M Unit :=
|
||||
|
|
@ -40,31 +40,32 @@ instance : MonadFVarSubstState M where
|
|||
let map := (← get).map
|
||||
try x finally modify fun s => { s with map }
|
||||
|
||||
def replaceLet (decl : LetDecl) (fvarId : FVarId) : M Unit := do
|
||||
def replaceLet (decl : LetDecl .pure) (fvarId : FVarId) : M Unit := do
|
||||
eraseLetDecl decl
|
||||
addFVarSubst decl.fvarId fvarId
|
||||
|
||||
def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do
|
||||
def replaceFun (decl : FunDecl .pure) (fvarId : FVarId) : M Unit := do
|
||||
eraseFunDecl decl
|
||||
addFVarSubst decl.fvarId fvarId
|
||||
|
||||
def hasNeverExtract (v : LetValue) : CompilerM Bool :=
|
||||
def hasNeverExtract (v : LetValue .pure) : CompilerM Bool :=
|
||||
match v with
|
||||
| .const declName .. =>
|
||||
return hasNeverExtractAttribute (← getEnv) declName
|
||||
| .lit _ | .erased | .proj .. | .fvar .. =>
|
||||
return false
|
||||
|
||||
partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code) : CompilerM Code :=
|
||||
partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code .pure) :
|
||||
CompilerM (Code .pure) :=
|
||||
go code |>.run' {}
|
||||
where
|
||||
goFunDecl (decl : FunDecl) : M FunDecl := do
|
||||
goFunDecl (decl : FunDecl .pure) : M (FunDecl .pure) := do
|
||||
let type ← normExpr decl.type
|
||||
let params ← normParams decl.params
|
||||
let value ← withNewScope do go decl.value
|
||||
decl.update type params value
|
||||
|
||||
go (code : Code) : M Code := do
|
||||
go (code : Code .pure) : M (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let decl ← normLetDecl decl
|
||||
|
|
@ -118,12 +119,13 @@ end CSE
|
|||
/--
|
||||
Common sub-expression elimination
|
||||
-/
|
||||
def Decl.cse (shouldElimFunDecls : Bool) (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.cse (shouldElimFunDecls : Bool) (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let value ← decl.value.mapCodeM (·.cse shouldElimFunDecls)
|
||||
return { decl with value }
|
||||
|
||||
def cse (phase : Phase := .base) (shouldElimFunDecls := false) (occurrence := 0) : Pass :=
|
||||
.mkPerDeclaration `cse (Decl.cse shouldElimFunDecls) phase occurrence
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `cse phase (h ▸ Decl.cse shouldElimFunDecls) occurrence
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.cse (inherited := true)
|
||||
|
|
|
|||
|
|
@ -79,7 +79,8 @@ the subtype relation in sanity checks and add the necessary casts.
|
|||
-/
|
||||
|
||||
namespace Check
|
||||
open InferType
|
||||
namespace Pure
|
||||
open InferType InferType.Pure
|
||||
|
||||
/-
|
||||
Type and structural properties checker for LCNF expressions.
|
||||
|
|
@ -110,7 +111,7 @@ def isCtorParam (f : Expr) (i : Nat) : CoreM Bool := do
|
|||
let .ctorInfo info ← getConstInfo declName | return false
|
||||
return i < info.numParams
|
||||
|
||||
def checkAppArgs (f : Expr) (args : Array Arg) : CheckM Unit := do
|
||||
def checkAppArgs (f : Expr) (args : Array (Arg .pure)) : CheckM Unit := do
|
||||
let mut fType ← inferType f
|
||||
let mut j := 0
|
||||
for h : i in *...args.size do
|
||||
|
|
@ -129,11 +130,11 @@ def checkAppArgs (f : Expr) (args : Array Arg) : CheckM Unit := do
|
|||
let expectedType := instantiateRevRangeArgs d j i args
|
||||
if (← checkTypes) then
|
||||
let argType ← arg.inferType
|
||||
unless (← InferType.compatibleTypes argType expectedType) do
|
||||
unless (← compatibleTypes argType expectedType) do
|
||||
throwError "type mismatch at LCNF application{indentExpr (mkAppN f (args.map Arg.toExpr))}\nargument {arg.toExpr} has type{indentExpr argType}\nbut is expected to have type{indentExpr expectedType}"
|
||||
fType := b
|
||||
|
||||
def checkLetValue (e : LetValue) : CheckM Unit := do
|
||||
def checkLetValue (e : LetValue .pure) : CheckM Unit := do
|
||||
match e with
|
||||
| .lit .. | .erased => pure ()
|
||||
| .const declName us args => checkAppArgs (mkConst declName us) args
|
||||
|
|
@ -154,18 +155,18 @@ def checkJpInScope (jp : FVarId) : CheckM Unit := do
|
|||
-/
|
||||
throwError "invalid jump to out of scope join point `{mkFVar jp}`"
|
||||
|
||||
def checkParam (param : Param) : CheckM Unit := do
|
||||
def checkParam (param : Param .pure) : CheckM Unit := do
|
||||
unless param == (← getParam param.fvarId) do
|
||||
throwError "LCNF parameter mismatch at `{param.binderName}`, does not value in local context"
|
||||
|
||||
def checkParams (params : Array Param) : CheckM Unit :=
|
||||
def checkParams (params : Array (Param .pure)) : CheckM Unit :=
|
||||
params.forM checkParam
|
||||
|
||||
def checkLetDecl (letDecl : LetDecl) : CheckM Unit := do
|
||||
def checkLetDecl (letDecl : LetDecl .pure) : CheckM Unit := do
|
||||
checkLetValue letDecl.value
|
||||
if (← checkTypes) then
|
||||
let valueType ← letDecl.value.inferType
|
||||
unless (← InferType.compatibleTypes letDecl.type valueType) do
|
||||
unless (← compatibleTypes letDecl.type valueType) do
|
||||
throwError "type mismatch at `{letDecl.binderName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr letDecl.type}"
|
||||
unless letDecl == (← getLetDecl letDecl.fvarId) do
|
||||
throwError "LCNF let declaration mismatch at `{letDecl.binderName}`, does not match value in local context"
|
||||
|
|
@ -183,7 +184,7 @@ def addFVarId (fvarId : FVarId) : CheckM Unit := do
|
|||
addFVarId fvarId
|
||||
withReader (fun ctx => { ctx with jps := ctx.jps.insert fvarId }) x
|
||||
|
||||
@[inline] def withParams (params : Array Param) (x : CheckM α) : CheckM α := do
|
||||
@[inline] def withParams (params : Array (Param .pure)) (x : CheckM α) : CheckM α := do
|
||||
params.forM (addFVarId ·.fvarId)
|
||||
withReader (fun ctx => { ctx with vars := params.foldl (init := ctx.vars) fun vars p => vars.insert p.fvarId })
|
||||
x
|
||||
|
|
@ -192,18 +193,18 @@ mutual
|
|||
|
||||
set_option linter.all false
|
||||
|
||||
partial def checkFunDeclCore (declName : Name) (params : Array Param) (type : Expr) (value : Code) : CheckM Unit := do
|
||||
partial def checkFunDeclCore (declName : Name) (params : Array (Param .pure)) (type : Expr) (value : Code .pure) : CheckM Unit := do
|
||||
checkParams params
|
||||
withParams params do
|
||||
discard <| check value
|
||||
if (← checkTypes) then
|
||||
let valueType ← mkForallParams params (← value.inferType)
|
||||
unless (← InferType.compatibleTypes type valueType) do
|
||||
unless (← compatibleTypes type valueType) do
|
||||
throwError "type mismatch at `{.ofConstName declName}`, value has type{indentExpr valueType}\nbut is expected to have type{indentExpr type}"
|
||||
|
||||
partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do
|
||||
partial def checkFunDecl (funDecl : FunDecl .pure) : CheckM Unit := do
|
||||
checkFunDeclCore funDecl.binderName funDecl.params funDecl.type funDecl.value
|
||||
let decl ← getFunDecl funDecl.fvarId
|
||||
let decl ← getFunDecl (pu := .pure) funDecl.fvarId
|
||||
unless decl.binderName == funDecl.binderName do
|
||||
throwError "LCNF local function declaration mismatch at `{funDecl.binderName}`, binder name in local context `{decl.binderName}`"
|
||||
unless decl.type == funDecl.type do
|
||||
|
|
@ -211,7 +212,7 @@ partial def checkFunDecl (funDecl : FunDecl) : CheckM Unit := do
|
|||
unless (← getFunDecl funDecl.fvarId) == funDecl do
|
||||
throwError "LCNF local function declaration mismatch at `{funDecl.binderName}`, declaration in local context does match"
|
||||
|
||||
partial def checkCases (c : Cases) : CheckM Unit := do
|
||||
partial def checkCases (c : Cases .pure) : CheckM Unit := do
|
||||
let mut ctorNames : NameSet := {}
|
||||
let mut hasDefault := false
|
||||
checkFVar c.discr
|
||||
|
|
@ -230,7 +231,7 @@ partial def checkCases (c : Cases) : CheckM Unit := do
|
|||
throwError "invalid LCNF `cases`, `{ctorName}` has # {val.numFields} fields, but alternative has # {params.size} alternatives"
|
||||
withParams params do check k
|
||||
|
||||
partial def check (code : Code) : CheckM Unit := do
|
||||
partial def check (code : Code .pure) : CheckM Unit := do
|
||||
match code with
|
||||
| .let decl k => checkLetDecl decl; withFVarId decl.fvarId do check k
|
||||
| .fun decl k =>
|
||||
|
|
@ -241,7 +242,7 @@ partial def check (code : Code) : CheckM Unit := do
|
|||
| .cases c => checkCases c
|
||||
| .jmp fvarId args =>
|
||||
checkJpInScope fvarId
|
||||
let decl ← getFunDecl fvarId
|
||||
let decl ← getFunDecl (pu := .pure) fvarId
|
||||
unless decl.getArity == args.size do
|
||||
throwError "invalid LCNF `goto`, join point {decl.binderName} has #{decl.getArity} parameters, but #{args.size} were provided"
|
||||
checkAppArgs (.fvar fvarId) args
|
||||
|
|
@ -253,9 +254,12 @@ end
|
|||
def run (x : CheckM α) : CompilerM α :=
|
||||
x |>.run {} |>.run' {} |>.run {}
|
||||
|
||||
end Pure
|
||||
end Check
|
||||
|
||||
def Decl.check (decl : Decl) : CompilerM Unit := do
|
||||
Check.run do decl.value.forCodeM (Check.checkFunDeclCore decl.name decl.params decl.type)
|
||||
def Decl.check (decl : Decl pu) : CompilerM Unit := do
|
||||
match pu with
|
||||
| .pure => Check.Pure.run do decl.value.forCodeM (Check.Pure.checkFunDeclCore decl.name decl.params decl.type)
|
||||
| .impure => panic! "Check for impure unimplemented" -- TODO
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ structure State where
|
|||
/--
|
||||
Free variables that must become new parameters of the code being specialized.
|
||||
-/
|
||||
params : Array Param := #[]
|
||||
params : Array (Param .pure) := #[]
|
||||
/--
|
||||
Let-declarations and local function declarations that are going to be "copied" to the code
|
||||
being processed. For example, when this module is used in the code specializer, the let-declarations
|
||||
|
|
@ -56,7 +56,7 @@ structure State where
|
|||
All customers of this module try to avoid work duplication. If a let-declaration is a ground value,
|
||||
it most likely will be computed during compilation time, and work duplication is not an issue.
|
||||
-/
|
||||
decls : Array CodeDecl := #[]
|
||||
decls : Array (CodeDecl .pure) := #[]
|
||||
|
||||
/--
|
||||
Monad for implementing the dependency collector.
|
||||
|
|
@ -75,16 +75,16 @@ mutual
|
|||
Collect dependencies in parameters. We need this because parameters may
|
||||
contain other type parameters.
|
||||
-/
|
||||
partial def collectParams (params : Array Param) : ClosureM Unit :=
|
||||
partial def collectParams (params : Array (Param .pure)) : ClosureM Unit :=
|
||||
params.forM (collectType ·.type)
|
||||
|
||||
partial def collectArg (arg : Arg) : ClosureM Unit :=
|
||||
partial def collectArg (arg : Arg .pure) : ClosureM Unit :=
|
||||
match arg with
|
||||
| .erased => return ()
|
||||
| .type e => collectType e
|
||||
| .fvar fvarId => collectFVar fvarId
|
||||
|
||||
partial def collectLetValue (e : LetValue) : ClosureM Unit := do
|
||||
partial def collectLetValue (e : LetValue .pure) : ClosureM Unit := do
|
||||
match e with
|
||||
| .erased | .lit .. => return ()
|
||||
| .proj _ _ fvarId => collectFVar fvarId
|
||||
|
|
@ -95,7 +95,7 @@ mutual
|
|||
Collect dependencies in the given code. We need this function to be able
|
||||
to collect dependencies in a local function declaration.
|
||||
-/
|
||||
partial def collectCode (c : Code) : ClosureM Unit := do
|
||||
partial def collectCode (c : Code .pure) : ClosureM Unit := do
|
||||
match c with
|
||||
| .let decl k =>
|
||||
collectType decl.type
|
||||
|
|
@ -114,7 +114,7 @@ mutual
|
|||
| .return fvarId => collectFVar fvarId
|
||||
|
||||
/-- Collect dependencies of a local function declaration. -/
|
||||
partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do
|
||||
partial def collectFunDecl (decl : FunDecl .pure) : ClosureM Unit := do
|
||||
collectType decl.type
|
||||
collectParams decl.params
|
||||
collectCode decl.value
|
||||
|
|
@ -155,7 +155,8 @@ mutual
|
|||
|
||||
end
|
||||
|
||||
def run (x : ClosureM α) (inScope : FVarId → Bool) (abstract : FVarId → Bool := fun _ => true) : CompilerM (α × Array Param × Array CodeDecl) := do
|
||||
def run (x : ClosureM α) (inScope : FVarId → Bool) (abstract : FVarId → Bool := fun _ => true) :
|
||||
CompilerM (α × Array (Param .pure) × Array (CodeDecl .pure)) := do
|
||||
let (a, s) ← x { inScope, abstract } |>.run {}
|
||||
-- If we've abstracted an fvar into a param, exclude its definition. Note that this still allows
|
||||
-- for other decls the removed decl depends upon to be included, but they will be removed later
|
||||
|
|
|
|||
|
|
@ -72,10 +72,13 @@ partial def compatibleTypesQuick (a b : Expr) : Bool :=
|
|||
| .const n us, .const m vs => n == m && List.isEqv us vs Level.isEquiv
|
||||
| _, _ => false
|
||||
|
||||
namespace InferType
|
||||
namespace Pure
|
||||
|
||||
/--
|
||||
Complete check for `compatibleTypes`. It eta-expands type formers. See comment at `compatibleTypes`.
|
||||
-/
|
||||
partial def InferType.compatibleTypesFull (a b : Expr) : InferTypeM Bool := do
|
||||
partial def compatibleTypesFull (a b : Expr) : InferTypeM Bool := do
|
||||
if a.isErased || b.isErased then
|
||||
return true
|
||||
else
|
||||
|
|
@ -141,10 +144,13 @@ This is a simplification. We used to use `isErasedCompatible`, but this only add
|
|||
For item 2, we would have to modify the `toLCNFType` function and make sure a type former is erased if the expected
|
||||
type is not always a type former (see `S.mk` type and example in the note above).
|
||||
-/
|
||||
def InferType.compatibleTypes (a b : Expr) : InferTypeM Bool := do
|
||||
def compatibleTypes (a b : Expr) : InferTypeM Bool := do
|
||||
if compatibleTypesQuick a b then
|
||||
return true
|
||||
else
|
||||
compatibleTypesFull a b
|
||||
|
||||
end Pure
|
||||
end InferType
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -21,7 +21,12 @@ inductive Phase where
|
|||
| base
|
||||
/-- In this phase polymorphism has been eliminated. -/
|
||||
| mono
|
||||
deriving Inhabited, BEq
|
||||
| impure
|
||||
deriving Inhabited, DecidableEq
|
||||
|
||||
@[expose, reducible] def Phase.toPurity : Phase → Purity
|
||||
| .base | .mono => .pure
|
||||
| .impure => .impure
|
||||
|
||||
/--
|
||||
The state managed by the `CompilerM` `Monad`.
|
||||
|
|
@ -52,48 +57,53 @@ instance : Monad CompilerM := let i := inferInstanceAs (Monad CompilerM); { pure
|
|||
def getPhase : CompilerM Phase :=
|
||||
return (← read).phase
|
||||
|
||||
def getPurity : CompilerM Purity :=
|
||||
return (← getPhase).toPurity
|
||||
|
||||
def inBasePhase : CompilerM Bool :=
|
||||
return (← getPhase) matches .base
|
||||
|
||||
instance : AddMessageContext CompilerM where
|
||||
addMessageContext msgData := do
|
||||
let env ← getEnv
|
||||
let lctx := (← get).lctx.toLocalContext
|
||||
let lctx := (← get).lctx.toLocalContext (← getPurity)
|
||||
let opts ← getOptions
|
||||
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
|
||||
|
||||
def getType (fvarId : FVarId) : CompilerM Expr := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
let pu ← getPurity
|
||||
if let some decl := (lctx.letDecls pu)[fvarId]? then
|
||||
return decl.type
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
else if let some decl := (lctx.params pu)[fvarId]? then
|
||||
return decl.type
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
else if let some decl := (lctx.funDecls pu)[fvarId]? then
|
||||
return decl.type
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def getBinderName (fvarId : FVarId) : CompilerM Name := do
|
||||
let lctx := (← get).lctx
|
||||
if let some decl := lctx.letDecls[fvarId]? then
|
||||
let pu ← getPurity
|
||||
if let some decl := (lctx.letDecls pu)[fvarId]? then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.params[fvarId]? then
|
||||
else if let some decl := (lctx.params pu)[fvarId]? then
|
||||
return decl.binderName
|
||||
else if let some decl := lctx.funDecls[fvarId]? then
|
||||
else if let some decl := (lctx.funDecls pu)[fvarId]? then
|
||||
return decl.binderName
|
||||
else
|
||||
throwError "unknown free variable {fvarId.name}"
|
||||
|
||||
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
|
||||
return (← get).lctx.params[fvarId]?
|
||||
def findParam? (fvarId : FVarId) : CompilerM (Option (Param pu)) := do
|
||||
return ((← get).lctx.params pu)[fvarId]?
|
||||
|
||||
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
|
||||
return (← get).lctx.letDecls[fvarId]?
|
||||
def findLetDecl? (fvarId : FVarId) : CompilerM (Option (LetDecl pu)) := do
|
||||
return ((← get).lctx.letDecls pu)[fvarId]?
|
||||
|
||||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
|
||||
return (← get).lctx.funDecls[fvarId]?
|
||||
def findFunDecl? (fvarId : FVarId) : CompilerM (Option (FunDecl pu)) := do
|
||||
return ((← get).lctx.funDecls pu)[fvarId]?
|
||||
|
||||
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
|
||||
def findLetValue? (fvarId : FVarId) : CompilerM (Option (LetValue pu)) := do
|
||||
let some { value, .. } ← findLetDecl? fvarId | return none
|
||||
return some value
|
||||
|
||||
|
|
@ -101,56 +111,56 @@ def isConstructorApp (fvarId : FVarId) : CompilerM Bool := do
|
|||
let some (.const declName _ _) ← findLetValue? fvarId | return false
|
||||
return (← getEnv).find? declName matches some (.ctorInfo ..)
|
||||
|
||||
def Arg.isConstructorApp (arg : Arg) : CompilerM Bool := do
|
||||
def Arg.isConstructorApp (arg : Arg pu) : CompilerM Bool := do
|
||||
let .fvar fvarId := arg | return false
|
||||
LCNF.isConstructorApp fvarId
|
||||
|
||||
def getParam (fvarId : FVarId) : CompilerM Param := do
|
||||
def getParam (fvarId : FVarId) : CompilerM (Param pu) := do
|
||||
let some param ← findParam? fvarId | throwError "unknown parameter {fvarId.name}"
|
||||
return param
|
||||
|
||||
def getLetDecl (fvarId : FVarId) : CompilerM LetDecl := do
|
||||
def getLetDecl (fvarId : FVarId) : CompilerM (LetDecl pu) := do
|
||||
let some decl ← findLetDecl? fvarId | throwError "unknown let-declaration {fvarId.name}"
|
||||
return decl
|
||||
|
||||
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
|
||||
def getFunDecl (fvarId : FVarId) : CompilerM (FunDecl pu) := do
|
||||
let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
|
||||
return decl
|
||||
|
||||
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
|
||||
modify fun s => { s with lctx := f s.lctx }
|
||||
|
||||
def eraseLetDecl (decl : LetDecl) : CompilerM Unit := do
|
||||
def eraseLetDecl (decl : LetDecl pu) : CompilerM Unit := do
|
||||
modifyLCtx fun lctx => lctx.eraseLetDecl decl
|
||||
|
||||
def eraseFunDecl (decl : FunDecl) (recursive := true) : CompilerM Unit := do
|
||||
def eraseFunDecl (decl : FunDecl pu) (recursive := true) : CompilerM Unit := do
|
||||
modifyLCtx fun lctx => lctx.eraseFunDecl decl recursive
|
||||
|
||||
def eraseCode (code : Code) : CompilerM Unit := do
|
||||
def eraseCode (code : Code pu) : CompilerM Unit := do
|
||||
modifyLCtx fun lctx => lctx.eraseCode code
|
||||
|
||||
def eraseParam (param : Param) : CompilerM Unit :=
|
||||
def eraseParam (param : Param pu) : CompilerM Unit :=
|
||||
modifyLCtx fun lctx => lctx.eraseParam param
|
||||
|
||||
def eraseParams (params : Array Param) : CompilerM Unit :=
|
||||
def eraseParams (params : Array (Param pu)) : CompilerM Unit :=
|
||||
modifyLCtx fun lctx => lctx.eraseParams params
|
||||
|
||||
def eraseCodeDecl (decl : CodeDecl) : CompilerM Unit := do
|
||||
def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
|
||||
match decl with
|
||||
| .let decl => eraseLetDecl decl
|
||||
| .jp decl | .fun decl => eraseFunDecl decl
|
||||
| .jp decl | .fun decl _ => eraseFunDecl decl
|
||||
|
||||
/--
|
||||
Erase all free variables occurring in `decls` from the local context.
|
||||
-/
|
||||
def eraseCodeDecls (decls : Array CodeDecl) : CompilerM Unit := do
|
||||
def eraseCodeDecls (decls : Array (CodeDecl pu)) : CompilerM Unit := do
|
||||
decls.forM fun decl => eraseCodeDecl decl
|
||||
|
||||
def eraseDecl (decl : Decl) : CompilerM Unit := do
|
||||
def eraseDecl (decl : Decl pu) : CompilerM Unit := do
|
||||
eraseParams decl.params
|
||||
decl.value.forCodeM eraseCode
|
||||
|
||||
abbrev Decl.erase (decl : Decl) : CompilerM Unit :=
|
||||
abbrev Decl.erase (decl : Decl pu) : CompilerM Unit :=
|
||||
eraseDecl decl
|
||||
|
||||
/--
|
||||
|
|
@ -166,7 +176,7 @@ it is a free variable, a type (or type former), or `lcErased`.
|
|||
|
||||
`Check.lean` contains a substitution validator.
|
||||
-/
|
||||
abbrev FVarSubst := Std.HashMap FVarId Arg
|
||||
abbrev FVarSubst (pu : Purity) := Std.HashMap FVarId (Arg pu)
|
||||
|
||||
/--
|
||||
Replace the free variables in `e` using the given substitution.
|
||||
|
|
@ -179,7 +189,7 @@ If `translator = false`, we assume the substitution contains free variable repla
|
|||
and given entries such as `x₁ ↦ x₂`, `x₂ ↦ x₃`, ..., `xₙ₋₁ ↦ xₙ`, and the expression `f x₁ x₂`, we want the resulting
|
||||
expression to be `f xₙ xₙ`. We use this setting, for example, in the simplifier.
|
||||
-/
|
||||
private partial def normExprImp (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
|
||||
private partial def normExprImp (s : FVarSubst pu) (e : Expr) (translator : Bool) : Expr :=
|
||||
go e
|
||||
where
|
||||
goApp (e : Expr) : Expr :=
|
||||
|
|
@ -192,7 +202,7 @@ where
|
|||
match e with
|
||||
| .fvar fvarId => match s[fvarId]? with
|
||||
| some (.fvar fvarId') => if translator then .fvar fvarId' else go (.fvar fvarId')
|
||||
| some (.type e) => if translator then e else go e
|
||||
| some (.type e _) => if translator then e else go e
|
||||
| some .erased => erasedExpr
|
||||
| none => e
|
||||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
|
||||
|
|
@ -225,7 +235,7 @@ This function panics if the substitution is mapping `fvarId` to an expression th
|
|||
That is, it is not a type (or type former), nor `lcErased`. Recall that a valid `FVarSubst` contains only
|
||||
expressions that are free variables, `lcErased`, or type formers.
|
||||
-/
|
||||
partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
|
||||
partial def normFVarImp (s : FVarSubst pu) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
|
||||
match s[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
if translator then
|
||||
|
|
@ -234,7 +244,7 @@ partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) :
|
|||
normFVarImp s fvarId' translator
|
||||
-- Types and type formers are only preserved as hints and
|
||||
-- are erased in computationally relevant contexts.
|
||||
| some .erased | some (.type _) => .erased
|
||||
| some .erased | some (.type _ _) => .erased
|
||||
| none => .fvar fvarId
|
||||
|
||||
/--
|
||||
|
|
@ -242,18 +252,18 @@ Replace the free variables in `arg` using the given substitution.
|
|||
|
||||
See `normExprImp`
|
||||
-/
|
||||
private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : Arg :=
|
||||
private partial def normArgImp (s : FVarSubst pu) (arg : Arg pu) (translator : Bool) : Arg pu :=
|
||||
match arg with
|
||||
| .erased => arg
|
||||
| .fvar fvarId =>
|
||||
match s[fvarId]? with
|
||||
| some (arg'@(.fvar _)) =>
|
||||
if translator then arg' else normArgImp s arg' translator
|
||||
| some (arg'@.erased) | some (arg'@(.type _)) => arg'
|
||||
| some (arg'@.erased) | some (arg'@(.type _ _)) => arg'
|
||||
| none => arg
|
||||
| .type e => arg.updateType! (normExprImp s e translator)
|
||||
| .type e _ => arg.updateType! (normExprImp s e translator)
|
||||
|
||||
private def normArgsImp (s : FVarSubst) (args : Array Arg) (translator : Bool) : Array Arg :=
|
||||
private def normArgsImp (s : FVarSubst pu) (args : Array (Arg pu)) (translator : Bool) : Array (Arg pu) :=
|
||||
args.mapMono (normArgImp s · translator)
|
||||
|
||||
/--
|
||||
|
|
@ -261,13 +271,13 @@ Replace the free variables in `e` using the given substitution.
|
|||
|
||||
See `normExprImp`
|
||||
-/
|
||||
private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator : Bool) : LetValue :=
|
||||
private partial def normLetValueImp (s : FVarSubst pu) (e : LetValue pu) (translator : Bool) : LetValue pu :=
|
||||
match e with
|
||||
| .erased | .lit .. => e
|
||||
| .proj _ _ fvarId => match normFVarImp s fvarId translator with
|
||||
| .proj _ _ fvarId _ => match normFVarImp s fvarId translator with
|
||||
| .fvar fvarId' => e.updateProj! fvarId'
|
||||
| .erased => .erased
|
||||
| .const _ _ args => e.updateArgs! (normArgsImp s args translator)
|
||||
| .const _ _ args _ => e.updateArgs! (normArgsImp s args translator)
|
||||
| .fvar fvarId args => match normFVarImp s fvarId translator with
|
||||
| .fvar fvarId' => e.updateFVar! fvarId' (normArgsImp s args translator)
|
||||
| .erased => .erased
|
||||
|
|
@ -275,20 +285,20 @@ private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator :
|
|||
/--
|
||||
Interface for monads that have a free substitutions.
|
||||
-/
|
||||
class MonadFVarSubst (m : Type → Type) (translator : outParam Bool) where
|
||||
getSubst : m FVarSubst
|
||||
class MonadFVarSubst (m : Type → Type) (pu : outParam Purity) (translator : outParam Bool) where
|
||||
getSubst : m (FVarSubst pu)
|
||||
|
||||
export MonadFVarSubst (getSubst)
|
||||
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubst m t] : MonadFVarSubst n t where
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubst m pu t] : MonadFVarSubst n pu t where
|
||||
getSubst := liftM (getSubst : m _)
|
||||
|
||||
class MonadFVarSubstState (m : Type → Type) where
|
||||
modifySubst : (FVarSubst → FVarSubst) → m Unit
|
||||
class MonadFVarSubstState (m : Type → Type) (pu : outParam Purity) where
|
||||
modifySubst : (FVarSubst pu → FVarSubst pu) → m Unit
|
||||
|
||||
export MonadFVarSubstState (modifySubst)
|
||||
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where
|
||||
instance (m n) [MonadLift m n] [MonadFVarSubstState m pu] : MonadFVarSubstState n pu where
|
||||
modifySubst f := liftM (modifySubst f : m _)
|
||||
|
||||
/--
|
||||
|
|
@ -296,35 +306,35 @@ Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF `Arg`.
|
|||
|
||||
See `Check.lean` for the free variable substitution checker.
|
||||
-/
|
||||
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (arg : Arg) : m Unit :=
|
||||
@[inline] def addSubst [MonadFVarSubstState m pu] (fvarId : FVarId) (arg : Arg pu) : m Unit :=
|
||||
modifySubst fun s => s.insert fvarId arg
|
||||
|
||||
/--
|
||||
Add the entry `fvarId ↦ fvarId'` to the free variable substitution.
|
||||
-/
|
||||
@[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
|
||||
@[inline] def addFVarSubst [MonadFVarSubstState m ph] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
|
||||
modifySubst fun s => s.insert fvarId (.fvar fvarId')
|
||||
|
||||
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
|
||||
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m pu t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
|
||||
return normFVarImp (← getSubst) fvarId t
|
||||
|
||||
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr :=
|
||||
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m pu t] [Monad m] (e : Expr) : m Expr :=
|
||||
return normExprImp (← getSubst) e t
|
||||
|
||||
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m t] [Monad m] (arg : Arg) : m Arg :=
|
||||
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m pu t] [Monad m] (arg : Arg pu) : m (Arg pu) :=
|
||||
return normArgImp (← getSubst) arg t
|
||||
|
||||
@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m t] [Monad m] (e : LetValue) : m LetValue :=
|
||||
@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m pu t] [Monad m] (e : LetValue pu) : m (LetValue pu) :=
|
||||
return normLetValueImp (← getSubst) e t
|
||||
|
||||
@[inherit_doc normExprImp, inline]
|
||||
def normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
|
||||
def normExprCore (s : FVarSubst pu) (e : Expr) (translator : Bool) : Expr :=
|
||||
normExprImp s e translator
|
||||
|
||||
/--
|
||||
Normalize the given arguments using the current substitution.
|
||||
-/
|
||||
def normArgs [MonadFVarSubst m t] [Monad m] (args : Array Arg) : m (Array Arg) :=
|
||||
def normArgs [MonadFVarSubst m pu t] [Monad m] (args : Array (Arg pu)) : m (Array (Arg pu)) :=
|
||||
return normArgsImp (← getSubst) args t
|
||||
|
||||
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
|
||||
|
|
@ -342,35 +352,35 @@ def ensureNotAnonymous (binderName : Name) (baseName : Name) : CompilerM Name :=
|
|||
Helper functions for creating LCNF local declarations.
|
||||
-/
|
||||
|
||||
def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param := do
|
||||
def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM (Param pu) := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
let binderName ← ensureNotAnonymous binderName `_y
|
||||
let param := { fvarId, binderName, type, borrow }
|
||||
modifyLCtx fun lctx => lctx.addParam param
|
||||
return param
|
||||
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue) : CompilerM LetDecl := do
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
let binderName ← ensureNotAnonymous binderName `_x
|
||||
let decl := { fvarId, binderName, type, value }
|
||||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||||
return decl
|
||||
|
||||
def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||||
def mkFunDecl (binderName : Name) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
let binderName ← ensureNotAnonymous binderName `_f
|
||||
let funDecl := ⟨fvarId, binderName, params, type, value⟩
|
||||
modifyLCtx fun lctx => lctx.addFunDecl funDecl
|
||||
return funDecl
|
||||
|
||||
def mkLetDeclErased : CompilerM LetDecl := do
|
||||
def mkLetDeclErased : CompilerM (LetDecl pu) := do
|
||||
mkLetDecl (← mkFreshBinderName `_x) erasedExpr .erased
|
||||
|
||||
def mkReturnErased : CompilerM Code := do
|
||||
def mkReturnErased : CompilerM (Code pu) := do
|
||||
let auxDecl ← mkLetDeclErased
|
||||
return .let auxDecl (.return auxDecl.fvarId)
|
||||
|
||||
private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := do
|
||||
private unsafe def updateParamImp (p : Param pu) (type : Expr) : CompilerM (Param pu) := do
|
||||
if ptrEq type p.type then
|
||||
return p
|
||||
else
|
||||
|
|
@ -378,9 +388,9 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param :=
|
|||
modifyLCtx fun lctx => lctx.addParam p
|
||||
return p
|
||||
|
||||
@[implemented_by updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
|
||||
@[implemented_by updateParamImp] opaque Param.update (p : Param pu) (type : Expr) : CompilerM (Param pu)
|
||||
|
||||
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl := do
|
||||
private unsafe def updateLetDeclImp (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu) := do
|
||||
if ptrEq type decl.type && ptrEq value decl.value then
|
||||
return decl
|
||||
else
|
||||
|
|
@ -388,12 +398,12 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetV
|
|||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||||
return decl
|
||||
|
||||
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl
|
||||
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl pu) (type : Expr) (value : LetValue pu) : CompilerM (LetDecl pu)
|
||||
|
||||
def LetDecl.updateValue (decl : LetDecl) (value : LetValue) : CompilerM LetDecl :=
|
||||
def LetDecl.updateValue (decl : LetDecl pu) (value : LetValue pu) : CompilerM (LetDecl pu) :=
|
||||
decl.update decl.type value
|
||||
|
||||
private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||||
private unsafe def updateFunDeclImp (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu) := do
|
||||
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
|
||||
return decl
|
||||
else
|
||||
|
|
@ -401,48 +411,48 @@ private unsafe def updateFunDeclImp (decl : FunDecl) (type : Expr) (params : Arr
|
|||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||||
return decl
|
||||
|
||||
@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
|
||||
@[implemented_by updateFunDeclImp] opaque FunDecl.update (decl : FunDecl pu) (type : Expr) (params : Array (Param pu)) (value : Code pu) : CompilerM (FunDecl pu)
|
||||
|
||||
abbrev FunDecl.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
|
||||
abbrev FunDecl.update' (decl : FunDecl pu) (type : Expr) (value : Code pu) : CompilerM (FunDecl pu) :=
|
||||
decl.update type decl.params value
|
||||
|
||||
abbrev FunDecl.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
|
||||
abbrev FunDecl.updateValue (decl : FunDecl pu) (value : Code pu) : CompilerM (FunDecl pu) :=
|
||||
decl.update decl.type decl.params value
|
||||
|
||||
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (p : Param) : m Param := do
|
||||
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (p : Param pu) : m (Param pu) := do
|
||||
p.update (← normExpr p.type)
|
||||
|
||||
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Array Param) : m (Array Param) :=
|
||||
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (ps : Array (Param pu)) : m (Array (Param pu)) :=
|
||||
ps.mapMonoM normParam
|
||||
|
||||
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do
|
||||
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : LetDecl pu) : m (LetDecl pu) := do
|
||||
decl.update (← normExpr decl.type) (← normLetValue decl.value)
|
||||
|
||||
abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM
|
||||
abbrev NormalizerM (pu : Purity) (_translator : Bool) := ReaderT (FVarSubst pu) CompilerM
|
||||
|
||||
instance : MonadFVarSubst (NormalizerM t) t where
|
||||
instance : MonadFVarSubst (NormalizerM pu t) pu t where
|
||||
getSubst := read
|
||||
|
||||
/--
|
||||
If `result` is `.fvar fvarId`, then return `x fvarId`. Otherwise, it is `.erased`,
|
||||
and method returns `let _x.i := .erased; return _x.i`.
|
||||
-/
|
||||
@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId → m Code) : m Code := do
|
||||
@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId → m (Code pu)) : m (Code pu) := do
|
||||
match result with
|
||||
| .fvar fvarId => x fvarId
|
||||
| .erased => mkReturnErased
|
||||
|
||||
mutual
|
||||
partial def normFunDeclImp (decl : FunDecl) : NormalizerM t FunDecl := do
|
||||
partial def normFunDeclImp (decl : FunDecl pu) : NormalizerM pu t (FunDecl pu) := do
|
||||
let type ← normExpr decl.type
|
||||
let params ← normParams decl.params
|
||||
let value ← normCodeImp decl.value
|
||||
decl.update type params value
|
||||
|
||||
partial def normCodeImp (code : Code) : NormalizerM t Code := do
|
||||
partial def normCodeImp (code : Code pu) : NormalizerM pu t (Code pu) := do
|
||||
match code with
|
||||
| .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k)
|
||||
| .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
|
||||
| .fun decl k _ | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
|
||||
| .return fvarId => withNormFVarResult (← normFVar fvarId) fun fvarId => return code.updateReturn! fvarId
|
||||
| .jmp fvarId args => withNormFVarResult (← normFVar fvarId) fun fvarId => return code.updateJmp! fvarId (← normArgs args)
|
||||
| .unreach type => return code.updateUnreach! (← normExpr type)
|
||||
|
|
@ -451,28 +461,28 @@ mutual
|
|||
withNormFVarResult (← normFVar c.discr) fun discr => do
|
||||
let alts ← c.alts.mapMonoM fun alt =>
|
||||
match alt with
|
||||
| .alt _ params k => return alt.updateAlt! (← normParams params) (← normCodeImp k)
|
||||
| .alt _ params k _ => return alt.updateAlt! (← normParams params) (← normCodeImp k)
|
||||
| .default k => return alt.updateCode (← normCodeImp k)
|
||||
return code.updateCases! resultType discr alts
|
||||
end
|
||||
|
||||
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : FunDecl) : m FunDecl := do
|
||||
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do
|
||||
normFunDeclImp (t := t) decl (← getSubst)
|
||||
|
||||
/-- Similar to `internalize`, but does not refresh `FVarId`s. -/
|
||||
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (code : Code) : m Code := do
|
||||
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (code : Code pu) : m (Code pu) := do
|
||||
normCodeImp (t := t) code (← getSubst)
|
||||
|
||||
def replaceExprFVars (e : Expr) (s : FVarSubst) (translator : Bool) : CompilerM Expr :=
|
||||
(normExpr e : NormalizerM translator Expr).run s
|
||||
def replaceExprFVars (e : Expr) (s : FVarSubst pu) (translator : Bool) : CompilerM Expr :=
|
||||
(normExpr e : NormalizerM pu translator Expr).run s
|
||||
|
||||
def replaceFVars (code : Code) (s : FVarSubst) (translator : Bool) : CompilerM Code :=
|
||||
(normCode code : NormalizerM translator Code).run s
|
||||
def replaceFVars (code : Code pu) (s : FVarSubst pu) (translator : Bool) : CompilerM (Code pu) :=
|
||||
(normCode code : NormalizerM pu translator (Code pu)).run s
|
||||
|
||||
def mkFreshJpName : CompilerM Name := do
|
||||
mkFreshBinderName `_jp
|
||||
|
||||
def mkAuxParam (type : Expr) (borrow := false) : CompilerM Param := do
|
||||
def mkAuxParam (type : Expr) (borrow := false) : CompilerM (Param pu) := do
|
||||
mkParam (← mkFreshBinderName `_y) type borrow
|
||||
|
||||
def getConfig : CompilerM ConfigOptions :=
|
||||
|
|
|
|||
|
|
@ -12,25 +12,25 @@ public section
|
|||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
instance : Hashable Param where
|
||||
instance : Hashable (Param pu) where
|
||||
hash p := mixHash (hash p.fvarId) (hash p.type)
|
||||
|
||||
def hashParams (ps : Array Param) : UInt64 :=
|
||||
def hashParams (ps : Array (Param pu)) : UInt64 :=
|
||||
hash ps
|
||||
|
||||
mutual
|
||||
partial def hashAlt (alt : Alt) : UInt64 :=
|
||||
partial def hashAlt (alt : Alt pu) : UInt64 :=
|
||||
match alt with
|
||||
| .alt ctorName ps k => mixHash (mixHash (hash ctorName) (hash ps)) (hashCode k)
|
||||
| .alt ctorName ps k _ => mixHash (mixHash (hash ctorName) (hash ps)) (hashCode k)
|
||||
| .default k => hashCode k
|
||||
|
||||
partial def hashAlts (alts : Array Alt) : UInt64 :=
|
||||
partial def hashAlts (alts : Array (Alt pu)) : UInt64 :=
|
||||
alts.foldl (fun r a => mixHash r (hashAlt a)) 7
|
||||
|
||||
partial def hashCode (code : Code) : UInt64 :=
|
||||
partial def hashCode (code : Code pu) : UInt64 :=
|
||||
match code with
|
||||
| .let decl k => mixHash (mixHash (hash decl.fvarId) (hash decl.type)) (mixHash (hash decl.value) (hashCode k))
|
||||
| .fun decl k | .jp decl k =>
|
||||
| .fun decl k _ | .jp decl k =>
|
||||
mixHash (mixHash (mixHash (hash decl.fvarId) (hash decl.type)) (mixHash (hashCode decl.value) (hashCode k))) (hash decl.params)
|
||||
| .return fvarId => hash fvarId
|
||||
| .unreach type => hash type
|
||||
|
|
@ -39,7 +39,7 @@ partial def hashCode (code : Code) : UInt64 :=
|
|||
|
||||
end
|
||||
|
||||
instance : Hashable Code where
|
||||
instance : Hashable (Code pu) where
|
||||
hash c := hashCode c
|
||||
|
||||
deriving instance Hashable for DeclValue
|
||||
|
|
|
|||
|
|
@ -21,46 +21,46 @@ private def typeDepOn (e : Expr) : M Bool := do
|
|||
let s ← read
|
||||
return e.hasAnyFVar fun fvarId => s.contains fvarId
|
||||
|
||||
private def argDepOn (a : Arg) : M Bool := do
|
||||
private def argDepOn (a : Arg pu) : M Bool := do
|
||||
match a with
|
||||
| .erased => return false
|
||||
| .fvar fvarId => fvarDepOn fvarId
|
||||
| .type e => typeDepOn e
|
||||
| .type e _ => typeDepOn e
|
||||
|
||||
private def letValueDepOn (e : LetValue) : M Bool :=
|
||||
private def letValueDepOn (e : LetValue pu) : M Bool :=
|
||||
match e with
|
||||
| .erased | .lit .. => return false
|
||||
| .proj _ _ fvarId => fvarDepOn fvarId
|
||||
| .proj _ _ fvarId _ => fvarDepOn fvarId
|
||||
| .fvar fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
|
||||
| .const _ _ args => args.anyM argDepOn
|
||||
| .const _ _ args _ => args.anyM argDepOn
|
||||
|
||||
private def LetDecl.depOn (decl : LetDecl) : M Bool :=
|
||||
private def LetDecl.depOn (decl : LetDecl pu) : M Bool :=
|
||||
typeDepOn decl.type <||> letValueDepOn decl.value
|
||||
|
||||
private partial def depOn (c : Code) : M Bool :=
|
||||
private partial def depOn (c : Code pu) : M Bool :=
|
||||
match c with
|
||||
| .let decl k => decl.depOn <||> depOn k
|
||||
| .jp decl k | .fun decl k => typeDepOn decl.type <||> depOn decl.value <||> depOn k
|
||||
| .jp decl k | .fun decl k _ => typeDepOn decl.type <||> depOn decl.value <||> depOn k
|
||||
| .cases c => typeDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode
|
||||
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
|
||||
| .return fvarId => fvarDepOn fvarId
|
||||
| .unreach _ => return false
|
||||
|
||||
@[inline] def LetDecl.dependsOn (decl : LetDecl) (s : FVarIdSet) : Bool :=
|
||||
@[inline] def LetDecl.dependsOn (decl : LetDecl pu) (s : FVarIdSet) : Bool :=
|
||||
decl.depOn s
|
||||
|
||||
@[inline] def FunDecl.dependsOn (decl : FunDecl) (s : FVarIdSet) : Bool :=
|
||||
@[inline] def FunDecl.dependsOn (decl : FunDecl pu) (s : FVarIdSet) : Bool :=
|
||||
typeDepOn decl.type s || depOn decl.value s
|
||||
|
||||
def CodeDecl.dependsOn (decl : CodeDecl) (s : FVarIdSet) : Bool :=
|
||||
def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
|
||||
match decl with
|
||||
| .let decl => decl.dependsOn s
|
||||
| .jp decl | .fun decl => decl.dependsOn s
|
||||
| .jp decl | .fun decl _ => decl.dependsOn s
|
||||
|
||||
/--
|
||||
Return `true` is `c` depends on a free variable in `s`.
|
||||
-/
|
||||
def Code.dependsOn (c : Code) (s : FVarIdSet) : Bool :=
|
||||
def Code.dependsOn (c : Code pu) (s : FVarIdSet) : Bool :=
|
||||
depOn c s
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -19,16 +19,16 @@ Collect set of (let) free variables in a LCNF value.
|
|||
This code exploits the LCNF property that local declarations do not occur in types.
|
||||
-/
|
||||
|
||||
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg) : UsedLocalDecls :=
|
||||
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg .pure) : UsedLocalDecls :=
|
||||
match arg with
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
-- Locally declared variables do not occur in types.
|
||||
| .type _ | .erased => s
|
||||
|
||||
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array Arg) : UsedLocalDecls :=
|
||||
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array (Arg .pure)) : UsedLocalDecls :=
|
||||
args.foldl (init := s) collectLocalDeclsArg
|
||||
|
||||
def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue) : UsedLocalDecls :=
|
||||
def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue .pure) : UsedLocalDecls :=
|
||||
match e with
|
||||
| .erased | .lit .. => s
|
||||
| .proj _ _ fvarId => s.insert fvarId
|
||||
|
|
@ -39,21 +39,22 @@ namespace ElimDead
|
|||
|
||||
abbrev M := StateRefT UsedLocalDecls CompilerM
|
||||
|
||||
private abbrev collectArgM (arg : Arg) : M Unit :=
|
||||
private abbrev collectArgM (arg : Arg .pure) : M Unit :=
|
||||
modify (collectLocalDeclsArg · arg)
|
||||
|
||||
private abbrev collectLetValueM (e : LetValue) : M Unit :=
|
||||
private abbrev collectLetValueM (e : LetValue .pure) : M Unit :=
|
||||
modify (collectLocalDeclsLetValue · e)
|
||||
|
||||
private abbrev collectFVarM (fvarId : FVarId) : M Unit :=
|
||||
modify (·.insert fvarId)
|
||||
|
||||
mutual
|
||||
partial def visitFunDecl (funDecl : FunDecl) : M FunDecl := do
|
||||
|
||||
partial def visitFunDecl (funDecl : FunDecl .pure) : M (FunDecl .pure) := do
|
||||
let value ← elimDead funDecl.value
|
||||
funDecl.updateValue value
|
||||
|
||||
partial def elimDead (code : Code) : M Code := do
|
||||
partial def elimDead (code : Code .pure) : M (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let k ← elimDead k
|
||||
|
|
@ -84,10 +85,11 @@ end
|
|||
|
||||
end ElimDead
|
||||
|
||||
def Code.elimDead (code : Code) : CompilerM Code :=
|
||||
-- TODO: Generalize this to arbitrary phases, keep in mind that in impure elim dead is not as easy though
|
||||
def Code.elimDead (code : Code .pure) : CompilerM (Code .pure) :=
|
||||
ElimDead.elimDead code |>.run' {}
|
||||
|
||||
def Decl.elimDead (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.elimDead (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
return { decl with value := (← decl.value.mapCodeM Code.elimDead) }
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -239,14 +239,14 @@ Attempt to turn a `Value` that is representing a literal into a set of
|
|||
auxiliary declarations + the final `FVarId` of the declaration that
|
||||
contains the actual literal. If it is not a literal return none.
|
||||
-/
|
||||
partial def getLiteral (v : Value) : CompilerM (Option ((Array CodeDecl) × FVarId)) := do
|
||||
partial def getLiteral (v : Value) : CompilerM (Option ((Array (CodeDecl .pure)) × FVarId)) := do
|
||||
if isLiteral v then
|
||||
let literal ← go v
|
||||
return some literal
|
||||
else
|
||||
return none
|
||||
where
|
||||
go : Value → CompilerM ((Array CodeDecl) × FVarId)
|
||||
go : Value → CompilerM ((Array (CodeDecl .pure)) × FVarId)
|
||||
| .ctor ``Nat.zero #[] .. => do
|
||||
let decl ← mkAuxLetDecl <| .lit <| .nat <| 0
|
||||
return (#[.let decl], decl.fvarId)
|
||||
|
|
@ -260,7 +260,7 @@ where
|
|||
let flatten acc := fun (decls, var) => (acc.fst ++ decls, acc.snd.push <| .fvar var)
|
||||
let (decls, args) :=
|
||||
fields.foldl (init := (#[], Array.replicate ctorInfo.numParams .erased)) flatten
|
||||
let letVal : LetValue := .const ctorName [] args
|
||||
let letVal : LetValue .pure := .const ctorName [] args
|
||||
let letDecl ← mkAuxLetDecl letVal
|
||||
return (decls.push <| .let letDecl, letDecl.fvarId)
|
||||
| _ => unreachable!
|
||||
|
|
@ -328,7 +328,7 @@ structure InterpContext where
|
|||
a single declaration or a mutual block of declarations where their
|
||||
analysis might influence each other as we approach the fixpoint.
|
||||
-/
|
||||
decls : Array Decl
|
||||
decls : Array (Decl .pure)
|
||||
/--
|
||||
The index of the function we are currently operating on in `decls.`
|
||||
-/
|
||||
|
|
@ -386,7 +386,7 @@ def findVarValue (var : FVarId) : InterpM Value := do
|
|||
/--
|
||||
Find the value of `arg` using the logic of `findVarValue`.
|
||||
-/
|
||||
def findArgValue (arg : Arg) : InterpM Value := do
|
||||
def findArgValue (arg : Arg .pure) : InterpM Value := do
|
||||
match arg with
|
||||
| .fvar fvarId => findVarValue fvarId
|
||||
| _ => return .top
|
||||
|
|
@ -421,7 +421,8 @@ Furthermore if we see that `params.size != args.size` we know that this is
|
|||
a partial application and set the values of the remaining parameters to
|
||||
`top` since it is impossible to track what will happen with them from here on.
|
||||
-/
|
||||
def updateFunDeclParamsAssignment (params : Array Param) (args : Array Arg) : InterpM Bool := do
|
||||
def updateFunDeclParamsAssignment (params : Array (Param .pure)) (args : Array (Arg .pure)) :
|
||||
InterpM Bool := do
|
||||
let mut ret := false
|
||||
let env ← getEnv
|
||||
for param in params, arg in args do
|
||||
|
|
@ -443,7 +444,7 @@ def updateFunDeclParamsAssignment (params : Array Param) (args : Array Arg) : In
|
|||
updateVarAssignment param.fvarId .top
|
||||
return ret
|
||||
|
||||
def updateFunDeclParamsTop (params : Array Param) : InterpM Bool := do
|
||||
def updateFunDeclParamsTop (params : Array (Param .pure)) : InterpM Bool := do
|
||||
let mut ret := false
|
||||
for param in params do
|
||||
let paramVal ← findVarValue param.fvarId
|
||||
|
|
@ -453,7 +454,7 @@ def updateFunDeclParamsTop (params : Array Param) : InterpM Bool := do
|
|||
ret := true
|
||||
return ret
|
||||
|
||||
private partial def resetNestedFunDeclParams : Code → InterpM Unit
|
||||
private partial def resetNestedFunDeclParams : Code .pure → InterpM Unit
|
||||
| .let _ k => resetNestedFunDeclParams k
|
||||
| .jp decl k | .fun decl k => do
|
||||
decl.params.forM (resetVarAssignment ·.fvarId)
|
||||
|
|
@ -467,7 +468,7 @@ private partial def resetNestedFunDeclParams : Code → InterpM Unit
|
|||
/--
|
||||
The actual abstract interpreter on a block of `Code`.
|
||||
-/
|
||||
partial def interpCode : Code → InterpM Unit
|
||||
partial def interpCode : Code .pure → InterpM Unit
|
||||
| .let decl k => do
|
||||
let val ← interpLetValue decl.value
|
||||
updateVarAssignment decl.fvarId val
|
||||
|
|
@ -503,7 +504,7 @@ where
|
|||
/--
|
||||
The abstract interpreter on a `LetValue`.
|
||||
-/
|
||||
interpLetValue (letVal : LetValue) : InterpM Value := do
|
||||
interpLetValue (letVal : LetValue .pure) : InterpM Value := do
|
||||
match letVal with
|
||||
| .lit val => return .ofLCNFLit val
|
||||
| .proj _ idx struct =>
|
||||
|
|
@ -513,7 +514,7 @@ where
|
|||
let env ← getEnv
|
||||
args.forM handleFunArg
|
||||
match (← getDecl? declName) with
|
||||
| some decl =>
|
||||
| some ⟨_, decl⟩ =>
|
||||
if decl.getArity == args.size then
|
||||
match getFunctionSummary? env declName with
|
||||
| some v => return v
|
||||
|
|
@ -538,7 +539,7 @@ where
|
|||
return .top
|
||||
| .erased => return .top
|
||||
|
||||
handleFunArg (arg : Arg) : InterpM Unit := do
|
||||
handleFunArg (arg : Arg .pure) : InterpM Unit := do
|
||||
if let .fvar fvarId := arg then
|
||||
handleFunVar fvarId
|
||||
|
||||
|
|
@ -557,7 +558,7 @@ where
|
|||
resetNestedFunDeclParams funDecl.value
|
||||
interpCode funDecl.value
|
||||
|
||||
interpFunCall (funDecl : FunDecl) (args : Array Arg) : InterpM Unit := do
|
||||
interpFunCall (funDecl : FunDecl .pure) (args : Array (Arg .pure)) : InterpM Unit := do
|
||||
let updated ← updateFunDeclParamsAssignment funDecl.params args
|
||||
if updated then
|
||||
/- We must reset the value of nested function declaration
|
||||
|
|
@ -608,11 +609,11 @@ Use the information produced by the abstract interpreter to:
|
|||
- Eliminate branches that we know cannot be hit
|
||||
- Eliminate values that we know have to be constants.
|
||||
-/
|
||||
partial def elimDead (assignment : Assignment) (decl : Decl) : CompilerM Decl := do
|
||||
partial def elimDead (assignment : Assignment) (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
trace[Compiler.elimDeadBranches] s!"Eliminating {decl.name} with {repr (← assignment.toArray |>.mapM (fun (name, val) => do return (toString (← getBinderName name), val)))}"
|
||||
return { decl with value := (← decl.value.mapCodeM go) }
|
||||
where
|
||||
go (code : Code) : CompilerM Code := do
|
||||
go (code : Code .pure) : CompilerM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
return code.updateLet! decl (← go k)
|
||||
|
|
@ -624,16 +625,14 @@ where
|
|||
match alt with
|
||||
| .alt ctor args body =>
|
||||
if discrVal.containsCtor ctor then
|
||||
let filter param := do
|
||||
let constantInfos ← args.filterMapM fun param => do
|
||||
if let some val := assignment[param.fvarId]? then
|
||||
if let some literal ← val.getLiteral then
|
||||
return some (param, literal)
|
||||
return none
|
||||
let constantInfos ← args.filterMapM filter
|
||||
if constantInfos.size != 0 then
|
||||
let folder := fun (body, subst) (param, decls, var) => do
|
||||
let (body, subst) ← constantInfos.foldlM (init := (← go body, {})) fun (body, subst) (param, decls, var) => do
|
||||
return (attachCodeDecls decls body, subst.insert param.fvarId (.fvar var))
|
||||
let (body, subst) ← constantInfos.foldlM (init := (← go body, {})) folder
|
||||
let body ← replaceFVars body subst false
|
||||
return alt.updateCode body
|
||||
else
|
||||
|
|
@ -649,7 +648,7 @@ where
|
|||
end UnreachableBranches
|
||||
|
||||
open UnreachableBranches in
|
||||
def Decl.elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
def Decl.elimDeadBranches (decls : Array (Decl .pure)) : CompilerM (Array (Decl .pure)) := do
|
||||
/-
|
||||
We sort declarations by size here to ensure that when we restart in inferStep it will mostly be
|
||||
small declarations that get re-analyzed.
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ public section
|
|||
namespace Lean.Compiler.LCNF
|
||||
namespace ExtractClosed
|
||||
|
||||
abbrev ExtractM := StateRefT (Array CodeDecl) CompilerM
|
||||
abbrev ExtractM := StateRefT (Array (CodeDecl .pure)) CompilerM
|
||||
|
||||
mutual
|
||||
|
||||
partial def extractLetValue (v : LetValue) : ExtractM Unit := do
|
||||
partial def extractLetValue (v : LetValue .pure) : ExtractM Unit := do
|
||||
match v with
|
||||
| .const _ _ args => args.forM extractArg
|
||||
| .fvar fnVar args =>
|
||||
|
|
@ -29,7 +29,7 @@ partial def extractLetValue (v : LetValue) : ExtractM Unit := do
|
|||
| .proj _ _ baseVar => extractFVar baseVar
|
||||
| .lit _ | .erased => return ()
|
||||
|
||||
partial def extractArg (arg : Arg) : ExtractM Unit := do
|
||||
partial def extractArg (arg : Arg .pure) : ExtractM Unit := do
|
||||
match arg with
|
||||
| .fvar fvarId => extractFVar fvarId
|
||||
| .type _ | .erased => return ()
|
||||
|
|
@ -41,17 +41,17 @@ partial def extractFVar (fvarId : FVarId) : ExtractM Unit := do
|
|||
|
||||
end
|
||||
|
||||
def isIrrelevantArg (arg : Arg) : Bool :=
|
||||
def isIrrelevantArg (arg : Arg .pure) : Bool :=
|
||||
match arg with
|
||||
| .erased | .type _ => true
|
||||
| .fvar _ => false
|
||||
|
||||
structure Context where
|
||||
baseName : Name
|
||||
sccDecls : Array Decl
|
||||
sccDecls : Array (Decl .pure)
|
||||
|
||||
structure State where
|
||||
decls : Array Decl := {}
|
||||
decls : Array (Decl .pure) := {}
|
||||
/--
|
||||
Cache for `shouldExtractFVar` in order to avoid superlinear behavior.
|
||||
-/
|
||||
|
|
@ -61,7 +61,7 @@ abbrev M := ReaderT Context $ StateRefT State CompilerM
|
|||
|
||||
mutual
|
||||
|
||||
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
|
||||
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue .pure) : M Bool := do
|
||||
match v with
|
||||
| .lit (.str _) => return true
|
||||
| .lit (.nat v) =>
|
||||
|
|
@ -90,7 +90,7 @@ partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
|
|||
| .fvar fnVar args => return (← shouldExtractFVar fnVar) && (← args.allM shouldExtractArg)
|
||||
| .proj _ _ baseVar => shouldExtractFVar baseVar
|
||||
|
||||
partial def shouldExtractArg (arg : Arg) : M Bool := do
|
||||
partial def shouldExtractArg (arg : Arg .pure) : M Bool := do
|
||||
match arg with
|
||||
| .fvar fvarId => shouldExtractFVar fvarId
|
||||
| .type _ | .erased => return true
|
||||
|
|
@ -113,7 +113,7 @@ end
|
|||
|
||||
mutual
|
||||
|
||||
partial def visitCode (code : Code) : M Code := do
|
||||
partial def visitCode (code : Code .pure) : M (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
if (← shouldExtractLetValue true decl.value) then
|
||||
|
|
@ -151,13 +151,14 @@ partial def visitCode (code : Code) : M Code := do
|
|||
|
||||
end
|
||||
|
||||
def visitDecl (decl : Decl) : M Decl := do
|
||||
def visitDecl (decl : Decl .pure) : M (Decl .pure) := do
|
||||
let value ← decl.value.mapCodeM visitCode
|
||||
return { decl with value }
|
||||
|
||||
end ExtractClosed
|
||||
|
||||
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do
|
||||
partial def Decl.extractClosed (decl : Decl .pure) (sccDecls : Array (Decl .pure)) :
|
||||
CompilerM (Array (Decl .pure)) := do
|
||||
let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
|
||||
return s.decls.push decl
|
||||
|
||||
|
|
|
|||
|
|
@ -48,67 +48,67 @@ instance : TraverseFVar Expr where
|
|||
mapFVarM := Expr.mapFVarM
|
||||
forFVarM := Expr.forFVarM
|
||||
|
||||
def Arg.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (arg : Arg) : m Arg := do
|
||||
def Arg.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (arg : Arg pu) : m (Arg pu) := do
|
||||
match arg with
|
||||
| .erased => return .erased
|
||||
| .type e => return arg.updateType! (← TraverseFVar.mapFVarM f e)
|
||||
| .type e _ => return arg.updateType! (← TraverseFVar.mapFVarM f e)
|
||||
| .fvar fvarId => return arg.updateFVar! (← f fvarId)
|
||||
|
||||
def Arg.forFVarM [Monad m] (f : FVarId → m Unit) (arg : Arg) : m Unit := do
|
||||
def Arg.forFVarM [Monad m] (f : FVarId → m Unit) (arg : Arg pu) : m Unit := do
|
||||
match arg with
|
||||
| .erased => return ()
|
||||
| .type e => TraverseFVar.forFVarM f e
|
||||
| .type e _ => TraverseFVar.forFVarM f e
|
||||
| .fvar fvarId => f fvarId
|
||||
|
||||
instance : TraverseFVar Arg where
|
||||
instance : TraverseFVar (Arg pu) where
|
||||
mapFVarM := Arg.mapFVarM
|
||||
forFVarM := Arg.forFVarM
|
||||
|
||||
def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (e : LetValue) : m LetValue := do
|
||||
def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (e : LetValue pu) : m (LetValue pu) := do
|
||||
match e with
|
||||
| .lit .. | .erased => return e
|
||||
| .proj _ _ fvarId => return e.updateProj! (← f fvarId)
|
||||
| .const _ _ args => return e.updateArgs! (← args.mapM (TraverseFVar.mapFVarM f))
|
||||
| .proj _ _ fvarId _ => return e.updateProj! (← f fvarId)
|
||||
| .const _ _ args _ => return e.updateArgs! (← args.mapM (TraverseFVar.mapFVarM f))
|
||||
| .fvar fvarId args => return e.updateFVar! (← f fvarId) (← args.mapM (TraverseFVar.mapFVarM f))
|
||||
|
||||
def LetValue.forFVarM [Monad m] (f : FVarId → m Unit) (e : LetValue) : m Unit := do
|
||||
def LetValue.forFVarM [Monad m] (f : FVarId → m Unit) (e : LetValue pu) : m Unit := do
|
||||
match e with
|
||||
| .lit .. | .erased => return ()
|
||||
| .proj _ _ fvarId => f fvarId
|
||||
| .const _ _ args => args.forM (TraverseFVar.forFVarM f)
|
||||
| .proj _ _ fvarId _ => f fvarId
|
||||
| .const _ _ args _ => args.forM (TraverseFVar.forFVarM f)
|
||||
| .fvar fvarId args => f fvarId; args.forM (TraverseFVar.forFVarM f)
|
||||
|
||||
instance : TraverseFVar LetValue where
|
||||
instance : TraverseFVar (LetValue pu) where
|
||||
mapFVarM := LetValue.mapFVarM
|
||||
forFVarM := LetValue.forFVarM
|
||||
|
||||
partial def LetDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : LetDecl) : m LetDecl := do
|
||||
partial def LetDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : LetDecl pu) : m (LetDecl pu) := do
|
||||
decl.update (← Expr.mapFVarM f decl.type) (← LetValue.mapFVarM f decl.value)
|
||||
|
||||
partial def LetDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : LetDecl) : m Unit := do
|
||||
partial def LetDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : LetDecl pu) : m Unit := do
|
||||
Expr.forFVarM f decl.type
|
||||
LetValue.forFVarM f decl.value
|
||||
|
||||
instance : TraverseFVar LetDecl where
|
||||
instance : TraverseFVar (LetDecl pu) where
|
||||
mapFVarM := LetDecl.mapFVarM
|
||||
forFVarM := LetDecl.forFVarM
|
||||
|
||||
partial def Param.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (param : Param) : m Param := do
|
||||
partial def Param.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (param : Param pu) : m (Param pu) := do
|
||||
param.update (← Expr.mapFVarM f param.type)
|
||||
|
||||
partial def Param.forFVarM [Monad m] (f : FVarId → m Unit) (param : Param) : m Unit := do
|
||||
partial def Param.forFVarM [Monad m] (f : FVarId → m Unit) (param : Param pu) : m Unit := do
|
||||
Expr.forFVarM f param.type
|
||||
|
||||
instance : TraverseFVar Param where
|
||||
instance : TraverseFVar (Param pu) where
|
||||
mapFVarM := Param.mapFVarM
|
||||
forFVarM := Param.forFVarM
|
||||
|
||||
partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (c : Code) : m Code := do
|
||||
partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (c : Code pu) : m (Code pu) := do
|
||||
match c with
|
||||
| .let decl k =>
|
||||
let decl ← LetDecl.mapFVarM f decl
|
||||
return Code.updateLet! c decl (← mapFVarM f k)
|
||||
| .fun decl k =>
|
||||
| .fun decl k _ =>
|
||||
let params ← decl.params.mapM (Param.mapFVarM f)
|
||||
let decl ← decl.update (← Expr.mapFVarM f decl.type) params (← mapFVarM f decl.value)
|
||||
return Code.updateFun! c decl (← mapFVarM f k)
|
||||
|
|
@ -125,12 +125,12 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F
|
|||
| .unreach typ =>
|
||||
return Code.updateUnreach! c (← Expr.mapFVarM f typ)
|
||||
|
||||
partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code) : m Unit := do
|
||||
partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Unit := do
|
||||
match c with
|
||||
| .let decl k =>
|
||||
LetDecl.forFVarM f decl
|
||||
forFVarM f k
|
||||
| .fun decl k =>
|
||||
| .fun decl k _ =>
|
||||
decl.params.forM (Param.forFVarM f)
|
||||
Expr.forFVarM f decl.type
|
||||
forFVarM f decl.value
|
||||
|
|
@ -151,45 +151,45 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code) : m Unit
|
|||
| .unreach typ =>
|
||||
Expr.forFVarM f typ
|
||||
|
||||
instance : TraverseFVar Code where
|
||||
instance : TraverseFVar (Code pu) where
|
||||
mapFVarM := Code.mapFVarM
|
||||
forFVarM := Code.forFVarM
|
||||
|
||||
def FunDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : FunDecl) : m FunDecl := do
|
||||
def FunDecl.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarId) (decl : FunDecl pu) : m (FunDecl pu) := do
|
||||
let params ← decl.params.mapM (Param.mapFVarM f)
|
||||
decl.update (← Expr.mapFVarM f decl.type) params (← Code.mapFVarM f decl.value)
|
||||
|
||||
def FunDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : FunDecl) : m Unit := do
|
||||
def FunDecl.forFVarM [Monad m] (f : FVarId → m Unit) (decl : FunDecl pu) : m Unit := do
|
||||
decl.params.forM (Param.forFVarM f)
|
||||
Expr.forFVarM f decl.type
|
||||
Code.forFVarM f decl.value
|
||||
|
||||
instance : TraverseFVar FunDecl where
|
||||
instance : TraverseFVar (FunDecl pu) where
|
||||
mapFVarM := FunDecl.mapFVarM
|
||||
forFVarM := FunDecl.forFVarM
|
||||
|
||||
instance : TraverseFVar CodeDecl where
|
||||
instance : TraverseFVar (CodeDecl pu) where
|
||||
mapFVarM f decl := do
|
||||
match decl with
|
||||
| .fun decl => return .fun (← mapFVarM f decl)
|
||||
| .fun decl _ => return .fun (← mapFVarM f decl)
|
||||
| .jp decl => return .jp (← mapFVarM f decl)
|
||||
| .let decl => return .let (← mapFVarM f decl)
|
||||
forFVarM f decl :=
|
||||
match decl with
|
||||
| .fun decl => forFVarM f decl
|
||||
| .fun decl _ => forFVarM f decl
|
||||
| .jp decl => forFVarM f decl
|
||||
| .let decl => forFVarM f decl
|
||||
|
||||
instance : TraverseFVar Alt where
|
||||
instance : TraverseFVar (Alt pu) where
|
||||
mapFVarM f alt := do
|
||||
match alt with
|
||||
| .alt ctor params c =>
|
||||
| .alt ctor params c _ =>
|
||||
let params ← params.mapM (Param.mapFVarM f)
|
||||
return .alt ctor params (← Code.mapFVarM f c)
|
||||
| .default c => return .default (← Code.mapFVarM f c)
|
||||
forFVarM f alt := do
|
||||
match alt with
|
||||
| .alt _ params c =>
|
||||
| .alt _ params c _ =>
|
||||
params.forM (Param.forFVarM f)
|
||||
Code.forFVarM f c
|
||||
| .default c => Code.forFVarM f c
|
||||
|
|
|
|||
|
|
@ -46,12 +46,12 @@ inductive AbsValue where
|
|||
|
||||
structure Context where
|
||||
/-- Declaration in the same mutual block. -/
|
||||
decls : Array Decl
|
||||
decls : Array (Decl .pure)
|
||||
/--
|
||||
Function being analyzed. We check every recursive call to this function.
|
||||
Remark: `main` is in `decls`.
|
||||
-/
|
||||
main : Decl
|
||||
main : Decl .pure
|
||||
/--
|
||||
The assignment maps free variable ids in the current code being analyzed to abstract values.
|
||||
We only track the abstract value assigned to parameters.
|
||||
|
|
@ -84,17 +84,17 @@ def evalFVar (fvarId : FVarId) : FixParamM AbsValue := do
|
|||
let some val := (← read).assignment.get? fvarId | return .top
|
||||
return val
|
||||
|
||||
def evalArg (arg : Arg) : FixParamM AbsValue := do
|
||||
def evalArg (arg : Arg .pure) : FixParamM AbsValue := do
|
||||
match arg with
|
||||
| .erased => return .erased
|
||||
| .type (.fvar fvarId) => evalFVar fvarId
|
||||
| .type _ => return .top
|
||||
| .type (.fvar fvarId) _ => evalFVar fvarId
|
||||
| .type _ _ => return .top
|
||||
| .fvar fvarId => evalFVar fvarId
|
||||
|
||||
def inMutualBlock (declName : Name) : FixParamM Bool :=
|
||||
return (← read).decls.any (·.name == declName)
|
||||
|
||||
def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do
|
||||
def mkAssignment (decl : Decl .pure) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do
|
||||
let mut assignment := {}
|
||||
for param in decl.params, value in values do
|
||||
assignment := assignment.insert param.fvarId value
|
||||
|
|
@ -102,12 +102,12 @@ def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue :=
|
|||
|
||||
mutual
|
||||
|
||||
partial def evalLetValue (e : LetValue) : FixParamM Unit := do
|
||||
partial def evalLetValue (e : LetValue .pure) : FixParamM Unit := do
|
||||
match e with
|
||||
| .const declName _ args => evalApp declName args
|
||||
| .const declName _ args _ => evalApp declName args
|
||||
| _ => return ()
|
||||
|
||||
partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do
|
||||
partial def isEquivalentFunDecl? (decl : FunDecl .pure) : FixParamM (Option Nat) := do
|
||||
let .let { fvarId, value := (.fvar funFvarId args), .. } k := decl.value | return none
|
||||
if args.size != decl.params.size then return none
|
||||
let .return retFVarId := k | return none
|
||||
|
|
@ -120,10 +120,10 @@ partial def isEquivalentFunDecl? (decl : FunDecl) : FixParamM (Option Nat) := do
|
|||
if arg != .fvar param.fvarId && arg != .erased then return none
|
||||
return some funIdx
|
||||
|
||||
partial def evalCode (code : Code) : FixParamM Unit := do
|
||||
partial def evalCode (code : Code .pure) : FixParamM Unit := do
|
||||
match code with
|
||||
| .let decl k => evalLetValue decl.value; evalCode k
|
||||
| .fun decl k =>
|
||||
| .fun decl k _ =>
|
||||
if let some paramIdx ← isEquivalentFunDecl? decl then
|
||||
withReader (fun ctx =>
|
||||
{ ctx with assignment := ctx.assignment.insert decl.fvarId (.val paramIdx) })
|
||||
|
|
@ -135,7 +135,7 @@ partial def evalCode (code : Code) : FixParamM Unit := do
|
|||
| .cases c => c.alts.forM fun alt => evalCode alt.getCode
|
||||
| .unreach .. | .jmp .. | .return .. => return ()
|
||||
|
||||
partial def evalApp (declName : Name) (args : Array Arg) : FixParamM Unit := do
|
||||
partial def evalApp (declName : Name) (args : Array (Arg .pure)) : FixParamM Unit := do
|
||||
let main := (← read).main
|
||||
if declName == main.name then
|
||||
-- Recursive call to the function being analyzed
|
||||
|
|
@ -180,6 +180,9 @@ def mkInitialValues (numParams : Nat) : Array AbsValue := Id.run do
|
|||
end FixedParams
|
||||
open FixedParams
|
||||
|
||||
-- TODO: consider making it phase polymorphic, this requires detecting in place mutations of
|
||||
-- variables etc in addition to just graph theory
|
||||
|
||||
/--
|
||||
Given the (potentially mutually) recursive declarations `decls`,
|
||||
return a map from declaration name `decl.name` to a bit-mask `m` where `m[i]` is true
|
||||
|
|
@ -188,7 +191,7 @@ applications.
|
|||
The function assumes that if a function `f` was declared in a mutual block, then `decls`
|
||||
contains all (computationally relevant) functions in the mutual block.
|
||||
-/
|
||||
def mkFixedParamsMap (decls : Array Decl) : NameMap (Array Bool) := Id.run do
|
||||
def mkFixedParamsMap (decls : Array (Decl .pure)) : NameMap (Array Bool) := Id.run do
|
||||
let mut result := {}
|
||||
for decl in decls do
|
||||
let values := mkInitialValues decl.params.size
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ inductive Decision where
|
|||
| unknown
|
||||
deriving Hashable, BEq, Inhabited, Repr
|
||||
|
||||
def Decision.ofAlt : Alt → Decision
|
||||
def Decision.ofAlt : Alt .pure → Decision
|
||||
| .alt name _ _ => .arm name
|
||||
| .default _ => .default
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ structure BaseFloatContext where
|
|||
All the declarations that were collected in the current LCNF basic
|
||||
block up to the current statement (in reverse order for efficiency).
|
||||
-/
|
||||
decls : List CodeDecl := []
|
||||
decls : List (CodeDecl .pure) := []
|
||||
|
||||
/--
|
||||
The state for `FloatM`
|
||||
|
|
@ -67,7 +67,7 @@ structure FloatState where
|
|||
- Which declarations do we move into a certain arm
|
||||
- Which declarations do we move into the default arm
|
||||
-/
|
||||
newArms : Std.HashMap Decision (List CodeDecl)
|
||||
newArms : Std.HashMap Decision (List (CodeDecl .pure))
|
||||
|
||||
/--
|
||||
Use to collect relevant declarations for the floating mechanism.
|
||||
|
|
@ -82,7 +82,7 @@ abbrev FloatM := StateRefT FloatState BaseFloatM
|
|||
/--
|
||||
Add `decl` to the list of declarations and run `x` with that updated context.
|
||||
-/
|
||||
def withNewCandidate (decl : CodeDecl) (x : BaseFloatM α) : BaseFloatM α :=
|
||||
def withNewCandidate (decl : CodeDecl .pure) (x : BaseFloatM α) : BaseFloatM α :=
|
||||
withReader (fun r => { r with decls := decl :: r.decls }) do
|
||||
x
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ Whether to ignore `decl` for the floating mechanism. We want to do this if:
|
|||
- `decl`' is storing a typeclass instance
|
||||
- `decl` is a projection from a variable that is storing a typeclass instance
|
||||
-/
|
||||
def ignore? (decl : LetDecl) : BaseFloatM Bool := do
|
||||
def ignore? (decl : LetDecl .pure) : BaseFloatM Bool := do
|
||||
if (← isArrowClass? decl.type).isSome then
|
||||
return true
|
||||
else if let .proj _ _ fvarId := decl.value then
|
||||
|
|
@ -117,7 +117,7 @@ up to this point, with respect to `cs`. The initial decisions are:
|
|||
- `arm` or `default` if we see the declaration only being used in exactly one cases arm
|
||||
- `unknown` otherwise
|
||||
-/
|
||||
def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) := do
|
||||
def initialDecisions (cs : Cases .pure) : BaseFloatM (Std.HashMap FVarId Decision) := do
|
||||
let mut map := Std.HashMap.emptyWithCapacity (← read).decls.length
|
||||
let owned : Std.HashSet FVarId := ∅
|
||||
(map, _) ← (← read).decls.foldlM (init := (map, owned)) fun (acc, owned) val => do
|
||||
|
|
@ -135,12 +135,12 @@ def initialDecisions (cs : Cases) : BaseFloatM (Std.HashMap FVarId Decision) :=
|
|||
(_, map) ← goCases cs |>.run map
|
||||
return map
|
||||
where
|
||||
visitDecl (env : Environment) (value : CodeDecl) : StateM (Std.HashSet FVarId) Bool := do
|
||||
visitDecl (env : Environment) (value : CodeDecl .pure) : StateM (Std.HashSet FVarId) Bool := do
|
||||
match value with
|
||||
| .let decl => visitLetValue env decl.value
|
||||
| _ => return false -- will need to investigate whether that can be a problem
|
||||
|
||||
visitLetValue (env : Environment) (value : LetValue) : StateM (Std.HashSet FVarId) Bool := do
|
||||
visitLetValue (env : Environment) (value : LetValue .pure) : StateM (Std.HashSet FVarId) Bool := do
|
||||
match value with
|
||||
| .proj _ _ x => visitArg (.fvar x) true
|
||||
| .const nm _ args =>
|
||||
|
|
@ -158,7 +158,7 @@ where
|
|||
(← visitArg (.fvar x) false)
|
||||
| .erased | .lit _ => return false
|
||||
|
||||
visitArg (var : Arg) (borrowed : Bool) : StateM (Std.HashSet FVarId) Bool := do
|
||||
visitArg (var : Arg .pure) (borrowed : Bool) : StateM (Std.HashSet FVarId) Bool := do
|
||||
let .fvar v := var | return false
|
||||
let res := (← get).contains v
|
||||
unless borrowed do
|
||||
|
|
@ -173,16 +173,16 @@ where
|
|||
modify fun s => s.insert var .dont
|
||||
-- otherwise we already have the proper decision
|
||||
|
||||
goAlt (alt : Alt) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goAlt (alt : Alt .pure) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
forFVarM (goFVar (.ofAlt alt)) alt
|
||||
goCases (cs : Cases) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
goCases (cs : Cases .pure) : StateRefT (Std.HashMap FVarId Decision) BaseFloatM Unit :=
|
||||
cs.alts.forM goAlt
|
||||
|
||||
/--
|
||||
Compute the initial new arms. This will just set up a map from all arms of
|
||||
`cs` to empty `Array`s, plus one additional entry for `dont`.
|
||||
-/
|
||||
def initialNewArms (cs : Cases) : Std.HashMap Decision (List CodeDecl) := Id.run do
|
||||
def initialNewArms (cs : Cases .pure) : Std.HashMap Decision (List (CodeDecl .pure)) := Id.run do
|
||||
let mut map := Std.HashMap.emptyWithCapacity (cs.alts.size + 1)
|
||||
map := map.insert .dont []
|
||||
cs.alts.foldr (init := map) fun val acc => acc.insert (.ofAlt val) []
|
||||
|
|
@ -203,7 +203,7 @@ cases z with
|
|||
Here `x` and `y` are originally marked as getting floated into `n` and `m`
|
||||
respectively but since `z` can't be moved we don't want that to move `x` and `y`.
|
||||
-/
|
||||
def dontFloat (decl : CodeDecl) : FloatM Unit := do
|
||||
def dontFloat (decl : CodeDecl .pure) : FloatM Unit := do
|
||||
forFVarM goFVar decl
|
||||
modify fun s => { s with newArms := s.newArms.insert .dont (decl :: s.newArms[Decision.dont]!) }
|
||||
where
|
||||
|
|
@ -257,7 +257,7 @@ Will:
|
|||
```
|
||||
If we are at `y` `x` is still marked to be moved but we don't want that.
|
||||
-/
|
||||
def float (decl : CodeDecl) : FloatM Unit := do
|
||||
def float (decl : CodeDecl .pure) : FloatM Unit := do
|
||||
let arm := (← get).decision[decl.fvarId]!
|
||||
forFVarM (goFVar · arm) decl
|
||||
modify fun s => { s with newArms := s.newArms.insert arm (decl :: s.newArms[arm]!) }
|
||||
|
|
@ -273,7 +273,7 @@ where
|
|||
Iterate through `decl`, pushing local declarations that are only used in one
|
||||
control flow arm into said arm in order to avoid useless computations.
|
||||
-/
|
||||
partial def floatLetIn (decl : Decl) : CompilerM Decl := do
|
||||
partial def floatLetIn (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let newValue ← decl.value.mapCodeM go |>.run {}
|
||||
return { decl with value := newValue }
|
||||
where
|
||||
|
|
@ -296,7 +296,7 @@ where
|
|||
else
|
||||
float decl
|
||||
|
||||
go (code : Code) : BaseFloatM Code := do
|
||||
go (code : Code .pure) : BaseFloatM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
withNewCandidate (.let decl) do
|
||||
|
|
@ -334,11 +334,12 @@ where
|
|||
|
||||
end FloatLetIn
|
||||
|
||||
def Decl.floatLetIn (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.floatLetIn (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
FloatLetIn.floatLetIn decl
|
||||
|
||||
def floatLetIn (phase := Phase.base) (occurrence := 0) : Pass :=
|
||||
.mkPerDeclaration `floatLetIn Decl.floatLetIn phase occurrence
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `floatLetIn phase (h ▸ Decl.floatLetIn) occurrence
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.floatLetIn (inherited := true)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@ public section
|
|||
namespace Lean.Compiler.LCNF
|
||||
/-! # Type inference for LCNF -/
|
||||
|
||||
namespace InferType
|
||||
|
||||
namespace Pure
|
||||
|
||||
/-
|
||||
Note about **erasure confusion**.
|
||||
|
||||
|
|
@ -53,10 +57,9 @@ but the expected type is `S Nat Type (fun x => Nat)`. `fun x => Nat` is not eras
|
|||
here because it is a type former.
|
||||
-/
|
||||
|
||||
namespace InferType
|
||||
|
||||
/-
|
||||
Type inference algorithm for LCNF. Invoked by the LCNF type checker
|
||||
Type inference algorithm for pure LCNF. Invoked by the LCNF type checker
|
||||
to check correctness of LCNF IR.
|
||||
-/
|
||||
|
||||
|
|
@ -80,12 +83,12 @@ def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
|
|||
let b := type.abstract xs
|
||||
xs.size.foldRevM (init := b) fun i _ b => do
|
||||
let x := xs[i]
|
||||
let n ← InferType.getBinderName x.fvarId!
|
||||
let ty ← InferType.getType x.fvarId!
|
||||
let n ← getBinderName x.fvarId!
|
||||
let ty ← getType x.fvarId!
|
||||
let ty := ty.abstractRange i xs;
|
||||
return .forallE n ty b .default
|
||||
|
||||
def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr :=
|
||||
def mkForallParams (params : Array (Param .pure)) (type : Expr) : InferTypeM Expr :=
|
||||
let xs := params.map fun p => .fvar p.fvarId
|
||||
mkForallFVars xs type |>.run {}
|
||||
|
||||
|
|
@ -97,7 +100,7 @@ def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr :=
|
|||
def inferConstType (declName : Name) (us : List Level) : CompilerM Expr := do
|
||||
if declName == ``lcErased then
|
||||
return erasedExpr
|
||||
else if let some decl ← getDecl? declName then
|
||||
else if let some ⟨_, decl⟩ ← getDecl? declName then
|
||||
return decl.instantiateTypeLevelParams us
|
||||
else
|
||||
/- Declaration does not have code associated with it: constructor, inductive type, foreign function -/
|
||||
|
|
@ -114,7 +117,7 @@ def inferLitValueType (value : LitValue) : Expr :=
|
|||
| .usize .. => mkConst ``USize
|
||||
|
||||
mutual
|
||||
partial def inferArgType (arg : Arg) : InferTypeM Expr :=
|
||||
partial def inferArgType (arg : Arg .pure) : InferTypeM Expr :=
|
||||
match arg with
|
||||
| .erased => return erasedExpr
|
||||
| .type e => inferType e
|
||||
|
|
@ -124,13 +127,13 @@ mutual
|
|||
match e with
|
||||
| .const c us => inferConstType c us
|
||||
| .app .. => inferAppType e
|
||||
| .fvar fvarId => InferType.getType fvarId
|
||||
| .fvar fvarId => getType fvarId
|
||||
| .sort lvl => return .sort (mkLevelSucc lvl)
|
||||
| .forallE .. => inferForallType e
|
||||
| .lam .. => inferLambdaType e
|
||||
| .letE .. | .mvar .. | .mdata .. | .lit .. | .bvar .. | .proj .. => unreachable!
|
||||
|
||||
partial def inferLetValueType (e : LetValue) : InferTypeM Expr := do
|
||||
partial def inferLetValueType (e : LetValue .pure) : InferTypeM Expr := do
|
||||
match e with
|
||||
| .erased => return erasedExpr
|
||||
| .lit v => return inferLitValueType v
|
||||
|
|
@ -138,7 +141,7 @@ mutual
|
|||
| .const declName us args => inferAppTypeCore (← inferConstType declName us) args
|
||||
| .fvar fvarId args => inferAppTypeCore (← getType fvarId) args
|
||||
|
||||
partial def inferAppTypeCore (fType : Expr) (args : Array Arg) : InferTypeM Expr := do
|
||||
partial def inferAppTypeCore (fType : Expr) (args : Array (Arg .pure)) : InferTypeM Expr := do
|
||||
let mut j := 0
|
||||
let mut fType := fType
|
||||
for i in *...args.size do
|
||||
|
|
@ -237,60 +240,79 @@ mutual
|
|||
mkForallFVars fvars type
|
||||
|
||||
end
|
||||
end Pure
|
||||
|
||||
namespace Impure
|
||||
end Impure
|
||||
|
||||
end InferType
|
||||
|
||||
-- TODO
|
||||
def inferType (e : Expr) : CompilerM Expr :=
|
||||
InferType.inferType e |>.run {}
|
||||
InferType.Pure.inferType e |>.run {}
|
||||
|
||||
def inferAppType (fnType : Expr) (args : Array Arg) : CompilerM Expr :=
|
||||
InferType.inferAppTypeCore fnType args |>.run {}
|
||||
def inferAppType (fnType : Expr) (args : Array (Arg pu)) : CompilerM Expr :=
|
||||
match pu with
|
||||
| .pure => InferType.Pure.inferAppTypeCore fnType args |>.run {}
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def getLevel (type : Expr) : CompilerM Level := do
|
||||
match (← inferType type) with
|
||||
| .sort u => return u
|
||||
| e => if e.isErased then return levelOne else throwError "type expected{indentExpr type}"
|
||||
def Arg.inferType (arg : Arg pu) : CompilerM Expr :=
|
||||
match pu with
|
||||
| .pure => InferType.Pure.inferArgType arg |>.run {}
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def Arg.inferType (arg : Arg) : CompilerM Expr :=
|
||||
InferType.inferArgType arg |>.run {}
|
||||
def LetValue.inferType (e : LetValue pu) : CompilerM Expr :=
|
||||
match pu with
|
||||
| .pure => InferType.Pure.inferLetValueType e |>.run {}
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def LetValue.inferType (e : LetValue) : CompilerM Expr :=
|
||||
InferType.inferLetValueType e |>.run {}
|
||||
def Code.inferType (code : Code pu) : CompilerM Expr := do
|
||||
match pu with
|
||||
| .pure =>
|
||||
match code with
|
||||
| .let _ k | .fun _ k _ | .jp _ k => k.inferType
|
||||
| .return fvarId => getType fvarId
|
||||
| .jmp fvarId args => InferType.Pure.inferAppTypeCore (← getType fvarId) args |>.run {}
|
||||
| .unreach type => return type
|
||||
| .cases c => return c.resultType
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def Code.inferType (code : Code) : CompilerM Expr := do
|
||||
match code with
|
||||
| .let _ k | .fun _ k | .jp _ k => k.inferType
|
||||
| .return fvarId => getType fvarId
|
||||
| .jmp fvarId args => InferType.inferAppTypeCore (← getType fvarId) args |>.run {}
|
||||
| .unreach type => return type
|
||||
| .cases c => return c.resultType
|
||||
|
||||
def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr := do
|
||||
def Code.inferParamType (params : Array (Param pu)) (code : Code pu) : CompilerM Expr := do
|
||||
let type ← code.inferType
|
||||
let xs := params.map fun p => .fvar p.fvarId
|
||||
InferType.mkForallFVars xs type |>.run {}
|
||||
InferType.Pure.mkForallFVars xs type |>.run {}
|
||||
|
||||
def Alt.inferType (alt : Alt) : CompilerM Expr :=
|
||||
def Alt.inferType (alt : Alt pu) : CompilerM Expr :=
|
||||
alt.getCode.inferType
|
||||
|
||||
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : CompilerM LetDecl := do
|
||||
def mkAuxLetDecl (e : LetValue pu) (prefixName := `_x) : CompilerM (LetDecl pu) := do
|
||||
mkLetDecl (← mkFreshBinderName prefixName) (← e.inferType) e
|
||||
|
||||
def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr :=
|
||||
InferType.mkForallParams params type |>.run {}
|
||||
def mkForallParams (params : Array (Param pu)) (type : Expr) : CompilerM Expr :=
|
||||
match pu with
|
||||
| .pure => InferType.Pure.mkForallParams params type |>.run {}
|
||||
| .impure => panic! "Infer type for impure unimplemented" -- TODO
|
||||
|
||||
def mkAuxFunDecl (params : Array Param) (code : Code) (prefixName := `_f) : CompilerM FunDecl := do
|
||||
private def mkAuxFunDeclAux (params : Array (Param pu)) (code : Code pu) (prefixName : Name) :
|
||||
CompilerM (FunDecl pu) := do
|
||||
let type ← mkForallParams params (← code.inferType)
|
||||
let binderName ← mkFreshBinderName prefixName
|
||||
mkFunDecl binderName type params code
|
||||
|
||||
def mkAuxJpDecl (params : Array Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
|
||||
mkAuxFunDecl params code prefixName
|
||||
def mkAuxFunDecl (params : Array (Param .pure)) (code : Code .pure) (prefixName := `_f) :
|
||||
CompilerM (FunDecl .pure) := do
|
||||
mkAuxFunDeclAux params code prefixName
|
||||
|
||||
def mkAuxJpDecl' (param : Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
|
||||
def mkAuxJpDecl (params : Array (Param pu)) (code : Code pu) (prefixName := `_jp) :
|
||||
CompilerM (FunDecl pu) := do
|
||||
mkAuxFunDeclAux params code prefixName
|
||||
|
||||
def mkAuxJpDecl' (param : Param pu) (code : Code pu) (prefixName := `_jp) :
|
||||
CompilerM (FunDecl pu) := do
|
||||
let params := #[param]
|
||||
mkAuxFunDecl params code prefixName
|
||||
mkAuxFunDeclAux params code prefixName
|
||||
|
||||
def mkCasesResultType (alts : Array Alt) : CompilerM Expr := do
|
||||
def mkCasesResultType (alts : Array (Alt pu)) : CompilerM Expr := do
|
||||
if alts.isEmpty then
|
||||
throwError "`Code.bind` failed, empty `cases` found"
|
||||
let mut resultType ← alts[0]!.inferType
|
||||
|
|
|
|||
|
|
@ -22,44 +22,45 @@ private def refreshBinderName (binderName : Name) : CompilerM Name := do
|
|||
|
||||
namespace Internalize
|
||||
|
||||
abbrev InternalizeM := StateRefT FVarSubst CompilerM
|
||||
abbrev InternalizeM (pu : Purity) := StateRefT (FVarSubst pu) CompilerM
|
||||
|
||||
/--
|
||||
The `InternalizeM` monad is a translator. It "translates" the free variables
|
||||
in the input expressions and `Code`, into new fresh free variables in the
|
||||
local context.
|
||||
-/
|
||||
instance : MonadFVarSubst InternalizeM true where
|
||||
instance : MonadFVarSubst (InternalizeM pu) pu true where
|
||||
getSubst := get
|
||||
|
||||
instance : MonadFVarSubstState InternalizeM where
|
||||
instance : MonadFVarSubstState (InternalizeM pu) pu where
|
||||
modifySubst := modify
|
||||
|
||||
private def mkNewFVarId (fvarId : FVarId) : InternalizeM FVarId := do
|
||||
private def mkNewFVarId (fvarId : FVarId) : InternalizeM pu FVarId := do
|
||||
let fvarId' ← Lean.mkFreshFVarId
|
||||
addFVarSubst fvarId fvarId'
|
||||
return fvarId'
|
||||
|
||||
private partial def internalizeExpr (e : Expr) : InternalizeM Expr :=
|
||||
private partial def internalizeExpr (e : Expr) : InternalizeM pu Expr :=
|
||||
go e
|
||||
where
|
||||
goApp (e : Expr) : InternalizeM Expr := do
|
||||
goApp (e : Expr) : InternalizeM pu Expr := do
|
||||
match e with
|
||||
| .app f a => return e.updateApp! (← goApp f) (← go a)
|
||||
| _ => go e
|
||||
|
||||
go (e : Expr) : InternalizeM Expr := do
|
||||
go (e : Expr) : InternalizeM pu Expr := do
|
||||
if e.hasFVar then
|
||||
match e with
|
||||
| .fvar fvarId => match (← get)[fvarId]? with
|
||||
| .fvar fvarId =>
|
||||
match (← get)[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
-- In LCNF, types can't depend on let-bound fvars.
|
||||
if (← findParam? fvarId').isSome then
|
||||
if (← findParam? (pu := pu) fvarId').isSome then
|
||||
return .fvar fvarId'
|
||||
else
|
||||
return anyExpr
|
||||
| some .erased => return erasedExpr
|
||||
| some (.type e) | none => return e
|
||||
| some (.type e _) | none => return e
|
||||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => return e
|
||||
| .app f a => return e.updateApp! (← goApp f) (← go a) |>.headBeta
|
||||
| .mdata _ b => return e.updateMData! (← go b)
|
||||
|
|
@ -70,7 +71,7 @@ where
|
|||
else
|
||||
return e
|
||||
|
||||
def internalizeParam (p : Param) : InternalizeM Param := do
|
||||
def internalizeParam (p : Param pu) : InternalizeM pu (Param pu) := do
|
||||
let binderName ← refreshBinderName p.binderName
|
||||
let type ← internalizeExpr p.type
|
||||
let fvarId ← mkNewFVarId p.fvarId
|
||||
|
|
@ -78,31 +79,31 @@ def internalizeParam (p : Param) : InternalizeM Param := do
|
|||
modifyLCtx fun lctx => lctx.addParam p
|
||||
return p
|
||||
|
||||
def internalizeArg (arg : Arg) : InternalizeM Arg := do
|
||||
def internalizeArg (arg : Arg pu) : InternalizeM pu (Arg pu) := do
|
||||
match arg with
|
||||
| .fvar fvarId =>
|
||||
match (← get)[fvarId]? with
|
||||
| some arg'@(.fvar _) => return arg'
|
||||
| some arg'@.erased | some arg'@(.type _) => return arg'
|
||||
| some arg'@.erased | some arg'@(.type _ _) => return arg'
|
||||
| none => return arg
|
||||
| .type e => return arg.updateType! (← internalizeExpr e)
|
||||
| .type e _ => return arg.updateType! (← internalizeExpr e)
|
||||
| .erased => return arg
|
||||
|
||||
def internalizeArgs (args : Array Arg) : InternalizeM (Array Arg) :=
|
||||
def internalizeArgs (args : Array (Arg pu)) : InternalizeM pu (Array (Arg pu)) :=
|
||||
args.mapM internalizeArg
|
||||
|
||||
private partial def internalizeLetValue (e : LetValue) : InternalizeM LetValue := do
|
||||
private partial def internalizeLetValue (e : LetValue pu) : InternalizeM pu (LetValue pu) := do
|
||||
match e with
|
||||
| .erased | .lit .. => return e
|
||||
| .proj _ _ fvarId => match (← normFVar fvarId) with
|
||||
| .proj _ _ fvarId _ => match (← normFVar fvarId) with
|
||||
| .fvar fvarId' => return e.updateProj! fvarId'
|
||||
| .erased => return .erased
|
||||
| .const _ _ args => return e.updateArgs! (← internalizeArgs args)
|
||||
| .const _ _ args _ => return e.updateArgs! (← internalizeArgs args)
|
||||
| .fvar fvarId args => match (← normFVar fvarId) with
|
||||
| .fvar fvarId' => return e.updateFVar! fvarId' (← internalizeArgs args)
|
||||
| .erased => return .erased
|
||||
|
||||
def internalizeLetDecl (decl : LetDecl) : InternalizeM LetDecl := do
|
||||
def internalizeLetDecl (decl : LetDecl pu) : InternalizeM pu (LetDecl pu) := do
|
||||
let binderName ← refreshBinderName decl.binderName
|
||||
let type ← internalizeExpr decl.type
|
||||
let value ← internalizeLetValue decl.value
|
||||
|
|
@ -113,7 +114,7 @@ def internalizeLetDecl (decl : LetDecl) : InternalizeM LetDecl := do
|
|||
|
||||
mutual
|
||||
|
||||
partial def internalizeFunDecl (decl : FunDecl) : InternalizeM FunDecl := do
|
||||
partial def internalizeFunDecl (decl : FunDecl pu) : InternalizeM pu (FunDecl pu) := do
|
||||
let type ← internalizeExpr decl.type
|
||||
let binderName ← refreshBinderName decl.binderName
|
||||
let params ← decl.params.mapM internalizeParam
|
||||
|
|
@ -123,10 +124,10 @@ partial def internalizeFunDecl (decl : FunDecl) : InternalizeM FunDecl := do
|
|||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||||
return decl
|
||||
|
||||
partial def internalizeCode (code : Code) : InternalizeM Code := do
|
||||
partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do
|
||||
match code with
|
||||
| .let decl k => return .let (← internalizeLetDecl decl) (← internalizeCode k)
|
||||
| .fun decl k => return .fun (← internalizeFunDecl decl) (← internalizeCode k)
|
||||
| .fun decl k _ => return .fun (← internalizeFunDecl decl) (← internalizeCode k)
|
||||
| .jp decl k => return .jp (← internalizeFunDecl decl) (← internalizeCode k)
|
||||
| .return fvarId => withNormFVarResult (← normFVar fvarId) fun fvarId => return .return fvarId
|
||||
| .jmp fvarId args => withNormFVarResult (← normFVar fvarId) fun fvarId => return .jmp fvarId (← internalizeArgs args)
|
||||
|
|
@ -134,19 +135,19 @@ partial def internalizeCode (code : Code) : InternalizeM Code := do
|
|||
| .cases c =>
|
||||
withNormFVarResult (← normFVar c.discr) fun discr => do
|
||||
let resultType ← internalizeExpr c.resultType
|
||||
let internalizeAltCode (k : Code) : InternalizeM Code :=
|
||||
let internalizeAltCode (k : Code pu) : InternalizeM pu (Code pu) :=
|
||||
internalizeCode k
|
||||
let alts ← c.alts.mapM fun
|
||||
| .alt ctorName params k => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k)
|
||||
| .alt ctorName params k _ => return .alt ctorName (← params.mapM internalizeParam) (← internalizeAltCode k)
|
||||
| .default k => return .default (← internalizeAltCode k)
|
||||
return .cases ⟨c.typeName, resultType, discr, alts⟩
|
||||
|
||||
end
|
||||
|
||||
partial def internalizeCodeDecl (decl : CodeDecl) : InternalizeM CodeDecl := do
|
||||
partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl pu) := do
|
||||
match decl with
|
||||
| .let decl => return .let (← internalizeLetDecl decl)
|
||||
| .fun decl => return .fun (← internalizeFunDecl decl)
|
||||
| .fun decl _ => return .fun (← internalizeFunDecl decl)
|
||||
| .jp decl => return .jp (← internalizeFunDecl decl)
|
||||
|
||||
end Internalize
|
||||
|
|
@ -154,14 +155,14 @@ end Internalize
|
|||
/--
|
||||
Refresh free variables ids in `code`, and store their declarations in the local context.
|
||||
-/
|
||||
partial def Code.internalize (code : Code) (s : FVarSubst := {}) : CompilerM Code :=
|
||||
partial def Code.internalize (code : Code pu) (s : FVarSubst pu := {}) : CompilerM (Code pu) :=
|
||||
Internalize.internalizeCode code |>.run' s
|
||||
|
||||
open Internalize in
|
||||
def Decl.internalize (decl : Decl) (s : FVarSubst := {}): CompilerM Decl :=
|
||||
def Decl.internalize (decl : Decl pu) (s : FVarSubst pu := {}): CompilerM (Decl pu) :=
|
||||
go decl |>.run' s
|
||||
where
|
||||
go (decl : Decl) : InternalizeM Decl := do
|
||||
go (decl : Decl pu) : InternalizeM pu (Decl pu) := do
|
||||
let type ← internalizeExpr decl.type
|
||||
let params ← decl.params.mapM internalizeParam
|
||||
let value ← decl.value.mapCodeM internalizeCode
|
||||
|
|
@ -170,13 +171,13 @@ where
|
|||
/--
|
||||
Create a fresh local context and internalize the given decls.
|
||||
-/
|
||||
def cleanup (decl : Array Decl) : CompilerM (Array Decl) := do
|
||||
def cleanup (decl : Array (Decl pu)) : CompilerM (Array (Decl pu)) := do
|
||||
modify fun _ => {}
|
||||
decl.mapM fun decl => do
|
||||
modify fun s => { s with nextIdx := 1 }
|
||||
decl.internalize
|
||||
|
||||
def normalizeFVarIds (decl : Decl) : CoreM Decl := do
|
||||
def normalizeFVarIds (decl : Decl pu) : CoreM (Decl pu) := do
|
||||
let ngenSaved ← getNGen
|
||||
setNGen {}
|
||||
try
|
||||
|
|
|
|||
|
|
@ -92,13 +92,13 @@ private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
|
|||
/--
|
||||
Remove all join point candidates contained in `a`.
|
||||
-/
|
||||
private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
|
||||
private partial def removeCandidatesInArg (a : Arg .pure) : FindM Unit := do
|
||||
forFVarM eraseCandidate a
|
||||
|
||||
/--
|
||||
Remove all join point candidates contained in `a`.
|
||||
-/
|
||||
private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
|
||||
private partial def removeCandidatesInLetValue (e : LetValue .pure) : FindM Unit := do
|
||||
forFVarM eraseCandidate e
|
||||
|
||||
/--
|
||||
|
|
@ -117,7 +117,7 @@ private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
|
|||
{ targetInfo with associated := targetInfo.associated.insert src }
|
||||
|
||||
@[inline]
|
||||
private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α :=
|
||||
private def withFnBody (decl : FunDecl .pure) (x : FindM α) : FindM α :=
|
||||
withReader (fun ctx => {
|
||||
ctx with
|
||||
definitionDepth := ctx.definitionDepth + 1,
|
||||
|
|
@ -125,7 +125,7 @@ private def withFnBody (decl : FunDecl) (x : FindM α) : FindM α :=
|
|||
x
|
||||
|
||||
@[inline]
|
||||
private def withFnDefined (decl : FunDecl) (x : FindM α) : FindM α :=
|
||||
private def withFnDefined (decl : FunDecl .pure) (x : FindM α) : FindM α :=
|
||||
withReader (fun ctx => {
|
||||
ctx with
|
||||
scope := ctx.scope.insert decl.fvarId ctx.definitionDepth }) do
|
||||
|
|
@ -163,11 +163,11 @@ def test (b : Bool) (x y : Nat) : Nat :=
|
|||
this. This is because otherwise the calls to `myjp` in `f` and `g` would
|
||||
produce out of scope join point jumps.
|
||||
-/
|
||||
partial def find (decl : Decl) : CompilerM FindState := do
|
||||
partial def find (decl : Decl .pure) : CompilerM FindState := do
|
||||
let (_, candidates) ← decl.value.forCodeM go |>.run {} |>.run {}
|
||||
return candidates
|
||||
where
|
||||
go : Code → FindM Unit
|
||||
go : Code .pure → FindM Unit
|
||||
| .let decl k => do
|
||||
match k, decl.value with
|
||||
| .return valId, .fvar fvarId args =>
|
||||
|
|
@ -207,13 +207,13 @@ where
|
|||
Replace all join point candidate `fun` declarations with `jp` ones
|
||||
and all calls to them with `jmp`s.
|
||||
-/
|
||||
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
|
||||
partial def replace (decl : Decl .pure) (state : FindState) : CompilerM (Decl .pure) := do
|
||||
let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
|
||||
let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := ∅) mapper
|
||||
let newValue ← decl.value.mapCodeM go |>.run replaceCtx
|
||||
return { decl with value := newValue }
|
||||
where
|
||||
go (code : Code) : ReplaceM Code := do
|
||||
go (code : Code .pure) : ReplaceM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
match k, decl.value with
|
||||
|
|
@ -274,7 +274,7 @@ structure ExtendState where
|
|||
to `Param`s. The free variables in this map are the once that the context
|
||||
of said join point will be extended by passing in the respective parameter.
|
||||
-/
|
||||
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId Param) := {}
|
||||
fvarMap : Std.HashMap FVarId (Std.HashMap FVarId (Param .pure)) := {}
|
||||
|
||||
/--
|
||||
The monad for the `extendJoinPointContext` pass.
|
||||
|
|
@ -388,7 +388,7 @@ the join point. This is so in the case of nested join points that refer
|
|||
to parameters of the current one we extend the context of the nested
|
||||
join points by said parameters.
|
||||
-/
|
||||
def withNewJpScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
|
||||
def withNewJpScope (decl : FunDecl .pure) (x : ExtendM α): ExtendM α := do
|
||||
withReader (fun ctx => { ctx with currentJp? := some decl.fvarId }) do
|
||||
modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId {} }
|
||||
withNewScope do
|
||||
|
|
@ -401,7 +401,7 @@ It will back up the current scope (since we are doing a case split
|
|||
and want to continue with other arms afterwards) and add all of the
|
||||
parameters of the match arm to the list of candidates.
|
||||
-/
|
||||
def withNewAltScope (alt : Alt) (x : ExtendM α) : ExtendM α := do
|
||||
def withNewAltScope (alt : Alt .pure) (x : ExtendM α) : ExtendM α := do
|
||||
withBackTrackingScope do
|
||||
withNewCandidates (alt.getParams.map (·.fvarId)) do
|
||||
x
|
||||
|
|
@ -418,7 +418,7 @@ All of this is done to eliminate dependencies of join points onto their
|
|||
position within the code so we can pull them out as far as possible, hopefully
|
||||
enabling new inlining possibilities in the next simplifier run.
|
||||
-/
|
||||
partial def extend (decl : Decl) : CompilerM Decl := do
|
||||
partial def extend (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let newValue ← decl.value.mapCodeM go |>.run {} |>.run' {} |>.run' {}
|
||||
let decl := { decl with value := newValue }
|
||||
decl.pullFunDecls
|
||||
|
|
@ -426,7 +426,7 @@ where
|
|||
goFVar (fvar : FVarId) : ExtendM FVarId := do
|
||||
extendByIfNecessary fvar
|
||||
replaceFVar fvar
|
||||
go (code : Code) : ExtendM Code := do
|
||||
go (code : Code .pure) : ExtendM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let decl ← decl.updateValue (← mapFVarM goFVar decl.value)
|
||||
|
|
@ -491,7 +491,7 @@ structure AnalysisState where
|
|||
A map, that for each join point id contains a map from all (so far)
|
||||
duplicated argument ids to the respective duplicate value
|
||||
-/
|
||||
jpJmpArgs : FVarIdMap FVarSubst := {}
|
||||
jpJmpArgs : FVarIdMap (FVarSubst .pure) := {}
|
||||
|
||||
abbrev ReduceAnalysisM := ReaderT AnalysisCtx StateRefT AnalysisState ScopeM
|
||||
abbrev ReduceActionM := ReaderT AnalysisState CompilerM
|
||||
|
|
@ -539,17 +539,17 @@ After we have performed all of these optimizations we can take away the
|
|||
(remaining) common arguments and end up with nicely floated and optimized
|
||||
code that has as little arguments as possible in the join points.
|
||||
-/
|
||||
partial def reduce (decl : Decl) : CompilerM Decl := do
|
||||
partial def reduce (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let (_, analysis) ← decl.value.forCodeM goAnalyze |>.run {} |>.run {} |>.run' {}
|
||||
let newValue ← decl.value.mapCodeM goReduce |>.run analysis
|
||||
return { decl with value := newValue }
|
||||
where
|
||||
goAnalyzeFunDecl (fn : FunDecl) : ReduceAnalysisM Unit := do
|
||||
goAnalyzeFunDecl (fn : FunDecl .pure) : ReduceAnalysisM Unit := do
|
||||
withNewScope do
|
||||
fn.params.forM (addToScope ·.fvarId)
|
||||
goAnalyze fn.value
|
||||
|
||||
goAnalyze (code : Code) : ReduceAnalysisM Unit := do
|
||||
goAnalyze (code : Code .pure) : ReduceAnalysisM Unit := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
addToScope decl.fvarId
|
||||
|
|
@ -571,7 +571,7 @@ where
|
|||
goAnalyze alt.getCode
|
||||
cs.alts.forM visitor
|
||||
| .jmp fn args =>
|
||||
let decl ← getFunDecl fn
|
||||
let decl ← getFunDecl (pu := .pure) fn
|
||||
if let some knownArgs := (← get).jpJmpArgs.get? fn then
|
||||
let mut newArgs := knownArgs
|
||||
for (param, arg) in decl.params.zip args do
|
||||
|
|
@ -589,7 +589,7 @@ where
|
|||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn interestingArgs }
|
||||
| .return .. | .unreach .. => return ()
|
||||
|
||||
goReduce (code : Code) : ReduceActionM Code := do
|
||||
goReduce (code : Code .pure) : ReduceActionM (Code .pure) := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
if let some reducibleArgs := (← read).jpJmpArgs.get? decl.fvarId then
|
||||
|
|
@ -613,7 +613,7 @@ where
|
|||
return Code.updateFun! code decl (← goReduce k)
|
||||
| .jmp fn args =>
|
||||
let reducibleArgs := (← read).jpJmpArgs.get! fn
|
||||
let decl ← getFunDecl fn
|
||||
let decl ← getFunDecl (pu := .pure) fn
|
||||
let newParams := decl.params.zip args
|
||||
|>.filter (!reducibleArgs.contains ·.fst.fvarId)
|
||||
|>.map Prod.snd
|
||||
|
|
@ -630,7 +630,7 @@ where
|
|||
|
||||
end JoinPointCommonArgs
|
||||
|
||||
def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do
|
||||
def Decl.findJoinPoints? (decl : Decl .pure) : CompilerM (Option (Decl .pure)) := do
|
||||
let findResult ← JoinPointFinder.find decl
|
||||
trace[Compiler.findJoinPoints] "Found {findResult.candidates.size} jp candidates for {decl.name}"
|
||||
if findResult.candidates.isEmpty then
|
||||
|
|
@ -642,29 +642,32 @@ def Decl.findJoinPoints? (decl : Decl) : CompilerM (Option Decl) := do
|
|||
Find all `fun` declarations in `decl` that qualify as join points then replace
|
||||
their definitions and call sites with `jp`/`jmp`.
|
||||
-/
|
||||
def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.findJoinPoints (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
return (← Decl.findJoinPoints? decl).getD decl
|
||||
|
||||
def findJoinPoints (occurrence : Nat := 0) : Pass :=
|
||||
.mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base (occurrence := occurrence)
|
||||
.mkPerDeclaration `findJoinPoints .base Decl.findJoinPoints (occurrence := occurrence)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.findJoinPoints (inherited := true)
|
||||
|
||||
def Decl.extendJoinPointContext (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.extendJoinPointContext (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
JoinPointContextExtender.extend decl
|
||||
|
||||
-- TODO: It might make sense to extend this to impure one day
|
||||
def extendJoinPointContext (occurrence : Nat := 0) (phase := Phase.mono) (_h : phase ≠ .base := by simp): Pass :=
|
||||
.mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext phase (occurrence := occurrence)
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `extendJoinPointContext phase (h ▸ Decl.extendJoinPointContext) (occurrence := occurrence)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.extendJoinPointContext (inherited := true)
|
||||
|
||||
def Decl.commonJoinPointArgs (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.commonJoinPointArgs (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
JoinPointCommonArgs.reduce decl
|
||||
|
||||
-- TODO: It might make sense to extend this to impure one day
|
||||
def commonJoinPointArgs : Pass :=
|
||||
.mkPerDeclaration `commonJoinPointArgs Decl.commonJoinPointArgs .mono
|
||||
.mkPerDeclaration `commonJoinPointArgs .mono Decl.commonJoinPointArgs
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.commonJoinPointArgs (inherited := true)
|
||||
|
|
|
|||
|
|
@ -16,61 +16,97 @@ namespace Lean.Compiler.LCNF
|
|||
LCNF local context.
|
||||
-/
|
||||
structure LCtx where
|
||||
params : Std.HashMap FVarId Param := {}
|
||||
letDecls : Std.HashMap FVarId LetDecl := {}
|
||||
funDecls : Std.HashMap FVarId FunDecl := {}
|
||||
paramsPure : Std.HashMap FVarId (Param .pure) := {}
|
||||
paramsImpure : Std.HashMap FVarId (Param .impure) := {}
|
||||
letDeclsPure : Std.HashMap FVarId (LetDecl .pure) := {}
|
||||
letDeclsImpure : Std.HashMap FVarId (LetDecl .impure) := {}
|
||||
funDeclsPure : Std.HashMap FVarId (FunDecl .pure) := {}
|
||||
funDeclsImpure : Std.HashMap FVarId (FunDecl .impure) := {}
|
||||
deriving Inhabited
|
||||
|
||||
def LCtx.addParam (lctx : LCtx) (param : Param) : LCtx :=
|
||||
{ lctx with params := lctx.params.insert param.fvarId param }
|
||||
def LCtx.addParam (lctx : LCtx) (param : Param pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with paramsPure := lctx.paramsPure.insert param.fvarId param }
|
||||
| .impure => { lctx with paramsImpure := lctx.paramsImpure.insert param.fvarId param }
|
||||
|
||||
def LCtx.addLetDecl (lctx : LCtx) (letDecl : LetDecl) : LCtx :=
|
||||
{ lctx with letDecls := lctx.letDecls.insert letDecl.fvarId letDecl }
|
||||
def LCtx.addLetDecl (lctx : LCtx) (letDecl : LetDecl pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with letDeclsPure := lctx.letDeclsPure.insert letDecl.fvarId letDecl }
|
||||
| .impure => { lctx with letDeclsImpure := lctx.letDeclsImpure.insert letDecl.fvarId letDecl }
|
||||
|
||||
def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl) : LCtx :=
|
||||
{ lctx with funDecls := lctx.funDecls.insert funDecl.fvarId funDecl }
|
||||
def LCtx.addFunDecl (lctx : LCtx) (funDecl : FunDecl pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with funDeclsPure := lctx.funDeclsPure.insert funDecl.fvarId funDecl }
|
||||
| .impure => { lctx with funDeclsImpure := lctx.funDeclsImpure.insert funDecl.fvarId funDecl }
|
||||
|
||||
def LCtx.eraseParam (lctx : LCtx) (param : Param) : LCtx :=
|
||||
{ lctx with params := lctx.params.erase param.fvarId }
|
||||
def LCtx.eraseParam (lctx : LCtx) (param : Param pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with paramsPure := lctx.paramsPure.erase param.fvarId }
|
||||
| .impure => { lctx with paramsImpure := lctx.paramsImpure.erase param.fvarId }
|
||||
|
||||
def LCtx.eraseParams (lctx : LCtx) (ps : Array Param) : LCtx :=
|
||||
{ lctx with params := ps.foldl (init := lctx.params) fun params p => params.erase p.fvarId }
|
||||
def LCtx.eraseParams (lctx : LCtx) (ps : Array (Param pu)) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with paramsPure := ps.foldl (init := lctx.paramsPure) fun params p => params.erase p.fvarId }
|
||||
| .impure => { lctx with paramsImpure := ps.foldl (init := lctx.paramsImpure) fun params p => params.erase p.fvarId }
|
||||
|
||||
def LCtx.eraseLetDecl (lctx : LCtx) (decl : LetDecl) : LCtx :=
|
||||
{ lctx with letDecls := lctx.letDecls.erase decl.fvarId }
|
||||
def LCtx.eraseLetDecl (lctx : LCtx) (decl : LetDecl pu) : LCtx :=
|
||||
match pu with
|
||||
| .pure => { lctx with letDeclsPure := lctx.letDeclsPure.erase decl.fvarId }
|
||||
| .impure => { lctx with letDeclsImpure := lctx.letDeclsImpure.erase decl.fvarId }
|
||||
|
||||
mutual
|
||||
partial def LCtx.eraseFunDecl (lctx : LCtx) (decl : FunDecl) (recursive := true) : LCtx :=
|
||||
let lctx := { lctx with funDecls := lctx.funDecls.erase decl.fvarId }
|
||||
partial def LCtx.eraseFunDecl (lctx : LCtx) (decl : FunDecl pu) (recursive := true) : LCtx :=
|
||||
let lctx :=
|
||||
match pu with
|
||||
| .pure => { lctx with funDeclsPure := lctx.funDeclsPure.erase decl.fvarId }
|
||||
| .impure => { lctx with funDeclsImpure := lctx.funDeclsImpure.erase decl.fvarId }
|
||||
if recursive then
|
||||
eraseCode decl.value <| eraseParams lctx decl.params
|
||||
else
|
||||
lctx
|
||||
|
||||
partial def LCtx.eraseAlts (alts : Array Alt) (lctx : LCtx) : LCtx :=
|
||||
partial def LCtx.eraseAlts (alts : Array (Alt pu)) (lctx : LCtx) : LCtx :=
|
||||
alts.foldl (init := lctx) fun lctx alt =>
|
||||
match alt with
|
||||
| .default k => eraseCode k lctx
|
||||
| .alt _ ps k => eraseCode k <| eraseParams lctx ps
|
||||
| .alt _ ps k _ => eraseCode k <| eraseParams lctx ps
|
||||
|
||||
partial def LCtx.eraseCode (code : Code) (lctx : LCtx) : LCtx :=
|
||||
partial def LCtx.eraseCode (code : Code pu) (lctx : LCtx) : LCtx :=
|
||||
match code with
|
||||
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
|
||||
| .jp decl k | .fun decl k => eraseCode k <| eraseFunDecl lctx decl
|
||||
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
|
||||
| .cases c => eraseAlts c.alts lctx
|
||||
| _ => lctx
|
||||
end
|
||||
|
||||
@[inline]
|
||||
def LCtx.params (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (Param pu) :=
|
||||
match pu with
|
||||
| .pure => lctx.paramsPure
|
||||
| .impure => lctx.paramsImpure
|
||||
|
||||
@[inline]
|
||||
def LCtx.letDecls (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (LetDecl pu) :=
|
||||
match pu with
|
||||
| .pure => lctx.letDeclsPure
|
||||
| .impure => lctx.letDeclsImpure
|
||||
|
||||
@[inline]
|
||||
def LCtx.funDecls (lctx : LCtx) (pu : Purity) : Std.HashMap FVarId (FunDecl pu) :=
|
||||
match pu with
|
||||
| .pure => lctx.funDeclsPure
|
||||
| .impure => lctx.funDeclsImpure
|
||||
|
||||
/--
|
||||
Convert a LCNF local context into a regular Lean local context.
|
||||
-/
|
||||
def LCtx.toLocalContext (lctx : LCtx) : LocalContext := Id.run do
|
||||
def LCtx.toLocalContext (lctx : LCtx) (pu : Purity) : LocalContext := Id.run do
|
||||
let mut result := {}
|
||||
for (_, param) in lctx.params.toArray do
|
||||
for (_, param) in lctx.params pu do
|
||||
result := result.addDecl (.cdecl 0 param.fvarId param.binderName param.type .default .default)
|
||||
for (_, decl) in lctx.letDecls.toArray do
|
||||
for (_, decl) in lctx.letDecls pu do
|
||||
result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value.toExpr true .default)
|
||||
for (_, decl) in lctx.funDecls.toArray do
|
||||
for (_, decl) in lctx.funDecls pu do
|
||||
result := result.addDecl (.cdecl 0 decl.fvarId decl.binderName decl.type .default .default)
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ structure Context where
|
|||
Declaration where lambda lifting is being applied.
|
||||
We use it to provide the "base name" for auxiliary declarations and the flag `safe`.
|
||||
-/
|
||||
mainDecl : Decl
|
||||
mainDecl : Decl .pure
|
||||
/--
|
||||
If true, the lambda-lifted functions inherit the inline attribute from `mainDecl`.
|
||||
We use this feature to implement `@[inline] instance ...` and `@[always_inline] instance ...`
|
||||
|
|
@ -51,7 +51,7 @@ structure State where
|
|||
/--
|
||||
New auxiliary declarations
|
||||
-/
|
||||
decls : Array Decl := #[]
|
||||
decls : Array (Decl .pure) := #[]
|
||||
/--
|
||||
Next index for generating auxiliary declaration name.
|
||||
-/
|
||||
|
|
@ -64,13 +64,13 @@ abbrev LiftM := ReaderT Context (StateRefT State (ScopeT CompilerM))
|
|||
Return `true` if the given declaration takes a local instance as a parameter.
|
||||
We lambda lift this kind of local function declaration before specialization.
|
||||
-/
|
||||
def hasInstParam (decl : FunDecl) : CompilerM Bool :=
|
||||
def hasInstParam (decl : FunDecl .pure) : CompilerM Bool :=
|
||||
decl.params.anyM fun param => return (← isArrowClass? param.type).isSome
|
||||
|
||||
/--
|
||||
Return `true` if the given declaration should be lambda lifted.
|
||||
-/
|
||||
def shouldLift (decl : FunDecl) : LiftM Bool := do
|
||||
def shouldLift (decl : FunDecl .pure) : LiftM Bool := do
|
||||
let minSize := (← read).minSize
|
||||
if decl.value.size < minSize then
|
||||
return false
|
||||
|
|
@ -85,7 +85,7 @@ partial def mkAuxDeclName : LiftM Name := do
|
|||
if (← getDecl? nameNew).isNone then return nameNew
|
||||
mkAuxDeclName
|
||||
|
||||
def replaceFunDecl (decl : FunDecl) (value : LetValue) : LiftM LetDecl := do
|
||||
def replaceFunDecl (decl : FunDecl .pure) (value : LetValue .pure) : LiftM (LetDecl .pure) := do
|
||||
/- We reuse `decl`s `fvarId` to avoid substitution -/
|
||||
let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value }
|
||||
modifyLCtx fun lctx => lctx.addLetDecl declNew
|
||||
|
|
@ -97,7 +97,7 @@ open Internalize in
|
|||
Create a new auxiliary declaration. The array `closure` contains all free variables
|
||||
occurring in `decl`.
|
||||
-/
|
||||
def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
|
||||
def mkAuxDecl (closure : Array (Param .pure)) (decl : FunDecl .pure) : LiftM (LetDecl .pure) := do
|
||||
let nameNew ← mkAuxDeclName
|
||||
let inlineAttr? ← if (← read).inheritInlineAttrs then pure (← read).mainDecl.inlineAttr? else pure none
|
||||
let auxDecl ← go nameNew (← read).mainDecl.safe inlineAttr? |>.run' {}
|
||||
|
|
@ -113,16 +113,16 @@ def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
|
|||
let value := .const auxDeclName us (closure.map (.fvar ·.fvarId))
|
||||
replaceFunDecl decl value
|
||||
where
|
||||
go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM Decl := do
|
||||
go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM .pure (Decl .pure):= do
|
||||
let params := (← closure.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam)
|
||||
let code ← internalizeCode decl.value
|
||||
let type ← code.inferType
|
||||
let type ← mkForallParams params type
|
||||
let value := .code code
|
||||
let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl }
|
||||
let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl .pure }
|
||||
return decl.setLevelParams
|
||||
|
||||
def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do
|
||||
def etaContractibleDecl? (decl : FunDecl .pure) : LiftM (Option (LetDecl .pure)) := do
|
||||
if !(← read).allowEtaContraction then return none
|
||||
let .let { fvarId := letVar, value := .const declName us args, .. } (.return retVar) := decl.value
|
||||
| return none
|
||||
|
|
@ -137,11 +137,11 @@ def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do
|
|||
replaceFunDecl decl value
|
||||
|
||||
mutual
|
||||
partial def visitFunDecl (funDecl : FunDecl) : LiftM FunDecl := do
|
||||
partial def visitFunDecl (funDecl : FunDecl .pure) : LiftM (FunDecl .pure) := do
|
||||
let value ← withParams funDecl.params <| visitCode funDecl.value
|
||||
funDecl.update' funDecl.type value
|
||||
|
||||
partial def visitCode (code : Code) : LiftM Code := do
|
||||
partial def visitCode (code : Code .pure) : LiftM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let k ← withFVar decl.fvarId <| visitCode k
|
||||
|
|
@ -174,14 +174,14 @@ mutual
|
|||
| .unreach .. | .jmp .. | .return .. => return code
|
||||
end
|
||||
|
||||
def main (decl : Decl) : LiftM Decl := do
|
||||
def main (decl : Decl .pure) : LiftM (Decl .pure) := do
|
||||
let value ← withParams decl.params <| decl.value.mapCodeM visitCode
|
||||
return { decl with value }
|
||||
|
||||
end LambdaLifting
|
||||
|
||||
partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (allowEtaContraction : Bool)
|
||||
(suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do
|
||||
partial def Decl.lambdaLifting (decl : Decl .pure) (liftInstParamOnly : Bool) (allowEtaContraction : Bool)
|
||||
(suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array (Decl .pure)) := do
|
||||
let ctx := {
|
||||
mainDecl := decl,
|
||||
liftInstParamOnly,
|
||||
|
|
|
|||
|
|
@ -105,45 +105,45 @@ open Lean.CollectLevelParams
|
|||
abbrev visitType (type : Expr) : Visitor :=
|
||||
visitExpr type
|
||||
|
||||
def visitArg (arg : Arg) : Visitor :=
|
||||
def visitArg (arg : Arg .pure) : Visitor :=
|
||||
match arg with
|
||||
| .erased | .fvar .. => id
|
||||
| .type e => visitType e
|
||||
| .type e _ => visitType e
|
||||
|
||||
def visitArgs (args : Array Arg) : Visitor :=
|
||||
def visitArgs (args : Array (Arg .pure)) : Visitor :=
|
||||
fun s => args.foldl (init := s) fun s arg => visitArg arg s
|
||||
|
||||
def visitLetValue (e : LetValue) : Visitor :=
|
||||
def visitLetValue (e : LetValue .pure) : Visitor :=
|
||||
match e with
|
||||
| .erased | .lit .. | .proj .. => id
|
||||
| .const _ us args => visitLevels us ∘ visitArgs args
|
||||
| .const _ us args _ => visitLevels us ∘ visitArgs args
|
||||
| .fvar _ args => visitArgs args
|
||||
|
||||
def visitParam (p : Param) : Visitor :=
|
||||
def visitParam (p : Param .pure) : Visitor :=
|
||||
visitType p.type
|
||||
|
||||
def visitParams (ps : Array Param) : Visitor :=
|
||||
def visitParams (ps : Array (Param .pure)) : Visitor :=
|
||||
fun s => ps.foldl (init := s) fun s p => visitParam p s
|
||||
|
||||
mutual
|
||||
partial def visitAlt (alt : Alt) : Visitor :=
|
||||
partial def visitAlt (alt : Alt .pure) : Visitor :=
|
||||
match alt with
|
||||
| .default k => visitCode k
|
||||
| .alt _ ps k => visitCode k ∘ visitParams ps
|
||||
| .alt _ ps k _ => visitCode k ∘ visitParams ps
|
||||
|
||||
partial def visitAlts (alts : Array Alt) : Visitor :=
|
||||
partial def visitAlts (alts : Array (Alt .pure)) : Visitor :=
|
||||
fun s => alts.foldl (init := s) fun s alt => visitAlt alt s
|
||||
|
||||
partial def visitCode : Code → Visitor
|
||||
partial def visitCode : Code .pure → Visitor
|
||||
| .let decl k => visitCode k ∘ visitLetValue decl.value ∘ visitType decl.type
|
||||
| .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type
|
||||
| .fun decl k _ | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type
|
||||
| .cases c => visitAlts c.alts ∘ visitType c.resultType
|
||||
| .unreach type => visitType type
|
||||
| .return _ => id
|
||||
| .jmp _ args => visitArgs args
|
||||
end
|
||||
|
||||
def visitDeclValue : DeclValue → Visitor
|
||||
def visitDeclValue : DeclValue .pure → Visitor
|
||||
| .code c => visitCode c
|
||||
| .extern .. => id
|
||||
|
||||
|
|
@ -156,7 +156,7 @@ open CollectLevelParams
|
|||
Collect universe level parameters collecting in the type, parameters, and value, and then
|
||||
set `decl.levelParams` with the resulting value.
|
||||
-/
|
||||
def Decl.setLevelParams (decl : Decl) : Decl :=
|
||||
def Decl.setLevelParams (decl : Decl .pure) : Decl .pure :=
|
||||
let levelParams := (visitDeclValue decl.value ∘ visitParams decl.params ∘ visitType decl.type) {} |>.params.toList
|
||||
{ decl with levelParams }
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import Lean.Meta.Match.MatcherInfo
|
|||
import Lean.Compiler.LCNF.SplitSCC
|
||||
public import Lean.Compiler.IR.Basic
|
||||
public import Lean.Compiler.LCNF.CompilerM
|
||||
|
||||
public section
|
||||
namespace Lean.Compiler.LCNF
|
||||
/--
|
||||
|
|
@ -50,7 +51,7 @@ A checkpoint in code generation to print all declarations in between
|
|||
compiler passes in order to ease debugging.
|
||||
The trace can be viewed with `set_option trace.Compiler.step true`.
|
||||
-/
|
||||
def checkpoint (stepName : Name) (decls : Array Decl) (shouldCheck : Bool) : CompilerM Unit := do
|
||||
def checkpoint (stepName : Name) (decls : Array (Decl pu)) (shouldCheck : Bool) : CompilerM Unit := do
|
||||
for decl in decls do
|
||||
trace[Compiler.stat] "{decl.name} : {decl.size}"
|
||||
withOptions (fun opts => opts.set `pp.motives.pi false) do
|
||||
|
|
@ -101,12 +102,12 @@ def run (declNames : Array Name) : CompilerM (Array (Array IR.Decl)) := withAtLe
|
|||
let decls := markRecDecls decls
|
||||
let manager ← getPassManager
|
||||
let isCheckEnabled := compiler.check.get (← getOptions)
|
||||
let decls ← runPassManagerPart "compilation (LCNF base)" manager.basePasses decls isCheckEnabled
|
||||
let decls ← runPassManagerPart "compilation (LCNF mono)" manager.monoPasses decls isCheckEnabled
|
||||
let decls ← runPassManagerPart .pure .pure "compilation (LCNF base)" manager.basePasses decls isCheckEnabled
|
||||
let decls ← runPassManagerPart .pure .pure "compilation (LCNF mono)" manager.monoPasses decls isCheckEnabled
|
||||
let sccs ← withTraceNode `Compiler.splitSCC (fun _ => return m!"Splitting up SCC") do
|
||||
splitScc decls
|
||||
sccs.mapM fun decls => do
|
||||
let decls ← runPassManagerPart "compilation (LCNF mono)" manager.monoPassesNoLambda decls isCheckEnabled
|
||||
let decls ← runPassManagerPart .pure .pure "compilation (LCNF mono)" manager.monoPassesNoLambda decls isCheckEnabled
|
||||
if (← Lean.isTracingEnabledFor `Compiler.result) then
|
||||
for decl in decls do
|
||||
let decl ← normalizeFVarIds decl
|
||||
|
|
@ -115,14 +116,19 @@ def run (declNames : Array Name) : CompilerM (Array (Array IR.Decl)) := withAtLe
|
|||
let irDecls ← IR.toIR decls
|
||||
IR.compile irDecls
|
||||
where
|
||||
runPassManagerPart (profilerName : String) (passes : Array Pass) (decls : Array Decl)
|
||||
(isCheckEnabled : Bool) : CompilerM (Array Decl) := do
|
||||
runPassManagerPart (inPhase outPhase : Purity) (profilerName : String)
|
||||
(passes : Array Pass) (decls : Array (Decl inPhase)) (isCheckEnabled : Bool) :
|
||||
CompilerM (Array (Decl outPhase)) := do
|
||||
profileitM Exception profilerName (← getOptions) do
|
||||
let mut decls := decls
|
||||
let mut state : (pu : Purity) × Array (Decl pu) := ⟨inPhase, decls⟩
|
||||
for pass in passes do
|
||||
decls ← withTraceNode `Compiler (fun _ => return m!"compiler phase: {pass.phase}, pass: {pass.name}") do
|
||||
withPhase pass.phase <| pass.run decls
|
||||
withPhase pass.phaseOut <| checkpoint pass.name decls (isCheckEnabled || pass.shouldAlwaysRunCheck)
|
||||
state ← withTraceNode `Compiler (fun _ => return m!"compiler phase: {pass.phase}, pass: {pass.name}") do
|
||||
let decls ← withPhase pass.phase do
|
||||
state.fst.withAssertPurity pass.phase.toPurity fun h => do
|
||||
pass.run (h ▸ state.snd)
|
||||
pure ⟨_, decls⟩
|
||||
withPhase pass.phaseOut <| checkpoint pass.name state.snd (isCheckEnabled || pass.shouldAlwaysRunCheck)
|
||||
let decls := state.fst.withAssertPurity outPhase fun h => h ▸ state.snd
|
||||
return decls
|
||||
|
||||
end PassManager
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ instance (m n) [MonadLift m n] [MonadFunctor m n] [MonadScope m] : MonadScope n
|
|||
def inScope [MonadScope m] [Monad m] (fvarId : FVarId) : m Bool :=
|
||||
return (← getScope).contains fvarId
|
||||
|
||||
@[inline] def withParams [MonadScope m] [Monad m] (ps : Array Param) (x : m α) : m α :=
|
||||
@[inline] def withParams [MonadScope m] [Monad m] (ps : Array (Param pu)) (x : m α) : m α :=
|
||||
withScope (fun s => ps.foldl (init := s) fun s p => s.insert p.fvarId) x
|
||||
|
||||
@[inline] def withFVar [MonadScope m] [Monad m] (fvarId : FVarId) (x : m α) : m α :=
|
||||
|
|
|
|||
|
|
@ -99,4 +99,7 @@ def getOtherDeclMonoType (declName : Name) : CoreM Expr := do
|
|||
monoTypeExt.insert declName type
|
||||
return type
|
||||
|
||||
def getOtherDeclImpureType (_declName : Name) : CoreM Expr := do
|
||||
panic! "Other decl impure type unimplemented" -- TODO
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -19,5 +19,6 @@ def getOtherDeclType (declName : Name) (us : List Level := []) : CompilerM Expr
|
|||
match (← getPhase) with
|
||||
| .base => getOtherDeclBaseType declName us
|
||||
| .mono => getOtherDeclMonoType declName
|
||||
| .impure => getOtherDeclImpureType declName
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -15,6 +15,20 @@ namespace Lean.Compiler.LCNF
|
|||
@[expose] def Phase.toNat : Phase → Nat
|
||||
| .base => 0
|
||||
| .mono => 1
|
||||
| .impure => 2
|
||||
|
||||
instance : ToString Phase where
|
||||
toString
|
||||
| .base => "base"
|
||||
| .mono => "mono"
|
||||
| .impure => "impure"
|
||||
|
||||
def Phase.withPurityCheck [Inhabited α] (pp : Phase) (ip : Purity)
|
||||
(x : pp.toPurity = ip → α) : α :=
|
||||
if h : pp.toPurity = ip then
|
||||
x h
|
||||
else
|
||||
panic! s!"Compiler error: {pp} is not equivalent to IR phase {ip}, this is a bug"
|
||||
|
||||
instance : LT Phase where
|
||||
lt l r := l.toNat < r.toNat
|
||||
|
|
@ -60,7 +74,7 @@ structure Pass where
|
|||
/--
|
||||
The actual pass function, operating on the `Decl`s.
|
||||
-/
|
||||
run : Array Decl → CompilerM (Array Decl)
|
||||
run : Array (Decl phase.toPurity) → CompilerM (Array (Decl phase.toPurity))
|
||||
|
||||
instance : Inhabited Pass where
|
||||
default := { phase := .base, name := default, run := fun decls => return decls }
|
||||
|
|
@ -90,14 +104,10 @@ structure PassManager where
|
|||
monoPassesNoLambda : Array Pass
|
||||
deriving Inhabited
|
||||
|
||||
instance : ToString Phase where
|
||||
toString
|
||||
| .base => "base"
|
||||
| .mono => "mono"
|
||||
|
||||
namespace Pass
|
||||
|
||||
def mkPerDeclaration (name : Name) (run : Decl → CompilerM Decl) (phase : Phase) (occurrence : Nat := 0) : Pass where
|
||||
def mkPerDeclaration (name : Name) (phase : Phase)
|
||||
(run : Decl phase.toPurity → CompilerM (Decl phase.toPurity)) (occurrence : Nat := 0) : Pass where
|
||||
occurrence := occurrence
|
||||
phase := phase
|
||||
name := name
|
||||
|
|
@ -190,6 +200,7 @@ def run (manager : PassManager) (installer : PassInstaller) : CoreM PassManager
|
|||
return { manager with basePasses := (← installer.install manager.basePasses) }
|
||||
| .mono =>
|
||||
return { manager with monoPasses := (← installer.install manager.monoPasses) }
|
||||
| .impure => panic! "Pass manager support for impure unimplemented" -- TODO
|
||||
|
||||
private unsafe def getPassInstallerUnsafe (declName : Name) : CoreM PassInstaller := do
|
||||
ofExcept <| (← getEnv).evalConstCheck PassInstaller (← getOptions) ``PassInstaller declName
|
||||
|
|
|
|||
|
|
@ -45,10 +45,15 @@ private builtin_initialize baseTransparentDeclsExt : EnvExtension (List Name ×
|
|||
Set of public declarations whose mono bodies should be exported to other modules
|
||||
-/
|
||||
private builtin_initialize monoTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt
|
||||
/--
|
||||
Set of public declarations whose impure bodies should be exported to other modules
|
||||
-/
|
||||
private builtin_initialize impureTransparentDeclsExt : EnvExtension (List Name × NameSet) ← mkDeclSetExt
|
||||
|
||||
private def getTransparencyExt : Phase → EnvExtension (List Name × NameSet)
|
||||
| .base => baseTransparentDeclsExt
|
||||
| .mono => monoTransparentDeclsExt
|
||||
| .impure => impureTransparentDeclsExt
|
||||
|
||||
def isDeclPublic (env : Environment) (declName : Name) : Bool := Id.run do
|
||||
if !env.header.isModule then
|
||||
|
|
@ -81,26 +86,28 @@ def setDeclTransparent (env : Environment) (phase : Phase) (declName : Name) : E
|
|||
getTransparencyExt phase |>.modifyState env fun s =>
|
||||
(declName :: s.1, s.2.insert declName)
|
||||
|
||||
abbrev DeclExtState := PHashMap Name Decl
|
||||
abbrev DeclExtState (pu : Purity) := PHashMap Name (Decl pu)
|
||||
|
||||
private abbrev declLt (a b : Decl) :=
|
||||
private abbrev declLt (a b : Decl pu) :=
|
||||
Name.quickLt a.name b.name
|
||||
|
||||
private def sortedDecls (s : DeclExtState) : Array Decl :=
|
||||
private def sortedDecls (s : DeclExtState pu) : Array (Decl pu) :=
|
||||
let decls := s.foldl (init := #[]) fun ps _ v => ps.push v
|
||||
decls.qsort declLt
|
||||
|
||||
private abbrev findAtSorted? (decls : Array Decl) (declName : Name) : Option Decl :=
|
||||
let tmpDecl : Decl := default
|
||||
private abbrev findAtSorted? (decls : Array (Decl pu)) (declName : Name) : Option (Decl pu) :=
|
||||
let tmpDecl : Decl pu := default
|
||||
let tmpDecl := { tmpDecl with name := declName }
|
||||
decls.binSearch tmpDecl declLt
|
||||
|
||||
@[expose] def DeclExt := PersistentEnvExtension Decl Decl DeclExtState
|
||||
@[expose] def DeclExt (pu : Purity) :=
|
||||
PersistentEnvExtension (Decl pu) (Decl pu) (DeclExtState pu)
|
||||
|
||||
instance : Inhabited DeclExt :=
|
||||
inferInstanceAs (Inhabited (PersistentEnvExtension Decl Decl DeclExtState))
|
||||
instance : Inhabited (DeclExt pu) :=
|
||||
inferInstanceAs (Inhabited (PersistentEnvExtension (Decl pu) (Decl pu) (DeclExtState pu)))
|
||||
|
||||
def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : IO DeclExt :=
|
||||
def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) :
|
||||
IO (DeclExt phase.toPurity) :=
|
||||
registerPersistentEnvExtension {
|
||||
name,
|
||||
mkInitial := pure {},
|
||||
|
|
@ -128,74 +135,77 @@ def mkDeclExt (phase : Phase) (name : Name := by exact decl_name%) : IO DeclExt
|
|||
otherState.insert k v
|
||||
}
|
||||
|
||||
builtin_initialize baseExt : DeclExt ← mkDeclExt .base
|
||||
builtin_initialize monoExt : DeclExt ← mkDeclExt .mono
|
||||
builtin_initialize baseExt : DeclExt .pure ← mkDeclExt .base
|
||||
builtin_initialize monoExt : DeclExt .pure ← mkDeclExt .mono
|
||||
builtin_initialize impureExt : DeclExt .impure ← mkDeclExt .impure
|
||||
|
||||
def getDeclCore? (env : Environment) (ext : DeclExt) (declName : Name) : Option Decl :=
|
||||
def getDeclCore? (env : Environment) (ext : DeclExt pu) (declName : Name) : Option (Decl pu) :=
|
||||
match env.getModuleIdxFor? declName with
|
||||
| some modIdx => findAtSorted? (ext.getModuleEntries env modIdx) declName
|
||||
| none => ext.getState env |>.find? declName
|
||||
|
||||
def getBaseDecl? (declName : Name) : CoreM (Option Decl) := do
|
||||
def getBaseDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do
|
||||
return getDeclCore? (← getEnv) baseExt declName
|
||||
|
||||
def getMonoDecl? (declName : Name) : CoreM (Option Decl) := do
|
||||
def getMonoDecl? (declName : Name) : CoreM (Option (Decl .pure)) := do
|
||||
return getDeclCore? (← getEnv) monoExt declName
|
||||
|
||||
def saveBaseDeclCore (env : Environment) (decl : Decl) : Environment :=
|
||||
def getImpureDecl? (declName : Name) : CoreM (Option (Decl .impure)) := do
|
||||
return getDeclCore? (← getEnv) impureExt declName
|
||||
|
||||
def saveBaseDeclCore (env : Environment) (decl : Decl .pure) : Environment :=
|
||||
baseExt.addEntry env decl
|
||||
|
||||
def saveMonoDeclCore (env : Environment) (decl : Decl) : Environment :=
|
||||
def saveMonoDeclCore (env : Environment) (decl : Decl .pure) : Environment :=
|
||||
monoExt.addEntry env decl
|
||||
|
||||
def Decl.saveBase (decl : Decl) : CoreM Unit :=
|
||||
def saveImpureDeclCore (env : Environment) (decl : Decl .impure) : Environment :=
|
||||
impureExt.addEntry env decl
|
||||
|
||||
def Decl.saveBase (decl : Decl .pure) : CoreM Unit :=
|
||||
modifyEnv (saveBaseDeclCore · decl)
|
||||
|
||||
def Decl.saveMono (decl : Decl) : CoreM Unit :=
|
||||
def Decl.saveMono (decl : Decl .pure) : CoreM Unit :=
|
||||
modifyEnv (saveMonoDeclCore · decl)
|
||||
|
||||
def Decl.save (decl : Decl) : CompilerM Unit := do
|
||||
match (← getPhase) with
|
||||
| .base => decl.saveBase
|
||||
| .mono => decl.saveMono
|
||||
def Decl.saveImpure (decl : Decl .impure) : CoreM Unit :=
|
||||
modifyEnv (saveImpureDeclCore · decl)
|
||||
|
||||
def getDeclAt? (declName : Name) (phase : Phase) : CoreM (Option Decl) :=
|
||||
def Decl.save (decl : Decl pu) : CompilerM Unit := do
|
||||
match (← getPhase) with
|
||||
| .base => Phase.withPurityCheck .base pu fun h =>
|
||||
(h.symm ▸ decl).saveBase
|
||||
| .mono => Phase.withPurityCheck .mono pu fun h =>
|
||||
(h.symm ▸ decl).saveMono
|
||||
| .impure => Phase.withPurityCheck .impure pu fun h =>
|
||||
(h.symm ▸ decl).saveImpure
|
||||
|
||||
def getDeclAt? (declName : Name) (phase : Phase) : CoreM (Option (Decl phase.toPurity)) :=
|
||||
match phase with
|
||||
| .base => getBaseDecl? declName
|
||||
| .mono => getMonoDecl? declName
|
||||
| .impure => getImpureDecl? declName
|
||||
|
||||
def getDecl? (declName : Name) : CompilerM (Option Decl) := do
|
||||
getDeclAt? declName (← getPhase)
|
||||
@[inline]
|
||||
def getDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl pu)) := do
|
||||
let some decl ← getDeclAt? declName (← getPhase) | return none
|
||||
return some ⟨_, decl⟩
|
||||
|
||||
def getLocalDeclAt? (declName : Name) (phase : Phase) : CompilerM (Option Decl) := do
|
||||
def getLocalDeclAt? (declName : Name) (phase : Phase) : CompilerM (Option (Decl phase.toPurity)) := do
|
||||
match phase with
|
||||
| .base => return baseExt.getState (← getEnv) |>.find? declName
|
||||
| .mono => return monoExt.getState (← getEnv) |>.find? declName
|
||||
| .impure => return impureExt.getState (← getEnv) |>.find? declName
|
||||
|
||||
def getLocalDecl? (declName : Name) : CompilerM (Option Decl) := do
|
||||
getLocalDeclAt? declName (← getPhase)
|
||||
@[inline]
|
||||
def getLocalDecl? (declName : Name) : CompilerM (Option ((pu : Purity) × Decl pu)) := do
|
||||
let some decl ← getLocalDeclAt? declName (← getPhase) | return none
|
||||
return some ⟨_, decl⟩
|
||||
|
||||
def getExt (phase : Phase) : DeclExt :=
|
||||
def getExt (phase : Phase) : DeclExt phase.toPurity :=
|
||||
match phase with
|
||||
| .base => baseExt
|
||||
| .mono => monoExt
|
||||
|
||||
def forEachDecl (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
for modIdx in *...env.allImportedModuleNames.size do
|
||||
for decl in ext.getModuleEntries env modIdx do
|
||||
f decl
|
||||
ext.getState env |>.forM fun _ decl => f decl
|
||||
|
||||
def forEachModuleDecl (moduleName : Name) (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
|
||||
for decl in ext.getModuleEntries env modIdx do
|
||||
f decl
|
||||
|
||||
def forEachMainModuleDecl (f : Decl → CoreM Unit) (phase := Phase.base) : CoreM Unit := do
|
||||
(getExt phase).getState (← getEnv) |>.forM fun _ decl => f decl
|
||||
| .impure => impureExt
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -43,11 +43,11 @@ def ppFVar (fvarId : FVarId) : M Format :=
|
|||
def ppExpr (e : Expr) : M Format := do
|
||||
Meta.ppExpr e |>.run' { lctx := (← read) }
|
||||
|
||||
def ppArg (e : Arg) : M Format := do
|
||||
def ppArg (e : Arg pu) : M Format := do
|
||||
match e with
|
||||
| .erased => return "◾"
|
||||
| .fvar fvarId => ppFVar fvarId
|
||||
| .type e =>
|
||||
| .type e _ =>
|
||||
if pp.explicit.get (← getOptions) then
|
||||
if e.isConst || e.isProp || e.isType0 || e.isFVar then
|
||||
ppExpr e
|
||||
|
|
@ -56,7 +56,7 @@ def ppArg (e : Arg) : M Format := do
|
|||
else
|
||||
return "_"
|
||||
|
||||
def ppArgs (args : Array Arg) : M Format := do
|
||||
def ppArgs (args : Array (Arg pu)) : M Format := do
|
||||
prefixJoin " " args ppArg
|
||||
|
||||
def ppLitValue (lit : LitValue) : M Format := do
|
||||
|
|
@ -64,49 +64,49 @@ def ppLitValue (lit : LitValue) : M Format := do
|
|||
| .nat v | .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return format v
|
||||
| .str v => return format (repr v)
|
||||
|
||||
def ppLetValue (e : LetValue) : M Format := do
|
||||
def ppLetValue (e : LetValue pu) : M Format := do
|
||||
match e with
|
||||
| .erased => return "◾"
|
||||
| .lit v => ppLitValue v
|
||||
| .proj _ i fvarId => return f!"{← ppFVar fvarId} # {i}"
|
||||
| .proj _ i fvarId _ => return f!"{← ppFVar fvarId} # {i}"
|
||||
| .fvar fvarId args => return f!"{← ppFVar fvarId}{← ppArgs args}"
|
||||
| .const declName us args => return f!"{← ppExpr (.const declName us)}{← ppArgs args}"
|
||||
| .const declName us args _ => return f!"{← ppExpr (.const declName us)}{← ppArgs args}"
|
||||
|
||||
def ppParam (param : Param) : M Format := do
|
||||
def ppParam (param : Param pu) : M Format := do
|
||||
let borrow := if param.borrow then "@&" else ""
|
||||
if pp.funBinderTypes.get (← getOptions) then
|
||||
return Format.paren f!"{param.binderName} : {borrow}{← ppExpr param.type}"
|
||||
else
|
||||
return format s!"{borrow}{param.binderName}"
|
||||
|
||||
def ppParams (params : Array Param) : M Format := do
|
||||
def ppParams (params : Array (Param pu)) : M Format := do
|
||||
prefixJoin " " params ppParam
|
||||
|
||||
def ppLetDecl (letDecl : LetDecl) : M Format := do
|
||||
def ppLetDecl (letDecl : LetDecl pu) : M Format := do
|
||||
if pp.letVarTypes.get (← getOptions) then
|
||||
return f!"let {letDecl.binderName} : {← ppExpr letDecl.type} := {← ppLetValue letDecl.value}"
|
||||
else
|
||||
return f!"let {letDecl.binderName} := {← ppLetValue letDecl.value}"
|
||||
|
||||
def getFunType (ps : Array Param) (type : Expr) : CoreM Expr :=
|
||||
def getFunType (ps : Array (Param pu)) (type : Expr) : CoreM Expr :=
|
||||
if type.isErased then
|
||||
pure type
|
||||
else
|
||||
instantiateForall type (ps.map (mkFVar ·.fvarId))
|
||||
|
||||
mutual
|
||||
partial def ppFunDecl (funDecl : FunDecl) : M Format := do
|
||||
partial def ppFunDecl (funDecl : FunDecl pu) : M Format := do
|
||||
return f!"{funDecl.binderName}{← ppParams funDecl.params} : {← ppExpr (← getFunType funDecl.params funDecl.type)} :={indentD (← ppCode funDecl.value)}"
|
||||
|
||||
partial def ppAlt (alt : Alt) : M Format := do
|
||||
partial def ppAlt (alt : Alt pu) : M Format := do
|
||||
match alt with
|
||||
| .default k => return f!"| _ =>{indentD (← ppCode k)}"
|
||||
| .alt ctorName params k => return f!"| {ctorName}{← ppParams params} =>{indentD (← ppCode k)}"
|
||||
| .alt ctorName params k _ => return f!"| {ctorName}{← ppParams params} =>{indentD (← ppCode k)}"
|
||||
|
||||
partial def ppCode (c : Code) : M Format := do
|
||||
partial def ppCode (c : Code pu) : M Format := do
|
||||
match c with
|
||||
| .let decl k => return (← ppLetDecl decl) ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .fun decl k => return f!"fun " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .fun decl k _ => return f!"fun " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .jp decl k => return f!"jp " ++ (← ppFunDecl decl) ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .cases c => return f!"cases {← ppFVar c.discr} : {← ppExpr c.resultType}{← prefixJoin .line c.alts ppAlt}"
|
||||
| .return fvarId => return f!"return {← ppFVar fvarId}"
|
||||
|
|
@ -117,7 +117,7 @@ mutual
|
|||
else
|
||||
return "⊥"
|
||||
|
||||
partial def ppDeclValue (b : DeclValue) : M Format := do
|
||||
partial def ppDeclValue (b : DeclValue pu) : M Format := do
|
||||
match b with
|
||||
| .code c => ppCode c
|
||||
| .extern .. => return "extern"
|
||||
|
|
@ -125,21 +125,21 @@ end
|
|||
|
||||
def run (x : M α) : CompilerM α :=
|
||||
withOptions (pp.sanitizeNames.set · false) do
|
||||
x |>.run (← get).lctx.toLocalContext
|
||||
x |>.run ((← get).lctx.toLocalContext (← getPurity))
|
||||
|
||||
end PP
|
||||
|
||||
def ppCode (code : Code) : CompilerM Format :=
|
||||
def ppCode (code : Code pu) : CompilerM Format :=
|
||||
PP.run <| PP.ppCode code
|
||||
|
||||
def ppLetValue (e : LetValue) : CompilerM Format :=
|
||||
def ppLetValue (e : LetValue pu) : CompilerM Format :=
|
||||
PP.run <| PP.ppLetValue e
|
||||
|
||||
def ppDecl (decl : Decl) : CompilerM Format :=
|
||||
def ppDecl (decl : Decl pu) : CompilerM Format :=
|
||||
PP.run do
|
||||
return f!"def {decl.name}{← PP.ppParams decl.params} : {← PP.ppExpr (← PP.getFunType decl.params decl.type)} :={indentD (← PP.ppDeclValue decl.value)}"
|
||||
|
||||
def ppFunDecl (decl : FunDecl) : CompilerM Format :=
|
||||
def ppFunDecl (decl : FunDecl pu) : CompilerM Format :=
|
||||
PP.run do
|
||||
return f!"fun {← PP.ppFunDecl decl}"
|
||||
|
||||
|
|
@ -159,7 +159,7 @@ Similar to `ppDecl`, but in `CoreM`, and it does not assume
|
|||
`decl` has already been internalized.
|
||||
This function is used for debugging purposes.
|
||||
-/
|
||||
def ppDecl' (decl : Decl) : CoreM Format := do
|
||||
def ppDecl' (decl : Decl pu) : CoreM Format := do
|
||||
runCompilerWithoutModifyingState do
|
||||
ppDecl (← decl.internalize)
|
||||
|
||||
|
|
@ -167,7 +167,7 @@ def ppDecl' (decl : Decl) : CoreM Format := do
|
|||
Similar to `ppCode`, but in `CoreM`, and it does not assume
|
||||
`code` has already been internalized.
|
||||
-/
|
||||
def ppCode' (code : Code) : CoreM Format := do
|
||||
def ppCode' (code : Code pu) : CoreM Format := do
|
||||
runCompilerWithoutModifyingState do
|
||||
ppCode (← code.internalize)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ def filter (f : α → CompilerM Bool) : Probe α α := fun data => data.filterM
|
|||
def sorted [Inhabited α] [LT α] [DecidableLT α] : Probe α α := fun data => return data.qsort (· < ·)
|
||||
|
||||
@[inline]
|
||||
def sortedBySize : Probe Decl (Nat × Decl) := fun decls =>
|
||||
def sortedBySize (pu : Purity) : Probe (Decl pu) (Nat × Decl pu) := fun decls =>
|
||||
let decls := decls.map fun decl => (decl.size, decl)
|
||||
return decls.qsort fun (sz₁, decl₁) (sz₂, decl₂) =>
|
||||
if sz₁ == sz₂ then Name.lt decl₁.name decl₂.name else sz₁ < sz₂
|
||||
|
|
@ -44,116 +44,118 @@ def countUnique [ToString α] [BEq α] [Hashable α] : Probe α (α × Nat) := f
|
|||
def countUniqueSorted [ToString α] [BEq α] [Hashable α] [Inhabited α] : Probe α (α × Nat) :=
|
||||
countUnique >=> fun data => return data.qsort (fun l r => l.snd < r.snd)
|
||||
|
||||
partial def getLetValues : Probe Decl LetValue := fun decls => do
|
||||
partial def getLetValues (pu : Purity) : Probe (Decl pu) (LetValue pu) := fun decls => do
|
||||
let (_, res) ← start decls |>.run #[]
|
||||
return res
|
||||
where
|
||||
go (c : Code) : StateRefT (Array LetValue) CompilerM Unit := do
|
||||
go (c : Code pu) : StateRefT (Array (LetValue pu)) CompilerM Unit := do
|
||||
match c with
|
||||
| .let (decl : LetDecl) (k : Code) =>
|
||||
| .let decl k =>
|
||||
modify fun s => s.push decl.value
|
||||
go k
|
||||
| .fun decl k | .jp decl k =>
|
||||
| .fun decl k _ | .jp decl k =>
|
||||
go decl.value
|
||||
go k
|
||||
| .cases cs => cs.alts.forM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return ()
|
||||
start (decls : Array Decl) : StateRefT (Array LetValue) CompilerM Unit :=
|
||||
start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit :=
|
||||
decls.forM (·.value.forCodeM go)
|
||||
|
||||
partial def getJps : Probe Decl FunDecl := fun decls => do
|
||||
partial def getJps (pu : Purity) : Probe (Decl pu) (FunDecl pu) := fun decls => do
|
||||
let (_, res) ← start decls |>.run #[]
|
||||
return res
|
||||
where
|
||||
go (code : Code) : StateRefT (Array FunDecl) CompilerM Unit := do
|
||||
go (code : Code pu) : StateRefT (Array (FunDecl pu)) CompilerM Unit := do
|
||||
match code with
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value; go k
|
||||
| .fun decl k _ => go decl.value; go k
|
||||
| .jp decl k => modify (·.push decl); go decl.value; go k
|
||||
| .cases cs => cs.alts.forM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return ()
|
||||
|
||||
start (decls : Array Decl) : StateRefT (Array FunDecl) CompilerM Unit :=
|
||||
start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit :=
|
||||
decls.forM (·.value.forCodeM go)
|
||||
|
||||
partial def filterByLet (f : LetDecl → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByLet (pu : Purity) (f : LetDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let decl k => do if (← f decl) then return true else go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByFun (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByFun (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k | .jp _ k => go k
|
||||
| .fun decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .fun decl k _ => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByJp (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByJp (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value <||> go k
|
||||
| .fun decl k _ => go decl.value <||> go k
|
||||
| .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByFunDecl (f : FunDecl → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByFunDecl (pu : Purity) (f : FunDecl pu → CompilerM Bool) :
|
||||
Probe (Decl pu) (Decl pu):=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByCases (f : Cases → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByCases (pu : Purity) (f : Cases pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByJmp (f : FVarId → Array Arg → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByJmp (pu : Purity) (f : FVarId → Array (Arg pu) → CompilerM Bool) :
|
||||
Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp fn var => f fn var
|
||||
| .return .. | .unreach .. => return false
|
||||
|
||||
partial def filterByReturn (f : FVarId → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByReturn (pu : Purity) (f : FVarId → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .unreach .. => return false
|
||||
| .return var => f var
|
||||
|
||||
partial def filterByUnreach (f : Expr → CompilerM Bool) : Probe Decl Decl :=
|
||||
partial def filterByUnreach (pu : Purity) (f : Expr → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
where
|
||||
go : Code → CompilerM Bool
|
||||
go : Code pu → CompilerM Bool
|
||||
| .let _ k => go k
|
||||
| .fun decl k | .jp decl k => go decl.value <||> go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. => return false
|
||||
| .unreach typ => f typ
|
||||
|
||||
@[inline]
|
||||
def declNames : Probe Decl Name :=
|
||||
def declNames (pu : Purity) : Probe (Decl pu) Name :=
|
||||
Probe.map (fun decl => return decl.name)
|
||||
|
||||
@[inline]
|
||||
|
|
@ -172,7 +174,8 @@ def tail (n : Nat) : Probe α α := fun data => return data[(data.size - n)...*]
|
|||
@[inline]
|
||||
def head (n : Nat) : Probe α α := fun data => return data[*...n]
|
||||
|
||||
def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
|
||||
def runOnDeclsNamed (declNames : Array Name) (phase : Phase := Phase.base)
|
||||
(probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
let decls ← declNames.mapM fun name => do
|
||||
|
|
@ -180,14 +183,15 @@ def runOnDeclsNamed (declNames : Array Name) (probe : Probe Decl β) (phase : Ph
|
|||
return decl
|
||||
probe decls |>.run (phase := phase)
|
||||
|
||||
def runOnModule (moduleName : Name) (probe : Probe Decl β) (phase : Phase := Phase.base): CoreM (Array β) := do
|
||||
def runOnModule (moduleName : Name) (phase : Phase := Phase.base)
|
||||
(probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
let some modIdx := env.getModuleIdx? moduleName | throwError "module `{moduleName}` not found"
|
||||
let decls := ext.getModuleEntries env modIdx
|
||||
probe decls |>.run (phase := phase)
|
||||
|
||||
def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (Array β) := do
|
||||
def runGlobally (phase : Phase := Phase.base) (probe : Probe (Decl phase.toPurity) β) : CoreM (Array β) := do
|
||||
let ext := getExt phase
|
||||
let env ← getEnv
|
||||
let mut decls := #[]
|
||||
|
|
@ -195,7 +199,7 @@ def runGlobally (probe : Probe Decl β) (phase : Phase := Phase.base) : CoreM (A
|
|||
decls := decls.append <| ext.getModuleEntries env modIdx
|
||||
probe decls |>.run (phase := phase)
|
||||
|
||||
def toPass [ToString β] (probe : Probe Decl β) (phase : Phase) : Pass where
|
||||
def toPass [ToString β] (phase : Phase) (probe : Probe (Decl phase.toPurity) β) : Pass where
|
||||
phase := phase
|
||||
name := `probe
|
||||
run := fun decls => do
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ Local function declaration and join point being pulled.
|
|||
-/
|
||||
structure ToPull where
|
||||
isFun : Bool
|
||||
decl : FunDecl
|
||||
decl : FunDecl .pure
|
||||
used : FVarIdHashSet
|
||||
deriving Inhabited
|
||||
|
||||
|
|
@ -50,7 +50,8 @@ where
|
|||
else
|
||||
go as (a :: keep) dep
|
||||
|
||||
partial def findFVarDepsFixpoint (todo : List ToPull) (acc : Array ToPull := #[]) : PullM (Array ToPull) := do
|
||||
partial def findFVarDepsFixpoint (todo : List ToPull) (acc : Array ToPull := #[]) :
|
||||
PullM (Array ToPull) := do
|
||||
match todo with
|
||||
| [] => return acc
|
||||
| p :: ps =>
|
||||
|
|
@ -65,7 +66,7 @@ partial def findFVarDeps (fvarId : FVarId) : PullM (Array ToPull) := do
|
|||
Similar to `findFVarDeps`. Extract from the state any local function declarations that depends on the given
|
||||
parameters.
|
||||
-/
|
||||
def findParamsDeps (params : Array Param) : PullM (Array ToPull) := do
|
||||
def findParamsDeps (params : Array (Param pu)) : PullM (Array ToPull) := do
|
||||
let mut acc := #[]
|
||||
for param in params do
|
||||
acc := acc ++ (← findFVarDeps param.fvarId)
|
||||
|
|
@ -74,7 +75,7 @@ def findParamsDeps (params : Array Param) : PullM (Array ToPull) := do
|
|||
/--
|
||||
Construct the code `fun p.decl k` or `jp p.decl k`.
|
||||
-/
|
||||
def ToPull.attach (p : ToPull) (k : Code) : Code :=
|
||||
def ToPull.attach (p : ToPull) (k : Code .pure) : Code .pure :=
|
||||
if p.isFun then
|
||||
.fun p.decl k
|
||||
else
|
||||
|
|
@ -83,19 +84,19 @@ def ToPull.attach (p : ToPull) (k : Code) : Code :=
|
|||
/--
|
||||
Attach the given array of local function declarations and join points to `k`.
|
||||
-/
|
||||
partial def attach (ps : Array ToPull) (k : Code) : Code := Id.run do
|
||||
partial def attach (ps : Array ToPull) (k : Code .pure) : Code .pure := Id.run do
|
||||
let visited := ps.map fun _ => false
|
||||
let (_, (k, _)) := go |>.run (k, visited)
|
||||
return k
|
||||
where
|
||||
go : StateM (Code × Array Bool) Unit := do
|
||||
go : StateM (Code .pure × Array Bool) Unit := do
|
||||
for i in *...ps.size do
|
||||
visit i
|
||||
|
||||
visited (i : Nat) : StateM (Code × Array Bool) Bool :=
|
||||
visited (i : Nat) : StateM (Code .pure × Array Bool) Bool :=
|
||||
return (← get).2[i]!
|
||||
|
||||
visit (i : Nat) : StateM (Code × Array Bool) Unit := do
|
||||
visit (i : Nat) : StateM (Code .pure × Array Bool) Unit := do
|
||||
unless (← visited i) do
|
||||
modify fun (k, visited) => (k, visited.set! i true)
|
||||
let pi := ps[i]!
|
||||
|
|
@ -110,7 +111,7 @@ where
|
|||
Extract from the state any local function declarations that depends on the given
|
||||
free variable, **and** attach to code `k`.
|
||||
-/
|
||||
partial def attachFVarDeps (fvarId : FVarId) (k : Code) : PullM Code := do
|
||||
partial def attachFVarDeps (fvarId : FVarId) (k : Code .pure) : PullM (Code .pure) := do
|
||||
let ps ← findFVarDeps fvarId
|
||||
return attach ps k
|
||||
|
||||
|
|
@ -118,11 +119,11 @@ partial def attachFVarDeps (fvarId : FVarId) (k : Code) : PullM Code := do
|
|||
Similar to `attachFVarDeps`. Extract from the state any local function declarations that depends on the given
|
||||
parameters, **and** attach to code `k`.
|
||||
-/
|
||||
def attachParamsDeps (params : Array Param) (k : Code) : PullM Code := do
|
||||
def attachParamsDeps (params : Array (Param .pure)) (k : Code .pure) : PullM (Code .pure) := do
|
||||
let ps ← findParamsDeps params
|
||||
return attach ps k
|
||||
|
||||
def attachJps (k : Code) : PullM Code := do
|
||||
def attachJps (k : Code .pure) : PullM (Code .pure) := do
|
||||
let jps := (← get).filter fun info => !info.isFun
|
||||
modify fun s => s.filter fun info => info.isFun
|
||||
let jps ← findFVarDepsFixpoint jps
|
||||
|
|
@ -132,7 +133,7 @@ mutual
|
|||
/--
|
||||
Add local function declaration (or join point if `isFun = false`) to the state.
|
||||
-/
|
||||
partial def addToPull (isFun : Bool) (decl : FunDecl) : PullM Unit := do
|
||||
partial def addToPull (isFun : Bool) (decl : FunDecl .pure) : PullM Unit := do
|
||||
let saved ← get
|
||||
modify fun _ => []
|
||||
let mut value ← pull decl.value
|
||||
|
|
@ -147,19 +148,19 @@ partial def addToPull (isFun : Bool) (decl : FunDecl) : PullM Unit := do
|
|||
Pull local function declarations and join points in `code`.
|
||||
The state contains the declarations being pulled.
|
||||
-/
|
||||
partial def pull (code : Code) : PullM Code := do
|
||||
partial def pull (code : Code .pure) : PullM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let k ← pull k
|
||||
let k ← attachFVarDeps decl.fvarId k
|
||||
return code.updateLet! decl k
|
||||
| .fun decl k => addToPull true decl; pull k
|
||||
| .fun decl k _ => addToPull true decl; pull k
|
||||
| .jp decl k => addToPull false decl; pull k
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapMonoM fun alt => do
|
||||
match alt with
|
||||
| .default k => return alt.updateCode (← pull k)
|
||||
| .alt _ ps k =>
|
||||
| .alt _ ps k _ =>
|
||||
let k ← pull k
|
||||
let k ← attachParamsDeps ps k
|
||||
return alt.updateCode k
|
||||
|
|
@ -174,13 +175,13 @@ open PullFunDecls
|
|||
/--
|
||||
Pull local function declarations and join points in the given declaration.
|
||||
-/
|
||||
def Decl.pullFunDecls (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.pullFunDecls (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let (value, ps) ← decl.value.mapCodeM pull |>.run []
|
||||
let value := value.mapCode (attach ps.toArray)
|
||||
return { decl with value }
|
||||
|
||||
def pullFunDecls : Pass :=
|
||||
.mkPerDeclaration `pullFunDecls Decl.pullFunDecls .base
|
||||
.mkPerDeclaration `pullFunDecls .base Decl.pullFunDecls
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.pullFunDecls (inherited := true)
|
||||
|
|
|
|||
|
|
@ -15,28 +15,28 @@ namespace Lean.Compiler.LCNF
|
|||
namespace PullLetDecls
|
||||
|
||||
structure Context where
|
||||
isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool
|
||||
isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool
|
||||
included : FVarIdSet := {}
|
||||
|
||||
structure State where
|
||||
toPull : Array LetDecl := #[]
|
||||
toPull : Array (LetDecl .pure) := #[]
|
||||
|
||||
abbrev PullM := ReaderT Context $ StateRefT State CompilerM
|
||||
|
||||
@[inline] def withFVar (fvarId : FVarId) (x : PullM α) : PullM α :=
|
||||
withReader (fun ctx => { ctx with included := ctx.included.insert fvarId }) x
|
||||
|
||||
@[inline] def withParams (ps : Array Param) (x : PullM α) : PullM α :=
|
||||
@[inline] def withParams (ps : Array (Param .pure)) (x : PullM α) : PullM α :=
|
||||
withReader (fun ctx => { ctx with included := ps.foldl (init := ctx.included) fun s p => s.insert p.fvarId }) x
|
||||
|
||||
@[inline] def withNewScope (x : PullM α) : PullM α :=
|
||||
withReader (fun ctx => { ctx with included := {} }) x
|
||||
|
||||
partial def withCheckpoint (x : PullM Code) : PullM Code := do
|
||||
partial def withCheckpoint (x : PullM (Code .pure)) : PullM (Code .pure) := do
|
||||
let toPullSizeSaved := (← get).toPull.size
|
||||
let c ← withNewScope x
|
||||
let toPull := (← get).toPull
|
||||
let rec go (i : Nat) (included : FVarIdSet) : StateM (Array LetDecl) Code := do
|
||||
let rec go (i : Nat) (included : FVarIdSet) : StateM (Array (LetDecl .pure)) (Code .pure) := do
|
||||
if h : i < toPull.size then
|
||||
let letDecl := toPull[i]
|
||||
if letDecl.dependsOn included then
|
||||
|
|
@ -51,11 +51,11 @@ partial def withCheckpoint (x : PullM Code) : PullM Code := do
|
|||
modify fun s => { s with toPull := s.toPull.shrink toPullSizeSaved ++ keep }
|
||||
return c
|
||||
|
||||
def attachToPull (c : Code) : PullM Code := do
|
||||
def attachToPull (c : Code .pure) : PullM (Code .pure) := do
|
||||
let toPull := (← get).toPull
|
||||
return toPull.foldr (init := c) fun decl c => .let decl c
|
||||
|
||||
def shouldPull (decl : LetDecl) : PullM Bool := do
|
||||
def shouldPull (decl : LetDecl .pure) : PullM Bool := do
|
||||
unless decl.dependsOn (← read).included do
|
||||
if (← (← read).isCandidateFn decl (← read).included) then
|
||||
modify fun s => { s with toPull := s.toPull.push decl }
|
||||
|
|
@ -63,12 +63,12 @@ def shouldPull (decl : LetDecl) : PullM Bool := do
|
|||
return false
|
||||
|
||||
mutual
|
||||
partial def pullAlt (alt : Alt) : PullM Alt :=
|
||||
partial def pullAlt (alt : (Alt .pure)) : PullM (Alt .pure) :=
|
||||
match alt with
|
||||
| .default k => return alt.updateCode (← withNewScope <| pullDecls k)
|
||||
| .alt _ params k => return alt.updateCode (← withNewScope <| withParams params <| pullDecls k)
|
||||
|
||||
partial def pullDecls (code : Code) : PullM Code := do
|
||||
partial def pullDecls (code : Code .pure) : PullM (Code .pure) := do
|
||||
match code with
|
||||
| .cases c =>
|
||||
-- At the present time, we can't correctly enforce the dependencies required for lifting
|
||||
|
|
@ -93,21 +93,21 @@ mutual
|
|||
|
||||
end
|
||||
|
||||
def PullM.run (x : PullM α) (isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool) : CompilerM α :=
|
||||
def PullM.run (x : PullM α) (isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool) : CompilerM α :=
|
||||
x { isCandidateFn } |>.run' {}
|
||||
|
||||
end PullLetDecls
|
||||
|
||||
open PullLetDecls
|
||||
|
||||
def Decl.pullLetDecls (decl : Decl) (isCandidateFn : LetDecl → FVarIdSet → CompilerM Bool) : CompilerM Decl := do
|
||||
def Decl.pullLetDecls (decl : Decl .pure) (isCandidateFn : LetDecl .pure → FVarIdSet → CompilerM Bool) : CompilerM (Decl .pure) := do
|
||||
PullM.run (isCandidateFn := isCandidateFn) do
|
||||
withParams decl.params do
|
||||
let value ← decl.value.mapCodeM pullDecls
|
||||
let value ← value.mapCodeM attachToPull
|
||||
return { decl with value }
|
||||
|
||||
def Decl.pullInstances (decl : Decl) : CompilerM Decl :=
|
||||
def Decl.pullInstances (decl : Decl .pure) : CompilerM (Decl .pure) :=
|
||||
decl.pullLetDecls fun letDecl candidates => do
|
||||
-- TODO: Correctly represent these dependencies so this check isn't required.
|
||||
if let .const _ _ args := letDecl.value then
|
||||
|
|
@ -122,7 +122,7 @@ def Decl.pullInstances (decl : Decl) : CompilerM Decl :=
|
|||
return false
|
||||
|
||||
def pullInstances : Pass :=
|
||||
.mkPerDeclaration `pullInstances Decl.pullInstances .base
|
||||
.mkPerDeclaration `pullInstances .base Decl.pullInstances
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.pullInstances (inherited := true)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ We assume this limitation is irrelevant in practice.
|
|||
namespace FindUsed
|
||||
|
||||
structure Context where
|
||||
decl : Decl
|
||||
decl : Decl .pure
|
||||
params : FVarIdSet
|
||||
|
||||
structure State where
|
||||
|
|
@ -64,12 +64,12 @@ def visitFVar (fvarId : FVarId) : FindUsedM Unit := do
|
|||
if (← read).params.contains fvarId then
|
||||
modify fun s => { s with used := s.used.insert fvarId }
|
||||
|
||||
def visitArg (arg : Arg) : FindUsedM Unit := do
|
||||
def visitArg (arg : Arg .pure) : FindUsedM Unit := do
|
||||
match arg with
|
||||
| .erased | .type .. => return ()
|
||||
| .fvar fvarId => visitFVar fvarId
|
||||
|
||||
def visitLetValue (e : LetValue) : FindUsedM Unit := do
|
||||
def visitLetValue (e : LetValue .pure) : FindUsedM Unit := do
|
||||
match e with
|
||||
| .erased | .lit .. => return ()
|
||||
| .proj _ _ fvarId => visitFVar fvarId
|
||||
|
|
@ -93,7 +93,7 @@ def visitLetValue (e : LetValue) : FindUsedM Unit := do
|
|||
else
|
||||
args.forM visitArg
|
||||
|
||||
partial def visit (code : Code) : FindUsedM Unit := do
|
||||
partial def visit (code : Code .pure) : FindUsedM Unit := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
visitLetValue decl.value
|
||||
|
|
@ -107,7 +107,7 @@ partial def visit (code : Code) : FindUsedM Unit := do
|
|||
| .return fvarId => visitFVar fvarId
|
||||
| .unreach _ => return ()
|
||||
|
||||
def collectUsedParams (decl : Decl) : CompilerM FVarIdHashSet := do
|
||||
def collectUsedParams (decl : Decl .pure) : CompilerM FVarIdHashSet := do
|
||||
let params := decl.params.foldl (init := {}) fun s p => s.insert p.fvarId
|
||||
let (_, { used, .. }) ← decl.value.forCodeM visit |>.run { decl, params } |>.run {}
|
||||
return used
|
||||
|
|
@ -123,7 +123,7 @@ structure Context where
|
|||
|
||||
abbrev ReduceM := ReaderT Context CompilerM
|
||||
|
||||
partial def reduce (code : Code) : ReduceM Code := do
|
||||
partial def reduce (code : Code .pure) : ReduceM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let .const declName _ args := decl.value | do return code.updateLet! decl (← reduce k)
|
||||
|
|
@ -148,7 +148,7 @@ end ReduceArity
|
|||
|
||||
open FindUsed ReduceArity Internalize
|
||||
|
||||
def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
|
||||
def Decl.reduceArity (decl : Decl .pure) : CompilerM (Array (Decl .pure)) := do
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
let used ← collectUsedParams decl
|
||||
|
|
@ -160,7 +160,7 @@ def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
|
|||
trace[Compiler.reduceArity] "{decl.name}, used params: {used.toList.map mkFVar}"
|
||||
let mask := decl.params.map fun param => used.contains param.fvarId
|
||||
let auxName := decl.name ++ `_redArg
|
||||
let mkAuxDecl : CompilerM Decl := do
|
||||
let mkAuxDecl : CompilerM (Decl .pure) := do
|
||||
let params := decl.params.filter fun param => used.contains param.fvarId
|
||||
let value ← decl.value.mapCodeM reduce |>.run { declName := decl.name, auxDeclName := auxName, paramMask := mask }
|
||||
let type ← code.inferType
|
||||
|
|
@ -168,7 +168,7 @@ def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
|
|||
let auxDecl := { decl with name := auxName, levelParams := [], type, params, value }
|
||||
auxDecl.saveMono
|
||||
return auxDecl
|
||||
let updateDecl : InternalizeM Decl := do
|
||||
let updateDecl : InternalizeM .pure (Decl .pure) := do
|
||||
let params ← decl.params.mapM internalizeParam
|
||||
let mut args := #[]
|
||||
for used in mask, param in params do
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ namespace ReduceJpArity
|
|||
|
||||
abbrev ReduceM := ReaderT (FVarIdMap (Array Bool)) CompilerM
|
||||
|
||||
partial def reduce (code : Code) : ReduceM Code := do
|
||||
partial def reduce (code : Code .pure) : ReduceM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k => return code.updateLet! decl (← reduce k)
|
||||
| .fun decl k =>
|
||||
|
|
@ -69,12 +69,14 @@ open ReduceJpArity
|
|||
/--
|
||||
Try to reduce arity of join points
|
||||
-/
|
||||
def Decl.reduceJpArity (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.reduceJpArity (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
let value ← decl.value.mapCodeM reduce |>.run {}
|
||||
return { decl with value }
|
||||
|
||||
-- TODO: This can be made Purity generic
|
||||
def reduceJpArity (phase := Phase.base) : Pass :=
|
||||
.mkPerDeclaration `reduceJpArity Decl.reduceJpArity phase
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `reduceJpArity phase (h ▸ Decl.reduceJpArity)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.reduceJpArity (inherited := true)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ A mapping from free variable id to binder name.
|
|||
-/
|
||||
abbrev Renaming := FVarIdMap Name
|
||||
|
||||
def Param.applyRenaming (param : Param) (r : Renaming) : CompilerM Param := do
|
||||
def Param.applyRenaming (param : Param pu) (r : Renaming) : CompilerM (Param pu) := do
|
||||
if let some binderName := r.get? param.fvarId then
|
||||
let param := { param with binderName }
|
||||
modifyLCtx fun lctx => lctx.addParam param
|
||||
|
|
@ -24,7 +24,7 @@ def Param.applyRenaming (param : Param) (r : Renaming) : CompilerM Param := do
|
|||
else
|
||||
return param
|
||||
|
||||
def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl := do
|
||||
def LetDecl.applyRenaming (decl : LetDecl pu) (r : Renaming) : CompilerM (LetDecl pu) := do
|
||||
if let some binderName := r.get? decl.fvarId then
|
||||
let decl := { decl with binderName }
|
||||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||||
|
|
@ -33,7 +33,7 @@ def LetDecl.applyRenaming (decl : LetDecl) (r : Renaming) : CompilerM LetDecl :=
|
|||
return decl
|
||||
|
||||
mutual
|
||||
partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM FunDecl := do
|
||||
partial def FunDecl.applyRenaming (decl : (FunDecl pu)) (r : Renaming) : CompilerM (FunDecl pu) := do
|
||||
if let some binderName := r.get? decl.fvarId then
|
||||
let decl := decl.updateBinderName binderName
|
||||
modifyLCtx fun lctx => lctx.addFunDecl decl
|
||||
|
|
@ -41,20 +41,20 @@ partial def FunDecl.applyRenaming (decl : FunDecl) (r : Renaming) : CompilerM Fu
|
|||
else
|
||||
decl.updateValue (← decl.value.applyRenaming r)
|
||||
|
||||
partial def Code.applyRenaming (code : Code) (r : Renaming) : CompilerM Code := do
|
||||
partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code pu) := do
|
||||
match code with
|
||||
| .let decl k => return code.updateLet! (← decl.applyRenaming r) (← k.applyRenaming r)
|
||||
| .fun decl k | .jp decl k => return code.updateFun! (← decl.applyRenaming r) (← k.applyRenaming r)
|
||||
| .fun decl k _ | .jp decl k => return code.updateFun! (← decl.applyRenaming r) (← k.applyRenaming r)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapMonoM fun alt =>
|
||||
match alt with
|
||||
| .default k => return alt.updateCode (← k.applyRenaming r)
|
||||
| .alt _ ps k => return alt.updateAlt! (← ps.mapMonoM (·.applyRenaming r)) (← k.applyRenaming r)
|
||||
| .alt _ ps k _ => return alt.updateAlt! (← ps.mapMonoM (·.applyRenaming r)) (← k.applyRenaming r)
|
||||
return code.updateAlts! alts
|
||||
| .jmp .. | .unreach .. | .return .. => return code
|
||||
end
|
||||
|
||||
def Decl.applyRenaming (decl : Decl) (r : Renaming) : CompilerM Decl := do
|
||||
def Decl.applyRenaming (decl : Decl pu) (r : Renaming) : CompilerM (Decl pu) := do
|
||||
if r.isEmpty then
|
||||
return decl
|
||||
else
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ public section
|
|||
namespace Lean.Compiler.LCNF
|
||||
open Simp
|
||||
|
||||
def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
|
||||
def Decl.simp? (decl : Decl .pure) : SimpM (Option (Decl .pure)) := do
|
||||
let .code code := decl.value | return none
|
||||
updateFunDeclInfo code
|
||||
traceM `Compiler.simp.inline.info do return m!"{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}"
|
||||
|
|
@ -42,7 +42,7 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
|
|||
else
|
||||
return none
|
||||
|
||||
partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
|
||||
partial def Decl.simp (decl : Decl .pure) (config : Config) : CompilerM (Decl .pure) := do
|
||||
let mut config := config
|
||||
if (← isTemplateLike decl) then
|
||||
/-
|
||||
|
|
@ -54,7 +54,7 @@ partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
|
|||
config := { config with etaPoly := false, inlinePartial := false }
|
||||
go decl config
|
||||
where
|
||||
go (decl : Decl) (config : Config) : CompilerM Decl := do
|
||||
go (decl : Decl .pure) (config : Config) : CompilerM (Decl .pure) := do
|
||||
if let some decl ← decl.simp? |>.run { config, declName := decl.name } |>.run' {} |>.run {} then
|
||||
-- TODO: bound number of steps?
|
||||
go decl config
|
||||
|
|
@ -62,7 +62,8 @@ where
|
|||
return decl
|
||||
|
||||
def simp (config : Config := {}) (occurrence : Nat := 0) (phase := Phase.base) : Pass :=
|
||||
.mkPerDeclaration `simp (Decl.simp · config) phase (occurrence := occurrence)
|
||||
phase.withPurityCheck .pure fun h =>
|
||||
.mkPerDeclaration `simp phase (h ▸ (Decl.simp · config)) (occurrence := occurrence)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.simp (inherited := true)
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ let _x.2 := _f.1
|
|||
```
|
||||
`findFunDecl? _x.2` returns `none`, but `findFunDecl'? _x.2` returns the declaration for `_f.1`.
|
||||
-/
|
||||
partial def findFunDecl'? (fvarId : FVarId) : CompilerM (Option FunDecl) := do
|
||||
if let some decl ← findFunDecl? fvarId then
|
||||
partial def findFunDecl'? (fvarId : FVarId) : CompilerM (Option (FunDecl pu)) := do
|
||||
if let some decl ← findFunDecl? (pu := pu) fvarId then
|
||||
return decl
|
||||
else if let some (.fvar fvarId' #[]) ← findLetValue? fvarId then
|
||||
else if let some (.fvar fvarId' #[]) ← findLetValue? (pu := pu) fvarId then
|
||||
findFunDecl'? fvarId'
|
||||
else
|
||||
return none
|
||||
|
|
|
|||
|
|
@ -18,14 +18,14 @@ namespace ConstantFold
|
|||
A constant folding monad, the additional state stores auxiliary declarations
|
||||
required to build the new constant.
|
||||
-/
|
||||
abbrev FolderM := StateRefT (Array CodeDecl) CompilerM
|
||||
abbrev FolderM := StateRefT (Array (CodeDecl .pure)) CompilerM
|
||||
|
||||
/--
|
||||
A constant folder for a specific function, takes all the arguments of a
|
||||
certain function and produces a new `Expr` + auxiliary declarations in
|
||||
the `FolderM` monad on success. If the folding fails it returns `none`.
|
||||
-/
|
||||
abbrev Folder := Array Arg → FolderM (Option LetValue)
|
||||
abbrev Folder := Array (Arg .pure) → FolderM (Option (LetValue .pure))
|
||||
|
||||
/--
|
||||
A typeclass for detecting and producing literals of arbitrary types
|
||||
|
|
@ -43,7 +43,7 @@ class Literal (α : Type) where
|
|||
final `Expr` putting them all together into a literal of type `α`,
|
||||
where again the idea of what a literal is depends on `α`.
|
||||
-/
|
||||
mkLit : α → FolderM LetValue
|
||||
mkLit : α → FolderM (LetValue .pure)
|
||||
|
||||
export Literal (getLit mkLit)
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ export Literal (getLit mkLit)
|
|||
A wrapper around `LCNF.mkAuxLetDecl` that will automatically store the
|
||||
`LetDecl` in the state of `FolderM`.
|
||||
-/
|
||||
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : FolderM FVarId := do
|
||||
def mkAuxLetDecl (e : LetValue .pure) (prefixName := `_x) : FolderM FVarId := do
|
||||
let decl ← LCNF.mkAuxLetDecl e prefixName
|
||||
modify fun s => s.push <| .let decl
|
||||
return decl.fvarId
|
||||
|
|
@ -66,10 +66,10 @@ def mkAuxLit [Literal α] (x : α) (prefixName := `_x) : FolderM FVarId := do
|
|||
mkAuxLetDecl lit prefixName
|
||||
|
||||
partial def getNatLit (fvarId : FVarId) : CompilerM (Option Nat) := do
|
||||
let some (.lit (.nat n)) ← findLetValue? fvarId | return none
|
||||
let some (.lit (.nat n)) ← findLetValue? (pu := .pure) fvarId | return none
|
||||
return n
|
||||
|
||||
def mkNatLit (n : Nat) : FolderM LetValue :=
|
||||
def mkNatLit (n : Nat) : FolderM (LetValue .pure) :=
|
||||
return .lit (.nat n)
|
||||
|
||||
instance : Literal Nat where
|
||||
|
|
@ -77,10 +77,10 @@ instance : Literal Nat where
|
|||
mkLit := mkNatLit
|
||||
|
||||
def getStringLit (fvarId : FVarId) : CompilerM (Option String) := do
|
||||
let some (.lit (.str s)) ← findLetValue? fvarId | return none
|
||||
let some (.lit (.str s)) ← findLetValue? (pu := .pure) fvarId | return none
|
||||
return s
|
||||
|
||||
def mkStringLit (n : String) : FolderM LetValue :=
|
||||
def mkStringLit (n : String) : FolderM (LetValue .pure) :=
|
||||
return .lit (.str n)
|
||||
|
||||
instance : Literal String where
|
||||
|
|
@ -91,7 +91,7 @@ def getBoolLit (fvarId : FVarId) : CompilerM (Option Bool) := do
|
|||
let some (.const ctor [] #[]) ← findLetValue? fvarId | return none
|
||||
return ctor == ``Bool.true
|
||||
|
||||
def mkBoolLit (b : Bool) : FolderM LetValue :=
|
||||
def mkBoolLit (b : Bool) : FolderM (LetValue .pure) :=
|
||||
let ctor := if b then ``Bool.true else ``Bool.false
|
||||
return .const ctor [] #[]
|
||||
|
||||
|
|
@ -115,7 +115,7 @@ instance : Literal Char := mkNatWrapperInstance Char.ofNat ``Char.ofNat Char.toN
|
|||
|
||||
def mkUIntInstance (matchLit : LitValue → Option α) (litValueCtor : α → LitValue) : Literal α where
|
||||
getLit fvarId := do
|
||||
let some (.lit litVal) ← findLetValue? fvarId | return none
|
||||
let some (.lit litVal) ← findLetValue? (pu := .pure) fvarId | return none
|
||||
return matchLit litVal
|
||||
mkLit x :=
|
||||
return .lit <| litValueCtor x
|
||||
|
|
@ -162,7 +162,7 @@ let _x.26 := @Array.push _ _x.24 z
|
|||
_x.26
|
||||
```
|
||||
-/
|
||||
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM LetValue := do
|
||||
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM (LetValue .pure) := do
|
||||
let sizeLit ← mkAuxLit elements.size
|
||||
let mut literal ← mkAuxLetDecl <| .const ``Array.mkEmpty [typLevel] #[.type typ, .fvar sizeLit]
|
||||
for element in elements do
|
||||
|
|
@ -335,7 +335,7 @@ def Folder.mulShift [Literal α] [BEq α] (shiftLeft : Name) (pow2 : α → α)
|
|||
-- TODO: add option for controlling the limit
|
||||
def natPowThreshold := 256
|
||||
|
||||
def foldNatPow (args : Array Arg) : FolderM (Option LetValue) := do
|
||||
def foldNatPow (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
|
||||
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
|
||||
let some value₁ ← getNatLit fvarId₁ | return none
|
||||
let some value₂ ← getNatLit fvarId₂ | return none
|
||||
|
|
@ -347,14 +347,14 @@ def foldNatPow (args : Array Arg) : FolderM (Option LetValue) := do
|
|||
/--
|
||||
Folder for ofNat operations on fixed-sized integer types.
|
||||
-/
|
||||
def Folder.ofNat (f : Nat → LitValue) (args : Array Arg) : FolderM (Option LetValue) := do
|
||||
def Folder.ofNat (f : Nat → LitValue) (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
|
||||
let #[.fvar fvarId] := args | return none
|
||||
let some value ← getNatLit fvarId | return none
|
||||
return some (.lit (f value))
|
||||
|
||||
def Folder.toNat (args : Array Arg) : FolderM (Option LetValue) := do
|
||||
def Folder.toNat (args : Array (Arg .pure)) : FolderM (Option (LetValue .pure)) := do
|
||||
let #[.fvar fvarId] := args | return none
|
||||
let some (.lit lit) ← findLetValue? fvarId | return none
|
||||
let some (.lit lit) ← findLetValue? (pu := .pure) fvarId | return none
|
||||
match lit with
|
||||
| .uint8 v | .uint16 v | .uint32 v | .uint64 v | .usize v => return some (.lit (.nat v.toNat))
|
||||
| .nat _ | .str _ => return none
|
||||
|
|
@ -436,7 +436,7 @@ def stringFolders : List (Name × Folder) := [
|
|||
/--
|
||||
Apply all known folders to `decl`.
|
||||
-/
|
||||
def applyFolders (decl : LetDecl) (folders : SMap Name Folder) : CompilerM (Option (Array CodeDecl)) := do
|
||||
def applyFolders (decl : LetDecl .pure) (folders : SMap Name Folder) : CompilerM (Option (Array (CodeDecl .pure))) := do
|
||||
match decl.value with
|
||||
| .const name _ args =>
|
||||
if let some folder := folders.find? name then
|
||||
|
|
@ -495,7 +495,7 @@ def getFolders : CoreM (SMap Name Folder) :=
|
|||
/--
|
||||
Apply a list of default folders to `decl`
|
||||
-/
|
||||
def foldConstants (decl : LetDecl) : CompilerM (Option (Array CodeDecl)) := do
|
||||
def foldConstants (decl : LetDecl .pure) : CompilerM (Option (Array (CodeDecl .pure))) := do
|
||||
applyFolders decl (← getFolders)
|
||||
|
||||
end ConstantFold
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ and the number of occurrences.
|
|||
We use this function to decide whether to create a `.default` case
|
||||
or not.
|
||||
-/
|
||||
private def getMaxOccs (alts : Array Alt) : Alt × Nat := Id.run do
|
||||
private def getMaxOccs (alts : Array (Alt .pure)) : Alt .pure × Nat := Id.run do
|
||||
let mut maxAlt := alts[0]!
|
||||
let mut max := getNumOccsOf alts 0
|
||||
for h : i in 1...alts.size do
|
||||
|
|
@ -35,7 +35,7 @@ where
|
|||
Note that the number of occurrences can be greater than 1 only when
|
||||
the alternative does not depend on field parameters
|
||||
-/
|
||||
getNumOccsOf (alts : Array Alt) (i : Nat) : Nat := Id.run do
|
||||
getNumOccsOf (alts : Array (Alt .pure)) (i : Nat) : Nat := Id.run do
|
||||
let code := alts[i]!.getCode
|
||||
let mut n := 1
|
||||
for h : j in (i+1)...alts.size do
|
||||
|
|
@ -47,7 +47,7 @@ where
|
|||
Add a default case to the given `cases` alternatives if there
|
||||
are alternatives with equivalent (aka alpha equivalent) right hand sides.
|
||||
-/
|
||||
def addDefaultAlt (alts : Array Alt) : SimpM (Array Alt) := do
|
||||
def addDefaultAlt (alts : Array (Alt .pure)) : SimpM (Array (Alt .pure)) := do
|
||||
if alts.size <= 1 || alts.any (· matches .default ..) then
|
||||
return alts
|
||||
else
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace Lean.Compiler.LCNF
|
|||
namespace Simp
|
||||
|
||||
inductive CtorInfo where
|
||||
| ctor (val : ConstructorVal) (args : Array Arg)
|
||||
| ctor (val : ConstructorVal) (args : Array (Arg .pure))
|
||||
| /-- Natural numbers are morally constructor applications -/
|
||||
natVal (n : Nat)
|
||||
|
||||
|
|
@ -70,7 +70,7 @@ def findCtorName? (fvarId : FVarId) : DiscrM (Option Name) := do
|
|||
/--
|
||||
If `type` is an application of the inductive type `ind`, return its universe levels and parameters.
|
||||
-/
|
||||
def getIndInfo? (type : Expr) (ind : Name) : CoreM (Option (List Level × Array Arg)) := do
|
||||
def getIndInfo? (type : Expr) (ind : Name) : CoreM (Option (List Level × Array (Arg .pure))) := do
|
||||
let type := type.headBeta
|
||||
let .const declName us := type.getAppFn | return none
|
||||
unless declName == ind do return none
|
||||
|
|
@ -85,7 +85,8 @@ def getIndInfo? (type : Expr) (ind : Name) : CoreM (Option (List Level × Array
|
|||
Execute `x` with the information that `discr = ctorName ctorFields`.
|
||||
We use this information to simplify nested cases on the same discriminant.
|
||||
-/
|
||||
@[inline] def withDiscrCtorImp (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : DiscrM α) : DiscrM α := do
|
||||
@[inline] def withDiscrCtorImp (discr : FVarId) (ctorName : Name)
|
||||
(ctorFields : Array (Param .pure)) (x : DiscrM α) : DiscrM α := do
|
||||
let ctx ← updateCtx
|
||||
withReader (fun _ => ctx) x
|
||||
where
|
||||
|
|
@ -103,7 +104,9 @@ where
|
|||
let ctorInfo := .ctor ctorVal (.replicate ctorVal.numParams Arg.erased ++ fieldArgs)
|
||||
return { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctorInfo }
|
||||
|
||||
@[inline, inherit_doc withDiscrCtorImp] def withDiscrCtor [MonadFunctorT DiscrM m] (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) : m α → m α :=
|
||||
@[inline, inherit_doc withDiscrCtorImp]
|
||||
def withDiscrCtor [MonadFunctorT DiscrM m] (discr : FVarId) (ctorName : Name)
|
||||
(ctorFields : Array (Param .pure)) : m α → m α :=
|
||||
monadMap (m := DiscrM) <| withDiscrCtorImp discr ctorName ctorFields
|
||||
|
||||
def simpCtorDiscrCore? (e : Expr) : DiscrM (Option FVarId) := do
|
||||
|
|
|
|||
|
|
@ -94,27 +94,27 @@ If `mustInline := true`, then all local function declarations occurring in
|
|||
`code` are tagged as `.mustInline`.
|
||||
Recall that we use `.mustInline` for local function declarations occurring in type class instances.
|
||||
-/
|
||||
partial def FunDeclInfoMap.update (s : FunDeclInfoMap) (code : Code) (mustInline := false) : CompilerM FunDeclInfoMap := do
|
||||
partial def FunDeclInfoMap.update (s : FunDeclInfoMap) (code : Code .pure) (mustInline := false) : CompilerM FunDeclInfoMap := do
|
||||
let (_, s) ← go code |>.run s
|
||||
return s
|
||||
where
|
||||
addArgOcc (arg : Arg) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
addArgOcc (arg : Arg .pure) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
match arg with
|
||||
| .fvar fvarId =>
|
||||
let some funDecl ← findFunDecl'? fvarId | return ()
|
||||
let some funDecl ← findFunDecl'? (pu := .pure) fvarId | return ()
|
||||
modify fun s => s.addHo funDecl.fvarId
|
||||
| .erased .. | .type .. => return ()
|
||||
|
||||
addLetValueOccs (e : LetValue) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
addLetValueOccs (e : LetValue .pure) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
match e with
|
||||
| .erased | .lit .. | .proj .. => return ()
|
||||
| .const _ _ args => args.forM addArgOcc
|
||||
| .fvar fvarId args =>
|
||||
let some funDecl ← findFunDecl'? fvarId | return ()
|
||||
let some funDecl ← findFunDecl'? (pu := .pure) fvarId | return ()
|
||||
modify fun s => s.add funDecl.fvarId
|
||||
args.forM addArgOcc
|
||||
|
||||
go (code : Code) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
go (code : Code .pure) : StateRefT FunDeclInfoMap CompilerM Unit := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
addLetValueOccs decl.value
|
||||
|
|
@ -126,7 +126,7 @@ where
|
|||
| .jp decl k => go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .jmp fvarId args =>
|
||||
let funDecl ← getFunDecl fvarId
|
||||
let funDecl ← getFunDecl (pu := .pure) fvarId
|
||||
modify fun s => s.add funDecl.fvarId
|
||||
args.forM addArgOcc
|
||||
| .return .. | .unreach .. => return ()
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ It contains information for inlining local and global functions.
|
|||
-/
|
||||
structure InlineCandidateInfo where
|
||||
isLocal : Bool
|
||||
params : Array Param
|
||||
params : Array (Param .pure)
|
||||
/-- Value (lambda expression) of the function to be inlined. -/
|
||||
value : Code
|
||||
value : Code .pure
|
||||
fType : Expr
|
||||
args : Array Arg
|
||||
args : Array (Arg .pure)
|
||||
/-- `ifReduce = true` if the declaration being inlined was tagged with `inline_if_reduce`. -/
|
||||
ifReduce : Bool
|
||||
/-- `recursive = true` if the declaration being inline is in a mutually recursive block. -/
|
||||
|
|
@ -36,7 +36,7 @@ def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
|
|||
/--
|
||||
Return `some info` if `e` should be inlined.
|
||||
-/
|
||||
def inlineCandidate? (e : LetValue) : SimpM (Option InlineCandidateInfo) := do
|
||||
def inlineCandidate? (e : LetValue .pure) : SimpM (Option InlineCandidateInfo) := do
|
||||
let mut e := e
|
||||
let mut mustInline := false
|
||||
if let .const ``inline _ #[_, .fvar argFVarId] := e then
|
||||
|
|
@ -46,7 +46,7 @@ def inlineCandidate? (e : LetValue) : SimpM (Option InlineCandidateInfo) := do
|
|||
if let .const declName us args := e then
|
||||
unless (← read).config.inlineDefs do
|
||||
return none
|
||||
let some decl ← getDecl? declName | return none
|
||||
let some ⟨.pure, decl⟩ ← getDecl? declName | return none
|
||||
let .code code := decl.value | return none
|
||||
let shouldInline : SimpM Bool := do
|
||||
if !decl.inlineIfReduceAttr && decl.recursive then return false
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ and the free variable containing the result (`FVarId`). The resulting `FVarId` o
|
|||
subset of `Array CodeDecl`. However, this method does try to filter the relevant ones.
|
||||
We rely on the `used` var set available in `SimpM` to filter them. See `attachCodeDecls`.
|
||||
-/
|
||||
partial def inlineProjInst? (e : LetValue) : SimpM (Option (Array CodeDecl × FVarId)) := do
|
||||
partial def inlineProjInst? (e : LetValue .pure) : SimpM (Option (Array (CodeDecl .pure) × FVarId)) := do
|
||||
let .proj _ i s := e | return none
|
||||
let sType ← getType s
|
||||
unless (← isClass? sType).isSome do return none
|
||||
|
|
@ -52,7 +52,7 @@ partial def inlineProjInst? (e : LetValue) : SimpM (Option (Array CodeDecl × FV
|
|||
eraseCodeDecls decls
|
||||
return none
|
||||
where
|
||||
visit (fvarId : FVarId) (projs : List Nat) : OptionT (StateRefT (Array CodeDecl) SimpM) FVarId := do
|
||||
visit (fvarId : FVarId) (projs : List Nat) : OptionT (StateRefT (Array (CodeDecl .pure)) SimpM) FVarId := do
|
||||
let some letDecl ← findLetDecl? fvarId | failure
|
||||
match letDecl.value with
|
||||
| .proj _ i s => visit s (i :: projs)
|
||||
|
|
@ -72,7 +72,7 @@ where
|
|||
else
|
||||
visit fvarId projs
|
||||
else
|
||||
let some decl ← getDecl? declName | failure
|
||||
let some ⟨.pure, decl⟩ ← getDecl? declName | failure
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
guard (!decl.recursive && decl.getArity == args.size)
|
||||
|
|
@ -82,7 +82,7 @@ where
|
|||
visitCode code projs
|
||||
| .extern .. => failure
|
||||
|
||||
visitCode (code : Code) (projs : List Nat) : OptionT (StateRefT (Array CodeDecl) SimpM) FVarId := do
|
||||
visitCode (code : Code .pure) (projs : List Nat) : OptionT (StateRefT (Array (CodeDecl .pure)) SimpM) FVarId := do
|
||||
match code with
|
||||
| .let decl k => modify (·.push (.let decl)); visitCode k projs
|
||||
| .fun decl k => modify (·.push (.fun decl)); visitCode k projs
|
||||
|
|
|
|||
|
|
@ -26,12 +26,12 @@ f y :=
|
|||
```
|
||||
`idx` is the index of the parameter used in the `cases` statement.
|
||||
-/
|
||||
def isJpCases? (decl : FunDecl) : CompilerM (Option Nat) := do
|
||||
def isJpCases? (decl : FunDecl .pure) : CompilerM (Option Nat) := do
|
||||
if decl.params.size == 0 then
|
||||
return none
|
||||
else
|
||||
let small := (← getConfig).smallThreshold
|
||||
let rec go (code : Code) (prefixSize : Nat) : Option Nat :=
|
||||
let rec go (code : Code .pure) (prefixSize : Nat) : Option Nat :=
|
||||
if prefixSize > small then none else
|
||||
match code with
|
||||
| .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/
|
||||
|
|
@ -64,11 +64,11 @@ in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor name
|
|||
there is a jump `.jmp jpFVarId #[..., x, ...]` in `code` and `x` is a constructor application.
|
||||
`paramIdx` is the index of the parameter
|
||||
-/
|
||||
partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfoMap := do
|
||||
partial def collectJpCasesInfo (code : Code .pure) : CompilerM JpCasesInfoMap := do
|
||||
let (_, s) ← go code |>.run {} |>.run {}
|
||||
return s
|
||||
where
|
||||
go (code : Code) : StateRefT JpCasesInfoMap DiscrM Unit := do
|
||||
go (code : Code .pure) : StateRefT JpCasesInfoMap DiscrM Unit := do
|
||||
match code with
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value; go k
|
||||
|
|
@ -90,17 +90,17 @@ where
|
|||
/--
|
||||
Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases?`.
|
||||
-/
|
||||
private def extractJpCases (code : Code) : Array CodeDecl × Cases :=
|
||||
private def extractJpCases (code : Code .pure) : Array (CodeDecl .pure) × Cases .pure :=
|
||||
go code #[]
|
||||
where
|
||||
go (code : Code) (decls : Array CodeDecl) :=
|
||||
go (code : Code .pure) (decls : Array (CodeDecl .pure)) :=
|
||||
match code with
|
||||
| .let decl k => go k <| decls.push (.let decl)
|
||||
| .cases c => (decls, c)
|
||||
| _ => unreachable! -- `code` is not the body of a join point that satisfies `isJpCases`
|
||||
|
||||
structure JpCasesAlt where
|
||||
decl : FunDecl
|
||||
decl : FunDecl .pure
|
||||
default : Bool
|
||||
dependsOnDiscr : Bool
|
||||
|
||||
|
|
@ -116,10 +116,12 @@ Construct an auxiliary join point for a particular alternative in a join-point t
|
|||
- `k` is the body of the alternative.
|
||||
- `default` is true if it is a default alternative.
|
||||
-/
|
||||
private def mkJpAlt (decls : Array CodeDecl) (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do
|
||||
private def mkJpAlt (decls : Array (CodeDecl .pure)) (params : Array (Param .pure))
|
||||
(targetParamIdx : Nat) (fields : Array (Param .pure)) (k : Code .pure) (default : Bool) :
|
||||
CompilerM JpCasesAlt := do
|
||||
go |>.run' {}
|
||||
where
|
||||
go : InternalizeM JpCasesAlt := do
|
||||
go : InternalizeM .pure JpCasesAlt := do
|
||||
let mut paramsNew := #[]
|
||||
let singleton : FVarIdSet := ({} : FVarIdSet).insert params[targetParamIdx]!.fvarId
|
||||
let dependsOnDiscr := k.dependsOn singleton || decls.any (·.dependsOn singleton)
|
||||
|
|
@ -137,7 +139,8 @@ where
|
|||
return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr }
|
||||
|
||||
/-- Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`. -/
|
||||
private def mkJmpNewArgs (args : Array Arg) (targetParamIdx : Nat) (fields : Array Arg) (dependsOnTarget : Bool) : Array Arg :=
|
||||
private def mkJmpNewArgs (args : Array (Arg .pure)) (targetParamIdx : Nat)
|
||||
(fields : Array (Arg .pure)) (dependsOnTarget : Bool) : Array (Arg .pure) :=
|
||||
if dependsOnTarget then
|
||||
args[*...=targetParamIdx] ++ fields ++ args[targetParamIdx<...*]
|
||||
else
|
||||
|
|
@ -147,7 +150,8 @@ private def mkJmpNewArgs (args : Array Arg) (targetParamIdx : Nat) (fields : Arr
|
|||
Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`.
|
||||
This function is used to create jumps from the join point satisfying `isJpCases?` to the new auxiliary join points created using `mkJpAlt`.
|
||||
-/
|
||||
private def mkJmpArgsAtJp (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (dependsOnTarget : Bool) : Array Arg := Id.run do
|
||||
private def mkJmpArgsAtJp (params : Array (Param .pure)) (targetParamIdx : Nat)
|
||||
(fields : Array (Param .pure)) (dependsOnTarget : Bool) : Array (Arg .pure) := Id.run do
|
||||
mkJmpNewArgs (params.map (Arg.fvar ·.fvarId)) targetParamIdx (fields.map (Arg.fvar ·.fvarId)) dependsOnTarget
|
||||
|
||||
/--
|
||||
|
|
@ -194,7 +198,7 @@ cases x.4
|
|||
Note that if all jumps to the join point are with constructors,
|
||||
then the join point is eliminated as dead code.
|
||||
-/
|
||||
partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do
|
||||
partial def simpJpCases? (code : Code .pure) : CompilerM (Option (Code .pure)) := do
|
||||
let map ← collectJpCasesInfo code
|
||||
unless map.isCandidate do return none
|
||||
traceM `Compiler.simp.jpCases do
|
||||
|
|
@ -204,7 +208,7 @@ partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do
|
|||
return msg
|
||||
visit code map |>.run' {} |>.run {}
|
||||
where
|
||||
visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) Code := do
|
||||
visit (code : Code .pure) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
return code.updateLet! decl (← visit k)
|
||||
|
|
@ -232,7 +236,8 @@ where
|
|||
let some code ← visitJmp? fvarId args | return code
|
||||
return code
|
||||
|
||||
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do
|
||||
visitJp? (decl : FunDecl .pure) (k : Code .pure) :
|
||||
ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option (Code .pure)) := do
|
||||
let some info := (← read).get? decl.fvarId | return none
|
||||
if info.ctorNames.isEmpty then return none
|
||||
-- This join point satisfies `isJpCases?` and there are jumps with constructors in `info` to it.
|
||||
|
|
@ -273,7 +278,8 @@ where
|
|||
let code := .jp decl (← visit k)
|
||||
return LCNF.attachCodeDecls jpAltDecls code
|
||||
|
||||
visitJmp? (fvarId : FVarId) (args : Array Arg) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option Code) := do
|
||||
visitJmp? (fvarId : FVarId) (args : Array (Arg .pure)) :
|
||||
ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt DiscrM) (Option (Code .pure)) := do
|
||||
let some ctorJpAltMap := (← get).get? fvarId | return none
|
||||
let some info := (← read).get? fvarId | return none
|
||||
let .fvar argFVarId := args[info.paramIdx]! | return none
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ such as: a `cases` with many but only one alternative is not reachable.
|
|||
It is only used to avoid the creation of auxiliary join points, and does not need
|
||||
to be precise.
|
||||
-/
|
||||
private partial def oneExitPointQuick (c : Code) : Bool :=
|
||||
private partial def oneExitPointQuick (c : Code .pure) : Bool :=
|
||||
go c
|
||||
where
|
||||
go (c : Code) : Bool :=
|
||||
go (c : Code .pure) : Bool :=
|
||||
match c with
|
||||
| .let _ k | .fun _ k => go k
|
||||
-- Approximation, the cases may have many unreachable alternatives, and only reachable.
|
||||
|
|
@ -41,7 +41,7 @@ where
|
|||
Create a new local function declaration when `info.args.size < info.params.size`.
|
||||
We use this function to inline/specialize a partial application of a local function.
|
||||
-/
|
||||
def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
|
||||
def specializePartialApp (info : InlineCandidateInfo) : SimpM (FunDecl .pure) := do
|
||||
let mut subst := {}
|
||||
for param in info.params, arg in info.args do
|
||||
subst := subst.insert param.fvarId arg
|
||||
|
|
@ -58,7 +58,7 @@ def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
|
|||
/--
|
||||
Try to inline a join point.
|
||||
-/
|
||||
partial def inlineJp? (fvarId : FVarId) (args : Array Arg) : SimpM (Option Code) := do
|
||||
partial def inlineJp? (fvarId : FVarId) (args : Array (Arg .pure)) : SimpM (Option (Code .pure)) := do
|
||||
/- Remark: we don't need to use `findFunDecl'?` here. -/
|
||||
let some decl ← findFunDecl? fvarId | return none
|
||||
unless (← shouldInlineLocal decl) do return none
|
||||
|
|
@ -71,13 +71,13 @@ partial applications of functions that take local instances as arguments.
|
|||
This kind of function is inlined or specialized, and we create new
|
||||
simplification opportunities by eta-expanding them.
|
||||
-/
|
||||
def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do
|
||||
def etaPolyApp? (letDecl : LetDecl .pure) : OptionT SimpM (FunDecl .pure) := do
|
||||
guard <| (← read).config.etaPoly
|
||||
let .const declName us args := letDecl.value | failure
|
||||
let some info := (← getEnv).find? declName | failure
|
||||
guard <| (← hasLocalInst info.type)
|
||||
guard <| !(← Meta.isInstance declName)
|
||||
let some decl ← getDecl? declName | failure
|
||||
let some ⟨.pure, decl⟩ ← getDecl? declName | failure
|
||||
guard <| decl.getArity > args.size
|
||||
let params ← mkNewParams letDecl.type
|
||||
let auxDecl ← mkAuxLetDecl (.const declName us (args ++ params.map (.fvar ·.fvarId)))
|
||||
|
|
@ -89,14 +89,14 @@ def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do
|
|||
/--
|
||||
Similar to `Code.isReturnOf`, but taking the current substitution into account.
|
||||
-/
|
||||
def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do
|
||||
def isReturnOf (c : Code .pure) (fvarId : FVarId) : SimpM Bool := do
|
||||
match c with
|
||||
| .return fvarId' => match (← normFVar fvarId') with
|
||||
| .fvar fvarId'' => return fvarId'' == fvarId
|
||||
| .erased => return false
|
||||
| _ => return false
|
||||
|
||||
def elimVar? (value : LetValue) : SimpM (Option FVarId) := do
|
||||
def elimVar? (value : LetValue .pure) : SimpM (Option FVarId) := do
|
||||
let .fvar fvarId #[] := value | return none
|
||||
return fvarId
|
||||
|
||||
|
|
@ -117,7 +117,7 @@ of exit points by simplified the inlined code, and then connecting the result to
|
|||
continuation `k`. However, this optimization is only possible if we simplify the
|
||||
inlined code **before** we attach it to the continuation.
|
||||
-/
|
||||
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
|
||||
partial def inlineApp? (letDecl : LetDecl .pure) (k : Code .pure) : SimpM (Option (Code .pure)) := do
|
||||
let some info ← inlineCandidate? letDecl.value | return none
|
||||
let numArgs := info.args.size
|
||||
withInlining letDecl.value info.recursive do
|
||||
|
|
@ -135,7 +135,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
|
|||
simp code
|
||||
else
|
||||
let code ← simp code
|
||||
let simpK (result : FVarId) : SimpM Code := do
|
||||
let simpK (result : FVarId) : SimpM (Code .pure) := do
|
||||
/- `result` contains the result of the inlined code -/
|
||||
if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (.fvar result info.args[info.arity...*])
|
||||
|
|
@ -151,7 +151,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
|
|||
markUsedFVar fvarId'
|
||||
simpK fvarId'
|
||||
else
|
||||
let expectedType ← inferAppType info.fType info.args[*...info.arity]
|
||||
let expectedType ← inferAppType info.fType (info.args[*...info.arity]).toArray
|
||||
if expectedType.headBeta.isForall then
|
||||
/-
|
||||
If `code` returns a function, we create an auxiliary local function declaration (and eta-expand it)
|
||||
|
|
@ -171,7 +171,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
|
|||
/--
|
||||
Simplify the given local function declaration.
|
||||
-/
|
||||
partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
|
||||
partial def simpFunDecl (decl : FunDecl .pure) : SimpM (FunDecl .pure) := do
|
||||
let type ← normExpr decl.type
|
||||
let params ← normParams decl.params
|
||||
let value ← simp decl.value
|
||||
|
|
@ -180,9 +180,11 @@ partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
|
|||
/--
|
||||
Try to simplify `cases` of `constructor`
|
||||
-/
|
||||
partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
||||
partial def simpCasesOnCtor? (cases : Cases .pure) : SimpM (Option (Code .pure)) := do
|
||||
match (← normFVar cases.discr) with
|
||||
| .erased => mkReturnErased
|
||||
| .erased =>
|
||||
let ret ← mkReturnErased
|
||||
return some ret
|
||||
| .fvar discr =>
|
||||
let some ctorInfo ← findCtor? discr | return none
|
||||
let (alt, cases) := cases.extractAlt! ctorInfo.getName
|
||||
|
|
@ -210,7 +212,7 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
|||
/--
|
||||
Simplify `code`
|
||||
-/
|
||||
partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
||||
partial def simp (code : Code .pure) : SimpM (Code .pure) := withIncRecDepth do
|
||||
incVisited
|
||||
match code with
|
||||
| .let decl k =>
|
||||
|
|
@ -225,7 +227,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
|||
-- and `FVarId` rather than `Arg`, and the substitution will end up
|
||||
-- creating a new erased let decl in that case.
|
||||
if decl.type.isErased && decl.value != .erased then
|
||||
addSubst decl.fvarId .erased
|
||||
addSubst decl.fvarId (.erased : Arg .pure)
|
||||
eraseLetDecl decl
|
||||
simp k
|
||||
else if let some decls ← ConstantFold.foldConstants decl then
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ structure State where
|
|||
/--
|
||||
Free variable substitution. We use it to implement inlining and removing redundant variables `let _x.i := _x.j`
|
||||
-/
|
||||
subst : FVarSubst := {}
|
||||
subst : FVarSubst .pure := {}
|
||||
/--
|
||||
Track used local declarations to be able to eliminate dead variables.
|
||||
-/
|
||||
|
|
@ -80,10 +80,10 @@ abbrev SimpM := ReaderT Context $ StateRefT State DiscrM
|
|||
@[always_inline]
|
||||
instance : Monad SimpM := let i := inferInstanceAs (Monad SimpM); { pure := i.pure, bind := i.bind }
|
||||
|
||||
instance : MonadFVarSubst SimpM false where
|
||||
instance : MonadFVarSubst SimpM .pure false where
|
||||
getSubst := return (← get).subst
|
||||
|
||||
instance : MonadFVarSubstState SimpM where
|
||||
instance : MonadFVarSubstState SimpM .pure where
|
||||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||||
|
||||
/-- Set the `simplified` flag to `true`. -/
|
||||
|
|
@ -115,7 +115,7 @@ def addFunHoOcc (fvarId : FVarId) : SimpM Unit :=
|
|||
modify fun s => { s with funDeclInfoMap := s.funDeclInfoMap.addHo fvarId }
|
||||
|
||||
@[inherit_doc FunDeclInfoMap.update]
|
||||
partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit := do
|
||||
partial def updateFunDeclInfo (code : Code .pure) (mustInline := false) : SimpM Unit := do
|
||||
let map ← modifyGet fun s => (s.funDeclInfoMap, { s with funDeclInfoMap := {} })
|
||||
let map ← map.update code mustInline
|
||||
modify fun s => { s with funDeclInfoMap := map }
|
||||
|
|
@ -124,7 +124,7 @@ partial def updateFunDeclInfo (code : Code) (mustInline := false) : SimpM Unit :
|
|||
Execute `x` with an updated `inlineStack`. If `value` is of the form `const ...`, add `const` to the stack.
|
||||
Otherwise, do not change the `inlineStack`.
|
||||
-/
|
||||
@[inline] def withInlining (value : LetValue) (recursive : Bool) (x : SimpM α) : SimpM α := do
|
||||
@[inline] def withInlining (value : LetValue .pure) (recursive : Bool) (x : SimpM α) : SimpM α := do
|
||||
if let .const declName _ _ := value then
|
||||
let numOccs ← check declName
|
||||
withReader (fun ctx => { ctx with inlineStack := declName :: ctx.inlineStack, inlineStackOccs := ctx.inlineStackOccs.insert declName numOccs }) x
|
||||
|
|
@ -135,7 +135,7 @@ where
|
|||
trace[Compiler.simp.inline] "{.ofConstName declName}"
|
||||
let numOccs := (← read).inlineStackOccs.find? declName |>.getD 0
|
||||
let numOccs := numOccs + 1
|
||||
let inlineIfReduce ← if let some decl ← getDecl? declName then pure decl.inlineIfReduceAttr else pure false
|
||||
let inlineIfReduce ← if let some ⟨_, decl⟩ ← getDecl? declName then pure decl.inlineIfReduceAttr else pure false
|
||||
if recursive && inlineIfReduce && numOccs > (← getConfig).maxRecInlineIfReduce then
|
||||
throwError "function `{.ofConstName declName}` has been recursively inlined more than #{(← getConfig).maxRecInlineIfReduce}, consider removing the attribute `[inline_if_reduce]` from this declaration or increasing the limit using `set_option compiler.maxRecInlineIfReduce <num>`"
|
||||
return numOccs
|
||||
|
|
@ -193,13 +193,13 @@ def isOnceOrMustInline (fvarId : FVarId) : SimpM Bool := do
|
|||
/--
|
||||
Return `true` if the given code is considered "small".
|
||||
-/
|
||||
def isSmall (code : Code) : SimpM Bool :=
|
||||
def isSmall (code : Code .pure) : SimpM Bool :=
|
||||
return code.sizeLe (← getConfig).smallThreshold
|
||||
|
||||
/--
|
||||
Return `true` if the given local function declaration should be inlined.
|
||||
-/
|
||||
def shouldInlineLocal (decl : FunDecl) : SimpM Bool := do
|
||||
def shouldInlineLocal (decl : FunDecl .pure) : SimpM Bool := do
|
||||
if (← isOnceOrMustInline decl.fvarId) then
|
||||
return true
|
||||
else
|
||||
|
|
@ -210,7 +210,8 @@ LCNF "Beta-reduce". The equivalent of `(fun params => code) args`.
|
|||
If `mustInline` is true, the local function declarations in the resulting code are marked as `.mustInline`.
|
||||
See comment at `updateFunDeclInfo`.
|
||||
-/
|
||||
def betaReduce (params : Array Param) (code : Code) (args : Array Arg) (mustInline := false) : SimpM Code := do
|
||||
def betaReduce (params : Array (Param .pure)) (code : Code .pure) (args : Array (Arg .pure))
|
||||
(mustInline := false) : SimpM (Code .pure) := do
|
||||
let mut subst := {}
|
||||
for param in params, arg in args do
|
||||
subst := subst.insert param.fvarId arg
|
||||
|
|
@ -222,7 +223,7 @@ def betaReduce (params : Array Param) (code : Code) (args : Array Arg) (mustInli
|
|||
Erase the given let-declaration from the local context,
|
||||
and set the `simplified` flag to true.
|
||||
-/
|
||||
def eraseLetDecl (decl : LetDecl) : SimpM Unit := do
|
||||
def eraseLetDecl (decl : LetDecl .pure) : SimpM Unit := do
|
||||
LCNF.eraseLetDecl decl
|
||||
markSimplified
|
||||
|
||||
|
|
@ -230,7 +231,7 @@ def eraseLetDecl (decl : LetDecl) : SimpM Unit := do
|
|||
Erase the given local function declaration from the local context,
|
||||
and set the `simplified` flag to true.
|
||||
-/
|
||||
def eraseFunDecl (decl : FunDecl) : SimpM Unit := do
|
||||
def eraseFunDecl (decl : FunDecl .pure) : SimpM Unit := do
|
||||
LCNF.eraseFunDecl decl
|
||||
markSimplified
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ namespace Simp
|
|||
/--
|
||||
Try to simplify projections `.proj _ i s` where `s` is constructor.
|
||||
-/
|
||||
def simpProj? (e : LetValue) : OptionT SimpM LetValue := do
|
||||
def simpProj? (e : LetValue .pure) : OptionT SimpM (LetValue .pure) := do
|
||||
let .proj _ i s := e | failure
|
||||
let some ctorInfo ← findCtor? s | failure
|
||||
match ctorInfo with
|
||||
|
|
@ -31,7 +31,7 @@ g b
|
|||
```
|
||||
is simplified to `f a b`.
|
||||
-/
|
||||
def simpAppApp? (e : LetValue) : OptionT SimpM LetValue := do
|
||||
def simpAppApp? (e : LetValue .pure) : OptionT SimpM (LetValue .pure) := do
|
||||
let .fvar g args := e | failure
|
||||
let some decl ← findLetDecl? g | failure
|
||||
match decl.value with
|
||||
|
|
@ -46,19 +46,19 @@ def simpAppApp? (e : LetValue) : OptionT SimpM LetValue := do
|
|||
| .erased => return .erased
|
||||
| .proj .. | .lit .. => failure
|
||||
|
||||
def simpCtorDiscr? (e : LetValue) : OptionT SimpM LetValue := do
|
||||
def simpCtorDiscr? (e : LetValue .pure) : OptionT SimpM (LetValue .pure) := do
|
||||
let .const declName _ _ := e | failure
|
||||
let some (.ctorInfo _) := (← getEnv).find? declName | failure
|
||||
let some fvarId ← simpCtorDiscrCore? e.toExpr | failure
|
||||
return .fvar fvarId #[]
|
||||
|
||||
def applyImplementedBy? (e : LetValue) : OptionT SimpM LetValue := do
|
||||
def applyImplementedBy? (e : LetValue .pure) : OptionT SimpM (LetValue .pure) := do
|
||||
guard <| (← read).config.implementedBy
|
||||
let .const declName us args := e | failure
|
||||
let some declNameNew := getImplementedBy? (← getEnv) declName | failure
|
||||
return .const declNameNew us args
|
||||
|
||||
/-- Try to apply simple simplifications. -/
|
||||
def simpValue? (e : LetValue) : SimpM (Option LetValue) :=
|
||||
def simpValue? (e : LetValue .pure) : SimpM (Option (LetValue .pure)) :=
|
||||
-- TODO: more simplifications
|
||||
simpProj? e <|> simpAppApp? e <|> simpCtorDiscr? e <|> applyImplementedBy? e
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def markUsedFVar (fvarId : FVarId) : SimpM Unit :=
|
|||
/--
|
||||
Mark all free variables occurring in `arg` as used.
|
||||
-/
|
||||
def markUsedArg (arg : Arg) : SimpM Unit :=
|
||||
def markUsedArg (arg : Arg .pure) : SimpM Unit :=
|
||||
match arg with
|
||||
| .fvar fvarId => markUsedFVar fvarId
|
||||
-- Locally declared variables do not occur in types.
|
||||
|
|
@ -32,7 +32,7 @@ def markUsedArg (arg : Arg) : SimpM Unit :=
|
|||
/--
|
||||
Mark all free variables occurring in `e` as used.
|
||||
-/
|
||||
def markUsedLetValue (e : LetValue) : SimpM Unit := do
|
||||
def markUsedLetValue (e : LetValue .pure) : SimpM Unit := do
|
||||
match e with
|
||||
| .lit .. | .erased => return ()
|
||||
| .proj _ _ fvarId => markUsedFVar fvarId
|
||||
|
|
@ -43,14 +43,14 @@ def markUsedLetValue (e : LetValue) : SimpM Unit := do
|
|||
Mark all free variables occurring on the right-hand side of the given let declaration as used.
|
||||
This is information is used to eliminate dead local declarations.
|
||||
-/
|
||||
def markUsedLetDecl (letDecl : LetDecl) : SimpM Unit :=
|
||||
def markUsedLetDecl (letDecl : LetDecl .pure) : SimpM Unit :=
|
||||
markUsedLetValue letDecl.value
|
||||
|
||||
mutual
|
||||
/--
|
||||
Mark all free variables occurring in `code` as used.
|
||||
-/
|
||||
partial def markUsedCode (code : Code) : SimpM Unit := do
|
||||
partial def markUsedCode (code : Code .pure) : SimpM Unit := do
|
||||
match code with
|
||||
| .let decl k => markUsedLetDecl decl; markUsedCode k
|
||||
| .jp decl k | .fun decl k => markUsedFunDecl decl; markUsedCode k
|
||||
|
|
@ -62,7 +62,7 @@ partial def markUsedCode (code : Code) : SimpM Unit := do
|
|||
/--
|
||||
Mark all free variables occurring in `funDecl` as used.
|
||||
-/
|
||||
partial def markUsedFunDecl (funDecl : FunDecl) : SimpM Unit :=
|
||||
partial def markUsedFunDecl (funDecl : FunDecl .pure) : SimpM Unit :=
|
||||
markUsedCode funDecl.value
|
||||
end
|
||||
|
||||
|
|
@ -81,10 +81,10 @@ let _x.2 := true
|
|||
<code>
|
||||
```
|
||||
-/
|
||||
def attachCodeDecls (decls : Array CodeDecl) (code : Code) : SimpM Code := do
|
||||
def attachCodeDecls (decls : Array (CodeDecl .pure)) (code : Code .pure) : SimpM (Code .pure) := do
|
||||
go decls.size code
|
||||
where
|
||||
go (i : Nat) (code : Code) : SimpM Code := do
|
||||
go (i : Nat) (code : Code .pure) : SimpM (Code .pure) := do
|
||||
if i > 0 then
|
||||
let decl := decls[i-1]!
|
||||
if (← isUsed decl.fvarId) then
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ and `k` is tagged as `.user`, `.fixedHO`, or `.fixedInst`.
|
|||
|
||||
See comment at `.fixedNeutral`.
|
||||
-/
|
||||
private def hasFwdDeps (decl : Decl) (paramsInfo : Array SpecParamInfo) (j : Nat) : Bool := Id.run do
|
||||
private def hasFwdDeps (decl : Decl .pure) (paramsInfo : Array SpecParamInfo) (j : Nat) : Bool := Id.run do
|
||||
let param := decl.params[j]!
|
||||
for h : k in (j+1)...decl.params.size do
|
||||
if paramsInfo[k]!.causesSpecialization then
|
||||
|
|
@ -175,7 +175,7 @@ computationally relevant declarations. Furthermore this function takes:
|
|||
- `alreadySpecialized` which is a mask that says whether a decl is already a specialized declaration
|
||||
itself.
|
||||
-/
|
||||
def computeSpecEntries (decls : Array Decl) (autoSpecialize : Name → Option (Array Nat) → Bool)
|
||||
def computeSpecEntries (decls : Array (Decl .pure)) (autoSpecialize : Name → Option (Array Nat) → Bool)
|
||||
(alreadySpecialized : Array Bool) : CompilerM (Array SpecEntry) := do
|
||||
let mut declsInfo := #[]
|
||||
for decl in decls do
|
||||
|
|
@ -245,7 +245,7 @@ def computeSpecEntries (decls : Array Decl) (autoSpecialize : Name → Option (A
|
|||
Compute and save specialization information for `decls`. Assumes that `decls` is an SCC of user
|
||||
defined declarations.
|
||||
-/
|
||||
def saveSpecEntries (decls : Array Decl) : CompilerM Unit := do
|
||||
def saveSpecEntries (decls : Array (Decl .pure)) : CompilerM Unit := do
|
||||
let entries ← computeSpecEntries
|
||||
decls
|
||||
(fun _ specArgs? => specArgs? == some #[])
|
||||
|
|
|
|||
|
|
@ -66,11 +66,11 @@ structure State where
|
|||
/--
|
||||
The set of `Decl` that we are done processing.
|
||||
-/
|
||||
processedDecls : Array Decl := #[]
|
||||
processedDecls : Array (Decl .pure) := #[]
|
||||
/--
|
||||
The set of `Decl` that we will attempt recursive specialization on in the next iteration.
|
||||
-/
|
||||
workingDecls : Array Decl := #[]
|
||||
workingDecls : Array (Decl .pure) := #[]
|
||||
/--
|
||||
Specialization information about specialized declarations generated in this SCC so far.
|
||||
-/
|
||||
|
|
@ -101,7 +101,7 @@ def isGround [TraverseFVar α] (e : α) : SpecializeM Bool := do
|
|||
let s := (← read).ground
|
||||
return allFVar (s.contains ·) e
|
||||
|
||||
@[inline] def withLetDecl (decl : LetDecl) (x : SpecializeM α) : SpecializeM α := do
|
||||
@[inline] def withLetDecl (decl : LetDecl .pure) (x : SpecializeM α) : SpecializeM α := do
|
||||
let grd ← isGround decl.value <||> (pure (← isArrowClass? decl.type).isSome)
|
||||
let isUnderApplied ←
|
||||
match decl.value with
|
||||
|
|
@ -109,10 +109,10 @@ def isGround [TraverseFVar α] (e : α) : SpecializeM Bool := do
|
|||
match ← getDecl? fnName with
|
||||
-- This ascription to `Bool` is required to avoid this being inferred as `Prop`,
|
||||
-- even with a type specified on the `let` binding.
|
||||
| some { params, .. } => pure ((args.size < params.size) : Bool)
|
||||
| some ⟨_, { params, .. }⟩ => pure ((args.size < params.size) : Bool)
|
||||
| none => pure false
|
||||
| .fvar fnFVarId args =>
|
||||
match ← findFunDecl? fnFVarId with
|
||||
match ← findFunDecl? (pu := .pure) fnFVarId with
|
||||
-- This ascription to `Bool` is required to avoid this being inferred as `Prop`,
|
||||
-- even with a type specified on the `let` binding.
|
||||
| some (.mk (params := params) ..) => pure ((args.size < params.size) : Bool)
|
||||
|
|
@ -125,7 +125,7 @@ def isGround [TraverseFVar α] (e : α) : SpecializeM Bool := do
|
|||
ground := if grd then ctx.ground.insert fvarId else ctx.ground
|
||||
}
|
||||
|
||||
@[inline] def withFunDecl (decl : FunDecl) (x : SpecializeM α) : SpecializeM α := do
|
||||
@[inline] def withFunDecl (decl : FunDecl .pure) (x : SpecializeM α) : SpecializeM α := do
|
||||
let ctx ← read
|
||||
let grd := allFVar (x := decl.value) fun fvarId =>
|
||||
!(ctx.scope.contains fvarId) || ctx.ground.contains fvarId
|
||||
|
|
@ -193,12 +193,13 @@ That is, `mask` contains only the arguments that are contributing to the code sp
|
|||
We use this information to compute a "key" to uniquely identify the code specialization, and
|
||||
creating the specialized code.
|
||||
-/
|
||||
def collect (paramsInfo : Array SpecParamInfo) (args : Array Arg) : SpecializeM (Array (Option Arg) × Array Param × Array CodeDecl) := do
|
||||
def collect (paramsInfo : Array SpecParamInfo) (args : Array (Arg .pure)) :
|
||||
SpecializeM (Array (Option (Arg .pure)) × Array (Param .pure) × Array (CodeDecl .pure)) := do
|
||||
let ctx ← read
|
||||
let lctx := (← getThe CompilerM.State).lctx
|
||||
let abstract (fvarId : FVarId) : Bool :=
|
||||
-- We convert let-declarations that are not ground into parameters
|
||||
!lctx.funDecls.contains fvarId &&
|
||||
!(lctx.funDecls .pure).contains fvarId &&
|
||||
!ctx.underApplied.contains fvarId &&
|
||||
!ctx.ground.contains fvarId
|
||||
Closure.run (inScope := ctx.scope.contains) (abstract := abstract) do
|
||||
|
|
@ -217,7 +218,7 @@ end Collector
|
|||
/--
|
||||
Return `true` if it is worth using arguments `args` for specialization given the parameter specialization information.
|
||||
-/
|
||||
def shouldSpecialize (specEntry : SpecEntry) (args : Array Arg) : SpecializeM Bool := do
|
||||
def shouldSpecialize (specEntry : SpecEntry) (args : Array (Arg .pure)) : SpecializeM Bool := do
|
||||
let hoCheck :=
|
||||
if specEntry.alreadySpecialized then
|
||||
fun arg => do
|
||||
|
|
@ -248,7 +249,7 @@ def shouldSpecialize (specEntry : SpecEntry) (args : Array Arg) : SpecializeM Bo
|
|||
-/
|
||||
match arg with
|
||||
| .erased | .type .. => return false
|
||||
| .fvar fvar => return (← findParam? fvar).isNone
|
||||
| .fvar fvar => return (← findParam? (pu := .pure) fvar).isNone
|
||||
else
|
||||
fun _ => pure true
|
||||
for paramInfo in specEntry.paramsInfo, arg in args do
|
||||
|
|
@ -264,7 +265,7 @@ def shouldSpecialize (specEntry : SpecEntry) (args : Array Arg) : SpecializeM Bo
|
|||
Convert the given declarations into `Expr`, and "zeta-reduce" them into body.
|
||||
This function is used to compute the key that uniquely identifies an code specialization.
|
||||
-/
|
||||
def expandCodeDecls (decls : Array CodeDecl) (body : LetValue) : CompilerM Expr := do
|
||||
def expandCodeDecls (decls : Array (CodeDecl .pure)) (body : LetValue .pure) : CompilerM Expr := do
|
||||
let xs := decls.map (mkFVar ·.fvarId)
|
||||
let values := decls.map fun
|
||||
| .let decl => decl.value.toExpr
|
||||
|
|
@ -285,7 +286,8 @@ Create the "key" that uniquely identifies a code specialization.
|
|||
The result contains the list of universe level parameter names the key that `params`, `decls`, and `body` depends on.
|
||||
We use this information to create the new auxiliary declaration and resulting application.
|
||||
-/
|
||||
def mkKey (params : Array Param) (decls : Array CodeDecl) (body : LetValue) : CompilerM (Expr × List Name) := do
|
||||
def mkKey (params : Array (Param .pure)) (decls : Array (CodeDecl .pure)) (body : LetValue .pure) :
|
||||
CompilerM (Expr × List Name) := do
|
||||
let body ← expandCodeDecls decls body
|
||||
let key := ToExpr.run do
|
||||
ToExpr.withParams params do
|
||||
|
|
@ -308,7 +310,9 @@ Specialize `decl` using
|
|||
- `decls`: local declarations that arguments in `argMask` depend on.
|
||||
- `levelParamsNew`: the universe level parameters for the new declaration.
|
||||
-/
|
||||
def mkSpecDecl (decl : Decl) (us : List Level) (argMask : Array (Option Arg)) (params : Array Param) (decls : Array CodeDecl) (levelParamsNew : List Name) : SpecializeM Decl := do
|
||||
def mkSpecDecl (decl : Decl .pure) (us : List Level) (argMask : Array (Option (Arg .pure)))
|
||||
(params : Array (Param .pure)) (decls : Array (CodeDecl .pure)) (levelParamsNew : List Name) :
|
||||
SpecializeM (Decl .pure) := do
|
||||
let nameNew := decl.name.appendCore `_at_
|
||||
|>.appendCore (← read).declName
|
||||
|>.appendCore `spec
|
||||
|
|
@ -325,7 +329,7 @@ def mkSpecDecl (decl : Decl) (us : List Level) (argMask : Array (Option Arg)) (p
|
|||
finally
|
||||
eraseDecl decl
|
||||
where
|
||||
go (decl : Decl) (nameNew : Name) : InternalizeM Decl := do
|
||||
go (decl : Decl .pure) (nameNew : Name) : InternalizeM .pure (Decl .pure) := do
|
||||
let .code code := decl.value | panic! "can only specialize decls with code"
|
||||
let mut params ← params.mapM internalizeParam
|
||||
let decls ← decls.mapM internalizeCodeDecl
|
||||
|
|
@ -348,21 +352,21 @@ where
|
|||
let value := .code code
|
||||
let safe := decl.safe
|
||||
let recursive := decl.recursive
|
||||
let decl := { name := nameNew, levelParams := levelParamsNew, params, type, value, safe, recursive, inlineAttr? := none : Decl }
|
||||
let decl := { name := nameNew, levelParams := levelParamsNew, params, type, value, safe, recursive, inlineAttr? := none : Decl .pure }
|
||||
return decl.setLevelParams
|
||||
|
||||
/--
|
||||
Given the specialization mask `paramsInfo` and the arguments `args`,
|
||||
return the arguments that have not been considered for specialization.
|
||||
-/
|
||||
def getRemainingArgs (paramsInfo : Array SpecParamInfo) (args : Array Arg) : Array Arg := Id.run do
|
||||
def getRemainingArgs (paramsInfo : Array SpecParamInfo) (args : Array (Arg .pure)) : Array (Arg .pure) := Id.run do
|
||||
let mut result := #[]
|
||||
for info in paramsInfo, arg in args do
|
||||
if info matches .other then
|
||||
result := result.push arg
|
||||
return result ++ args[paramsInfo.size...*]
|
||||
|
||||
def paramsToGroundVars (params : Array Param) : CompilerM FVarIdSet :=
|
||||
def paramsToGroundVars (params : Array (Param .pure)) : CompilerM FVarIdSet :=
|
||||
params.foldlM (init := {}) fun r p => do
|
||||
if isTypeFormerType p.type || (← isArrowClass? p.type).isSome then
|
||||
return r.insert p.fvarId
|
||||
|
|
@ -392,13 +396,13 @@ mutual
|
|||
Try to specialize the function application in the given let-declaration.
|
||||
`k` is the continuation for the let-declaration.
|
||||
-/
|
||||
partial def specializeApp? (e : LetValue) : SpecializeM (Option LetValue) := do
|
||||
partial def specializeApp? (e : LetValue .pure) : SpecializeM (Option (LetValue .pure)) := do
|
||||
let .const declName us args := e | return none
|
||||
if args.isEmpty then return none
|
||||
if (← Meta.isInstance declName) then return none
|
||||
let some specEntry ← getSpecEntry? declName | return none
|
||||
unless (← shouldSpecialize specEntry args) do return none
|
||||
let some decl ← getDecl? declName | return none
|
||||
let some ⟨.pure, decl⟩ ← getDecl? declName | return none
|
||||
let .code _ := decl.value | return none
|
||||
trace[Compiler.specialize.candidate] "{e.toExpr}, {specEntry}"
|
||||
let paramsInfo := specEntry.paramsInfo
|
||||
|
|
@ -419,7 +423,7 @@ mutual
|
|||
fun
|
||||
| .type .. | .erased => return false
|
||||
| .fvar fvar => do
|
||||
if let some param ← findParam? fvar then
|
||||
if let some param ← findParam? (pu := .pure) fvar then
|
||||
/-
|
||||
For now we only allow recursive specialization on non class parameters, reason:
|
||||
We can encounter situations where we repeatedly re-abstract over type classes
|
||||
|
|
@ -442,11 +446,11 @@ mutual
|
|||
}
|
||||
return some (.const specDecl.name usNew argsNew)
|
||||
|
||||
partial def visitFunDecl (funDecl : FunDecl) : SpecializeM FunDecl := do
|
||||
partial def visitFunDecl (funDecl : FunDecl .pure) : SpecializeM (FunDecl .pure) := do
|
||||
let value ← withParams funDecl.params <| visitCode funDecl.value
|
||||
funDecl.update' funDecl.type value
|
||||
|
||||
partial def visitCode (code : Code) : SpecializeM Code := do
|
||||
partial def visitCode (code : Code .pure) : SpecializeM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let mut decl := decl
|
||||
|
|
@ -476,7 +480,7 @@ end
|
|||
/--
|
||||
Run specialization on the body of `decl`.
|
||||
-/
|
||||
def specializeDecl (decl : Decl) : SpecializeM (Decl × Bool) := do
|
||||
def specializeDecl (decl : Decl .pure) : SpecializeM (Decl .pure × Bool) := do
|
||||
trace[Compiler.specialize.step] m!"Working {decl.name}"
|
||||
if (← decl.isTemplateLike) then
|
||||
return (decl, false)
|
||||
|
|
@ -544,7 +548,7 @@ partial def loop (round : Nat := 0) : SpecializeM Unit := do
|
|||
|
||||
loop (round + 1)
|
||||
|
||||
def main (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
def main (decls : Array (Decl .pure)) : CompilerM (Array (Decl .pure)) := do
|
||||
saveSpecEntries decls
|
||||
let (_, s) ← loop |>.run { declName := .anonymous } |>.run { workingDecls := decls }
|
||||
return s.processedDecls
|
||||
|
|
|
|||
|
|
@ -13,21 +13,21 @@ namespace Lean.Compiler.LCNF
|
|||
|
||||
namespace SplitScc
|
||||
|
||||
partial def findSccCalls (scc : Std.HashMap Name Decl) (decl : Decl) : BaseIO (Std.HashSet Name) := do
|
||||
partial def findSccCalls (scc : Std.HashMap Name (Decl pu)) (decl : Decl pu) : BaseIO (Std.HashSet Name) := do
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
let (_, calls) ← goCode code |>.run {}
|
||||
return calls
|
||||
| .extern .. => return {}
|
||||
where
|
||||
goCode (c : Code) : StateRefT (Std.HashSet Name) BaseIO Unit := do
|
||||
goCode (c : Code pu) : StateRefT (Std.HashSet Name) BaseIO Unit := do
|
||||
match c with
|
||||
| .let decl k =>
|
||||
if let .const name .. := decl.value then
|
||||
if scc.contains name then
|
||||
modify fun s => s.insert name
|
||||
goCode k
|
||||
| .fun decl k | .jp decl k =>
|
||||
| .fun decl k _ | .jp decl k =>
|
||||
goCode decl.value
|
||||
goCode k
|
||||
| .cases cases => cases.alts.forM (·.forCodeM goCode)
|
||||
|
|
@ -35,7 +35,7 @@ where
|
|||
|
||||
end SplitScc
|
||||
|
||||
public def splitScc (scc : Array Decl) : CompilerM (Array (Array Decl)) := do
|
||||
public def splitScc (scc : Array (Decl pu)) : CompilerM (Array (Array (Decl pu))) := do
|
||||
if scc.size == 1 then
|
||||
return #[scc]
|
||||
let declMap := Std.HashMap.ofArray <| scc.map fun decl => (decl.name, decl)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def findStructCtorInfo? (typeName : Name) : CoreM (Option ConstructorVal) := do
|
|||
return ctorInfo
|
||||
|
||||
def mkFieldParamsForCtorType (ctorType : Expr) (numParams : Nat) (numFields : Nat) :
|
||||
CompilerM (Array Param) := do
|
||||
CompilerM (Array (Param .pure)) := do
|
||||
let mut type ← Meta.MetaM.run' <| toLCNFType ctorType
|
||||
type ← toMonoType type
|
||||
for _ in *...numParams do
|
||||
|
|
@ -52,7 +52,7 @@ def remapFVar (fvarId : FVarId) : M FVarId := do
|
|||
|
||||
mutual
|
||||
|
||||
partial def visitCode (code : Code) : M Code := do
|
||||
partial def visitCode (code : Code .pure) : M (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
match decl.value with
|
||||
|
|
@ -105,7 +105,7 @@ partial def visitCode (code : Code) : M Code := do
|
|||
| .return fvarId => return code.updateReturn! (← remapFVar fvarId)
|
||||
| .unreach .. => return code
|
||||
|
||||
partial def visitLetValue (v : LetValue) : M LetValue := do
|
||||
partial def visitLetValue (v : LetValue .pure) : M (LetValue .pure) := do
|
||||
match v with
|
||||
| .const _ _ args =>
|
||||
return v.updateArgs! (← args.mapM visitArg)
|
||||
|
|
@ -115,24 +115,24 @@ partial def visitLetValue (v : LetValue) : M LetValue := do
|
|||
-- Projections should be handled directly by `visitCode`.
|
||||
| .proj .. => unreachable!
|
||||
|
||||
partial def visitAlt (alt : Alt) : M Alt := do
|
||||
partial def visitAlt (alt : Alt .pure) : M (Alt .pure) := do
|
||||
return alt.updateCode (← visitCode alt.getCode)
|
||||
|
||||
partial def visitArg (arg : Arg) : M Arg :=
|
||||
partial def visitArg (arg : Arg .pure) : M (Arg .pure) :=
|
||||
match arg with
|
||||
| .fvar fvarId => return arg.updateFVar! (← remapFVar fvarId)
|
||||
| .type _ | .erased => return arg
|
||||
|
||||
end
|
||||
|
||||
def visitDecl (decl : Decl) : M Decl := do
|
||||
def visitDecl (decl : Decl .pure) : M (Decl .pure) := do
|
||||
let value ← decl.value.mapCodeM (visitCode ·)
|
||||
return { decl with value }
|
||||
|
||||
end StructProjCases
|
||||
|
||||
def structProjCases : Pass :=
|
||||
.mkPerDeclaration `structProjCases (StructProjCases.visitDecl · |>.run) .mono
|
||||
.mkPerDeclaration `structProjCases .mono (StructProjCases.visitDecl · |>.run)
|
||||
|
||||
builtin_initialize registerTraceClass `Compiler.structProjCases (inherited := true)
|
||||
|
||||
|
|
|
|||
|
|
@ -105,13 +105,13 @@ The steps for this are roughly:
|
|||
- expand declarations tagged with the `[macro_inline]` attribute
|
||||
- turn the resulting term into LCNF declaration
|
||||
-/
|
||||
def toDecl (declName : Name) : CompilerM Decl := do
|
||||
def toDecl (declName : Name) : CompilerM (Decl .pure) := do
|
||||
let declName := if let some name := isUnsafeRecName? declName then name else declName
|
||||
let some info ← getDeclInfo? declName | throwError "declaration `{.ofConstName declName}` not found"
|
||||
let safe ← declIsNotUnsafe declName
|
||||
let env ← getEnv
|
||||
let inlineAttr? := getInlineAttribute? env declName
|
||||
let paramsFromTypeBinders (expr : Expr) : CompilerM (Array Param) := do
|
||||
let paramsFromTypeBinders (expr : Expr) : CompilerM (Array (Param .pure)) := do
|
||||
let mut params := #[]
|
||||
let mut currentExpr := expr
|
||||
repeat
|
||||
|
|
@ -145,7 +145,7 @@ def toDecl (declName : Name) : CompilerM Decl := do
|
|||
let code ← toLCNF value
|
||||
let decl ← if let .fun decl (.return _) := code then
|
||||
eraseFunDecl decl (recursive := false)
|
||||
pure { name := declName, params := decl.params, type, value := .code decl.value, levelParams := info.levelParams, safe, inlineAttr? : Decl }
|
||||
pure { name := declName, params := decl.params, type, value := .code decl.value, levelParams := info.levelParams, safe, inlineAttr? : Decl .pure }
|
||||
else
|
||||
pure { name := declName, params := #[], type, value := .code code, levelParams := info.levelParams, safe, inlineAttr? }
|
||||
/- `toLCNF` may eta-reduce simple declarations. -/
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ where
|
|||
|
||||
abbrev ToExprM := ReaderT Nat $ StateM LevelMap
|
||||
|
||||
@[inline] def mkLambdaM (params : Array Param) (e : Expr) : ToExprM Expr :=
|
||||
@[inline] def mkLambdaM (params : Array (Param pu)) (e : Expr) : ToExprM Expr :=
|
||||
return go (← read) (← get) params.size e
|
||||
where
|
||||
go (offset : Nat) (m : LevelMap) (i : Nat) (e : Expr) : Expr :=
|
||||
|
|
@ -59,7 +59,7 @@ private abbrev _root_.Lean.FVarId.toExprM (fvarId : FVarId) : ToExprM Expr :=
|
|||
modify fun s => s.insert fvarId offset
|
||||
withReader (·+1) k
|
||||
|
||||
@[inline] partial def withParams (params : Array Param) (k : ToExprM α) : ToExprM α :=
|
||||
@[inline] partial def withParams (params : Array (Param pu)) (k : ToExprM α) : ToExprM α :=
|
||||
go 0
|
||||
where
|
||||
@[specialize] go (i : Nat) : ToExprM α := do
|
||||
|
|
@ -79,21 +79,21 @@ end ToExpr
|
|||
|
||||
open ToExpr
|
||||
|
||||
private def Arg.toExprM (arg : Arg) : ToExprM Expr :=
|
||||
private def Arg.toExprM (arg : Arg pu) : ToExprM Expr :=
|
||||
return arg.toExpr.abstract' (← read) (← get)
|
||||
|
||||
mutual
|
||||
partial def FunDecl.toExprM (decl : FunDecl) : ToExprM Expr :=
|
||||
partial def FunDecl.toExprM (decl : FunDecl pu) : ToExprM Expr :=
|
||||
withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM)
|
||||
|
||||
partial def Code.toExprM (code : Code) : ToExprM Expr := do
|
||||
partial def Code.toExprM (code : Code pu) : ToExprM Expr := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let type ← abstractM decl.type
|
||||
let value ← abstractM decl.value.toExpr
|
||||
let body ← withFVar decl.fvarId k.toExprM
|
||||
return .letE decl.binderName type value body true
|
||||
| .fun decl k | .jp decl k =>
|
||||
| .fun decl k _ | .jp decl k =>
|
||||
let type ← abstractM decl.type
|
||||
let value ← decl.toExprM
|
||||
let body ← withFVar decl.fvarId k.toExprM
|
||||
|
|
@ -103,17 +103,17 @@ partial def Code.toExprM (code : Code) : ToExprM Expr := do
|
|||
| .unreach type => return mkApp (mkConst ``lcUnreachable) (← abstractM type)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapM fun
|
||||
| .alt ctorName params k => do
|
||||
| .alt ctorName params k _ => do
|
||||
let body ← withParams params do mkLambdaM params (← k.toExprM)
|
||||
return mkApp (mkConst ctorName) body
|
||||
| .default k => k.toExprM
|
||||
return mkAppN (mkConst `cases) (#[← c.discr.toExprM] ++ alts)
|
||||
end
|
||||
|
||||
public def Code.toExpr (code : Code) (xs : Array FVarId := #[]) : Expr :=
|
||||
public def Code.toExpr (code : Code pu) (xs : Array FVarId := #[]) : Expr :=
|
||||
run' code.toExprM xs
|
||||
|
||||
public def FunDecl.toExpr (decl : FunDecl) (xs : Array FVarId := #[]) : Expr :=
|
||||
public def FunDecl.toExpr (decl : FunDecl pu) (xs : Array FVarId := #[]) : Expr :=
|
||||
run' decl.toExprM xs
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -33,18 +33,18 @@ The `toLCNF` function maintains a sequence of elements that is eventually
|
|||
converted into `Code`.
|
||||
-/
|
||||
inductive Element where
|
||||
| jp (decl : FunDecl)
|
||||
| fun (decl : FunDecl)
|
||||
| let (decl : LetDecl)
|
||||
| cases (p : Param) (cases : Cases)
|
||||
| unreach (p : Param)
|
||||
| jp (decl : FunDecl .pure)
|
||||
| fun (decl : FunDecl .pure)
|
||||
| let (decl : LetDecl .pure)
|
||||
| cases (p : Param .pure) (cases : Cases .pure)
|
||||
| unreach (p : Param .pure)
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
State for `BindCasesM` monad
|
||||
Mapping from `_alt.<idx>` variables to new join points
|
||||
-/
|
||||
abbrev BindCasesM.State := FVarIdMap FunDecl
|
||||
abbrev BindCasesM.State := FVarIdMap (FunDecl .pure)
|
||||
|
||||
/-- Auxiliary monad for implementing `bindCases` -/
|
||||
abbrev BindCasesM := StateRefT BindCasesM.State CompilerM
|
||||
|
|
@ -60,25 +60,25 @@ and then jumps to `jpDecl`. The goal is to make sure the auxiliary join point is
|
|||
of `_alt.<idx>`, then `simp` will inline it.
|
||||
That is, our goal is to try to promote the pre join points `_alt.<idx>` into a proper join point.
|
||||
-/
|
||||
partial def bindCases (jpDecl : FunDecl) (cases : Cases) : CompilerM Code := do
|
||||
partial def bindCases (jpDecl : FunDecl .pure) (cases : Cases .pure) : CompilerM (Code .pure) := do
|
||||
let (alts, s) ← visitAlts cases.alts |>.run {}
|
||||
let resultType ← mkCasesResultType alts
|
||||
let result := .cases ⟨cases.typeName, resultType, cases.discr, alts⟩
|
||||
let result := s.foldl (init := result) fun result _ altJp => .jp altJp result
|
||||
return .jp jpDecl result
|
||||
where
|
||||
visitAlts (alts : Array Alt) : BindCasesM (Array Alt) :=
|
||||
visitAlts (alts : Array (Alt .pure)) : BindCasesM (Array (Alt .pure)) :=
|
||||
alts.mapM fun alt => return alt.updateCode (← go alt.getCode)
|
||||
|
||||
findFun? (f : FVarId) : CompilerM (Option FunDecl) := do
|
||||
if let some funDecl ← findFunDecl? f then
|
||||
findFun? (f : FVarId) : CompilerM (Option (FunDecl .pure)) := do
|
||||
if let some funDecl ← findFunDecl? (pu := .pure) f then
|
||||
return funDecl
|
||||
else if let some (.fvar f' #[]) ← findLetValue? f then
|
||||
else if let some (.fvar f' #[]) ← findLetValue? (pu := .pure) f then
|
||||
findFun? f'
|
||||
else
|
||||
return none
|
||||
|
||||
go (code : Code) : BindCasesM Code := do
|
||||
go (code : Code .pure) : BindCasesM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
if let .return fvarId := k then
|
||||
|
|
@ -112,7 +112,7 @@ where
|
|||
Then, we replace the current `let`-declaration with `jmp altJp args`
|
||||
-/
|
||||
let mut jpParams := #[]
|
||||
let mut subst := {}
|
||||
let mut subst : FVarSubst .pure := {}
|
||||
let mut jpArgs := #[]
|
||||
/- Remark: `funDecl.params.size` may be greater than `args.size`. -/
|
||||
for param in funDecl.params[*...args.size] do
|
||||
|
|
@ -151,10 +151,10 @@ where
|
|||
| .return fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
| .jmp .. | .unreach .. => return code
|
||||
|
||||
def seqToCode (seq : Array Element) (k : Code) : CompilerM Code := do
|
||||
def seqToCode (seq : Array Element) (k : Code .pure) : CompilerM (Code .pure) := do
|
||||
go seq seq.size k
|
||||
where
|
||||
go (seq : Array Element) (i : Nat) (c : Code) : CompilerM Code := do
|
||||
go (seq : Array Element) (i : Nat) (c : Code .pure) : CompilerM (Code .pure) := do
|
||||
if i > 0 then
|
||||
match seq[i-1]! with
|
||||
| .jp decl => go seq (i - 1) (.jp decl c)
|
||||
|
|
@ -198,7 +198,7 @@ structure State where
|
|||
/-- Local context containing the original Lean types (not LCNF ones). -/
|
||||
lctx : LocalContext := {}
|
||||
/-- Cache from Lean regular expression to LCNF argument. -/
|
||||
cache : PHashMap Expr Arg := {}
|
||||
cache : PHashMap Expr (Arg .pure) := {}
|
||||
/--
|
||||
Determines whether caching has been disabled due to finding a use of
|
||||
a constant marked with `never_extract`.
|
||||
|
|
@ -228,12 +228,12 @@ abbrev M := StateRefT State CompilerM
|
|||
def pushElement (elem : Element) : M Unit := do
|
||||
modify fun s => { s with seq := s.seq.push elem }
|
||||
|
||||
def mkUnreachable (type : Expr) : M Arg := do
|
||||
def mkUnreachable (type : Expr) : M (Arg .pure) := do
|
||||
let p ← mkAuxParam type
|
||||
pushElement (.unreach p)
|
||||
return .fvar p.fvarId
|
||||
|
||||
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : M FVarId := do
|
||||
def mkAuxLetDecl (e : LetValue .pure) (prefixName := `_x) : M FVarId := do
|
||||
match e with
|
||||
| .fvar fvarId #[] => return fvarId
|
||||
| _ =>
|
||||
|
|
@ -241,11 +241,11 @@ def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : M FVarId := do
|
|||
pushElement (.let letDecl)
|
||||
return letDecl.fvarId
|
||||
|
||||
def letValueToArg (e : LetValue) (prefixName := `_x) : M Arg :=
|
||||
def letValueToArg (e : LetValue .pure) (prefixName := `_x) : M (Arg .pure) :=
|
||||
return .fvar (← mkAuxLetDecl e prefixName)
|
||||
|
||||
/-- Create `Code` that executes the current `seq` and then returns `result` -/
|
||||
def toCode (result : Arg) : M Code := do
|
||||
def toCode (result : Arg .pure) : M (Code .pure) := do
|
||||
match result with
|
||||
| .fvar fvarId => seqToCode (← get).seq (.return fvarId)
|
||||
| .erased | .type .. =>
|
||||
|
|
@ -327,7 +327,7 @@ def cleanupBinderName (binderName : Name) : CompilerM Name :=
|
|||
return binderName
|
||||
|
||||
/-- Create a new local declaration using a Lean regular type. -/
|
||||
def mkParam (binderName : Name) (type : Expr) : M Param := do
|
||||
def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do
|
||||
let binderName ← cleanupBinderName binderName
|
||||
let borrow := isMarkedBorrowed type
|
||||
let type' ← toLCNFType type
|
||||
|
|
@ -335,7 +335,8 @@ def mkParam (binderName : Name) (type : Expr) : M Param := do
|
|||
modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default }
|
||||
return param
|
||||
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg) : M LetDecl := do
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg .pure) :
|
||||
M (LetDecl .pure) := do
|
||||
let binderName ← cleanupBinderName binderName
|
||||
let value' ← match arg with
|
||||
| .fvar fvarId => pure <| .fvar fvarId #[]
|
||||
|
|
@ -347,10 +348,10 @@ def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (a
|
|||
}
|
||||
return letDecl
|
||||
|
||||
def visitLambda (e : Expr) : M (Array Param × Expr) :=
|
||||
def visitLambda (e : Expr) : M (Array (Param .pure) × Expr) :=
|
||||
go e #[] #[]
|
||||
where
|
||||
go (e : Expr) (xs : Array Expr) (ps : Array Param) := do
|
||||
go (e : Expr) (xs : Array Expr) (ps : Array (Param .pure)) := do
|
||||
if let .lam binderName type body _ := e then
|
||||
let type := type.instantiateRev xs
|
||||
let p ← mkParam binderName type
|
||||
|
|
@ -358,10 +359,10 @@ where
|
|||
else
|
||||
return (ps, e.instantiateRev xs)
|
||||
|
||||
def visitBoundedLambda (e : Expr) (n : Nat) : M (Array Param × Expr) :=
|
||||
def visitBoundedLambda (e : Expr) (n : Nat) : M (Array (Param .pure) × Expr) :=
|
||||
go e n #[] #[]
|
||||
where
|
||||
go (e : Expr) (n : Nat) (xs : Array Expr) (ps : Array Param) := do
|
||||
go (e : Expr) (n : Nat) (xs : Array Expr) (ps : Array (Param .pure)) := do
|
||||
if n == 0 then
|
||||
return (ps, e.instantiateRev xs)
|
||||
else if let .lam binderName type body _ := e then
|
||||
|
|
@ -422,10 +423,10 @@ Put the given expression in `LCNF`.
|
|||
- Eta-expand applications of declarations that satisfy `shouldEtaExpand`.
|
||||
- Put computationally relevant expressions in A-normal form.
|
||||
-/
|
||||
partial def toLCNF (e : Expr) : CompilerM Code := do
|
||||
partial def toLCNF (e : Expr) : CompilerM (Code .pure) := do
|
||||
run do toCode (← visit e)
|
||||
where
|
||||
visitCore (e : Expr) : M Arg := withIncRecDepth do
|
||||
visitCore (e : Expr) : M (Arg .pure) := withIncRecDepth do
|
||||
if let some arg := (← get).cache.find? e then
|
||||
return arg
|
||||
let r : Arg ← match e with
|
||||
|
|
@ -441,7 +442,7 @@ where
|
|||
modify fun s => if s.shouldCache then { s with cache := s.cache.insert e r } else s
|
||||
return r
|
||||
|
||||
visit (e : Expr) : M Arg := withIncRecDepth do
|
||||
visit (e : Expr) : M (Arg .pure) := withIncRecDepth do
|
||||
if isLCProof e then
|
||||
return .erased
|
||||
let type ← liftMetaM <| Meta.inferType e
|
||||
|
|
@ -457,10 +458,10 @@ where
|
|||
return .erased
|
||||
visitCore e
|
||||
|
||||
visitLit (lit : Literal) : M Arg :=
|
||||
visitLit (lit : Literal) : M (Arg .pure) :=
|
||||
letValueToArg (.lit (litToValue lit))
|
||||
|
||||
visitAppArg (e : Expr) : M Arg := do
|
||||
visitAppArg (e : Expr) : M (Arg .pure) := do
|
||||
if isLCProof e then
|
||||
return .erased
|
||||
let type ← liftMetaM <| Meta.inferType e
|
||||
|
|
@ -478,7 +479,7 @@ where
|
|||
visitCore e
|
||||
|
||||
/-- Giving `f` a constant `.const declName us`, convert `args` into `args'`, and return `.const declName us args'` -/
|
||||
visitAppDefaultConst (f : Expr) (args : Array Expr) : M Arg := do
|
||||
visitAppDefaultConst (f : Expr) (args : Array Expr) : M (Arg .pure) := do
|
||||
let env ← getEnv
|
||||
let .const declName us := CSimp.replaceConstants env f | unreachable!
|
||||
let args ← args.mapM visitAppArg
|
||||
|
|
@ -487,7 +488,7 @@ where
|
|||
letValueToArg <| .const declName us args
|
||||
|
||||
/-- Eta expand if under applied, otherwise apply k -/
|
||||
etaIfUnderApplied (e : Expr) (arity : Nat) (k : M Arg) : M Arg := do
|
||||
etaIfUnderApplied (e : Expr) (arity : Nat) (k : M (Arg .pure)) : M (Arg .pure) := do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if numArgs < arity then
|
||||
visit (← etaExpandN e (arity - numArgs))
|
||||
|
|
@ -502,7 +503,7 @@ where
|
|||
k args[arity...*]
|
||||
```
|
||||
-/
|
||||
mkOverApplication (app : Arg) (args : Array Expr) (arity : Nat) : M Arg := do
|
||||
mkOverApplication (app : (Arg .pure)) (args : Array Expr) (arity : Nat) : M (Arg .pure) := do
|
||||
if args.size == arity then
|
||||
return app
|
||||
else
|
||||
|
|
@ -517,7 +518,7 @@ where
|
|||
/--
|
||||
Visit a `matcher`/`casesOn` alternative.
|
||||
-/
|
||||
visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × Alt) := do
|
||||
visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × (Alt .pure)) := do
|
||||
withNewScope do
|
||||
match casesAltInfo with
|
||||
| .default numHyps =>
|
||||
|
|
@ -552,7 +553,7 @@ where
|
|||
let altType ← c.inferType
|
||||
return (altType, .alt ctorName ps c)
|
||||
|
||||
visitCases (casesInfo : CasesInfo) (e : Expr) : M Arg :=
|
||||
visitCases (casesInfo : CasesInfo) (e : Expr) : M (Arg .pure) :=
|
||||
etaIfUnderApplied e casesInfo.arity do
|
||||
let args := e.getAppArgs
|
||||
let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...casesInfo.arity]))
|
||||
|
|
@ -603,11 +604,11 @@ where
|
|||
let result := .fvar auxDecl.fvarId
|
||||
mkOverApplication result args casesInfo.arity
|
||||
|
||||
visitCtor (arity : Nat) (e : Expr) : M Arg :=
|
||||
visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) :=
|
||||
etaIfUnderApplied e arity do
|
||||
visitAppDefaultConst e.getAppFn e.getAppArgs
|
||||
|
||||
visitQuotLift (e : Expr) : M Arg := do
|
||||
visitQuotLift (e : Expr) : M (Arg .pure) := do
|
||||
let arity := 6
|
||||
etaIfUnderApplied e arity do
|
||||
let mut args := e.getAppArgs
|
||||
|
|
@ -622,7 +623,7 @@ where
|
|||
| .type _ => unreachable!
|
||||
| .fvar fvarId => mkOverApplication (← letValueToArg <| .fvar fvarId #[.fvar invq]) args arity
|
||||
|
||||
visitEqRec (e : Expr) : M Arg :=
|
||||
visitEqRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 6
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
|
|
@ -630,7 +631,7 @@ where
|
|||
let minor ← visit minor
|
||||
mkOverApplication minor args arity
|
||||
|
||||
visitHEqRec (e : Expr) : M Arg :=
|
||||
visitHEqRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 7
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
|
|
@ -638,19 +639,19 @@ where
|
|||
let minor ← visit minor
|
||||
mkOverApplication minor args arity
|
||||
|
||||
visitFalseRec (e : Expr) : M Arg :=
|
||||
visitFalseRec (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 2
|
||||
etaIfUnderApplied e arity do
|
||||
let type ← toLCNFType (← liftMetaM do Meta.inferType e)
|
||||
mkUnreachable type
|
||||
|
||||
visitLcUnreachable (e : Expr) : M Arg :=
|
||||
visitLcUnreachable (e : Expr) : M (Arg .pure) :=
|
||||
let arity := 1
|
||||
etaIfUnderApplied e arity do
|
||||
let type ← toLCNFType (← liftMetaM do Meta.inferType e)
|
||||
mkUnreachable type
|
||||
|
||||
visitAndIffRecCore (e : Expr) (minorPos : Nat) : M Arg :=
|
||||
visitAndIffRecCore (e : Expr) (minorPos : Nat) : M (Arg .pure) :=
|
||||
let arity := 5
|
||||
etaIfUnderApplied e arity do
|
||||
let args := e.getAppArgs
|
||||
|
|
@ -660,7 +661,7 @@ where
|
|||
let minor := minor.beta #[ha, hb]
|
||||
visit (mkAppN minor args[arity...*])
|
||||
|
||||
visitNoConfusion (e : Expr) : M Arg := do
|
||||
visitNoConfusion (e : Expr) : M (Arg .pure) := do
|
||||
let .const declName _ := e.getAppFn | unreachable!
|
||||
let info := getNoConfusionInfo (← getEnv) declName
|
||||
let typeName := declName.getPrefix
|
||||
|
|
@ -705,7 +706,7 @@ where
|
|||
else
|
||||
expandNoConfusionMajor (← etaExpandN major (n+1)) (n+1)
|
||||
|
||||
visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M Arg := do
|
||||
visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M (Arg .pure) := do
|
||||
let typeName := projInfo.ctorName.getPrefix
|
||||
if isRuntimeBuiltinType typeName then
|
||||
let numArgs := e.getAppNumArgs
|
||||
|
|
@ -720,7 +721,7 @@ where
|
|||
let f ← Core.instantiateValueLevelParams info us
|
||||
visit (f.beta e.getAppArgs)
|
||||
|
||||
visitApp (e : Expr) : M Arg := do
|
||||
visitApp (e : Expr) : M (Arg .pure) := do
|
||||
if let .const declName us := CSimp.replaceConstants (← getEnv) e.getAppFn then
|
||||
if declName == ``Quot.lift then
|
||||
visitQuotLift e
|
||||
|
|
@ -754,7 +755,7 @@ where
|
|||
let args ← args.mapM visitAppArg
|
||||
letValueToArg <| .fvar fvarId args
|
||||
|
||||
visitLambda (e : Expr) : M Arg := do
|
||||
visitLambda (e : Expr) : M (Arg .pure) := do
|
||||
let b := etaReduceImplicit e
|
||||
/-
|
||||
Note: we don't want to eta-reduce arbitrary lambda expressions since it can
|
||||
|
|
@ -790,10 +791,10 @@ where
|
|||
pushElement (.fun funDecl)
|
||||
return .fvar funDecl.fvarId
|
||||
|
||||
visitMData (_mdata : MData) (e : Expr) : M Arg := do
|
||||
visitMData (_mdata : MData) (e : Expr) : M (Arg .pure) := do
|
||||
visit e
|
||||
|
||||
visitProj (s : Name) (i : Nat) (e : Expr) : M Arg := do
|
||||
visitProj (s : Name) (i : Nat) (e : Expr) : M (Arg .pure) := do
|
||||
if isRuntimeBuiltinType s then
|
||||
let structInfo := getStructureInfo (← getEnv) s
|
||||
let projExpr ← liftMetaM <| Meta.mkProjection e structInfo.fieldNames[i]!
|
||||
|
|
@ -803,7 +804,7 @@ where
|
|||
| .erased | .type .. => return .erased
|
||||
| .fvar fvarId => letValueToArg <| .proj s i fvarId
|
||||
|
||||
visitLet (e : Expr) (xs : Array Expr) : M Arg := do
|
||||
visitLet (e : Expr) (xs : Array Expr) : M (Arg .pure) := do
|
||||
match e with
|
||||
| .letE binderName type value body _ =>
|
||||
let type := type.instantiateRev xs
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ structure ToMonoM.State where
|
|||
|
||||
abbrev ToMonoM := StateRefT ToMonoM.State CompilerM
|
||||
|
||||
def Param.toMono (param : Param) : ToMonoM Param := do
|
||||
def Param.toMono (param : Param .pure) : ToMonoM (Param .pure) := do
|
||||
if isTypeFormerType param.type then
|
||||
modify fun s => { s with typeParams := s.typeParams.insert param.fvarId }
|
||||
param.update (← toMonoType param.type)
|
||||
|
|
@ -37,7 +37,7 @@ def checkFVarUseDeferred (resultFVar fvarId : FVarId) : ToMonoM Unit := do
|
|||
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName }
|
||||
|
||||
@[inline]
|
||||
def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg) : ToMonoM Arg := do
|
||||
def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg .pure) : ToMonoM (Arg .pure) := do
|
||||
match arg with
|
||||
| .erased | .type .. => return .erased
|
||||
| .fvar fvarId =>
|
||||
|
|
@ -47,13 +47,13 @@ def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg) : ToMonoM Arg :=
|
|||
check fvarId
|
||||
return arg
|
||||
|
||||
def argToMono (arg : Arg) : ToMonoM Arg := argToMonoBase checkFVarUse arg
|
||||
def argToMono (arg : Arg .pure) : ToMonoM (Arg .pure) := argToMonoBase checkFVarUse arg
|
||||
|
||||
def argToMonoDeferredCheck (resultFVar : FVarId) (arg : Arg) : ToMonoM Arg :=
|
||||
def argToMonoDeferredCheck (resultFVar : FVarId) (arg : Arg .pure) : ToMonoM (Arg .pure) :=
|
||||
argToMonoBase (checkFVarUseDeferred resultFVar) arg
|
||||
|
||||
def argsToMonoWithFnType (resultFVar : FVarId) (args : Array Arg) (type : Expr)
|
||||
: ToMonoM (Array Arg) := do
|
||||
def argsToMonoWithFnType (resultFVar : FVarId) (args : Array (Arg .pure)) (type : Expr)
|
||||
: ToMonoM (Array (Arg .pure)) := do
|
||||
let mut remainingType : Option Expr := some type
|
||||
let mut result := Array.emptyWithCapacity args.size
|
||||
for arg in args do
|
||||
|
|
@ -69,8 +69,8 @@ def argsToMonoWithFnType (resultFVar : FVarId) (args : Array Arg) (type : Expr)
|
|||
result := result.push monoArg
|
||||
return result
|
||||
|
||||
def argsToMonoRedArg (resultFVar : FVarId) (args : Array Arg) (params : Array Param)
|
||||
(redArgs : Array Arg) : ToMonoM (Array Arg) := do
|
||||
def argsToMonoRedArg (resultFVar : FVarId) (args : Array (Arg .pure)) (params : Array (Param .pure))
|
||||
(redArgs : Array (Arg .pure)) : ToMonoM (Array (Arg .pure)) := do
|
||||
let mut result := #[]
|
||||
let mut argIdx := 0
|
||||
for redArg in redArgs do
|
||||
|
|
@ -87,14 +87,14 @@ def argsToMonoRedArg (resultFVar : FVarId) (args : Array Arg) (params : Array Pa
|
|||
result := result.push arg
|
||||
return result
|
||||
|
||||
def ctorAppToMono (resultFVar : FVarId) (ctorInfo : ConstructorVal) (args : Array Arg)
|
||||
: ToMonoM LetValue := do
|
||||
let argsNewParams : Array Arg := .replicate ctorInfo.numParams .erased
|
||||
def ctorAppToMono (resultFVar : FVarId) (ctorInfo : ConstructorVal) (args : Array (Arg .pure))
|
||||
: ToMonoM (LetValue .pure) := do
|
||||
let argsNewParams : Array (Arg .pure) := .replicate ctorInfo.numParams .erased
|
||||
let argsNewFields ← args[ctorInfo.numParams...*].toArray.mapM (argToMonoDeferredCheck resultFVar)
|
||||
let argsNew := argsNewParams ++ argsNewFields
|
||||
return .const ctorInfo.name [] argsNew
|
||||
|
||||
partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetValue := do
|
||||
partial def LetValue.toMono (e : LetValue .pure) (resultFVar : FVarId) : ToMonoM (LetValue .pure) := do
|
||||
match e with
|
||||
| .erased | .lit .. => return e
|
||||
| .const declName _ args =>
|
||||
|
|
@ -111,7 +111,7 @@ partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetVa
|
|||
else if declName == ``Quot.lcInv then
|
||||
match args[2]! with
|
||||
| .fvar fvarId =>
|
||||
let mut extraArgs : Array Arg := .emptyWithCapacity (args.size - 3)
|
||||
let mut extraArgs : Array (Arg .pure) := .emptyWithCapacity (args.size - 3)
|
||||
for i in 3...args.size do
|
||||
let arg ← argToMono args[i]!
|
||||
extraArgs := extraArgs.push arg
|
||||
|
|
@ -164,13 +164,13 @@ partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetVa
|
|||
else
|
||||
return e
|
||||
|
||||
def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do
|
||||
def LetDecl.toMono (decl : LetDecl .pure) : ToMonoM (LetDecl .pure) := do
|
||||
let type ← toMonoType decl.type
|
||||
let value ← decl.value.toMono decl.fvarId
|
||||
decl.update type value
|
||||
|
||||
def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFields : Nat)
|
||||
(oldFields : Array Param) : ToMonoM (Array Param) := do
|
||||
(oldFields : Array (Param .pure)) : ToMonoM (Array (Param .pure)) := do
|
||||
let mut type := ctorType
|
||||
for _ in *...numParams do
|
||||
match type with
|
||||
|
|
@ -189,14 +189,14 @@ def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFi
|
|||
|
||||
mutual
|
||||
|
||||
partial def FunDecl.toMono (decl : FunDecl) : ToMonoM FunDecl := do
|
||||
partial def FunDecl.toMono (decl : FunDecl .pure) : ToMonoM (FunDecl .pure) := do
|
||||
let type ← toMonoType decl.type
|
||||
let params ← decl.params.mapM (·.toMono)
|
||||
let value ← decl.value.toMono
|
||||
decl.update type params value
|
||||
|
||||
/-- Convert `cases` `Decidable` => `Bool` -/
|
||||
partial def decToMono (c : Cases) (_ : c.typeName == ``Decidable) : ToMonoM Code := do
|
||||
partial def decToMono (c : Cases .pure) (_ : c.typeName == ``Decidable) : ToMonoM (Code .pure) := do
|
||||
let resultType ← toMonoType c.resultType
|
||||
let alts ← c.alts.mapM fun alt => do
|
||||
match alt with
|
||||
|
|
@ -208,7 +208,7 @@ partial def decToMono (c : Cases) (_ : c.typeName == ``Decidable) : ToMonoM Code
|
|||
return .cases ⟨``Bool, resultType, c.discr, alts⟩
|
||||
|
||||
/-- Eliminate `cases` for `Nat`. -/
|
||||
partial def casesNatToMono (c: Cases) (_ : c.typeName == ``Nat) : ToMonoM Code := do
|
||||
partial def casesNatToMono (c: Cases .pure) (_ : c.typeName == ``Nat) : ToMonoM (Code .pure) := do
|
||||
let resultType ← toMonoType c.resultType
|
||||
let natType := mkConst ``Nat
|
||||
let zeroDecl ← mkLetDecl `zero natType (.lit (.nat 0))
|
||||
|
|
@ -229,7 +229,7 @@ partial def casesNatToMono (c: Cases) (_ : c.typeName == ``Nat) : ToMonoM Code :
|
|||
return .let zeroDecl (.let isZeroDecl (.cases ⟨``Bool, resultType, isZeroDecl.fvarId, alts⟩))
|
||||
|
||||
/-- Eliminate `cases` for `Int`. -/
|
||||
partial def casesIntToMono (c: Cases) (_ : c.typeName == ``Int) : ToMonoM Code := do
|
||||
partial def casesIntToMono (c: Cases .pure) (_ : c.typeName == ``Int) : ToMonoM (Code .pure) := do
|
||||
let resultType ← toMonoType c.resultType
|
||||
let natType := mkConst ``Nat
|
||||
let zeroNatDecl ← mkLetDecl `natZero natType (.lit (.nat 0))
|
||||
|
|
@ -254,7 +254,8 @@ partial def casesIntToMono (c: Cases) (_ : c.typeName == ``Int) : ToMonoM Code :
|
|||
return .let zeroNatDecl (.let zeroIntDecl (.let isNegDecl (.cases ⟨``Bool, resultType, isNegDecl.fvarId, alts⟩)))
|
||||
|
||||
/-- Eliminate `cases` for `UInt` types. -/
|
||||
partial def casesUIntToMono (c : Cases) (uintName : Name) (_ : c.typeName == uintName) : ToMonoM Code := do
|
||||
partial def casesUIntToMono (c : Cases .pure) (uintName : Name) (_ : c.typeName == uintName) :
|
||||
ToMonoM (Code .pure) := do
|
||||
assert! c.alts.size == 1
|
||||
let .alt _ ps k := c.alts[0]! | unreachable!
|
||||
eraseParams ps
|
||||
|
|
@ -265,7 +266,7 @@ partial def casesUIntToMono (c : Cases) (uintName : Name) (_ : c.typeName == uin
|
|||
return .let decl k
|
||||
|
||||
/-- Eliminate `cases` for `Array. -/
|
||||
partial def casesArrayToMono (c : Cases) (_ : c.typeName == ``Array) : ToMonoM Code := do
|
||||
partial def casesArrayToMono (c : Cases .pure) (_ : c.typeName == ``Array) : ToMonoM (Code .pure) := do
|
||||
assert! c.alts.size == 1
|
||||
let .alt _ ps k := c.alts[0]! | unreachable!
|
||||
eraseParams ps
|
||||
|
|
@ -276,7 +277,8 @@ partial def casesArrayToMono (c : Cases) (_ : c.typeName == ``Array) : ToMonoM C
|
|||
return .let decl k
|
||||
|
||||
/-- Eliminate `cases` for `ByteArray. -/
|
||||
partial def casesByteArrayToMono (c : Cases) (_ : c.typeName == ``ByteArray) : ToMonoM Code := do
|
||||
partial def casesByteArrayToMono (c : Cases .pure) (_ : c.typeName == ``ByteArray) :
|
||||
ToMonoM (Code .pure) := do
|
||||
assert! c.alts.size == 1
|
||||
let .alt _ ps k := c.alts[0]! | unreachable!
|
||||
eraseParams ps
|
||||
|
|
@ -287,7 +289,8 @@ partial def casesByteArrayToMono (c : Cases) (_ : c.typeName == ``ByteArray) : T
|
|||
return .let decl k
|
||||
|
||||
/-- Eliminate `cases` for `FloatArray. -/
|
||||
partial def casesFloatArrayToMono (c : Cases) (_ : c.typeName == ``FloatArray) : ToMonoM Code := do
|
||||
partial def casesFloatArrayToMono (c : Cases .pure) (_ : c.typeName == ``FloatArray) :
|
||||
ToMonoM (Code .pure) := do
|
||||
assert! c.alts.size == 1
|
||||
let .alt _ ps k := c.alts[0]! | unreachable!
|
||||
eraseParams ps
|
||||
|
|
@ -298,7 +301,7 @@ partial def casesFloatArrayToMono (c : Cases) (_ : c.typeName == ``FloatArray) :
|
|||
return .let decl k
|
||||
|
||||
/-- Eliminate `cases` for `String. -/
|
||||
partial def casesStringToMono (c : Cases) (_ : c.typeName == ``String) : ToMonoM Code := do
|
||||
partial def casesStringToMono (c : Cases .pure) (_ : c.typeName == ``String) : ToMonoM (Code .pure) := do
|
||||
assert! c.alts.size == 1
|
||||
let .alt _ ps k := c.alts[0]! | unreachable!
|
||||
eraseParams ps
|
||||
|
|
@ -309,7 +312,7 @@ partial def casesStringToMono (c : Cases) (_ : c.typeName == ``String) : ToMonoM
|
|||
return .let decl k
|
||||
|
||||
/-- Eliminate `cases` for `Thunk. -/
|
||||
partial def casesThunkToMono (c : Cases) (_ : c.typeName == ``Thunk) : ToMonoM Code := do
|
||||
partial def casesThunkToMono (c : Cases .pure) (_ : c.typeName == ``Thunk) : ToMonoM (Code .pure) := do
|
||||
assert! c.alts.size == 1
|
||||
let .alt _ ps k := c.alts[0]! | unreachable!
|
||||
eraseParams ps
|
||||
|
|
@ -329,7 +332,7 @@ partial def casesThunkToMono (c : Cases) (_ : c.typeName == ``Thunk) : ToMonoM C
|
|||
return .fun decl k
|
||||
|
||||
/-- Eliminate `cases` for `Task. -/
|
||||
partial def casesTaskToMono (c : Cases) (_ : c.typeName == ``Task) : ToMonoM Code := do
|
||||
partial def casesTaskToMono (c : Cases .pure) (_ : c.typeName == ``Task) : ToMonoM (Code .pure) := do
|
||||
assert! c.alts.size == 1
|
||||
let .alt _ ps k := c.alts[0]! | unreachable!
|
||||
eraseParams ps
|
||||
|
|
@ -340,7 +343,7 @@ partial def casesTaskToMono (c : Cases) (_ : c.typeName == ``Task) : ToMonoM Cod
|
|||
return .let decl k
|
||||
|
||||
/-- Eliminate `cases` for trivial structure. See `hasTrivialStructure?` -/
|
||||
partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases) : ToMonoM Code := do
|
||||
partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases .pure) : ToMonoM (Code .pure) := do
|
||||
assert! c.alts.size == 1
|
||||
let .alt ctorName ps k := c.alts[0]! | unreachable!
|
||||
assert! ctorName == info.ctorName
|
||||
|
|
@ -353,7 +356,7 @@ partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases) : ToMo
|
|||
let k ← k.toMono
|
||||
return .let decl k
|
||||
|
||||
partial def Code.toMono (code : Code) : ToMonoM Code := do
|
||||
partial def Code.toMono (code : Code .pure) : ToMonoM (Code .pure) := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
match decl.value with
|
||||
|
|
@ -428,10 +431,10 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do
|
|||
|
||||
end
|
||||
|
||||
def Decl.toMono (decl : Decl) : CompilerM Decl := do
|
||||
def Decl.toMono (decl : Decl .pure) : CompilerM (Decl .pure) := do
|
||||
go |>.run' {}
|
||||
where
|
||||
go : ToMonoM Decl := do
|
||||
go : ToMonoM (Decl .pure) := do
|
||||
let type ← toMonoType decl.type
|
||||
let params ← decl.params.mapM (·.toMono)
|
||||
let value ← decl.value.mapCodeM (·.toMono)
|
||||
|
|
|
|||
|
|
@ -14,23 +14,23 @@ public section
|
|||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
private partial def collectUsedDecls (code : Code) (s : NameSet := {}) : NameSet :=
|
||||
private partial def collectUsedDecls (code : Code pu) (s : NameSet := {}) : NameSet :=
|
||||
match code with
|
||||
| .let decl k => collectUsedDecls k <| collectLetValue decl.value s
|
||||
| .jp decl k | .fun decl k => collectUsedDecls decl.value <| collectUsedDecls k s
|
||||
| .jp decl k | .fun decl k _ => collectUsedDecls decl.value <| collectUsedDecls k s
|
||||
| .cases c =>
|
||||
c.alts.foldl (init := s) fun s alt =>
|
||||
match alt with
|
||||
| .default k => collectUsedDecls k s
|
||||
| .alt _ _ k => collectUsedDecls k s
|
||||
| .alt _ _ k _ => collectUsedDecls k s
|
||||
| _ => s
|
||||
where
|
||||
collectLetValue (e : LetValue) (s : NameSet) : NameSet :=
|
||||
collectLetValue (e : LetValue pu) (s : NameSet) : NameSet :=
|
||||
match e with
|
||||
| .const declName .. => s.insert declName
|
||||
| _ => s
|
||||
|
||||
private def shouldExportBody (decl : Decl) : CompilerM Bool := do
|
||||
private def shouldExportBody (decl : Decl pu) : CompilerM Bool := do
|
||||
-- Export body if template-like...
|
||||
decl.isTemplateLike <||>
|
||||
-- ...or it is below the (local) opportunistic inlining threshold and its `Expr` is exported
|
||||
|
|
@ -44,7 +44,7 @@ private def shouldExportBody (decl : Decl) : CompilerM Bool := do
|
|||
Marks the given declaration as to be exported and recursively infers the correct visibility of its
|
||||
body and referenced declarations based on that.
|
||||
-/
|
||||
partial def markDeclPublicRec (phase : Phase) (decl : Decl) : CompilerM Unit := do
|
||||
partial def markDeclPublicRec (phase : Phase) (decl : Decl pu) : CompilerM Unit := do
|
||||
modifyEnv (setDeclPublic · decl.name)
|
||||
if (← shouldExportBody decl) && !isDeclTransparent (← getEnv) phase decl.name then
|
||||
trace[Compiler.inferVisibility] m!"Marking {decl.name} as transparent because it is opaque and its body looks relevant"
|
||||
|
|
@ -57,7 +57,7 @@ partial def markDeclPublicRec (phase : Phase) (decl : Decl) : CompilerM Unit :=
|
|||
markDeclPublicRec phase refDecl
|
||||
|
||||
/-- Checks whether references in the given declaration adhere to phase distinction. -/
|
||||
partial def checkMeta (origDecl : Decl) : CompilerM Unit := do
|
||||
partial def checkMeta (origDecl : Decl pu) : CompilerM Unit := do
|
||||
if !(← getEnv).header.isModule || !compiler.checkMeta.get (← getOptions) then
|
||||
return
|
||||
let irPhases := getIRPhases (← getEnv) origDecl.name
|
||||
|
|
@ -68,7 +68,7 @@ partial def checkMeta (origDecl : Decl) : CompilerM Unit := do
|
|||
-- decls with relevant global attrs are public (`Lean.ensureAttrDeclIsMeta`).
|
||||
let isPublic := !isPrivateName origDecl.name
|
||||
go (irPhases == .comptime) isPublic origDecl |>.run' {}
|
||||
where go (isMeta isPublic : Bool) (decl : Decl) : StateT NameSet CompilerM Unit := do
|
||||
where go (isMeta isPublic : Bool) (decl : Decl pu) : StateT NameSet CompilerM Unit := do
|
||||
decl.value.forCodeM fun code =>
|
||||
for ref in collectUsedDecls code do
|
||||
if (← get).contains ref then
|
||||
|
|
@ -112,8 +112,8 @@ where go (isMeta isPublic : Bool) (decl : Decl) : StateT NameSet CompilerM Unit
|
|||
-- *their* references in this case. We also need to do this for non-auxiliary defs in case a
|
||||
-- public meta def tries to use a private meta import via a local private meta def :/ .
|
||||
if irPhases == .all || isPublic && isPrivateName ref then
|
||||
if let some refDecl ← getLocalDecl? ref then
|
||||
go isMeta isPublic refDecl
|
||||
if let some ⟨_, refDecl⟩ ← getLocalDecl? ref then
|
||||
go isMeta isPublic (refDecl.castPurity! pu)
|
||||
|
||||
/--
|
||||
Checks that imports necessary for inlining/specialization are public as otherwise we may run into
|
||||
|
|
@ -134,7 +134,7 @@ partial def checkTemplateVisibility : Pass where
|
|||
let isMeta := isMarkedMeta (← getEnv) decl.name
|
||||
go decl decl |>.run' {}
|
||||
return decls
|
||||
where go (origDecl decl : Decl) : StateT NameSet CompilerM Unit := do
|
||||
where go (origDecl decl : Decl .pure) : StateT NameSet CompilerM Unit := do
|
||||
decl.value.forCodeM fun code =>
|
||||
for ref in collectUsedDecls code do
|
||||
if (← get).contains ref then
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ def f (a : Nat) : Bool :=
|
|||
-- This is only required until the new code generator is enabled.
|
||||
run_meta Lean.Compiler.compile #[``f]
|
||||
|
||||
def countCalls : Probe Decl Nat :=
|
||||
Probe.getLetValues >=>
|
||||
def countCalls : Probe (Decl .pure) Nat :=
|
||||
Probe.getLetValues .pure >=>
|
||||
Probe.filter (fun e => return e matches .const `Decidable.decide ..) >=>
|
||||
Probe.count
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue