From bad714f5e9f8db6af3d4410451ddacb35be884ca Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 13 Dec 2020 16:13:27 -0800 Subject: [PATCH] feat: add `deriving BEq` --- src/Lean/Elab/Deriving.lean | 2 + src/Lean/Elab/Deriving/BEq.lean | 144 +++++++++++++++++++++++++++++++ src/Lean/Elab/Deriving/Util.lean | 112 ++++++++++++++++++++++++ tests/lean/run/derivingBEq.lean | 48 +++++++++++ 4 files changed, 306 insertions(+) create mode 100644 src/Lean/Elab/Deriving/BEq.lean create mode 100644 src/Lean/Elab/Deriving/Util.lean create mode 100644 tests/lean/run/derivingBEq.lean diff --git a/src/Lean/Elab/Deriving.lean b/src/Lean/Elab/Deriving.lean index 3860388c6d..bbca818a5f 100644 --- a/src/Lean/Elab/Deriving.lean +++ b/src/Lean/Elab/Deriving.lean @@ -4,4 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ import Lean.Elab.Deriving.Basic +import Lean.Elab.Deriving.Util import Lean.Elab.Deriving.Inhabited +import Lean.Elab.Deriving.BEq diff --git a/src/Lean/Elab/Deriving/BEq.lean b/src/Lean/Elab/Deriving/BEq.lean new file mode 100644 index 0000000000..155634f766 --- /dev/null +++ b/src/Lean/Elab/Deriving/BEq.lean @@ -0,0 +1,144 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Meta.Transform +import Lean.Elab.Deriving.Basic +import Lean.Elab.Deriving.Util + +namespace Lean.Elab.Deriving.BEq + +open Meta + +structure Header where + binders : Array Syntax + argNames : Array Name + target1Name : Name + target2Name : Name + +def mkHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do + let argNames ← mkInductArgNames indVal + let binders ← mkImplicitBinders argNames + let targetType ← mkInductiveApp indVal argNames + let target1Name ← mkFreshUserName `x + let target2Name ← mkFreshUserName `y + let binders := binders ++ (← mkInstImplicitBinders `BEq indVal argNames) + let target1Binder ← `(explicitBinderF| ($(mkIdent target1Name) : $targetType)) + let target2Binder ← `(explicitBinderF| ($(mkIdent target2Name) : $targetType)) + let binders := binders ++ #[target1Binder, target2Binder] + return { + binders := binders + argNames := argNames + target1Name := target1Name + target2Name := target2Name + } + +def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) (argNames : Array Name) : TermElabM Syntax := do + let discrs ← mkDiscrs + let alts ← mkAlts + `(match $[$discrs],* with | $[$alts:matchAlt]|*) +where + mkDiscr (varName : Name) : TermElabM Syntax := + `(Parser.Term.matchDiscr| $(mkIdent varName):term) + + mkDiscrs : TermElabM (Array Syntax) := do + let mut discrs := #[] + -- add indices + for argName in argNames[indVal.nparams:] do + discrs := discrs.push (← mkDiscr argName) + return discrs ++ #[← mkDiscr header.target1Name, ← mkDiscr header.target2Name] + + mkElseAlt : TermElabM Syntax := do + let mut patterns := #[] + -- add `_` pattern for indices + for i in [:indVal.nindices] do + patterns := patterns.push (← `(_)) + patterns := patterns.push (← `(_)) + patterns := patterns.push (← `(_)) + let altRhs ← `(false) + `(matchAltExpr| $[$patterns:term],* => $altRhs:term) + + mkAlts : TermElabM (Array Syntax) := do + let mut alts := #[] + for ctorName in indVal.ctors do + let ctorInfo ← getConstInfoCtor ctorName + let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies + let mut patterns := #[] + -- add `_` pattern for indices + for i in [:indVal.nindices] do + patterns := patterns.push (← `(_)) + let mut ctorArgs1 := #[] + let mut ctorArgs2 := #[] + let mut rhs ← `(true) + -- add `_` for inductive parameters, they are inaccessible + for i in [:indVal.nparams] do + ctorArgs1 := ctorArgs1.push (← `(_)) + ctorArgs2 := ctorArgs2.push (← `(_)) + for i in [:ctorInfo.nfields] do + let x := xs[indVal.nparams + i] + if type.containsFVar x.fvarId! then + -- If resulting type depends on this field, we don't need to compare + ctorArgs1 := ctorArgs1.push (← `(_)) + ctorArgs2 := ctorArgs2.push (← `(_)) + else + let a := mkIdent (← mkFreshUserName `a) + let b := mkIdent (← mkFreshUserName `b) + ctorArgs1 := ctorArgs1.push a + ctorArgs2 := ctorArgs2.push b + if (← inferType x).isAppOf indVal.name then + rhs ← `($rhs && $(mkIdent auxFunName):ident $a:ident $b:ident) + else + rhs ← `($rhs && $a:ident == $b:ident) + patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs1:term*)) + patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs2:term*)) + `(matchAltExpr| $[$patterns:term],* => $rhs:term) + alts := alts.push alt + alts := alts.push (← mkElseAlt) + return alts + +def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do + let auxFunName ← ctx.auxFunNames[i] + let indVal ← ctx.typeInfos[i] + let header ← mkHeader ctx indVal + let mut body ← mkMatch ctx header indVal auxFunName header.argNames + if ctx.usePartial then + let letDecls ← mkLocalInstanceLetDecls ctx `BEq header.argNames + body ← mkLet letDecls body + let binders := header.binders + if ctx.usePartial then + `(private partial def $(mkIdent auxFunName):ident $binders:explicitBinder* : Bool := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : Bool := $body:term) + +def mkMutualBlock (ctx : Context) : TermElabM Syntax := do + let mut auxDefs := #[] + for i in [:ctx.typeInfos.size] do + auxDefs := auxDefs.push (← mkAuxFunction ctx i) + `(mutual + set_option match.ignoreUnusedAlts true + $auxDefs:command* + end) + +private def mkBEqInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do + let ctx ← mkContext declNames[0] + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq declNames) + trace[Elab.Deriving.beq]! "\n{cmds}" + return cmds + +open Command + +def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do + if (← declNames.allM isInductive) && declNames.size > 0 then + let cmds ← liftTermElabM none <| mkBEqInstanceCmds declNames + cmds.forM elabCommand + return true + else + return false + +builtin_initialize + registerBuiltinDerivingHandler `BEq mkBEqInstanceHandler + registerTraceClass `Elab.Deriving.beq + +end Lean.Elab.Deriving.BEq diff --git a/src/Lean/Elab/Deriving/Util.lean b/src/Lean/Elab/Deriving/Util.lean new file mode 100644 index 0000000000..c4fb1e68da --- /dev/null +++ b/src/Lean/Elab/Deriving/Util.lean @@ -0,0 +1,112 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Parser.Term +import Lean.Elab.Term + +namespace Lean.Elab.Deriving +open Meta + +def implicitBinderF := Parser.Term.implicitBinder +def instBinderF := Parser.Term.instBinder +def explicitBinderF := Parser.Term.explicitBinder +def matchAltExpr := Parser.Term.matchAlt + +def mkInductArgNames (indVal : InductiveVal) : TermElabM (Array Name) := do + forallTelescopeReducing indVal.type fun xs _ => do + let mut argNames := #[] + for x in xs do + let localDecl ← getLocalDecl x.fvarId! + let paramName ← mkFreshUserName localDecl.userName.eraseMacroScopes + argNames := argNames.push paramName + pure argNames + +def mkInductiveApp (indVal : InductiveVal) (argNames : Array Name) : TermElabM Syntax := + let f := mkIdent indVal.name + let args := argNames.map mkIdent + `(@$f $args*) + +def mkImplicitBinders (argNames : Array Name) : TermElabM (Array Syntax) := + argNames.mapM fun argName => + `(implicitBinderF| { $(mkIdent argName) }) + +def mkInstImplicitBinders (className : Name) (indVal : InductiveVal) (argNames : Array Name) : TermElabM (Array Syntax) := + forallBoundedTelescope indVal.type indVal.nparams fun xs _ => do + let mut binders := #[] + for i in [:xs.size] do + try + let x := xs[i] + let c ← mkAppM className #[x] + if (← isTypeCorrect c) then + let argName := argNames[i] + let binder ← `(instBinderF| [ $(mkIdent className):ident $(mkIdent argName):ident ]) + binders := binders.push binder + catch _ => + pure () + return binders + +structure Context where + typeInfos : Array InductiveVal + auxFunNames : Array Name + usePartial : Bool + +def mkContext (typeName : Name) : TermElabM Context := do + let indVal ← getConstInfoInduct typeName + let mut typeInfos := #[] + for typeName in indVal.all do + typeInfos ← typeInfos.push (← getConstInfoInduct typeName) + let mut auxFunNames := #[] + for typeName in indVal.all do + match typeName.eraseMacroScopes with + | Name.str _ t _ => auxFunNames := auxFunNames.push (← mkFreshUserName <| Name.mkSimple <| "beq" ++ t) + | _ => auxFunNames := auxFunNames.push (← mkFreshUserName `instFn) + trace[Elab.Deriving.beq]! "{auxFunNames}" + let usePartial := indVal.isNested || typeInfos.size > 1 + return { + typeInfos := typeInfos + auxFunNames := auxFunNames + usePartial := usePartial + } + +def mkLocalInstanceLetDecls (ctx : Context) (className : Name) (argNames : Array Name) : TermElabM (Array Syntax) := do + let mut letDecls := #[] + for i in [:ctx.typeInfos.size] do + let indVal := ctx.typeInfos[i] + let auxFunName := ctx.auxFunNames[i] + let currArgNames ← mkInductArgNames indVal + let numParams := indVal.nparams + let currIndices := currArgNames[numParams:] + let binders ← mkImplicitBinders currIndices + let argNamesNew := argNames[:numParams] ++ currIndices + let indType ← mkInductiveApp indVal argNamesNew + let type ← `($(mkIdent className) $indType) + let val ← `(⟨$(mkIdent auxFunName)⟩) + let instName ← mkFreshUserName `localinst + let letDecl ← `(Parser.Term.letDecl| $(mkIdent instName):ident $binders:implicitBinder* : $type := $val) + letDecls := letDecls.push letDecl + return letDecls + +def mkLet (letDecls : Array Syntax) (body : Syntax) : TermElabM Syntax := + letDecls.foldrM (init := body) fun letDecl body => + `(let $letDecl:letDecl; $body) + +def mkInstanceCmds (ctx : Context) (className : Name) (typeNames : Array Name) : TermElabM (Array Syntax) := do + let mut instances := #[] + for i in [:ctx.typeInfos.size] do + let indVal := ctx.typeInfos[i] + if typeNames.contains indVal.name then + let auxFunName := ctx.auxFunNames[i] + let argNames ← mkInductArgNames indVal + let binders ← mkImplicitBinders argNames + let binders := binders ++ (← mkInstImplicitBinders className indVal argNames) + let indType ← mkInductiveApp indVal argNames + let type ← `($(mkIdent className) $indType) + let val ← `(⟨$(mkIdent auxFunName)⟩) + let instCmd ← `(instance $binders:implicitBinder* : $type := $val) + trace[Meta.debug]! "\n{instCmd}" + instances := instances.push instCmd + return instances + +end Lean.Elab.Deriving diff --git a/tests/lean/run/derivingBEq.lean b/tests/lean/run/derivingBEq.lean new file mode 100644 index 0000000000..ff8f1146b1 --- /dev/null +++ b/tests/lean/run/derivingBEq.lean @@ -0,0 +1,48 @@ +inductive Foo + | mk1 | mk2 | mk3 + deriving BEq + +namespace Foo +theorem ex1 : (mk1 == mk2) = false := + rfl +theorem ex2 : (mk1 == mk1) = true := + rfl +theorem ex3 : (mk2 == mk2) = true := + rfl +theorem ex4 : (mk3 == mk3) = true := + rfl +theorem ex5 : (mk2 == mk3) = false := + rfl +end Foo + +inductive Vec (α : Type u) : Nat → Type u + | nil : Vec α 0 + | cons : α → {n : Nat} → Vec α n → Vec α (n+1) + deriving BEq + +namespace Vec +theorem ex1 : (cons 10 Vec.nil == cons 20 Vec.nil) = false := + rfl + +theorem ex2 : (cons 10 Vec.nil == cons 10 Vec.nil) = true := + rfl + +theorem ex3 : (cons 20 (cons 11 Vec.nil) == cons 20 (cons 10 Vec.nil)) = false := + rfl + +theorem ex4 : (cons 20 (cons 11 Vec.nil) == cons 20 (cons 11 Vec.nil)) = true := + rfl +end Vec + +inductive Bla (α : Type u) where + | node : List (Bla α) → Bla α + | leaf : α → Bla α + deriving BEq + +namespace Bla + +#eval node [] == leaf 10 +#eval node [leaf 10] == node [leaf 10] +#eval node [leaf 10] == node [leaf 10, leaf 20] + +end Bla