fix: manage all declarations in a given derive (#3058)

Closes #3057
This commit is contained in:
Arthur Adjedj 2024-01-09 08:42:06 +01:00 committed by GitHub
parent 0aa2b83450
commit b7c3ff6e6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 182 additions and 134 deletions

View file

@ -91,9 +91,9 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
$auxDefs:command*
end)
private def mkBEqInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "beq" declNames[0]!
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq declNames)
private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "beq" declName
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName])
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds
@ -109,14 +109,18 @@ private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do
open Command
def mkBEqInstance (declName : Name) : CommandElabM Unit := do
let cmds ← liftTermElabM <|
if (← isEnumType declName) then
mkBEqEnumCmd declName
else
mkBEqInstanceCmds declName
cmds.forM elabCommand
def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size == 1 && (← isEnumType declNames[0]!) then
let cmds ← liftTermElabM <| mkBEqEnumCmd declNames[0]!
cmds.forM elabCommand
return true
else if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM <| mkBEqInstanceCmds declNames
cmds.forM elabCommand
if (← declNames.allM isInductive) then
for declName in declNames do
mkBEqInstance declName
return true
else
return false

View file

@ -186,12 +186,15 @@ def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
trace[Elab.Deriving.decEq] "\n{cmd}"
elabCommand cmd
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← isEnumType declNames[0]!) then
mkDecEqEnum declNames[0]!
def mkDecEqInstance (declName : Name) : CommandElabM Bool := do
if (← isEnumType declName) then
mkDecEqEnum declName
return true
else
mkDecEq declNames[0]!
mkDecEq declName
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
declNames.foldlM (fun b n => andM (pure b) (mkDecEqInstance n)) true
builtin_initialize
registerDerivingHandler `DecidableEq mkDecEqInstanceHandler

View file

@ -19,60 +19,58 @@ def mkJsonField (n : Name) : CoreM (Bool × Term) := do
let s₁ := s.dropRightWhile (· == '?')
return (s != s₁, Syntax.mkStrLit s₁)
def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size == 1 then
if isStructure (← getEnv) declNames[0]! then
let cmds ← liftTermElabM do
let ctx ← mkContext "toJson" declNames[0]!
let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]!
let fields := getStructureFieldsFlattened (← getEnv) declNames[0]! (includeSubobjectFields := false)
let fields ← fields.mapM fun field => do
let (isOptField, nm) ← mkJsonField field
let target := mkIdent header.targetNames[0]!
if isOptField then ``(opt $nm ($target).$(mkIdent field))
else ``([($nm, toJson ($target).$(mkIdent field))])
let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* : Json :=
mkObj <| List.join [$fields,*])
return #[cmd] ++ (← mkInstanceCmds ctx ``ToJson declNames)
cmds.forM elabCommand
return true
else
let indVal ← getConstInfoInduct declNames[0]!
let cmds ← liftTermElabM do
let ctx ← mkContext "toJson" declNames[0]!
let toJsonFuncId := mkIdent ctx.auxFunNames[0]!
-- Return syntax to JSONify `id`, either via `ToJson` or recursively
-- if `id`'s type is the type we're deriving for.
let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do
if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident)
else ``(toJson $id:ident)
let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]!
let discrs ← mkDiscrs header indVal
let alts ← mkAlts indVal fun ctor args userNames => do
let ctorStr := ctor.name.eraseMacroScopes.getString!
match args, userNames with
| #[], _ => ``(toJson $(quote ctorStr))
| #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))])
| xs, none =>
let xs ← xs.mapM fun (x, t) => mkToJson x t
``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])])
| xs, some userNames =>
let xs ← xs.mapIdxM fun idx (x, t) => do
`(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t)))
``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])])
let auxTerm ← `(match $[$discrs],* with $alts:matchAlt*)
let auxCmd ←
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames
let auxTerm ← mkLet letDecls auxTerm
`(private partial def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm)
else
`(private def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm)
return #[auxCmd] ++ (← mkInstanceCmds ctx ``ToJson declNames)
cmds.forM elabCommand
return true
def mkToJsonInstance (declName : Name) : CommandElabM Bool := do
if isStructure (← getEnv) declName then
let cmds ← liftTermElabM do
let ctx ← mkContext "toJson" declName
let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]!
let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false)
let fields ← fields.mapM fun field => do
let (isOptField, nm) ← mkJsonField field
let target := mkIdent header.targetNames[0]!
if isOptField then ``(opt $nm ($target).$(mkIdent field))
else ``([($nm, toJson ($target).$(mkIdent field))])
let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* : Json :=
mkObj <| List.join [$fields,*])
return #[cmd] ++ (← mkInstanceCmds ctx ``ToJson #[declName])
cmds.forM elabCommand
return true
else
return false
let indVal ← getConstInfoInduct declName
let cmds ← liftTermElabM do
let ctx ← mkContext "toJson" declName
let toJsonFuncId := mkIdent ctx.auxFunNames[0]!
-- Return syntax to JSONify `id`, either via `ToJson` or recursively
-- if `id`'s type is the type we're deriving for.
let mkToJson (id : Ident) (type : Expr) : TermElabM Term := do
if type.isAppOf indVal.name then `($toJsonFuncId:ident $id:ident)
else ``(toJson $id:ident)
let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]!
let discrs ← mkDiscrs header indVal
let alts ← mkAlts indVal fun ctor args userNames => do
let ctorStr := ctor.name.eraseMacroScopes.getString!
match args, userNames with
| #[], _ => ``(toJson $(quote ctorStr))
| #[(x, t)], none => ``(mkObj [($(quote ctorStr), $(← mkToJson x t))])
| xs, none =>
let xs ← xs.mapM fun (x, t) => mkToJson x t
``(mkObj [($(quote ctorStr), Json.arr #[$[$xs:term],*])])
| xs, some userNames =>
let xs ← xs.mapIdxM fun idx (x, t) => do
`(($(quote userNames[idx]!.eraseMacroScopes.getString!), $(← mkToJson x t)))
``(mkObj [($(quote ctorStr), mkObj [$[$xs:term],*])])
let auxTerm ← `(match $[$discrs],* with $alts:matchAlt*)
let auxCmd ←
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx ``ToJson header.argNames
let auxTerm ← mkLet letDecls auxTerm
`(private partial def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm)
else
`(private def $toJsonFuncId:ident $header.binders:bracketedBinder* : Json := $auxTerm)
return #[auxCmd] ++ (← mkInstanceCmds ctx ``ToJson #[declName])
cmds.forM elabCommand
return true
where
mkAlts
(indVal : InductiveVal)
@ -103,54 +101,51 @@ where
let rhs ← rhs ctorInfo binders (if userNames.size == binders.size then some userNames else none)
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size == 1 then
let declName := declNames[0]!
if isStructure (← getEnv) declName then
let cmds ← liftTermElabM do
let ctx ← mkContext "fromJson" declName
let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]!
let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false)
let getters ← fields.mapM (fun field => do
let getter ← `(getObjValAs? j _ $(Prod.snd <| ← mkJsonField field))
let getter ← `(doElem| Except.mapError (fun s => (toString $(quote declName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter)
return getter
)
let fields := fields.map mkIdent
let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* (j : Json)
: Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := do
$[let $fields:ident ← $getters]*
return { $[$fields:ident := $(id fields)],* })
return #[cmd] ++ (← mkInstanceCmds ctx ``FromJson declNames)
cmds.forM elabCommand
return true
else
let indVal ← getConstInfoInduct declName
let cmds ← liftTermElabM do
let ctx ← mkContext "fromJson" declName
let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]!
let fromJsonFuncId := mkIdent ctx.auxFunNames[0]!
let alts ← mkAlts indVal fromJsonFuncId
let mut auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched"))
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames
auxTerm ← mkLet letDecls auxTerm
-- FromJson is not structurally recursive even non-nested recursive inductives,
-- so we also use `partial` then.
let auxCmd ←
if ctx.usePartial || indVal.isRec then
`(private partial def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json)
: Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) :=
$auxTerm)
else
`(private def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json)
: Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) :=
$auxTerm)
return #[auxCmd] ++ (← mkInstanceCmds ctx ``FromJson declNames)
cmds.forM elabCommand
return true
def mkFromJsonInstance (declName : Name) : CommandElabM Bool := do
if isStructure (← getEnv) declName then
let cmds ← liftTermElabM do
let ctx ← mkContext "fromJson" declName
let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]!
let fields := getStructureFieldsFlattened (← getEnv) declName (includeSubobjectFields := false)
let getters ← fields.mapM (fun field => do
let getter ← `(getObjValAs? j _ $(Prod.snd <| ← mkJsonField field))
let getter ← `(doElem| Except.mapError (fun s => (toString $(quote declName)) ++ "." ++ (toString $(quote field)) ++ ": " ++ s) <| $getter)
return getter
)
let fields := fields.map mkIdent
let cmd ← `(private def $(mkIdent ctx.auxFunNames[0]!):ident $header.binders:bracketedBinder* (j : Json)
: Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) := do
$[let $fields:ident ← $getters]*
return { $[$fields:ident := $(id fields)],* })
return #[cmd] ++ (← mkInstanceCmds ctx ``FromJson #[declName])
cmds.forM elabCommand
return true
else
return false
let indVal ← getConstInfoInduct declName
let cmds ← liftTermElabM do
let ctx ← mkContext "fromJson" declName
let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]!
let fromJsonFuncId := mkIdent ctx.auxFunNames[0]!
let alts ← mkAlts indVal fromJsonFuncId
let mut auxTerm ← alts.foldrM (fun xs x => `(Except.orElseLazy $xs (fun _ => $x))) (← `(Except.error "no inductive constructor matched"))
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx ``FromJson header.argNames
auxTerm ← mkLet letDecls auxTerm
-- FromJson is not structurally recursive even non-nested recursive inductives,
-- so we also use `partial` then.
let auxCmd ←
if ctx.usePartial || indVal.isRec then
`(private partial def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json)
: Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) :=
$auxTerm)
else
`(private def $fromJsonFuncId:ident $header.binders:bracketedBinder* (json : Json)
: Except String $(← mkInductiveApp ctx.typeInfos[0]! header.argNames) :=
$auxTerm)
return #[auxCmd] ++ (← mkInstanceCmds ctx ``FromJson #[declName])
cmds.forM elabCommand
return true
where
mkAlts (indVal : InductiveVal) (fromJsonFuncId : Ident) : TermElabM (Array Term) := do
let alts ←
@ -188,6 +183,12 @@ where
let alts := alts.qsort (fun (_, x) (_, y) => x < y)
return alts.map Prod.fst
def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
declNames.foldlM (fun b n => andM (pure b) (mkToJsonInstance n)) true
def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
declNames.foldlM (fun b n => andM (pure b) (mkFromJsonInstance n)) true
builtin_initialize
registerDerivingHandler ``ToJson mkToJsonInstanceHandler
registerDerivingHandler ``FromJson mkFromJsonInstanceHandler

View file

@ -75,16 +75,17 @@ def mkHashFuncs (ctx : Context) : TermElabM Syntax := do
auxDefs := auxDefs.push (← mkAuxFunction ctx i)
`(mutual $auxDefs:command* end)
private def mkHashableInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "hash" declNames[0]!
let cmds := #[← mkHashFuncs ctx] ++ (← mkInstanceCmds ctx `Hashable declNames)
private def mkHashableInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "hash" declName
let cmds := #[← mkHashFuncs ctx] ++ (← mkInstanceCmds ctx `Hashable #[declName])
trace[Elab.Deriving.hashable] "\n{cmds}"
return cmds
def mkHashableHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM <| mkHashableInstanceCmds declNames
cmds.forM elabCommand
if (← declNames.allM isInductive) then
for declName in declNames do
let cmds ← liftTermElabM <| mkHashableInstanceCmds declName
cmds.forM elabCommand
return true
else
return false

