feat: add deriving BEq
This commit is contained in:
parent
0862df9ade
commit
bad714f5e9
4 changed files with 306 additions and 0 deletions
|
|
@ -4,4 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.Deriving.Util
|
||||
import Lean.Elab.Deriving.Inhabited
|
||||
import Lean.Elab.Deriving.BEq
|
||||
|
|
|
|||
144
src/Lean/Elab/Deriving/BEq.lean
Normal file
144
src/Lean/Elab/Deriving/BEq.lean
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Meta.Transform
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.Deriving.Util
|
||||
|
||||
namespace Lean.Elab.Deriving.BEq
|
||||
|
||||
open Meta
|
||||
|
||||
structure Header where
|
||||
binders : Array Syntax
|
||||
argNames : Array Name
|
||||
target1Name : Name
|
||||
target2Name : Name
|
||||
|
||||
def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
|
||||
let argNames ← mkInductArgNames indVal
|
||||
let binders ← mkImplicitBinders argNames
|
||||
let targetType ← mkInductiveApp indVal argNames
|
||||
let target1Name ← mkFreshUserName `x
|
||||
let target2Name ← mkFreshUserName `y
|
||||
let binders := binders ++ (← mkInstImplicitBinders `BEq indVal argNames)
|
||||
let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType))
|
||||
let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType))
|
||||
let binders := binders ++ #[target1Binder, target2Binder]
|
||||
return {
|
||||
binders := binders
|
||||
argNames := argNames
|
||||
target1Name := target1Name
|
||||
target2Name := target2Name
|
||||
}
|
||||
|
||||
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do
|
||||
let discrs ← mkDiscrs
|
||||
let alts ← mkAlts
|
||||
`(match $[$discrs],* with | $[$alts:matchAlt]|*)
|
||||
where
|
||||
mkDiscr (varName : Name) : TermElabM Syntax :=
|
||||
`(Parser.Term.matchDiscr| $(mkIdent varName):term)
|
||||
|
||||
mkDiscrs : TermElabM (Array Syntax) := do
|
||||
let mut discrs := #[]
|
||||
-- add indices
|
||||
for argName in argNames[indVal.nparams:] do
|
||||
discrs := discrs.push (← mkDiscr argName)
|
||||
return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name]
|
||||
|
||||
mkElseAlt : TermElabM Syntax := do
|
||||
let mut patterns := #[]
|
||||
-- add `_` pattern for indices
|
||||
for i in [:indVal.nindices] do
|
||||
patterns := patterns.push (← `(_))
|
||||
patterns := patterns.push (← `(_))
|
||||
patterns := patterns.push (← `(_))
|
||||
let altRhs ← `(false)
|
||||
`(matchAltExpr| $[$patterns:term],* => $altRhs:term)
|
||||
|
||||
mkAlts : TermElabM (Array Syntax) := do
|
||||
let mut alts := #[]
|
||||
for ctorName in indVal.ctors do
|
||||
let ctorInfo ← getConstInfoCtor ctorName
|
||||
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
|
||||
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
|
||||
let mut patterns := #[]
|
||||
-- add `_` pattern for indices
|
||||
for i in [:indVal.nindices] do
|
||||
patterns := patterns.push (← `(_))
|
||||
let mut ctorArgs1 := #[]
|
||||
let mut ctorArgs2 := #[]
|
||||
let mut rhs ← `(true)
|
||||
-- add `_` for inductive parameters, they are inaccessible
|
||||
for i in [:indVal.nparams] do
|
||||
ctorArgs1 := ctorArgs1.push (← `(_))
|
||||
ctorArgs2 := ctorArgs2.push (← `(_))
|
||||
for i in [:ctorInfo.nfields] do
|
||||
let x := xs[indVal.nparams + i]
|
||||
if type.containsFVar x.fvarId! then
|
||||
-- If resulting type depends on this field, we don't need to compare
|
||||
ctorArgs1 := ctorArgs1.push (← `(_))
|
||||
ctorArgs2 := ctorArgs2.push (← `(_))
|
||||
else
|
||||
let a := mkIdent (← mkFreshUserName `a)
|
||||
let b := mkIdent (← mkFreshUserName `b)
|
||||
ctorArgs1 := ctorArgs1.push a
|
||||
ctorArgs2 := ctorArgs2.push b
|
||||
if (← inferType x).isAppOf indVal.name then
|
||||
rhs ← `($rhs && $(mkIdent auxFunName):ident $a:ident $b:ident)
|
||||
else
|
||||
rhs ← `($rhs && $a:ident == $b:ident)
|
||||
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs1:term*))
|
||||
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2:term*))
|
||||
`(matchAltExpr| $[$patterns:term],* => $rhs:term)
|
||||
alts := alts.push alt
|
||||
alts := alts.push (← mkElseAlt)
|
||||
return alts
|
||||
|
||||
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do
|
||||
let auxFunName ← ctx.auxFunNames[i]
|
||||
let indVal ← ctx.typeInfos[i]
|
||||
let header ← mkHeader ctx indVal
|
||||
let mut body ← mkMatch ctx header indVal auxFunName header.argNames
|
||||
if ctx.usePartial then
|
||||
let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
|
||||
body ← mkLet letDecls body
|
||||
let binders := header.binders
|
||||
if ctx.usePartial then
|
||||
`(private partial def $(mkIdent auxFunName):ident $binders:explicitBinder* : Bool := $body:term)
|
||||
else
|
||||
`(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : Bool := $body:term)
|
||||
|
||||
def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
|
||||
let mut auxDefs := #[]
|
||||
for i in [:ctx.typeInfos.size] do
|
||||
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
|
||||
`(mutual
|
||||
set_option match.ignoreUnusedAlts true
|
||||
$auxDefs:command*
|
||||
end)
|
||||
|
||||
private def mkBEqInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext declNames[0]
|
||||
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq declNames)
|
||||
trace[Elab.Deriving.beq]! "\n{cmds}"
|
||||
return cmds
|
||||
|
||||
open Command
|
||||
|
||||
def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
||||
if (← declNames.allM isInductive) && declNames.size > 0 then
|
||||
let cmds ← liftTermElabM none <| mkBEqInstanceCmds declNames
|
||||
cmds.forM elabCommand
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinDerivingHandler `BEq mkBEqInstanceHandler
|
||||
registerTraceClass `Elab.Deriving.beq
|
||||
|
||||
end Lean.Elab.Deriving.BEq
|
||||
112
src/Lean/Elab/Deriving/Util.lean
Normal file
112
src/Lean/Elab/Deriving/Util.lean
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Parser.Term
|
||||
import Lean.Elab.Term
|
||||
|
||||
namespace Lean.Elab.Deriving
|
||||
open Meta
|
||||
|
||||
def implicitBinderF := Parser.Term.implicitBinder
|
||||
def instBinderF := Parser.Term.instBinder
|
||||
def explicitBinderF := Parser.Term.explicitBinder
|
||||
def matchAltExpr := Parser.Term.matchAlt
|
||||
|
||||
def mkInductArgNames (indVal : InductiveVal) : TermElabM (Array Name) := do
|
||||
forallTelescopeReducing indVal.type fun xs _ => do
|
||||
let mut argNames := #[]
|
||||
for x in xs do
|
||||
let localDecl ← getLocalDecl x.fvarId!
|
||||
let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes
|
||||
argNames := argNames.push paramName
|
||||
pure argNames
|
||||
|
||||
def mkInductiveApp (indVal : InductiveVal) (argNames : Array Name) : TermElabM Syntax :=
|
||||
let f := mkIdent indVal.name
|
||||
let args := argNames.map mkIdent
|
||||
`(@$f $args*)
|
||||
|
||||
def mkImplicitBinders (argNames : Array Name) : TermElabM (Array Syntax) :=
|
||||
argNames.mapM fun argName =>
|
||||
`(implicitBinderF| { $(mkIdent argName) })
|
||||
|
||||
def mkInstImplicitBinders (className : Name) (indVal : InductiveVal) (argNames : Array Name) : TermElabM (Array Syntax) :=
|
||||
forallBoundedTelescope indVal.type indVal.nparams fun xs _ => do
|
||||
let mut binders := #[]
|
||||
for i in [:xs.size] do
|
||||
try
|
||||
let x := xs[i]
|
||||
let c ← mkAppM className #[x]
|
||||
if (← isTypeCorrect c) then
|
||||
let argName := argNames[i]
|
||||
let binder ← `(instBinderF| [ $(mkIdent className):ident $(mkIdent argName):ident ])
|
||||
binders := binders.push binder
|
||||
catch _ =>
|
||||
pure ()
|
||||
return binders
|
||||
|
||||
structure Context where
|
||||
typeInfos : Array InductiveVal
|
||||
auxFunNames : Array Name
|
||||
usePartial : Bool
|
||||
|
||||
def mkContext (typeName : Name) : TermElabM Context := do
|
||||
let indVal ← getConstInfoInduct typeName
|
||||
let mut typeInfos := #[]
|
||||
for typeName in indVal.all do
|
||||
typeInfos ← typeInfos.push (← getConstInfoInduct typeName)
|
||||
let mut auxFunNames := #[]
|
||||
for typeName in indVal.all do
|
||||
match typeName.eraseMacroScopes with
|
||||
| Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| "beq" ++ t)
|
||||
| _ => auxFunNames := auxFunNames.push (← mkFreshUserName `instFn)
|
||||
trace[Elab.Deriving.beq]! "{auxFunNames}"
|
||||
let usePartial := indVal.isNested || typeInfos.size > 1
|
||||
return {
|
||||
typeInfos := typeInfos
|
||||
auxFunNames := auxFunNames
|
||||
usePartial := usePartial
|
||||
}
|
||||
|
||||
def mkLocalInstanceLetDecls (ctx : Context) (className : Name) (argNames : Array Name) : TermElabM (Array Syntax) := do
|
||||
let mut letDecls := #[]
|
||||
for i in [:ctx.typeInfos.size] do
|
||||
let indVal := ctx.typeInfos[i]
|
||||
let auxFunName := ctx.auxFunNames[i]
|
||||
let currArgNames ← mkInductArgNames indVal
|
||||
let numParams := indVal.nparams
|
||||
let currIndices := currArgNames[numParams:]
|
||||
let binders ← mkImplicitBinders currIndices
|
||||
let argNamesNew := argNames[:numParams] ++ currIndices
|
||||
let indType ← mkInductiveApp indVal argNamesNew
|
||||
let type ← `($(mkIdent className) $indType)
|
||||
let val ← `(⟨$(mkIdent auxFunName)⟩)
|
||||
let instName ← mkFreshUserName `localinst
|
||||
let letDecl ← `(Parser.Term.letDecl| $(mkIdent instName):ident $binders:implicitBinder* : $type := $val)
|
||||
letDecls := letDecls.push letDecl
|
||||
return letDecls
|
||||
|
||||
def mkLet (letDecls : Array Syntax) (body : Syntax) : TermElabM Syntax :=
|
||||
letDecls.foldrM (init := body) fun letDecl body =>
|
||||
`(let $letDecl:letDecl; $body)
|
||||
|
||||
def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) : TermElabM (Array Syntax) := do
|
||||
let mut instances := #[]
|
||||
for i in [:ctx.typeInfos.size] do
|
||||
let indVal := ctx.typeInfos[i]
|
||||
if typeNames.contains indVal.name then
|
||||
let auxFunName := ctx.auxFunNames[i]
|
||||
let argNames ← mkInductArgNames indVal
|
||||
let binders ← mkImplicitBinders argNames
|
||||
let binders := binders ++ (← mkInstImplicitBinders className indVal argNames)
|
||||
let indType ← mkInductiveApp indVal argNames
|
||||
let type ← `($(mkIdent className) $indType)
|
||||
let val ← `(⟨$(mkIdent auxFunName)⟩)
|
||||
let instCmd ← `(instance $binders:implicitBinder* : $type := $val)
|
||||
trace[Meta.debug]! "\n{instCmd}"
|
||||
instances := instances.push instCmd
|
||||
return instances
|
||||
|
||||
end Lean.Elab.Deriving
|
||||
48
tests/lean/run/derivingBEq.lean
Normal file
48
tests/lean/run/derivingBEq.lean
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
inductive Foo
|
||||
| mk1 | mk2 | mk3
|
||||
deriving BEq
|
||||
|
||||
namespace Foo
|
||||
theorem ex1 : (mk1 == mk2) = false :=
|
||||
rfl
|
||||
theorem ex2 : (mk1 == mk1) = true :=
|
||||
rfl
|
||||
theorem ex3 : (mk2 == mk2) = true :=
|
||||
rfl
|
||||
theorem ex4 : (mk3 == mk3) = true :=
|
||||
rfl
|
||||
theorem ex5 : (mk2 == mk3) = false :=
|
||||
rfl
|
||||
end Foo
|
||||
|
||||
inductive Vec (α : Type u) : Nat → Type u
|
||||
| nil : Vec α 0
|
||||
| cons : α → {n : Nat} → Vec α n → Vec α (n+1)
|
||||
deriving BEq
|
||||
|
||||
namespace Vec
|
||||
theorem ex1 : (cons 10 Vec.nil == cons 20 Vec.nil) = false :=
|
||||
rfl
|
||||
|
||||
theorem ex2 : (cons 10 Vec.nil == cons 10 Vec.nil) = true :=
|
||||
rfl
|
||||
|
||||
theorem ex3 : (cons 20 (cons 11 Vec.nil) == cons 20 (cons 10 Vec.nil)) = false :=
|
||||
rfl
|
||||
|
||||
theorem ex4 : (cons 20 (cons 11 Vec.nil) == cons 20 (cons 11 Vec.nil)) = true :=
|
||||
rfl
|
||||
end Vec
|
||||
|
||||
inductive Bla (α : Type u) where
|
||||
| node : List (Bla α) → Bla α
|
||||
| leaf : α → Bla α
|
||||
deriving BEq
|
||||
|
||||
namespace Bla
|
||||
|
||||
#eval node [] == leaf 10
|
||||
#eval node [leaf 10] == node [leaf 10]
|
||||
#eval node [leaf 10] == node [leaf 10, leaf 20]
|
||||
|
||||
end Bla
|
||||
Loading…
Add table
Reference in a new issue