diff --git a/src/Lean/Elab/Deriving.lean b/src/Lean/Elab/Deriving.lean index de069d1c17..840ee63261 100644 --- a/src/Lean/Elab/Deriving.lean +++ b/src/Lean/Elab/Deriving.lean @@ -8,3 +8,4 @@ import Lean.Elab.Deriving.Util import Lean.Elab.Deriving.Inhabited import Lean.Elab.Deriving.BEq import Lean.Elab.Deriving.DecEq +import Lean.Elab.Deriving.Repr diff --git a/src/Lean/Elab/Deriving/Repr.lean b/src/Lean/Elab/Deriving/Repr.lean new file mode 100644 index 0000000000..a59f2b942a --- /dev/null +++ b/src/Lean/Elab/Deriving/Repr.lean @@ -0,0 +1,128 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Meta.Transform +import Lean.Meta.Inductive +import Lean.Elab.Deriving.Basic +import Lean.Elab.Deriving.Util + +namespace Lean.Elab.Deriving.Repr +open Lean.Parser.Term +open Meta +open Std + +def mkReprHeader (ctx : Context) (indVal : InductiveVal) : TermElabM Header := do + let prec ← `(prec) + let header ← mkHeader ctx `Repr 1 indVal + return { header with + binders := header.binders.push (← `(explicitBinderF| (prec : Nat))) + } + +def mkBodyForStruct (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Syntax := do + let ctorVal ← getConstInfoCtor indVal.ctors.head! + let fieldNames ← getStructureFields (← getEnv) indVal.name + let numParams := indVal.nparams + let target := mkIdent header.targetNames[0] + forallTelescopeReducing ctorVal.type fun xs _ => do + let mut fields : Syntax ← `(Format.nil) + let mut first := true + if xs.size != numParams + fieldNames.size then + throwError! "'deriving Repr' failed, unexpected number of fields in structure" + for i in [:fieldNames.size] do + let fieldName := fieldNames[i] + let fieldNameLit := Syntax.mkStrLit (toString fieldName) + let x := xs[numParams + i] + if first then + first := false + else + fields ← `($fields ++ "," ++ Format.line) + if (← isType x <||> isProof x) then + fields ← `($fields ++ $fieldNameLit ++ " := " ++ "_") + else + fields ← `($fields ++ $fieldNameLit ++ " := " ++ repr ($target.$(mkIdent fieldName):ident)) + `(Format.bracket "{ " $fields:term " }") + +def mkBodyForInduct (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do + let discrs ← mkDiscrs header indVal + let alts ← mkAlts + `(match $[$discrs],* with $alts:matchAlt*) +where + mkAlts : TermElabM (Array Syntax) := do + let mut alts := #[] + for ctorName in indVal.ctors do + let ctorInfo ← getConstInfoCtor ctorName + let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do + let mut patterns := #[] + -- add `_` pattern for indices + for i in [:indVal.nindices] do + patterns := patterns.push (← `(_)) + let mut ctorArgs := #[] + let mut rhs := Syntax.mkStrLit (toString ctorInfo.name) + let mut rhs ← `(Format.text $rhs) + -- add `_` for inductive parameters, they are inaccessible + for i in [:indVal.nparams] do + ctorArgs := ctorArgs.push (← `(_)) + for i in [:ctorInfo.nfields] do + let x := xs[indVal.nparams + i] + let a := mkIdent (← mkFreshUserName `a) + ctorArgs := ctorArgs.push a + if (← inferType x).isAppOf indVal.name then + rhs ← `($rhs ++ Format.line ++ $(mkIdent auxFunName):ident $a:ident maxPrec!) + else + rhs ← `($rhs ++ Format.line ++ reprArg $a) + patterns := patterns.push (← `(@$(mkIdent ctorName):ident $ctorArgs:term*)) + `(matchAltExpr| | $[$patterns:term],* => Repr.addAppParen (Format.group (Format.nest (if prec >= maxPrec! then 1 else 2) ($rhs:term))) prec) + alts := alts.push alt + return alts + +def mkBody (ctx : Context) (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Syntax := do + if isStructureLike (← getEnv) indVal.name then + mkBodyForStruct ctx header indVal + else + mkBodyForInduct ctx header indVal auxFunName + +def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Syntax := do + let auxFunName ← ctx.auxFunNames[i] + let indVal ← ctx.typeInfos[i] + let header ← mkReprHeader ctx indVal + let mut body ← mkBody ctx header indVal auxFunName + if ctx.usePartial then + let letDecls ← mkLocalInstanceLetDecls ctx `Repr header.argNames + body ← mkLet letDecls body + let binders := header.binders + if ctx.usePartial then + `(private partial def $(mkIdent auxFunName):ident $binders:explicitBinder* : Format := $body:term) + else + `(private def $(mkIdent auxFunName):ident $binders:explicitBinder* : Format := $body:term) + +def mkMutualBlock (ctx : Context) : TermElabM Syntax := do + let mut auxDefs := #[] + for i in [:ctx.typeInfos.size] do + auxDefs := auxDefs.push (← mkAuxFunction ctx i) + `(mutual + $auxDefs:command* + end) + +private def mkReprInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do + let ctx ← mkContext "repr" declNames[0] + let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr declNames) + trace[Elab.Deriving.repr]! "\n{cmds}" + return cmds + +open Command + +def mkReprInstanceHandler (declNames : Array Name) : CommandElabM Bool := do + if (← declNames.allM isInductive) && declNames.size > 0 then + let cmds ← liftTermElabM none <| mkReprInstanceCmds declNames + cmds.forM elabCommand + return true + else + return false + +builtin_initialize + registerBuiltinDerivingHandler `Repr mkReprInstanceHandler + registerTraceClass `Elab.Deriving.repr + +end Lean.Elab.Deriving.Repr diff --git a/tests/lean/derivingRepr.lean b/tests/lean/derivingRepr.lean new file mode 100644 index 0000000000..597bfae37a --- /dev/null +++ b/tests/lean/derivingRepr.lean @@ -0,0 +1,33 @@ +structure Foo where + name : String + val : List Nat + lower : Nat := List.length val + inv : val.length >= lower + flag : Bool + deriving Repr + +#eval { name := "Joe", val := List.iota 40, flag := true, inv := by decide! : Foo } + +inductive Tree (α : Type) where + | node : List (Tree α) → Bool → Tree α + | leaf : α → Tree α + deriving Repr + +#eval Tree.node (List.iota 10 |>.map fun i => Tree.node [Tree.leaf i] (i%2==0)) true + +namespace Foo +mutual +inductive Tree (α : Type u) where + | node : TreeList α → Tree α + | leaf : α → Tree α + deriving Repr + +inductive TreeList (α : Type u) where + | nil : TreeList α + | cons : Tree α → TreeList α → TreeList α + deriving Repr +end + +#eval Tree.node (TreeList.cons (Tree.leaf 30) (TreeList.cons (Tree.leaf 20) (TreeList.cons (Tree.leaf 10) TreeList.nil))) + +end Foo diff --git a/tests/lean/derivingRepr.lean.expected.out b/tests/lean/derivingRepr.lean.expected.out new file mode 100644 index 0000000000..418b11387c --- /dev/null +++ b/tests/lean/derivingRepr.lean.expected.out @@ -0,0 +1,22 @@ +{ name := "Joe", + val := [40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, + 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + lower := 40, + inv := _, + flag := true } +Tree.node + [Tree.node [Tree.leaf 10] true, + Tree.node [Tree.leaf 9] false, + Tree.node [Tree.leaf 8] true, + Tree.node [Tree.leaf 7] false, + Tree.node [Tree.leaf 6] true, + Tree.node [Tree.leaf 5] false, + Tree.node [Tree.leaf 4] true, + Tree.node [Tree.leaf 3] false, + Tree.node [Tree.leaf 2] true, + Tree.node [Tree.leaf 1] false] + true +Foo.Tree.node + (Foo.TreeList.cons + (Foo.Tree.leaf 30) + (Foo.TreeList.cons (Foo.Tree.leaf 20) (Foo.TreeList.cons (Foo.Tree.leaf 10) (Foo.TreeList.nil))))