diff --git a/src/Lean/Compiler/LCNF/ToMono.lean b/src/Lean/Compiler/LCNF/ToMono.lean index ad9cf398ac..0aa5694159 100644 --- a/src/Lean/Compiler/LCNF/ToMono.lean +++ b/src/Lean/Compiler/LCNF/ToMono.lean @@ -31,26 +31,38 @@ def checkFVarUse (fvarId : FVarId) : ToMonoM Unit := do if let some declName := (← get).noncomputableVars.get? fvarId then throwError f!"failed to compile definition, consider marking it as 'noncomputable' because it depends on '{declName}', which is 'noncomputable'" -def argToMono (arg : Arg) : ToMonoM Arg := do +def checkFVarUseDeferred (resultFVar fvarId : FVarId) : ToMonoM Unit := do + if let some declName := (← get).noncomputableVars.get? fvarId then + modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName } + +@[inline] +def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg) : ToMonoM Arg := do match arg with | .erased | .type .. => return .erased | .fvar fvarId => if (← get).typeParams.contains fvarId then return .erased else - checkFVarUse fvarId + check fvarId return arg -def ctorAppToMono (ctorInfo : ConstructorVal) (args : Array Arg) : ToMonoM LetValue := do - let argsNew : Array Arg ← args[:ctorInfo.numParams].toArray.mapM fun arg => do +def argToMono (arg : Arg) : ToMonoM Arg := argToMonoBase checkFVarUse arg + +def argToMonoDeferredCheck (resultFVar : FVarId) (arg : Arg) : ToMonoM Arg := + argToMonoBase (checkFVarUseDeferred resultFVar) arg + +def ctorAppToMono (resultFVar : FVarId) (ctorInfo : ConstructorVal) (args : Array Arg) + : ToMonoM LetValue := do + let argsNewParams : Array Arg ← args[:ctorInfo.numParams].toArray.mapM fun arg => do -- We only preserve constructor parameters that are types match arg with | .type type => return .type (← toMonoType type) | .fvar .. | .erased => return .erased - let argsNew := argsNew ++ (← args[ctorInfo.numParams:].toArray.mapM argToMono) + let argsNewFields ← args[ctorInfo.numParams:].toArray.mapM (argToMonoDeferredCheck resultFVar) + let argsNew := argsNewParams ++ argsNewFields return .const ctorInfo.name [] argsNew -partial def LetValue.toMono (e : LetValue) (fvarId : FVarId) : ToMonoM LetValue := do +partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetValue := do match e with | .erased | .lit .. => return e | .const declName _ args => @@ -63,26 +75,25 @@ partial def LetValue.toMono (e : LetValue) (fvarId : FVarId) : ToMonoM LetValue -- and Bool have the same runtime representation. return args[1]!.toLetValue else if let some e' ← isTrivialConstructorApp? declName args then - e'.toMono fvarId + e'.toMono resultFVar else if let some (.ctorInfo ctorInfo) := (← getEnv).find? declName then - ctorAppToMono ctorInfo args + ctorAppToMono resultFVar ctorInfo args else let env ← getEnv if isNoncomputable env declName && !(isExtern env declName) then - modify fun s => { s with noncomputableVars := s.noncomputableVars.insert fvarId declName } - return .const declName [] (← args.mapM argToMono) + modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName } + return .const declName [] (← args.mapM (argToMonoDeferredCheck resultFVar)) | .fvar fvarId args => if (← get).typeParams.contains fvarId then return .erased else - checkFVarUse fvarId - return .fvar fvarId (← args.mapM argToMono) + checkFVarUseDeferred resultFVar fvarId + return .fvar fvarId (← args.mapM (argToMonoDeferredCheck resultFVar)) | .proj structName fieldIdx baseFVar => if (← get).typeParams.contains baseFVar then return .erased else - if let some declName := (← get).noncomputableVars.get? baseFVar then - modify fun s => { s with noncomputableVars := s.noncomputableVars.insert fvarId declName } + checkFVarUseDeferred resultFVar baseFVar if let some info ← hasTrivialStructure? structName then if info.fieldIdx == fieldIdx then return .fvar baseFVar #[]