perf: handle per-constructor noConfusion in toLCNF (#11566)

This PR lets the compiler treat per-constructor `noConfusion` like the
general one, and moves some more logic closer to no confusion
generation.
This commit is contained in:
Joachim Breitner 2025-12-10 10:03:55 +01:00 committed by GitHub
parent 06037ade0f
commit 3b40682b22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 99 additions and 71 deletions

View file

@ -62,12 +62,31 @@ public def isSparseCasesOn (env : Environment) (declName : Name) : Bool :=
public def isCasesOnLike (env : Environment) (declName : Name) : Bool :=
isCasesOnRecursor env declName || isSparseCasesOn env declName
builtin_initialize noConfusionExt : TagDeclarationExtension ← mkTagDeclarationExtension
/--
Shape information for no confusion lemmas.
The `arity` does not include the final major argument (which is not there when the constructors differ)
The regular no confusion lemma marks the lhs and rhs arguments for the compiler to look at and
find the number of fields.
The per-constructor no confusion lemmas know the number of (non-prop) fields statically.
-/
inductive NoConfusionInfo where
| regular (arity : Nat) (lhs : Nat) (rhs : Nat)
| perCtor (arity : Nat) (fields : Nat)
deriving Inhabited
def markNoConfusion (env : Environment) (n : Name) : Environment :=
noConfusionExt.tag env n
def NoConfusionInfo.arity : NoConfusionInfo → Nat
| .regular arity _ _ => arity
| .perCtor arity _ => arity
builtin_initialize noConfusionExt : MapDeclarationExtension NoConfusionInfo ← mkMapDeclarationExtension (asyncMode := .mainOnly)
def markNoConfusion (env : Environment) (n : Name) (info : NoConfusionInfo) : Environment :=
noConfusionExt.insert env n info
def isNoConfusion (env : Environment) (n : Name) : Bool :=
noConfusionExt.isTagged env n
noConfusionExt.contains env n
def getNoConfusionInfo (env : Environment) (n : Name) : NoConfusionInfo :=
(noConfusionExt.find? env n).get!
end Lean

View file

@ -664,32 +664,38 @@ where
visitNoConfusion (e : Expr) : M Arg := do
let .const declName _ := e.getAppFn | unreachable!
let info := getNoConfusionInfo (← getEnv) declName
let typeName := declName.getPrefix
let .inductInfo inductVal ← getConstInfo typeName | unreachable!
let arity := inductVal.numParams + 1 /- motive -/ + 3*(inductVal.numIndices + 1) /- lhs/rhs and equalities -/
etaIfUnderApplied e arity do
etaIfUnderApplied e info.arity do
let args := e.getAppArgs
let lhs ← liftMetaM do Meta.whnf args[inductVal.numParams + 1 + inductVal.numIndices]!
let rhs ← liftMetaM do Meta.whnf args[inductVal.numParams + 1 + inductVal.numIndices + 1 + inductVal.numIndices]!
let lhs ← liftMetaM lhs.toCtorIfLit
let rhs ← liftMetaM rhs.toCtorIfLit
match (← liftMetaM <| Meta.isConstructorApp? lhs), (← liftMetaM <| Meta.isConstructorApp? rhs) with
| some lhsCtorVal, some rhsCtorVal =>
if lhsCtorVal.name == rhsCtorVal.name then
etaIfUnderApplied e (arity+1) do
let major := args[arity]!
let visitMajor (numNonPropFields : Nat) := do
etaIfUnderApplied e (info.arity+1) do
let major := args[info.arity]!
let major ← expandNoConfusionMajor major numNonPropFields
let major := mkAppN major args[(info.arity+1)...*]
visit major
match info with
| .regular _ lhsPos rhsPos =>
let lhs ← liftMetaM do Meta.whnf args[lhsPos]!
let rhs ← liftMetaM do Meta.whnf args[rhsPos]!
let lhs ← liftMetaM lhs.toCtorIfLit
let rhs ← liftMetaM rhs.toCtorIfLit
match (← liftMetaM <| Meta.isConstructorApp? lhs), (← liftMetaM <| Meta.isConstructorApp? rhs) with
| some lhsCtorVal, some rhsCtorVal =>
if lhsCtorVal.name == rhsCtorVal.name then
let numNonPropFields ← liftMetaM <| Meta.forallTelescope lhsCtorVal.type fun params _ =>
params[lhsCtorVal.numParams...*].foldlM (init := 0) fun n param => do
let type ← param.fvarId!.getType
return if !(← Meta.isProp type) then n + 1 else n
let major ← expandNoConfusionMajor major numNonPropFields
let major := mkAppN major args[(arity+1)...*]
visit major
else
let type ← toLCNFType (← liftMetaM <| Meta.inferType e)
mkUnreachable type
| _, _ =>
throwError "code generator failed, unsupported occurrence of `{.ofConstName declName}`"
visitMajor numNonPropFields
else
let type ← toLCNFType (← liftMetaM <| Meta.inferType e)
mkUnreachable type
| _, _ =>
throwError "code generator failed, unsupported occurrence of `{.ofConstName declName}`"
| .perCtor _ numNonPropFields =>
visitMajor numNonPropFields
expandNoConfusionMajor (major : Expr) (numFields : Nat) : M Expr := do
match numFields with

View file

@ -257,7 +257,10 @@ def mkNoConfusionCoreImp (indName : Name) : MetaM Unit := do
(value := e)
(hints := ReducibilityHints.abbrev)))
setReducibleAttribute declName
modifyEnv fun env => markNoConfusion env declName
let arity := info.numParams + 1 + 3 * (info.numIndices + 1)
let lhsPos := info.numParams + 1 + info.numIndices
let rhsPos := info.numParams + 1 + info.numIndices + 1 + info.numIndices
modifyEnv fun env => markNoConfusion env declName (.regular arity lhsPos rhsPos)
modifyEnv fun env => addProtected env declName
/--
@ -295,48 +298,47 @@ def mkNoConfusionCtors (declName : Name) : MetaM Unit := do
for ctor in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctor
if ctorInfo.numFields > 0 then
let e ←
forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs t => do
withLocalDeclD `P (.sort v) fun P =>
forallBoundedTelescope t ctorInfo.numFields fun fields1 _ => do
forallBoundedTelescope t ctorInfo.numFields fun fields2 _ => do
withPrimedNames fields2 do
withImplicitBinderInfos (xs ++ #[P] ++ fields1 ++ fields2) do
let ctor1 := mkAppN (mkConst ctor us) (xs ++ fields1)
let ctor2 := mkAppN (mkConst ctor us) (xs ++ fields2)
let is1 := (← whnf (← inferType ctor1)).getAppArgsN indVal.numIndices
let is2 := (← whnf (← inferType ctor2)).getAppArgsN indVal.numIndices
withNeededEqTelescope (is1.push ctor1) (is2.push ctor2) fun eqvs eqs => do
-- When the kernel checks this definition, it will perform the potentially expensive
-- computation that `noConfusionType h` is equal to `$kType → P`
let kType ← mkNoConfusionCtorArg ctor P
let kType := kType.beta (xs ++ fields1 ++ fields2)
withLocalDeclD `k kType fun k => do
let mut e := mkConst noConfusionName (v :: us)
e := mkAppN e (xs ++ #[P] ++ is1 ++ #[ctor1] ++ is2 ++ #[ctor2])
-- eqs may have more Eq rather than HEq than expected by `noConfusion`
for eq in eqs do
let needsHEq := (← whnfForall (← inferType e)).bindingDomain!.isHEq
if needsHEq && (← inferType eq).isEq then
e := mkApp e (← mkHEqOfEq eq)
else
e := mkApp e eq
e := mkApp e k
e ← mkExpectedTypeHint e P
mkLambdaFVars (xs ++ #[P] ++ fields1 ++ fields2 ++ eqvs ++ #[k]) e
let name := ctor.str "noConfusion"
addDecl (.defnDecl (← mkDefinitionValInferringUnsafe
(name := name)
(levelParams := recInfo.levelParams)
(type := (← inferType e))
(value := e)
(hints := ReducibilityHints.abbrev)
))
setReducibleAttribute name
-- The compiler has special support for `noConfusion`. So lets mark this as
-- macroInline to not generate code for all these extra definitions, and instead
-- let the compiler unfold this to then put the custom code there
setInlineAttribute name (kind := .macroInline)
forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs t => do
withLocalDeclD `P (.sort v) fun P =>
forallBoundedTelescope t ctorInfo.numFields fun fields1 _ => do
forallBoundedTelescope t ctorInfo.numFields fun fields2 _ => do
withPrimedNames fields2 do
withImplicitBinderInfos (xs ++ #[P] ++ fields1 ++ fields2) do
let ctor1 := mkAppN (mkConst ctor us) (xs ++ fields1)
let ctor2 := mkAppN (mkConst ctor us) (xs ++ fields2)
let is1 := (← whnf (← inferType ctor1)).getAppArgsN indVal.numIndices
let is2 := (← whnf (← inferType ctor2)).getAppArgsN indVal.numIndices
withNeededEqTelescope (is1.push ctor1) (is2.push ctor2) fun eqvs eqs => do
-- When the kernel checks this definition, it will perform the potentially expensive
-- computation that `noConfusionType h` is equal to `$kType → P`
let kType ← mkNoConfusionCtorArg ctor P
let kType := kType.beta (xs ++ fields1 ++ fields2)
withLocalDeclD `k kType fun k => do
let mut e := mkConst noConfusionName (v :: us)
e := mkAppN e (xs ++ #[P] ++ is1 ++ #[ctor1] ++ is2 ++ #[ctor2])
-- eqs may have more Eq rather than HEq than expected by `noConfusion`
for eq in eqs do
let needsHEq := (← whnfForall (← inferType e)).bindingDomain!.isHEq
if needsHEq && (← inferType eq).isEq then
e := mkApp e (← mkHEqOfEq eq)
else
e := mkApp e eq
e := mkApp e k
e ← mkExpectedTypeHint e P
e ← mkLambdaFVars (xs ++ #[P] ++ fields1 ++ fields2 ++ eqvs ++ #[k]) e
let name := ctor.str "noConfusion"
addDecl (.defnDecl (← mkDefinitionValInferringUnsafe
(name := name)
(levelParams := recInfo.levelParams)
(type := (← inferType e))
(value := e)
(hints := ReducibilityHints.abbrev)
))
setReducibleAttribute name
let arity := ctorInfo.numParams + 1 + 2 * ctorInfo.numFields + indVal.numIndices + 1
let fields := kType.getNumHeadForalls
modifyEnv fun env => markNoConfusion env name (.perCtor arity fields)
def mkNoConfusionCore (declName : Name) : MetaM Unit := do
@ -375,7 +377,7 @@ where
let ctorIdx := mkConst (mkCtorIdxName enumName) us
mkLambdaFVars #[P, x, y] (← mkAppM ``noConfusionTypeEnum #[ctorIdx, P, x, y])
let declName := Name.mkStr enumName "noConfusionType"
addAndCompile <| Declaration.defnDecl {
addDecl <| Declaration.defnDecl {
name := declName
levelParams := v :: info.levelParams
type := declType
@ -404,7 +406,7 @@ where
else
mkAppOptM ``noConfusionEnum #[none, none, none, ctorIdx, P, x, y, h]
let declName := Name.mkStr enumName "noConfusion"
addAndCompile <| Declaration.defnDecl {
addDecl <| Declaration.defnDecl {
name := declName
levelParams := v :: info.levelParams
type := declType
@ -413,7 +415,7 @@ where
hints := ReducibilityHints.abbrev
}
setReducibleAttribute declName
modifyEnv fun env => markNoConfusion env declName
modifyEnv fun env => markNoConfusion env declName (.regular 4 1 2)
public def mkNoConfusion (declName : Name) : MetaM Unit := do
withTraceNode `Meta.mkNoConfusion (fun _ => return m!"{declName}") do

View file

@ -1,5 +1,7 @@
#include "util/options.h"
// please update me
namespace lean {
options get_default_options() {
options opts;

View file

@ -1,3 +1,2 @@
librarySearch.lean:270:0-270:7: warning: declaration uses 'sorry'
librarySearch.lean:375:0-375:7: warning: declaration uses 'sorry'
librarySearch.lean:385:0-385:7: warning: declaration uses 'sorry'