fix: increase precision of new compiler's noncomputable check (#8675)

This PR increases the precision of the new compiler's non computable
check, particularly around irrelevant uses of `noncomputable` defs in
applications.

There are no tests included because they don't pass with the old
compiler. They are on the new compiler's branch and they will be merged
when it is enabled.
This commit is contained in:
Cameron Zwarich 2025-06-07 15:20:55 -07:00 committed by GitHub
parent 4abc4430dc
commit 8d8fd0715f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -31,26 +31,38 @@ def checkFVarUse (fvarId : FVarId) : ToMonoM Unit := do
if let some declName := (← get).noncomputableVars.get? fvarId then
throwError f!"failed to compile definition, consider marking it as 'noncomputable' because it depends on '{declName}', which is 'noncomputable'"
def argToMono (arg : Arg) : ToMonoM Arg := do
def checkFVarUseDeferred (resultFVar fvarId : FVarId) : ToMonoM Unit := do
if let some declName := (← get).noncomputableVars.get? fvarId then
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName }
@[inline]
def argToMonoBase (check : FVarId → ToMonoM Unit) (arg : Arg) : ToMonoM Arg := do
match arg with
| .erased | .type .. => return .erased
| .fvar fvarId =>
if (← get).typeParams.contains fvarId then
return .erased
else
checkFVarUse fvarId
check fvarId
return arg
def ctorAppToMono (ctorInfo : ConstructorVal) (args : Array Arg) : ToMonoM LetValue := do
let argsNew : Array Arg ← args[:ctorInfo.numParams].toArray.mapM fun arg => do
def argToMono (arg : Arg) : ToMonoM Arg := argToMonoBase checkFVarUse arg
def argToMonoDeferredCheck (resultFVar : FVarId) (arg : Arg) : ToMonoM Arg :=
argToMonoBase (checkFVarUseDeferred resultFVar) arg
def ctorAppToMono (resultFVar : FVarId) (ctorInfo : ConstructorVal) (args : Array Arg)
: ToMonoM LetValue := do
let argsNewParams : Array Arg ← args[:ctorInfo.numParams].toArray.mapM fun arg => do
-- We only preserve constructor parameters that are types
match arg with
| .type type => return .type (← toMonoType type)
| .fvar .. | .erased => return .erased
let argsNew := argsNew ++ (← args[ctorInfo.numParams:].toArray.mapM argToMono)
let argsNewFields ← args[ctorInfo.numParams:].toArray.mapM (argToMonoDeferredCheck resultFVar)
let argsNew := argsNewParams ++ argsNewFields
return .const ctorInfo.name [] argsNew
partial def LetValue.toMono (e : LetValue) (fvarId : FVarId) : ToMonoM LetValue := do
partial def LetValue.toMono (e : LetValue) (resultFVar : FVarId) : ToMonoM LetValue := do
match e with
| .erased | .lit .. => return e
| .const declName _ args =>
@ -63,26 +75,25 @@ partial def LetValue.toMono (e : LetValue) (fvarId : FVarId) : ToMonoM LetValue
-- and Bool have the same runtime representation.
return args[1]!.toLetValue
else if let some e' ← isTrivialConstructorApp? declName args then
e'.toMono fvarId
e'.toMono resultFVar
else if let some (.ctorInfo ctorInfo) := (← getEnv).find? declName then
ctorAppToMono ctorInfo args
ctorAppToMono resultFVar ctorInfo args
else
let env ← getEnv
if isNoncomputable env declName && !(isExtern env declName) then
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert fvarId declName }
return .const declName [] (← args.mapM argToMono)
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert resultFVar declName }
return .const declName [] (← args.mapM (argToMonoDeferredCheck resultFVar))
| .fvar fvarId args =>
if (← get).typeParams.contains fvarId then
return .erased
else
checkFVarUse fvarId
return .fvar fvarId (← args.mapM argToMono)
checkFVarUseDeferred resultFVar fvarId
return .fvar fvarId (← args.mapM (argToMonoDeferredCheck resultFVar))
| .proj structName fieldIdx baseFVar =>
if (← get).typeParams.contains baseFVar then
return .erased
else
if let some declName := (← get).noncomputableVars.get? baseFVar then
modify fun s => { s with noncomputableVars := s.noncomputableVars.insert fvarId declName }
checkFVarUseDeferred resultFVar baseFVar
if let some info ← hasTrivialStructure? structName then
if info.fieldIdx == fieldIdx then
return .fvar baseFVar #[]