From 87b6385bea3e7ea57e82ba778415c733ea5d958d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 17 Dec 2020 15:37:26 -0800 Subject: [PATCH] feat: add `deriving DecidableEq` --- src/Init/Meta.lean | 13 +++- src/Lean/Elab/Deriving.lean | 1 + src/Lean/Elab/Deriving/BEq.lean | 40 ++--------- src/Lean/Elab/Deriving/DecEq.lean | 114 ++++++++++++++++++++++++++++++ src/Lean/Elab/Deriving/Util.lean | 45 ++++++++++-- src/Lean/Meta.lean | 1 + src/Lean/Meta/Inductive.lean | 23 ++++++ tests/lean/run/decEq.lean | 18 +++++ 8 files changed, 216 insertions(+), 39 deletions(-) create mode 100644 src/Lean/Elab/Deriving/DecEq.lean create mode 100644 src/Lean/Meta/Inductive.lean create mode 100644 tests/lean/run/decEq.lean diff --git a/src/Init/Meta.lean b/src/Init/Meta.lean index 1a279f6278..40efeb9a10 100644 --- a/src/Init/Meta.lean +++ b/src/Init/Meta.lean @@ -794,4 +794,15 @@ def expandInterpolatedStr (interpStr : Syntax) (type : Syntax) (toTypeFn : Synta def getSepArgs (stx : Syntax) : Array Syntax := stx.getArgs.getSepElems -end Lean.Syntax +end Syntax + +/- Helper macros for builtin `deriving` -/ +macro "decEqIsFalse! " h:ident : term => + `(isFalse (by intro hn; injection hn; apply $h:ident _; assumption)) + +macro "decEqIsTrue! " hs:(ident*) : term => do + let hs ← hs.getArgs.mapM fun h => `(tactic| subst $h:ident) + let hs := hs.push (← `(tactic| rfl)) + `(isTrue (by $[$hs;]*)) + +end Lean diff --git a/src/Lean/Elab/Deriving.lean b/src/Lean/Elab/Deriving.lean index bbca818a5f..de069d1c17 100644 --- a/src/Lean/Elab/Deriving.lean +++ b/src/Lean/Elab/Deriving.lean @@ -7,3 +7,4 @@ import Lean.Elab.Deriving.Basic import Lean.Elab.Deriving.Util import Lean.Elab.Deriving.Inhabited import Lean.Elab.Deriving.BEq +import Lean.Elab.Deriving.DecEq diff --git a/src/Lean/Elab/Deriving/BEq.lean b/src/Lean/Elab/Deriving/BEq.lean index ff59524da0..3f9169b031 100644 --- a/src/Lean/Elab/Deriving/BEq.lean +++ b/src/Lean/Elab/Deriving/BEq.lean @@ -11,44 +11,16 @@ namespace Lean.Elab.Deriving.BEq open Lean.Parser.Term open Meta -structure Header where - binders : Array Syntax - argNames : Array Name - target1Name : Name - target2Name : Name +open Binary (Header mkDiscrs) def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do - let argNames ← mkInductArgNames indVal - let binders ← mkImplicitBinders argNames - let targetType ← mkInductiveApp indVal argNames - let target1Name ← mkFreshUserName `x - let target2Name ← mkFreshUserName `y - let binders := binders ++ (← mkInstImplicitBinders `BEq indVal argNames) - let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType)) - let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType)) - let binders := binders ++ #[target1Binder, target2Binder] - return { - binders := binders - argNames := argNames - target1Name := target1Name - target2Name := target2Name - } + Binary.mkHeader ctx `BEq indVal -def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do - let discrs ← mkDiscrs +def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do + let discrs ← mkDiscrs header indVal let alts ← mkAlts `(match $[$discrs],* with $alts:matchAlt*) where - mkDiscr (varName : Name) : TermElabM Syntax := - `(Parser.Term.matchDiscr| $(mkIdent varName):term) - - mkDiscrs : TermElabM (Array Syntax) := do - let mut discrs := #[] - -- add indices - for argName in argNames[indVal.nparams:] do - discrs := discrs.push (← mkDiscr argName) - return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name] - mkElseAlt : TermElabM Syntax := do let mut patterns := #[] -- add `_` pattern for indices @@ -102,7 +74,7 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do let auxFunName ← ctx.auxFunNames[i] let indVal ← ctx.typeInfos[i] let header ← mkHeader ctx indVal - let mut body ← mkMatch ctx header indVal auxFunName header.argNames + let mut body ← mkMatch ctx header indVal auxFunName if ctx.usePartial then let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames body ← mkLet letDecls body @@ -122,7 +94,7 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do end) private def mkBEqInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext declNames[0] + let ctx ← mkContext "beq" declNames[0] let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq declNames) trace[Elab.Deriving.beq]! "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/DecEq.lean b/src/Lean/Elab/Deriving/DecEq.lean new file mode 100644 index 0000000000..f5fdbe7e62 --- /dev/null +++ b/src/Lean/Elab/Deriving/DecEq.lean @@ -0,0 +1,114 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Meta.Transform +import Lean.Meta.Inductive +import Lean.Elab.Deriving.Basic +import Lean.Elab.Deriving.Util + +namespace Lean.Elab.Deriving.DecEq +open Lean.Parser.Term +open Meta +open Binary (Header mkDiscrs) + +def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do + Binary.mkHeader ctx `DecidableEq indVal + +def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do + let discrs ← mkDiscrs header indVal + let alts ← mkAlts + `(match $[$discrs],* with $alts:matchAlt*) +where + mkSameCtorRhs : List (Syntax × Syntax × Bool) → Array Syntax → TermElabM Syntax + | [], hs => `(decEqIsTrue! $hs:ident*) + | (a, b, recField) :: todo, hs => withFreshMacroScope do + let discr ← + if recField then + `($(mkIdent auxFunName) $a $b) + else + `(decEq $a $b) + let h ← `(h) + `(match $discr:term with + | isTrue h => $(← mkSameCtorRhs todo (hs.push h)):term + | isFalse h => decEqIsFalse! h) + + mkAlts : TermElabM (Array Syntax) := do + let mut alts := #[] + for ctorName₁ in indVal.ctors do + let ctorInfo ← getConstInfoCtor ctorName₁ + for ctorName₂ in indVal.ctors do + let mut patterns := #[] + -- add `_` pattern for indices + for i in [:indVal.nindices] do + patterns := patterns.push (← `(_)) + if ctorName₁ == ctorName₂ then + let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies + let mut patterns := patterns + let mut ctorArgs1 := #[] + let mut ctorArgs2 := #[] + -- add `_` for inductive parameters, they are inaccessible + for i in [:indVal.nparams] do + ctorArgs1 := ctorArgs1.push (← `(_)) + ctorArgs2 := ctorArgs2.push (← `(_)) + let mut todo := #[] + for i in [:ctorInfo.nfields] do + let x := xs[indVal.nparams + i] + if type.containsFVar x.fvarId! then + -- If resulting type depends on this field, we don't need to compare + ctorArgs1 := ctorArgs1.push (← `(_)) + ctorArgs2 := ctorArgs2.push (← `(_)) + else + let a := mkIdent (← mkFreshUserName `a) + let b := mkIdent (← mkFreshUserName `b) + ctorArgs1 := ctorArgs1.push a + ctorArgs2 := ctorArgs2.push b + let recField := (← inferType x).isAppOf indVal.name + todo := todo.push (a, b, recField) + patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*)) + patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*)) + let rhs ← mkSameCtorRhs todo.toList #[] + `(matchAltExpr| | $[$patterns:term],* => $rhs:term) + alts := alts.push alt + else if (← compatibleCtors ctorName₁ ctorName₂) then + patterns := patterns ++ #[(← `($(mkIdent ctorName₁) ..)), (← `($(mkIdent ctorName₂) ..))] + let rhs ← `(isFalse (by intro h; injection h)) + alts ← alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term)) + return alts + +def mkAuxFunction (ctx : Context) : TermElabM Syntax := do + let auxFunName ← ctx.auxFunNames[0] + let indVal ← ctx.typeInfos[0] + let header ← mkHeader ctx indVal + let mut body ← mkMatch ctx header indVal auxFunName header.argNames + let binders := header.binders + let type ← `(Decidable ($(mkIdent header.target1Name) = $(mkIdent header.target2Name))) + `(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : $type:term := $body:term) + +def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do + let ctx ← mkContext "decEq" indVal.name + let cmds := #[← mkAuxFunction ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false)) + trace[Elab.Deriving.decEq]! "\n{cmds}" + return cmds + +open Command + +def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do + if declNames.size != 1 then + return false -- mutually inductive types are not supported yet + else + let indVal ← getConstInfoInduct declNames[0] + if indVal.isNested then + return false -- nested inductive types are not supported yet + else + let cmds ← liftTermElabM none <| mkDecEqCmds indVal + cmds.forM elabCommand + return true + +builtin_initialize + registerBuiltinDerivingHandler `DecidableEq mkDecEqInstanceHandler + registerTraceClass `Elab.Deriving.decEq + +end Lean.Elab.Deriving.DecEq diff --git a/src/Lean/Elab/Deriving/Util.lean b/src/Lean/Elab/Deriving/Util.lean index c014489711..c90e1ee2ac 100644 --- a/src/Lean/Elab/Deriving/Util.lean +++ b/src/Lean/Elab/Deriving/Util.lean @@ -51,7 +51,7 @@ structure Context where auxFunNames : Array Name usePartial : Bool -def mkContext (typeName : Name) : TermElabM Context := do +def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do let indVal ← getConstInfoInduct typeName let mut typeInfos := #[] for typeName in indVal.all do @@ -59,7 +59,7 @@ def mkContext (typeName : Name) : TermElabM Context := do let mut auxFunNames := #[] for typeName in indVal.all do match typeName.eraseMacroScopes with - | Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| "beq" ++ t) + | Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| fnPrefix ++ t) | _ => auxFunNames := auxFunNames.push (← mkFreshUserName `instFn) trace[Elab.Deriving.beq]! "{auxFunNames}" let usePartial := indVal.isNested || typeInfos.size > 1 @@ -91,7 +91,7 @@ def mkLet (letDecls : Array Syntax) (body : Syntax) : TermElabM Syntax := letDecls.foldrM (init := body) fun letDecl body => `(let $letDecl:letDecl; $body) -def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) : TermElabM (Array Syntax) := do +def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) (useAnonCtor := true) : TermElabM (Array Syntax) := do let mut instances := #[] for i in [:ctx.typeInfos.size] do let indVal := ctx.typeInfos[i] @@ -102,10 +102,47 @@ def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) : let binders := binders ++ (← mkInstImplicitBinders className indVal argNames) let indType ← mkInductiveApp indVal argNames let type ← `($(mkIdent className) $indType) - let val ← `(⟨$(mkIdent auxFunName)⟩) + let val ← if useAnonCtor then `(⟨$(mkIdent auxFunName)⟩) else pure <| mkIdent auxFunName let instCmd ← `(instance $binders:implicitBinder* : $type := $val) trace[Meta.debug]! "\n{instCmd}" instances := instances.push instCmd return instances +namespace Binary + +structure Header where + binders : Array Syntax + argNames : Array Name + target1Name : Name + target2Name : Name + +def mkHeader (ctx : Context) (className : Name) (indVal : InductiveVal) : TermElabM Header := do + let argNames ← mkInductArgNames indVal + let binders ← mkImplicitBinders argNames + let targetType ← mkInductiveApp indVal argNames + let target1Name ← mkFreshUserName `x + let target2Name ← mkFreshUserName `y + let binders := binders ++ (← mkInstImplicitBinders className indVal argNames) + let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType)) + let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType)) + let binders := binders ++ #[target1Binder, target2Binder] + return { + binders := binders + argNames := argNames + target1Name := target1Name + target2Name := target2Name + } + +def mkDiscr (varName : Name) : TermElabM Syntax := + `(Parser.Term.matchDiscr| $(mkIdent varName):term) + +def mkDiscrs (header : Header) (indVal : InductiveVal) : TermElabM (Array Syntax) := do + let mut discrs := #[] + -- add indices + for argName in header.argNames[indVal.nparams:] do + discrs := discrs.push (← mkDiscr argName) + return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name] + +end Binary + end Lean.Elab.Deriving diff --git a/src/Lean/Meta.lean b/src/Lean/Meta.lean index c33a5e1a9d..f7678f14d5 100644 --- a/src/Lean/Meta.lean +++ b/src/Lean/Meta.lean @@ -27,3 +27,4 @@ import Lean.Meta.ForEachExpr import Lean.Meta.Transform import Lean.Meta.PPGoal import Lean.Meta.UnificationHint +import Lean.Meta.Inductive diff --git a/src/Lean/Meta/Inductive.lean b/src/Lean/Meta/Inductive.lean new file mode 100644 index 0000000000..2959699c99 --- /dev/null +++ b/src/Lean/Meta/Inductive.lean @@ -0,0 +1,23 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Meta.ExprDefEq + +/- Helper methods for inductive datatypes -/ + +namespace Lean.Meta + +/- Return true if the types of the given constructors are compatible. -/ +def compatibleCtors (ctorName₁ ctorName₂ : Name) : MetaM Bool := do + let ctorInfo₁ ← getConstInfoCtor ctorName₁ + let ctorInfo₂ ← getConstInfoCtor ctorName₂ + if ctorInfo₁.induct != ctorInfo₂.induct then + return false + else + let (_, _, ctorType₁) ← forallMetaTelescope ctorInfo₁.type + let (_, _, ctorType₂) ← forallMetaTelescope ctorInfo₂.type + isDefEq ctorType₁ ctorType₂ + +end Lean.Meta diff --git a/tests/lean/run/decEq.lean b/tests/lean/run/decEq.lean new file mode 100644 index 0000000000..ab0a9d9943 --- /dev/null +++ b/tests/lean/run/decEq.lean @@ -0,0 +1,18 @@ +import Lean + +inductive Vec (α : Type u) : Nat → Type u + | nil : Vec α 0 + | cons : α → {n : Nat} → Vec α n → Vec α (n+1) + deriving DecidableEq + +inductive Test (α : Type) + | mk₀ + | mk₁ : (n : Nat) → (α × α) → List α → Test α + | mk₂ : Test α → α → Test α + deriving DecidableEq + +def t1 [DecidableEq α] : DecidableEq (Vec α n) := + inferInstance + +def t2 [DecidableEq α] : DecidableEq (Test α) := + inferInstance