Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pkg/grpc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ go_library(
"@org_golang_google_protobuf//proto",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//types/known/emptypb",
"@org_golang_x_sync//errgroup",
"@org_golang_x_sync//semaphore",
] + select({
"@rules_go//go/platform:android": [
Expand Down
92 changes: 48 additions & 44 deletions pkg/grpc/forwarding_stream_handler.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package grpc

import (
"context"
"io"

"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
)
Expand Down Expand Up @@ -38,55 +38,59 @@ func (s *forwardingStreamHandler) HandleStream(srv any, incomingStream grpc.Serv
ServerStreams: true,
ClientStreams: true,
}
group, groupCtx := errgroup.WithContext(incomingStream.Context())
group.Go(func() error {
// groupCtx is guaranteed to be canceled before returning from this method, so outgoingStream will not leak resources.
outgoingStream, err := s.backend.NewStream(groupCtx, &desc, method)
if err != nil {
return err
}
// Avoid group.Go because incomingStream.RecvMsg might block returning
// an error from the outgoingStream and getting the context for
// incomingStream canceled.
go func() {
for {
msg := &emptypb.Empty{}
if err := incomingStream.RecvMsg(msg); err != nil {
if err == io.EOF {
// Let's continue to receive on outgoingStream, so don't
// cancel grouptCtx.
outgoingStream.CloseSend()
return
}
// Cancel groupCtx immediately.
group.Go(func() error { return err })
return
}
if err := outgoingStream.SendMsg(msg); err != nil {
if err == io.EOF {
// The error will be returned by outgoingStream.RecvMsg(),
// no need to cancel groupCtx now.
return
}
// Cancel groupCtx immediately.
group.Go(func() error { return err })
return
}
}
}()
ctx, cancel := context.WithCancelCause(incomingStream.Context())
defer cancel(nil)

// ctx is guaranteed to be canceled when returning from this method, so
// outgoingStream will not leak resources.
outgoingStream, err := s.backend.NewStream(ctx, &desc, method)
if err != nil {
return err
}

// The only way to cancel a blocking incomingStream.RecvMsg is to return
// from this method. Therefore, an error from outgoingStream.RecvMsg
// needs to be returned without waiting for incomingStream.RecvMsg, so
// it cannot be run inside e.g. errgroup.Go.
go func() {
for {
msg := &emptypb.Empty{}
if err := outgoingStream.RecvMsg(msg); err != nil {
if err := incomingStream.RecvMsg(msg); err != nil {
if err == io.EOF {
return nil
// Let's continue to receive on outgoingStream, so don't
// cancel grouptCtx.
outgoingStream.CloseSend()
return
}
return err
// Cancel ctx immediately.
cancel(err)
return
}
if err := incomingStream.SendMsg(msg); err != nil {
return err
if err := outgoingStream.SendMsg(msg); err != nil {
if err == io.EOF {
// The error will be returned by outgoingStream.RecvMsg(),
// no need to cancel ctx now.
return
}
// Cancel ctx immediately.
cancel(err)
return
}
}
}()

for {
msg := &emptypb.Empty{}
if err := outgoingStream.RecvMsg(msg); err != nil {
if err != io.EOF {
cancel(err)
}
break
}
})
return group.Wait()
if err := incomingStream.SendMsg(msg); err != nil {
cancel(err)
break
}
}
return context.Cause(ctx)
}