View file

@ -86,18 +86,19 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
$auxDefs:command*
end)
private def mkOrdInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "ord" declNames[0]!
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord declNames)
private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "ord" declName
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName])
trace[Elab.Deriving.ord] "\n{cmds}"
return cmds
open Command
def mkOrdInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM <| mkOrdInstanceCmds declNames
cmds.forM elabCommand
if (← declNames.allM isInductive) then
for declName in declNames do
let cmds ← liftTermElabM <| mkOrdInstanceCmds declName
cmds.forM elabCommand
return true
else
return false

View file

@ -104,18 +104,19 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
$auxDefs:command*
end)
private def mkReprInstanceCmds (declNames : Array Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "repr" declNames[0]!
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr declNames)
private def mkReprInstanceCmd (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext "repr" declName
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Repr #[declName])
trace[Elab.Deriving.repr] "\n{cmds}"
return cmds
open Command
def mkReprInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM <| mkReprInstanceCmds declNames
cmds.forM elabCommand
if (← declNames.allM isInductive) then
for declName in declNames do
let cmds ← liftTermElabM <| mkReprInstanceCmd declName
cmds.forM elabCommand
return true
else
return false

View file

@ -16,8 +16,9 @@ namespace Lean.Elab.Deriving.SizeOf
open Command
def mkSizeOfHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
liftTermElabM <| Meta.mkSizeOfInstances declNames[0]!
if (← declNames.allM isInductive) then
for declName in declNames do
liftTermElabM <| Meta.mkSizeOfInstances declName
return true
else
return false

26
tests/lean/3057.lean Normal file
View file

@ -0,0 +1,26 @@
/-
The derive handlers should manage both inductives and not ignore the second one,
Fixes `#3057`
-/
mutual
inductive Tree : Type :=
| node : ListTree → Tree
deriving Repr, DecidableEq, BEq, Hashable, Ord
inductive ListTree : Type :=
| nil : ListTree
| cons : Tree → ListTree → ListTree
deriving Repr, DecidableEq, BEq, Hashable, Ord
end
#synth Repr Tree
#synth Repr ListTree
#synth DecidableEq Tree
#synth DecidableEq ListTree
#synth BEq Tree
#synth BEq ListTree
#synth Hashable Tree
#synth Hashable ListTree
#synth Ord Tree
#synth Ord ListTree

View file

@ -0,0 +1,10 @@
instReprTree
instReprListTree
fun a b => instDecidableEqTree a b
fun a b => instDecidableEqListTree a b
instBEqTree
instBEqListTree
instHashableTree
instHashableListTree
instOrdTree
instOrdListTree