From 9f44e9c858a79154936ebfa208acec075f05b651 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 3 Sep 2022 19:23:04 -0700 Subject: [PATCH] feat: simplify nested cases on the same discriminant --- src/Lean/Compiler/LCNF/Simp.lean | 44 +++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index e2f46c8b0c..b2225e583d 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -103,6 +103,7 @@ structure Config where structure Context where config : Config := {} + discrCtorMap : FVarIdMap Expr := {} structure State where /-- @@ -142,6 +143,24 @@ abbrev SimpM := ReaderT Context $ StateRefT State CompilerM instance : MonadFVarSubst SimpM where getSubst := return (← get).subst +/-- +Execute `x` with the information that `discr = ctorName ctorFields`. +We use this information to simplify nested cases on the same discriminant. + +Remark: we do not perform the reverse direction at this phase. +That is, we do not replace occurrences of `ctorName ctorFields` with `discr`. +We wait more type information to be erased. +-/ +def withDiscrCtor (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : SimpM α) : SimpM α := do + let ctorInfo ← getConstInfoCtor ctorName + let mut ctor := mkConst ctorName + for _ in [:ctorInfo.numParams] do + ctor := .app ctor erasedExpr -- the parameters are irrelevant for optimizations that use this information + for field in ctorFields do + ctor := .app ctor (.fvar field.fvarId) + withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor }) do + x + def markSimplified : SimpM Unit := modify fun s => { s with simplified := true } @@ -503,8 +522,10 @@ where return go funDecl.value 0 def findCtor (e : Expr) : SimpM Expr := do - -- TODO: add support for mapping discriminants to constructors in branches - findExpr e + let e ← findExpr e + let .fvar fvarId := e | return e + let some ctor := (← read).discrCtorMap.find? fvarId | return e + return ctor /-- Try to simplify projections `.proj _ i s` where `s` is constructor. @@ -533,6 +554,11 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do markSimplified return mkAppN f e.getAppArgs +/-- Try to apply simple simplifications. -/ +def simpValue? (e : Expr) : SimpM (Option Expr) := + -- TODO: more simplifications + simpProj? e <|> simpAppApp? e + def eraseLocalDecl (fvarId : FVarId) : SimpM Unit := do eraseFVar fvarId markSimplified @@ -544,11 +570,6 @@ it is a type, type former, or `lcErased`. def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit := modify fun s => { s with subst := s.subst.insert fvarId val } -/-- Try to apply simple simplifications. -/ -def simpValue? (e : Expr) : SimpM (Option Expr) := - -- TODO: more simplifications - simpProj? e <|> simpAppApp? e - mutual partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do let type ← normExpr decl.type @@ -559,7 +580,7 @@ partial def simpFunDecl (decl : FunDecl) : SimpM FunDecl := do /-- Try to simplify `cases` of `constructor` -/ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do let discr ← normFVar cases.discr - let discrExpr ← findExpr (.fvar discr) + let discrExpr ← findCtor (.fvar discr) let some (ctorVal, ctorArgs) := discrExpr.constructorApp? (← getEnv) | return none let (alt, cases) := cases.extractAlt! ctorVal.name eraseFVarsAt (.cases cases) @@ -656,7 +677,12 @@ partial def simp (code : Code) : SimpM Code := do let discr ← normFVar c.discr let resultType ← normExpr c.resultType markUsedFVar discr - let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← simp alt.getCode) + let alts ← c.alts.mapMonoM fun alt => + match alt with + | .alt ctorName ps k => + withDiscrCtor discr ctorName ps do + return alt.updateCode (← simp k) + | .default k => return alt.updateCode (← simp k) return code.updateCases! resultType discr alts if let some jpFVarId ← isCasesOnCases? c then withAddMustInline jpFVarId simpCasesDefault