lean4-htt/src/Lean/Compiler/LCNF/CSE.lean
2022-09-26 05:46:04 -07:00

112 lines
3.3 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.ToExpr
import Lean.Compiler.LCNF.PassManager
namespace Lean.Compiler.LCNF
/-! Common Sub-expression Elimination -/
namespace CSE
structure State where
map : PHashMap Expr FVarId := {}
subst : FVarSubst := {}
abbrev M := StateRefT State CompilerM
instance : MonadFVarSubst M false where
getSubst := return (← get).subst
instance : MonadFVarSubstState M where
modifySubst f := modify fun s => { s with subst := f s.subst }
@[inline] def getSubst : M FVarSubst :=
return (← get).subst
@[inline] def addEntry (value : Expr) (fvarId : FVarId) : M Unit :=
modify fun s => { s with map := s.map.insert value fvarId }
@[inline] def withNewScope (x : M α) : M α := do
let map := (← get).map
try x finally modify fun s => { s with map }
def replaceLet (decl : LetDecl) (fvarId : FVarId) : M Unit := do
eraseLetDecl decl
addFVarSubst decl.fvarId fvarId
def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do
eraseFunDecl decl
addFVarSubst decl.fvarId fvarId
partial def _root_.Lean.Compiler.LCNF.Code.cse (code : Code) : CompilerM Code :=
go code |>.run' {}
where
goFunDecl (decl : FunDecl) : M FunDecl := do
let type ← normExpr decl.type
let params ← normParams decl.params
let value ← withNewScope do go decl.value
decl.update type params value
go (code : Code) : M Code := do
match code with
| .let decl k =>
let decl ← normLetDecl decl
-- We only apply CSE to pure code
match (← get).map.find? decl.value with
| some fvarId =>
replaceLet decl fvarId
go k
| none =>
addEntry decl.value decl.fvarId
return code.updateLet! decl (← go k)
| .fun decl k =>
let decl ← goFunDecl decl
let value := decl.toExpr
match (← get).map.find? value with
| some fvarId' =>
replaceFun decl fvarId'
go k
| none =>
addEntry value decl.fvarId
return code.updateFun! decl (← go k)
| .jp decl k =>
let decl ← goFunDecl decl
/-
We currently don't eliminate common join points because we want to prevent
jumps to out-of-scope join points.
-/
return code.updateFun! decl (← go k)
| .cases c =>
let discr ← normFVar c.discr
let resultType ← normExpr c.resultType
let alts ← c.alts.mapMonoM fun alt => do
match alt with
| .alt _ ps k => withNewScope do
return alt.updateAlt! (← normParams ps) (← go k)
| .default k => withNewScope do return alt.updateCode (← go k)
return code.updateCases! resultType discr alts
| .return fvarId => return code.updateReturn! (← normFVar fvarId)
| .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args)
| .unreach .. => return code
end CSE
/--
Common sub-expression elimination
-/
def Decl.cse (decl : Decl) : CompilerM Decl := do
let value ← decl.value.cse
return { decl with value }
def cse : Pass :=
.mkPerDeclaration `cse Decl.cse .base
builtin_initialize
registerTraceClass `Compiler.cse (inherited := true)
end Lean.Compiler.LCNF