chore: new LCNF representation

This is the first of a series of commits to change the LCNF representation.
This commit is contained in:
Leonardo de Moura 2022-10-24 19:53:08 -07:00
parent 22cdac914d
commit 6d46829599
14 changed files with 361 additions and 167 deletions

View file

@ -19,25 +19,47 @@ def eqvFVar (fvarId₁ fvarId₂ : FVarId) : EqvM Bool := do
let fvarId₂ := (← read).find? fvarId₂ |>.getD fvarId₂
return fvarId₁ == fvarId₂
def eqvExpr (e₁ e₂ : Expr) : EqvM Bool := do
def eqvType (e₁ e₂ : Expr) : EqvM Bool := do
match e₁, e₂ with
| .app f₁ a₁, .app f₂ a₂ => eqvExpr a₁ a₂ <&&> eqvExpr f₁ f₂
| .proj s₁ i₁ e₁, .proj s₂ i₂ e₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvExpr e₁ e₂
| .mdata m₁ e₁, .mdata m₂ e₂ => pure (m₁ == m₂) <&&> eqvExpr e₁ e₂
| .app f₁ a₁, .app f₂ a₂ => eqvType a₁ a₂ <&&> eqvType f₁ f₂
| .fvar fvarId₁, .fvar fvarId₂ => eqvFVar fvarId₁ fvarId₂
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvExpr d₁ d₂ <&&> eqvExpr b₁ b₂
| .letE .., _ | _, .letE .. => unreachable!
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvType d₁ d₂ <&&> eqvType b₁ b₂
| _, _ => return e₁ == e₂
def eqvExprs (es₁ es₂ : Array Expr) : EqvM Bool := do
def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do
if es₁.size = es₂.size then
for e₁ in es₁, e₂ in es₂ do
unless (← eqvExpr e₁ e₂) do
unless (← eqvType e₁ e₂) do
return false
return true
else
return false
def eqvArg (a₁ a₂ : Arg) : EqvM Bool := do
match a₁, a₂ with
| .type e₁, .type e₂ => eqvType e₁ e₂
| .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂
| .erased, .erased => return true
| _, _ => return false
def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do
if as₁.size = as₂.size then
for a₁ in as₁, a₂ in as₂ do
unless (← eqvArg a₁ a₂) do
return false
return true
else
return false
def eqvLetExpr (e₁ e₂ : LetExpr) : EqvM Bool := do
match e₁, e₂ with
| .value v₁, .value v₂ => return v₁ == v₂
| .erased, .erased => return true
| .proj s₁ i₁ x₁, .proj s₂ i₂ x₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂
| .const n₁ us₁ as₁, .const n₂ us₂ as₂ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂
| .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂
| _, _ => return false
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
withReader (·.insert fvarId₂ fvarId₁) x
@ -48,7 +70,7 @@ def eqvExprs (es₁ es₂ : Array Expr) : EqvM Bool := do
let p₁ := params₁[i]
have : i < params₂.size := by simp_all_arith
let p₂ := params₂[i]
unless (← eqvExpr p₁.type p₂.type) do return false
unless (← eqvType p₁.type p₂.type) do return false
withFVar p₁.fvarId p₂.fvarId do
go (i+1)
else
@ -84,20 +106,20 @@ partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do
partial def eqv (code₁ code₂ : Code) : EqvM Bool := do
match code₁, code₂ with
| .let decl₁ k₁, .let decl₂ k₂ =>
eqvExpr decl₁.type decl₂.type <&&>
eqvExpr decl₁.value decl₂.value <&&>
eqvType decl₁.type decl₂.type <&&>
eqvLetExpr decl₁.value decl₂.value <&&>
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
| .fun decl₁ k₁, .fun decl₂ k₂
| .jp decl₁ k₁, .jp decl₂ k₂ =>
eqvExpr decl₁.type decl₂.type <&&>
eqvType decl₁.type decl₂.type <&&>
withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&>
withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂)
| .return fvarId₁, .return fvarId₂ => eqvFVar fvarId₁ fvarId₂
| .unreach type₁, .unreach type₂ => eqvExpr type₁ type₂
| .jmp fvarId₁ args₁, .jmp fvarId₂ args₂ => eqvFVar fvarId₁ fvarId₂ <&&> eqvExprs args₁ args₂
| .unreach type₁, .unreach type₂ => eqvType type₁ type₂
| .jmp fvarId₁ args₁, .jmp fvarId₂ args₂ => eqvFVar fvarId₁ fvarId₂ <&&> eqvArgs args₁ args₂
| .cases c₁, .cases c₂ =>
eqvFVar c₁.discr c₂.discr <&&>
eqvExpr c₁.resultType c₂.resultType <&&>
eqvType c₁.resultType c₂.resultType <&&>
eqvAlts c₁.alts c₂.alts
| _, _ => return false

View file

