diff --git a/src/grpc_client.app.src b/src/grpc_client.app.src index 5ea4b44..6ec2512 100644 --- a/src/grpc_client.app.src +++ b/src/grpc_client.app.src @@ -2,6 +2,7 @@ [{description,"gRPC client in Erlang"}, {vsn,"0.1.0"}, {modules,[]}, + {applications, [grpc_lib,http2_client]}, {registered, []}, {env, []}, {applications,[]}]}. diff --git a/src/grpc_client.erl b/src/grpc_client.erl index 9a79ff5..917a02c 100644 --- a/src/grpc_client.erl +++ b/src/grpc_client.erl @@ -71,10 +71,12 @@ -type metadata() :: #{metadata_key() => metadata_value()}. -type compression_method() :: none | gzip. +-type msg_type() :: map() | tuple(). -type stream_option() :: {metadata, metadata()} | {compression, compression_method()} | - {http2_options, [term()]}. + {http2_options, [term()]} | + {msgs_as_records,module()}. -type client_stream() :: pid(). @@ -178,18 +180,16 @@ new_stream(Connection, Service, Rpc, DecoderModule) -> new_stream(Connection, Service, Rpc, DecoderModule, Options) -> grpc_client_stream:new(Connection, Service, Rpc, DecoderModule, Options). --spec send(Stream::client_stream(), Msg::map()) -> ok. +-spec send(Stream::client_stream(), Msg::msg_type()) -> ok. %% @doc Send a message from the client to the server. -send(Stream, Msg) when is_pid(Stream), - is_map(Msg) -> +send(Stream, Msg) when is_pid(Stream) -> grpc_client_stream:send(Stream, Msg). --spec send_last(Stream::client_stream(), Msg::map()) -> ok. +-spec send_last(Stream::client_stream(), Msg::msg_type()) -> ok. %% @doc Send a message to server and mark it as the last message %% on the stream. For simple RPC and client-streaming RPCs that %% will trigger the response from the server. -send_last(Stream, Msg) when is_pid(Stream), - is_map(Msg) -> +send_last(Stream, Msg) when is_pid(Stream) -> grpc_client_stream:send_last(Stream, Msg). -spec rcv(Stream::client_stream()) -> rcv_response(). @@ -240,7 +240,7 @@ stop_connection(Connection) -> grpc_client_connection:stop(Connection). -spec unary(Connection::connection(), - Message::map(), Service::atom(), Rpc::atom(), + Message::msg_type(), Service::atom(), Rpc::atom(), Decoder::module(), Options::[stream_option() | {timeout, timeout()}]) -> unary_response(map()). diff --git a/src/grpc_client_stream.erl b/src/grpc_client_stream.erl index f466301..b763056 100644 --- a/src/grpc_client_stream.erl +++ b/src/grpc_client_stream.erl @@ -226,6 +226,8 @@ new_stream(Connection, Service, Rpc, Encoder, Options) -> Compression = proplists:get_value(compression, Options, none), Metadata = proplists:get_value(metadata, Options, #{}), TransportOptions = proplists:get_value(http2_options, Options, []), + RecordsEncoder = proplists:get_value(msgs_as_records, Options, []), + ClientPid = proplists:get_value(async_notification, Options), {ok, StreamId} = grpc_client_connection:new_stream(Connection, TransportOptions), Package = Encoder:get_package_name(), RpcDef = Encoder:find_rpc_def(Service, Rpc), @@ -238,8 +240,10 @@ new_stream(Connection, Service, Rpc, Encoder, Options) -> rpc => Rpc, queue => queue:new(), response_pending => false, + async_notification => ClientPid, state => idle, encoder => Encoder, + records_encoder => RecordsEncoder, connection => Connection, headers_sent => false, metadata => Metadata, @@ -314,6 +318,9 @@ add_metadata(Headers, Metadata) -> lists:keystore(K, 1, Acc, {K,V}) end, Headers, maps:to_list(Metadata)). +info_response(Response, #{async_notification := Client} = Stream) when is_pid(Client) -> + Client ! {grpc_notification,Response}, + {noreply, Stream}; info_response(Response, #{response_pending := true, client := Client} = Stream) -> gen_server:reply(Client, Response), @@ -325,17 +332,24 @@ info_response(Response, #{queue := Queue} = Stream) -> %% TODO: fix the error handling, currently it is very hard to understand the %% error that results from a bad message (Map). encode(#{encoder := Encoder, - input := MsgType, - compression := CompressionMethod}, Map) -> - %% RequestData = Encoder:encode_msg(Map, MsgType), - try Encoder:encode_msg(Map, MsgType) of - RequestData -> + records_encoder := RecordsEncoder, + input := MsgType, + compression := CompressionMethod}, Msg) -> + try + begin + RequestData = case is_map(Msg) of + true -> + Encoder:encode_msg(Msg, MsgType); + false when is_tuple(Msg) -> + RecordsEncoder:encode_msg(Msg) + end, maybe_compress(RequestData, CompressionMethod) + end catch error:function_clause -> - throw({error, {failed_to_encode, MsgType, Map}}); + throw({error, {failed_to_encode, MsgType, Msg}}); Error:Reason -> - throw({error, {Error, Reason}}) + throw({error, {Error, Reason}}) end. maybe_compress(Encoded, none) -> @@ -351,12 +365,18 @@ maybe_compress(_Encoded, Other) -> decode(Encoded, Binary, #{response_encoding := Method, encoder := Encoder, + records_encoder := RecordsEncoder, output := MsgType}) -> - Message = case Encoded of + Message = case Encoded of 1 -> decompress(Binary, Method); 0 -> Binary end, - Encoder:decode_msg(Message, MsgType). + case RecordsEncoder == [] of + true -> + Encoder:decode_msg(Message, MsgType); + _ -> + RecordsEncoder:decode_msg(Message, MsgType) + end. decompress(Compressed, <<"gzip">>) -> zlib:gunzip(Compressed);