fix: enable more optimizations on inductives with computed fields in the new compiler (#8754)

This PR changes the implementation of computed fields in the new
compiler, which should enable more optimizations (and remove a
questionable hack in `toLCNF` that was only suitable for bringup). We
convert `casesOn` to `cases` like we do for other inductive types, all
constructors get replaced by their real implementations late in the base
phase, and then the `cases` expression is rewritten to use the real
constructors in `toMono`.

In the future, it might be better to move to a model where the `cases`
expression gets rewritten earlier or the constructors get replaced
later, so that both are done at the same time.
This commit is contained in:
Cameron Zwarich 2025-06-12 16:28:09 -07:00 committed by GitHub
parent 8aa003bdfc
commit deda28e6e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 44 additions and 37 deletions

View file

@ -568,33 +568,6 @@ where
let result := .fvar auxDecl.fvarId
mkOverApplication result args casesInfo.arity
visitCasesImplementedBy (casesInfo : CasesInfo) (f : Expr) (args : Array Expr) : M Arg := do
let mut args := args
let discr := args[casesInfo.discrPos]!
if discr matches .fvar _ then
let typeName := casesInfo.declName.getPrefix
let .inductInfo indVal ← getConstInfo typeName | unreachable!
args ← args.mapIdxM fun i arg => do
unless casesInfo.altsRange.start <= i && i < casesInfo.altsRange.stop do return arg
let altIdx := i - casesInfo.altsRange.start
let numParams := casesInfo.altNumParams[altIdx]!
let ctorName := indVal.ctors[altIdx]!
-- We simplify `casesOn` arguments that simply reconstruct the discriminant and replace
-- them with the actual discriminant. This is required for hash consing to work correctly,
-- and should eventually be fixed by changing the elaborated term to use the original
-- variable.
Meta.MetaM.run' <| Meta.lambdaBoundedTelescope arg numParams fun paramExprs body => do
let fn := body.getAppFn
let args := body.getAppArgs
let args := args.map fun arg =>
if arg.getAppFn.constName? == some ctorName && arg.getAppArgs == paramExprs then
discr
else
arg
Meta.mkLambdaFVars paramExprs (mkAppN fn args)
visitAppDefaultConst f args
visitCtor (arity : Nat) (e : Expr) : M Arg :=
etaIfUnderApplied e arity do
visitAppDefaultConst e.getAppFn e.getAppArgs
@ -715,10 +688,7 @@ where
else if declName == ``False.rec || declName == ``Empty.rec || declName == ``False.casesOn || declName == ``Empty.casesOn then
visitFalseRec e
else if let some casesInfo ← getCasesInfo? declName then
if (getImplementedBy? (← getEnv) declName).isSome then
e.withApp (visitCasesImplementedBy casesInfo)
else
visitCases casesInfo e
visitCases casesInfo e
else if let some arity ← getCtorArity? declName then
visitCtor arity e
else if isNoConfusion (← getEnv) declName then

View file

@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Compiler.ExternAttr
import Lean.Compiler.ImplementedByAttr
import Lean.Compiler.LCNF.MonoTypes
import Lean.Compiler.LCNF.InferType
import Lean.Compiler.NoncomputableAttr
@ -107,6 +108,24 @@ def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do
let value ← decl.value.toMono decl.fvarId
decl.update type value
def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFields : Nat)
(oldFields : Array Param) : ToMonoM (Array Param) := do
let mut type := ctorType
for _ in [0:numParams] do
match type with
| .forallE _ _ body _ =>
type := body
| _ => unreachable!
let mut newFields := Array.emptyWithCapacity (oldFields.size + numNewFields)
for _ in [0:numNewFields] do
match type with
| .forallE name fieldType body _ =>
let param ← mkParam name (← toMonoType fieldType) false
newFields := newFields.push param
type := body
| _ => unreachable!
return newFields ++ oldFields
mutual
partial def FunDecl.toMono (decl : FunDecl) : ToMonoM FunDecl := do
@ -278,12 +297,30 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do
else if let some info ← hasTrivialStructure? c.typeName then
trivialStructToMono info c
else
let type ← toMonoType c.resultType
let alts ← c.alts.mapM fun alt =>
match alt with
| .default k => return alt.updateCode (← k.toMono)
| .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
return code.updateCases! type c.discr alts
let resultType ← toMonoType c.resultType
let env ← getEnv
let some (.inductInfo inductInfo) := env.find? c.typeName | panic! "expected inductive type"
let casesOnName := mkCasesOnName inductInfo.name
if (getImplementedBy? env casesOnName).isSome then
-- TODO: Enforce that this is only used for computed fields.
let typeName := c.typeName ++ `_impl
let alts ← c.alts.mapM fun alt => do
match alt with
| .default k => return alt.updateCode (← k.toMono)
| .alt ctorName ps k =>
let implCtorName := ctorName ++ `_impl
let some (.ctorInfo ctorInfo) := env.find? implCtorName | panic! "expected constructor"
let numNewFields := ctorInfo.numFields - ps.size
let ps ← mkFieldParamsForComputedFields ctorInfo.type ctorInfo.numParams numNewFields ps
let k ← k.toMono
return .alt implCtorName ps k
return .cases { discr := c.discr, resultType, typeName, alts }
else
let alts ← c.alts.mapM fun alt =>
match alt with
| .default k => return alt.updateCode (← k.toMono)
| .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
return code.updateCases! resultType c.discr alts
end