@ -7,6 +7,7 @@ import Lean.Expr
import Lean.Meta.Instances
import Lean.Compiler.InlineAttrs
import Lean.Compiler.Specialize
import Lean.Compiler.LCNF.Types
namespace Lean.Compiler.LCNF
@ -34,11 +35,90 @@ inductive AltCore (Code : Type) where
| default (code : Code)
deriving Inhabited
inductive Value where
| natVal (val : Nat)
| strVal (val : String)
-- TODO: add constructors for `Int`, `Float`, `UInt` ...
deriving Inhabited, BEq, Hashable
inductive Arg where
| erased
| fvar (fvarId : FVarId)
| type (expr : Expr)
deriving Inhabited, BEq, Hashable
def Arg.toExpr (arg : Arg) : Expr :=
match arg with
| .erased => erasedExpr
| .fvar fvarId => .fvar fvarId
| .type e => e
private unsafe def Arg.updateTypeImp (arg : Arg) (type' : Expr) : Arg :=
match arg with
| .type ty => if ptrEq ty type' then arg else .type type'
| _ => unreachable!
@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg) (type : Expr) : Arg
private unsafe def Arg.updateFVarImp (arg : Arg) (fvarId' : FVarId) : Arg :=
match arg with
| .fvar fvarId => if fvarId' == fvarId then arg else .fvar fvarId'
| _ => unreachable!
@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg) (fvarId' : FVarId) : Arg
inductive LetExpr where
| value (value : Value)
| erased
| proj (typeName : Name) (idx : Nat) (struct : FVarId)
| const (declName : Name) (us : List Level) (args : Array Arg)
| fvar (fvarId : FVarId) (args : Array Arg)
-- TODO: add constructors for mono and impure phases
deriving Inhabited, BEq, Hashable
private unsafe def LetExpr.updateProjImp (e : LetExpr) (fvarId' : FVarId) : LetExpr :=
match e with
| .proj s i fvarId => if fvarId == fvarId' then e else .proj s i fvarId'
| _ => unreachable!
@[implemented_by LetExpr.updateProjImp] opaque LetExpr.updateProj! (e : LetExpr) (fvarId' : FVarId) : LetExpr
private unsafe def LetExpr.updateConstImp (e : LetExpr) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetExpr :=
match e with
| .const declName us args => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args'
| _ => unreachable!
@[implemented_by LetExpr.updateConstImp] opaque LetExpr.updateConst! (e : LetExpr) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetExpr
private unsafe def LetExpr.updateFVarImp (e : LetExpr) (fvarId' : FVarId) (args' : Array Arg) : LetExpr :=
match e with
| .fvar fvarId args => if fvarId == fvarId' && ptrEq args args' then e else .fvar fvarId' args'
| _ => unreachable!
@[implemented_by LetExpr.updateFVarImp] opaque LetExpr.updateFVar! (e : LetExpr) (fvarId' : FVarId) (args' : Array Arg) : LetExpr
private unsafe def LetExpr.updateArgsImp (e : LetExpr) (args' : Array Arg) : LetExpr :=
match e with
| .const declName us args => if ptrEq args args' then e else .const declName us args'
| .fvar fvarId args => if ptrEq args args' then e else .fvar fvarId args'
| _ => unreachable!
@[implemented_by LetExpr.updateArgsImp] opaque LetExpr.updateArgs! (e : LetExpr) (args' : Array Arg) : LetExpr
def LetExpr.toExpr (e : LetExpr) : Expr :=
match e with
| .value (.natVal val) => .lit (.natVal val)
| .value (.strVal val) => .lit (.strVal val)
| .erased => erasedExpr
| .proj n i s => .proj n i (.fvar s)
| .const n us as => mkAppN (.const n us) (as.map Arg.toExpr)
| .fvar fvarId as => mkAppN (.fvar fvarId) (as.map Arg.toExpr)
structure LetDecl where
fvarId : FVarId
binderName : Name
type : Expr
value : Expr
value : LetExpr
deriving Inhabited, BEq
structure FunDeclCore (Code : Type) where
@ -63,7 +143,7 @@ inductive Code where
| let (decl : LetDecl) (k : Code)
| fun (decl : FunDeclCore Code) (k : Code)
| jp (decl : FunDeclCore Code) (k : Code)
| jmp (fvarId : FVarId) (args : Array Expr)
| jmp (fvarId : FVarId) (args : Array Arg)
| cases (cases : CasesCore Code)
| return (fvarId : FVarId)
| unreach (type : Expr)
@ -218,12 +298,12 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al
@[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code :=
@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code :=
match c with
| .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args'
| _ => unreachable!
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code
@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code
@[inline] private unsafe def updateUnreachImp (c : Code) (type' : Expr) : Code :=
match c with
@ -245,7 +325,7 @@ to be updated.
-/
@[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param) (type : Expr) : Param
private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : Expr) : LetDecl :=
private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : LetExpr) : LetDecl :=
if ptrEq type decl.type && ptrEq value decl.value then
decl
else
@ -256,7 +336,7 @@ Low-level update `LetDecl` function. It does not update the local context.
Consider using `LetDecl.update : LetDecl → Expr → Expr → CompilerM LetDecl` if you want the local context
to be updated.
-/
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : Expr) : LetDecl
@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : LetExpr) : LetDecl
private unsafe def updateFunDeclCoreImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl :=
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
@ -459,6 +539,9 @@ def Decl.instantiateParamsLevelParams (decl : Decl) (us : List Level) : Array Pa
partial def Decl.instantiateValueLevelParams (decl : Decl) (us : List Level) : Code :=
instCode decl.value
where
instLevel (u : Level) :=
u.instantiateParams decl.levelParams us
instExpr (e : Expr) :=
e.instantiateLevelParamsNoCache decl.levelParams us
@ -470,8 +553,19 @@ where
| .default k => alt.updateCode (instCode k)
| .alt _ ps k => alt.updateAlt! (instParams ps) (instCode k)
instArg (arg : Arg) : Arg :=
match arg with
| .type e => arg.updateType! (instExpr e)
| .fvar .. | .erased => arg
instLetExpr (e : LetExpr) : LetExpr :=
match e with
| .const declName vs args => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg)
| .fvar fvarId args => e.updateFVar! fvarId (args.mapMono instArg)
| .proj .. | .value .. | .erased => e
instLetDecl (decl : LetDecl) :=
decl.updateCore (instExpr decl.type) (instExpr decl.value)
decl.updateCore (instExpr decl.type) (instLetExpr decl.value)
instFunDecl (decl : FunDecl) :=
decl.updateCore (instExpr decl.type) (instParams decl.params) (instCode decl.value)
@ -481,7 +575,7 @@ where
| .let decl k => code.updateLet! (instLetDecl decl) (instCode k)
| .jp decl k | .fun decl k => code.updateFun! (instFunDecl decl) (instCode k)
| .cases c => code.updateCases! (instExpr c.resultType) c.discr (c.alts.mapMono instAlt)
| .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instExpr)
| .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instArg)
| .return .. => code
| .unreach type => code.updateUnreach! (instExpr type)
@ -506,45 +600,56 @@ def Decl.isTemplateLike (decl : Decl) : CoreM Bool := do
else
return false
mutual
partial def FunDeclCore.collectUsed (decl : FunDecl) (s : FVarIdSet := {}) : FVarIdSet :=
decl.value.collectUsed <| collectParams decl.params <| collectExpr decl.type s
private partial def collectType (e : Expr) : FVarIdSet → FVarIdSet :=
match e with
| .forallE _ d b _ => collectType b ∘ collectType d
| .lam _ d b _ => collectType b ∘ collectType d
| .app f a => collectType f ∘ collectType a
| .fvar fvarId => fun s => s.insert fvarId
| .proj .. | .letE .. | .mdata .. => unreachable!
| _ => id
private def collectArg (arg : Arg) (s : FVarIdSet) : FVarIdSet :=
match arg with
| .erased => s
| .fvar fvarId => s.insert fvarId
| .type e => collectType e s
private def collectArgs (args : Array Arg) (s : FVarIdSet) : FVarIdSet :=
args.foldl (init := s) fun s arg => collectArg arg s
private def collectLetExpr (e : LetExpr) (s : FVarIdSet) : FVarIdSet :=
match e with
| .fvar fvarId args => collectArgs args <| s.insert fvarId
| .const _ _ args => collectArgs args s
| .proj _ _ fvarId => s.insert fvarId
| .value .. | .erased => s
private partial def collectParams (ps : Array Param) (s : FVarIdSet) : FVarIdSet :=
ps.foldl (init := s) fun s p => collectExpr p.type s
ps.foldl (init := s) fun s p => collectType p.type s
private partial def collectExprs (es : Array Expr) (s : FVarIdSet) : FVarIdSet :=
es.foldl (init := s) fun s e => collectExpr e s
private partial def collectExpr (e : Expr) : FVarIdSet → FVarIdSet :=
match e with
| .proj _ _ e => collectExpr e
| .forallE _ d b _ => collectExpr b ∘ collectExpr d
| .lam _ d b _ => collectExpr b ∘ collectExpr d
| .letE .. => unreachable!
| .app f a => collectExpr f ∘ collectExpr a
| .mdata _ b => collectExpr b
| .fvar fvarId => fun s => s.insert fvarId
| _ => id
mutual
partial def FunDeclCore.collectUsed (decl : FunDecl) (s : FVarIdSet := {}) : FVarIdSet :=
decl.value.collectUsed <| collectParams decl.params <| collectType decl.type s
partial def Code.collectUsed (code : Code) (s : FVarIdSet := {}) : FVarIdSet :=
match code with
| .let decl k => k.collectUsed <| collectExpr decl.value <| collectExpr decl.type s
| .let decl k => k.collectUsed <| collectLetExpr decl.value <| collectType decl.type s
| .jp decl k | .fun decl k => k.collectUsed <| decl.collectUsed s
| .cases c =>
let s := s.insert c.discr
let s := collectExpr c.resultType s
let s := collectType c.resultType s
c.alts.foldl (init := s) fun s alt =>
match alt with
| .default k => k.collectUsed s
| .alt _ ps k => k.collectUsed <| collectParams ps s
| .return fvarId => s.insert fvarId
| .unreach type => collectExpr type s
| .jmp fvarId args => collectExprs args <| s.insert fvarId
| .unreach type => collectType type s
| .jmp fvarId args => collectArgs args <| s.insert fvarId
end
abbrev collectUsedAtExpr (s : FVarIdSet) (e : Expr) : FVarIdSet :=
collectExpr e s
collectType e s
/--
Traverse the given block of potentially mutually recursive functions
@ -568,7 +673,7 @@ where
| .cases c => c.alts.forM fun alt => visit alt.getCode
| .unreach .. | .jmp .. | .return .. => return ()
| .let decl k =>
if let .const declName _ := decl.value.getAppFn then
if let .const declName _ _ := decl.value then
if decls.any (·.name == declName) then
modify fun s => s.insert declName
visit k

