112 lines
3.3 KiB
Text
112 lines
3.3 KiB
Text
/-
|
||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
import Lean.Compiler.LCNF.CompilerM
|
||
import Lean.Compiler.LCNF.ToExpr
|
||
import Lean.Compiler.LCNF.PassManager
|
||
|
||
namespace Lean.Compiler.LCNF
|
||
|
||
/-! Common Sub-expression Elimination -/
|
||
|
||
namespace CSE
|
||
|
||
structure State where
|
||
map : PHashMap Expr FVarId := {}
|
||
subst : FVarSubst := {}
|
||
|
||
abbrev M := StateRefT State CompilerM
|
||
|
||
instance : MonadFVarSubst M false where
|
||
getSubst := return (← get).subst
|
||
|
||
instance : MonadFVarSubstState M where
|
||
modifySubst f := modify fun s => { s with subst := f s.subst }
|
||
|
||
@[inline] def getSubst : M FVarSubst :=
|
||
return (← get).subst
|
||
|
||
@[inline] def addEntry (value : Expr) (fvarId : FVarId) : M Unit :=
|
||
modify fun s => { s with map := s.map.insert value fvarId }
|
||
|
||
@[inline] def withNewScope (x : M α) : M α := do
|
||
let map := (← get).map
|
||
try x finally modify fun s => { s with map }
|
||
|
||
def replaceLet (decl : LetDecl) (fvarId : FVarId) : M Unit := do
|
||
eraseLetDecl decl
|
||
addFVarSubst decl.fvarId fvarId
|
||
|
||
def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do
|
||
eraseFunDecl decl
|
||
addFVarSubst decl.fvarId fvarId
|
||
|
||
partial def _root_.Lean.Compiler.LCNF.Code.cse (code : Code) : CompilerM Code :=
|
||
go code |>.run' {}
|
||
where
|
||
goFunDecl (decl : FunDecl) : M FunDecl := do
|
||
let type ← normExpr decl.type
|
||
let params ← normParams decl.params
|
||
let value ← withNewScope do go decl.value
|
||
decl.update type params value
|
||
|
||
go (code : Code) : M Code := do
|
||
match code with
|
||
| .let decl k =>
|
||
let decl ← normLetDecl decl
|
||
-- We only apply CSE to pure code
|
||
match (← get).map.find? decl.value with
|
||
| some fvarId =>
|
||
replaceLet decl fvarId
|
||
go k
|
||
| none =>
|
||
addEntry decl.value decl.fvarId
|
||
return code.updateLet! decl (← go k)
|
||
| .fun decl k =>
|
||
let decl ← goFunDecl decl
|
||
let value := decl.toExpr
|
||
match (← get).map.find? value with
|
||
| some fvarId' =>
|
||
replaceFun decl fvarId'
|
||
go k
|
||
| none =>
|
||
addEntry value decl.fvarId
|
||
return code.updateFun! decl (← go k)
|
||
| .jp decl k =>
|
||
let decl ← goFunDecl decl
|
||
/-
|
||
We currently don't eliminate common join points because we want to prevent
|
||
jumps to out-of-scope join points.
|
||
-/
|
||
return code.updateFun! decl (← go k)
|
||
| .cases c =>
|
||
let discr ← normFVar c.discr
|
||
let resultType ← normExpr c.resultType
|
||
let alts ← c.alts.mapMonoM fun alt => do
|
||
match alt with
|
||
| .alt _ ps k => withNewScope do
|
||
return alt.updateAlt! (← normParams ps) (← go k)
|
||
| .default k => withNewScope do return alt.updateCode (← go k)
|
||
return code.updateCases! resultType discr alts
|
||
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
|
||
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
|
||
| .unreach .. => return code
|
||
|
||
end CSE
|
||
|
||
/--
|
||
Common sub-expression elimination
|
||
-/
|
||
def Decl.cse (decl : Decl) : CompilerM Decl := do
|
||
let value ← decl.value.cse
|
||
return { decl with value }
|
||
|
||
def cse : Pass :=
|
||
.mkPerDeclaration `cse Decl.cse .base
|
||
|
||
builtin_initialize
|
||
registerTraceClass `Compiler.cse (inherited := true)
|
||
|
||
end Lean.Compiler.LCNF
|