diff --git a/src/Lean/Compiler/LCNF/Simp/SimpM.lean b/src/Lean/Compiler/LCNF/Simp/SimpM.lean index fb846e8abd..0d304ef12f 100644 --- a/src/Lean/Compiler/LCNF/Simp/SimpM.lean +++ b/src/Lean/Compiler/LCNF/Simp/SimpM.lean @@ -24,8 +24,15 @@ structure Context where -/ declName : Name config : Config := {} + /-- + A mapping from discriminant to constructor application it is equal to in the current context. + -/ discrCtorMap : FVarIdMap Expr := {} /-- + A mapping from constructor application to discriminant it is equal to in the current context. + -/ + ctorDiscrMap : PersistentExprMap FVarId := {} + /-- Stack of global declarations being recursively inlined. -/ inlineStack : List Name := [] @@ -90,6 +97,16 @@ def findCtor (e : Expr) : SimpM Expr := do let some ctor := (← read).discrCtorMap.find? fvarId | return e return ctor +/-- +If `type` is an inductive datatype, return its universe levels and parameters. +-/ +def getIndInfo? (type : Expr) : CoreM (Option (List Level × Array Expr)) := do + let type := type.headBeta + let .const declName us := type.getAppFn | return none + let .inductInfo info ← getConstInfo declName | return none + unless type.getAppNumArgs >= info.numParams do return none + return some (us, type.getAppArgs[:info.numParams]) + /-- Execute `x` with the information that `discr = ctorName ctorFields`. We use this information to simplify nested cases on the same discriminant. @@ -100,13 +117,16 @@ We wait more type information to be erased. -/ def withDiscrCtor (discr : FVarId) (ctorName : Name) (ctorFields : Array Param) (x : SimpM α) : SimpM α := do let ctorInfo ← getConstInfoCtor ctorName - let mut ctor := mkConst ctorName - for _ in [:ctorInfo.numParams] do - ctor := .app ctor erasedExpr -- the parameters are irrelevant for optimizations that use this information - for field in ctorFields do - ctor := .app ctor (.fvar field.fvarId) - withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor }) do - x + let fieldArgs := ctorFields.map (.fvar ·.fvarId) + if let some (us, params) ← getIndInfo? (← getType discr) then + let ctor := mkAppN (mkAppN (mkConst ctorName us) params) fieldArgs + withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor, ctorDiscrMap := ctx.ctorDiscrMap.insert ctor discr }) do + x + else + -- For the discrCtor map, the constructor parameters are irrelevant for optimizations that use this information + let ctor := mkAppN (mkAppN (mkConst ctorName) (mkArray ctorInfo.numParams erasedExpr)) fieldArgs + withReader (fun ctx => { ctx with discrCtorMap := ctx.discrCtorMap.insert discr ctor }) do + x /-- Set the `simplified` flag to `true`. -/ def markSimplified : SimpM Unit := diff --git a/src/Lean/Compiler/LCNF/Simp/SimpValue.lean b/src/Lean/Compiler/LCNF/Simp/SimpValue.lean index 77759d4fc4..9ba22c4186 100644 --- a/src/Lean/Compiler/LCNF/Simp/SimpValue.lean +++ b/src/Lean/Compiler/LCNF/Simp/SimpValue.lean @@ -35,6 +35,11 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do markSimplified return mkAppN f e.getAppArgs +def simpCtorDiscr? (e : Expr) : OptionT SimpM Expr := do + let some discr := (← read).ctorDiscrMap.find? e | failure + guard <| compatibleTypes (← getType discr) (← inferType e) + return .fvar discr + def applyImplementedBy? (e : Expr) : OptionT SimpM Expr := do guard <| (← read).config.implementedBy let .const declName us := e.getAppFn | failure @@ -45,4 +50,4 @@ def applyImplementedBy? (e : Expr) : OptionT SimpM Expr := do /-- Try to apply simple simplifications. -/ def simpValue? (e : Expr) : SimpM (Option Expr) := -- TODO: more simplifications - simpProj? e <|> simpAppApp? e <|> applyImplementedBy? e + simpProj? e <|> simpAppApp? e <|> simpCtorDiscr? e <|> applyImplementedBy? e