feat: simplify nested cases on the same discriminant
This commit is contained in:
parent
bff9cdbfb3
commit
9f44e9c858
1 changed files with 35 additions and 9 deletions
|
|
@ -103,6 +103,7 @@ structure Config where
|
|||
|
||||
structure Context where
|
||||
config : Config := {}
|
||||
discrCtorMap : FVarIdMap Expr := {}
|
||||
|
||||
structure State where
|
||||
/--
|
||||
|
|
@ -142,6 +143,24 @@ abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
|
|||
instance : MonadFVarSubst SimpM where
|
||||
getSubst := return (← get).subst
|
||||
|
||||
/--
|
||||
Execute `x` with the information that `discr = ctorName ctorFields`.
|
||||
We use this information to simplify nested cases on the same discriminant.
|
||||
|
||||
Remark: we do not perform the reverse direction at this phase.
|
||||
That is, we do not replace occurrences of `ctorName ctorFields` with `discr`.
|
||||
We wait more type information to be erased.
|
||||
-/
|
||||
def withDiscrCtor (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : SimpM α) : SimpM α := do
|
||||
let ctorInfo ← getConstInfoCtor ctorName
|
||||
let mut ctor := mkConst ctorName
|
||||
for _ in [:ctorInfo.numParams] do
|
||||
ctor := .app ctor erasedExpr -- the parameters are irrelevant for optimizations that use this information
|
||||
for field in ctorFields do
|
||||
ctor := .app ctor (.fvar field.fvarId)
|
||||
withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor }) do
|
||||
x
|
||||
|
||||
def markSimplified : SimpM Unit :=
|
||||
modify fun s => { s with simplified := true }
|
||||
|
||||
|
|
@ -503,8 +522,10 @@ where
|
|||
return go funDecl.value 0
|
||||
|
||||
def findCtor (e : Expr) : SimpM Expr := do
|
||||
-- TODO: add support for mapping discriminants to constructors in branches
|
||||
findExpr e
|
||||
let e ← findExpr e
|
||||
let .fvar fvarId := e | return e
|
||||
let some ctor := (← read).discrCtorMap.find? fvarId | return e
|
||||
return ctor
|
||||
|
||||
/--
|
||||
Try to simplify projections `.proj _ i s` where `s` is constructor.
|
||||
|
|
@ -533,6 +554,11 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
|
|||
markSimplified
|
||||
return mkAppN f e.getAppArgs
|
||||
|
||||
/-- Try to apply simple simplifications. -/
|
||||
def simpValue? (e : Expr) : SimpM (Option Expr) :=
|
||||
-- TODO: more simplifications
|
||||
simpProj? e <|> simpAppApp? e
|
||||
|
||||
def eraseLocalDecl (fvarId : FVarId) : SimpM Unit := do
|
||||
eraseFVar fvarId
|
||||
markSimplified
|
||||
|
|
@ -544,11 +570,6 @@ it is a type, type former, or `lcErased`.
|
|||
def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit :=
|
||||
modify fun s => { s with subst := s.subst.insert fvarId val }
|
||||
|
||||
/-- Try to apply simple simplifications. -/
|
||||
def simpValue? (e : Expr) : SimpM (Option Expr) :=
|
||||
-- TODO: more simplifications
|
||||
simpProj? e <|> simpAppApp? e
|
||||
|
||||
mutual
|
||||
partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
|
||||
let type ← normExpr decl.type
|
||||
|
|
@ -559,7 +580,7 @@ partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
|
|||
/-- Try to simplify `cases` of `constructor` -/
|
||||
partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
||||
let discr ← normFVar cases.discr
|
||||
let discrExpr ← findExpr (.fvar discr)
|
||||
let discrExpr ← findCtor (.fvar discr)
|
||||
let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) | return none
|
||||
let (alt, cases) := cases.extractAlt! ctorVal.name
|
||||
eraseFVarsAt (.cases cases)
|
||||
|
|
@ -656,7 +677,12 @@ partial def simp (code : Code) : SimpM Code := do
|
|||
let discr ← normFVar c.discr
|
||||
let resultType ← normExpr c.resultType
|
||||
markUsedFVar discr
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
|
||||
let alts ← c.alts.mapMonoM fun alt =>
|
||||
match alt with
|
||||
| .alt ctorName ps k =>
|
||||
withDiscrCtor discr ctorName ps do
|
||||
return alt.updateCode (← simp k)
|
||||
| .default k => return alt.updateCode (← simp k)
|
||||
return code.updateCases! resultType discr alts
|
||||
if let some jpFVarId ← isCasesOnCases? c then
|
||||
withAddMustInline jpFVarId simpCasesDefault
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue