feat: add deriving DecidableEq
This commit is contained in:
parent
9a8de1774c
commit
87b6385bea
8 changed files with 216 additions and 39 deletions
|
|
@ -794,4 +794,15 @@ def expandInterpolatedStr (interpStr : Syntax) (type : Syntax) (toTypeFn : Synta
|
|||
def getSepArgs (stx : Syntax) : Array Syntax :=
|
||||
stx.getArgs.getSepElems
|
||||
|
||||
end Lean.Syntax
|
||||
end Syntax
|
||||
|
||||
/- Helper macros for builtin `deriving` -/
|
||||
macro "decEqIsFalse! " h:ident : term =>
|
||||
`(isFalse (by intro hn; injection hn; apply $h:ident _; assumption))
|
||||
|
||||
macro "decEqIsTrue! " hs:(ident*) : term => do
|
||||
let hs ← hs.getArgs.mapM fun h => `(tactic| subst $h:ident)
|
||||
let hs := hs.push (← `(tactic| rfl))
|
||||
`(isTrue (by $[$hs;]*))
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ import Lean.Elab.Deriving.Basic
|
|||
import Lean.Elab.Deriving.Util
|
||||
import Lean.Elab.Deriving.Inhabited
|
||||
import Lean.Elab.Deriving.BEq
|
||||
import Lean.Elab.Deriving.DecEq
|
||||
|
|
|
|||
|
|
@ -11,44 +11,16 @@ namespace Lean.Elab.Deriving.BEq
|
|||
open Lean.Parser.Term
|
||||
open Meta
|
||||
|
||||
structure Header where
|
||||
binders : Array Syntax
|
||||
argNames : Array Name
|
||||
target1Name : Name
|
||||
target2Name : Name
|
||||
open Binary (Header mkDiscrs)
|
||||
|
||||
def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
|
||||
let argNames ← mkInductArgNames indVal
|
||||
let binders ← mkImplicitBinders argNames
|
||||
let targetType ← mkInductiveApp indVal argNames
|
||||
let target1Name ← mkFreshUserName `x
|
||||
let target2Name ← mkFreshUserName `y
|
||||
let binders := binders ++ (← mkInstImplicitBinders `BEq indVal argNames)
|
||||
let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType))
|
||||
let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType))
|
||||
let binders := binders ++ #[target1Binder, target2Binder]
|
||||
return {
|
||||
binders := binders
|
||||
argNames := argNames
|
||||
target1Name := target1Name
|
||||
target2Name := target2Name
|
||||
}
|
||||
Binary.mkHeader ctx `BEq indVal
|
||||
|
||||
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do
|
||||
let discrs ← mkDiscrs
|
||||
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do
|
||||
let discrs ← mkDiscrs header indVal
|
||||
let alts ← mkAlts
|
||||
`(match $[$discrs],* with $alts:matchAlt*)
|
||||
where
|
||||
mkDiscr (varName : Name) : TermElabM Syntax :=
|
||||
`(Parser.Term.matchDiscr| $(mkIdent varName):term)
|
||||
|
||||
mkDiscrs : TermElabM (Array Syntax) := do
|
||||
let mut discrs := #[]
|
||||
-- add indices
|
||||
for argName in argNames[indVal.nparams:] do
|
||||
discrs := discrs.push (← mkDiscr argName)
|
||||
return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name]
|
||||
|
||||
mkElseAlt : TermElabM Syntax := do
|
||||
let mut patterns := #[]
|
||||
-- add `_` pattern for indices
|
||||
|
|
@ -102,7 +74,7 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do
|
|||
let auxFunName ← ctx.auxFunNames[i]
|
||||
let indVal ← ctx.typeInfos[i]
|
||||
let header ← mkHeader ctx indVal
|
||||
let mut body ← mkMatch ctx header indVal auxFunName header.argNames
|
||||
let mut body ← mkMatch ctx header indVal auxFunName
|
||||
if ctx.usePartial then
|
||||
let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
|
||||
body ← mkLet letDecls body
|
||||
|
|
@ -122,7 +94,7 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
|
|||
end)
|
||||
|
||||
private def mkBEqInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext declNames[0]
|
||||
let ctx ← mkContext "beq" declNames[0]
|
||||
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq declNames)
|
||||
trace[Elab.Deriving.beq]! "\n{cmds}"
|
||||
return cmds
|
||||
|
|
|
|||
114
src/Lean/Elab/Deriving/DecEq.lean
Normal file
114
src/Lean/Elab/Deriving/DecEq.lean
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Meta.Transform
|
||||
import Lean.Meta.Inductive
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.Deriving.Util
|
||||
|
||||
namespace Lean.Elab.Deriving.DecEq
|
||||
open Lean.Parser.Term
|
||||
open Meta
|
||||
open Binary (Header mkDiscrs)
|
||||
|
||||
def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
|
||||
Binary.mkHeader ctx `DecidableEq indVal
|
||||
|
||||
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do
|
||||
let discrs ← mkDiscrs header indVal
|
||||
let alts ← mkAlts
|
||||
`(match $[$discrs],* with $alts:matchAlt*)
|
||||
where
|
||||
mkSameCtorRhs : List (Syntax × Syntax × Bool) → Array Syntax → TermElabM Syntax
|
||||
| [], hs => `(decEqIsTrue! $hs:ident*)
|
||||
| (a, b, recField) :: todo, hs => withFreshMacroScope do
|
||||
let discr ←
|
||||
if recField then
|
||||
`($(mkIdent auxFunName) $a $b)
|
||||
else
|
||||
`(decEq $a $b)
|
||||
let h ← `(h)
|
||||
`(match $discr:term with
|
||||
| isTrue h => $(← mkSameCtorRhs todo (hs.push h)):term
|
||||
| isFalse h => decEqIsFalse! h)
|
||||
|
||||
mkAlts : TermElabM (Array Syntax) := do
|
||||
let mut alts := #[]
|
||||
for ctorName₁ in indVal.ctors do
|
||||
let ctorInfo ← getConstInfoCtor ctorName₁
|
||||
for ctorName₂ in indVal.ctors do
|
||||
let mut patterns := #[]
|
||||
-- add `_` pattern for indices
|
||||
for i in [:indVal.nindices] do
|
||||
patterns := patterns.push (← `(_))
|
||||
if ctorName₁ == ctorName₂ then
|
||||
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
|
||||
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
|
||||
let mut patterns := patterns
|
||||
let mut ctorArgs1 := #[]
|
||||
let mut ctorArgs2 := #[]
|
||||
-- add `_` for inductive parameters, they are inaccessible
|
||||
for i in [:indVal.nparams] do
|
||||
ctorArgs1 := ctorArgs1.push (← `(_))
|
||||
ctorArgs2 := ctorArgs2.push (← `(_))
|
||||
let mut todo := #[]
|
||||
for i in [:ctorInfo.nfields] do
|
||||
let x := xs[indVal.nparams + i]
|
||||
if type.containsFVar x.fvarId! then
|
||||
-- If resulting type depends on this field, we don't need to compare
|
||||
ctorArgs1 := ctorArgs1.push (← `(_))
|
||||
ctorArgs2 := ctorArgs2.push (← `(_))
|
||||
else
|
||||
let a := mkIdent (← mkFreshUserName `a)
|
||||
let b := mkIdent (← mkFreshUserName `b)
|
||||
ctorArgs1 := ctorArgs1.push a
|
||||
ctorArgs2 := ctorArgs2.push b
|
||||
let recField := (← inferType x).isAppOf indVal.name
|
||||
todo := todo.push (a, b, recField)
|
||||
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
|
||||
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
|
||||
let rhs ← mkSameCtorRhs todo.toList #[]
|
||||
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
|
||||
alts := alts.push alt
|
||||
else if (← compatibleCtors ctorName₁ ctorName₂) then
|
||||
patterns := patterns ++ #[(← `($(mkIdent ctorName₁) ..)), (← `($(mkIdent ctorName₂) ..))]
|
||||
let rhs ← `(isFalse (by intro h; injection h))
|
||||
alts ← alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
|
||||
return alts
|
||||
|
||||
def mkAuxFunction (ctx : Context) : TermElabM Syntax := do
|
||||
let auxFunName ← ctx.auxFunNames[0]
|
||||
let indVal ← ctx.typeInfos[0]
|
||||
let header ← mkHeader ctx indVal
|
||||
let mut body ← mkMatch ctx header indVal auxFunName header.argNames
|
||||
let binders := header.binders
|
||||
let type ← `(Decidable ($(mkIdent header.target1Name) = $(mkIdent header.target2Name)))
|
||||
`(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : $type:term := $body:term)
|
||||
|
||||
def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext "decEq" indVal.name
|
||||
let cmds := #[← mkAuxFunction ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
|
||||
trace[Elab.Deriving.decEq]! "\n{cmds}"
|
||||
return cmds
|
||||
|
||||
open Command
|
||||
|
||||
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
||||
if declNames.size != 1 then
|
||||
return false -- mutually inductive types are not supported yet
|
||||
else
|
||||
let indVal ← getConstInfoInduct declNames[0]
|
||||
if indVal.isNested then
|
||||
return false -- nested inductive types are not supported yet
|
||||
else
|
||||
let cmds ← liftTermElabM none <| mkDecEqCmds indVal
|
||||
cmds.forM elabCommand
|
||||
return true
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinDerivingHandler `DecidableEq mkDecEqInstanceHandler
|
||||
registerTraceClass `Elab.Deriving.decEq
|
||||
|
||||
end Lean.Elab.Deriving.DecEq
|
||||
|
|
@ -51,7 +51,7 @@ structure Context where
|
|||
auxFunNames : Array Name
|
||||
usePartial : Bool
|
||||
|
||||
def mkContext (typeName : Name) : TermElabM Context := do
|
||||
def mkContext (fnPrefix : String) (typeName : Name) : TermElabM Context := do
|
||||
let indVal ← getConstInfoInduct typeName
|
||||
let mut typeInfos := #[]
|
||||
for typeName in indVal.all do
|
||||
|
|
@ -59,7 +59,7 @@ def mkContext (typeName : Name) : TermElabM Context := do
|
|||
let mut auxFunNames := #[]
|
||||
for typeName in indVal.all do
|
||||
match typeName.eraseMacroScopes with
|
||||
| Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| "beq" ++ t)
|
||||
| Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| fnPrefix ++ t)
|
||||
| _ => auxFunNames := auxFunNames.push (← mkFreshUserName `instFn)
|
||||
trace[Elab.Deriving.beq]! "{auxFunNames}"
|
||||
let usePartial := indVal.isNested || typeInfos.size > 1
|
||||
|
|
@ -91,7 +91,7 @@ def mkLet (letDecls : Array Syntax) (body : Syntax) : TermElabM Syntax :=
|
|||
letDecls.foldrM (init := body) fun letDecl body =>
|
||||
`(let $letDecl:letDecl; $body)
|
||||
|
||||
def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) : TermElabM (Array Syntax) := do
|
||||
def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) (useAnonCtor := true) : TermElabM (Array Syntax) := do
|
||||
let mut instances := #[]
|
||||
for i in [:ctx.typeInfos.size] do
|
||||
let indVal := ctx.typeInfos[i]
|
||||
|
|
@ -102,10 +102,47 @@ def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) :
|
|||
let binders := binders ++ (← mkInstImplicitBinders className indVal argNames)
|
||||
let indType ← mkInductiveApp indVal argNames
|
||||
let type ← `($(mkIdent className) $indType)
|
||||
let val ← `(⟨$(mkIdent auxFunName)⟩)
|
||||
let val ← if useAnonCtor then `(⟨$(mkIdent auxFunName)⟩) else pure <| mkIdent auxFunName
|
||||
let instCmd ← `(instance $binders:implicitBinder* : $type := $val)
|
||||
trace[Meta.debug]! "\n{instCmd}"
|
||||
instances := instances.push instCmd
|
||||
return instances
|
||||
|
||||
namespace Binary
|
||||
|
||||
structure Header where
|
||||
binders : Array Syntax
|
||||
argNames : Array Name
|
||||
target1Name : Name
|
||||
target2Name : Name
|
||||
|
||||
def mkHeader (ctx : Context) (className : Name) (indVal : InductiveVal) : TermElabM Header := do
|
||||
let argNames ← mkInductArgNames indVal
|
||||
let binders ← mkImplicitBinders argNames
|
||||
let targetType ← mkInductiveApp indVal argNames
|
||||
let target1Name ← mkFreshUserName `x
|
||||
let target2Name ← mkFreshUserName `y
|
||||
let binders := binders ++ (← mkInstImplicitBinders className indVal argNames)
|
||||
let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType))
|
||||
let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType))
|
||||
let binders := binders ++ #[target1Binder, target2Binder]
|
||||
return {
|
||||
binders := binders
|
||||
argNames := argNames
|
||||
target1Name := target1Name
|
||||
target2Name := target2Name
|
||||
}
|
||||
|
||||
def mkDiscr (varName : Name) : TermElabM Syntax :=
|
||||
`(Parser.Term.matchDiscr| $(mkIdent varName):term)
|
||||
|
||||
def mkDiscrs (header : Header) (indVal : InductiveVal) : TermElabM (Array Syntax) := do
|
||||
let mut discrs := #[]
|
||||
-- add indices
|
||||
for argName in header.argNames[indVal.nparams:] do
|
||||
discrs := discrs.push (← mkDiscr argName)
|
||||
return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name]
|
||||
|
||||
end Binary
|
||||
|
||||
end Lean.Elab.Deriving
|
||||
|
|
|
|||
|
|
@ -27,3 +27,4 @@ import Lean.Meta.ForEachExpr
|
|||
import Lean.Meta.Transform
|
||||
import Lean.Meta.PPGoal
|
||||
import Lean.Meta.UnificationHint
|
||||
import Lean.Meta.Inductive
|
||||
|
|
|
|||
23
src/Lean/Meta/Inductive.lean
Normal file
23
src/Lean/Meta/Inductive.lean
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Meta.ExprDefEq
|
||||
|
||||
/- Helper methods for inductive datatypes -/
|
||||
|
||||
namespace Lean.Meta
|
||||
|
||||
/- Return true if the types of the given constructors are compatible. -/
|
||||
def compatibleCtors (ctorName₁ ctorName₂ : Name) : MetaM Bool := do
|
||||
let ctorInfo₁ ← getConstInfoCtor ctorName₁
|
||||
let ctorInfo₂ ← getConstInfoCtor ctorName₂
|
||||
if ctorInfo₁.induct != ctorInfo₂.induct then
|
||||
return false
|
||||
else
|
||||
let (_, _, ctorType₁) ← forallMetaTelescope ctorInfo₁.type
|
||||
let (_, _, ctorType₂) ← forallMetaTelescope ctorInfo₂.type
|
||||
isDefEq ctorType₁ ctorType₂
|
||||
|
||||
end Lean.Meta
|
||||
18
tests/lean/run/decEq.lean
Normal file
18
tests/lean/run/decEq.lean
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
import Lean
|
||||
|
||||
inductive Vec (α : Type u) : Nat → Type u
|
||||
| nil : Vec α 0
|
||||
| cons : α → {n : Nat} → Vec α n → Vec α (n+1)
|
||||
deriving DecidableEq
|
||||
|
||||
inductive Test (α : Type)
|
||||
| mk₀
|
||||
| mk₁ : (n : Nat) → (α × α) → List α → Test α
|
||||
| mk₂ : Test α → α → Test α
|
||||
deriving DecidableEq
|
||||
|
||||
def t1 [DecidableEq α] : DecidableEq (Vec α n) :=
|
||||
inferInstance
|
||||
|
||||
def t2 [DecidableEq α] : DecidableEq (Test α) :=
|
||||
inferInstance
|
||||
Loading…
Add table
Reference in a new issue