feat: support recursive types in RpcEncoding

This commit is contained in:
Gabriel Ebner 2022-07-17 18:49:26 +02:00 committed by Sebastian Ullrich
parent b7bcb1616a
commit cde339c2fb

View file

@ -25,9 +25,7 @@ private def deriveWithRefInstance (typeNm : Name) : CommandElabM Bool := do
rpcDecode := WithRpcRef.decodeUnsafeAs $typeId:ident $(quote typeNm)
@[implementedBy unsafeInst]
instance : RpcEncoding (WithRpcRef $typeId:ident) Lsp.RpcRef where
rpcEncode r := pure ⟨0⟩
rpcDecode r := throw "unreachable"
instance : RpcEncoding (WithRpcRef $typeId:ident) Lsp.RpcRef := default
)
elabCommand cmds
return true
@ -73,10 +71,16 @@ def withFieldsFlattened (indVal : InductiveVal) (params : Array Expr)
end
def isOptField (n : Name) : Bool :=
n.toString.endsWith "?"
private def getRpcPacketFor (ty : Expr) : MetaM Expr := do
let packetTy ← mkFreshExprMVar (Expr.sort levelOne)
let _ ← synthInstance (mkApp2 (mkConst ``RpcEncoding) ty packetTy)
instantiateMVars packetTy
private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr) : TermElabM Command :=
private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr)
(paramBinders packetParamBinders encInstBinders : Array (TSyntax ``Parser.Term.bracketedBinder)) : TermElabM Command := do
let packetParamNames := packetParamBinders.map fun
| `(bracketedBinder| ($t:ident : $_)) => t
| _ => unreachable!
withFields indVal params fun fields => do
trace[Elab.Deriving.RpcEncoding] "for structure {indVal.name} with params {params}"
-- Postulate that every field have a rpc encoding, storing the encoding type ident
@ -84,34 +88,16 @@ private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr
-- as otherwise typeclass synthesis fails.
let mut binders := #[]
let mut fieldIds := #[]
let mut fieldEncIds : Array Term := #[]
let mut uniqFieldEncIds : Array Ident := #[]
let mut fieldEncIds' : DiscrTree Ident := {}
let mut fieldEncTypeStxs := #[]
for (fieldName, fieldTp) in fields do
let mut fieldTp := fieldTp
if isOptField fieldName then
if !fieldTp.isAppOf ``Option then
throwError "optional field '{fieldName}' has type{indentD m!"{fieldTp}"}\nbut is expected to have type{indentD "Option _"}" --"
fieldTp := fieldTp.getArg! 0
-- postulate that the field has an encoding and remember the encoding's binder name
let fieldEncTypeStx ← PrettyPrinter.delab (← getRpcPacketFor fieldTp)
let stx ← PrettyPrinter.delab fieldTp
fieldIds := fieldIds.push <| mkIdent fieldName
let mut fieldEncId : Ident := ⟨Syntax.missing⟩
match (← fieldEncIds'.getMatch fieldTp).back? with
| none =>
fieldEncId ← mkIdent <$> mkFreshUserName fieldName
binders := binders.push (← `(bracketedBinder| ( $fieldEncId:ident )))
let stx ← PrettyPrinter.delab fieldTp
binders := binders.push
(← `(bracketedBinder| [ RpcEncoding $stx $fieldEncId:ident ]))
fieldEncIds' ← fieldEncIds'.insert fieldTp fieldEncId
uniqFieldEncIds := uniqFieldEncIds.push fieldEncId
| some fid => fieldEncId := fid
if isOptField fieldName then
fieldEncIds := fieldEncIds.push <| ← ``(Option $fieldEncId:ident)
else
fieldEncIds := fieldEncIds.push fieldEncId
fieldEncTypeStxs := fieldEncTypeStxs.push fieldEncTypeStx
binders := binders.push
(← `(bracketedBinder| [ RpcEncoding $stx $fieldEncTypeStx ]))
-- helpers for field initialization syntax
let fieldInits (func : Name) := fieldIds.mapM fun fid =>
@ -123,114 +109,92 @@ private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr
let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName
let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds
let packetId := mkIdent <| indVal.name ++ `RpcEncodingPacket
let packetAppliedId := Syntax.mkApp packetId uniqFieldEncIds
let packetAppliedId := Syntax.mkApp packetId packetParamNames
`(variable $binders*
protected structure $packetId:ident where
$[($fieldIds : $fieldEncIds)]*
`(protected structure $packetId:ident $packetParamBinders* where
$[($fieldIds : $fieldEncTypeStxs)]*
deriving FromJson, ToJson
variable $(paramBinders ++ packetParamBinders ++ encInstBinders)* in
instance : RpcEncoding $typeId $packetAppliedId where
rpcEncode a := return {
$[$encInits],*
}
rpcDecode a := return {
$[$decInits],*
}
rpcEncode a := return { $[$encInits],* }
rpcDecode a := return { $[$decInits],* }
)
private structure CtorState where
-- names of encoded argument types in the RPC packet
encArgTypes : DiscrTree Name := {}
uniqEncArgTypes : Array Name := #[]
-- binders for `encArgTypes` as well as the relevant `RpcEncoding`s
binders : Array (TSyntax ``Parser.Term.bracketedBinder) := #[]
-- the syntax of each constructor in the packet
ctors : Array (TSyntax ``Parser.Command.ctor) := #[]
-- syntax of each arm of the `rpcEncode` pattern-match
encodes : Array (TSyntax ``Parser.Term.matchAlt) := #[]
-- syntax of each arm of the `rpcDecode` pattern-match
decodes : Array (TSyntax ``Parser.Term.matchAlt) := #[]
deriving Inhabited
private def matchF := Lean.Parser.Term.matchAlt (rhsParser := Lean.Parser.termParser)
private def deriveInductiveInstance (indVal : InductiveVal) (params : Array Expr) : TermElabM Command := do
private def deriveInductiveInstance (indVal : InductiveVal) (params packetParams : Array Expr)
(paramBinders packetParamBinders encInstBinders : Array (TSyntax ``Parser.Term.bracketedBinder)) : TermElabM Command := do
trace[Elab.Deriving.RpcEncoding] "for inductive {indVal.name} with params {params}"
let packetParamNames := packetParamBinders.map fun
| `(bracketedBinder| ($t:ident : $_)) => t
| _ => unreachable!
-- produce all encoding types and binders for them
let st ← foldWithConstructors indVal params (init := { : CtorState}) fun acc ctor argVars tp => do
trace[Elab.Deriving.RpcEncoding] "{ctor} : {argVars} → {tp}"
let mut acc := acc
let argFVars ← argVars.mapM (LocalDecl.fvarId <$> getFVarLocalDecl ·)
for arg in argVars do
withoutModifyingEnv do
let packetNm := indVal.name ++ `RpcEncodingPacket
addDecl <| .axiomDecl {
name := packetNm
levelParams := []
type := ← mkForallFVars packetParams (mkSort levelOne)
isUnsafe := true
}
let pktCtorTp := mkAppN (mkConst packetNm) packetParams
let recInstTp := mkApp2 (mkConst ``RpcEncoding) (mkAppN (mkConst indVal.name) params) pktCtorTp
withLocalDecl `inst .instImplicit recInstTp fun _ => do
let st ← foldWithConstructors indVal params (init := { : CtorState }) fun acc ctor argVars _ => do
let mut pktCtorTp := pktCtorTp
-- create the constructor
for arg in argVars.reverse do
let argTp ← inferType arg
if (← findExprDependsOn argTp (pf := fun fv => argFVars.contains fv)) then
throwError "cross-argument dependencies are not supported ({arg} : {argTp})"
let encTp ← getRpcPacketFor argTp
pktCtorTp := mkForall (← getFVarLocalDecl arg).userName BinderInfo.default encTp pktCtorTp
-- TODO(WN): this relies on delab printing non-macro-scoped user names in non-dependent foralls
-- to generate the expected JSON encoding
let pktCtorTpStx ← PrettyPrinter.delab pktCtorTp
let pktCtor ← `(Lean.Parser.Command.ctor| | $(mkIdent ctor.getString!):ident : $pktCtorTpStx:term)
if (← acc.encArgTypes.getMatch argTp).isEmpty then
let tid ← mkFreshUserName (← getFVarLocalDecl arg).userName
let argTpStx ← PrettyPrinter.delab argTp
acc := { acc with encArgTypes := ← acc.encArgTypes.insert argTp tid
uniqEncArgTypes := acc.uniqEncArgTypes.push tid
binders := acc.binders.append #[
(← `(bracketedBinder| ( $(mkIdent tid):ident ))),
(← `(bracketedBinder| [ RpcEncoding $argTpStx $(mkIdent tid):ident ]))
] }
return acc
-- create encoder and decoder match arms
let nms ← argVars.mapM fun _ => mkIdent <$> mkFreshBinderName
let mkPattern (src : Name) := Syntax.mkApp (mkIdent <| Name.mkStr src ctor.getString!) nms
let mkBody (tgt : Name) (func : Name) : TermElabM Term := do
let items ← nms.mapM fun nm => `(← $(mkIdent func) $nm)
let tm := Syntax.mkApp (mkIdent <| Name.mkStr tgt ctor.getString!) items
`(return $tm:term)
-- introduce encoding types into the local context so that we can use the delaborator to print them
withLocalDecls
(st.uniqEncArgTypes.map fun tid => (tid, BinderInfo.default, fun _ => pure <| mkSort levelOne))
fun ts => do
trace[Elab.Deriving.RpcEncoding] m!"RpcEncoding type binders : {ts}"
let encArm ← `(matchF| | $(mkPattern indVal.name):term => $(← mkBody packetNm ``rpcEncode))
let decArm ← `(matchF| | $(mkPattern packetNm):term => $(← mkBody indVal.name ``rpcDecode))
let packetNm := indVal.name ++ `RpcEncodingPacket
let st ← foldWithConstructors indVal params (init := st) fun acc ctor argVars _ => do
-- create the constructor
let mut pktCtorTp := Lean.mkConst packetNm
for arg in argVars.reverse do
let argTp ← inferType arg
let encTpNm := (← acc.encArgTypes.getMatch argTp).back
let encTp ← elabTerm (mkIdent encTpNm) none
pktCtorTp := mkForall (← getFVarLocalDecl arg).userName BinderInfo.default encTp pktCtorTp
-- TODO(WN): this relies on delab printing non-macro-scoped user names in non-dependent foralls
-- to generate the expected JSON encoding
let pktCtorTpStx ← PrettyPrinter.delab pktCtorTp
let pktCtor ← `(Lean.Parser.Command.ctor| | $(mkIdent ctor.getString!):ident : $pktCtorTpStx:term)
return { acc with ctors := acc.ctors.push pktCtor
encodes := acc.encodes.push ⟨encArm⟩
decodes := acc.decodes.push ⟨decArm⟩ }
-- create encoder and decoder match arms
let nms ← argVars.mapM fun _ => mkIdent <$> mkFreshBinderName
let mkPattern (src : Name) := Syntax.mkApp (mkIdent <| Name.mkStr src ctor.getString!) nms
let mkBody (tgt : Name) (func : Name) : TermElabM Term := do
let items ← nms.mapM fun nm => `(← $(mkIdent func) $nm)
let tm := Syntax.mkApp (mkIdent <| Name.mkStr tgt ctor.getString!) items
`(return $tm:term)
-- helpers for type name syntax
let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName
let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds
let packetId := Syntax.mkApp (mkIdent packetNm) packetParamNames
let encArm ← `(matchF| | $(mkPattern indVal.name):term => $(← mkBody packetNm ``rpcEncode))
let decArm ← `(matchF| | $(mkPattern packetNm):term => $(← mkBody indVal.name ``rpcDecode))
`(inductive $(mkIdent packetNm) $packetParamBinders:bracketedBinder* where
$[$(st.ctors):ctor]*
deriving FromJson, ToJson
return { acc with ctors := acc.ctors.push pktCtor
encodes := acc.encodes.push ⟨encArm⟩
decodes := acc.decodes.push ⟨decArm⟩ }
-- helpers for type name syntax
let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName
let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds
let packetAppliedId := Syntax.mkApp (mkIdent packetNm) (st.uniqEncArgTypes.map (mkIdent ·))
`(variable $st.binders*
protected inductive $(mkIdent packetNm) where
$[$(st.ctors):ctor]*
deriving FromJson, ToJson
instance : RpcEncoding $typeId $packetAppliedId where
rpcEncode := fun x => match x with
$[$(st.encodes):matchAlt]*
rpcDecode := fun x => match x with
$[$(st.decodes):matchAlt]*
)
variable $(paramBinders ++ packetParamBinders ++ encInstBinders)* in
partial instance : RpcEncoding $typeId $packetId :=
{ rpcEncode, rpcDecode }
where
rpcEncode {m} [Monad m] [MonadRpcSession m] (x : $typeId) : ExceptT String m $packetId :=
have inst : RpcEncoding $typeId $packetId := { rpcEncode, rpcDecode }
match x with $[$(st.encodes):matchAlt]*
rpcDecode {m} [Monad m] [MonadRpcSession m] (x : $packetId) : ExceptT String m $typeId :=
have inst : RpcEncoding $typeId $packetId := { rpcEncode, rpcDecode }
match x with $[$(st.decodes):matchAlt]*
)
/-- Creates an `RpcEncodingPacket` for `typeName`. For structures, the packet is a structure
with the same field names. For inductives, it mirrors the inductive structure with every field
@ -243,30 +207,33 @@ private def deriveInstance (typeName : Name) : CommandElabM Bool := do
if indVal.numIndices ≠ 0 then
throwError "indexed inductive families are not supported"
let cmd ← liftTermElabM none do
let (paramBinders, packetParamBinders, encInstBinders) ← liftTermElabM none do
-- introduce fvars for all the parameters
forallTelescopeReducing indVal.type fun params _ => do
assert! params.size == indVal.numParams
let mut paramBinders := #[] -- input parameters
let mut packetParamBinders := #[] -- RPC encoding packets for type input parameters
let mut encInstBinders := #[] -- RPC encoding instance binders corresponding to packetParamBinders
-- bind every parameter and *some* (not named) `RpcEncoding` for it
let mut binders := #[]
for param in params do
let paramNm := (←getFVarLocalDecl param).userName
binders := binders.push (← `(bracketedBinder| ( $(mkIdent paramNm) )))
let paramNm := (← getFVarLocalDecl param).userName
paramBinders := paramBinders.push (← `(bracketedBinder| ($(mkIdent paramNm))))
-- only look for encodings for `Type` parameters
if !(← inferType param).isType then continue
binders := binders.push
(← `(bracketedBinder| [ RpcEncoding $(mkIdent paramNm) _ ]))
if (← inferType param).isType then
let packet := mkIdent (← mkFreshUserName (paramNm.appendAfter "Packet"))
packetParamBinders := packetParamBinders.push (← `(bracketedBinder| ($packet : Type)))
encInstBinders := encInstBinders.push (← `(bracketedBinder| [RpcEncoding $(mkIdent paramNm) $packet]))
`(section
variable $binders*
$(← if isStructure (← getEnv) typeName then
deriveStructureInstance indVal params
else
deriveInductiveInstance indVal params):command
end)
return (paramBinders, packetParamBinders, encInstBinders)
elabCommand <| ← liftTermElabM none do
Term.elabBinders (paramBinders ++ packetParamBinders ++ encInstBinders) fun locals => do
let params := locals[:paramBinders.size]
let packetParams := locals[paramBinders.size:paramBinders.size+packetParamBinders.size]
if isStructure (← getEnv) typeName then
deriveStructureInstance indVal params paramBinders packetParamBinders encInstBinders
else
deriveInductiveInstance indVal params packetParams paramBinders packetParamBinders encInstBinders
elabCommand cmd
return true
private unsafe def dispatchDeriveInstanceUnsafe (declNames : Array Name) (args? : Option (TSyntax ``Parser.Term.structInst)) : CommandElabM Bool := do