diff --git a/src/Lean/Elab/Deriving/DecEq.lean b/src/Lean/Elab/Deriving/DecEq.lean index 45b831fa24..a3ca71b167 100644 --- a/src/Lean/Elab/Deriving/DecEq.lean +++ b/src/Lean/Elab/Deriving/DecEq.lean @@ -95,17 +95,82 @@ def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do open Command +def mkDecEq (declName : Name) : CommandElabM Bool := do + let indVal ← getConstInfoInduct declName + if indVal.isNested then + return false -- nested inductive types are not supported yet + else + let cmds ← liftTermElabM none <| mkDecEqCmds indVal + cmds.forM elabCommand + return true + +def mkEnumOfNat (declName : Name) : MetaM Unit := do + let indVal ← getConstInfoInduct declName + let enumType := mkConst declName + let ctors := indVal.ctors.toArray + withLocalDeclD `n (mkConst ``Nat) fun n => do + let cond := mkConst ``cond [levelZero] + let mut value := mkConst ctors.back + for i in [:ctors.size-1] do + let j := ctors.size - i - 2 + value := mkApp4 cond enumType (mkApp2 (mkConst ``Nat.beq) n (mkNatLit j)) (mkConst ctors[j]) value + value ← mkLambdaFVars #[n] value + let type ← mkArrow (mkConst ``Nat) enumType + addAndCompile <| Declaration.defnDecl { + name := Name.mkStr declName "ofNat" + levelParams := [] + safety := DefinitionSafety.safe + hints := ReducibilityHints.abbrev + value, type + } + +def mkEnumOfNatThm (declName : Name) : MetaM Unit := do + let indVal ← getConstInfoInduct declName + let toCtorIdx := mkConst (Name.mkStr declName "toCtorIdx") + let ofNat := mkConst (Name.mkStr declName "ofNat") + let enumType := mkConst declName + let eqEnum := mkApp (mkConst ``Eq [levelOne]) enumType + let rflEnum := mkApp (mkConst ``Eq.refl [levelOne]) enumType + let ctors := indVal.ctors + withLocalDeclD `x enumType fun x => do + let resultType := mkApp2 eqEnum (mkApp ofNat (mkApp toCtorIdx x)) x + let motive ← mkLambdaFVars #[x] resultType + let casesOn := mkConst (mkCasesOnName declName) [levelZero] + let mut value := mkApp2 casesOn motive x + for ctor in ctors do + value := mkApp value (mkApp rflEnum (mkConst ctor)) + value ← mkLambdaFVars #[x] value + let type ← mkForallFVars #[x] resultType + addAndCompile <| Declaration.thmDecl { + name := Name.mkStr declName "ofNat_toCtorIdx" + levelParams := [] + value, type + } + +def mkDecEqEnum (declName : Name) : CommandElabM Unit := do + liftTermElabM none <| mkEnumOfNat declName + liftTermElabM none <| mkEnumOfNatThm declName + let ofNatIdent := mkIdent (Name.mkStr declName "ofNat") + let auxThmIdent := mkIdent (Name.mkStr declName "ofNat_toCtorIdx") + let indVal ← getConstInfoInduct declName + let cmd ← `( + instance : DecidableEq $(mkIdent declName) := + fun x y => + if h : x.toCtorIdx = y.toCtorIdx then + isTrue (by have aux := congrArg $ofNatIdent h; rw [$auxThmIdent:ident, $auxThmIdent:ident] at aux; assumption) + else + isFalse fun h => by subst h; contradiction + ) + elabCommand cmd + def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do if declNames.size != 1 then return false -- mutually inductive types are not supported yet + else if (← isEnumType declNames[0]) then + mkDecEqEnum declNames[0] + return true else - let indVal ← getConstInfoInduct declNames[0] - if indVal.isNested then - return false -- nested inductive types are not supported yet - else - let cmds ← liftTermElabM none <| mkDecEqCmds indVal - cmds.forM elabCommand - return true + mkDecEq declNames[0] builtin_initialize registerBuiltinDerivingHandler `DecidableEq mkDecEqInstanceHandler diff --git a/tests/lean/run/654.lean b/tests/lean/run/654.lean index 2dcfa25675..4c27eebba6 100644 --- a/tests/lean/run/654.lean +++ b/tests/lean/run/654.lean @@ -278,7 +278,7 @@ inductive CXCursorKind where | CXCursor_FirstExtraDecl | CXCursor_LastExtraDecl | CXCursor_OverloadCandidate - deriving BEq + deriving BEq, DecidableEq open CXCursorKind @@ -286,3 +286,5 @@ example (h : CXCursor_CUDAGlobalAttr = CXCursor_CUDAHostAttr) : False := by contradiction #eval CXCursor_CUDAGlobalAttr == CXCursor_CUDAHostAttr + +#eval decide (CXCursor_CUDAGlobalAttr = CXCursor_CUDAHostAttr)