feat: optimized deriving DecidableEq for enumeration types

The proof term is liner on the number of constructors, but type
checking is not linear because the reduction engine in the kernel is
not efficient.
This commit is contained in:
Leonardo de Moura 2021-09-08 16:19:31 -07:00
parent 9b0dfc4b90
commit 193d4dc9f5
2 changed files with 75 additions and 8 deletions

View file

@ -95,17 +95,82 @@ def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
open Command
def mkDecEq (declName : Name) : CommandElabM Bool := do
let indVal ← getConstInfoInduct declName
if indVal.isNested then
return false -- nested inductive types are not supported yet
else
let cmds ← liftTermElabM none <| mkDecEqCmds indVal
cmds.forM elabCommand
return true
def mkEnumOfNat (declName : Name) : MetaM Unit := do
let indVal ← getConstInfoInduct declName
let enumType := mkConst declName
let ctors := indVal.ctors.toArray
withLocalDeclD `n (mkConst ``Nat) fun n => do
let cond := mkConst ``cond [levelZero]
let mut value := mkConst ctors.back
for i in [:ctors.size-1] do
let j := ctors.size - i - 2
value := mkApp4 cond enumType (mkApp2 (mkConst ``Nat.beq) n (mkNatLit j)) (mkConst ctors[j]) value
value ← mkLambdaFVars #[n] value
let type ← mkArrow (mkConst ``Nat) enumType
addAndCompile <| Declaration.defnDecl {
name := Name.mkStr declName "ofNat"
levelParams := []
safety := DefinitionSafety.safe
hints := ReducibilityHints.abbrev
value, type
}
def mkEnumOfNatThm (declName : Name) : MetaM Unit := do
let indVal ← getConstInfoInduct declName
let toCtorIdx := mkConst (Name.mkStr declName "toCtorIdx")
let ofNat := mkConst (Name.mkStr declName "ofNat")
let enumType := mkConst declName
let eqEnum := mkApp (mkConst ``Eq [levelOne]) enumType
let rflEnum := mkApp (mkConst ``Eq.refl [levelOne]) enumType
let ctors := indVal.ctors
withLocalDeclD `x enumType fun x => do
let resultType := mkApp2 eqEnum (mkApp ofNat (mkApp toCtorIdx x)) x
let motive ← mkLambdaFVars #[x] resultType
let casesOn := mkConst (mkCasesOnName declName) [levelZero]
let mut value := mkApp2 casesOn motive x
for ctor in ctors do
value := mkApp value (mkApp rflEnum (mkConst ctor))
value ← mkLambdaFVars #[x] value
let type ← mkForallFVars #[x] resultType
addAndCompile <| Declaration.thmDecl {
name := Name.mkStr declName "ofNat_toCtorIdx"
levelParams := []
value, type
}
def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
liftTermElabM none <| mkEnumOfNat declName
liftTermElabM none <| mkEnumOfNatThm declName
let ofNatIdent := mkIdent (Name.mkStr declName "ofNat")
let auxThmIdent := mkIdent (Name.mkStr declName "ofNat_toCtorIdx")
let indVal ← getConstInfoInduct declName
let cmd ← `(
instance : DecidableEq $(mkIdent declName) :=
fun x y =>
if h : x.toCtorIdx = y.toCtorIdx then
isTrue (by have aux := congrArg $ofNatIdent h; rw [$auxThmIdent:ident, $auxThmIdent:ident] at aux; assumption)
else
isFalse fun h => by subst h; contradiction
)
elabCommand cmd
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size != 1 then
return false -- mutually inductive types are not supported yet
else if (← isEnumType declNames[0]) then
mkDecEqEnum declNames[0]
return true
else
let indVal ← getConstInfoInduct declNames[0]
if indVal.isNested then
return false -- nested inductive types are not supported yet
else
let cmds ← liftTermElabM none <| mkDecEqCmds indVal
cmds.forM elabCommand
return true
mkDecEq declNames[0]
builtin_initialize
registerBuiltinDerivingHandler `DecidableEq mkDecEqInstanceHandler

View file

@ -278,7 +278,7 @@ inductive CXCursorKind where
| CXCursor_FirstExtraDecl
| CXCursor_LastExtraDecl
| CXCursor_OverloadCandidate
deriving BEq
deriving BEq, DecidableEq
open CXCursorKind
@ -286,3 +286,5 @@ example (h : CXCursor_CUDAGlobalAttr = CXCursor_CUDAHostAttr) : False := by
contradiction
#eval CXCursor_CUDAGlobalAttr == CXCursor_CUDAHostAttr
#eval decide (CXCursor_CUDAGlobalAttr = CXCursor_CUDAHostAttr)