feat: simplify nested cases on the same discriminant

This commit is contained in:
Leonardo de Moura 2022-09-03 19:23:04 -07:00
parent bff9cdbfb3
commit 9f44e9c858

View file

@ -103,6 +103,7 @@ structure Config where
structure Context where
config : Config := {}
discrCtorMap : FVarIdMap Expr := {}
structure State where
/--
@ -142,6 +143,24 @@ abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
instance : MonadFVarSubst SimpM where
getSubst := return (← get).subst
/--
Execute `x` with the information that `discr = ctorName ctorFields`.
We use this information to simplify nested cases on the same discriminant.
Remark: we do not perform the reverse direction at this phase.
That is, we do not replace occurrences of `ctorName ctorFields` with `discr`.
We wait more type information to be erased.
-/
def withDiscrCtor (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : SimpM α) : SimpM α := do
let ctorInfo ← getConstInfoCtor ctorName
let mut ctor := mkConst ctorName
for _ in [:ctorInfo.numParams] do
ctor := .app ctor erasedExpr -- the parameters are irrelevant for optimizations that use this information
for field in ctorFields do
ctor := .app ctor (.fvar field.fvarId)
withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor }) do
x
def markSimplified : SimpM Unit :=
modify fun s => { s with simplified := true }
@ -503,8 +522,10 @@ where
return go funDecl.value 0
def findCtor (e : Expr) : SimpM Expr := do
-- TODO: add support for mapping discriminants to constructors in branches
findExpr e
let e ← findExpr e
let .fvar fvarId := e | return e
let some ctor := (← read).discrCtorMap.find? fvarId | return e
return ctor
/--
Try to simplify projections `.proj _ i s` where `s` is constructor.
@ -533,6 +554,11 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
markSimplified
return mkAppN f e.getAppArgs
/-- Try to apply simple simplifications. -/
def simpValue? (e : Expr) : SimpM (Option Expr) :=
-- TODO: more simplifications
simpProj? e <|> simpAppApp? e
def eraseLocalDecl (fvarId : FVarId) : SimpM Unit := do
eraseFVar fvarId
markSimplified
@ -544,11 +570,6 @@ it is a type, type former, or `lcErased`.
def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit :=
modify fun s => { s with subst := s.subst.insert fvarId val }
/-- Try to apply simple simplifications. -/
def simpValue? (e : Expr) : SimpM (Option Expr) :=
-- TODO: more simplifications
simpProj? e <|> simpAppApp? e
mutual
partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
let type ← normExpr decl.type
@ -559,7 +580,7 @@ partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do
/-- Try to simplify `cases` of `constructor` -/
partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
let discr ← normFVar cases.discr
let discrExpr ← findExpr (.fvar discr)
let discrExpr ← findCtor (.fvar discr)
let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) | return none
let (alt, cases) := cases.extractAlt! ctorVal.name
eraseFVarsAt (.cases cases)
@ -656,7 +677,12 @@ partial def simp (code : Code) : SimpM Code := do
let discr ← normFVar c.discr
let resultType ← normExpr c.resultType
markUsedFVar discr
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode)
let alts ← c.alts.mapMonoM fun alt =>
match alt with
| .alt ctorName ps k =>
withDiscrCtor discr ctorName ps do
return alt.updateCode (← simp k)
| .default k => return alt.updateCode (← simp k)
return code.updateCases! resultType discr alts
if let some jpFVarId ← isCasesOnCases? c then
withAddMustInline jpFVarId simpCasesDefault