View file

@ -57,12 +57,13 @@ where
| .let decl k =>
let decl ← normLetDecl decl
-- We only apply CSE to pure code
match (← get).map.find? decl.value with
let key := decl.value.toExpr
match (← get).map.find? key with
| some fvarId =>
replaceLet decl fvarId
go k
| none =>
addEntry decl.value decl.fvarId
addEntry key decl.fvarId
return code.updateLet! decl (← go k)
| .fun decl k =>
let decl ← goFunDecl decl
@ -91,7 +92,7 @@ where
| .default k => withNewScope do return alt.updateCode (← go k)
return code.updateCases! resultType discr alts
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normArgs args)
| .unreach .. => return code
end CSE

View file

@ -71,7 +71,20 @@ mutual
contain other type parameters.
-/
partial def collectParams (params : Array Param) : ClosureM Unit :=
params.forM (collectExpr ·.type)
params.forM (collectType ·.type)
partial def collectArg (arg : Arg) : ClosureM Unit :=
match arg with
| .erased => return ()
| .type e => collectType e
| .fvar fvarId => collectFVar fvarId
partial def collectLetExpr (e : LetExpr) : ClosureM Unit := do
match e with
| .erased | .value .. => return ()
| .proj _ _ fvarId => collectFVar fvarId
| .const _ _ args => args.forM collectArg
| .fvar fvarId args => collectFVar fvarId; args.forM collectArg
/--
Collect dependencies in the given code. We need this function to be able
@ -79,22 +92,22 @@ mutual
-/
partial def collectCode (c : Code) : ClosureM Unit := do
match c with
| .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k
| .let decl k => collectType decl.type; collectLetExpr decl.value; collectCode k
| .fun decl k | .jp decl k => collectFunDecl decl; collectCode k
| .cases c =>
collectExpr c.resultType
collectType c.resultType
collectFVar c.discr
c.alts.forM fun alt => do
match alt with
| .default k => collectCode k
| .alt _ ps k => collectParams ps; collectCode k
| .jmp _ args => args.forM collectExpr
| .unreach type => collectExpr type
| .jmp _ args => args.forM collectArg
| .unreach type => collectType type
| .return fvarId => collectFVar fvarId
/-- Collect dependencies of a local function declaration. -/
partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do
collectExpr decl.type
collectType decl.type
collectParams decl.params
collectCode decl.value
@ -114,21 +127,21 @@ mutual
collectFunDecl funDecl
modify fun s => { s with decls := s.decls.push <| .fun funDecl }
else if let some param ← findParam? fvarId then
collectExpr param.type
collectType param.type
modify fun s => { s with params := s.params.push param }
else if let some letDecl ← findLetDecl? fvarId then
collectExpr letDecl.type
collectType letDecl.type
if (← read).abstract letDecl.fvarId then
modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } }
else
collectExpr letDecl.value
collectLetExpr letDecl.value
modify fun s => { s with decls := s.decls.push <| .let letDecl }
else
unreachable!
/-- Collect dependencies of the given expression. -/
partial def collectExpr (e : Expr) : ClosureM Unit := do
e.forEach fun e => do
partial def collectType (type : Expr) : ClosureM Unit := do
type.forEach fun e => do
match e with
| .fvar fvarId => collectFVar fvarId
| _ => pure ()

View file

@ -190,6 +190,23 @@ where
else
e
/--
Replace the free variables in `arg` using the given substitution.
See `normExprImp`
-/
private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : Arg :=
match arg with
| .erased => arg
| .fvar fvarId =>
match s.find? fvarId with
| some (.fvar fvarId') =>
let arg' := .fvar fvarId'
if translator then arg' else normArgImp s arg' translator
| some e => if e.isErased then .erased else .type e
| none => arg
| .type e => arg.updateType! (normExprImp s e translator)
/--
Normalize the given free variable.
See `normExprImp` for documentation on the `translator` parameter.
@ -208,6 +225,22 @@ private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator :
| some e => panic! s!"invalid LCNF substitution of free variable with expression {e}"
| none => fvarId
private def normArgsImp (s : FVarSubst) (args : Array Arg) (translator : Bool) : Array Arg :=
args.mapMono (normArgImp s · translator)
/--
Replace the free variables in `e` using the given substitution.
See `normExprImp`
-/
private partial def normLetExprImp (s : FVarSubst) (e : LetExpr) (translator : Bool) : LetExpr :=
match e with
| .erased | .value .. => e
| .proj _ _ fvarId => e.updateProj! (normFVarImp s fvarId translator)
| .const _ _ args => e.updateArgs! (normArgsImp s args translator)
| .fvar fvarId args => e.updateFVar! (normFVarImp s fvarId translator) (normArgsImp s args translator)
/--
Interface for monads that have a free substitutions.
-/
@ -248,15 +281,21 @@ See `Check.lean` for the free variable substitution checker.
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr :=
return normExprImp (← getSubst) e t
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m t] [Monad m] (arg : Arg) : m Arg :=
return normArgImp (← getSubst) arg t
@[inline, inherit_doc normLetExprImp] def normLetExpr [MonadFVarSubst m t] [Monad m] (e : LetExpr) : m LetExpr :=
return normLetExprImp (← getSubst) e t
@[inherit_doc normExprImp]
abbrev normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
normExprImp s e translator
/--
Normalize the given expressions using the current substitution.
Normalize the given arguments using the current substitution.
-/
def normExprs [MonadFVarSubst m t] [Monad m] (es : Array Expr) : m (Array Expr) :=
es.mapMonoM normExpr
def normArgs [MonadFVarSubst m t] [Monad m] (args : Array Arg) : m (Array Arg) :=
return normArgsImp (← getSubst) args t
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
let declName := .num binderName (← get).nextIdx
@ -280,7 +319,7 @@ def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param
modifyLCtx fun lctx => lctx.addParam param
return param
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) : CompilerM LetDecl := do
def mkLetDecl (binderName : Name) (type : Expr) (value : LetExpr) : CompilerM LetDecl := do
let fvarId ← mkFreshFVarId
let binderName ← ensureNotAnonymous binderName `_x
let decl := { fvarId, binderName, type, value }
@ -304,7 +343,7 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param :=
@[implemented_by updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetExpr) : CompilerM LetDecl := do
if ptrEq type decl.type && ptrEq value decl.value then
return decl
else
@ -312,9 +351,9 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr
modifyLCtx fun lctx => lctx.addLetDecl decl
return decl
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : LetExpr) : CompilerM LetDecl
def LetDecl.updateValue (decl : LetDecl) (value : Expr) : CompilerM LetDecl :=
def LetDecl.updateValue (decl : LetDecl) (value : LetExpr) : CompilerM LetDecl :=
decl.update decl.type value
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
@ -340,7 +379,7 @@ def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Arr
ps.mapMonoM normParam
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do
decl.update (← normExpr decl.type) (← normExpr decl.value)
decl.update (← normExpr decl.type) (← normLetExpr decl.value)
abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM
@ -359,7 +398,7 @@ mutual
| .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k)
| .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normArgs args)
| .unreach type => return code.updateUnreach! (← normExpr type)
| .cases c =>
let resultType ← normExpr c.resultType

View file

@ -12,19 +12,32 @@ private abbrev M := ReaderT FVarIdSet Id
private def fvarDepOn (fvarId : FVarId) : M Bool :=
return (← read).contains fvarId
private def exprDepOn (e : Expr) : M Bool := do
private def typeDepOn (e : Expr) : M Bool := do
let s ← read
return e.hasAnyFVar fun fvarId => s.contains fvarId
private def argDepOn (a : Arg) : M Bool := do
match a with
| .erased => return false
| .fvar fvarId => fvarDepOn fvarId
| .type e => typeDepOn e
private def letExprDepOn (e : LetExpr) : M Bool :=
match e with
| .erased | .value .. => return false
| .proj _ _ fvarId => fvarDepOn fvarId
| .fvar fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
| .const _ _ args => args.anyM argDepOn
private def LetDecl.depOn (decl : LetDecl) : M Bool :=
exprDepOn decl.type <||> exprDepOn decl.value
typeDepOn decl.type <||> letExprDepOn decl.value
private partial def depOn (c : Code) : M Bool :=
match c with
| .let decl k => decl.depOn <||> depOn k
| .jp decl k | .fun decl k => exprDepOn decl.type <||> depOn decl.value <||> depOn k
| .cases c => exprDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM exprDepOn
| .jp decl k | .fun decl k => typeDepOn decl.type <||> depOn decl.value <||> depOn k
| .cases c => typeDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
| .return fvarId => fvarDepOn fvarId
| .unreach _ => return false
@ -32,7 +45,7 @@ abbrev LetDecl.dependsOn (decl : LetDecl) (s : FVarIdSet) : Bool :=
decl.depOn s
abbrev FunDecl.dependsOn (decl : FunDecl) (s : FVarIdSet) : Bool :=
exprDepOn decl.type s || depOn decl.value s
typeDepOn decl.type s || depOn decl.value s
def CodeDecl.dependsOn (decl : CodeDecl) (s : FVarIdSet) : Bool :=
match decl with

View file

@ -13,26 +13,43 @@ abbrev UsedLocalDecls := FVarIdHashSet
Collect set of (let) free variables in a LCNF value.
This code exploits the LCNF property that local declarations do not occur in types.
-/
def collectLocalDecls (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls :=
go s e
def collectLocalDeclsType (s : UsedLocalDecls) (type : Expr) : UsedLocalDecls :=
go s type
where
go (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls :=
match e with
| .proj _ _ e => go s e
| .forallE .. => s
| .lam _ _ b _ => go s b
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
| .app f a => go (go s a) f
| .mdata _ b => go s b
| .fvar fvarId => s.insert fvarId
| .letE .. | .proj .. | .mdata .. => unreachable! -- Valid LCNF type does not contain this kind of expr
| _ => s
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg) : UsedLocalDecls :=
match arg with
| .erased => s
| .type e => collectLocalDeclsType s e
| .fvar fvarId => s.insert fvarId
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array Arg) : UsedLocalDecls :=
args.foldl (init := s) collectLocalDeclsArg
def collectLocalDeclsLetExpr (s : UsedLocalDecls) (e : LetExpr) : UsedLocalDecls :=
match e with
| .erased | .value .. => s
| .proj _ _ fvarId => s.insert fvarId
| .const _ _ args => collectLocalDeclsArgs s args
| .fvar fvarId args => collectLocalDeclsArgs (s.insert fvarId) args
namespace ElimDead
abbrev M := StateRefT UsedLocalDecls CompilerM
private abbrev collectExprM (e : Expr) : M Unit :=
modify (collectLocalDecls · e)
private abbrev collectArgM (arg : Arg) : M Unit :=
modify (collectLocalDeclsArg · arg)
private abbrev collectLetExprM (e : LetExpr) : M Unit :=
modify (collectLocalDeclsLetExpr · e)
private abbrev collectFVarM (fvarId : FVarId) : M Unit :=
modify (·.insert fvarId)
@ -48,7 +65,7 @@ partial def elimDead (code : Code) : M Code := do
let k ← elimDead k
if (← get).contains decl.fvarId then
/- Remark: we don't need to collect `decl.type` because LCNF local declarations do not occur in types. -/
collectExprM decl.value
collectLetExprM decl.value
return code.updateCont! k
else
eraseLetDecl decl
@ -66,7 +83,7 @@ partial def elimDead (code : Code) : M Code := do
collectFVarM c.discr
return code.updateAlts! alts
| .return fvarId => collectFVarM fvarId; return code
| .jmp fvarId args => collectFVarM fvarId; args.forM collectExprM; return code
| .jmp fvarId args => collectFVarM fvarId; args.forM collectArgM; return code
| .unreach .. => return code
end

View file

@ -34,7 +34,7 @@ marked as not fixed.
-/
/-- Abstract value for the "fixed parameter" analysis. -/
inductive Value where
inductive AbsValue where
| top
| erased
| val (i : Nat)
@ -52,7 +52,7 @@ structure Context where
The assignment maps free variable ids in the current code being analyzed to abstract values.
We only track the abstract value assigned to parameters.
-/
assignment : FVarIdMap Value
assignment : FVarIdMap AbsValue
structure State where
/--
@ -61,7 +61,7 @@ structure State where
Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and
we visit with the abstract values assigned to `aᵢ`, but first we record the visit here.
-/
visited : HashSet (Name × Array Value) := {}
visited : HashSet (Name × Array AbsValue) := {}
/--
Bitmask containing the result, i.e., which parameters of `main` are fixed.
We initialize it with `true` everywhere.
@ -76,17 +76,18 @@ abbrev abort : FixParamM α := do
modify fun s => { s with fixed := s.fixed.map fun _ => false }
throw ()
def evalArg (arg : Expr) : FixParamM Value := do
if arg.isErased then
return .erased
let .fvar fvarId := arg | return .top
let some val := (← read).assignment.find? fvarId | return .top
return val
def evalArg (arg : Arg) : FixParamM AbsValue := do
match arg with
| .erased => return .erased
| .type _ => return .top
| .fvar fvarId =>
let some val := (← read).assignment.find? fvarId | return .top
return val
def inMutualBlock (declName : Name) : FixParamM Bool :=
return (← read).decls.any (·.name == declName)
def mkAssignment (decl : Decl) (values : Array Value) : FVarIdMap Value := Id.run do
def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do
let mut assignment := {}
for param in decl.params, value in values do
assignment := assignment.insert param.fvarId value
@ -94,23 +95,19 @@ def mkAssignment (decl : Decl) (values : Array Value) : FVarIdMap Value := Id.ru
mutual
partial def evalExpr (e : Expr) : FixParamM Unit := do
partial def evalLetExpr (e : LetExpr) : FixParamM Unit := do
match e with
| .const declName _ => evalApp declName #[]
| .app .. =>
let .const declName _ := e.getAppFn | return ()
if (← inMutualBlock declName) then
evalApp declName e.getAppArgs
| .const declName _ args => evalApp declName args
| _ => return ()
partial def evalCode (code : Code) : FixParamM Unit := do
match code with
| .let decl k => evalExpr decl.value; evalCode k
| .let decl k => evalLetExpr decl.value; evalCode k
| .fun decl k | .jp decl k => evalCode decl.value; evalCode k
| .cases c => c.alts.forM fun alt => evalCode alt.getCode
| .unreach .. | .jmp .. | .return .. => return ()
partial def evalApp (declName : Name) (args : Array Expr) : FixParamM Unit := do
partial def evalApp (declName : Name) (args : Array Arg) : FixParamM Unit := do
let main := (← read).main
if declName == main.name then
-- Recursive call to the function being analyzed
@ -145,7 +142,7 @@ partial def evalApp (declName : Name) (args : Array Expr) : FixParamM Unit := do
end
def mkInitialValues (numParams : Nat) : Array Value := Id.run do
def mkInitialValues (numParams : Nat) : Array AbsValue := Id.run do
let mut values := #[]
for i in [:numParams] do
values := values.push <| .val i

View file

@ -8,44 +8,6 @@ import Lean.Compiler.LCNF.Basic
namespace Lean.Compiler.LCNF
partial def Code.forEachExpr [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m] (f : Expr → m Unit) (c : Code) (skipTypes := false) : m Unit := do
visit c |>.run
where
visit (c : Code) : MonadCacheT Expr Unit m Unit := do
match c with
| .let decl k =>
visitType decl.type
visitExpr decl.value
visit k
| .jp decl k | .fun decl k =>
visitType decl.type
decl.params.forM visitParam
visit decl.value
visit k
| .unreach type => visitType type
| .return .. => return ()
| .jmp _ args => args.forM visitExpr
| .cases c => visitType c.resultType; c.alts.forM fun alt => do
match alt with
| .default k => visit k
| .alt _ ps k => ps.forM visitParam; visit k
visitParam (p : Param) : MonadCacheT Expr Unit m Unit :=
visitType p.type
visitExpr (e : Expr) : MonadCacheT Expr Unit m Unit :=
ForEachExpr.visit (fun e => f e *> return true) e
visitType (e : Expr) : MonadCacheT Expr Unit m Unit :=
unless skipTypes do
visitExpr e
def Decl.forEachExpr [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m] (f : Expr → m Unit) (decl : Decl) (skipTypes := false) : m Unit := do
visit |>.run
where
visit : MonadCacheT Expr Unit m Unit := do
Code.forEachExpr.visitType f decl.type (skipTypes := skipTypes)
decl.params.forM (Code.forEachExpr.visitParam f (skipTypes := skipTypes))
Code.forEachExpr.visit f decl.value (skipTypes := skipTypes)
-- TODO: delete
end Lean.Compiler.LCNF

View file

@ -101,21 +101,22 @@ def inferConstType (declName : Name) (us : List Level) : CompilerM Expr := do
getOtherDeclType declName us
mutual
partial def inferArgType (arg : Arg) : InferTypeM Expr :=
match arg with
| .erased => return erasedExpr
| .type e => inferType e
| .fvar fvarId => LCNF.getType fvarId
-- TODO: stopped here
partial def inferType (e : Expr) : InferTypeM Expr :=
match e with
| .const c us => inferConstType c us
| .proj n i s => inferProjType n i s
| .app .. => inferAppType e
| .mvar .. => throwError "unexpected metavariable {e}"
| .fvar fvarId => InferType.getType fvarId
| .bvar .. => throwError "unexpected bound variable {e}"
| .mdata _ e => inferType e
| .lit v => return v.type
| .sort lvl => return .sort (mkLevelSucc lvl)
| .forallE .. => inferForallType e
| .lam .. => inferLambdaType e
| .letE .. => inferLambdaType e
| .letE .. | .mvar .. | .mdata .. | .lit .. | .bvar .. | .proj .. => unreachable!
partial def inferAppTypeCore (f : Expr) (args : Array Expr) : InferTypeM Expr := do
let mut j := 0

View file

@ -65,7 +65,7 @@ def LCtx.toLocalContext (lctx : LCtx) : LocalContext := Id.run do
for (_, param) in lctx.params.toArray do
result := result.addDecl (.cdecl 0 param.fvarId param.binderName param.type .default .default)
for (_, decl) in lctx.letDecls.toArray do
result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value true .default)
result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value.toExpr true .default)
for (_, decl) in lctx.funDecls.toArray do
result := result.addDecl (.cdecl 0 decl.fvarId decl.binderName decl.type .default .default)
return result

View file

@ -97,8 +97,25 @@ See `Decl.setLevelParams`.
-/
open Lean.CollectLevelParams
abbrev visitType (type : Expr) : Visitor :=
visitExpr type
def visitArg (arg : Arg) : Visitor :=
match arg with
| .erased | .fvar .. => id
| .type e => visitType e
def visitArgs (args : Array Arg) : Visitor :=
fun s => args.foldl (init := s) fun s arg => visitArg arg s
def visitLetExpr (e : LetExpr) : Visitor :=
match e with
| .erased | .value .. | .proj .. => id
| .const _ us args => visitLevels us ∘ visitArgs args
| .fvar _ args => visitArgs args
def visitParam (p : Param) : Visitor :=
visitExpr p.type
visitType p.type
def visitParams (ps : Array Param) : Visitor :=
fun s => ps.foldl (init := s) fun s p => visitParam p s
@ -113,12 +130,12 @@ mutual
fun s => alts.foldl (init := s) fun s alt => visitAlt alt s
partial def visitCode : Code → Visitor
| .let decl k => visitCode k ∘ visitExpr decl.value ∘ visitExpr decl.type
| .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitExpr decl.type
| .cases c => visitAlts c.alts ∘ visitExpr c.resultType
| .unreach type => visitExpr type
| .let decl k => visitCode k ∘ visitLetExpr decl.value ∘ visitType decl.type
| .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type
| .cases c => visitAlts c.alts ∘ visitType c.resultType
| .unreach type => visitType type
| .return _ => id
| .jmp _ args => fun s => args.foldl (init := s) fun s arg => visitExpr arg s
| .jmp _ args => visitArgs args
end
end CollectLevelParams
@ -131,7 +148,7 @@ Collect universe level parameters collecting in the type, parameters, and value,
set `decl.levelParams` with the resulting value.
-/
def Decl.setLevelParams (decl : Decl) : Decl :=
let levelParams := (visitCode decl.value ∘ visitParams decl.params ∘ visitExpr decl.type) {} |>.params.toList
let levelParams := (visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type) {} |>.params.toList
{ decl with levelParams }
end Lean.Compiler.LCNF

View file

@ -11,25 +11,29 @@ import Lean.Compiler.LCNF.CompilerM
namespace Lean.Compiler.LCNF
namespace Simp
partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
/-
-- TODO: cleanup
partial def findExpr (e : LetExpr) : CompilerM LetExpr := do
match e with
| .fvar fvarId =>
let some decl ← findLetDecl? fvarId | return e
findExpr decl.value
| .mdata _ e' => if skipMData then findExpr e' else return e
| .fvar fvarId args =>
if args.isEmpty then
let some decl ← findLetDecl? fvarId | return e
findExpr decl.value
else
return e
| _ => return e
partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do
partial def findFunDecl? (e : LetExpr) : CompilerM (Option FunDecl) := do
match e with
| .fvar fvarId =>
| .fvar fvarId args =>
if let some decl ← LCNF.findFunDecl? fvarId then
return some decl
else if let some decl ← findLetDecl? fvarId then
findFunDecl? decl.value
else
return none
| .mdata _ e => findFunDecl? e
| _ => return none
-/
end Simp
end Lean.Compiler.LCNF

View file

@ -74,6 +74,9 @@ end ToExpr
open ToExpr
private def Arg.toExprM (arg : Arg) : ToExprM Expr :=
return arg.toExpr.abstract' (← read) (← get)
mutual
partial def FunDeclCore.toExprM (decl : FunDecl) : ToExprM Expr :=
withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM)
@ -82,7 +85,7 @@ partial def Code.toExprM (code : Code) : ToExprM Expr := do
match code with
| .let decl k =>
let type ← abstractM decl.type
let value ← abstractM decl.value
let value ← abstractM decl.value.toExpr
let body ← withFVar decl.fvarId k.toExprM
return .letE decl.binderName type value body true
| .fun decl k | .jp decl k =>
@ -91,7 +94,7 @@ partial def Code.toExprM (code : Code) : ToExprM Expr := do
let body ← withFVar decl.fvarId k.toExprM
return .letE decl.binderName type value body true
| .return fvarId => fvarId.toExprM
| .jmp fvarId args => return mkAppN (← fvarId.toExprM) (← args.mapM abstractM)
| .jmp fvarId args => return mkAppN (← fvarId.toExprM) (← args.mapM Arg.toExprM)
| .unreach type => return mkApp (mkConst ``lcUnreachable) (← abstractM type)
| .cases c =>
let alts ← c.alts.mapM fun