perf: handle per-constructor noConfusion in toLCNF (#11566)
This PR lets the compiler treat per-constructor `noConfusion` like the general one, and moves some more logic closer to no confusion generation.
This commit is contained in:
parent
06037ade0f
commit
3b40682b22
5 changed files with 99 additions and 71 deletions
|
|
@ -62,12 +62,31 @@ public def isSparseCasesOn (env : Environment) (declName : Name) : Bool :=
|
|||
public def isCasesOnLike (env : Environment) (declName : Name) : Bool :=
|
||||
isCasesOnRecursor env declName || isSparseCasesOn env declName
|
||||
|
||||
builtin_initialize noConfusionExt : TagDeclarationExtension ← mkTagDeclarationExtension
|
||||
/--
|
||||
Shape information for no confusion lemmas.
|
||||
The `arity` does not include the final major argument (which is not there when the constructors differ)
|
||||
The regular no confusion lemma marks the lhs and rhs arguments for the compiler to look at and
|
||||
find the number of fields.
|
||||
The per-constructor no confusion lemmas know the number of (non-prop) fields statically.
|
||||
-/
|
||||
inductive NoConfusionInfo where
|
||||
| regular (arity : Nat) (lhs : Nat) (rhs : Nat)
|
||||
| perCtor (arity : Nat) (fields : Nat)
|
||||
deriving Inhabited
|
||||
|
||||
def markNoConfusion (env : Environment) (n : Name) : Environment :=
|
||||
noConfusionExt.tag env n
|
||||
def NoConfusionInfo.arity : NoConfusionInfo → Nat
|
||||
| .regular arity _ _ => arity
|
||||
| .perCtor arity _ => arity
|
||||
|
||||
builtin_initialize noConfusionExt : MapDeclarationExtension NoConfusionInfo ← mkMapDeclarationExtension (asyncMode := .mainOnly)
|
||||
|
||||
def markNoConfusion (env : Environment) (n : Name) (info : NoConfusionInfo) : Environment :=
|
||||
noConfusionExt.insert env n info
|
||||
|
||||
def isNoConfusion (env : Environment) (n : Name) : Bool :=
|
||||
noConfusionExt.isTagged env n
|
||||
noConfusionExt.contains env n
|
||||
|
||||
def getNoConfusionInfo (env : Environment) (n : Name) : NoConfusionInfo :=
|
||||
(noConfusionExt.find? env n).get!
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -664,32 +664,38 @@ where
|
|||
|
||||
visitNoConfusion (e : Expr) : M Arg := do
|
||||
let .const declName _ := e.getAppFn | unreachable!
|
||||
let info := getNoConfusionInfo (← getEnv) declName
|
||||
let typeName := declName.getPrefix
|
||||
let .inductInfo inductVal ← getConstInfo typeName | unreachable!
|
||||
let arity := inductVal.numParams + 1 /- motive -/ + 3*(inductVal.numIndices + 1) /- lhs/rhs and equalities -/
|
||||
etaIfUnderApplied e arity do
|
||||
etaIfUnderApplied e info.arity do
|
||||
let args := e.getAppArgs
|
||||
let lhs ← liftMetaM do Meta.whnf args[inductVal.numParams + 1 + inductVal.numIndices]!
|
||||
let rhs ← liftMetaM do Meta.whnf args[inductVal.numParams + 1 + inductVal.numIndices + 1 + inductVal.numIndices]!
|
||||
let lhs ← liftMetaM lhs.toCtorIfLit
|
||||
let rhs ← liftMetaM rhs.toCtorIfLit
|
||||
match (← liftMetaM <| Meta.isConstructorApp? lhs), (← liftMetaM <| Meta.isConstructorApp? rhs) with
|
||||
| some lhsCtorVal, some rhsCtorVal =>
|
||||
if lhsCtorVal.name == rhsCtorVal.name then
|
||||
etaIfUnderApplied e (arity+1) do
|
||||
let major := args[arity]!
|
||||
let visitMajor (numNonPropFields : Nat) := do
|
||||
etaIfUnderApplied e (info.arity+1) do
|
||||
let major := args[info.arity]!
|
||||
let major ← expandNoConfusionMajor major numNonPropFields
|
||||
let major := mkAppN major args[(info.arity+1)...*]
|
||||
visit major
|
||||
|
||||
match info with
|
||||
| .regular _ lhsPos rhsPos =>
|
||||
let lhs ← liftMetaM do Meta.whnf args[lhsPos]!
|
||||
let rhs ← liftMetaM do Meta.whnf args[rhsPos]!
|
||||
let lhs ← liftMetaM lhs.toCtorIfLit
|
||||
let rhs ← liftMetaM rhs.toCtorIfLit
|
||||
match (← liftMetaM <| Meta.isConstructorApp? lhs), (← liftMetaM <| Meta.isConstructorApp? rhs) with
|
||||
| some lhsCtorVal, some rhsCtorVal =>
|
||||
if lhsCtorVal.name == rhsCtorVal.name then
|
||||
let numNonPropFields ← liftMetaM <| Meta.forallTelescope lhsCtorVal.type fun params _ =>
|
||||
params[lhsCtorVal.numParams...*].foldlM (init := 0) fun n param => do
|
||||
let type ← param.fvarId!.getType
|
||||
return if !(← Meta.isProp type) then n + 1 else n
|
||||
let major ← expandNoConfusionMajor major numNonPropFields
|
||||
let major := mkAppN major args[(arity+1)...*]
|
||||
visit major
|
||||
else
|
||||
let type ← toLCNFType (← liftMetaM <| Meta.inferType e)
|
||||
mkUnreachable type
|
||||
| _, _ =>
|
||||
throwError "code generator failed, unsupported occurrence of `{.ofConstName declName}`"
|
||||
visitMajor numNonPropFields
|
||||
else
|
||||
let type ← toLCNFType (← liftMetaM <| Meta.inferType e)
|
||||
mkUnreachable type
|
||||
| _, _ =>
|
||||
throwError "code generator failed, unsupported occurrence of `{.ofConstName declName}`"
|
||||
| .perCtor _ numNonPropFields =>
|
||||
visitMajor numNonPropFields
|
||||
|
||||
expandNoConfusionMajor (major : Expr) (numFields : Nat) : M Expr := do
|
||||
match numFields with
|
||||
|
|
|
|||
|
|
@ -257,7 +257,10 @@ def mkNoConfusionCoreImp (indName : Name) : MetaM Unit := do
|
|||
(value := e)
|
||||
(hints := ReducibilityHints.abbrev)))
|
||||
setReducibleAttribute declName
|
||||
modifyEnv fun env => markNoConfusion env declName
|
||||
let arity := info.numParams + 1 + 3 * (info.numIndices + 1)
|
||||
let lhsPos := info.numParams + 1 + info.numIndices
|
||||
let rhsPos := info.numParams + 1 + info.numIndices + 1 + info.numIndices
|
||||
modifyEnv fun env => markNoConfusion env declName (.regular arity lhsPos rhsPos)
|
||||
modifyEnv fun env => addProtected env declName
|
||||
|
||||
/--
|
||||
|
|
@ -295,48 +298,47 @@ def mkNoConfusionCtors (declName : Name) : MetaM Unit := do
|
|||
for ctor in indVal.ctors do
|
||||
let ctorInfo ← getConstInfoCtor ctor
|
||||
if ctorInfo.numFields > 0 then
|
||||
let e ←
|
||||
forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs t => do
|
||||
withLocalDeclD `P (.sort v) fun P =>
|
||||
forallBoundedTelescope t ctorInfo.numFields fun fields1 _ => do
|
||||
forallBoundedTelescope t ctorInfo.numFields fun fields2 _ => do
|
||||
withPrimedNames fields2 do
|
||||
withImplicitBinderInfos (xs ++ #[P] ++ fields1 ++ fields2) do
|
||||
let ctor1 := mkAppN (mkConst ctor us) (xs ++ fields1)
|
||||
let ctor2 := mkAppN (mkConst ctor us) (xs ++ fields2)
|
||||
let is1 := (← whnf (← inferType ctor1)).getAppArgsN indVal.numIndices
|
||||
let is2 := (← whnf (← inferType ctor2)).getAppArgsN indVal.numIndices
|
||||
withNeededEqTelescope (is1.push ctor1) (is2.push ctor2) fun eqvs eqs => do
|
||||
-- When the kernel checks this definition, it will perform the potentially expensive
|
||||
-- computation that `noConfusionType h` is equal to `$kType → P`
|
||||
let kType ← mkNoConfusionCtorArg ctor P
|
||||
let kType := kType.beta (xs ++ fields1 ++ fields2)
|
||||
withLocalDeclD `k kType fun k => do
|
||||
let mut e := mkConst noConfusionName (v :: us)
|
||||
e := mkAppN e (xs ++ #[P] ++ is1 ++ #[ctor1] ++ is2 ++ #[ctor2])
|
||||
-- eqs may have more Eq rather than HEq than expected by `noConfusion`
|
||||
for eq in eqs do
|
||||
let needsHEq := (← whnfForall (← inferType e)).bindingDomain!.isHEq
|
||||
if needsHEq && (← inferType eq).isEq then
|
||||
e := mkApp e (← mkHEqOfEq eq)
|
||||
else
|
||||
e := mkApp e eq
|
||||
e := mkApp e k
|
||||
e ← mkExpectedTypeHint e P
|
||||
mkLambdaFVars (xs ++ #[P] ++ fields1 ++ fields2 ++ eqvs ++ #[k]) e
|
||||
let name := ctor.str "noConfusion"
|
||||
addDecl (.defnDecl (← mkDefinitionValInferringUnsafe
|
||||
(name := name)
|
||||
(levelParams := recInfo.levelParams)
|
||||
(type := (← inferType e))
|
||||
(value := e)
|
||||
(hints := ReducibilityHints.abbrev)
|
||||
))
|
||||
setReducibleAttribute name
|
||||
-- The compiler has special support for `noConfusion`. So lets mark this as
|
||||
-- macroInline to not generate code for all these extra definitions, and instead
|
||||
-- let the compiler unfold this to then put the custom code there
|
||||
setInlineAttribute name (kind := .macroInline)
|
||||
forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs t => do
|
||||
withLocalDeclD `P (.sort v) fun P =>
|
||||
forallBoundedTelescope t ctorInfo.numFields fun fields1 _ => do
|
||||
forallBoundedTelescope t ctorInfo.numFields fun fields2 _ => do
|
||||
withPrimedNames fields2 do
|
||||
withImplicitBinderInfos (xs ++ #[P] ++ fields1 ++ fields2) do
|
||||
let ctor1 := mkAppN (mkConst ctor us) (xs ++ fields1)
|
||||
let ctor2 := mkAppN (mkConst ctor us) (xs ++ fields2)
|
||||
let is1 := (← whnf (← inferType ctor1)).getAppArgsN indVal.numIndices
|
||||
let is2 := (← whnf (← inferType ctor2)).getAppArgsN indVal.numIndices
|
||||
withNeededEqTelescope (is1.push ctor1) (is2.push ctor2) fun eqvs eqs => do
|
||||
-- When the kernel checks this definition, it will perform the potentially expensive
|
||||
-- computation that `noConfusionType h` is equal to `$kType → P`
|
||||
let kType ← mkNoConfusionCtorArg ctor P
|
||||
let kType := kType.beta (xs ++ fields1 ++ fields2)
|
||||
withLocalDeclD `k kType fun k => do
|
||||
let mut e := mkConst noConfusionName (v :: us)
|
||||
e := mkAppN e (xs ++ #[P] ++ is1 ++ #[ctor1] ++ is2 ++ #[ctor2])
|
||||
-- eqs may have more Eq rather than HEq than expected by `noConfusion`
|
||||
for eq in eqs do
|
||||
let needsHEq := (← whnfForall (← inferType e)).bindingDomain!.isHEq
|
||||
if needsHEq && (← inferType eq).isEq then
|
||||
e := mkApp e (← mkHEqOfEq eq)
|
||||
else
|
||||
e := mkApp e eq
|
||||
e := mkApp e k
|
||||
e ← mkExpectedTypeHint e P
|
||||
e ← mkLambdaFVars (xs ++ #[P] ++ fields1 ++ fields2 ++ eqvs ++ #[k]) e
|
||||
|
||||
let name := ctor.str "noConfusion"
|
||||
addDecl (.defnDecl (← mkDefinitionValInferringUnsafe
|
||||
(name := name)
|
||||
(levelParams := recInfo.levelParams)
|
||||
(type := (← inferType e))
|
||||
(value := e)
|
||||
(hints := ReducibilityHints.abbrev)
|
||||
))
|
||||
setReducibleAttribute name
|
||||
let arity := ctorInfo.numParams + 1 + 2 * ctorInfo.numFields + indVal.numIndices + 1
|
||||
let fields := kType.getNumHeadForalls
|
||||
modifyEnv fun env => markNoConfusion env name (.perCtor arity fields)
|
||||
|
||||
|
||||
def mkNoConfusionCore (declName : Name) : MetaM Unit := do
|
||||
|
|
@ -375,7 +377,7 @@ where
|
|||
let ctorIdx := mkConst (mkCtorIdxName enumName) us
|
||||
mkLambdaFVars #[P, x, y] (← mkAppM ``noConfusionTypeEnum #[ctorIdx, P, x, y])
|
||||
let declName := Name.mkStr enumName "noConfusionType"
|
||||
addAndCompile <| Declaration.defnDecl {
|
||||
addDecl <| Declaration.defnDecl {
|
||||
name := declName
|
||||
levelParams := v :: info.levelParams
|
||||
type := declType
|
||||
|
|
@ -404,7 +406,7 @@ where
|
|||
else
|
||||
mkAppOptM ``noConfusionEnum #[none, none, none, ctorIdx, P, x, y, h]
|
||||
let declName := Name.mkStr enumName "noConfusion"
|
||||
addAndCompile <| Declaration.defnDecl {
|
||||
addDecl <| Declaration.defnDecl {
|
||||
name := declName
|
||||
levelParams := v :: info.levelParams
|
||||
type := declType
|
||||
|
|
@ -413,7 +415,7 @@ where
|
|||
hints := ReducibilityHints.abbrev
|
||||
}
|
||||
setReducibleAttribute declName
|
||||
modifyEnv fun env => markNoConfusion env declName
|
||||
modifyEnv fun env => markNoConfusion env declName (.regular 4 1 2)
|
||||
|
||||
public def mkNoConfusion (declName : Name) : MetaM Unit := do
|
||||
withTraceNode `Meta.mkNoConfusion (fun _ => return m!"{declName}") do
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#include "util/options.h"
|
||||
|
||||
// please update me
|
||||
|
||||
namespace lean {
|
||||
options get_default_options() {
|
||||
options opts;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,2 @@
|
|||
librarySearch.lean:270:0-270:7: warning: declaration uses 'sorry'
|
||||
librarySearch.lean:375:0-375:7: warning: declaration uses 'sorry'
|
||||
librarySearch.lean:385:0-385:7: warning: declaration uses 'sorry'
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue