feat: optimized deriving DecidableEq for enumeration types
The proof term is liner on the number of constructors, but type checking is not linear because the reduction engine in the kernel is not efficient.
This commit is contained in:
parent
9b0dfc4b90
commit
193d4dc9f5
2 changed files with 75 additions and 8 deletions
|
|
@ -95,17 +95,82 @@ def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
|
|||
|
||||
open Command
|
||||
|
||||
def mkDecEq (declName : Name) : CommandElabM Bool := do
|
||||
let indVal ← getConstInfoInduct declName
|
||||
if indVal.isNested then
|
||||
return false -- nested inductive types are not supported yet
|
||||
else
|
||||
let cmds ← liftTermElabM none <| mkDecEqCmds indVal
|
||||
cmds.forM elabCommand
|
||||
return true
|
||||
|
||||
def mkEnumOfNat (declName : Name) : MetaM Unit := do
|
||||
let indVal ← getConstInfoInduct declName
|
||||
let enumType := mkConst declName
|
||||
let ctors := indVal.ctors.toArray
|
||||
withLocalDeclD `n (mkConst ``Nat) fun n => do
|
||||
let cond := mkConst ``cond [levelZero]
|
||||
let mut value := mkConst ctors.back
|
||||
for i in [:ctors.size-1] do
|
||||
let j := ctors.size - i - 2
|
||||
value := mkApp4 cond enumType (mkApp2 (mkConst ``Nat.beq) n (mkNatLit j)) (mkConst ctors[j]) value
|
||||
value ← mkLambdaFVars #[n] value
|
||||
let type ← mkArrow (mkConst ``Nat) enumType
|
||||
addAndCompile <| Declaration.defnDecl {
|
||||
name := Name.mkStr declName "ofNat"
|
||||
levelParams := []
|
||||
safety := DefinitionSafety.safe
|
||||
hints := ReducibilityHints.abbrev
|
||||
value, type
|
||||
}
|
||||
|
||||
def mkEnumOfNatThm (declName : Name) : MetaM Unit := do
|
||||
let indVal ← getConstInfoInduct declName
|
||||
let toCtorIdx := mkConst (Name.mkStr declName "toCtorIdx")
|
||||
let ofNat := mkConst (Name.mkStr declName "ofNat")
|
||||
let enumType := mkConst declName
|
||||
let eqEnum := mkApp (mkConst ``Eq [levelOne]) enumType
|
||||
let rflEnum := mkApp (mkConst ``Eq.refl [levelOne]) enumType
|
||||
let ctors := indVal.ctors
|
||||
withLocalDeclD `x enumType fun x => do
|
||||
let resultType := mkApp2 eqEnum (mkApp ofNat (mkApp toCtorIdx x)) x
|
||||
let motive ← mkLambdaFVars #[x] resultType
|
||||
let casesOn := mkConst (mkCasesOnName declName) [levelZero]
|
||||
let mut value := mkApp2 casesOn motive x
|
||||
for ctor in ctors do
|
||||
value := mkApp value (mkApp rflEnum (mkConst ctor))
|
||||
value ← mkLambdaFVars #[x] value
|
||||
let type ← mkForallFVars #[x] resultType
|
||||
addAndCompile <| Declaration.thmDecl {
|
||||
name := Name.mkStr declName "ofNat_toCtorIdx"
|
||||
levelParams := []
|
||||
value, type
|
||||
}
|
||||
|
||||
def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
|
||||
liftTermElabM none <| mkEnumOfNat declName
|
||||
liftTermElabM none <| mkEnumOfNatThm declName
|
||||
let ofNatIdent := mkIdent (Name.mkStr declName "ofNat")
|
||||
let auxThmIdent := mkIdent (Name.mkStr declName "ofNat_toCtorIdx")
|
||||
let indVal ← getConstInfoInduct declName
|
||||
let cmd ← `(
|
||||
instance : DecidableEq $(mkIdent declName) :=
|
||||
fun x y =>
|
||||
if h : x.toCtorIdx = y.toCtorIdx then
|
||||
isTrue (by have aux := congrArg $ofNatIdent h; rw [$auxThmIdent:ident, $auxThmIdent:ident] at aux; assumption)
|
||||
else
|
||||
isFalse fun h => by subst h; contradiction
|
||||
)
|
||||
elabCommand cmd
|
||||
|
||||
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
||||
if declNames.size != 1 then
|
||||
return false -- mutually inductive types are not supported yet
|
||||
else if (← isEnumType declNames[0]) then
|
||||
mkDecEqEnum declNames[0]
|
||||
return true
|
||||
else
|
||||
let indVal ← getConstInfoInduct declNames[0]
|
||||
if indVal.isNested then
|
||||
return false -- nested inductive types are not supported yet
|
||||
else
|
||||
let cmds ← liftTermElabM none <| mkDecEqCmds indVal
|
||||
cmds.forM elabCommand
|
||||
return true
|
||||
mkDecEq declNames[0]
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinDerivingHandler `DecidableEq mkDecEqInstanceHandler
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ inductive CXCursorKind where
|
|||
| CXCursor_FirstExtraDecl
|
||||
| CXCursor_LastExtraDecl
|
||||
| CXCursor_OverloadCandidate
|
||||
deriving BEq
|
||||
deriving BEq, DecidableEq
|
||||
|
||||
open CXCursorKind
|
||||
|
||||
|
|
@ -286,3 +286,5 @@ example (h : CXCursor_CUDAGlobalAttr = CXCursor_CUDAHostAttr) : False := by
|
|||
contradiction
|
||||
|
||||
#eval CXCursor_CUDAGlobalAttr == CXCursor_CUDAHostAttr
|
||||
|
||||
#eval decide (CXCursor_CUDAGlobalAttr = CXCursor_CUDAHostAttr)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue