feat: add deriving DecidableEq

This commit is contained in:
Leonardo de Moura 2020-12-17 15:37:26 -08:00
parent 9a8de1774c
commit 87b6385bea
8 changed files with 216 additions and 39 deletions

View file

@ -794,4 +794,15 @@ def expandInterpolatedStr (interpStr : Syntax) (type : Syntax) (toTypeFn : Synta
def getSepArgs (stx : Syntax) : Array Syntax :=
stx.getArgs.getSepElems
end Lean.Syntax
end Syntax
/- Helper macros for builtin `deriving` -/
macro "decEqIsFalse! " h:ident : term =>
`(isFalse (by intro hn; injection hn; apply $h:ident _; assumption))
macro "decEqIsTrue! " hs:(ident*) : term => do
let hs ← hs.getArgs.mapM fun h => `(tactic| subst $h:ident)
let hs := hs.push (← `(tactic| rfl))
`(isTrue (by $[$hs;]*))
end Lean

View file

@ -7,3 +7,4 @@ import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Lean.Elab.Deriving.Inhabited
import Lean.Elab.Deriving.BEq
import Lean.Elab.Deriving.DecEq

View file

@ -11,44 +11,16 @@ namespace Lean.Elab.Deriving.BEq
open Lean.Parser.Term
open Meta
structure Header where
binders : Array Syntax
argNames : Array Name
target1Name : Name
target2Name : Name
open Binary (Header mkDiscrs)
def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
let argNames ← mkInductArgNames indVal
let binders ← mkImplicitBinders argNames
let targetType ← mkInductiveApp indVal argNames
let target1Name ← mkFreshUserName `x
let target2Name ← mkFreshUserName `y
let binders := binders ++ (← mkInstImplicitBinders `BEq indVal argNames)
let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType))
let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType))
let binders := binders ++ #[target1Binder, target2Binder]
return {
binders := binders
argNames := argNames
target1Name := target1Name
target2Name := target2Name
}
Binary.mkHeader ctx `BEq indVal
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do
let discrs ← mkDiscrs
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
where
mkDiscr (varName : Name) : TermElabM Syntax :=
`(Parser.Term.matchDiscr| $(mkIdent varName):term)
mkDiscrs : TermElabM (Array Syntax) := do
let mut discrs := #[]
-- add indices
for argName in argNames[indVal.nparams:] do
discrs := discrs.push (← mkDiscr argName)
return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name]
mkElseAlt : TermElabM Syntax := do
let mut patterns := #[]
-- add `_` pattern for indices
@ -102,7 +74,7 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do
let auxFunName ← ctx.auxFunNames[i]
let indVal ← ctx.typeInfos[i]
let header ← mkHeader ctx indVal
let mut body ← mkMatch ctx header indVal auxFunName header.argNames
let mut body ← mkMatch ctx header indVal auxFunName
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
body ← mkLet letDecls body
@ -122,7 +94,7 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
end)
private def mkBEqInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext declNames[0]
let ctx ← mkContext "beq" declNames[0]
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq declNames)
trace[Elab.Deriving.beq]! "\n{cmds}"
return cmds

View file

@ -0,0 +1,114 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Transform
import Lean.Meta.Inductive
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
namespace Lean.Elab.Deriving.DecEq
open Lean.Parser.Term
open Meta
open Binary (Header mkDiscrs)
def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
Binary.mkHeader ctx `DecidableEq indVal
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
where
mkSameCtorRhs : List (Syntax × Syntax × Bool) → Array Syntax → TermElabM Syntax
| [], hs => `(decEqIsTrue! $hs:ident*)
| (a, b, recField) :: todo, hs => withFreshMacroScope do
let discr ←
if recField then
`($(mkIdent auxFunName) $a $b)
else
`(decEq $a $b)
let h ← `(h)
`(match $discr:term with
| isTrue h => $(← mkSameCtorRhs todo (hs.push h)):term
| isFalse h => decEqIsFalse! h)
mkAlts : TermElabM (Array Syntax) := do
let mut alts := #[]
for ctorName₁ in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName₁
for ctorName₂ in indVal.ctors do
let mut patterns := #[]
-- add `_` pattern for indices
for i in [:indVal.nindices] do
patterns := patterns.push (← `(_))
if ctorName₁ == ctorName₂ then
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := patterns
let mut ctorArgs1 := #[]
let mut ctorArgs2 := #[]
-- add `_` for inductive parameters, they are inaccessible
for i in [:indVal.nparams] do
ctorArgs1 := ctorArgs1.push (← `(_))
ctorArgs2 := ctorArgs2.push (← `(_))
let mut todo := #[]
for i in [:ctorInfo.nfields] do
let x := xs[indVal.nparams + i]
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to compare
ctorArgs1 := ctorArgs1.push (← `(_))
ctorArgs2 := ctorArgs2.push (← `(_))
else
let a := mkIdent (← mkFreshUserName `a)
let b := mkIdent (← mkFreshUserName `b)
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let recField := (← inferType x).isAppOf indVal.name
todo := todo.push (a, b, recField)
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
let rhs ← mkSameCtorRhs todo.toList #[]
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
alts := alts.push alt
else if (← compatibleCtors ctorName₁ ctorName₂) then
patterns := patterns ++ #[(← `($(mkIdent ctorName₁) ..)), (← `($(mkIdent ctorName₂) ..))]
let rhs ← `(isFalse (by intro h; injection h))
alts ← alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
return alts
def mkAuxFunction (ctx : Context) : TermElabM Syntax := do
let auxFunName ← ctx.auxFunNames[0]
let indVal ← ctx.typeInfos[0]
let header ← mkHeader ctx indVal
let mut body ← mkMatch ctx header indVal auxFunName header.argNames
let binders := header.binders
let type ← `(Decidable ($(mkIdent header.target1Name) = $(mkIdent header.target2Name)))
`(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : $type:term := $body:term)
def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
let ctx ← mkContext "decEq" indVal.name
let cmds := #[← mkAuxFunction ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
trace[Elab.Deriving.decEq]! "\n{cmds}"
return cmds
open Command
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size != 1 then
return false -- mutually inductive types are not supported yet
else
let indVal ← getConstInfoInduct declNames[0]
if indVal.isNested then
return false -- nested inductive types are not supported yet
else
let cmds ← liftTermElabM none <| mkDecEqCmds indVal
cmds.forM elabCommand
return true
builtin_initialize
registerBuiltinDerivingHandler `DecidableEq mkDecEqInstanceHandler
registerTraceClass `Elab.Deriving.decEq
end Lean.Elab.Deriving.DecEq

View file

@ -51,7 +51,7 @@ structure Context where
auxFunNames : Array Name
usePartial : Bool
def mkContext (typeName : Name) : TermElabM Context := do
def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do
let indVal ← getConstInfoInduct typeName
let mut typeInfos := #[]
for typeName in indVal.all do
@ -59,7 +59,7 @@ def mkContext (typeName : Name) : TermElabM Context := do
let mut auxFunNames := #[]
for typeName in indVal.all do
match typeName.eraseMacroScopes with
| Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| "beq" ++ t)
| Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| fnPrefix ++ t)
| _ => auxFunNames := auxFunNames.push (← mkFreshUserName `instFn)
trace[Elab.Deriving.beq]! "{auxFunNames}"
let usePartial := indVal.isNested || typeInfos.size > 1
@ -91,7 +91,7 @@ def mkLet (letDecls : Array Syntax) (body : Syntax) : TermElabM Syntax :=
letDecls.foldrM (init := body) fun letDecl body =>
`(let $letDecl:letDecl; $body)
def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) : TermElabM (Array Syntax) := do
def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) (useAnonCtor := true) : TermElabM (Array Syntax) := do
let mut instances := #[]
for i in [:ctx.typeInfos.size] do
let indVal := ctx.typeInfos[i]
@ -102,10 +102,47 @@ def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) :
let binders := binders ++ (← mkInstImplicitBinders className indVal argNames)
let indType ← mkInductiveApp indVal argNames
let type ← `($(mkIdent className) $indType)
let val ← `(⟨$(mkIdent auxFunName)⟩)
let val ← if useAnonCtor then `(⟨$(mkIdent auxFunName)⟩) else pure <| mkIdent auxFunName
let instCmd ← `(instance $binders:implicitBinder* : $type := $val)
trace[Meta.debug]! "\n{instCmd}"
instances := instances.push instCmd
return instances
namespace Binary
structure Header where
binders : Array Syntax
argNames : Array Name
target1Name : Name
target2Name : Name
def mkHeader (ctx : Context) (className : Name) (indVal : InductiveVal) : TermElabM Header := do
let argNames ← mkInductArgNames indVal
let binders ← mkImplicitBinders argNames
let targetType ← mkInductiveApp indVal argNames
let target1Name ← mkFreshUserName `x
let target2Name ← mkFreshUserName `y
let binders := binders ++ (← mkInstImplicitBinders className indVal argNames)
let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType))
let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType))
let binders := binders ++ #[target1Binder, target2Binder]
return {
binders := binders
argNames := argNames
target1Name := target1Name
target2Name := target2Name
}
def mkDiscr (varName : Name) : TermElabM Syntax :=
`(Parser.Term.matchDiscr| $(mkIdent varName):term)
def mkDiscrs (header : Header) (indVal : InductiveVal) : TermElabM (Array Syntax) := do
let mut discrs := #[]
-- add indices
for argName in header.argNames[indVal.nparams:] do
discrs := discrs.push (← mkDiscr argName)
return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name]
end Binary
end Lean.Elab.Deriving

View file

@ -27,3 +27,4 @@ import Lean.Meta.ForEachExpr
import Lean.Meta.Transform
import Lean.Meta.PPGoal
import Lean.Meta.UnificationHint
import Lean.Meta.Inductive

View file

@ -0,0 +1,23 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.ExprDefEq
/- Helper methods for inductive datatypes -/
namespace Lean.Meta
/- Return true if the types of the given constructors are compatible. -/
def compatibleCtors (ctorName₁ ctorName₂ : Name) : MetaM Bool := do
let ctorInfo₁ ← getConstInfoCtor ctorName₁
let ctorInfo₂ ← getConstInfoCtor ctorName₂
if ctorInfo₁.induct != ctorInfo₂.induct then
return false
else
let (_, _, ctorType₁) ← forallMetaTelescope ctorInfo₁.type
let (_, _, ctorType₂) ← forallMetaTelescope ctorInfo₂.type
isDefEq ctorType₁ ctorType₂
end Lean.Meta

18
tests/lean/run/decEq.lean Normal file
View file

@ -0,0 +1,18 @@
import Lean
inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons : α → {n : Nat} → Vec α n → Vec α (n+1)
deriving DecidableEq
inductive Test (α : Type)
| mk₀
| mk₁ : (n : Nat) → (α × α) → List α → Test α
| mk₂ : Test αα → Test α
deriving DecidableEq
def t1 [DecidableEq α] : DecidableEq (Vec α n) :=
inferInstance
def t2 [DecidableEq α] : DecidableEq (Test α) :=
inferInstance