feat: add deriving Repr
This commit is contained in:
parent
793ffa2f11
commit
d22c5cb1b0
4 changed files with 184 additions and 0 deletions
|
|
@ -8,3 +8,4 @@ import Lean.Elab.Deriving.Util
|
|||
import Lean.Elab.Deriving.Inhabited
|
||||
import Lean.Elab.Deriving.BEq
|
||||
import Lean.Elab.Deriving.DecEq
|
||||
import Lean.Elab.Deriving.Repr
|
||||
|
|
|
|||
128
src/Lean/Elab/Deriving/Repr.lean
Normal file
128
src/Lean/Elab/Deriving/Repr.lean
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
/-
|
||||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Meta.Transform
|
||||
import Lean.Meta.Inductive
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.Deriving.Util
|
||||
|
||||
namespace Lean.Elab.Deriving.Repr
|
||||
open Lean.Parser.Term
|
||||
open Meta
|
||||
open Std
|
||||
|
||||
def mkReprHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do
|
||||
let prec ← `(prec)
|
||||
let header ← mkHeader ctx `Repr 1 indVal
|
||||
return { header with
|
||||
binders := header.binders.push (← `(explicitBinderF| (prec : Nat)))
|
||||
}
|
||||
|
||||
def mkBodyForStruct (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Syntax := do
|
||||
let ctorVal ← getConstInfoCtor indVal.ctors.head!
|
||||
let fieldNames ← getStructureFields (← getEnv) indVal.name
|
||||
let numParams := indVal.nparams
|
||||
let target := mkIdent header.targetNames[0]
|
||||
forallTelescopeReducing ctorVal.type fun xs _ => do
|
||||
let mut fields : Syntax ← `(Format.nil)
|
||||
let mut first := true
|
||||
if xs.size != numParams + fieldNames.size then
|
||||
throwError! "'deriving Repr' failed, unexpected number of fields in structure"
|
||||
for i in [:fieldNames.size] do
|
||||
let fieldName := fieldNames[i]
|
||||
let fieldNameLit := Syntax.mkStrLit (toString fieldName)
|
||||
let x := xs[numParams + i]
|
||||
if first then
|
||||
first := false
|
||||
else
|
||||
fields ← `($fields ++ "," ++ Format.line)
|
||||
if (← isType x <||> isProof x) then
|
||||
fields ← `($fields ++ $fieldNameLit ++ " := " ++ "_")
|
||||
else
|
||||
fields ← `($fields ++ $fieldNameLit ++ " := " ++ repr ($target.$(mkIdent fieldName):ident))
|
||||
`(Format.bracket "{ " $fields:term " }")
|
||||
|
||||
def mkBodyForInduct (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do
|
||||
let discrs ← mkDiscrs header indVal
|
||||
let alts ← mkAlts
|
||||
`(match $[$discrs],* with $alts:matchAlt*)
|
||||
where
|
||||
mkAlts : TermElabM (Array Syntax) := do
|
||||
let mut alts := #[]
|
||||
for ctorName in indVal.ctors do
|
||||
let ctorInfo ← getConstInfoCtor ctorName
|
||||
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
|
||||
let mut patterns := #[]
|
||||
-- add `_` pattern for indices
|
||||
for i in [:indVal.nindices] do
|
||||
patterns := patterns.push (← `(_))
|
||||
let mut ctorArgs := #[]
|
||||
let mut rhs := Syntax.mkStrLit (toString ctorInfo.name)
|
||||
let mut rhs ← `(Format.text $rhs)
|
||||
-- add `_` for inductive parameters, they are inaccessible
|
||||
for i in [:indVal.nparams] do
|
||||
ctorArgs := ctorArgs.push (← `(_))
|
||||
for i in [:ctorInfo.nfields] do
|
||||
let x := xs[indVal.nparams + i]
|
||||
let a := mkIdent (← mkFreshUserName `a)
|
||||
ctorArgs := ctorArgs.push a
|
||||
if (← inferType x).isAppOf indVal.name then
|
||||
rhs ← `($rhs ++ Format.line ++ $(mkIdent auxFunName):ident $a:ident maxPrec!)
|
||||
else
|
||||
rhs ← `($rhs ++ Format.line ++ reprArg $a)
|
||||
patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*))
|
||||
`(matchAltExpr| | $[$patterns:term],* => Repr.addAppParen (Format.group (Format.nest (if prec >= maxPrec! then 1 else 2) ($rhs:term))) prec)
|
||||
alts := alts.push alt
|
||||
return alts
|
||||
|
||||
def mkBody (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do
|
||||
if isStructureLike (← getEnv) indVal.name then
|
||||
mkBodyForStruct ctx header indVal
|
||||
else
|
||||
mkBodyForInduct ctx header indVal auxFunName
|
||||
|
||||
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do
|
||||
let auxFunName ← ctx.auxFunNames[i]
|
||||
let indVal ← ctx.typeInfos[i]
|
||||
let header ← mkReprHeader ctx indVal
|
||||
let mut body ← mkBody ctx header indVal auxFunName
|
||||
if ctx.usePartial then
|
||||
let letDecls ← mkLocalInstanceLetDecls ctx `Repr header.argNames
|
||||
body ← mkLet letDecls body
|
||||
let binders := header.binders
|
||||
if ctx.usePartial then
|
||||
`(private partial def $(mkIdent auxFunName):ident $binders:explicitBinder* : Format := $body:term)
|
||||
else
|
||||
`(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : Format := $body:term)
|
||||
|
||||
def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
|
||||
let mut auxDefs := #[]
|
||||
for i in [:ctx.typeInfos.size] do
|
||||
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
|
||||
`(mutual
|
||||
$auxDefs:command*
|
||||
end)
|
||||
|
||||
private def mkReprInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext "repr" declNames[0]
|
||||
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr declNames)
|
||||
trace[Elab.Deriving.repr]! "\n{cmds}"
|
||||
return cmds
|
||||
|
||||
open Command
|
||||
|
||||
def mkReprInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
||||
if (← declNames.allM isInductive) && declNames.size > 0 then
|
||||
let cmds ← liftTermElabM none <| mkReprInstanceCmds declNames
|
||||
cmds.forM elabCommand
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinDerivingHandler `Repr mkReprInstanceHandler
|
||||
registerTraceClass `Elab.Deriving.repr
|
||||
|
||||
end Lean.Elab.Deriving.Repr
|
||||
33
tests/lean/derivingRepr.lean
Normal file
33
tests/lean/derivingRepr.lean
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
structure Foo where
|
||||
name : String
|
||||
val : List Nat
|
||||
lower : Nat := List.length val
|
||||
inv : val.length >= lower
|
||||
flag : Bool
|
||||
deriving Repr
|
||||
|
||||
#eval { name := "Joe", val := List.iota 40, flag := true, inv := by decide! : Foo }
|
||||
|
||||
inductive Tree (α : Type) where
|
||||
| node : List (Tree α) → Bool → Tree α
|
||||
| leaf : α → Tree α
|
||||
deriving Repr
|
||||
|
||||
#eval Tree.node (List.iota 10 |>.map fun i => Tree.node [Tree.leaf i] (i%2==0)) true
|
||||
|
||||
namespace Foo
|
||||
mutual
|
||||
inductive Tree (α : Type u) where
|
||||
| node : TreeList α → Tree α
|
||||
| leaf : α → Tree α
|
||||
deriving Repr
|
||||
|
||||
inductive TreeList (α : Type u) where
|
||||
| nil : TreeList α
|
||||
| cons : Tree α → TreeList α → TreeList α
|
||||
deriving Repr
|
||||
end
|
||||
|
||||
#eval Tree.node (TreeList.cons (Tree.leaf 30) (TreeList.cons (Tree.leaf 20) (TreeList.cons (Tree.leaf 10) TreeList.nil)))
|
||||
|
||||
end Foo
|
||||
22
tests/lean/derivingRepr.lean.expected.out
Normal file
22
tests/lean/derivingRepr.lean.expected.out
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
{ name := "Joe",
|
||||
val := [40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14,
|
||||
13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
|
||||
lower := 40,
|
||||
inv := _,
|
||||
flag := true }
|
||||
Tree.node
|
||||
[Tree.node [Tree.leaf 10] true,
|
||||
Tree.node [Tree.leaf 9] false,
|
||||
Tree.node [Tree.leaf 8] true,
|
||||
Tree.node [Tree.leaf 7] false,
|
||||
Tree.node [Tree.leaf 6] true,
|
||||
Tree.node [Tree.leaf 5] false,
|
||||
Tree.node [Tree.leaf 4] true,
|
||||
Tree.node [Tree.leaf 3] false,
|
||||
Tree.node [Tree.leaf 2] true,
|
||||
Tree.node [Tree.leaf 1] false]
|
||||
true
|
||||
Foo.Tree.node
|
||||
(Foo.TreeList.cons
|
||||
(Foo.Tree.leaf 30)
|
||||
(Foo.TreeList.cons (Foo.Tree.leaf 20) (Foo.TreeList.cons (Foo.Tree.leaf 10) (Foo.TreeList.nil))))
|
||||
Loading…
Add table
Reference in a new issue