chore: new LCNF representation
This is the first of a series of commits to change the LCNF representation.
This commit is contained in:
parent
22cdac914d
commit
6d46829599
14 changed files with 361 additions and 167 deletions
|
|
@ -19,25 +19,47 @@ def eqvFVar (fvarId₁ fvarId₂ : FVarId) : EqvM Bool := do
|
|||
let fvarId₂ := (← read).find? fvarId₂ |>.getD fvarId₂
|
||||
return fvarId₁ == fvarId₂
|
||||
|
||||
def eqvExpr (e₁ e₂ : Expr) : EqvM Bool := do
|
||||
def eqvType (e₁ e₂ : Expr) : EqvM Bool := do
|
||||
match e₁, e₂ with
|
||||
| .app f₁ a₁, .app f₂ a₂ => eqvExpr a₁ a₂ <&&> eqvExpr f₁ f₂
|
||||
| .proj s₁ i₁ e₁, .proj s₂ i₂ e₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvExpr e₁ e₂
|
||||
| .mdata m₁ e₁, .mdata m₂ e₂ => pure (m₁ == m₂) <&&> eqvExpr e₁ e₂
|
||||
| .app f₁ a₁, .app f₂ a₂ => eqvType a₁ a₂ <&&> eqvType f₁ f₂
|
||||
| .fvar fvarId₁, .fvar fvarId₂ => eqvFVar fvarId₁ fvarId₂
|
||||
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvExpr d₁ d₂ <&&> eqvExpr b₁ b₂
|
||||
| .letE .., _ | _, .letE .. => unreachable!
|
||||
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvType d₁ d₂ <&&> eqvType b₁ b₂
|
||||
| _, _ => return e₁ == e₂
|
||||
|
||||
def eqvExprs (es₁ es₂ : Array Expr) : EqvM Bool := do
|
||||
def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do
|
||||
if es₁.size = es₂.size then
|
||||
for e₁ in es₁, e₂ in es₂ do
|
||||
unless (← eqvExpr e₁ e₂) do
|
||||
unless (← eqvType e₁ e₂) do
|
||||
return false
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
def eqvArg (a₁ a₂ : Arg) : EqvM Bool := do
|
||||
match a₁, a₂ with
|
||||
| .type e₁, .type e₂ => eqvType e₁ e₂
|
||||
| .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂
|
||||
| .erased, .erased => return true
|
||||
| _, _ => return false
|
||||
|
||||
def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
|
||||
if as₁.size = as₂.size then
|
||||
for a₁ in as₁, a₂ in as₂ do
|
||||
unless (← eqvArg a₁ a₂) do
|
||||
return false
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
def eqvLetExpr (e₁ e₂ : LetExpr) : EqvM Bool := do
|
||||
match e₁, e₂ with
|
||||
| .value v₁, .value v₂ => return v₁ == v₂
|
||||
| .erased, .erased => return true
|
||||
| .proj s₁ i₁ x₁, .proj s₂ i₂ x₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
|
||||
| .const n₁ us₁ as₁, .const n₂ us₂ as₂ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
|
||||
| .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂
|
||||
| _, _ => return false
|
||||
|
||||
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
|
||||
withReader (·.insert fvarId₂ fvarId₁) x
|
||||
|
||||
|
|
@ -48,7 +70,7 @@ def eqvExprs (es₁ es₂ : Array Expr) : EqvM Bool := do
|
|||
let p₁ := params₁[i]
|
||||
have : i < params₂.size := by simp_all_arith
|
||||
let p₂ := params₂[i]
|
||||
unless (← eqvExpr p₁.type p₂.type) do return false
|
||||
unless (← eqvType p₁.type p₂.type) do return false
|
||||
withFVar p₁.fvarId p₂.fvarId do
|
||||
go (i+1)
|
||||
else
|
||||
|
|
@ -84,20 +106,20 @@ partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
|
|||
partial def eqv (code₁ code₂ : Code) : EqvM Bool := do
|
||||
match code₁, code₂ with
|
||||
| .let decl₁ k₁, .let decl₂ k₂ =>
|
||||
eqvExpr decl₁.type decl₂.type <&&>
|
||||
eqvExpr decl₁.value decl₂.value <&&>
|
||||
eqvType decl₁.type decl₂.type <&&>
|
||||
eqvLetExpr decl₁.value decl₂.value <&&>
|
||||
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
|
||||
| .fun decl₁ k₁, .fun decl₂ k₂
|
||||
| .jp decl₁ k₁, .jp decl₂ k₂ =>
|
||||
eqvExpr decl₁.type decl₂.type <&&>
|
||||
eqvType decl₁.type decl₂.type <&&>
|
||||
withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&>
|
||||
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
|
||||
| .return fvarId₁, .return fvarId₂ => eqvFVar fvarId₁ fvarId₂
|
||||
| .unreach type₁, .unreach type₂ => eqvExpr type₁ type₂
|
||||
| .jmp fvarId₁ args₁, .jmp fvarId₂ args₂ => eqvFVar fvarId₁ fvarId₂ <&&> eqvExprs args₁ args₂
|
||||
| .unreach type₁, .unreach type₂ => eqvType type₁ type₂
|
||||
| .jmp fvarId₁ args₁, .jmp fvarId₂ args₂ => eqvFVar fvarId₁ fvarId₂ <&&> eqvArgs args₁ args₂
|
||||
| .cases c₁, .cases c₂ =>
|
||||
eqvFVar c₁.discr c₂.discr <&&>
|
||||
eqvExpr c₁.resultType c₂.resultType <&&>
|
||||
eqvType c₁.resultType c₂.resultType <&&>
|
||||
eqvAlts c₁.alts c₂.alts
|
||||
| _, _ => return false
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import Lean.Expr
|
|||
import Lean.Meta.Instances
|
||||
import Lean.Compiler.InlineAttrs
|
||||
import Lean.Compiler.Specialize
|
||||
import Lean.Compiler.LCNF.Types
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
|
|
@ -34,11 +35,90 @@ inductive AltCore (Code : Type) where
|
|||
| default (code : Code)
|
||||
deriving Inhabited
|
||||
|
||||
inductive Value where
|
||||
| natVal (val : Nat)
|
||||
| strVal (val : String)
|
||||
-- TODO: add constructors for `Int`, `Float`, `UInt` ...
|
||||
deriving Inhabited, BEq, Hashable
|
||||
|
||||
inductive Arg where
|
||||
| erased
|
||||
| fvar (fvarId : FVarId)
|
||||
| type (expr : Expr)
|
||||
deriving Inhabited, BEq, Hashable
|
||||
|
||||
def Arg.toExpr (arg : Arg) : Expr :=
|
||||
match arg with
|
||||
| .erased => erasedExpr
|
||||
| .fvar fvarId => .fvar fvarId
|
||||
| .type e => e
|
||||
|
||||
private unsafe def Arg.updateTypeImp (arg : Arg) (type' : Expr) : Arg :=
|
||||
match arg with
|
||||
| .type ty => if ptrEq ty type' then arg else .type type'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg) (type : Expr) : Arg
|
||||
|
||||
private unsafe def Arg.updateFVarImp (arg : Arg) (fvarId' : FVarId) : Arg :=
|
||||
match arg with
|
||||
| .fvar fvarId => if fvarId' == fvarId then arg else .fvar fvarId'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg) (fvarId' : FVarId) : Arg
|
||||
|
||||
inductive LetExpr where
|
||||
| value (value : Value)
|
||||
| erased
|
||||
| proj (typeName : Name) (idx : Nat) (struct : FVarId)
|
||||
| const (declName : Name) (us : List Level) (args : Array Arg)
|
||||
| fvar (fvarId : FVarId) (args : Array Arg)
|
||||
-- TODO: add constructors for mono and impure phases
|
||||
deriving Inhabited, BEq, Hashable
|
||||
|
||||
private unsafe def LetExpr.updateProjImp (e : LetExpr) (fvarId' : FVarId) : LetExpr :=
|
||||
match e with
|
||||
| .proj s i fvarId => if fvarId == fvarId' then e else .proj s i fvarId'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetExpr.updateProjImp] opaque LetExpr.updateProj! (e : LetExpr) (fvarId' : FVarId) : LetExpr
|
||||
|
||||
private unsafe def LetExpr.updateConstImp (e : LetExpr) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetExpr :=
|
||||
match e with
|
||||
| .const declName us args => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetExpr.updateConstImp] opaque LetExpr.updateConst! (e : LetExpr) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetExpr
|
||||
|
||||
private unsafe def LetExpr.updateFVarImp (e : LetExpr) (fvarId' : FVarId) (args' : Array Arg) : LetExpr :=
|
||||
match e with
|
||||
| .fvar fvarId args => if fvarId == fvarId' && ptrEq args args' then e else .fvar fvarId' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetExpr.updateFVarImp] opaque LetExpr.updateFVar! (e : LetExpr) (fvarId' : FVarId) (args' : Array Arg) : LetExpr
|
||||
|
||||
private unsafe def LetExpr.updateArgsImp (e : LetExpr) (args' : Array Arg) : LetExpr :=
|
||||
match e with
|
||||
| .const declName us args => if ptrEq args args' then e else .const declName us args'
|
||||
| .fvar fvarId args => if ptrEq args args' then e else .fvar fvarId args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by LetExpr.updateArgsImp] opaque LetExpr.updateArgs! (e : LetExpr) (args' : Array Arg) : LetExpr
|
||||
|
||||
def LetExpr.toExpr (e : LetExpr) : Expr :=
|
||||
match e with
|
||||
| .value (.natVal val) => .lit (.natVal val)
|
||||
| .value (.strVal val) => .lit (.strVal val)
|
||||
| .erased => erasedExpr
|
||||
| .proj n i s => .proj n i (.fvar s)
|
||||
| .const n us as => mkAppN (.const n us) (as.map Arg.toExpr)
|
||||
| .fvar fvarId as => mkAppN (.fvar fvarId) (as.map Arg.toExpr)
|
||||
|
||||
structure LetDecl where
|
||||
fvarId : FVarId
|
||||
binderName : Name
|
||||
type : Expr
|
||||
value : Expr
|
||||
value : LetExpr
|
||||
deriving Inhabited, BEq
|
||||
|
||||
structure FunDeclCore (Code : Type) where
|
||||
|
|
@ -63,7 +143,7 @@ inductive Code where
|
|||
| let (decl : LetDecl) (k : Code)
|
||||
| fun (decl : FunDeclCore Code) (k : Code)
|
||||
| jp (decl : FunDeclCore Code) (k : Code)
|
||||
| jmp (fvarId : FVarId) (args : Array Expr)
|
||||
| jmp (fvarId : FVarId) (args : Array Arg)
|
||||
| cases (cases : CasesCore Code)
|
||||
| return (fvarId : FVarId)
|
||||
| unreach (type : Expr)
|
||||
|
|
@ -218,12 +298,12 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al
|
|||
|
||||
@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code
|
||||
|
||||
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code :=
|
||||
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code :=
|
||||
match c with
|
||||
| .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code
|
||||
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code
|
||||
|
||||
@[inline] private unsafe def updateUnreachImp (c : Code) (type' : Expr) : Code :=
|
||||
match c with
|
||||
|
|
@ -245,7 +325,7 @@ to be updated.
|
|||
-/
|
||||
@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param) (type : Expr) : Param
|
||||
|
||||
private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : Expr) : LetDecl :=
|
||||
private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : LetExpr) : LetDecl :=
|
||||
if ptrEq type decl.type && ptrEq value decl.value then
|
||||
decl
|
||||
else
|
||||
|
|
@ -256,7 +336,7 @@ Low-level update `LetDecl` function. It does not update the local context.
|
|||
Consider using `LetDecl.update : LetDecl → Expr → Expr → CompilerM LetDecl` if you want the local context
|
||||
to be updated.
|
||||
-/
|
||||
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : Expr) : LetDecl
|
||||
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : LetExpr) : LetDecl
|
||||
|
||||
private unsafe def updateFunDeclCoreImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl :=
|
||||
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
|
||||
|
|
@ -459,6 +539,9 @@ def Decl.instantiateParamsLevelParams (decl : Decl) (us : List Level) : Array Pa
|
|||
partial def Decl.instantiateValueLevelParams (decl : Decl) (us : List Level) : Code :=
|
||||
instCode decl.value
|
||||
where
|
||||
instLevel (u : Level) :=
|
||||
u.instantiateParams decl.levelParams us
|
||||
|
||||
instExpr (e : Expr) :=
|
||||
e.instantiateLevelParamsNoCache decl.levelParams us
|
||||
|
||||
|
|
@ -470,8 +553,19 @@ where
|
|||
| .default k => alt.updateCode (instCode k)
|
||||
| .alt _ ps k => alt.updateAlt! (instParams ps) (instCode k)
|
||||
|
||||
instArg (arg : Arg) : Arg :=
|
||||
match arg with
|
||||
| .type e => arg.updateType! (instExpr e)
|
||||
| .fvar .. | .erased => arg
|
||||
|
||||
instLetExpr (e : LetExpr) : LetExpr :=
|
||||
match e with
|
||||
| .const declName vs args => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg)
|
||||
| .fvar fvarId args => e.updateFVar! fvarId (args.mapMono instArg)
|
||||
| .proj .. | .value .. | .erased => e
|
||||
|
||||
instLetDecl (decl : LetDecl) :=
|
||||
decl.updateCore (instExpr decl.type) (instExpr decl.value)
|
||||
decl.updateCore (instExpr decl.type) (instLetExpr decl.value)
|
||||
|
||||
instFunDecl (decl : FunDecl) :=
|
||||
decl.updateCore (instExpr decl.type) (instParams decl.params) (instCode decl.value)
|
||||
|
|
@ -481,7 +575,7 @@ where
|
|||
| .let decl k => code.updateLet! (instLetDecl decl) (instCode k)
|
||||
| .jp decl k | .fun decl k => code.updateFun! (instFunDecl decl) (instCode k)
|
||||
| .cases c => code.updateCases! (instExpr c.resultType) c.discr (c.alts.mapMono instAlt)
|
||||
| .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instExpr)
|
||||
| .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instArg)
|
||||
| .return .. => code
|
||||
| .unreach type => code.updateUnreach! (instExpr type)
|
||||
|
||||
|
|
@ -506,45 +600,56 @@ def Decl.isTemplateLike (decl : Decl) : CoreM Bool := do
|
|||
else
|
||||
return false
|
||||
|
||||
mutual
|
||||
partial def FunDeclCore.collectUsed (decl : FunDecl) (s : FVarIdSet := {}) : FVarIdSet :=
|
||||
decl.value.collectUsed <| collectParams decl.params <| collectExpr decl.type s
|
||||
private partial def collectType (e : Expr) : FVarIdSet → FVarIdSet :=
|
||||
match e with
|
||||
| .forallE _ d b _ => collectType b ∘ collectType d
|
||||
| .lam _ d b _ => collectType b ∘ collectType d
|
||||
| .app f a => collectType f ∘ collectType a
|
||||
| .fvar fvarId => fun s => s.insert fvarId
|
||||
| .proj .. | .letE .. | .mdata .. => unreachable!
|
||||
| _ => id
|
||||
|
||||
private def collectArg (arg : Arg) (s : FVarIdSet) : FVarIdSet :=
|
||||
match arg with
|
||||
| .erased => s
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
| .type e => collectType e s
|
||||
|
||||
private def collectArgs (args : Array Arg) (s : FVarIdSet) : FVarIdSet :=
|
||||
args.foldl (init := s) fun s arg => collectArg arg s
|
||||
|
||||
private def collectLetExpr (e : LetExpr) (s : FVarIdSet) : FVarIdSet :=
|
||||
match e with
|
||||
| .fvar fvarId args => collectArgs args <| s.insert fvarId
|
||||
| .const _ _ args => collectArgs args s
|
||||
| .proj _ _ fvarId => s.insert fvarId
|
||||
| .value .. | .erased => s
|
||||
|
||||
private partial def collectParams (ps : Array Param) (s : FVarIdSet) : FVarIdSet :=
|
||||
ps.foldl (init := s) fun s p => collectExpr p.type s
|
||||
ps.foldl (init := s) fun s p => collectType p.type s
|
||||
|
||||
private partial def collectExprs (es : Array Expr) (s : FVarIdSet) : FVarIdSet :=
|
||||
es.foldl (init := s) fun s e => collectExpr e s
|
||||
|
||||
private partial def collectExpr (e : Expr) : FVarIdSet → FVarIdSet :=
|
||||
match e with
|
||||
| .proj _ _ e => collectExpr e
|
||||
| .forallE _ d b _ => collectExpr b ∘ collectExpr d
|
||||
| .lam _ d b _ => collectExpr b ∘ collectExpr d
|
||||
| .letE .. => unreachable!
|
||||
| .app f a => collectExpr f ∘ collectExpr a
|
||||
| .mdata _ b => collectExpr b
|
||||
| .fvar fvarId => fun s => s.insert fvarId
|
||||
| _ => id
|
||||
mutual
|
||||
partial def FunDeclCore.collectUsed (decl : FunDecl) (s : FVarIdSet := {}) : FVarIdSet :=
|
||||
decl.value.collectUsed <| collectParams decl.params <| collectType decl.type s
|
||||
|
||||
partial def Code.collectUsed (code : Code) (s : FVarIdSet := {}) : FVarIdSet :=
|
||||
match code with
|
||||
| .let decl k => k.collectUsed <| collectExpr decl.value <| collectExpr decl.type s
|
||||
| .let decl k => k.collectUsed <| collectLetExpr decl.value <| collectType decl.type s
|
||||
| .jp decl k | .fun decl k => k.collectUsed <| decl.collectUsed s
|
||||
| .cases c =>
|
||||
let s := s.insert c.discr
|
||||
let s := collectExpr c.resultType s
|
||||
let s := collectType c.resultType s
|
||||
c.alts.foldl (init := s) fun s alt =>
|
||||
match alt with
|
||||
| .default k => k.collectUsed s
|
||||
| .alt _ ps k => k.collectUsed <| collectParams ps s
|
||||
| .return fvarId => s.insert fvarId
|
||||
| .unreach type => collectExpr type s
|
||||
| .jmp fvarId args => collectExprs args <| s.insert fvarId
|
||||
| .unreach type => collectType type s
|
||||
| .jmp fvarId args => collectArgs args <| s.insert fvarId
|
||||
end
|
||||
|
||||
abbrev collectUsedAtExpr (s : FVarIdSet) (e : Expr) : FVarIdSet :=
|
||||
collectExpr e s
|
||||
collectType e s
|
||||
|
||||
/--
|
||||
Traverse the given block of potentially mutually recursive functions
|
||||
|
|
@ -568,7 +673,7 @@ where
|
|||
| .cases c => c.alts.forM fun alt => visit alt.getCode
|
||||
| .unreach .. | .jmp .. | .return .. => return ()
|
||||
| .let decl k =>
|
||||
if let .const declName _ := decl.value.getAppFn then
|
||||
if let .const declName _ _ := decl.value then
|
||||
if decls.any (·.name == declName) then
|
||||
modify fun s => s.insert declName
|
||||
visit k
|
||||
|
|
|
|||
|
|
@ -57,12 +57,13 @@ where
|
|||
| .let decl k =>
|
||||
let decl ← normLetDecl decl
|
||||
-- We only apply CSE to pure code
|
||||
match (← get).map.find? decl.value with
|
||||
let key := decl.value.toExpr
|
||||
match (← get).map.find? key with
|
||||
| some fvarId =>
|
||||
replaceLet decl fvarId
|
||||
go k
|
||||
| none =>
|
||||
addEntry decl.value decl.fvarId
|
||||
addEntry key decl.fvarId
|
||||
return code.updateLet! decl (← go k)
|
||||
| .fun decl k =>
|
||||
let decl ← goFunDecl decl
|
||||
|
|
@ -91,7 +92,7 @@ where
|
|||
| .default k => withNewScope do return alt.updateCode (← go k)
|
||||
return code.updateCases! resultType discr alts
|
||||
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
|
||||
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
|
||||
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normArgs args)
|
||||
| .unreach .. => return code
|
||||
|
||||
end CSE
|
||||
|
|
|
|||
|
|
@ -71,7 +71,20 @@ mutual
|
|||
contain other type parameters.
|
||||
-/
|
||||
partial def collectParams (params : Array Param) : ClosureM Unit :=
|
||||
params.forM (collectExpr ·.type)
|
||||
params.forM (collectType ·.type)
|
||||
|
||||
partial def collectArg (arg : Arg) : ClosureM Unit :=
|
||||
match arg with
|
||||
| .erased => return ()
|
||||
| .type e => collectType e
|
||||
| .fvar fvarId => collectFVar fvarId
|
||||
|
||||
partial def collectLetExpr (e : LetExpr) : ClosureM Unit := do
|
||||
match e with
|
||||
| .erased | .value .. => return ()
|
||||
| .proj _ _ fvarId => collectFVar fvarId
|
||||
| .const _ _ args => args.forM collectArg
|
||||
| .fvar fvarId args => collectFVar fvarId; args.forM collectArg
|
||||
|
||||
/--
|
||||
Collect dependencies in the given code. We need this function to be able
|
||||
|
|
@ -79,22 +92,22 @@ mutual
|
|||
-/
|
||||
partial def collectCode (c : Code) : ClosureM Unit := do
|
||||
match c with
|
||||
| .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k
|
||||
| .let decl k => collectType decl.type; collectLetExpr decl.value; collectCode k
|
||||
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
|
||||
| .cases c =>
|
||||
collectExpr c.resultType
|
||||
collectType c.resultType
|
||||
collectFVar c.discr
|
||||
c.alts.forM fun alt => do
|
||||
match alt with
|
||||
| .default k => collectCode k
|
||||
| .alt _ ps k => collectParams ps; collectCode k
|
||||
| .jmp _ args => args.forM collectExpr
|
||||
| .unreach type => collectExpr type
|
||||
| .jmp _ args => args.forM collectArg
|
||||
| .unreach type => collectType type
|
||||
| .return fvarId => collectFVar fvarId
|
||||
|
||||
/-- Collect dependencies of a local function declaration. -/
|
||||
partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do
|
||||
collectExpr decl.type
|
||||
collectType decl.type
|
||||
collectParams decl.params
|
||||
collectCode decl.value
|
||||
|
||||
|
|
@ -114,21 +127,21 @@ mutual
|
|||
collectFunDecl funDecl
|
||||
modify fun s => { s with decls := s.decls.push <| .fun funDecl }
|
||||
else if let some param ← findParam? fvarId then
|
||||
collectExpr param.type
|
||||
collectType param.type
|
||||
modify fun s => { s with params := s.params.push param }
|
||||
else if let some letDecl ← findLetDecl? fvarId then
|
||||
collectExpr letDecl.type
|
||||
collectType letDecl.type
|
||||
if (← read).abstract letDecl.fvarId then
|
||||
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
|
||||
else
|
||||
collectExpr letDecl.value
|
||||
collectLetExpr letDecl.value
|
||||
modify fun s => { s with decls := s.decls.push <| .let letDecl }
|
||||
else
|
||||
unreachable!
|
||||
|
||||
/-- Collect dependencies of the given expression. -/
|
||||
partial def collectExpr (e : Expr) : ClosureM Unit := do
|
||||
e.forEach fun e => do
|
||||
partial def collectType (type : Expr) : ClosureM Unit := do
|
||||
type.forEach fun e => do
|
||||
match e with
|
||||
| .fvar fvarId => collectFVar fvarId
|
||||
| _ => pure ()
|
||||
|
|
|
|||
|
|
@ -190,6 +190,23 @@ where
|
|||
else
|
||||
e
|
||||
|
||||
/--
|
||||
Replace the free variables in `arg` using the given substitution.
|
||||
|
||||
See `normExprImp`
|
||||
-/
|
||||
private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : Arg :=
|
||||
match arg with
|
||||
| .erased => arg
|
||||
| .fvar fvarId =>
|
||||
match s.find? fvarId with
|
||||
| some (.fvar fvarId') =>
|
||||
let arg' := .fvar fvarId'
|
||||
if translator then arg' else normArgImp s arg' translator
|
||||
| some e => if e.isErased then .erased else .type e
|
||||
| none => arg
|
||||
| .type e => arg.updateType! (normExprImp s e translator)
|
||||
|
||||
/--
|
||||
Normalize the given free variable.
|
||||
See `normExprImp` for documentation on the `translator` parameter.
|
||||
|
|
@ -208,6 +225,22 @@ private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator :
|
|||
| some e => panic! s!"invalid LCNF substitution of free variable with expression {e}"
|
||||
| none => fvarId
|
||||
|
||||
|
||||
private def normArgsImp (s : FVarSubst) (args : Array Arg) (translator : Bool) : Array Arg :=
|
||||
args.mapMono (normArgImp s · translator)
|
||||
|
||||
/--
|
||||
Replace the free variables in `e` using the given substitution.
|
||||
|
||||
See `normExprImp`
|
||||
-/
|
||||
private partial def normLetExprImp (s : FVarSubst) (e : LetExpr) (translator : Bool) : LetExpr :=
|
||||
match e with
|
||||
| .erased | .value .. => e
|
||||
| .proj _ _ fvarId => e.updateProj! (normFVarImp s fvarId translator)
|
||||
| .const _ _ args => e.updateArgs! (normArgsImp s args translator)
|
||||
| .fvar fvarId args => e.updateFVar! (normFVarImp s fvarId translator) (normArgsImp s args translator)
|
||||
|
||||
/--
|
||||
Interface for monads that have a free substitutions.
|
||||
-/
|
||||
|
|
@ -248,15 +281,21 @@ See `Check.lean` for the free variable substitution checker.
|
|||
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr :=
|
||||
return normExprImp (← getSubst) e t
|
||||
|
||||
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m t] [Monad m] (arg : Arg) : m Arg :=
|
||||
return normArgImp (← getSubst) arg t
|
||||
|
||||
@[inline, inherit_doc normLetExprImp] def normLetExpr [MonadFVarSubst m t] [Monad m] (e : LetExpr) : m LetExpr :=
|
||||
return normLetExprImp (← getSubst) e t
|
||||
|
||||
@[inherit_doc normExprImp]
|
||||
abbrev normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
|
||||
normExprImp s e translator
|
||||
|
||||
/--
|
||||
Normalize the given expressions using the current substitution.
|
||||
Normalize the given arguments using the current substitution.
|
||||
-/
|
||||
def normExprs [MonadFVarSubst m t] [Monad m] (es : Array Expr) : m (Array Expr) :=
|
||||
es.mapMonoM normExpr
|
||||
def normArgs [MonadFVarSubst m t] [Monad m] (args : Array Arg) : m (Array Arg) :=
|
||||
return normArgsImp (← getSubst) args t
|
||||
|
||||
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
|
||||
let declName := .num binderName (← get).nextIdx
|
||||
|
|
@ -280,7 +319,7 @@ def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param
|
|||
modifyLCtx fun lctx => lctx.addParam param
|
||||
return param
|
||||
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) : CompilerM LetDecl := do
|
||||
def mkLetDecl (binderName : Name) (type : Expr) (value : LetExpr) : CompilerM LetDecl := do
|
||||
let fvarId ← mkFreshFVarId
|
||||
let binderName ← ensureNotAnonymous binderName `_x
|
||||
let decl := { fvarId, binderName, type, value }
|
||||
|
|
@ -304,7 +343,7 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param :=
|
|||
|
||||
@[implemented_by updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
|
||||
|
||||
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do
|
||||
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetExpr) : CompilerM LetDecl := do
|
||||
if ptrEq type decl.type && ptrEq value decl.value then
|
||||
return decl
|
||||
else
|
||||
|
|
@ -312,9 +351,9 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr
|
|||
modifyLCtx fun lctx => lctx.addLetDecl decl
|
||||
return decl
|
||||
|
||||
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
|
||||
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : LetExpr) : CompilerM LetDecl
|
||||
|
||||
def LetDecl.updateValue (decl : LetDecl) (value : Expr) : CompilerM LetDecl :=
|
||||
def LetDecl.updateValue (decl : LetDecl) (value : LetExpr) : CompilerM LetDecl :=
|
||||
decl.update decl.type value
|
||||
|
||||
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
|
||||
|
|
@ -340,7 +379,7 @@ def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Arr
|
|||
ps.mapMonoM normParam
|
||||
|
||||
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do
|
||||
decl.update (← normExpr decl.type) (← normExpr decl.value)
|
||||
decl.update (← normExpr decl.type) (← normLetExpr decl.value)
|
||||
|
||||
abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM
|
||||
|
||||
|
|
@ -359,7 +398,7 @@ mutual
|
|||
| .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k)
|
||||
| .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
|
||||
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
|
||||
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
|
||||
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normArgs args)
|
||||
| .unreach type => return code.updateUnreach! (← normExpr type)
|
||||
| .cases c =>
|
||||
let resultType ← normExpr c.resultType
|
||||
|
|
|
|||
|
|
@ -12,19 +12,32 @@ private abbrev M := ReaderT FVarIdSet Id
|
|||
private def fvarDepOn (fvarId : FVarId) : M Bool :=
|
||||
return (← read).contains fvarId
|
||||
|
||||
private def exprDepOn (e : Expr) : M Bool := do
|
||||
private def typeDepOn (e : Expr) : M Bool := do
|
||||
let s ← read
|
||||
return e.hasAnyFVar fun fvarId => s.contains fvarId
|
||||
|
||||
private def argDepOn (a : Arg) : M Bool := do
|
||||
match a with
|
||||
| .erased => return false
|
||||
| .fvar fvarId => fvarDepOn fvarId
|
||||
| .type e => typeDepOn e
|
||||
|
||||
private def letExprDepOn (e : LetExpr) : M Bool :=
|
||||
match e with
|
||||
| .erased | .value .. => return false
|
||||
| .proj _ _ fvarId => fvarDepOn fvarId
|
||||
| .fvar fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
|
||||
| .const _ _ args => args.anyM argDepOn
|
||||
|
||||
private def LetDecl.depOn (decl : LetDecl) : M Bool :=
|
||||
exprDepOn decl.type <||> exprDepOn decl.value
|
||||
typeDepOn decl.type <||> letExprDepOn decl.value
|
||||
|
||||
private partial def depOn (c : Code) : M Bool :=
|
||||
match c with
|
||||
| .let decl k => decl.depOn <||> depOn k
|
||||
| .jp decl k | .fun decl k => exprDepOn decl.type <||> depOn decl.value <||> depOn k
|
||||
| .cases c => exprDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode
|
||||
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM exprDepOn
|
||||
| .jp decl k | .fun decl k => typeDepOn decl.type <||> depOn decl.value <||> depOn k
|
||||
| .cases c => typeDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode
|
||||
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
|
||||
| .return fvarId => fvarDepOn fvarId
|
||||
| .unreach _ => return false
|
||||
|
||||
|
|
@ -32,7 +45,7 @@ abbrev LetDecl.dependsOn (decl : LetDecl) (s : FVarIdSet) : Bool :=
|
|||
decl.depOn s
|
||||
|
||||
abbrev FunDecl.dependsOn (decl : FunDecl) (s : FVarIdSet) : Bool :=
|
||||
exprDepOn decl.type s || depOn decl.value s
|
||||
typeDepOn decl.type s || depOn decl.value s
|
||||
|
||||
def CodeDecl.dependsOn (decl : CodeDecl) (s : FVarIdSet) : Bool :=
|
||||
match decl with
|
||||
|
|
|
|||
|
|
@ -13,26 +13,43 @@ abbrev UsedLocalDecls := FVarIdHashSet
|
|||
Collect set of (let) free variables in a LCNF value.
|
||||
This code exploits the LCNF property that local declarations do not occur in types.
|
||||
-/
|
||||
def collectLocalDecls (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls :=
|
||||
go s e
|
||||
def collectLocalDeclsType (s : UsedLocalDecls) (type : Expr) : UsedLocalDecls :=
|
||||
go s type
|
||||
where
|
||||
go (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls :=
|
||||
match e with
|
||||
| .proj _ _ e => go s e
|
||||
| .forallE .. => s
|
||||
| .lam _ _ b _ => go s b
|
||||
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
|
||||
| .app f a => go (go s a) f
|
||||
| .mdata _ b => go s b
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
| .letE .. | .proj .. | .mdata .. => unreachable! -- Valid LCNF type does not contain this kind of expr
|
||||
| _ => s
|
||||
|
||||
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg) : UsedLocalDecls :=
|
||||
match arg with
|
||||
| .erased => s
|
||||
| .type e => collectLocalDeclsType s e
|
||||
| .fvar fvarId => s.insert fvarId
|
||||
|
||||
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array Arg) : UsedLocalDecls :=
|
||||
args.foldl (init := s) collectLocalDeclsArg
|
||||
|
||||
def collectLocalDeclsLetExpr (s : UsedLocalDecls) (e : LetExpr) : UsedLocalDecls :=
|
||||
match e with
|
||||
| .erased | .value .. => s
|
||||
| .proj _ _ fvarId => s.insert fvarId
|
||||
| .const _ _ args => collectLocalDeclsArgs s args
|
||||
| .fvar fvarId args => collectLocalDeclsArgs (s.insert fvarId) args
|
||||
|
||||
namespace ElimDead
|
||||
|
||||
abbrev M := StateRefT UsedLocalDecls CompilerM
|
||||
|
||||
private abbrev collectExprM (e : Expr) : M Unit :=
|
||||
modify (collectLocalDecls · e)
|
||||
private abbrev collectArgM (arg : Arg) : M Unit :=
|
||||
modify (collectLocalDeclsArg · arg)
|
||||
|
||||
private abbrev collectLetExprM (e : LetExpr) : M Unit :=
|
||||
modify (collectLocalDeclsLetExpr · e)
|
||||
|
||||
private abbrev collectFVarM (fvarId : FVarId) : M Unit :=
|
||||
modify (·.insert fvarId)
|
||||
|
|
@ -48,7 +65,7 @@ partial def elimDead (code : Code) : M Code := do
|
|||
let k ← elimDead k
|
||||
if (← get).contains decl.fvarId then
|
||||
/- Remark: we don't need to collect `decl.type` because LCNF local declarations do not occur in types. -/
|
||||
collectExprM decl.value
|
||||
collectLetExprM decl.value
|
||||
return code.updateCont! k
|
||||
else
|
||||
eraseLetDecl decl
|
||||
|
|
@ -66,7 +83,7 @@ partial def elimDead (code : Code) : M Code := do
|
|||
collectFVarM c.discr
|
||||
return code.updateAlts! alts
|
||||
| .return fvarId => collectFVarM fvarId; return code
|
||||
| .jmp fvarId args => collectFVarM fvarId; args.forM collectExprM; return code
|
||||
| .jmp fvarId args => collectFVarM fvarId; args.forM collectArgM; return code
|
||||
| .unreach .. => return code
|
||||
|
||||
end
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ marked as not fixed.
|
|||
-/
|
||||
|
||||
/-- Abstract value for the "fixed parameter" analysis. -/
|
||||
inductive Value where
|
||||
inductive AbsValue where
|
||||
| top
|
||||
| erased
|
||||
| val (i : Nat)
|
||||
|
|
@ -52,7 +52,7 @@ structure Context where
|
|||
The assignment maps free variable ids in the current code being analyzed to abstract values.
|
||||
We only track the abstract value assigned to parameters.
|
||||
-/
|
||||
assignment : FVarIdMap Value
|
||||
assignment : FVarIdMap AbsValue
|
||||
|
||||
structure State where
|
||||
/--
|
||||
|
|
@ -61,7 +61,7 @@ structure State where
|
|||
Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and
|
||||
we visit with the abstract values assigned to `aᵢ`, but first we record the visit here.
|
||||
-/
|
||||
visited : HashSet (Name × Array Value) := {}
|
||||
visited : HashSet (Name × Array AbsValue) := {}
|
||||
/--
|
||||
Bitmask containing the result, i.e., which parameters of `main` are fixed.
|
||||
We initialize it with `true` everywhere.
|
||||
|
|
@ -76,17 +76,18 @@ abbrev abort : FixParamM α := do
|
|||
modify fun s => { s with fixed := s.fixed.map fun _ => false }
|
||||
throw ()
|
||||
|
||||
def evalArg (arg : Expr) : FixParamM Value := do
|
||||
if arg.isErased then
|
||||
return .erased
|
||||
let .fvar fvarId := arg | return .top
|
||||
let some val := (← read).assignment.find? fvarId | return .top
|
||||
return val
|
||||
def evalArg (arg : Arg) : FixParamM AbsValue := do
|
||||
match arg with
|
||||
| .erased => return .erased
|
||||
| .type _ => return .top
|
||||
| .fvar fvarId =>
|
||||
let some val := (← read).assignment.find? fvarId | return .top
|
||||
return val
|
||||
|
||||
def inMutualBlock (declName : Name) : FixParamM Bool :=
|
||||
return (← read).decls.any (·.name == declName)
|
||||
|
||||
def mkAssignment (decl : Decl) (values : Array Value) : FVarIdMap Value := Id.run do
|
||||
def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do
|
||||
let mut assignment := {}
|
||||
for param in decl.params, value in values do
|
||||
assignment := assignment.insert param.fvarId value
|
||||
|
|
@ -94,23 +95,19 @@ def mkAssignment (decl : Decl) (values : Array Value) : FVarIdMap Value := Id.ru
|
|||
|
||||
mutual
|
||||
|
||||
partial def evalExpr (e : Expr) : FixParamM Unit := do
|
||||
partial def evalLetExpr (e : LetExpr) : FixParamM Unit := do
|
||||
match e with
|
||||
| .const declName _ => evalApp declName #[]
|
||||
| .app .. =>
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
if (← inMutualBlock declName) then
|
||||
evalApp declName e.getAppArgs
|
||||
| .const declName _ args => evalApp declName args
|
||||
| _ => return ()
|
||||
|
||||
partial def evalCode (code : Code) : FixParamM Unit := do
|
||||
match code with
|
||||
| .let decl k => evalExpr decl.value; evalCode k
|
||||
| .let decl k => evalLetExpr decl.value; evalCode k
|
||||
| .fun decl k | .jp decl k => evalCode decl.value; evalCode k
|
||||
| .cases c => c.alts.forM fun alt => evalCode alt.getCode
|
||||
| .unreach .. | .jmp .. | .return .. => return ()
|
||||
|
||||
partial def evalApp (declName : Name) (args : Array Expr) : FixParamM Unit := do
|
||||
partial def evalApp (declName : Name) (args : Array Arg) : FixParamM Unit := do
|
||||
let main := (← read).main
|
||||
if declName == main.name then
|
||||
-- Recursive call to the function being analyzed
|
||||
|
|
@ -145,7 +142,7 @@ partial def evalApp (declName : Name) (args : Array Expr) : FixParamM Unit := do
|
|||
|
||||
end
|
||||
|
||||
def mkInitialValues (numParams : Nat) : Array Value := Id.run do
|
||||
def mkInitialValues (numParams : Nat) : Array AbsValue := Id.run do
|
||||
let mut values := #[]
|
||||
for i in [:numParams] do
|
||||
values := values.push <| .val i
|
||||
|
|
|
|||
|
|
@ -8,44 +8,6 @@ import Lean.Compiler.LCNF.Basic
|
|||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
partial def Code.forEachExpr [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m] (f : Expr → m Unit) (c : Code) (skipTypes := false) : m Unit := do
|
||||
visit c |>.run
|
||||
where
|
||||
visit (c : Code) : MonadCacheT Expr Unit m Unit := do
|
||||
match c with
|
||||
| .let decl k =>
|
||||
visitType decl.type
|
||||
visitExpr decl.value
|
||||
visit k
|
||||
| .jp decl k | .fun decl k =>
|
||||
visitType decl.type
|
||||
decl.params.forM visitParam
|
||||
visit decl.value
|
||||
visit k
|
||||
| .unreach type => visitType type
|
||||
| .return .. => return ()
|
||||
| .jmp _ args => args.forM visitExpr
|
||||
| .cases c => visitType c.resultType; c.alts.forM fun alt => do
|
||||
match alt with
|
||||
| .default k => visit k
|
||||
| .alt _ ps k => ps.forM visitParam; visit k
|
||||
|
||||
visitParam (p : Param) : MonadCacheT Expr Unit m Unit :=
|
||||
visitType p.type
|
||||
|
||||
visitExpr (e : Expr) : MonadCacheT Expr Unit m Unit :=
|
||||
ForEachExpr.visit (fun e => f e *> return true) e
|
||||
|
||||
visitType (e : Expr) : MonadCacheT Expr Unit m Unit :=
|
||||
unless skipTypes do
|
||||
visitExpr e
|
||||
|
||||
def Decl.forEachExpr [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m] (f : Expr → m Unit) (decl : Decl) (skipTypes := false) : m Unit := do
|
||||
visit |>.run
|
||||
where
|
||||
visit : MonadCacheT Expr Unit m Unit := do
|
||||
Code.forEachExpr.visitType f decl.type (skipTypes := skipTypes)
|
||||
decl.params.forM (Code.forEachExpr.visitParam f (skipTypes := skipTypes))
|
||||
Code.forEachExpr.visit f decl.value (skipTypes := skipTypes)
|
||||
-- TODO: delete
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
@ -101,21 +101,22 @@ def inferConstType (declName : Name) (us : List Level) : CompilerM Expr := do
|
|||
getOtherDeclType declName us
|
||||
|
||||
mutual
|
||||
partial def inferArgType (arg : Arg) : InferTypeM Expr :=
|
||||
match arg with
|
||||
| .erased => return erasedExpr
|
||||
| .type e => inferType e
|
||||
| .fvar fvarId => LCNF.getType fvarId
|
||||
|
||||
-- TODO: stopped here
|
||||
partial def inferType (e : Expr) : InferTypeM Expr :=
|
||||
match e with
|
||||
| .const c us => inferConstType c us
|
||||
| .proj n i s => inferProjType n i s
|
||||
| .app .. => inferAppType e
|
||||
| .mvar .. => throwError "unexpected metavariable {e}"
|
||||
| .fvar fvarId => InferType.getType fvarId
|
||||
| .bvar .. => throwError "unexpected bound variable {e}"
|
||||
| .mdata _ e => inferType e
|
||||
| .lit v => return v.type
|
||||
| .sort lvl => return .sort (mkLevelSucc lvl)
|
||||
| .forallE .. => inferForallType e
|
||||
| .lam .. => inferLambdaType e
|
||||
| .letE .. => inferLambdaType e
|
||||
| .letE .. | .mvar .. | .mdata .. | .lit .. | .bvar .. | .proj .. => unreachable!
|
||||
|
||||
partial def inferAppTypeCore (f : Expr) (args : Array Expr) : InferTypeM Expr := do
|
||||
let mut j := 0
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def LCtx.toLocalContext (lctx : LCtx) : LocalContext := Id.run do
|
|||
for (_, param) in lctx.params.toArray do
|
||||
result := result.addDecl (.cdecl 0 param.fvarId param.binderName param.type .default .default)
|
||||
for (_, decl) in lctx.letDecls.toArray do
|
||||
result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value true .default)
|
||||
result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value.toExpr true .default)
|
||||
for (_, decl) in lctx.funDecls.toArray do
|
||||
result := result.addDecl (.cdecl 0 decl.fvarId decl.binderName decl.type .default .default)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -97,8 +97,25 @@ See `Decl.setLevelParams`.
|
|||
-/
|
||||
open Lean.CollectLevelParams
|
||||
|
||||
abbrev visitType (type : Expr) : Visitor :=
|
||||
visitExpr type
|
||||
|
||||
def visitArg (arg : Arg) : Visitor :=
|
||||
match arg with
|
||||
| .erased | .fvar .. => id
|
||||
| .type e => visitType e
|
||||
|
||||
def visitArgs (args : Array Arg) : Visitor :=
|
||||
fun s => args.foldl (init := s) fun s arg => visitArg arg s
|
||||
|
||||
def visitLetExpr (e : LetExpr) : Visitor :=
|
||||
match e with
|
||||
| .erased | .value .. | .proj .. => id
|
||||
| .const _ us args => visitLevels us ∘ visitArgs args
|
||||
| .fvar _ args => visitArgs args
|
||||
|
||||
def visitParam (p : Param) : Visitor :=
|
||||
visitExpr p.type
|
||||
visitType p.type
|
||||
|
||||
def visitParams (ps : Array Param) : Visitor :=
|
||||
fun s => ps.foldl (init := s) fun s p => visitParam p s
|
||||
|
|
@ -113,12 +130,12 @@ mutual
|
|||
fun s => alts.foldl (init := s) fun s alt => visitAlt alt s
|
||||
|
||||
partial def visitCode : Code → Visitor
|
||||
| .let decl k => visitCode k ∘ visitExpr decl.value ∘ visitExpr decl.type
|
||||
| .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitExpr decl.type
|
||||
| .cases c => visitAlts c.alts ∘ visitExpr c.resultType
|
||||
| .unreach type => visitExpr type
|
||||
| .let decl k => visitCode k ∘ visitLetExpr decl.value ∘ visitType decl.type
|
||||
| .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type
|
||||
| .cases c => visitAlts c.alts ∘ visitType c.resultType
|
||||
| .unreach type => visitType type
|
||||
| .return _ => id
|
||||
| .jmp _ args => fun s => args.foldl (init := s) fun s arg => visitExpr arg s
|
||||
| .jmp _ args => visitArgs args
|
||||
end
|
||||
|
||||
end CollectLevelParams
|
||||
|
|
@ -131,7 +148,7 @@ Collect universe level parameters collecting in the type, parameters, and value,
|
|||
set `decl.levelParams` with the resulting value.
|
||||
-/
|
||||
def Decl.setLevelParams (decl : Decl) : Decl :=
|
||||
let levelParams := (visitCode decl.value ∘ visitParams decl.params ∘ visitExpr decl.type) {} |>.params.toList
|
||||
let levelParams := (visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type) {} |>.params.toList
|
||||
{ decl with levelParams }
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -11,25 +11,29 @@ import Lean.Compiler.LCNF.CompilerM
|
|||
namespace Lean.Compiler.LCNF
|
||||
namespace Simp
|
||||
|
||||
partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
|
||||
/-
|
||||
-- TODO: cleanup
|
||||
partial def findExpr (e : LetExpr) : CompilerM LetExpr := do
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
let some decl ← findLetDecl? fvarId | return e
|
||||
findExpr decl.value
|
||||
| .mdata _ e' => if skipMData then findExpr e' else return e
|
||||
| .fvar fvarId args =>
|
||||
if args.isEmpty then
|
||||
let some decl ← findLetDecl? fvarId | return e
|
||||
findExpr decl.value
|
||||
else
|
||||
return e
|
||||
| _ => return e
|
||||
|
||||
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
|
||||
partial def findFunDecl? (e : LetExpr) : CompilerM (Option FunDecl) := do
|
||||
match e with
|
||||
| .fvar fvarId =>
|
||||
| .fvar fvarId args =>
|
||||
|
||||
if let some decl ← LCNF.findFunDecl? fvarId then
|
||||
return some decl
|
||||
else if let some decl ← findLetDecl? fvarId then
|
||||
findFunDecl? decl.value
|
||||
else
|
||||
return none
|
||||
| .mdata _ e => findFunDecl? e
|
||||
| _ => return none
|
||||
|
||||
-/
|
||||
end Simp
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ end ToExpr
|
|||
|
||||
open ToExpr
|
||||
|
||||
private def Arg.toExprM (arg : Arg) : ToExprM Expr :=
|
||||
return arg.toExpr.abstract' (← read) (← get)
|
||||
|
||||
mutual
|
||||
partial def FunDeclCore.toExprM (decl : FunDecl) : ToExprM Expr :=
|
||||
withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM)
|
||||
|
|
@ -82,7 +85,7 @@ partial def Code.toExprM (code : Code) : ToExprM Expr := do
|
|||
match code with
|
||||
| .let decl k =>
|
||||
let type ← abstractM decl.type
|
||||
let value ← abstractM decl.value
|
||||
let value ← abstractM decl.value.toExpr
|
||||
let body ← withFVar decl.fvarId k.toExprM
|
||||
return .letE decl.binderName type value body true
|
||||
| .fun decl k | .jp decl k =>
|
||||
|
|
@ -91,7 +94,7 @@ partial def Code.toExprM (code : Code) : ToExprM Expr := do
|
|||
let body ← withFVar decl.fvarId k.toExprM
|
||||
return .letE decl.binderName type value body true
|
||||
| .return fvarId => fvarId.toExprM
|
||||
| .jmp fvarId args => return mkAppN (← fvarId.toExprM) (← args.mapM abstractM)
|
||||
| .jmp fvarId args => return mkAppN (← fvarId.toExprM) (← args.mapM Arg.toExprM)
|
||||
| .unreach type => return mkApp (mkConst ``lcUnreachable) (← abstractM type)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapM fun
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue