feat: support unused params in RpcEncoding deriver

This commit is contained in:
Gabriel Ebner 2022-07-18 11:06:45 +02:00 committed by Sebastian Ullrich
parent d36552848c
commit 2c0f8fac99
3 changed files with 61 additions and 40 deletions

View file

@ -80,14 +80,8 @@ private def getRpcPacketFor (ty : Expr) : MetaM Expr := do
private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr)
(paramBinders packetParamBinders encInstBinders : Array (TSyntax ``Parser.Term.bracketedBinder)) : TermElabM Command := do
let packetParamNames := packetParamBinders.map fun
| `(bracketedBinder| ($t:ident : $_)) => t
| _ => unreachable!
withFields indVal params fun fields => do
trace[Elab.Deriving.RpcEncoding] "for structure {indVal.name} with params {params}"
-- Postulate that every field have a rpc encoding, storing the encoding type ident
-- in `fieldEncIds`. When multiple fields have the same type, we reuse the encoding type
-- as otherwise typeclass synthesis fails.
let mut binders := #[]
let mut fieldIds := #[]
let mut fieldEncTypeStxs := #[]
@ -111,16 +105,19 @@ private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr
let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName
let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds
let packetId := mkIdent <| indVal.name ++ `RpcEncodingPacket
let packetAppliedId := Syntax.mkApp packetId packetParamNames
let packetAppliedId ← `($packetId ..)
let instId := mkIdent (`_root_ ++ indVal.name.appendBefore "instRpcEncoding")
`(protected structure $packetId:ident $packetParamBinders* where
`(variable $packetParamBinders* in
protected structure $packetId:ident where
$[($fieldIds : $fieldEncTypeStxs)]*
deriving FromJson, ToJson
variable $(paramBinders ++ packetParamBinders ++ encInstBinders)* in
instance : RpcEncoding $typeId $packetAppliedId where
rpcEncode a := return { $[$encInits],* }
rpcDecode a := return { $[$decInits],* }
@[instance] def $instId := show RpcEncoding $typeId $packetAppliedId from {
rpcEncode := fun a => return { $[$encInits],* }
rpcDecode := fun a => return { $[$decInits],* }
}
)
private structure CtorState where
@ -132,22 +129,18 @@ private structure CtorState where
decodes : Array (TSyntax ``Parser.Term.matchAlt) := #[]
private def matchF := Lean.Parser.Term.matchAlt (rhsParser := Lean.Parser.termParser)
private def deriveInductiveInstance (indVal : InductiveVal) (params packetParams : Array Expr)
private def deriveInductiveInstance (indVal : InductiveVal) (params : Array Expr)
(paramBinders packetParamBinders encInstBinders : Array (TSyntax ``Parser.Term.bracketedBinder)) : TermElabM Command := do
trace[Elab.Deriving.RpcEncoding] "for inductive {indVal.name} with params {params}"
let packetParamNames := packetParamBinders.map fun
| `(bracketedBinder| ($t:ident : $_)) => t
| _ => unreachable!
withoutModifyingEnv do
let packetNm := indVal.name ++ `RpcEncodingPacket
addDecl <| .axiomDecl {
name := packetNm
levelParams := []
type := ← mkForallFVars packetParams (mkSort levelOne)
type := mkSort levelOne
isUnsafe := true
}
let pktCtorTp := mkAppN (mkConst packetNm) packetParams
let pktCtorTp := mkConst packetNm
let recInstTp := mkApp2 (mkConst ``RpcEncoding) (mkAppN (mkConst indVal.name) params) pktCtorTp
withLocalDecl `inst .instImplicit recInstTp fun _ => do
let st ← foldWithConstructors indVal params (init := { : CtorState }) fun acc ctor argVars _ => do
@ -180,14 +173,16 @@ private def deriveInductiveInstance (indVal : InductiveVal) (params packetParams
-- helpers for type name syntax
let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName
let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds
let packetId := Syntax.mkApp (mkIdent packetNm) packetParamNames
let packetId ← `($(mkIdent packetNm) ..)
let instId := mkIdent (`_root_ ++ indVal.name.appendBefore "instRpcEncoding")
`(inductive $(mkIdent packetNm) $packetParamBinders:bracketedBinder* where
`(variable $packetParamBinders:bracketedBinder* in
protected inductive $(mkIdent packetNm) where
$[$(st.ctors):ctor]*
deriving FromJson, ToJson
variable $(paramBinders ++ packetParamBinders ++ encInstBinders)* in
partial instance : RpcEncoding $typeId $packetId :=
@[instance] partial def $instId := show RpcEncoding $typeId $packetId from
{ rpcEncode, rpcDecode }
where
rpcEncode {m} [Monad m] [MonadRpcSession m] (x : $typeId) : ExceptT String m $packetId :=
@ -226,18 +221,17 @@ private def deriveInstance (typeName : Name) : CommandElabM Bool := do
packetParamBinders := packetParamBinders.push (← `(bracketedBinder| ($packet : Type)))
encInstBinders := encInstBinders.push (← `(bracketedBinder| [RpcEncoding $(mkIdent paramNm) $packet]))
else
packetParamBinders := packetParamBinders.push paramBinders.back --(← `(bracketedBinder| ($packet : $ty)))
packetParamBinders := packetParamBinders.push paramBinders.back
return (paramBinders, packetParamBinders, encInstBinders)
elabCommand <| ← liftTermElabM none do
Term.elabBinders (paramBinders ++ packetParamBinders ++ encInstBinders) fun locals => do
let params := locals[:paramBinders.size]
let packetParams := locals[paramBinders.size:paramBinders.size+packetParamBinders.size]
if isStructure (← getEnv) typeName then
deriveStructureInstance indVal params paramBinders packetParamBinders encInstBinders
else
deriveInductiveInstance indVal params packetParams paramBinders packetParamBinders encInstBinders
deriveInductiveInstance indVal params paramBinders packetParamBinders encInstBinders
return true

View file

@ -38,21 +38,21 @@ structure Bar where
deriving RpcEncoding, Inhabited
#print Bar.RpcEncodingPacket
#check instRpcEncodingBarRpcEncodingPacket
#check instRpcEncodingBar
#eval test Bar default
structure BarTrans where
bar : Bar
deriving RpcEncoding, Inhabited
#check instRpcEncodingBarTransRpcEncodingPacket
#check instRpcEncodingBarTrans
#eval test BarTrans default
structure Baz where
arr : Array String -- non-constant field
deriving RpcEncoding, Inhabited
#check instRpcEncodingBazRpcEncodingPacket
#check instRpcEncodingBaz
#eval test Baz default
structure FooGeneric (α : Type) where
@ -61,7 +61,7 @@ structure FooGeneric (α : Type) where
deriving RpcEncoding, Inhabited
#print FooGeneric.RpcEncodingPacket
#check instRpcEncodingFooGenericRpcEncodingPacket
#check instRpcEncodingFooGeneric
#eval test (FooGeneric Nat) default
#eval test (FooGeneric Nat) { a := 3, b? := some 42 }
@ -69,7 +69,7 @@ inductive BazInductive
| baz (arr : Array Bar)
deriving RpcEncoding, Inhabited
#check instRpcEncodingBazInductiveRpcEncodingPacket
#check instRpcEncodingBazInductive
#eval test BazInductive ⟨#[default, default]⟩
inductive FooInductive (α : Type) where
@ -78,7 +78,7 @@ inductive FooInductive (α : Type) where
deriving RpcEncoding, Inhabited
#print FooInductive.RpcEncodingPacket
#check instRpcEncodingFooInductiveRpcEncodingPacket
#check instRpcEncodingFooInductive
#eval test (FooInductive BazInductive) (.a default default)
#eval test (FooInductive BazInductive) (.b 42 default default)
@ -94,5 +94,20 @@ inductive FooParam (n : Nat) where
| a : Nat → FooParam n
deriving RpcEncoding, Inhabited
#check instRpcEncodingFooParamRpcEncodingPacket
#check instRpcEncodingFooParam
#eval test (FooParam 10) (.a 42)
inductive Unused (α : Type) | a
deriving RpcEncoding, Inhabited
#print Unused.RpcEncodingPacket
#check instRpcEncodingUnused
structure NoRpcEncoding
#eval test (Unused NoRpcEncoding) default
structure UnusedStruct (α : Type)
deriving RpcEncoding, Inhabited
#print UnusedStruct.RpcEncodingPacket
#check instRpcEncodingUnusedStruct
#eval test (UnusedStruct NoRpcEncoding) default

View file

@ -4,34 +4,34 @@ protected inductive Bar.RpcEncodingPacket : Type
number of parameters: 0
constructors:
Bar.RpcEncodingPacket.mk : Lsp.RpcRef → FooJson → Bar.RpcEncodingPacket
instRpcEncodingBarRpcEncodingPacket : RpcEncoding Bar Bar.RpcEncodingPacket
instRpcEncodingBar : RpcEncoding Bar Bar.RpcEncodingPacket
ok: {"fooRef": {"p": "0"}, "fooJson": {"s": ""}}
instRpcEncodingBarTransRpcEncodingPacket : RpcEncoding BarTrans BarTrans.RpcEncodingPacket
instRpcEncodingBarTrans : RpcEncoding BarTrans BarTrans.RpcEncodingPacket
ok: {"bar": {"fooRef": {"p": "0"}, "fooJson": {"s": ""}}}
instRpcEncodingBazRpcEncodingPacket : RpcEncoding Baz Baz.RpcEncodingPacket
instRpcEncodingBaz : RpcEncoding Baz Baz.RpcEncodingPacket
ok: {"arr": []}
protected inductive FooGeneric.RpcEncodingPacket : Type → Type
number of parameters: 1
constructors:
FooGeneric.RpcEncodingPacket.mk : {αPacket : Type} → αPacket → Option αPacket → FooGeneric.RpcEncodingPacket αPacket
instRpcEncodingFooGenericRpcEncodingPacket : (α αPacket : Type) →
instRpcEncodingFooGeneric : (α αPacket : Type) →
[inst : RpcEncoding α αPacket] → RpcEncoding (FooGeneric α) (FooGeneric.RpcEncodingPacket αPacket)
ok: {"a": 0}
ok: {"b": 42, "a": 3}
instRpcEncodingBazInductiveRpcEncodingPacket : RpcEncoding BazInductive BazInductive.RpcEncodingPacket
instRpcEncodingBazInductive : RpcEncoding BazInductive BazInductive.RpcEncodingPacket
ok: {"baz":
[{"fooRef": {"p": "0"}, "fooJson": {"s": ""}},
{"fooRef": {"p": "1"}, "fooJson": {"s": ""}}]}
inductive FooInductive.RpcEncodingPacket : Type → Type
protected inductive FooInductive.RpcEncodingPacket : Type → Type
number of parameters: 1
constructors:
FooInductive.RpcEncodingPacket.a : {αPacket : Type} → αPacket → Lsp.RpcRef → FooInductive.RpcEncodingPacket αPacket
FooInductive.RpcEncodingPacket.b : {αPacket : Type} → Nat → αPacket → Nat → FooInductive.RpcEncodingPacket αPacket
instRpcEncodingFooInductiveRpcEncodingPacket : (α αPacket : Type) →
instRpcEncodingFooInductive : (α αPacket : Type) →
[inst : RpcEncoding α αPacket] → RpcEncoding (FooInductive α) (FooInductive.RpcEncodingPacket αPacket)
ok: {"a": [{"baz": []}, {"p": "0"}]}
ok: {"b": [42, {"baz": []}, 0]}
inductive FooNested.RpcEncodingPacket : Type → Type
protected inductive FooNested.RpcEncodingPacket : Type → Type
number of parameters: 1
constructors:
FooNested.RpcEncodingPacket.a : {αPacket : Type} →
@ -39,5 +39,17 @@ FooNested.RpcEncodingPacket.a : {αPacket : Type} →
@FooNested.RpcEncodingPacket.a : {αPacket : Type} →
αPacket → Array (FooNested.RpcEncodingPacket αPacket) → FooNested.RpcEncodingPacket αPacket
ok: {"a": [{"baz": []}, [{"a": [{"baz": []}, []]}]]}
instRpcEncodingFooParamRpcEncodingPacket : (n : Nat) → RpcEncoding (FooParam n) (FooParam.RpcEncodingPacket n)
instRpcEncodingFooParam : (n : Nat) → RpcEncoding (FooParam n) FooParam.RpcEncodingPacket
ok: {"a": 42}
protected inductive Unused.RpcEncodingPacket : Type
number of parameters: 0
constructors:
Unused.RpcEncodingPacket.a : Unused.RpcEncodingPacket
instRpcEncodingUnused : (α : Type) → RpcEncoding (Unused α) Unused.RpcEncodingPacket
ok: "a"
protected inductive UnusedStruct.RpcEncodingPacket : Type
number of parameters: 0
constructors:
UnusedStruct.RpcEncodingPacket.mk : UnusedStruct.RpcEncodingPacket
instRpcEncodingUnusedStruct : (α : Type) → RpcEncoding (UnusedStruct α) UnusedStruct.RpcEncodingPacket
ok: {}