From 1750e0d3f50afb631f954a24b8acbd6f0130d464 Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Tue, 18 Nov 2025 09:14:39 -0800 Subject: [PATCH] Allow parallel calling of the visitor --- cmd/proxygenerator/go.mod | 4 +- cmd/proxygenerator/interceptor.go | 87 +++++++++---- go.mod | 3 +- go.sum | 2 + proto/api | 2 +- proxy/interceptor.go | 200 ++++++++++++++++++++---------- 6 files changed, 208 insertions(+), 90 deletions(-) diff --git a/cmd/proxygenerator/go.mod b/cmd/proxygenerator/go.mod index 5cb6b849..da0eb8e9 100644 --- a/cmd/proxygenerator/go.mod +++ b/cmd/proxygenerator/go.mod @@ -1,8 +1,8 @@ module go.temporal.io/api/cmd/proxygenerator -go 1.22.0 +go 1.24 -toolchain go1.24.0 +toolchain go1.24.1 replace go.temporal.io/api => ../.. diff --git a/cmd/proxygenerator/interceptor.go b/cmd/proxygenerator/interceptor.go index af16d5e1..8487ad4c 100644 --- a/cmd/proxygenerator/interceptor.go +++ b/cmd/proxygenerator/interceptor.go @@ -34,8 +34,12 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/protoadapt" + "golang.org/x/sync/errgroup" ) +type contextKey string +var concurrencyKey = contextKey("payload-visitor-concurrency") + // VisitPayloadsContext provides Payload context for visitor functions. type VisitPayloadsContext struct { context.Context @@ -55,6 +59,7 @@ type VisitPayloadsOptions struct { // Will be called for each Any encountered. If not set, the default is to recurse into the Any // object, unmarshal it, visit, and re-marshal it always (even if there are no changes). WellKnownAnyVisitor func(*VisitPayloadsContext, *anypb.Any) error + ConcurrencyLimit int } // VisitPayloads calls the options.Visitor function for every Payload proto within msg. @@ -265,38 +270,72 @@ func visitPayloads( parent proto.Message, objs ...interface{}, ) error { + var waitOnErrGroup bool + eg, ok := ctx.Value(concurrencyKey).(*errgroup.Group) + if !ok { + var cctx context.Context + eg, cctx = errgroup.WithContext(ctx) + cctx = context.WithValue(cctx, concurrencyKey, eg) + ctx = &VisitPayloadsContext{ + Context: cctx, + Parent: ctx.Parent, + SinglePayloadRequired: ctx.SinglePayloadRequired, + } + if options.ConcurrencyLimit > 0 { + eg.SetLimit(options.ConcurrencyLimit) + } else if options.ConcurrencyLimit == 0 { + eg.SetLimit(1) + } + waitOnErrGroup = true + } + // Make a copy of ctx since we may be modifying it in the goroutines + ctx = &VisitPayloadsContext{ + Context: ctx, + Parent: ctx.Parent, + SinglePayloadRequired: ctx.SinglePayloadRequired, + } + for _, obj := range objs { ctx.SinglePayloadRequired = false switch o := obj.(type) { case map[string]*common.Payload: for ix, x := range o { - if nx, err := visitPayload(ctx, options, parent, x); err != nil { - return err - } else { - o[ix] = nx - } + eg.Go(func() error { + if nx, err := visitPayload(ctx, options, parent, x); err != nil { + return err + } else { + o[ix] = nx + } + return nil + }) } case *common.Payloads: if o == nil { continue } - ctx.Parent = parent - newPayloads, err := options.Visitor(ctx, o.Payloads) - ctx.Parent = nil - if err != nil { return err } - o.Payloads = newPayloads + eg.Go(func() error { + ctx.Parent = parent + newPayloads, err := options.Visitor(ctx, o.Payloads) + ctx.Parent = nil + if err != nil { return err } + o.Payloads = newPayloads + return nil + }) case map[string]*common.Payloads: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { - return err - } + eg.Go(func() error { + return visitPayloads(ctx, options, parent, x) + }) } case []*common.Payload: for ix, x := range o { - if nx, err := visitPayload(ctx, options, parent, x); err != nil { - return err - } else { - o[ix] = nx - } + eg.Go(func() error { + if nx, err := visitPayload(ctx, options, parent, x); err != nil { + return err + } else { + o[ix] = nx + } + return nil + }) } case *anypb.Any: if o == nil { @@ -342,9 +381,12 @@ func visitPayloads( if o == nil { continue } {{range $record.Payloads -}} if o.{{.}} != nil { - no, err := visitPayload(ctx, options, o, o.{{.}}) - if err != nil { return err } - o.{{.}} = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.{{.}}) + if err != nil { return err } + o.{{.}} = no + return nil + }) } {{end}} {{if $record.Methods}} @@ -361,6 +403,9 @@ func visitPayloads( } } + if waitOnErrGroup { + return eg.Wait() + } return nil } diff --git a/go.mod b/go.mod index 1ac05b44..8dc2b2aa 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,12 @@ module go.temporal.io/api -go 1.21 +go 1.24 require ( github.com/golang/mock v1.6.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 github.com/stretchr/testify v1.9.0 + golang.org/x/sync v0.8.0 google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed google.golang.org/grpc v1.66.0 diff --git a/go.sum b/go.sum index ec868699..77f4cf04 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/proto/api b/proto/api index 35aacca9..d96bd55e 160000 --- a/proto/api +++ b/proto/api @@ -1 +1 @@ -Subproject commit 35aacca933c2d12197c77004700fff4af1926220 +Subproject commit d96bd55e87799e9f6a33a1c40a56cfa932566bdf diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 7004ac54..4938a456 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -24,12 +24,17 @@ import ( "go.temporal.io/api/update/v1" "go.temporal.io/api/workflow/v1" "go.temporal.io/api/workflowservice/v1" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" ) +type contextKey string + +var concurrencyKey = contextKey("payload-visitor-concurrency") + // VisitPayloadsContext provides Payload context for visitor functions. type VisitPayloadsContext struct { context.Context @@ -49,6 +54,7 @@ type VisitPayloadsOptions struct { // Will be called for each Any encountered. If not set, the default is to recurse into the Any // object, unmarshal it, visit, and re-marshal it always (even if there are no changes). WellKnownAnyVisitor func(*VisitPayloadsContext, *anypb.Any) error + ConcurrencyLimit int } // VisitPayloads calls the options.Visitor function for every Payload proto within msg. @@ -258,42 +264,76 @@ func visitPayloads( parent proto.Message, objs ...interface{}, ) error { + var waitOnErrGroup bool + eg, ok := ctx.Value(concurrencyKey).(*errgroup.Group) + if !ok { + var cctx context.Context + eg, cctx = errgroup.WithContext(ctx) + cctx = context.WithValue(cctx, concurrencyKey, eg) + ctx = &VisitPayloadsContext{ + Context: cctx, + Parent: ctx.Parent, + SinglePayloadRequired: ctx.SinglePayloadRequired, + } + if options.ConcurrencyLimit > 0 { + eg.SetLimit(options.ConcurrencyLimit) + } else if options.ConcurrencyLimit == 0 { + eg.SetLimit(1) + } + waitOnErrGroup = true + } + // Make a copy of ctx since we may be modifying it in the goroutines + ctx = &VisitPayloadsContext{ + Context: ctx, + Parent: ctx.Parent, + SinglePayloadRequired: ctx.SinglePayloadRequired, + } + for _, obj := range objs { ctx.SinglePayloadRequired = false switch o := obj.(type) { case map[string]*common.Payload: for ix, x := range o { - if nx, err := visitPayload(ctx, options, parent, x); err != nil { - return err - } else { - o[ix] = nx - } + eg.Go(func() error { + if nx, err := visitPayload(ctx, options, parent, x); err != nil { + return err + } else { + o[ix] = nx + } + return nil + }) } case *common.Payloads: if o == nil { continue } - ctx.Parent = parent - newPayloads, err := options.Visitor(ctx, o.Payloads) - ctx.Parent = nil - if err != nil { - return err - } - o.Payloads = newPayloads - case map[string]*common.Payloads: - for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + eg.Go(func() error { + ctx.Parent = parent + newPayloads, err := options.Visitor(ctx, o.Payloads) + ctx.Parent = nil + if err != nil { return err } + o.Payloads = newPayloads + return nil + }) + case map[string]*common.Payloads: + for _, x := range o { + eg.Go(func() error { + return visitPayloads(ctx, options, parent, x) + }) } case []*common.Payload: for ix, x := range o { - if nx, err := visitPayload(ctx, options, parent, x); err != nil { - return err - } else { - o[ix] = nx - } + eg.Go(func() error { + if nx, err := visitPayload(ctx, options, parent, x); err != nil { + return err + } else { + o[ix] = nx + } + return nil + }) } case *anypb.Any: if o == nil { @@ -514,11 +554,14 @@ func visitPayloads( continue } if o.Input != nil { - no, err := visitPayload(ctx, options, o, o.Input) - if err != nil { - return err - } - o.Input = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.Input) + if err != nil { + return err + } + o.Input = no + return nil + }) } case *command.SignalExternalWorkflowExecutionCommandAttributes: @@ -811,11 +854,14 @@ func visitPayloads( continue } if o.EncodedAttributes != nil { - no, err := visitPayload(ctx, options, o, o.EncodedAttributes) - if err != nil { - return err - } - o.EncodedAttributes = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.EncodedAttributes) + if err != nil { + return err + } + o.EncodedAttributes = no + return nil + }) } if err := visitPayloads( @@ -1136,11 +1182,14 @@ func visitPayloads( continue } if o.Result != nil { - no, err := visitPayload(ctx, options, o, o.Result) - if err != nil { - return err - } - o.Result = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.Result) + if err != nil { + return err + } + o.Result = no + return nil + }) } case *history.NexusOperationFailedEventAttributes: @@ -1164,11 +1213,14 @@ func visitPayloads( continue } if o.Input != nil { - no, err := visitPayload(ctx, options, o, o.Input) - if err != nil { - return err - } - o.Input = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.Input) + if err != nil { + return err + } + o.Input = no + return nil + }) } case *history.NexusOperationTimedOutEventAttributes: @@ -1485,11 +1537,14 @@ func visitPayloads( continue } if o.Description != nil { - no, err := visitPayload(ctx, options, o, o.Description) - if err != nil { - return err - } - o.Description = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.Description) + if err != nil { + return err + } + o.Description = no + return nil + }) } case *nexus.Request: @@ -1528,11 +1583,14 @@ func visitPayloads( continue } if o.Payload != nil { - no, err := visitPayload(ctx, options, o, o.Payload) - if err != nil { - return err - } - o.Payload = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.Payload) + if err != nil { + return err + } + o.Payload = no + return nil + }) } case *nexus.StartOperationResponse: @@ -1556,11 +1614,14 @@ func visitPayloads( continue } if o.Payload != nil { - no, err := visitPayload(ctx, options, o, o.Payload) - if err != nil { - return err - } - o.Payload = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.Payload) + if err != nil { + return err + } + o.Payload = no + return nil + }) } case *operatorservice.CreateNexusEndpointRequest: @@ -1780,18 +1841,24 @@ func visitPayloads( continue } if o.Details != nil { - no, err := visitPayload(ctx, options, o, o.Details) - if err != nil { - return err - } - o.Details = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.Details) + if err != nil { + return err + } + o.Details = no + return nil + }) } if o.Summary != nil { - no, err := visitPayload(ctx, options, o, o.Summary) - if err != nil { - return err - } - o.Summary = no + eg.Go(func() error { + no, err := visitPayload(ctx, options, o, o.Summary) + if err != nil { + return err + } + o.Summary = no + return nil + }) } case *update.Acceptance: @@ -2974,6 +3041,9 @@ func visitPayloads( } } + if waitOnErrGroup { + return eg.Wait() + } return nil }