feat: add deriving BEq

This commit is contained in:
Leonardo de Moura 2020-12-13 16:13:27 -08:00
parent 0862df9ade
commit bad714f5e9
4 changed files with 306 additions and 0 deletions

View file

@ -4,4 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Lean.Elab.Deriving.Inhabited
import Lean.Elab.Deriving.BEq

View file

@ -0,0 +1,144 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Transform
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
namespace Lean.Elab.Deriving.BEq
open Meta
structure Header where
binders : Array Syntax
argNames : Array Name
target1Name : Name
target2Name : Name
def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
let argNames ← mkInductArgNames indVal
let binders ← mkImplicitBinders argNames
let targetType ← mkInductiveApp indVal argNames
let target1Name ← mkFreshUserName `x
let target2Name ← mkFreshUserName `y
let binders := binders ++ (← mkInstImplicitBinders `BEq indVal argNames)
let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType))
let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType))
let binders := binders ++ #[target1Binder, target2Binder]
return {
binders := binders
argNames := argNames
target1Name := target1Name
target2Name := target2Name
}
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do
let discrs ← mkDiscrs
let alts ← mkAlts
`(match $[$discrs],* with | $[$alts:matchAlt]|*)
where
mkDiscr (varName : Name) : TermElabM Syntax :=
`(Parser.Term.matchDiscr| $(mkIdent varName):term)
mkDiscrs : TermElabM (Array Syntax) := do
let mut discrs := #[]
-- add indices
for argName in argNames[indVal.nparams:] do
discrs := discrs.push (← mkDiscr argName)
return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name]
mkElseAlt : TermElabM Syntax := do
let mut patterns := #[]
-- add `_` pattern for indices
for i in [:indVal.nindices] do
patterns := patterns.push (← `(_))
patterns := patterns.push (← `(_))
patterns := patterns.push (← `(_))
let altRhs ← `(false)
`(matchAltExpr| $[$patterns:term],* => $altRhs:term)
mkAlts : TermElabM (Array Syntax) := do
let mut alts := #[]
for ctorName in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := #[]
-- add `_` pattern for indices
for i in [:indVal.nindices] do
patterns := patterns.push (← `(_))
let mut ctorArgs1 := #[]
let mut ctorArgs2 := #[]
let mut rhs ← `(true)
-- add `_` for inductive parameters, they are inaccessible
for i in [:indVal.nparams] do
ctorArgs1 := ctorArgs1.push (← `(_))
ctorArgs2 := ctorArgs2.push (← `(_))
for i in [:ctorInfo.nfields] do
let x := xs[indVal.nparams + i]
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to compare
ctorArgs1 := ctorArgs1.push (← `(_))
ctorArgs2 := ctorArgs2.push (← `(_))
else
let a := mkIdent (← mkFreshUserName `a)
let b := mkIdent (← mkFreshUserName `b)
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
if (← inferType x).isAppOf indVal.name then
rhs ← `($rhs && $(mkIdent auxFunName):ident $a:ident $b:ident)
else
rhs ← `($rhs && $a:ident == $b:ident)
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs1:term*))
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2:term*))
`(matchAltExpr| $[$patterns:term],* => $rhs:term)
alts := alts.push alt
alts := alts.push (← mkElseAlt)
return alts
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do
let auxFunName ← ctx.auxFunNames[i]
let indVal ← ctx.typeInfos[i]
let header ← mkHeader ctx indVal
let mut body ← mkMatch ctx header indVal auxFunName header.argNames
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames
body ← mkLet letDecls body
let binders := header.binders
if ctx.usePartial then
`(private partial def $(mkIdent auxFunName):ident $binders:explicitBinder* : Bool := $body:term)
else
`(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : Bool := $body:term)
def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
let mut auxDefs := #[]
for i in [:ctx.typeInfos.size] do
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
`(mutual
set_option match.ignoreUnusedAlts true
$auxDefs:command*
end)
private def mkBEqInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext declNames[0]
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq declNames)
trace[Elab.Deriving.beq]! "\n{cmds}"
return cmds
open Command
def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM none <| mkBEqInstanceCmds declNames
cmds.forM elabCommand
return true
else
return false
builtin_initialize
registerBuiltinDerivingHandler `BEq mkBEqInstanceHandler
registerTraceClass `Elab.Deriving.beq
end Lean.Elab.Deriving.BEq

View file

