feat: support recursive types in RpcEncoding
This commit is contained in:
parent
b7bcb1616a
commit
cde339c2fb
1 changed files with 98 additions and 131 deletions
|
|
@ -25,9 +25,7 @@ private def deriveWithRefInstance (typeNm : Name) : CommandElabM Bool := do
|
|||
rpcDecode := WithRpcRef.decodeUnsafeAs $typeId:ident $(quote typeNm)
|
||||
|
||||
@[implementedBy unsafeInst]
|
||||
instance : RpcEncoding (WithRpcRef $typeId:ident) Lsp.RpcRef where
|
||||
rpcEncode r := pure ⟨0⟩
|
||||
rpcDecode r := throw "unreachable"
|
||||
instance : RpcEncoding (WithRpcRef $typeId:ident) Lsp.RpcRef := default
|
||||
)
|
||||
elabCommand cmds
|
||||
return true
|
||||
|
|
@ -73,10 +71,16 @@ def withFieldsFlattened (indVal : InductiveVal) (params : Array Expr)
|
|||
|
||||
end
|
||||
|
||||
def isOptField (n : Name) : Bool :=
|
||||
n.toString.endsWith "?"
|
||||
private def getRpcPacketFor (ty : Expr) : MetaM Expr := do
|
||||
let packetTy ← mkFreshExprMVar (Expr.sort levelOne)
|
||||
let _ ← synthInstance (mkApp2 (mkConst ``RpcEncoding) ty packetTy)
|
||||
instantiateMVars packetTy
|
||||
|
||||
private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr) : TermElabM Command :=
|
||||
private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr)
|
||||
(paramBinders packetParamBinders encInstBinders : Array (TSyntax ``Parser.Term.bracketedBinder)) : TermElabM Command := do
|
||||
let packetParamNames := packetParamBinders.map fun
|
||||
| `(bracketedBinder| ($t:ident : $_)) => t
|
||||
| _ => unreachable!
|
||||
withFields indVal params fun fields => do
|
||||
trace[Elab.Deriving.RpcEncoding] "for structure {indVal.name} with params {params}"
|
||||
-- Postulate that every field have a rpc encoding, storing the encoding type ident
|
||||
|
|
@ -84,34 +88,16 @@ private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr
|
|||
-- as otherwise typeclass synthesis fails.
|
||||
let mut binders := #[]
|
||||
let mut fieldIds := #[]
|
||||
let mut fieldEncIds : Array Term := #[]
|
||||
let mut uniqFieldEncIds : Array Ident := #[]
|
||||
let mut fieldEncIds' : DiscrTree Ident := {}
|
||||
let mut fieldEncTypeStxs := #[]
|
||||
for (fieldName, fieldTp) in fields do
|
||||
let mut fieldTp := fieldTp
|
||||
if isOptField fieldName then
|
||||
if !fieldTp.isAppOf ``Option then
|
||||
throwError "optional field '{fieldName}' has type{indentD m!"{fieldTp}"}\nbut is expected to have type{indentD "Option _"}" --"
|
||||
fieldTp := fieldTp.getArg! 0
|
||||
|
||||
-- postulate that the field has an encoding and remember the encoding's binder name
|
||||
let fieldEncTypeStx ← PrettyPrinter.delab (← getRpcPacketFor fieldTp)
|
||||
let stx ← PrettyPrinter.delab fieldTp
|
||||
fieldIds := fieldIds.push <| mkIdent fieldName
|
||||
let mut fieldEncId : Ident := ⟨Syntax.missing⟩
|
||||
match (← fieldEncIds'.getMatch fieldTp).back? with
|
||||
| none =>
|
||||
fieldEncId ← mkIdent <$> mkFreshUserName fieldName
|
||||
binders := binders.push (← `(bracketedBinder| ( $fieldEncId:ident )))
|
||||
let stx ← PrettyPrinter.delab fieldTp
|
||||
binders := binders.push
|
||||
(← `(bracketedBinder| [ RpcEncoding $stx $fieldEncId:ident ]))
|
||||
fieldEncIds' ← fieldEncIds'.insert fieldTp fieldEncId
|
||||
uniqFieldEncIds := uniqFieldEncIds.push fieldEncId
|
||||
| some fid => fieldEncId := fid
|
||||
|
||||
if isOptField fieldName then
|
||||
fieldEncIds := fieldEncIds.push <| ← ``(Option $fieldEncId:ident)
|
||||
else
|
||||
fieldEncIds := fieldEncIds.push fieldEncId
|
||||
fieldEncTypeStxs := fieldEncTypeStxs.push fieldEncTypeStx
|
||||
binders := binders.push
|
||||
(← `(bracketedBinder| [ RpcEncoding $stx $fieldEncTypeStx ]))
|
||||
|
||||
-- helpers for field initialization syntax
|
||||
let fieldInits (func : Name) := fieldIds.mapM fun fid =>
|
||||
|
|
@ -123,114 +109,92 @@ private def deriveStructureInstance (indVal : InductiveVal) (params : Array Expr
|
|||
let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName
|
||||
let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds
|
||||
let packetId := mkIdent <| indVal.name ++ `RpcEncodingPacket
|
||||
let packetAppliedId := Syntax.mkApp packetId uniqFieldEncIds
|
||||
let packetAppliedId := Syntax.mkApp packetId packetParamNames
|
||||
|
||||
`(variable $binders*
|
||||
|
||||
protected structure $packetId:ident where
|
||||
$[($fieldIds : $fieldEncIds)]*
|
||||
`(protected structure $packetId:ident $packetParamBinders* where
|
||||
$[($fieldIds : $fieldEncTypeStxs)]*
|
||||
deriving FromJson, ToJson
|
||||
|
||||
variable $(paramBinders ++ packetParamBinders ++ encInstBinders)* in
|
||||
instance : RpcEncoding $typeId $packetAppliedId where
|
||||
rpcEncode a := return {
|
||||
$[$encInits],*
|
||||
}
|
||||
rpcDecode a := return {
|
||||
$[$decInits],*
|
||||
}
|
||||
rpcEncode a := return { $[$encInits],* }
|
||||
rpcDecode a := return { $[$decInits],* }
|
||||
)
|
||||
|
||||
private structure CtorState where
|
||||
-- names of encoded argument types in the RPC packet
|
||||
encArgTypes : DiscrTree Name := {}
|
||||
uniqEncArgTypes : Array Name := #[]
|
||||
-- binders for `encArgTypes` as well as the relevant `RpcEncoding`s
|
||||
binders : Array (TSyntax ``Parser.Term.bracketedBinder) := #[]
|
||||
-- the syntax of each constructor in the packet
|
||||
ctors : Array (TSyntax ``Parser.Command.ctor) := #[]
|
||||
-- syntax of each arm of the `rpcEncode` pattern-match
|
||||
encodes : Array (TSyntax ``Parser.Term.matchAlt) := #[]
|
||||
-- syntax of each arm of the `rpcDecode` pattern-match
|
||||
decodes : Array (TSyntax ``Parser.Term.matchAlt) := #[]
|
||||
deriving Inhabited
|
||||
|
||||
private def matchF := Lean.Parser.Term.matchAlt (rhsParser := Lean.Parser.termParser)
|
||||
private def deriveInductiveInstance (indVal : InductiveVal) (params : Array Expr) : TermElabM Command := do
|
||||
private def deriveInductiveInstance (indVal : InductiveVal) (params packetParams : Array Expr)
|
||||
(paramBinders packetParamBinders encInstBinders : Array (TSyntax ``Parser.Term.bracketedBinder)) : TermElabM Command := do
|
||||
trace[Elab.Deriving.RpcEncoding] "for inductive {indVal.name} with params {params}"
|
||||
let packetParamNames := packetParamBinders.map fun
|
||||
| `(bracketedBinder| ($t:ident : $_)) => t
|
||||
| _ => unreachable!
|
||||
|
||||
-- produce all encoding types and binders for them
|
||||
let st ← foldWithConstructors indVal params (init := { : CtorState}) fun acc ctor argVars tp => do
|
||||
trace[Elab.Deriving.RpcEncoding] "{ctor} : {argVars} → {tp}"
|
||||
let mut acc := acc
|
||||
let argFVars ← argVars.mapM (LocalDecl.fvarId <$> getFVarLocalDecl ·)
|
||||
for arg in argVars do
|
||||
withoutModifyingEnv do
|
||||
let packetNm := indVal.name ++ `RpcEncodingPacket
|
||||
addDecl <| .axiomDecl {
|
||||
name := packetNm
|
||||
levelParams := []
|
||||
type := ← mkForallFVars packetParams (mkSort levelOne)
|
||||
isUnsafe := true
|
||||
}
|
||||
let pktCtorTp := mkAppN (mkConst packetNm) packetParams
|
||||
let recInstTp := mkApp2 (mkConst ``RpcEncoding) (mkAppN (mkConst indVal.name) params) pktCtorTp
|
||||
withLocalDecl `inst .instImplicit recInstTp fun _ => do
|
||||
let st ← foldWithConstructors indVal params (init := { : CtorState }) fun acc ctor argVars _ => do
|
||||
let mut pktCtorTp := pktCtorTp
|
||||
-- create the constructor
|
||||
for arg in argVars.reverse do
|
||||
let argTp ← inferType arg
|
||||
if (← findExprDependsOn argTp (pf := fun fv => argFVars.contains fv)) then
|
||||
throwError "cross-argument dependencies are not supported ({arg} : {argTp})"
|
||||
let encTp ← getRpcPacketFor argTp
|
||||
pktCtorTp := mkForall (← getFVarLocalDecl arg).userName BinderInfo.default encTp pktCtorTp
|
||||
-- TODO(WN): this relies on delab printing non-macro-scoped user names in non-dependent foralls
|
||||
-- to generate the expected JSON encoding
|
||||
let pktCtorTpStx ← PrettyPrinter.delab pktCtorTp
|
||||
let pktCtor ← `(Lean.Parser.Command.ctor| | $(mkIdent ctor.getString!):ident : $pktCtorTpStx:term)
|
||||
|
||||
if (← acc.encArgTypes.getMatch argTp).isEmpty then
|
||||
let tid ← mkFreshUserName (← getFVarLocalDecl arg).userName
|
||||
let argTpStx ← PrettyPrinter.delab argTp
|
||||
acc := { acc with encArgTypes := ← acc.encArgTypes.insert argTp tid
|
||||
uniqEncArgTypes := acc.uniqEncArgTypes.push tid
|
||||
binders := acc.binders.append #[
|
||||
(← `(bracketedBinder| ( $(mkIdent tid):ident ))),
|
||||
(← `(bracketedBinder| [ RpcEncoding $argTpStx $(mkIdent tid):ident ]))
|
||||
] }
|
||||
return acc
|
||||
-- create encoder and decoder match arms
|
||||
let nms ← argVars.mapM fun _ => mkIdent <$> mkFreshBinderName
|
||||
let mkPattern (src : Name) := Syntax.mkApp (mkIdent <| Name.mkStr src ctor.getString!) nms
|
||||
let mkBody (tgt : Name) (func : Name) : TermElabM Term := do
|
||||
let items ← nms.mapM fun nm => `(← $(mkIdent func) $nm)
|
||||
let tm := Syntax.mkApp (mkIdent <| Name.mkStr tgt ctor.getString!) items
|
||||
`(return $tm:term)
|
||||
|
||||
-- introduce encoding types into the local context so that we can use the delaborator to print them
|
||||
withLocalDecls
|
||||
(st.uniqEncArgTypes.map fun tid => (tid, BinderInfo.default, fun _ => pure <| mkSort levelOne))
|
||||
fun ts => do
|
||||
trace[Elab.Deriving.RpcEncoding] m!"RpcEncoding type binders : {ts}"
|
||||
let encArm ← `(matchF| | $(mkPattern indVal.name):term => $(← mkBody packetNm ``rpcEncode))
|
||||
let decArm ← `(matchF| | $(mkPattern packetNm):term => $(← mkBody indVal.name ``rpcDecode))
|
||||
|
||||
let packetNm := indVal.name ++ `RpcEncodingPacket
|
||||
let st ← foldWithConstructors indVal params (init := st) fun acc ctor argVars _ => do
|
||||
-- create the constructor
|
||||
let mut pktCtorTp := Lean.mkConst packetNm
|
||||
for arg in argVars.reverse do
|
||||
let argTp ← inferType arg
|
||||
let encTpNm := (← acc.encArgTypes.getMatch argTp).back
|
||||
let encTp ← elabTerm (mkIdent encTpNm) none
|
||||
pktCtorTp := mkForall (← getFVarLocalDecl arg).userName BinderInfo.default encTp pktCtorTp
|
||||
-- TODO(WN): this relies on delab printing non-macro-scoped user names in non-dependent foralls
|
||||
-- to generate the expected JSON encoding
|
||||
let pktCtorTpStx ← PrettyPrinter.delab pktCtorTp
|
||||
let pktCtor ← `(Lean.Parser.Command.ctor| | $(mkIdent ctor.getString!):ident : $pktCtorTpStx:term)
|
||||
return { acc with ctors := acc.ctors.push pktCtor
|
||||
encodes := acc.encodes.push ⟨encArm⟩
|
||||
decodes := acc.decodes.push ⟨decArm⟩ }
|
||||
|
||||
-- create encoder and decoder match arms
|
||||
let nms ← argVars.mapM fun _ => mkIdent <$> mkFreshBinderName
|
||||
let mkPattern (src : Name) := Syntax.mkApp (mkIdent <| Name.mkStr src ctor.getString!) nms
|
||||
let mkBody (tgt : Name) (func : Name) : TermElabM Term := do
|
||||
let items ← nms.mapM fun nm => `(← $(mkIdent func) $nm)
|
||||
let tm := Syntax.mkApp (mkIdent <| Name.mkStr tgt ctor.getString!) items
|
||||
`(return $tm:term)
|
||||
-- helpers for type name syntax
|
||||
let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName
|
||||
let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds
|
||||
let packetId := Syntax.mkApp (mkIdent packetNm) packetParamNames
|
||||
|
||||
let encArm ← `(matchF| | $(mkPattern indVal.name):term => $(← mkBody packetNm ``rpcEncode))
|
||||
let decArm ← `(matchF| | $(mkPattern packetNm):term => $(← mkBody indVal.name ``rpcDecode))
|
||||
`(inductive $(mkIdent packetNm) $packetParamBinders:bracketedBinder* where
|
||||
$[$(st.ctors):ctor]*
|
||||
deriving FromJson, ToJson
|
||||
|
||||
return { acc with ctors := acc.ctors.push pktCtor
|
||||
encodes := acc.encodes.push ⟨encArm⟩
|
||||
decodes := acc.decodes.push ⟨decArm⟩ }
|
||||
|
||||
-- helpers for type name syntax
|
||||
let paramIds ← params.mapM fun p => return mkIdent (← getFVarLocalDecl p).userName
|
||||
let typeId := Syntax.mkApp (← `(@$(mkIdent indVal.name))) paramIds
|
||||
let packetAppliedId := Syntax.mkApp (mkIdent packetNm) (st.uniqEncArgTypes.map (mkIdent ·))
|
||||
|
||||
`(variable $st.binders*
|
||||
|
||||
protected inductive $(mkIdent packetNm) where
|
||||
$[$(st.ctors):ctor]*
|
||||
deriving FromJson, ToJson
|
||||
|
||||
instance : RpcEncoding $typeId $packetAppliedId where
|
||||
rpcEncode := fun x => match x with
|
||||
$[$(st.encodes):matchAlt]*
|
||||
rpcDecode := fun x => match x with
|
||||
$[$(st.decodes):matchAlt]*
|
||||
)
|
||||
variable $(paramBinders ++ packetParamBinders ++ encInstBinders)* in
|
||||
partial instance : RpcEncoding $typeId $packetId :=
|
||||
{ rpcEncode, rpcDecode }
|
||||
where
|
||||
rpcEncode {m} [Monad m] [MonadRpcSession m] (x : $typeId) : ExceptT String m $packetId :=
|
||||
have inst : RpcEncoding $typeId $packetId := { rpcEncode, rpcDecode }
|
||||
match x with $[$(st.encodes):matchAlt]*
|
||||
rpcDecode {m} [Monad m] [MonadRpcSession m] (x : $packetId) : ExceptT String m $typeId :=
|
||||
have inst : RpcEncoding $typeId $packetId := { rpcEncode, rpcDecode }
|
||||
match x with $[$(st.decodes):matchAlt]*
|
||||
)
|
||||
|
||||
/-- Creates an `RpcEncodingPacket` for `typeName`. For structures, the packet is a structure
|
||||
with the same field names. For inductives, it mirrors the inductive structure with every field
|
||||
|
|
@ -243,30 +207,33 @@ private def deriveInstance (typeName : Name) : CommandElabM Bool := do
|
|||
if indVal.numIndices ≠ 0 then
|
||||
throwError "indexed inductive families are not supported"
|
||||
|
||||
let cmd ← liftTermElabM none do
|
||||
let (paramBinders, packetParamBinders, encInstBinders) ← liftTermElabM none do
|
||||
-- introduce fvars for all the parameters
|
||||
forallTelescopeReducing indVal.type fun params _ => do
|
||||
assert! params.size == indVal.numParams
|
||||
let mut paramBinders := #[] -- input parameters
|
||||
let mut packetParamBinders := #[] -- RPC encoding packets for type input parameters
|
||||
let mut encInstBinders := #[] -- RPC encoding instance binders corresponding to packetParamBinders
|
||||
|
||||
-- bind every parameter and *some* (not named) `RpcEncoding` for it
|
||||
let mut binders := #[]
|
||||
for param in params do
|
||||
let paramNm := (←getFVarLocalDecl param).userName
|
||||
binders := binders.push (← `(bracketedBinder| ( $(mkIdent paramNm) )))
|
||||
let paramNm := (← getFVarLocalDecl param).userName
|
||||
paramBinders := paramBinders.push (← `(bracketedBinder| ($(mkIdent paramNm))))
|
||||
-- only look for encodings for `Type` parameters
|
||||
if !(← inferType param).isType then continue
|
||||
binders := binders.push
|
||||
(← `(bracketedBinder| [ RpcEncoding $(mkIdent paramNm) _ ]))
|
||||
if (← inferType param).isType then
|
||||
let packet := mkIdent (← mkFreshUserName (paramNm.appendAfter "Packet"))
|
||||
packetParamBinders := packetParamBinders.push (← `(bracketedBinder| ($packet : Type)))
|
||||
encInstBinders := encInstBinders.push (← `(bracketedBinder| [RpcEncoding $(mkIdent paramNm) $packet]))
|
||||
|
||||
`(section
|
||||
variable $binders*
|
||||
$(← if isStructure (← getEnv) typeName then
|
||||
deriveStructureInstance indVal params
|
||||
else
|
||||
deriveInductiveInstance indVal params):command
|
||||
end)
|
||||
return (paramBinders, packetParamBinders, encInstBinders)
|
||||
|
||||
elabCommand <| ← liftTermElabM none do
|
||||
Term.elabBinders (paramBinders ++ packetParamBinders ++ encInstBinders) fun locals => do
|
||||
let params := locals[:paramBinders.size]
|
||||
let packetParams := locals[paramBinders.size:paramBinders.size+packetParamBinders.size]
|
||||
if isStructure (← getEnv) typeName then
|
||||
deriveStructureInstance indVal params paramBinders packetParamBinders encInstBinders
|
||||
else
|
||||
deriveInductiveInstance indVal params packetParams paramBinders packetParamBinders encInstBinders
|
||||
|
||||
elabCommand cmd
|
||||
return true
|
||||
|
||||
private unsafe def dispatchDeriveInstanceUnsafe (declNames : Array Name) (args? : Option (TSyntax ``Parser.Term.structInst)) : CommandElabM Bool := do
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue