diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 0d1b4937b4..89ce351ec6 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -568,33 +568,6 @@ where let result := .fvar auxDecl.fvarId mkOverApplication result args casesInfo.arity - visitCasesImplementedBy (casesInfo : CasesInfo) (f : Expr) (args : Array Expr) : M Arg := do - let mut args := args - let discr := args[casesInfo.discrPos]! - if discr matches .fvar _ then - let typeName := casesInfo.declName.getPrefix - let .inductInfo indVal ← getConstInfo typeName | unreachable! - args ← args.mapIdxM fun i arg => do - unless casesInfo.altsRange.start <= i && i < casesInfo.altsRange.stop do return arg - let altIdx := i - casesInfo.altsRange.start - let numParams := casesInfo.altNumParams[altIdx]! - let ctorName := indVal.ctors[altIdx]! - - -- We simplify `casesOn` arguments that simply reconstruct the discriminant and replace - -- them with the actual discriminant. This is required for hash consing to work correctly, - -- and should eventually be fixed by changing the elaborated term to use the original - -- variable. - Meta.MetaM.run' <| Meta.lambdaBoundedTelescope arg numParams fun paramExprs body => do - let fn := body.getAppFn - let args := body.getAppArgs - let args := args.map fun arg => - if arg.getAppFn.constName? == some ctorName && arg.getAppArgs == paramExprs then - discr - else - arg - Meta.mkLambdaFVars paramExprs (mkAppN fn args) - visitAppDefaultConst f args - visitCtor (arity : Nat) (e : Expr) : M Arg := etaIfUnderApplied e arity do visitAppDefaultConst e.getAppFn e.getAppArgs @@ -715,10 +688,7 @@ where else if declName == ``False.rec || declName == ``Empty.rec || declName == ``False.casesOn || declName == ``Empty.casesOn then visitFalseRec e else if let some casesInfo ← getCasesInfo? declName then - if (getImplementedBy? (← getEnv) declName).isSome then - e.withApp (visitCasesImplementedBy casesInfo) - else - visitCases casesInfo e + visitCases casesInfo e else if let some arity ← getCtorArity? declName then visitCtor arity e else if isNoConfusion (← getEnv) declName then diff --git a/src/Lean/Compiler/LCNF/ToMono.lean b/src/Lean/Compiler/LCNF/ToMono.lean index bb8f7013cb..479a7ba69f 100644 --- a/src/Lean/Compiler/LCNF/ToMono.lean +++ b/src/Lean/Compiler/LCNF/ToMono.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ prelude import Lean.Compiler.ExternAttr +import Lean.Compiler.ImplementedByAttr import Lean.Compiler.LCNF.MonoTypes import Lean.Compiler.LCNF.InferType import Lean.Compiler.NoncomputableAttr @@ -107,6 +108,24 @@ def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do let value ← decl.value.toMono decl.fvarId decl.update type value +def mkFieldParamsForComputedFields (ctorType : Expr) (numParams : Nat) (numNewFields : Nat) + (oldFields : Array Param) : ToMonoM (Array Param) := do + let mut type := ctorType + for _ in [0:numParams] do + match type with + | .forallE _ _ body _ => + type := body + | _ => unreachable! + let mut newFields := Array.emptyWithCapacity (oldFields.size + numNewFields) + for _ in [0:numNewFields] do + match type with + | .forallE name fieldType body _ => + let param ← mkParam name (← toMonoType fieldType) false + newFields := newFields.push param + type := body + | _ => unreachable! + return newFields ++ oldFields + mutual partial def FunDecl.toMono (decl : FunDecl) : ToMonoM FunDecl := do @@ -278,12 +297,30 @@ partial def Code.toMono (code : Code) : ToMonoM Code := do else if let some info ← hasTrivialStructure? c.typeName then trivialStructToMono info c else - let type ← toMonoType c.resultType - let alts ← c.alts.mapM fun alt => - match alt with - | .default k => return alt.updateCode (← k.toMono) - | .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono) - return code.updateCases! type c.discr alts + let resultType ← toMonoType c.resultType + let env ← getEnv + let some (.inductInfo inductInfo) := env.find? c.typeName | panic! "expected inductive type" + let casesOnName := mkCasesOnName inductInfo.name + if (getImplementedBy? env casesOnName).isSome then + -- TODO: Enforce that this is only used for computed fields. + let typeName := c.typeName ++ `_impl + let alts ← c.alts.mapM fun alt => do + match alt with + | .default k => return alt.updateCode (← k.toMono) + | .alt ctorName ps k => + let implCtorName := ctorName ++ `_impl + let some (.ctorInfo ctorInfo) := env.find? implCtorName | panic! "expected constructor" + let numNewFields := ctorInfo.numFields - ps.size + let ps ← mkFieldParamsForComputedFields ctorInfo.type ctorInfo.numParams numNewFields ps + let k ← k.toMono + return .alt implCtorName ps k + return .cases { discr := c.discr, resultType, typeName, alts } + else + let alts ← c.alts.mapM fun alt => + match alt with + | .default k => return alt.updateCode (← k.toMono) + | .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono) + return code.updateCases! resultType c.discr alts end