@ -0,0 +1,112 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Parser.Term
import Lean.Elab.Term
namespace Lean.Elab.Deriving
open Meta
def implicitBinderF := Parser.Term.implicitBinder
def instBinderF := Parser.Term.instBinder
def explicitBinderF := Parser.Term.explicitBinder
def matchAltExpr := Parser.Term.matchAlt
def mkInductArgNames (indVal : InductiveVal) : TermElabM (Array Name) := do
forallTelescopeReducing indVal.type fun xs _ => do
let mut argNames := #[]
for x in xs do
let localDecl ← getLocalDecl x.fvarId!
let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes
argNames := argNames.push paramName
pure argNames
def mkInductiveApp (indVal : InductiveVal) (argNames : Array Name) : TermElabM Syntax :=
let f := mkIdent indVal.name
let args := argNames.map mkIdent
`(@$f $args*)
def mkImplicitBinders (argNames : Array Name) : TermElabM (Array Syntax) :=
argNames.mapM fun argName =>
`(implicitBinderF| { $(mkIdent argName) })
def mkInstImplicitBinders (className : Name) (indVal : InductiveVal) (argNames : Array Name) : TermElabM (Array Syntax) :=
forallBoundedTelescope indVal.type indVal.nparams fun xs _ => do
let mut binders := #[]
for i in [:xs.size] do
try
let x := xs[i]
let c ← mkAppM className #[x]
if (← isTypeCorrect c) then
let argName := argNames[i]
let binder ← `(instBinderF| [ $(mkIdent className):ident $(mkIdent argName):ident ])
binders := binders.push binder
catch _ =>
pure ()
return binders
structure Context where
typeInfos : Array InductiveVal
auxFunNames : Array Name
usePartial : Bool
def mkContext (typeName : Name) : TermElabM Context := do
let indVal ← getConstInfoInduct typeName
let mut typeInfos := #[]
for typeName in indVal.all do
typeInfos ← typeInfos.push (← getConstInfoInduct typeName)
let mut auxFunNames := #[]
for typeName in indVal.all do
match typeName.eraseMacroScopes with
| Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| "beq" ++ t)
| _ => auxFunNames := auxFunNames.push (← mkFreshUserName `instFn)
trace[Elab.Deriving.beq]! "{auxFunNames}"
let usePartial := indVal.isNested || typeInfos.size > 1
return {
typeInfos := typeInfos
auxFunNames := auxFunNames
usePartial := usePartial
}
def mkLocalInstanceLetDecls (ctx : Context) (className : Name) (argNames : Array Name) : TermElabM (Array Syntax) := do
let mut letDecls := #[]
for i in [:ctx.typeInfos.size] do
let indVal := ctx.typeInfos[i]
let auxFunName := ctx.auxFunNames[i]
let currArgNames ← mkInductArgNames indVal
let numParams := indVal.nparams
let currIndices := currArgNames[numParams:]
let binders ← mkImplicitBinders currIndices
let argNamesNew := argNames[:numParams] ++ currIndices
let indType ← mkInductiveApp indVal argNamesNew
let type ← `($(mkIdent className) $indType)
let val ← `(⟨$(mkIdent auxFunName)⟩)
let instName ← mkFreshUserName `localinst
let letDecl ← `(Parser.Term.letDecl| $(mkIdent instName):ident $binders:implicitBinder* : $type := $val)
letDecls := letDecls.push letDecl
return letDecls
def mkLet (letDecls : Array Syntax) (body : Syntax) : TermElabM Syntax :=
letDecls.foldrM (init := body) fun letDecl body =>
`(let $letDecl:letDecl; $body)
def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) : TermElabM (Array Syntax) := do
let mut instances := #[]
for i in [:ctx.typeInfos.size] do
let indVal := ctx.typeInfos[i]
if typeNames.contains indVal.name then
let auxFunName := ctx.auxFunNames[i]
let argNames ← mkInductArgNames indVal
let binders ← mkImplicitBinders argNames
let binders := binders ++ (← mkInstImplicitBinders className indVal argNames)
let indType ← mkInductiveApp indVal argNames
let type ← `($(mkIdent className) $indType)
let val ← `(⟨$(mkIdent auxFunName)⟩)
let instCmd ← `(instance $binders:implicitBinder* : $type := $val)
trace[Meta.debug]! "\n{instCmd}"
instances := instances.push instCmd
return instances
end Lean.Elab.Deriving

View file

@ -0,0 +1,48 @@
inductive Foo
| mk1 | mk2 | mk3
deriving BEq
namespace Foo
theorem ex1 : (mk1 == mk2) = false :=
rfl
theorem ex2 : (mk1 == mk1) = true :=
rfl
theorem ex3 : (mk2 == mk2) = true :=
rfl
theorem ex4 : (mk3 == mk3) = true :=
rfl
theorem ex5 : (mk2 == mk3) = false :=
rfl
end Foo
inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons : α → {n : Nat} → Vec α n → Vec α (n+1)
deriving BEq
namespace Vec
theorem ex1 : (cons 10 Vec.nil == cons 20 Vec.nil) = false :=
rfl
theorem ex2 : (cons 10 Vec.nil == cons 10 Vec.nil) = true :=
rfl
theorem ex3 : (cons 20 (cons 11 Vec.nil) == cons 20 (cons 10 Vec.nil)) = false :=
rfl
theorem ex4 : (cons 20 (cons 11 Vec.nil) == cons 20 (cons 11 Vec.nil)) = true :=
rfl
end Vec
inductive Bla (α : Type u) where
| node : List (Bla α) → Bla α
| leaf : α → Bla α
deriving BEq
namespace Bla
#eval node [] == leaf 10
#eval node [leaf 10] == node [leaf 10]
#eval node [leaf 10] == node [leaf 10, leaf 20]
end Bla