From 33df78c30361b3520c523676167d3393dfa2c6ba Mon Sep 17 00:00:00 2001 From: JeethJJ Date: Tue, 3 Feb 2026 17:30:18 +0530 Subject: [PATCH] Flow service context propagation and transection usage --- backend/internal/application/service.go | 11 +- backend/internal/application/service_test.go | 84 ++-- backend/internal/flow/flowexec/service.go | 66 +-- .../internal/flow/flowexec/service_test.go | 11 +- .../mgt/FlowMgtServiceInterface_mock_test.go | 381 +++++++++------- .../internal/flow/mgt/cache_backed_store.go | 48 +- .../flow/mgt/cache_backed_store_test.go | 148 ++++--- .../internal/flow/mgt/declarative_resource.go | 5 +- .../flow/mgt/declarative_resource_test.go | 52 +-- backend/internal/flow/mgt/file_based_store.go | 31 +- .../flow/mgt/file_based_store_test.go | 91 ++-- .../flow/mgt/flowStoreInterface_mock_test.go | 396 ++++++++++------- .../mgt/graphBuilderInterface_mock_test.go | 4 +- backend/internal/flow/mgt/handler.go | 24 +- backend/internal/flow/mgt/handler_test.go | 37 +- backend/internal/flow/mgt/init.go | 8 +- backend/internal/flow/mgt/service.go | 301 ++++++++----- backend/internal/flow/mgt/service_test.go | 412 +++++++++++------- backend/internal/flow/mgt/store.go | 390 +++++++---------- backend/internal/flow/mgt/store_test.go | 243 ++++++----- backend/internal/mcp/tools/flow/tool.go | 12 +- backend/internal/mcp/tools/flow/tool_test.go | 22 +- .../internal/system/export/service_test.go | 23 +- .../FlowMgtServiceInterface_mock.go | 381 +++++++++------- 24 files changed, 1782 insertions(+), 1399 deletions(-) diff --git a/backend/internal/application/service.go b/backend/internal/application/service.go index 21adb5f2c..75a6d81ba 100644 --- a/backend/internal/application/service.go +++ b/backend/internal/application/service.go @@ -19,6 +19,7 @@ package application import ( + "context" "errors" "slices" @@ -636,7 +637,7 @@ func (as *applicationService) DeleteApplication(appID string) *serviceerror.Serv // If the flow ID is not provided, it sets the default authentication flow ID. func (as *applicationService) validateAuthFlowID(app *model.ApplicationDTO) *serviceerror.ServiceError { if app.AuthFlowID != "" { - isValidFlow := as.flowMgtService.IsValidFlow(app.AuthFlowID) + isValidFlow := as.flowMgtService.IsValidFlow(context.TODO(), app.AuthFlowID) if !isValidFlow { return &ErrorInvalidAuthFlowID } @@ -657,13 +658,13 @@ func (as *applicationService) validateRegistrationFlowID(app *model.ApplicationD logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "ApplicationService")) if app.RegistrationFlowID != "" { - isValidFlow := as.flowMgtService.IsValidFlow(app.RegistrationFlowID) + isValidFlow := as.flowMgtService.IsValidFlow(context.TODO(), app.RegistrationFlowID) if !isValidFlow { return &ErrorInvalidRegistrationFlowID } } else { // Try to get the equivalent registration flow for the auth flow - authFlow, svcErr := as.flowMgtService.GetFlow(app.AuthFlowID) + authFlow, svcErr := as.flowMgtService.GetFlow(context.TODO(), app.AuthFlowID) if svcErr != nil { if svcErr.Type == serviceerror.ServerErrorType { logger.Error("Error while retrieving auth flow definition", @@ -674,7 +675,7 @@ func (as *applicationService) validateRegistrationFlowID(app *model.ApplicationD } registrationFlow, svcErr := as.flowMgtService.GetFlowByHandle( - authFlow.Handle, flowcommon.FlowTypeRegistration) + context.TODO(), authFlow.Handle, flowcommon.FlowTypeRegistration) if svcErr != nil { if svcErr.Type == serviceerror.ServerErrorType { logger.Error("Error while retrieving registration flow definition by handle", @@ -899,7 +900,7 @@ func (as *applicationService) getDefaultAuthFlowID() (string, *serviceerror.Serv defaultAuthFlowHandle := config.GetThunderRuntime().Config.Flow.DefaultAuthFlowHandle defaultAuthFlow, svcErr := as.flowMgtService.GetFlowByHandle( - defaultAuthFlowHandle, flowcommon.FlowTypeAuthentication) + context.TODO(), defaultAuthFlowHandle, flowcommon.FlowTypeAuthentication) if svcErr != nil { if svcErr.Type == serviceerror.ServerErrorType { diff --git a/backend/internal/application/service_test.go b/backend/internal/application/service_test.go index e053df169..fa8eb27b4 100644 --- a/backend/internal/application/service_test.go +++ b/backend/internal/application/service_test.go @@ -653,7 +653,7 @@ func (suite *ServiceTestSuite) TestValidateAuthFlowID_WithValidFlowID() { AuthFlowID: "auth-flow-123", } - mockFlowMgtService.EXPECT().IsValidFlow("auth-flow-123").Return(true) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "auth-flow-123").Return(true) svcErr := service.validateAuthFlowID(app) @@ -668,7 +668,7 @@ func (suite *ServiceTestSuite) TestValidateAuthFlowID_WithInvalidFlowID() { AuthFlowID: "invalid-flow", } - mockFlowMgtService.EXPECT().IsValidFlow("invalid-flow").Return(false) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "invalid-flow").Return(false) svcErr := service.validateAuthFlowID(app) @@ -697,7 +697,7 @@ func (suite *ServiceTestSuite) TestValidateAuthFlowID_WithEmptyFlowID_SetsDefaul ID: "default-flow-id-123", Handle: "default_auth_flow", } - mockFlowMgtService.EXPECT().GetFlowByHandle("default_auth_flow", flowcommon.FlowTypeAuthentication). + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "default_auth_flow", flowcommon.FlowTypeAuthentication). Return(defaultFlow, nil) svcErr := service.validateAuthFlowID(app) @@ -723,7 +723,7 @@ func (suite *ServiceTestSuite) TestValidateAuthFlowID_WithEmptyFlowID_ErrorRetri AuthFlowID: "", } - mockFlowMgtService.EXPECT().GetFlowByHandle("default_auth_flow", flowcommon.FlowTypeAuthentication). + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "default_auth_flow", flowcommon.FlowTypeAuthentication). Return(nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType}) svcErr := service.validateAuthFlowID(app) @@ -739,7 +739,7 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_WithValidFlowID() RegistrationFlowID: "reg-flow-123", } - mockFlowMgtService.EXPECT().IsValidFlow("reg-flow-123").Return(true) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "reg-flow-123").Return(true) svcErr := service.validateRegistrationFlowID(app) @@ -754,7 +754,7 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_WithInvalidFlowID( RegistrationFlowID: "invalid-reg-flow", } - mockFlowMgtService.EXPECT().IsValidFlow("invalid-reg-flow").Return(false) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "invalid-reg-flow").Return(false) svcErr := service.validateRegistrationFlowID(app) @@ -779,8 +779,8 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_WithEmptyFlowID_In Handle: "basic_auth", } - mockFlowMgtService.EXPECT().GetFlow("auth-flow-123").Return(authFlow, nil) - mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration). + mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123").Return(authFlow, nil) + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration). Return(regFlow, nil) svcErr := service.validateRegistrationFlowID(app) @@ -797,7 +797,7 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_ErrorRetrievingAut RegistrationFlowID: "", } - mockFlowMgtService.EXPECT().GetFlow("auth-flow-123"). + mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123"). Return(nil, &serviceerror.ServiceError{Type: serviceerror.ServerErrorType}) svcErr := service.validateRegistrationFlowID(app) @@ -819,8 +819,8 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_ErrorRetrievingReg Handle: "basic_auth", } - mockFlowMgtService.EXPECT().GetFlow("auth-flow-123").Return(authFlow, nil) - mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration). + mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123").Return(authFlow, nil) + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration). Return(nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType}) svcErr := service.validateRegistrationFlowID(app) @@ -837,7 +837,7 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_ClientErrorRetriev RegistrationFlowID: "", } - mockFlowMgtService.EXPECT().GetFlow("auth-flow-123"). + mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123"). Return(nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType}) svcErr := service.validateRegistrationFlowID(app) @@ -859,8 +859,8 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_ServerErrorRetriev Handle: "basic_auth", } - mockFlowMgtService.EXPECT().GetFlow("auth-flow-123").Return(authFlow, nil) - mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration). + mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123").Return(authFlow, nil) + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration). Return(nil, &serviceerror.ServiceError{Type: serviceerror.ServerErrorType}) svcErr := service.validateRegistrationFlowID(app) @@ -886,7 +886,7 @@ func (suite *ServiceTestSuite) TestGetDefaultAuthFlowID_Success() { ID: "flow-id-789", Handle: "custom_auth_flow", } - mockFlowMgtService.EXPECT().GetFlowByHandle("custom_auth_flow", flowcommon.FlowTypeAuthentication). + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "custom_auth_flow", flowcommon.FlowTypeAuthentication). Return(defaultFlow, nil) result, svcErr := service.getDefaultAuthFlowID() @@ -908,7 +908,7 @@ func (suite *ServiceTestSuite) TestGetDefaultAuthFlowID_ErrorRetrieving() { service, _, _, mockFlowMgtService := suite.setupTestService() - mockFlowMgtService.EXPECT().GetFlowByHandle("custom_auth_flow", flowcommon.FlowTypeAuthentication). + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "custom_auth_flow", flowcommon.FlowTypeAuthentication). Return(nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType}) result, svcErr := service.getDefaultAuthFlowID() @@ -931,7 +931,7 @@ func (suite *ServiceTestSuite) TestGetDefaultAuthFlowID_ServerError() { service, _, _, mockFlowMgtService := suite.setupTestService() - mockFlowMgtService.EXPECT().GetFlowByHandle("custom_auth_flow", flowcommon.FlowTypeAuthentication). + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "custom_auth_flow", flowcommon.FlowTypeAuthentication). Return(nil, &serviceerror.ServiceError{Type: serviceerror.ServerErrorType}) result, svcErr := service.getDefaultAuthFlowID() @@ -2430,19 +2430,20 @@ func (suite *ServiceTestSuite) TestValidateApplication_InvalidURL() { } mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError) - mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) - mockFlowMgtService.EXPECT().GetFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(&flowmgt.CompleteFlowDefinition{ - ID: "edc013d0-e893-4dc0-990c-3e1d203e005b", - Handle: "basic_auth", - }, nil).Maybe() + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) + mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b"). + Return(&flowmgt.CompleteFlowDefinition{ + ID: "edc013d0-e893-4dc0-990c-3e1d203e005b", + Handle: "basic_auth", + }, nil).Maybe() // Return success for registration flow so URL validation runs - mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration).Return( + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration).Return( &flowmgt.CompleteFlowDefinition{ ID: "reg_flow_basic", Handle: "basic_auth", }, nil).Maybe() - mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything).Return(true).Maybe() + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, mock.Anything).Return(true).Maybe() result, inboundAuth, svcErr := service.ValidateApplication(app) @@ -2473,19 +2474,20 @@ func (suite *ServiceTestSuite) TestValidateApplication_InvalidLogoURL() { } mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError) - mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) - mockFlowMgtService.EXPECT().GetFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(&flowmgt.CompleteFlowDefinition{ - ID: "edc013d0-e893-4dc0-990c-3e1d203e005b", - Handle: "basic_auth", - }, nil).Maybe() + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) + mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b"). + Return(&flowmgt.CompleteFlowDefinition{ + ID: "edc013d0-e893-4dc0-990c-3e1d203e005b", + Handle: "basic_auth", + }, nil).Maybe() // Return success for registration flow so URL validation runs - mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration).Return( + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration).Return( &flowmgt.CompleteFlowDefinition{ ID: "reg_flow_basic", Handle: "basic_auth", }, nil).Maybe() - mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything).Return(true).Maybe() + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, mock.Anything).Return(true).Maybe() result, inboundAuth, svcErr := service.ValidateApplication(app) @@ -2522,8 +2524,8 @@ func (suite *ServiceTestSuite) TestCreateApplication_StoreErrorWithRollback() { } mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError) - mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) - mockFlowMgtService.EXPECT().IsValidFlow("80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true) mockCertService.EXPECT().CreateCertificate(mock.Anything).Return(&cert.Certificate{Type: "JWKS"}, nil) mockStore.On("CreateApplication", mock.Anything).Return(errors.New("store error")) mockCertService.EXPECT(). @@ -2564,8 +2566,8 @@ func (suite *ServiceTestSuite) TestCreateApplication_StoreErrorWithRollbackFailu } mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError) - mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) - mockFlowMgtService.EXPECT().IsValidFlow("80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true) mockCertService.EXPECT().CreateCertificate(mock.Anything).Return(&cert.Certificate{Type: "JWKS"}, nil) mockStore.On("CreateApplication", mock.Anything).Return(errors.New("store error")) rollbackErr := &serviceerror.ServiceError{ @@ -2689,7 +2691,7 @@ func (suite *ServiceTestSuite) TestUpdateApplication_StoreErrorWhenCheckingClien } mockStore.On("GetApplicationByID", "app123").Return(existingApp, nil) - mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything).Return(true).Maybe() + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, mock.Anything).Return(true).Maybe() // Return an error that's not ApplicationNotFoundError when checking client ID mockStore.On("GetOAuthApplication", "new-client-id").Return(nil, errors.New("database connection error")) @@ -2733,8 +2735,8 @@ func (suite *ServiceTestSuite) TestUpdateApplication_StoreErrorWithRollback() { } mockStore.On("GetApplicationByID", "app123").Return(existingApp, nil) - mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) - mockFlowMgtService.EXPECT().IsValidFlow("80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true) + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true) mockCertService.EXPECT(). GetCertificateByReference(cert.CertificateReferenceTypeApplication, "app123"). Return(nil, &cert.ErrorCertificateNotFound) @@ -3452,12 +3454,12 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_NoPrefix() { } mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError) - mockFlowMgtService.EXPECT().IsValidFlow("invalid_flow_id").Return(true) - mockFlowMgtService.EXPECT().GetFlow("invalid_flow_id").Return(&flowmgt.CompleteFlowDefinition{ + mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "invalid_flow_id").Return(true) + mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "invalid_flow_id").Return(&flowmgt.CompleteFlowDefinition{ ID: "invalid_flow_id", Handle: "test_flow", }, nil).Maybe() - mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, flowcommon.FlowTypeRegistration).Return( + mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, mock.Anything, flowcommon.FlowTypeRegistration).Return( nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType}).Maybe() result, inboundAuth, svcErr := service.ValidateApplication(app) diff --git a/backend/internal/flow/flowexec/service.go b/backend/internal/flow/flowexec/service.go index f65d2638e..84c7d8f6d 100644 --- a/backend/internal/flow/flowexec/service.go +++ b/backend/internal/flow/flowexec/service.go @@ -80,7 +80,7 @@ func (s *flowExecService) Execute(ctx context.Context, var loadErr *serviceerror.ServiceError if isNewFlow(flowID) { - context, loadErr = s.loadNewContext(appID, flowType, verbose, action, inputs, logger) + context, loadErr = s.loadNewContext(ctx, appID, flowType, verbose, action, inputs, logger) if loadErr != nil { logger.Error("Failed to load new flow context", log.String("appID", appID), @@ -108,7 +108,7 @@ func (s *flowExecService) Execute(ctx context.Context, return nil, loadErr } } else { - context, loadErr = s.loadPrevContext(flowID, action, inputs, logger) + context, loadErr = s.loadPrevContext(ctx, flowID, action, inputs, logger) if loadErr != nil { logger.Error("Failed to load previous flow context", log.String("flowID", flowID), @@ -162,7 +162,7 @@ func (s *flowExecService) Execute(ctx context.Context, } // initContext initializes a new flow context with the given details. -func (s *flowExecService) loadNewContext(appID, flowTypeStr string, verbose bool, +func (s *flowExecService) loadNewContext(ctx context.Context, appID, flowTypeStr string, verbose bool, action string, inputs map[string]string, logger *log.Logger) ( *EngineContext, *serviceerror.ServiceError) { flowType, err := validateFlowType(flowTypeStr) @@ -170,65 +170,65 @@ func (s *flowExecService) loadNewContext(appID, flowTypeStr string, verbose bool return nil, err } - ctx, err := s.initContext(appID, flowType, verbose, logger) + flowCtx, err := s.initContext(ctx, appID, flowType, verbose, logger) if err != nil { return nil, err } - prepareContext(ctx, action, inputs) - return ctx, nil + prepareContext(flowCtx, action, inputs) + return flowCtx, nil } // initContext initializes a new flow context with the given details. -func (s *flowExecService) initContext(appID string, flowType common.FlowType, +func (s *flowExecService) initContext(ctx context.Context, appID string, flowType common.FlowType, verbose bool, logger *log.Logger) (*EngineContext, *serviceerror.ServiceError) { - graphID, svcErr := s.getFlowGraph(appID, flowType, logger) + graphID, svcErr := s.getFlowGraph(ctx, appID, flowType, logger) if svcErr != nil { return nil, svcErr } - ctx := EngineContext{} + flowCtx := EngineContext{} flowID, err := sysutils.GenerateUUIDv7() if err != nil { logger.Error("Failed to generate UUID", log.Error(err)) return nil, &serviceerror.InternalServerError } - ctx.FlowID = flowID + flowCtx.FlowID = flowID - graph, svcErr := s.flowMgtService.GetGraph(graphID) + graph, svcErr := s.flowMgtService.GetGraph(ctx, graphID) if svcErr != nil { logger.Error("Error retrieving flow graph from flow management service", log.String("graphID", graphID), log.String("error", svcErr.Error)) return nil, &serviceerror.InternalServerError } - ctx.FlowType = graph.GetType() - ctx.Graph = graph - ctx.AppID = appID - ctx.Verbose = verbose + flowCtx.FlowType = graph.GetType() + flowCtx.Graph = graph + flowCtx.AppID = appID + flowCtx.Verbose = verbose // Set application context if required - if err := s.setApplicationToContext(&ctx, logger); err != nil { + if err := s.setApplicationToContext(&flowCtx, logger); err != nil { return nil, err } - return &ctx, nil + return &flowCtx, nil } // loadPrevContext retrieves the flow context from the store based on the given details. -func (s *flowExecService) loadPrevContext(flowID, action string, inputs map[string]string, +func (s *flowExecService) loadPrevContext(ctx context.Context, flowID, action string, inputs map[string]string, logger *log.Logger) (*EngineContext, *serviceerror.ServiceError) { - ctx, err := s.loadContextFromStore(flowID, logger) + flowCtx, err := s.loadContextFromStore(ctx, flowID, logger) if err != nil { return nil, err } - prepareContext(ctx, action, inputs) - return ctx, nil + prepareContext(flowCtx, action, inputs) + return flowCtx, nil } // loadContextFromStore retrieves the flow context from the store based on the given details. -func (s *flowExecService) loadContextFromStore(flowID string, logger *log.Logger) ( +func (s *flowExecService) loadContextFromStore(ctx context.Context, flowID string, logger *log.Logger) ( *EngineContext, *serviceerror.ServiceError) { if flowID == "" { return nil, &ErrorInvalidFlowID @@ -245,7 +245,7 @@ func (s *flowExecService) loadContextFromStore(flowID string, logger *log.Logger return nil, &ErrorInvalidFlowID } - graph, svcErr := s.flowMgtService.GetGraph(dbModel.GraphID) + graph, svcErr := s.flowMgtService.GetGraph(ctx, dbModel.GraphID) if svcErr != nil { logger.Error("Error retrieving flow graph from flow management service", log.String("graphID", dbModel.GraphID), log.String("error", svcErr.Error)) @@ -351,11 +351,11 @@ func (s *flowExecService) storeContext(ctx *EngineContext, logger *log.Logger) e } // getFlowGraph checks if the provided application ID is valid and returns the associated flow ID. -func (s *flowExecService) getFlowGraph(appID string, flowType common.FlowType, +func (s *flowExecService) getFlowGraph(ctx context.Context, appID string, flowType common.FlowType, logger *log.Logger) (string, *serviceerror.ServiceError) { // Handle app-independent system flows if flowType == common.FlowTypeUserOnboarding { - return s.getSystemFlowGraph(flowType, logger) + return s.getSystemFlowGraph(ctx, flowType, logger) } if appID == "" { @@ -416,7 +416,7 @@ func isNewFlow(flowID string) bool { } // getSystemFlowGraph retrieves the flow graph for system flows by handle. -func (s *flowExecService) getSystemFlowGraph(flowType common.FlowType, +func (s *flowExecService) getSystemFlowGraph(ctx context.Context, flowType common.FlowType, logger *log.Logger) (string, *serviceerror.ServiceError) { handle := "" switch flowType { @@ -426,7 +426,7 @@ func (s *flowExecService) getSystemFlowGraph(flowType common.FlowType, return "", &ErrorInvalidFlowType } - flow, err := s.flowMgtService.GetFlowByHandle(handle, flowType) + flow, err := s.flowMgtService.GetFlowByHandle(ctx, handle, flowType) if err != nil { logger.Error("Failed to get system flow by handle", log.String("handle", handle), log.String("flowType", string(flowType))) @@ -482,7 +482,7 @@ func (s *flowExecService) InitiateFlow(initContext *FlowInitContext) (string, *s // Initialize the engine context // This uses verbose true to ensure step layouts are returned during execution - ctx, err := s.initContext(initContext.ApplicationID, flowType, true, logger) + flowCtx, err := s.initContext(context.Background(), initContext.ApplicationID, flowType, true, logger) if err != nil { logger.Error("Failed to initialize flow context", log.String("appID", initContext.ApplicationID), @@ -492,16 +492,16 @@ func (s *flowExecService) InitiateFlow(initContext *FlowInitContext) (string, *s } // Replace the RuntimeData with initContext RuntimeData - ctx.RuntimeData = initContext.RuntimeData + flowCtx.RuntimeData = initContext.RuntimeData // Store the context without executing the flow - if storeErr := s.storeContext(ctx, logger); storeErr != nil { + if storeErr := s.storeContext(flowCtx, logger); storeErr != nil { logger.Error("Failed to store initial flow context", - log.String("flowID", ctx.FlowID), + log.String("flowID", flowCtx.FlowID), log.Error(storeErr)) return "", &serviceerror.InternalServerError } - logger.Debug("Flow initiated successfully", log.String("flowID", ctx.FlowID)) - return ctx.FlowID, nil + logger.Debug("Flow initiated successfully", log.String("flowID", flowCtx.FlowID)) + return flowCtx.FlowID, nil } diff --git a/backend/internal/flow/flowexec/service_test.go b/backend/internal/flow/flowexec/service_test.go index 38c23bc3d..06464005d 100644 --- a/backend/internal/flow/flowexec/service_test.go +++ b/backend/internal/flow/flowexec/service_test.go @@ -212,12 +212,12 @@ func TestInitiateFlowSuccessScenarios(t *testing.T) { // Mock flow management service to return flow by handle mockFlow := &flowmgt.CompleteFlowDefinition{ID: "onboarding-flow-123"} - mockFlowMgtSvc.EXPECT().GetFlowByHandle(mock.Anything, + mockFlowMgtSvc.EXPECT().GetFlowByHandle(mock.Anything, mock.Anything, common.FlowTypeUserOnboarding).Return(mockFlow, nil) // Mock GetGraph call which is made during initContext inviteGraph := flowFactory.CreateGraph("onboarding-flow-123", common.FlowTypeUserOnboarding) - mockFlowMgtSvc.EXPECT().GetGraph("onboarding-flow-123").Return(inviteGraph, nil) + mockFlowMgtSvc.EXPECT().GetGraph(mock.Anything, "onboarding-flow-123").Return(inviteGraph, nil) // For system flows, StoreFlowContext is called with empty AppID mockStore.EXPECT().StoreFlowContext(mock.MatchedBy(func(ctx EngineContext) bool { @@ -240,7 +240,7 @@ func TestInitiateFlowSuccessScenarios(t *testing.T) { })).Return(nil) } else { mockAppService.EXPECT().GetApplication(appID).Return(mockApp, nil) - mockFlowMgtSvc.EXPECT().GetGraph("auth-graph-1").Return(testGraph, nil) + mockFlowMgtSvc.EXPECT().GetGraph(mock.Anything, "auth-graph-1").Return(testGraph, nil) mockStore.EXPECT().StoreFlowContext(mock.MatchedBy(func(ctx EngineContext) bool { // Verify flowID is generated if ctx.FlowID == "" { @@ -338,7 +338,8 @@ func TestInitiateFlowErrorScenarios(t *testing.T) { mockAppService.EXPECT().GetApplication(appID).Return(mockApp, nil) // Mock flow management service to return error (graph not found) - mockFlowMgtSvc.EXPECT().GetGraph("auth-graph-1").Return(nil, &serviceerror.InternalServerError) + mockFlowMgtSvc.EXPECT().GetGraph(mock.Anything, "auth-graph-1"). + Return(nil, &serviceerror.InternalServerError) // No store mock needed as it fails before storing }, expectedErrorCode: serviceerror.InternalServerError.Code, @@ -359,7 +360,7 @@ func TestInitiateFlowErrorScenarios(t *testing.T) { // Mock flow management service to return valid graph testGraph := flowFactory.CreateGraph("auth-graph-1", common.FlowTypeAuthentication) - mockFlowMgtSvc.EXPECT().GetGraph("auth-graph-1").Return(testGraph, nil) + mockFlowMgtSvc.EXPECT().GetGraph(mock.Anything, "auth-graph-1").Return(testGraph, nil) // Mock store to return error mockStore.EXPECT().StoreFlowContext(mock.AnythingOfType("EngineContext")).Return(assert.AnError) diff --git a/backend/internal/flow/mgt/FlowMgtServiceInterface_mock_test.go b/backend/internal/flow/mgt/FlowMgtServiceInterface_mock_test.go index 4592a36c0..3d0324635 100644 --- a/backend/internal/flow/mgt/FlowMgtServiceInterface_mock_test.go +++ b/backend/internal/flow/mgt/FlowMgtServiceInterface_mock_test.go @@ -5,6 +5,8 @@ package flowmgt import ( + "context" + "github.com/asgardeo/thunder/internal/flow/common" "github.com/asgardeo/thunder/internal/flow/core" "github.com/asgardeo/thunder/internal/system/error/serviceerror" @@ -39,8 +41,8 @@ func (_m *FlowMgtServiceInterfaceMock) EXPECT() *FlowMgtServiceInterfaceMock_Exp } // CreateFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) CreateFlow(flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(flowDef) +func (_mock *FlowMgtServiceInterfaceMock) CreateFlow(ctx context.Context, flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowDef) if len(ret) == 0 { panic("no return value specified for CreateFlow") @@ -48,18 +50,18 @@ func (_mock *FlowMgtServiceInterfaceMock) CreateFlow(flowDef *FlowDefinition) (* var r0 *CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(*FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(flowDef) + if returnFunc, ok := ret.Get(0).(func(context.Context, *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowDef) } - if returnFunc, ok := ret.Get(0).(func(*FlowDefinition) *CompleteFlowDefinition); ok { - r0 = returnFunc(flowDef) + if returnFunc, ok := ret.Get(0).(func(context.Context, *FlowDefinition) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowDef) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(*FlowDefinition) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowDef) + if returnFunc, ok := ret.Get(1).(func(context.Context, *FlowDefinition) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowDef) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -74,19 +76,25 @@ type FlowMgtServiceInterfaceMock_CreateFlow_Call struct { } // CreateFlow is a helper method to define mock.On call +// - ctx context.Context // - flowDef *FlowDefinition -func (_e *FlowMgtServiceInterfaceMock_Expecter) CreateFlow(flowDef interface{}) *FlowMgtServiceInterfaceMock_CreateFlow_Call { - return &FlowMgtServiceInterfaceMock_CreateFlow_Call{Call: _e.mock.On("CreateFlow", flowDef)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) CreateFlow(ctx interface{}, flowDef interface{}) *FlowMgtServiceInterfaceMock_CreateFlow_Call { + return &FlowMgtServiceInterfaceMock_CreateFlow_Call{Call: _e.mock.On("CreateFlow", ctx, flowDef)} } -func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) Run(run func(flowDef *FlowDefinition)) *FlowMgtServiceInterfaceMock_CreateFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) Run(run func(ctx context.Context, flowDef *FlowDefinition)) *FlowMgtServiceInterfaceMock_CreateFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 *FlowDefinition + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(*FlowDefinition) + arg0 = args[0].(context.Context) + } + var arg1 *FlowDefinition + if args[1] != nil { + arg1 = args[1].(*FlowDefinition) } run( arg0, + arg1, ) }) return _c @@ -97,22 +105,23 @@ func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) Return(completeFlowDefini return _c } -func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) RunAndReturn(run func(flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_CreateFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) RunAndReturn(run func(ctx context.Context, flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_CreateFlow_Call { _c.Call.Return(run) return _c } // DeleteFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) DeleteFlow(flowID string) *serviceerror.ServiceError { - ret := _mock.Called(flowID) +// DeleteFlow provides a mock function for the type FlowMgtServiceInterfaceMock +func (_mock *FlowMgtServiceInterfaceMock) DeleteFlow(ctx context.Context, flowID string) *serviceerror.ServiceError { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for DeleteFlow") } var r0 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string) *serviceerror.ServiceError); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *serviceerror.ServiceError); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*serviceerror.ServiceError) @@ -127,19 +136,25 @@ type FlowMgtServiceInterfaceMock_DeleteFlow_Call struct { } // DeleteFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) DeleteFlow(flowID interface{}) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { - return &FlowMgtServiceInterfaceMock_DeleteFlow_Call{Call: _e.mock.On("DeleteFlow", flowID)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) DeleteFlow(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { + return &FlowMgtServiceInterfaceMock_DeleteFlow_Call{Call: _e.mock.On("DeleteFlow", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -150,14 +165,14 @@ func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) Return(serviceError *serv return _c } -func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) RunAndReturn(run func(flowID string) *serviceerror.ServiceError) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string) *serviceerror.ServiceError) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { _c.Call.Return(run) return _c } // GetFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) GetFlow(flowID string) (*CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) GetFlow(ctx context.Context, flowID string) (*CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for GetFlow") @@ -165,18 +180,18 @@ func (_mock *FlowMgtServiceInterfaceMock) GetFlow(flowID string) (*CompleteFlowD var r0 *CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) *CompleteFlowDefinition); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -191,19 +206,25 @@ type FlowMgtServiceInterfaceMock_GetFlow_Call struct { } // GetFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlow(flowID interface{}) *FlowMgtServiceInterfaceMock_GetFlow_Call { - return &FlowMgtServiceInterfaceMock_GetFlow_Call{Call: _e.mock.On("GetFlow", flowID)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlow(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_GetFlow_Call { + return &FlowMgtServiceInterfaceMock_GetFlow_Call{Call: _e.mock.On("GetFlow", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_GetFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_GetFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -214,14 +235,14 @@ func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) Return(completeFlowDefinitio return _c } -func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) RunAndReturn(run func(flowID string) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlow_Call { _c.Call.Return(run) return _c } // GetFlowByHandle provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) GetFlowByHandle(handle string, flowType common.FlowType) (*CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(handle, flowType) +func (_mock *FlowMgtServiceInterfaceMock) GetFlowByHandle(ctx context.Context, handle string, flowType common.FlowType) (*CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, handle, flowType) if len(ret) == 0 { panic("no return value specified for GetFlowByHandle") @@ -229,18 +250,18 @@ func (_mock *FlowMgtServiceInterfaceMock) GetFlowByHandle(handle string, flowTyp var r0 *CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string, common.FlowType) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, common.FlowType) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, handle, flowType) } - if returnFunc, ok := ret.Get(0).(func(string, common.FlowType) *CompleteFlowDefinition); ok { - r0 = returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, common.FlowType) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, handle, flowType) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, common.FlowType) *serviceerror.ServiceError); ok { - r1 = returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, common.FlowType) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, handle, flowType) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -255,25 +276,31 @@ type FlowMgtServiceInterfaceMock_GetFlowByHandle_Call struct { } // GetFlowByHandle is a helper method to define mock.On call +// - ctx context.Context // - handle string // - flowType common.FlowType -func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlowByHandle(handle interface{}, flowType interface{}) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { - return &FlowMgtServiceInterfaceMock_GetFlowByHandle_Call{Call: _e.mock.On("GetFlowByHandle", handle, flowType)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlowByHandle(ctx interface{}, handle interface{}, flowType interface{}) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { + return &FlowMgtServiceInterfaceMock_GetFlowByHandle_Call{Call: _e.mock.On("GetFlowByHandle", ctx, handle, flowType)} } -func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) Run(run func(handle string, flowType common.FlowType)) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) Run(run func(ctx context.Context, handle string, flowType common.FlowType)) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 common.FlowType + var arg1 string if args[1] != nil { - arg1 = args[1].(common.FlowType) + arg1 = args[1].(string) + } + var arg2 common.FlowType + if args[2] != nil { + arg2 = args[2].(common.FlowType) } run( arg0, arg1, + arg2, ) }) return _c @@ -284,14 +311,14 @@ func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) Return(completeFlowD return _c } -func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) RunAndReturn(run func(handle string, flowType common.FlowType) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) RunAndReturn(run func(ctx context.Context, handle string, flowType common.FlowType) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { _c.Call.Return(run) return _c } // GetFlowVersion provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) GetFlowVersion(flowID string, version int) (*FlowVersion, *serviceerror.ServiceError) { - ret := _mock.Called(flowID, version) +func (_mock *FlowMgtServiceInterfaceMock) GetFlowVersion(ctx context.Context, flowID string, version int) (*FlowVersion, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID, version) if len(ret) == 0 { panic("no return value specified for GetFlowVersion") @@ -299,18 +326,18 @@ func (_mock *FlowMgtServiceInterfaceMock) GetFlowVersion(flowID string, version var r0 *FlowVersion var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string, int) (*FlowVersion, *serviceerror.ServiceError)); ok { - return returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) (*FlowVersion, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID, version) } - if returnFunc, ok := ret.Get(0).(func(string, int) *FlowVersion); ok { - r0 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) *FlowVersion); ok { + r0 = returnFunc(ctx, flowID, version) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*FlowVersion) } } - if returnFunc, ok := ret.Get(1).(func(string, int) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, int) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID, version) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -325,25 +352,31 @@ type FlowMgtServiceInterfaceMock_GetFlowVersion_Call struct { } // GetFlowVersion is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - version int -func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlowVersion(flowID interface{}, version interface{}) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { - return &FlowMgtServiceInterfaceMock_GetFlowVersion_Call{Call: _e.mock.On("GetFlowVersion", flowID, version)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlowVersion(ctx interface{}, flowID interface{}, version interface{}) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { + return &FlowMgtServiceInterfaceMock_GetFlowVersion_Call{Call: _e.mock.On("GetFlowVersion", ctx, flowID, version)} } -func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) Run(run func(flowID string, version int)) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) Run(run func(ctx context.Context, flowID string, version int)) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 int + var arg1 string if args[1] != nil { - arg1 = args[1].(int) + arg1 = args[1].(string) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) } run( arg0, arg1, + arg2, ) }) return _c @@ -354,14 +387,14 @@ func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) Return(flowVersion *F return _c } -func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) RunAndReturn(run func(flowID string, version int) (*FlowVersion, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) RunAndReturn(run func(ctx context.Context, flowID string, version int) (*FlowVersion, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { _c.Call.Return(run) return _c } // GetGraph provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) GetGraph(flowID string) (core.GraphInterface, *serviceerror.ServiceError) { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) GetGraph(ctx context.Context, flowID string) (core.GraphInterface, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for GetGraph") @@ -369,18 +402,18 @@ func (_mock *FlowMgtServiceInterfaceMock) GetGraph(flowID string) (core.GraphInt var r0 core.GraphInterface var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string) (core.GraphInterface, *serviceerror.ServiceError)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (core.GraphInterface, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) core.GraphInterface); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) core.GraphInterface); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(core.GraphInterface) } } - if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -395,19 +428,25 @@ type FlowMgtServiceInterfaceMock_GetGraph_Call struct { } // GetGraph is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) GetGraph(flowID interface{}) *FlowMgtServiceInterfaceMock_GetGraph_Call { - return &FlowMgtServiceInterfaceMock_GetGraph_Call{Call: _e.mock.On("GetGraph", flowID)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) GetGraph(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_GetGraph_Call { + return &FlowMgtServiceInterfaceMock_GetGraph_Call{Call: _e.mock.On("GetGraph", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_GetGraph_Call { +func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_GetGraph_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -418,22 +457,22 @@ func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) Return(graphInterface core. return _c } -func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) RunAndReturn(run func(flowID string) (core.GraphInterface, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetGraph_Call { +func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) RunAndReturn(run func(ctx context.Context, flowID string) (core.GraphInterface, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetGraph_Call { _c.Call.Return(run) return _c } // IsValidFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) IsValidFlow(flowID string) bool { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) IsValidFlow(ctx context.Context, flowID string) bool { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for IsValidFlow") } var r0 bool - if returnFunc, ok := ret.Get(0).(func(string) bool); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = returnFunc(ctx, flowID) } else { r0 = ret.Get(0).(bool) } @@ -446,19 +485,25 @@ type FlowMgtServiceInterfaceMock_IsValidFlow_Call struct { } // IsValidFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) IsValidFlow(flowID interface{}) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { - return &FlowMgtServiceInterfaceMock_IsValidFlow_Call{Call: _e.mock.On("IsValidFlow", flowID)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) IsValidFlow(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { + return &FlowMgtServiceInterfaceMock_IsValidFlow_Call{Call: _e.mock.On("IsValidFlow", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -469,14 +514,14 @@ func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) Return(b bool) *FlowMgtS return _c } -func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) RunAndReturn(run func(flowID string) bool) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string) bool) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { _c.Call.Return(run) return _c } // ListFlowVersions provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) ListFlowVersions(flowID string) (*FlowVersionListResponse, *serviceerror.ServiceError) { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) ListFlowVersions(ctx context.Context, flowID string) (*FlowVersionListResponse, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for ListFlowVersions") @@ -484,18 +529,18 @@ func (_mock *FlowMgtServiceInterfaceMock) ListFlowVersions(flowID string) (*Flow var r0 *FlowVersionListResponse var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string) (*FlowVersionListResponse, *serviceerror.ServiceError)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*FlowVersionListResponse, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) *FlowVersionListResponse); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *FlowVersionListResponse); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*FlowVersionListResponse) } } - if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -510,19 +555,25 @@ type FlowMgtServiceInterfaceMock_ListFlowVersions_Call struct { } // ListFlowVersions is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) ListFlowVersions(flowID interface{}) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { - return &FlowMgtServiceInterfaceMock_ListFlowVersions_Call{Call: _e.mock.On("ListFlowVersions", flowID)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) ListFlowVersions(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { + return &FlowMgtServiceInterfaceMock_ListFlowVersions_Call{Call: _e.mock.On("ListFlowVersions", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { +func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -533,14 +584,14 @@ func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) Return(flowVersionL return _c } -func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) RunAndReturn(run func(flowID string) (*FlowVersionListResponse, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { +func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) RunAndReturn(run func(ctx context.Context, flowID string) (*FlowVersionListResponse, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { _c.Call.Return(run) return _c } // ListFlows provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) ListFlows(limit int, offset int, flowType common.FlowType) (*FlowListResponse, *serviceerror.ServiceError) { - ret := _mock.Called(limit, offset, flowType) +func (_mock *FlowMgtServiceInterfaceMock) ListFlows(ctx context.Context, limit int, offset int, flowType common.FlowType) (*FlowListResponse, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, limit, offset, flowType) if len(ret) == 0 { panic("no return value specified for ListFlows") @@ -548,18 +599,18 @@ func (_mock *FlowMgtServiceInterfaceMock) ListFlows(limit int, offset int, flowT var r0 *FlowListResponse var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(int, int, common.FlowType) (*FlowListResponse, *serviceerror.ServiceError)); ok { - return returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, int, int, common.FlowType) (*FlowListResponse, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, limit, offset, flowType) } - if returnFunc, ok := ret.Get(0).(func(int, int, common.FlowType) *FlowListResponse); ok { - r0 = returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, int, int, common.FlowType) *FlowListResponse); ok { + r0 = returnFunc(ctx, limit, offset, flowType) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*FlowListResponse) } } - if returnFunc, ok := ret.Get(1).(func(int, int, common.FlowType) *serviceerror.ServiceError); ok { - r1 = returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(1).(func(context.Context, int, int, common.FlowType) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, limit, offset, flowType) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -574,31 +625,37 @@ type FlowMgtServiceInterfaceMock_ListFlows_Call struct { } // ListFlows is a helper method to define mock.On call +// - ctx context.Context // - limit int // - offset int // - flowType common.FlowType -func (_e *FlowMgtServiceInterfaceMock_Expecter) ListFlows(limit interface{}, offset interface{}, flowType interface{}) *FlowMgtServiceInterfaceMock_ListFlows_Call { - return &FlowMgtServiceInterfaceMock_ListFlows_Call{Call: _e.mock.On("ListFlows", limit, offset, flowType)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) ListFlows(ctx interface{}, limit interface{}, offset interface{}, flowType interface{}) *FlowMgtServiceInterfaceMock_ListFlows_Call { + return &FlowMgtServiceInterfaceMock_ListFlows_Call{Call: _e.mock.On("ListFlows", ctx, limit, offset, flowType)} } -func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) Run(run func(limit int, offset int, flowType common.FlowType)) *FlowMgtServiceInterfaceMock_ListFlows_Call { +func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) Run(run func(ctx context.Context, limit int, offset int, flowType common.FlowType)) *FlowMgtServiceInterfaceMock_ListFlows_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 int + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(int) + arg0 = args[0].(context.Context) } var arg1 int if args[1] != nil { arg1 = args[1].(int) } - var arg2 common.FlowType + var arg2 int if args[2] != nil { - arg2 = args[2].(common.FlowType) + arg2 = args[2].(int) + } + var arg3 common.FlowType + if args[3] != nil { + arg3 = args[3].(common.FlowType) } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -609,14 +666,14 @@ func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) Return(flowListResponse *F return _c } -func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) RunAndReturn(run func(limit int, offset int, flowType common.FlowType) (*FlowListResponse, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_ListFlows_Call { +func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) RunAndReturn(run func(ctx context.Context, limit int, offset int, flowType common.FlowType) (*FlowListResponse, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_ListFlows_Call { _c.Call.Return(run) return _c } // RestoreFlowVersion provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) RestoreFlowVersion(flowID string, version int) (*CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(flowID, version) +func (_mock *FlowMgtServiceInterfaceMock) RestoreFlowVersion(ctx context.Context, flowID string, version int) (*CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID, version) if len(ret) == 0 { panic("no return value specified for RestoreFlowVersion") @@ -624,18 +681,18 @@ func (_mock *FlowMgtServiceInterfaceMock) RestoreFlowVersion(flowID string, vers var r0 *CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string, int) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID, version) } - if returnFunc, ok := ret.Get(0).(func(string, int) *CompleteFlowDefinition); ok { - r0 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID, version) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, int) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, int) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID, version) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -650,25 +707,31 @@ type FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call struct { } // RestoreFlowVersion is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - version int -func (_e *FlowMgtServiceInterfaceMock_Expecter) RestoreFlowVersion(flowID interface{}, version interface{}) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { - return &FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call{Call: _e.mock.On("RestoreFlowVersion", flowID, version)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) RestoreFlowVersion(ctx interface{}, flowID interface{}, version interface{}) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { + return &FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call{Call: _e.mock.On("RestoreFlowVersion", ctx, flowID, version)} } -func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) Run(run func(flowID string, version int)) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { +func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) Run(run func(ctx context.Context, flowID string, version int)) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 int + var arg1 string if args[1] != nil { - arg1 = args[1].(int) + arg1 = args[1].(string) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) } run( arg0, arg1, + arg2, ) }) return _c @@ -679,14 +742,14 @@ func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) Return(completeFl return _c } -func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) RunAndReturn(run func(flowID string, version int) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { +func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) RunAndReturn(run func(ctx context.Context, flowID string, version int) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { _c.Call.Return(run) return _c } // UpdateFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) UpdateFlow(flowID string, flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(flowID, flowDef) +func (_mock *FlowMgtServiceInterfaceMock) UpdateFlow(ctx context.Context, flowID string, flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID, flowDef) if len(ret) == 0 { panic("no return value specified for UpdateFlow") @@ -694,18 +757,18 @@ func (_mock *FlowMgtServiceInterfaceMock) UpdateFlow(flowID string, flowDef *Flo var r0 *CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string, *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(flowID, flowDef) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID, flowDef) } - if returnFunc, ok := ret.Get(0).(func(string, *FlowDefinition) *CompleteFlowDefinition); ok { - r0 = returnFunc(flowID, flowDef) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *FlowDefinition) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID, flowDef) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, *FlowDefinition) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID, flowDef) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, *FlowDefinition) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID, flowDef) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -720,25 +783,31 @@ type FlowMgtServiceInterfaceMock_UpdateFlow_Call struct { } // UpdateFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - flowDef *FlowDefinition -func (_e *FlowMgtServiceInterfaceMock_Expecter) UpdateFlow(flowID interface{}, flowDef interface{}) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { - return &FlowMgtServiceInterfaceMock_UpdateFlow_Call{Call: _e.mock.On("UpdateFlow", flowID, flowDef)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) UpdateFlow(ctx interface{}, flowID interface{}, flowDef interface{}) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { + return &FlowMgtServiceInterfaceMock_UpdateFlow_Call{Call: _e.mock.On("UpdateFlow", ctx, flowID, flowDef)} } -func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) Run(run func(flowID string, flowDef *FlowDefinition)) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) Run(run func(ctx context.Context, flowID string, flowDef *FlowDefinition)) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 *FlowDefinition + var arg1 string if args[1] != nil { - arg1 = args[1].(*FlowDefinition) + arg1 = args[1].(string) + } + var arg2 *FlowDefinition + if args[2] != nil { + arg2 = args[2].(*FlowDefinition) } run( arg0, arg1, + arg2, ) }) return _c @@ -749,7 +818,7 @@ func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) Return(completeFlowDefini return _c } -func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) RunAndReturn(run func(flowID string, flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string, flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { _c.Call.Return(run) return _c } diff --git a/backend/internal/flow/mgt/cache_backed_store.go b/backend/internal/flow/mgt/cache_backed_store.go index f94c1f297..a58542ad9 100644 --- a/backend/internal/flow/mgt/cache_backed_store.go +++ b/backend/internal/flow/mgt/cache_backed_store.go @@ -19,6 +19,7 @@ package flowmgt import ( + "context" "errors" "github.com/asgardeo/thunder/internal/flow/common" @@ -49,15 +50,15 @@ func newCacheBackedFlowStore() flowStoreInterface { // ListFlows retrieves a paginated list of flow definitions. // Note: List operations are not cached as they can vary by parameters and change frequently. -func (s *cacheBackedFlowStore) ListFlows(limit, offset int, flowType string) ( +func (s *cacheBackedFlowStore) ListFlows(ctx context.Context, limit, offset int, flowType string) ( []BasicFlowDefinition, int, error) { - return s.store.ListFlows(limit, offset, flowType) + return s.store.ListFlows(ctx, limit, offset, flowType) } // CreateFlow creates a new flow definition and caches it. -func (s *cacheBackedFlowStore) CreateFlow(flowID string, flow *FlowDefinition) ( +func (s *cacheBackedFlowStore) CreateFlow(ctx context.Context, flowID string, flow *FlowDefinition) ( *CompleteFlowDefinition, error) { - createdFlow, err := s.store.CreateFlow(flowID, flow) + createdFlow, err := s.store.CreateFlow(ctx, flowID, flow) if err != nil { return nil, err } @@ -67,7 +68,7 @@ func (s *cacheBackedFlowStore) CreateFlow(flowID string, flow *FlowDefinition) ( } // GetFlowByID retrieves a flow definition by its ID, using cache if available. -func (s *cacheBackedFlowStore) GetFlowByID(flowID string) (*CompleteFlowDefinition, error) { +func (s *cacheBackedFlowStore) GetFlowByID(ctx context.Context, flowID string) (*CompleteFlowDefinition, error) { cacheKey := cache.CacheKey{ Key: flowID, } @@ -76,7 +77,7 @@ func (s *cacheBackedFlowStore) GetFlowByID(flowID string) (*CompleteFlowDefiniti return cachedFlow, nil } - flow, err := s.store.GetFlowByID(flowID) + flow, err := s.store.GetFlowByID(ctx, flowID) if err != nil || flow == nil { return flow, err } @@ -86,7 +87,7 @@ func (s *cacheBackedFlowStore) GetFlowByID(flowID string) (*CompleteFlowDefiniti } // GetFlowByHandle retrieves a flow definition by handle and flow type, using cache if available. -func (s *cacheBackedFlowStore) GetFlowByHandle(handle string, flowType common.FlowType) ( +func (s *cacheBackedFlowStore) GetFlowByHandle(ctx context.Context, handle string, flowType common.FlowType) ( *CompleteFlowDefinition, error) { cacheKey := getFlowByHandleCacheKey(handle, flowType) cachedFlow, ok := s.flowByHandleCache.Get(cacheKey) @@ -94,7 +95,7 @@ func (s *cacheBackedFlowStore) GetFlowByHandle(handle string, flowType common.Fl return cachedFlow, nil } - flow, err := s.store.GetFlowByHandle(handle, flowType) + flow, err := s.store.GetFlowByHandle(ctx, handle, flowType) if err != nil || flow == nil { return flow, err } @@ -105,9 +106,9 @@ func (s *cacheBackedFlowStore) GetFlowByHandle(handle string, flowType common.Fl } // UpdateFlow updates an existing flow definition and refreshes the cache. -func (s *cacheBackedFlowStore) UpdateFlow(flowID string, flow *FlowDefinition) ( +func (s *cacheBackedFlowStore) UpdateFlow(ctx context.Context, flowID string, flow *FlowDefinition) ( *CompleteFlowDefinition, error) { - updatedFlow, err := s.store.UpdateFlow(flowID, flow) + updatedFlow, err := s.store.UpdateFlow(ctx, flowID, flow) if err != nil { return nil, err } @@ -117,14 +118,14 @@ func (s *cacheBackedFlowStore) UpdateFlow(flowID string, flow *FlowDefinition) ( } // DeleteFlow deletes a flow definition by its ID and invalidates the cache. -func (s *cacheBackedFlowStore) DeleteFlow(flowID string) error { +func (s *cacheBackedFlowStore) DeleteFlow(ctx context.Context, flowID string) error { cacheKey := cache.CacheKey{ Key: flowID, } existingFlow, ok := s.flowByIDCache.Get(cacheKey) if !ok { var err error - existingFlow, err = s.store.GetFlowByID(flowID) + existingFlow, err = s.store.GetFlowByID(ctx, flowID) if err != nil { if errors.Is(err, errFlowNotFound) { return nil @@ -136,7 +137,7 @@ func (s *cacheBackedFlowStore) DeleteFlow(flowID string) error { return nil } - if err := s.store.DeleteFlow(flowID); err != nil { + if err := s.store.DeleteFlow(ctx, flowID); err != nil { return err } s.invalidateFlowCache(flowID) @@ -146,7 +147,7 @@ func (s *cacheBackedFlowStore) DeleteFlow(flowID string) error { } // IsFlowExists checks if a flow exists with a given flow ID, using cache if available. -func (s *cacheBackedFlowStore) IsFlowExists(flowID string) (bool, error) { +func (s *cacheBackedFlowStore) IsFlowExists(ctx context.Context, flowID string) (bool, error) { cacheKey := cache.CacheKey{ Key: flowID, } @@ -155,36 +156,37 @@ func (s *cacheBackedFlowStore) IsFlowExists(flowID string) (bool, error) { return true, nil } - return s.store.IsFlowExists(flowID) + return s.store.IsFlowExists(ctx, flowID) } // IsFlowExistsByHandle checks if a flow exists with a given handle and flow type, using cache if available. -func (s *cacheBackedFlowStore) IsFlowExistsByHandle(handle string, flowType common.FlowType) (bool, error) { +func (s *cacheBackedFlowStore) IsFlowExistsByHandle(ctx context.Context, handle string, flowType common.FlowType) (bool, + error) { cacheKey := getFlowByHandleCacheKey(handle, flowType) cachedFlow, ok := s.flowByHandleCache.Get(cacheKey) if ok && cachedFlow != nil { return true, nil } - return s.store.IsFlowExistsByHandle(handle, flowType) + return s.store.IsFlowExistsByHandle(ctx, handle, flowType) } // ListFlowVersions retrieves all versions of a flow. // Note: Version operations are not cached as they are less frequently accessed. -func (s *cacheBackedFlowStore) ListFlowVersions(flowID string) ([]BasicFlowVersion, error) { - return s.store.ListFlowVersions(flowID) +func (s *cacheBackedFlowStore) ListFlowVersions(ctx context.Context, flowID string) ([]BasicFlowVersion, error) { + return s.store.ListFlowVersions(ctx, flowID) } // GetFlowVersion retrieves a specific version of a flow. // Note: Version operations are not cached as they are less frequently accessed. -func (s *cacheBackedFlowStore) GetFlowVersion(flowID string, version int) (*FlowVersion, error) { - return s.store.GetFlowVersion(flowID, version) +func (s *cacheBackedFlowStore) GetFlowVersion(ctx context.Context, flowID string, version int) (*FlowVersion, error) { + return s.store.GetFlowVersion(ctx, flowID, version) } // RestoreFlowVersion restores a flow to a specific version and invalidates the cache. -func (s *cacheBackedFlowStore) RestoreFlowVersion(flowID string, version int) ( +func (s *cacheBackedFlowStore) RestoreFlowVersion(ctx context.Context, flowID string, version int) ( *CompleteFlowDefinition, error) { - restoredFlow, err := s.store.RestoreFlowVersion(flowID, version) + restoredFlow, err := s.store.RestoreFlowVersion(ctx, flowID, version) if err != nil { return nil, err } diff --git a/backend/internal/flow/mgt/cache_backed_store_test.go b/backend/internal/flow/mgt/cache_backed_store_test.go index 7a3af4298..8eaa64406 100644 --- a/backend/internal/flow/mgt/cache_backed_store_test.go +++ b/backend/internal/flow/mgt/cache_backed_store_test.go @@ -19,6 +19,7 @@ package flowmgt import ( + "context" "errors" "testing" @@ -160,9 +161,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestListFlows() { }, } - s.mockStore.EXPECT().ListFlows(10, 0, "").Return(flows, 2, nil) + s.mockStore.EXPECT().ListFlows(mock.Anything, 10, 0, "").Return(flows, 2, nil) - result, count, err := s.cachedStore.ListFlows(10, 0, "") + result, count, err := s.cachedStore.ListFlows(context.Background(), 10, 0, "") s.NoError(err) s.Len(result, 2) @@ -171,9 +172,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestListFlows() { } func (s *CacheBackedFlowStoreTestSuite) TestListFlowsError() { - s.mockStore.EXPECT().ListFlows(10, 0, "").Return(nil, 0, errors.New("list error")) + s.mockStore.EXPECT().ListFlows(mock.Anything, 10, 0, "").Return(nil, 0, errors.New("list error")) - result, count, err := s.cachedStore.ListFlows(10, 0, "") + result, count, err := s.cachedStore.ListFlows(context.Background(), 10, 0, "") s.Error(err) s.Nil(result) @@ -191,9 +192,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestCreateFlowSuccess() { } expected := s.createTestFlow() - s.mockStore.EXPECT().CreateFlow("flow-1", flowDef).Return(expected, nil) + s.mockStore.EXPECT().CreateFlow(mock.Anything, "flow-1", flowDef).Return(expected, nil) - result, err := s.cachedStore.CreateFlow("flow-1", flowDef) + result, err := s.cachedStore.CreateFlow(context.Background(), "flow-1", flowDef) s.NoError(err) s.NotNil(result) @@ -212,9 +213,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestCreateFlowError() { Nodes: []NodeDefinition{{ID: "node-1", Type: "basic-auth"}}, } - s.mockStore.EXPECT().CreateFlow("flow-1", flowDef).Return(nil, errors.New("create error")) + s.mockStore.EXPECT().CreateFlow(mock.Anything, "flow-1", flowDef).Return(nil, errors.New("create error")) - result, err := s.cachedStore.CreateFlow("flow-1", flowDef) + result, err := s.cachedStore.CreateFlow(context.Background(), "flow-1", flowDef) s.Error(err) s.Nil(result) @@ -227,7 +228,7 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByIDFromCache() { expected := s.createTestFlow() s.cacheData["flow-1"] = expected - result, err := s.cachedStore.GetFlowByID("flow-1") + result, err := s.cachedStore.GetFlowByID(context.Background(), "flow-1") s.NoError(err) s.NotNil(result) @@ -236,9 +237,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByIDFromCache() { func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByIDFromStoreAndCache() { expected := s.createTestFlow() - s.mockStore.EXPECT().GetFlowByID("flow-1").Return(expected, nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, "flow-1").Return(expected, nil) - result, err := s.cachedStore.GetFlowByID("flow-1") + result, err := s.cachedStore.GetFlowByID(context.Background(), "flow-1") s.NoError(err) s.NotNil(result) @@ -250,9 +251,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByIDFromStoreAndCache() { } func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByIDNotFound() { - s.mockStore.EXPECT().GetFlowByID("nonexistent").Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, "nonexistent").Return(nil, errFlowNotFound) - result, err := s.cachedStore.GetFlowByID("nonexistent") + result, err := s.cachedStore.GetFlowByID(context.Background(), "nonexistent") s.Error(err) s.Nil(result) @@ -263,9 +264,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByIDNotFound() { } func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByIDNilFlow() { - s.mockStore.EXPECT().GetFlowByID("flow-1").Return(nil, nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, "flow-1").Return(nil, nil) - result, err := s.cachedStore.GetFlowByID("flow-1") + result, err := s.cachedStore.GetFlowByID(context.Background(), "flow-1") s.NoError(err) s.Nil(result) @@ -280,7 +281,7 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByHandleFromCache() { flow := s.createTestFlow() s.handleCacheData[testAuthenticationHandleCacheKey] = flow - result, err := s.cachedStore.GetFlowByHandle("test-handle", common.FlowTypeAuthentication) + result, err := s.cachedStore.GetFlowByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.NoError(err) s.NotNil(result) @@ -290,9 +291,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByHandleFromCache() { func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByHandleFromStoreAndCache() { flow := s.createTestFlow() - s.mockStore.EXPECT().GetFlowByHandle("test-handle", common.FlowTypeAuthentication).Return(flow, nil) + s.mockStore.EXPECT().GetFlowByHandle(mock.Anything, "test-handle", common.FlowTypeAuthentication).Return(flow, nil) - result, err := s.cachedStore.GetFlowByHandle("test-handle", common.FlowTypeAuthentication) + result, err := s.cachedStore.GetFlowByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.NoError(err) s.NotNil(result) @@ -305,10 +306,10 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByHandleFromStoreAndCache() { } func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByHandleNotFound() { - s.mockStore.EXPECT().GetFlowByHandle("non-existent", common.FlowTypeAuthentication). + s.mockStore.EXPECT().GetFlowByHandle(mock.Anything, "non-existent", common.FlowTypeAuthentication). Return(nil, errFlowNotFound) - result, err := s.cachedStore.GetFlowByHandle("non-existent", common.FlowTypeAuthentication) + result, err := s.cachedStore.GetFlowByHandle(context.Background(), "non-existent", common.FlowTypeAuthentication) s.Error(err) s.ErrorIs(err, errFlowNotFound) @@ -320,9 +321,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByHandleNotFound() { } func (s *CacheBackedFlowStoreTestSuite) TestGetFlowByHandleNilFlow() { - s.mockStore.EXPECT().GetFlowByHandle("test-handle", common.FlowTypeAuthentication).Return(nil, nil) + s.mockStore.EXPECT().GetFlowByHandle(mock.Anything, "test-handle", common.FlowTypeAuthentication).Return(nil, nil) - result, err := s.cachedStore.GetFlowByHandle("test-handle", common.FlowTypeAuthentication) + result, err := s.cachedStore.GetFlowByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.NoError(err) s.Nil(result) @@ -343,9 +344,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestUpdateFlowSuccess() { updated.Name = "Updated Flow" updated.ActiveVersion = 2 - s.mockStore.EXPECT().UpdateFlow("flow-1", flowDef).Return(updated, nil) + s.mockStore.EXPECT().UpdateFlow(mock.Anything, "flow-1", flowDef).Return(updated, nil) - result, err := s.cachedStore.UpdateFlow("flow-1", flowDef) + result, err := s.cachedStore.UpdateFlow(context.Background(), "flow-1", flowDef) s.NoError(err) s.NotNil(result) @@ -365,9 +366,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestUpdateFlowError() { Nodes: []NodeDefinition{{ID: "node-1", Type: "basic-auth"}}, } - s.mockStore.EXPECT().UpdateFlow("flow-1", flowDef).Return(nil, errors.New("update error")) + s.mockStore.EXPECT().UpdateFlow(mock.Anything, "flow-1", flowDef).Return(nil, errors.New("update error")) - result, err := s.cachedStore.UpdateFlow("flow-1", flowDef) + result, err := s.cachedStore.UpdateFlow(context.Background(), "flow-1", flowDef) s.Error(err) s.Nil(result) @@ -377,9 +378,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestDeleteFlowFromCache() { flow := s.createTestFlow() s.cacheData["flow-1"] = flow - s.mockStore.EXPECT().DeleteFlow("flow-1").Return(nil) + s.mockStore.EXPECT().DeleteFlow(mock.Anything, "flow-1").Return(nil) - err := s.cachedStore.DeleteFlow("flow-1") + err := s.cachedStore.DeleteFlow(context.Background(), "flow-1") s.NoError(err) @@ -390,10 +391,10 @@ func (s *CacheBackedFlowStoreTestSuite) TestDeleteFlowFromCache() { func (s *CacheBackedFlowStoreTestSuite) TestDeleteFlowFromStore() { flow := s.createTestFlow() - s.mockStore.EXPECT().GetFlowByID("flow-1").Return(flow, nil) - s.mockStore.EXPECT().DeleteFlow("flow-1").Return(nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, "flow-1").Return(flow, nil) + s.mockStore.EXPECT().DeleteFlow(mock.Anything, "flow-1").Return(nil) - err := s.cachedStore.DeleteFlow("flow-1") + err := s.cachedStore.DeleteFlow(context.Background(), "flow-1") s.NoError(err) @@ -402,17 +403,17 @@ func (s *CacheBackedFlowStoreTestSuite) TestDeleteFlowFromStore() { } func (s *CacheBackedFlowStoreTestSuite) TestDeleteFlowNotFound() { - s.mockStore.EXPECT().GetFlowByID("nonexistent").Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, "nonexistent").Return(nil, errFlowNotFound) - err := s.cachedStore.DeleteFlow("nonexistent") + err := s.cachedStore.DeleteFlow(context.Background(), "nonexistent") s.NoError(err) } func (s *CacheBackedFlowStoreTestSuite) TestDeleteFlowGetError() { - s.mockStore.EXPECT().GetFlowByID("flow-1").Return(nil, errors.New("get error")) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, "flow-1").Return(nil, errors.New("get error")) - err := s.cachedStore.DeleteFlow("flow-1") + err := s.cachedStore.DeleteFlow(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "get error") @@ -422,18 +423,18 @@ func (s *CacheBackedFlowStoreTestSuite) TestDeleteFlowDeleteError() { flow := s.createTestFlow() s.cacheData["flow-1"] = flow - s.mockStore.EXPECT().DeleteFlow("flow-1").Return(errors.New("delete error")) + s.mockStore.EXPECT().DeleteFlow(mock.Anything, "flow-1").Return(errors.New("delete error")) - err := s.cachedStore.DeleteFlow("flow-1") + err := s.cachedStore.DeleteFlow(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "delete error") } func (s *CacheBackedFlowStoreTestSuite) TestDeleteFlowNilFromStore() { - s.mockStore.EXPECT().GetFlowByID("flow-1").Return(nil, nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, "flow-1").Return(nil, nil) - err := s.cachedStore.DeleteFlow("flow-1") + err := s.cachedStore.DeleteFlow(context.Background(), "flow-1") s.NoError(err) } @@ -445,7 +446,7 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsFromCache() { flow := s.createTestFlow() s.cacheData["flow-1"] = flow - exists, err := s.cachedStore.IsFlowExists("flow-1") + exists, err := s.cachedStore.IsFlowExists(context.Background(), "flow-1") s.NoError(err) s.True(exists) @@ -455,9 +456,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsFromCache() { func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsFromStore() { // Not in cache, should query store - s.mockStore.EXPECT().IsFlowExists("flow-2").Return(true, nil) + s.mockStore.EXPECT().IsFlowExists(mock.Anything, "flow-2").Return(true, nil) - exists, err := s.cachedStore.IsFlowExists("flow-2") + exists, err := s.cachedStore.IsFlowExists(context.Background(), "flow-2") s.NoError(err) s.True(exists) @@ -465,9 +466,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsFromStore() { func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsNotFound() { // Not in cache, store returns false - s.mockStore.EXPECT().IsFlowExists("non-existent").Return(false, nil) + s.mockStore.EXPECT().IsFlowExists(mock.Anything, "non-existent").Return(false, nil) - exists, err := s.cachedStore.IsFlowExists("non-existent") + exists, err := s.cachedStore.IsFlowExists(context.Background(), "non-existent") s.NoError(err) s.False(exists) @@ -476,9 +477,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsNotFound() { func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsCacheNil() { // Nil value in cache should query store s.cacheData["flow-3"] = nil - s.mockStore.EXPECT().IsFlowExists("flow-3").Return(true, nil) + s.mockStore.EXPECT().IsFlowExists(mock.Anything, "flow-3").Return(true, nil) - exists, err := s.cachedStore.IsFlowExists("flow-3") + exists, err := s.cachedStore.IsFlowExists(context.Background(), "flow-3") s.NoError(err) s.True(exists) @@ -486,9 +487,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsCacheNil() { func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsStoreError() { // Not in cache, store returns error - s.mockStore.EXPECT().IsFlowExists("flow-error").Return(false, errors.New("db connection error")) + s.mockStore.EXPECT().IsFlowExists(mock.Anything, "flow-error").Return(false, errors.New("db connection error")) - exists, err := s.cachedStore.IsFlowExists("flow-error") + exists, err := s.cachedStore.IsFlowExists(context.Background(), "flow-error") s.Error(err) s.Contains(err.Error(), "db connection error") @@ -505,7 +506,8 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsByHandleFromCache() { FlowType: common.FlowTypeAuthentication, } - exists, err := s.cachedStore.IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication) + exists, err := s.cachedStore.IsFlowExistsByHandle(context.Background(), "test-handle", + common.FlowTypeAuthentication) s.NoError(err) s.True(exists) @@ -515,9 +517,10 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsByHandleFromCache() { func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsByHandleFromStore() { // Not in cache, should query store - s.mockStore.EXPECT().IsFlowExistsByHandle("new-handle", common.FlowTypeAuthentication).Return(true, nil) + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "new-handle", common.FlowTypeAuthentication). + Return(true, nil) - exists, err := s.cachedStore.IsFlowExistsByHandle("new-handle", common.FlowTypeAuthentication) + exists, err := s.cachedStore.IsFlowExistsByHandle(context.Background(), "new-handle", common.FlowTypeAuthentication) s.NoError(err) s.True(exists) @@ -525,9 +528,10 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsByHandleFromStore() { func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsByHandleNotFound() { // Not in cache, store returns false - s.mockStore.EXPECT().IsFlowExistsByHandle("non-existent", common.FlowTypeRegistration).Return(false, nil) + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "non-existent", common.FlowTypeRegistration). + Return(false, nil) - exists, err := s.cachedStore.IsFlowExistsByHandle("non-existent", common.FlowTypeRegistration) + exists, err := s.cachedStore.IsFlowExistsByHandle(context.Background(), "non-existent", common.FlowTypeRegistration) s.NoError(err) s.False(exists) @@ -535,10 +539,11 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsByHandleNotFound() { func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsByHandleStoreError() { // Not in cache, store returns error - s.mockStore.EXPECT().IsFlowExistsByHandle("error-handle", common.FlowTypeAuthentication). + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "error-handle", common.FlowTypeAuthentication). Return(false, errors.New("db connection error")) - exists, err := s.cachedStore.IsFlowExistsByHandle("error-handle", common.FlowTypeAuthentication) + exists, err := s.cachedStore.IsFlowExistsByHandle(context.Background(), "error-handle", + common.FlowTypeAuthentication) s.Error(err) s.Contains(err.Error(), "db connection error") @@ -555,15 +560,18 @@ func (s *CacheBackedFlowStoreTestSuite) TestIsFlowExistsByHandleCompositeKey() { // Cache the auth flow s.handleCacheData["common-handle:AUTHENTICATION"] = authFlow // Registration not in cache, should query store - s.mockStore.EXPECT().IsFlowExistsByHandle("common-handle", common.FlowTypeRegistration).Return(false, nil) + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "common-handle", common.FlowTypeRegistration). + Return(false, nil) // First call - authentication exists (from cache) - exists1, err1 := s.cachedStore.IsFlowExistsByHandle("common-handle", common.FlowTypeAuthentication) + exists1, err1 := s.cachedStore.IsFlowExistsByHandle(context.Background(), "common-handle", + common.FlowTypeAuthentication) s.NoError(err1) s.True(exists1) // Second call - registration doesn't exist - exists2, err2 := s.cachedStore.IsFlowExistsByHandle("common-handle", common.FlowTypeRegistration) + exists2, err2 := s.cachedStore.IsFlowExistsByHandle(context.Background(), "common-handle", + common.FlowTypeRegistration) s.NoError(err2) s.False(exists2) @@ -583,9 +591,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestListFlowVersions() { {Version: 1, CreatedAt: "2025-01-01T00:00:00Z", IsActive: false}, } - s.mockStore.EXPECT().ListFlowVersions("flow-1").Return(versions, nil) + s.mockStore.EXPECT().ListFlowVersions(mock.Anything, "flow-1").Return(versions, nil) - result, err := s.cachedStore.ListFlowVersions("flow-1") + result, err := s.cachedStore.ListFlowVersions(context.Background(), "flow-1") s.NoError(err) s.Len(result, 3) @@ -594,9 +602,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestListFlowVersions() { } func (s *CacheBackedFlowStoreTestSuite) TestListFlowVersionsError() { - s.mockStore.EXPECT().ListFlowVersions("flow-1").Return(nil, errors.New("list versions error")) + s.mockStore.EXPECT().ListFlowVersions(mock.Anything, "flow-1").Return(nil, errors.New("list versions error")) - result, err := s.cachedStore.ListFlowVersions("flow-1") + result, err := s.cachedStore.ListFlowVersions(context.Background(), "flow-1") s.Error(err) s.Nil(result) @@ -614,9 +622,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowVersion() { CreatedAt: "2025-01-02T00:00:00Z", } - s.mockStore.EXPECT().GetFlowVersion("flow-1", 2).Return(version, nil) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, "flow-1", 2).Return(version, nil) - result, err := s.cachedStore.GetFlowVersion("flow-1", 2) + result, err := s.cachedStore.GetFlowVersion(context.Background(), "flow-1", 2) s.NoError(err) s.NotNil(result) @@ -625,9 +633,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestGetFlowVersion() { } func (s *CacheBackedFlowStoreTestSuite) TestGetFlowVersionError() { - s.mockStore.EXPECT().GetFlowVersion("flow-1", 999).Return(nil, errVersionNotFound) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, "flow-1", 999).Return(nil, errVersionNotFound) - result, err := s.cachedStore.GetFlowVersion("flow-1", 999) + result, err := s.cachedStore.GetFlowVersion(context.Background(), "flow-1", 999) s.Error(err) s.Nil(result) @@ -637,9 +645,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestRestoreFlowVersionSuccess() { restored := s.createTestFlow() restored.ActiveVersion = 4 - s.mockStore.EXPECT().RestoreFlowVersion("flow-1", 1).Return(restored, nil) + s.mockStore.EXPECT().RestoreFlowVersion(mock.Anything, "flow-1", 1).Return(restored, nil) - result, err := s.cachedStore.RestoreFlowVersion("flow-1", 1) + result, err := s.cachedStore.RestoreFlowVersion(context.Background(), "flow-1", 1) s.NoError(err) s.NotNil(result) @@ -651,9 +659,9 @@ func (s *CacheBackedFlowStoreTestSuite) TestRestoreFlowVersionSuccess() { } func (s *CacheBackedFlowStoreTestSuite) TestRestoreFlowVersionError() { - s.mockStore.EXPECT().RestoreFlowVersion("flow-1", 1).Return(nil, errors.New("restore error")) + s.mockStore.EXPECT().RestoreFlowVersion(mock.Anything, "flow-1", 1).Return(nil, errors.New("restore error")) - result, err := s.cachedStore.RestoreFlowVersion("flow-1", 1) + result, err := s.cachedStore.RestoreFlowVersion(context.Background(), "flow-1", 1) s.Error(err) s.Nil(result) diff --git a/backend/internal/flow/mgt/declarative_resource.go b/backend/internal/flow/mgt/declarative_resource.go index 2d8bd53c9..bd9adb9a9 100644 --- a/backend/internal/flow/mgt/declarative_resource.go +++ b/backend/internal/flow/mgt/declarative_resource.go @@ -19,6 +19,7 @@ package flowmgt import ( + "context" "fmt" "github.com/asgardeo/thunder/internal/flow/common" @@ -61,7 +62,7 @@ func (e *FlowGraphExporter) GetParameterizerType() string { // GetAllResourceIDs retrieves all flow graph IDs. func (e *FlowGraphExporter) GetAllResourceIDs() ([]string, *serviceerror.ServiceError) { - flows, err := e.service.ListFlows(10000, 0, common.FlowType("")) + flows, err := e.service.ListFlows(context.Background(), 10000, 0, common.FlowType("")) if err != nil { return nil, err } @@ -74,7 +75,7 @@ func (e *FlowGraphExporter) GetAllResourceIDs() ([]string, *serviceerror.Service // GetResourceByID retrieves a flow graph by its ID. func (e *FlowGraphExporter) GetResourceByID(id string) (interface{}, string, *serviceerror.ServiceError) { - flow, err := e.service.GetFlow(id) + flow, err := e.service.GetFlow(context.Background(), id) if err != nil { return nil, "", err } diff --git a/backend/internal/flow/mgt/declarative_resource_test.go b/backend/internal/flow/mgt/declarative_resource_test.go index af0da5971..be42d019c 100644 --- a/backend/internal/flow/mgt/declarative_resource_test.go +++ b/backend/internal/flow/mgt/declarative_resource_test.go @@ -19,10 +19,12 @@ package flowmgt import ( + "context" "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "gopkg.in/yaml.v3" @@ -182,7 +184,7 @@ func (s *DeclarativeResourceTestSuite) TestFlowGraphExporter_GetAllResourceIDs() } // Use common.FlowType to match the service interface type - mockService.EXPECT().ListFlows(10000, 0, common.FlowType("")).Return(listResponse, nil) + mockService.EXPECT().ListFlows(mock.Anything, 10000, 0, common.FlowType("")).Return(listResponse, nil) exporter := newFlowGraphExporter(mockService) ids, err := exporter.GetAllResourceIDs() @@ -202,7 +204,7 @@ func (s *DeclarativeResourceTestSuite) TestFlowGraphExporter_GetAllResourceIDs_E Error: "test error", } - mockService.EXPECT().ListFlows(10000, 0, common.FlowType("")).Return(nil, expectedError) + mockService.EXPECT().ListFlows(mock.Anything, 10000, 0, common.FlowType("")).Return(nil, expectedError) exporter := newFlowGraphExporter(mockService) ids, err := exporter.GetAllResourceIDs() @@ -220,7 +222,7 @@ func (s *DeclarativeResourceTestSuite) TestFlowGraphExporter_GetAllResourceIDs_E Count: 0, } - mockService.EXPECT().ListFlows(10000, 0, common.FlowType("")).Return(listResponse, nil) + mockService.EXPECT().ListFlows(mock.Anything, 10000, 0, common.FlowType("")).Return(listResponse, nil) exporter := newFlowGraphExporter(mockService) ids, err := exporter.GetAllResourceIDs() @@ -238,7 +240,7 @@ func (s *DeclarativeResourceTestSuite) TestFlowGraphExporter_GetResourceByID() { Name: "Auth Flow", } - mockService.EXPECT().GetFlow("flow-001").Return(flow, nil) + mockService.EXPECT().GetFlow(mock.Anything, "flow-001").Return(flow, nil) exporter := newFlowGraphExporter(mockService) resource, name, err := exporter.GetResourceByID("flow-001") @@ -257,7 +259,7 @@ func (s *DeclarativeResourceTestSuite) TestFlowGraphExporter_GetResourceByID_Err Error: "test error", } - mockService.EXPECT().GetFlow("flow-001").Return(nil, expectedError) + mockService.EXPECT().GetFlow(mock.Anything, "flow-001").Return(nil, expectedError) exporter := newFlowGraphExporter(mockService) resource, name, err := exporter.GetResourceByID("flow-001") @@ -335,7 +337,7 @@ func (s *DeclarativeResourceTestSuite) TestFileBasedStore_CreateFlow() { }, } - completeFlow, err := store.CreateFlow("flow-001", flowDef) + completeFlow, err := store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) assert.Equal(s.T(), "flow-001", completeFlow.ID) @@ -359,10 +361,10 @@ func (s *DeclarativeResourceTestSuite) TestFileBasedStore_GetFlowByID() { }, } - _, err := store.CreateFlow("flow-001", flowDef) + _, err := store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - retrieved, err := store.GetFlowByID("flow-001") + retrieved, err := store.GetFlowByID(context.Background(), "flow-001") require.NoError(s.T(), err) assert.Equal(s.T(), "flow-001", retrieved.ID) @@ -374,7 +376,7 @@ func (s *DeclarativeResourceTestSuite) TestFileBasedStore_GetFlowByID_NotFound() _ = entity.GetInstance().Clear() store := newFileBasedStore() - _, err := store.GetFlowByID("non-existent") + _, err := store.GetFlowByID(context.Background(), "non-existent") assert.Error(s.T(), err) } @@ -394,10 +396,10 @@ func (s *DeclarativeResourceTestSuite) TestFileBasedStore_GetFlowByHandle() { }, } - _, err := store.CreateFlow("flow-001", flowDef) + _, err := store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - retrieved, err := store.GetFlowByHandle("test-flow", "AUTHENTICATION") + retrieved, err := store.GetFlowByHandle(context.Background(), "test-flow", "AUTHENTICATION") require.NoError(s.T(), err) assert.Equal(s.T(), "flow-001", retrieved.ID) @@ -420,11 +422,11 @@ func (s *DeclarativeResourceTestSuite) TestFileBasedStore_ListFlows() { {ID: "end", Type: "END"}, }, } - _, err := store.CreateFlow(fmt.Sprintf("flow-%03d", i), flowDef) + _, err := store.CreateFlow(context.Background(), fmt.Sprintf("flow-%03d", i), flowDef) require.NoError(s.T(), err) } - flows, count, err := store.ListFlows(10, 0, "") + flows, count, err := store.ListFlows(context.Background(), 10, 0, "") require.NoError(s.T(), err) assert.Equal(s.T(), 3, count) @@ -447,14 +449,14 @@ func (s *DeclarativeResourceTestSuite) TestFileBasedStore_IsFlowExists() { }, } - _, err := store.CreateFlow("flow-001", flowDef) + _, err := store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - exists, err := store.IsFlowExists("flow-001") + exists, err := store.IsFlowExists(context.Background(), "flow-001") require.NoError(s.T(), err) assert.True(s.T(), exists) - exists, err = store.IsFlowExists("non-existent") + exists, err = store.IsFlowExists(context.Background(), "non-existent") require.NoError(s.T(), err) assert.False(s.T(), exists) } @@ -475,27 +477,27 @@ func (s *DeclarativeResourceTestSuite) TestFileBasedStore_UnsupportedOperations( }, } - _, err := store.CreateFlow("flow-001", flowDef) + _, err := store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) // UpdateFlow - _, err = store.UpdateFlow("flow-001", flowDef) + _, err = store.UpdateFlow(context.Background(), "flow-001", flowDef) assert.Error(s.T(), err) // DeleteFlow - err = store.DeleteFlow("flow-001") + err = store.DeleteFlow(context.Background(), "flow-001") assert.Error(s.T(), err) // ListFlowVersions - _, err = store.ListFlowVersions("flow-001") + _, err = store.ListFlowVersions(context.Background(), "flow-001") assert.Error(s.T(), err) // GetFlowVersion - _, err = store.GetFlowVersion("flow-001", 1) + _, err = store.GetFlowVersion(context.Background(), "flow-001", 1) assert.Error(s.T(), err) // RestoreFlowVersion - _, err = store.RestoreFlowVersion("flow-001", 1) + _, err = store.RestoreFlowVersion(context.Background(), "flow-001", 1) assert.Error(s.T(), err) } @@ -556,8 +558,8 @@ func (s *DeclarativeResourceTestSuite) TestFlowGraphExporterIntegration() { Count: 1, } - mockService.EXPECT().ListFlows(10000, 0, common.FlowType("")).Return(listResponse, nil) - mockService.EXPECT().GetFlow("flow-001").Return(flow, nil) + mockService.EXPECT().ListFlows(mock.Anything, 10000, 0, common.FlowType("")).Return(listResponse, nil) + mockService.EXPECT().GetFlow(mock.Anything, "flow-001").Return(flow, nil) exporter := newFlowGraphExporter(mockService) @@ -939,7 +941,7 @@ func (s *DeclarativeResourceTestSuite) TestFlowExport_WithComplexMeta() { }, } - mockService.EXPECT().GetFlow("test-flow-001").Return(complexFlow, nil) + mockService.EXPECT().GetFlow(mock.Anything, "test-flow-001").Return(complexFlow, nil) exporter := newFlowGraphExporter(mockService) resource, name, err := exporter.GetResourceByID("test-flow-001") diff --git a/backend/internal/flow/mgt/file_based_store.go b/backend/internal/flow/mgt/file_based_store.go index 667f0c7db..396bab4b5 100644 --- a/backend/internal/flow/mgt/file_based_store.go +++ b/backend/internal/flow/mgt/file_based_store.go @@ -19,6 +19,7 @@ package flowmgt import ( + "context" "errors" "github.com/asgardeo/thunder/internal/flow/common" @@ -37,7 +38,7 @@ func (f *fileBasedStore) Create(id string, data interface{}) error { declarativeresource.LogTypeAssertionError("flow", id) return errors.New("invalid flow data type") } - _, err := f.CreateFlow(flow.ID, &FlowDefinition{ + _, err := f.CreateFlow(context.Background(), flow.ID, &FlowDefinition{ Handle: flow.Handle, Name: flow.Name, FlowType: flow.FlowType, @@ -47,7 +48,8 @@ func (f *fileBasedStore) Create(id string, data interface{}) error { } // CreateFlow implements flowStoreInterface. -func (f *fileBasedStore) CreateFlow(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) { +func (f *fileBasedStore) CreateFlow(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, + error) { completeFlow := &CompleteFlowDefinition{ ID: flowID, Handle: flow.Handle, @@ -62,7 +64,8 @@ func (f *fileBasedStore) CreateFlow(flowID string, flow *FlowDefinition) (*Compl } // ListFlows implements flowStoreInterface. -func (f *fileBasedStore) ListFlows(limit, offset int, flowType string) ([]BasicFlowDefinition, int, error) { +func (f *fileBasedStore) ListFlows(ctx context.Context, limit, offset int, flowType string) ([]BasicFlowDefinition, int, + error) { list, err := f.GenericFileBasedStore.List() if err != nil { return nil, 0, err @@ -102,7 +105,7 @@ func (f *fileBasedStore) ListFlows(limit, offset int, flowType string) ([]BasicF } // GetFlowByID implements flowStoreInterface. -func (f *fileBasedStore) GetFlowByID(flowID string) (*CompleteFlowDefinition, error) { +func (f *fileBasedStore) GetFlowByID(ctx context.Context, flowID string) (*CompleteFlowDefinition, error) { data, err := f.GenericFileBasedStore.Get(flowID) if err != nil { return nil, errFlowNotFound @@ -116,7 +119,8 @@ func (f *fileBasedStore) GetFlowByID(flowID string) (*CompleteFlowDefinition, er } // GetFlowByHandle implements flowStoreInterface. -func (f *fileBasedStore) GetFlowByHandle(handle string, flowType common.FlowType) (*CompleteFlowDefinition, error) { +func (f *fileBasedStore) GetFlowByHandle(ctx context.Context, handle string, + flowType common.FlowType) (*CompleteFlowDefinition, error) { data, err := f.GenericFileBasedStore.GetByField(handle, func(d interface{}) string { if flow, ok := d.(*CompleteFlowDefinition); ok && flow.FlowType == flowType { return flow.Handle @@ -135,32 +139,34 @@ func (f *fileBasedStore) GetFlowByHandle(handle string, flowType common.FlowType } // UpdateFlow implements flowStoreInterface. -func (f *fileBasedStore) UpdateFlow(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) { +func (f *fileBasedStore) UpdateFlow(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, + error) { return nil, errors.New("UpdateFlow is not supported in file-based store") } // DeleteFlow implements flowStoreInterface. -func (f *fileBasedStore) DeleteFlow(flowID string) error { +func (f *fileBasedStore) DeleteFlow(ctx context.Context, flowID string) error { return errors.New("DeleteFlow is not supported in file-based store") } // ListFlowVersions implements flowStoreInterface. -func (f *fileBasedStore) ListFlowVersions(flowID string) ([]BasicFlowVersion, error) { +func (f *fileBasedStore) ListFlowVersions(ctx context.Context, flowID string) ([]BasicFlowVersion, error) { return nil, errors.New("ListFlowVersions is not supported in file-based store") } // GetFlowVersion implements flowStoreInterface. -func (f *fileBasedStore) GetFlowVersion(flowID string, version int) (*FlowVersion, error) { +func (f *fileBasedStore) GetFlowVersion(ctx context.Context, flowID string, version int) (*FlowVersion, error) { return nil, errors.New("GetFlowVersion is not supported in file-based store") } // RestoreFlowVersion implements flowStoreInterface. -func (f *fileBasedStore) RestoreFlowVersion(flowID string, version int) (*CompleteFlowDefinition, error) { +func (f *fileBasedStore) RestoreFlowVersion(ctx context.Context, flowID string, version int) (*CompleteFlowDefinition, + error) { return nil, errors.New("RestoreFlowVersion is not supported in file-based store") } // IsFlowExists implements flowStoreInterface. -func (f *fileBasedStore) IsFlowExists(flowID string) (bool, error) { +func (f *fileBasedStore) IsFlowExists(ctx context.Context, flowID string) (bool, error) { _, err := f.GenericFileBasedStore.Get(flowID) if err != nil { return false, nil @@ -169,7 +175,8 @@ func (f *fileBasedStore) IsFlowExists(flowID string) (bool, error) { } // IsFlowExistsByHandle implements flowStoreInterface. -func (f *fileBasedStore) IsFlowExistsByHandle(handle string, flowType common.FlowType) (bool, error) { +func (f *fileBasedStore) IsFlowExistsByHandle(ctx context.Context, handle string, flowType common.FlowType) (bool, + error) { list, err := f.GenericFileBasedStore.List() if err != nil { return false, err diff --git a/backend/internal/flow/mgt/file_based_store_test.go b/backend/internal/flow/mgt/file_based_store_test.go index 35986a05a..420302f8d 100644 --- a/backend/internal/flow/mgt/file_based_store_test.go +++ b/backend/internal/flow/mgt/file_based_store_test.go @@ -19,6 +19,7 @@ package flowmgt import ( + "context" "fmt" "testing" @@ -59,7 +60,7 @@ func (s *FileBasedStoreTestSuite) createTestFlow(handle string) *FlowDefinition func (s *FileBasedStoreTestSuite) TestCreateFlow_Success() { flowDef := s.createTestFlow("test-flow") - completeFlow, err := s.store.CreateFlow("flow-001", flowDef) + completeFlow, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) assert.Equal(s.T(), "flow-001", completeFlow.ID) @@ -72,10 +73,10 @@ func (s *FileBasedStoreTestSuite) TestCreateFlow_Success() { func (s *FileBasedStoreTestSuite) TestGetFlowByID_Success() { flowDef := s.createTestFlow("test-flow") - _, err := s.store.CreateFlow("flow-001", flowDef) + _, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - retrieved, err := s.store.GetFlowByID("flow-001") + retrieved, err := s.store.GetFlowByID(context.Background(), "flow-001") require.NoError(s.T(), err) assert.Equal(s.T(), "flow-001", retrieved.ID) @@ -84,7 +85,7 @@ func (s *FileBasedStoreTestSuite) TestGetFlowByID_Success() { } func (s *FileBasedStoreTestSuite) TestGetFlowByID_NotFound() { - _, err := s.store.GetFlowByID("non-existent") + _, err := s.store.GetFlowByID(context.Background(), "non-existent") assert.Error(s.T(), err) assert.Equal(s.T(), errFlowNotFound, err) @@ -92,10 +93,10 @@ func (s *FileBasedStoreTestSuite) TestGetFlowByID_NotFound() { func (s *FileBasedStoreTestSuite) TestGetFlowByHandle_Success() { flowDef := s.createTestFlow("test-flow") - _, err := s.store.CreateFlow("flow-001", flowDef) + _, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - retrieved, err := s.store.GetFlowByHandle("test-flow", testFlowTypeAuthentication) + retrieved, err := s.store.GetFlowByHandle(context.Background(), "test-flow", testFlowTypeAuthentication) require.NoError(s.T(), err) assert.Equal(s.T(), "flow-001", retrieved.ID) @@ -103,7 +104,7 @@ func (s *FileBasedStoreTestSuite) TestGetFlowByHandle_Success() { } func (s *FileBasedStoreTestSuite) TestGetFlowByHandle_NotFound() { - _, err := s.store.GetFlowByHandle("non-existent", testFlowTypeAuthentication) + _, err := s.store.GetFlowByHandle(context.Background(), "non-existent", testFlowTypeAuthentication) assert.Error(s.T(), err) assert.Equal(s.T(), errFlowNotFound, err) @@ -112,10 +113,10 @@ func (s *FileBasedStoreTestSuite) TestGetFlowByHandle_NotFound() { func (s *FileBasedStoreTestSuite) TestGetFlowByHandle_WrongFlowType() { flowDef := s.createTestFlow("test-flow") flowDef.FlowType = testFlowTypeAuthentication - _, err := s.store.CreateFlow("flow-001", flowDef) + _, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - _, err = s.store.GetFlowByHandle("test-flow", "REGISTRATION") + _, err = s.store.GetFlowByHandle(context.Background(), "test-flow", "REGISTRATION") assert.Error(s.T(), err) } @@ -124,11 +125,11 @@ func (s *FileBasedStoreTestSuite) TestListFlows_NoFilter() { // Create multiple flows for i := 0; i < 3; i++ { flowDef := s.createTestFlow(fmt.Sprintf("flow-%d", i)) - _, err := s.store.CreateFlow(fmt.Sprintf("flow-00%d", i), flowDef) + _, err := s.store.CreateFlow(context.Background(), fmt.Sprintf("flow-00%d", i), flowDef) require.NoError(s.T(), err) } - flows, count, err := s.store.ListFlows(10, 0, "") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "") require.NoError(s.T(), err) assert.Equal(s.T(), 3, count) @@ -139,16 +140,16 @@ func (s *FileBasedStoreTestSuite) TestListFlows_WithFlowTypeFilter() { // Create flows with different types authFlow := s.createTestFlow("auth-flow") authFlow.FlowType = testFlowTypeAuthentication - _, err := s.store.CreateFlow("flow-001", authFlow) + _, err := s.store.CreateFlow(context.Background(), "flow-001", authFlow) require.NoError(s.T(), err) regFlow := s.createTestFlow("reg-flow") regFlow.FlowType = "REGISTRATION" - _, err = s.store.CreateFlow("flow-002", regFlow) + _, err = s.store.CreateFlow(context.Background(), "flow-002", regFlow) require.NoError(s.T(), err) // List only AUTHENTICATION flows - flows, count, err := s.store.ListFlows(10, 0, testFlowTypeAuthentication) + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, testFlowTypeAuthentication) require.NoError(s.T(), err) assert.Equal(s.T(), 1, count) @@ -160,24 +161,24 @@ func (s *FileBasedStoreTestSuite) TestListFlows_Pagination() { // Create 5 flows for i := 0; i < 5; i++ { flowDef := s.createTestFlow(fmt.Sprintf("flow-%d", i)) - _, err := s.store.CreateFlow(fmt.Sprintf("flow-00%d", i), flowDef) + _, err := s.store.CreateFlow(context.Background(), fmt.Sprintf("flow-00%d", i), flowDef) require.NoError(s.T(), err) } // Test first page - flows, count, err := s.store.ListFlows(2, 0, "") + flows, count, err := s.store.ListFlows(context.Background(), 2, 0, "") require.NoError(s.T(), err) assert.Equal(s.T(), 5, count) assert.Len(s.T(), flows, 2) // Test second page - flows, count, err = s.store.ListFlows(2, 2, "") + flows, count, err = s.store.ListFlows(context.Background(), 2, 2, "") require.NoError(s.T(), err) assert.Equal(s.T(), 5, count) assert.Len(s.T(), flows, 2) // Test offset beyond total - flows, count, err = s.store.ListFlows(10, 10, "") + flows, count, err = s.store.ListFlows(context.Background(), 10, 10, "") require.NoError(s.T(), err) assert.Equal(s.T(), 5, count) assert.Len(s.T(), flows, 0) @@ -185,17 +186,17 @@ func (s *FileBasedStoreTestSuite) TestListFlows_Pagination() { func (s *FileBasedStoreTestSuite) TestIsFlowExists_Found() { flowDef := s.createTestFlow("test-flow") - _, err := s.store.CreateFlow("flow-001", flowDef) + _, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - exists, err := s.store.IsFlowExists("flow-001") + exists, err := s.store.IsFlowExists(context.Background(), "flow-001") require.NoError(s.T(), err) assert.True(s.T(), exists) } func (s *FileBasedStoreTestSuite) TestIsFlowExists_NotFound() { - exists, err := s.store.IsFlowExists("non-existent") + exists, err := s.store.IsFlowExists(context.Background(), "non-existent") require.NoError(s.T(), err) assert.False(s.T(), exists) @@ -203,17 +204,17 @@ func (s *FileBasedStoreTestSuite) TestIsFlowExists_NotFound() { func (s *FileBasedStoreTestSuite) TestIsFlowExistsByHandle_Found() { flowDef := s.createTestFlow("test-flow") - _, err := s.store.CreateFlow("flow-001", flowDef) + _, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - exists, err := s.store.IsFlowExistsByHandle("test-flow", testFlowTypeAuthentication) + exists, err := s.store.IsFlowExistsByHandle(context.Background(), "test-flow", testFlowTypeAuthentication) require.NoError(s.T(), err) assert.True(s.T(), exists) } func (s *FileBasedStoreTestSuite) TestIsFlowExistsByHandle_NotFound() { - exists, err := s.store.IsFlowExistsByHandle("non-existent", testFlowTypeAuthentication) + exists, err := s.store.IsFlowExistsByHandle(context.Background(), "non-existent", testFlowTypeAuthentication) require.NoError(s.T(), err) assert.False(s.T(), exists) @@ -222,10 +223,10 @@ func (s *FileBasedStoreTestSuite) TestIsFlowExistsByHandle_NotFound() { func (s *FileBasedStoreTestSuite) TestIsFlowExistsByHandle_WrongFlowType() { flowDef := s.createTestFlow("test-flow") flowDef.FlowType = testFlowTypeAuthentication - _, err := s.store.CreateFlow("flow-001", flowDef) + _, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) - exists, err := s.store.IsFlowExistsByHandle("test-flow", "REGISTRATION") + exists, err := s.store.IsFlowExistsByHandle(context.Background(), "test-flow", "REGISTRATION") require.NoError(s.T(), err) assert.False(s.T(), exists) @@ -234,35 +235,35 @@ func (s *FileBasedStoreTestSuite) TestIsFlowExistsByHandle_WrongFlowType() { func (s *FileBasedStoreTestSuite) TestUpdateFlow_NotSupported() { flowDef := s.createTestFlow("test-flow") - _, err := s.store.UpdateFlow("flow-001", flowDef) + _, err := s.store.UpdateFlow(context.Background(), "flow-001", flowDef) assert.Error(s.T(), err) assert.Contains(s.T(), err.Error(), "not supported in file-based store") } func (s *FileBasedStoreTestSuite) TestDeleteFlow_NotSupported() { - err := s.store.DeleteFlow("flow-001") + err := s.store.DeleteFlow(context.Background(), "flow-001") assert.Error(s.T(), err) assert.Contains(s.T(), err.Error(), "not supported in file-based store") } func (s *FileBasedStoreTestSuite) TestListFlowVersions_NotSupported() { - _, err := s.store.ListFlowVersions("flow-001") + _, err := s.store.ListFlowVersions(context.Background(), "flow-001") assert.Error(s.T(), err) assert.Contains(s.T(), err.Error(), "not supported in file-based store") } func (s *FileBasedStoreTestSuite) TestGetFlowVersion_NotSupported() { - _, err := s.store.GetFlowVersion("flow-001", 1) + _, err := s.store.GetFlowVersion(context.Background(), "flow-001", 1) assert.Error(s.T(), err) assert.Contains(s.T(), err.Error(), "not supported in file-based store") } func (s *FileBasedStoreTestSuite) TestRestoreFlowVersion_NotSupported() { - _, err := s.store.RestoreFlowVersion("flow-001", 1) + _, err := s.store.RestoreFlowVersion(context.Background(), "flow-001", 1) assert.Error(s.T(), err) assert.Contains(s.T(), err.Error(), "not supported in file-based store") @@ -287,13 +288,13 @@ func (s *FileBasedStoreTestSuite) TestCreate_ImplementsStorer() { require.NoError(s.T(), err) // Verify it was created - retrieved, err := s.store.GetFlowByID("flow-001") + retrieved, err := s.store.GetFlowByID(context.Background(), "flow-001") require.NoError(s.T(), err) assert.Equal(s.T(), "test-flow", retrieved.Handle) } func (s *FileBasedStoreTestSuite) TestListFlows_EmptyStore() { - flows, count, err := s.store.ListFlows(10, 0, "") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "") require.NoError(s.T(), err) assert.Equal(s.T(), 0, count) @@ -304,21 +305,21 @@ func (s *FileBasedStoreTestSuite) TestGetFlowByHandle_MultipleFlowsSameHandle() // Create two flows with different types but same handle authFlow := s.createTestFlow("common-handle") authFlow.FlowType = testFlowTypeAuthentication - _, err := s.store.CreateFlow("flow-001", authFlow) + _, err := s.store.CreateFlow(context.Background(), "flow-001", authFlow) require.NoError(s.T(), err) regFlow := s.createTestFlow("common-handle") regFlow.FlowType = "REGISTRATION" - _, err = s.store.CreateFlow("flow-002", regFlow) + _, err = s.store.CreateFlow(context.Background(), "flow-002", regFlow) require.NoError(s.T(), err) // Retrieve by handle and type should get the correct one - authRetrieved, err := s.store.GetFlowByHandle("common-handle", testFlowTypeAuthentication) + authRetrieved, err := s.store.GetFlowByHandle(context.Background(), "common-handle", testFlowTypeAuthentication) require.NoError(s.T(), err) assert.Equal(s.T(), "flow-001", authRetrieved.ID) // Retrieve REGISTRATION flow by handle and type should also work - regRetrieved, err := s.store.GetFlowByHandle("common-handle", "REGISTRATION") + regRetrieved, err := s.store.GetFlowByHandle(context.Background(), "common-handle", "REGISTRATION") require.NoError(s.T(), err) assert.Equal(s.T(), "flow-002", regRetrieved.ID) assert.Equal(s.T(), "common-handle", regRetrieved.Handle) @@ -328,7 +329,7 @@ func (s *FileBasedStoreTestSuite) TestGetFlowByID_TypeAssertionFailure() { // This test verifies the type assertion error path in GetFlowByID // Create a flow and then manually corrupt the store data flowDef := s.createTestFlow("test-flow") - _, err := s.store.CreateFlow("flow-001", flowDef) + _, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) // Access the underlying store to corrupt data @@ -338,7 +339,7 @@ func (s *FileBasedStoreTestSuite) TestGetFlowByID_TypeAssertionFailure() { require.NoError(s.T(), err) // Try to retrieve the corrupted flow - _, err = s.store.GetFlowByID("corrupted-flow") + _, err = s.store.GetFlowByID(context.Background(), "corrupted-flow") assert.Error(s.T(), err) assert.Equal(s.T(), errFlowNotFound, err) } @@ -347,7 +348,7 @@ func (s *FileBasedStoreTestSuite) TestListFlows_TypeAssertionSkip() { // Create valid flows for i := 0; i < 2; i++ { flowDef := s.createTestFlow(fmt.Sprintf("flow-%d", i)) - _, err := s.store.CreateFlow(fmt.Sprintf("flow-00%d", i), flowDef) + _, err := s.store.CreateFlow(context.Background(), fmt.Sprintf("flow-00%d", i), flowDef) require.NoError(s.T(), err) } @@ -357,7 +358,7 @@ func (s *FileBasedStoreTestSuite) TestListFlows_TypeAssertionSkip() { require.NoError(s.T(), err) // List should skip the corrupted entry and return valid flows - flows, count, err := s.store.ListFlows(10, 0, "") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "") require.NoError(s.T(), err) assert.Equal(s.T(), 2, count) assert.Len(s.T(), flows, 2) @@ -366,7 +367,7 @@ func (s *FileBasedStoreTestSuite) TestListFlows_TypeAssertionSkip() { func (s *FileBasedStoreTestSuite) TestIsFlowExistsByHandle_TypeAssertionSkip() { // Create valid flow flowDef := s.createTestFlow("test-flow") - _, err := s.store.CreateFlow("flow-001", flowDef) + _, err := s.store.CreateFlow(context.Background(), "flow-001", flowDef) require.NoError(s.T(), err) // Add corrupted data to store @@ -375,7 +376,7 @@ func (s *FileBasedStoreTestSuite) TestIsFlowExistsByHandle_TypeAssertionSkip() { require.NoError(s.T(), err) // Should still find the valid flow - exists, err := s.store.IsFlowExistsByHandle("test-flow", testFlowTypeAuthentication) + exists, err := s.store.IsFlowExistsByHandle(context.Background(), "test-flow", testFlowTypeAuthentication) require.NoError(s.T(), err) assert.True(s.T(), exists) } @@ -386,7 +387,7 @@ func (s *FileBasedStoreTestSuite) TestIsFlowExistsByHandle_ListError() { // IsFlowExistsByHandle should handle list errors gracefully // In the current implementation, it returns the error from List() - exists, err := store.IsFlowExistsByHandle("test", testFlowTypeAuthentication) + exists, err := store.IsFlowExistsByHandle(context.Background(), "test", testFlowTypeAuthentication) // With empty store, should return false with no error require.NoError(s.T(), err) assert.False(s.T(), exists) @@ -412,7 +413,7 @@ func (s *FileBasedStoreTestSuite) TestCreate_WithCompleteFlow() { require.NoError(s.T(), err) // Verify it was created correctly - retrieved, err := s.store.GetFlowByID("flow-100") + retrieved, err := s.store.GetFlowByID(context.Background(), "flow-100") require.NoError(s.T(), err) assert.Equal(s.T(), "flow-100", retrieved.ID) assert.Equal(s.T(), "complete-flow", retrieved.Handle) diff --git a/backend/internal/flow/mgt/flowStoreInterface_mock_test.go b/backend/internal/flow/mgt/flowStoreInterface_mock_test.go index 64edc00ea..6e4730670 100644 --- a/backend/internal/flow/mgt/flowStoreInterface_mock_test.go +++ b/backend/internal/flow/mgt/flowStoreInterface_mock_test.go @@ -5,6 +5,8 @@ package flowmgt import ( + "context" + "github.com/asgardeo/thunder/internal/flow/common" mock "github.com/stretchr/testify/mock" ) @@ -37,8 +39,8 @@ func (_m *flowStoreInterfaceMock) EXPECT() *flowStoreInterfaceMock_Expecter { } // CreateFlow provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) CreateFlow(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) { - ret := _mock.Called(flowID, flow) +func (_mock *flowStoreInterfaceMock) CreateFlow(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) { + ret := _mock.Called(ctx, flowID, flow) if len(ret) == 0 { panic("no return value specified for CreateFlow") @@ -46,18 +48,18 @@ func (_mock *flowStoreInterfaceMock) CreateFlow(flowID string, flow *FlowDefinit var r0 *CompleteFlowDefinition var r1 error - if returnFunc, ok := ret.Get(0).(func(string, *FlowDefinition) (*CompleteFlowDefinition, error)); ok { - return returnFunc(flowID, flow) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *FlowDefinition) (*CompleteFlowDefinition, error)); ok { + return returnFunc(ctx, flowID, flow) } - if returnFunc, ok := ret.Get(0).(func(string, *FlowDefinition) *CompleteFlowDefinition); ok { - r0 = returnFunc(flowID, flow) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *FlowDefinition) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID, flow) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, *FlowDefinition) error); ok { - r1 = returnFunc(flowID, flow) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, *FlowDefinition) error); ok { + r1 = returnFunc(ctx, flowID, flow) } else { r1 = ret.Error(1) } @@ -70,25 +72,31 @@ type flowStoreInterfaceMock_CreateFlow_Call struct { } // CreateFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - flow *FlowDefinition -func (_e *flowStoreInterfaceMock_Expecter) CreateFlow(flowID interface{}, flow interface{}) *flowStoreInterfaceMock_CreateFlow_Call { - return &flowStoreInterfaceMock_CreateFlow_Call{Call: _e.mock.On("CreateFlow", flowID, flow)} +func (_e *flowStoreInterfaceMock_Expecter) CreateFlow(ctx interface{}, flowID interface{}, flow interface{}) *flowStoreInterfaceMock_CreateFlow_Call { + return &flowStoreInterfaceMock_CreateFlow_Call{Call: _e.mock.On("CreateFlow", ctx, flowID, flow)} } -func (_c *flowStoreInterfaceMock_CreateFlow_Call) Run(run func(flowID string, flow *FlowDefinition)) *flowStoreInterfaceMock_CreateFlow_Call { +func (_c *flowStoreInterfaceMock_CreateFlow_Call) Run(run func(ctx context.Context, flowID string, flow *FlowDefinition)) *flowStoreInterfaceMock_CreateFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 *FlowDefinition + var arg1 string if args[1] != nil { - arg1 = args[1].(*FlowDefinition) + arg1 = args[1].(string) + } + var arg2 *FlowDefinition + if args[2] != nil { + arg2 = args[2].(*FlowDefinition) } run( arg0, arg1, + arg2, ) }) return _c @@ -99,22 +107,22 @@ func (_c *flowStoreInterfaceMock_CreateFlow_Call) Return(completeFlowDefinition return _c } -func (_c *flowStoreInterfaceMock_CreateFlow_Call) RunAndReturn(run func(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_CreateFlow_Call { +func (_c *flowStoreInterfaceMock_CreateFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_CreateFlow_Call { _c.Call.Return(run) return _c } // DeleteFlow provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) DeleteFlow(flowID string) error { - ret := _mock.Called(flowID) +func (_mock *flowStoreInterfaceMock) DeleteFlow(ctx context.Context, flowID string) error { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for DeleteFlow") } var r0 error - if returnFunc, ok := ret.Get(0).(func(string) error); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, flowID) } else { r0 = ret.Error(0) } @@ -127,19 +135,25 @@ type flowStoreInterfaceMock_DeleteFlow_Call struct { } // DeleteFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *flowStoreInterfaceMock_Expecter) DeleteFlow(flowID interface{}) *flowStoreInterfaceMock_DeleteFlow_Call { - return &flowStoreInterfaceMock_DeleteFlow_Call{Call: _e.mock.On("DeleteFlow", flowID)} +func (_e *flowStoreInterfaceMock_Expecter) DeleteFlow(ctx interface{}, flowID interface{}) *flowStoreInterfaceMock_DeleteFlow_Call { + return &flowStoreInterfaceMock_DeleteFlow_Call{Call: _e.mock.On("DeleteFlow", ctx, flowID)} } -func (_c *flowStoreInterfaceMock_DeleteFlow_Call) Run(run func(flowID string)) *flowStoreInterfaceMock_DeleteFlow_Call { +func (_c *flowStoreInterfaceMock_DeleteFlow_Call) Run(run func(ctx context.Context, flowID string)) *flowStoreInterfaceMock_DeleteFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -150,14 +164,14 @@ func (_c *flowStoreInterfaceMock_DeleteFlow_Call) Return(err error) *flowStoreIn return _c } -func (_c *flowStoreInterfaceMock_DeleteFlow_Call) RunAndReturn(run func(flowID string) error) *flowStoreInterfaceMock_DeleteFlow_Call { +func (_c *flowStoreInterfaceMock_DeleteFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string) error) *flowStoreInterfaceMock_DeleteFlow_Call { _c.Call.Return(run) return _c } // GetFlowByHandle provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) GetFlowByHandle(handle string, flowType common.FlowType) (*CompleteFlowDefinition, error) { - ret := _mock.Called(handle, flowType) +func (_mock *flowStoreInterfaceMock) GetFlowByHandle(ctx context.Context, handle string, flowType common.FlowType) (*CompleteFlowDefinition, error) { + ret := _mock.Called(ctx, handle, flowType) if len(ret) == 0 { panic("no return value specified for GetFlowByHandle") @@ -165,18 +179,18 @@ func (_mock *flowStoreInterfaceMock) GetFlowByHandle(handle string, flowType com var r0 *CompleteFlowDefinition var r1 error - if returnFunc, ok := ret.Get(0).(func(string, common.FlowType) (*CompleteFlowDefinition, error)); ok { - return returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, common.FlowType) (*CompleteFlowDefinition, error)); ok { + return returnFunc(ctx, handle, flowType) } - if returnFunc, ok := ret.Get(0).(func(string, common.FlowType) *CompleteFlowDefinition); ok { - r0 = returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, common.FlowType) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, handle, flowType) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, common.FlowType) error); ok { - r1 = returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, common.FlowType) error); ok { + r1 = returnFunc(ctx, handle, flowType) } else { r1 = ret.Error(1) } @@ -189,25 +203,31 @@ type flowStoreInterfaceMock_GetFlowByHandle_Call struct { } // GetFlowByHandle is a helper method to define mock.On call +// - ctx context.Context // - handle string // - flowType common.FlowType -func (_e *flowStoreInterfaceMock_Expecter) GetFlowByHandle(handle interface{}, flowType interface{}) *flowStoreInterfaceMock_GetFlowByHandle_Call { - return &flowStoreInterfaceMock_GetFlowByHandle_Call{Call: _e.mock.On("GetFlowByHandle", handle, flowType)} +func (_e *flowStoreInterfaceMock_Expecter) GetFlowByHandle(ctx interface{}, handle interface{}, flowType interface{}) *flowStoreInterfaceMock_GetFlowByHandle_Call { + return &flowStoreInterfaceMock_GetFlowByHandle_Call{Call: _e.mock.On("GetFlowByHandle", ctx, handle, flowType)} } -func (_c *flowStoreInterfaceMock_GetFlowByHandle_Call) Run(run func(handle string, flowType common.FlowType)) *flowStoreInterfaceMock_GetFlowByHandle_Call { +func (_c *flowStoreInterfaceMock_GetFlowByHandle_Call) Run(run func(ctx context.Context, handle string, flowType common.FlowType)) *flowStoreInterfaceMock_GetFlowByHandle_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 common.FlowType + var arg1 string if args[1] != nil { - arg1 = args[1].(common.FlowType) + arg1 = args[1].(string) + } + var arg2 common.FlowType + if args[2] != nil { + arg2 = args[2].(common.FlowType) } run( arg0, arg1, + arg2, ) }) return _c @@ -218,14 +238,14 @@ func (_c *flowStoreInterfaceMock_GetFlowByHandle_Call) Return(completeFlowDefini return _c } -func (_c *flowStoreInterfaceMock_GetFlowByHandle_Call) RunAndReturn(run func(handle string, flowType common.FlowType) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_GetFlowByHandle_Call { +func (_c *flowStoreInterfaceMock_GetFlowByHandle_Call) RunAndReturn(run func(ctx context.Context, handle string, flowType common.FlowType) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_GetFlowByHandle_Call { _c.Call.Return(run) return _c } // GetFlowByID provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) GetFlowByID(flowID string) (*CompleteFlowDefinition, error) { - ret := _mock.Called(flowID) +func (_mock *flowStoreInterfaceMock) GetFlowByID(ctx context.Context, flowID string) (*CompleteFlowDefinition, error) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for GetFlowByID") @@ -233,18 +253,18 @@ func (_mock *flowStoreInterfaceMock) GetFlowByID(flowID string) (*CompleteFlowDe var r0 *CompleteFlowDefinition var r1 error - if returnFunc, ok := ret.Get(0).(func(string) (*CompleteFlowDefinition, error)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*CompleteFlowDefinition, error)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) *CompleteFlowDefinition); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string) error); ok { - r1 = returnFunc(flowID) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, flowID) } else { r1 = ret.Error(1) } @@ -257,19 +277,25 @@ type flowStoreInterfaceMock_GetFlowByID_Call struct { } // GetFlowByID is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *flowStoreInterfaceMock_Expecter) GetFlowByID(flowID interface{}) *flowStoreInterfaceMock_GetFlowByID_Call { - return &flowStoreInterfaceMock_GetFlowByID_Call{Call: _e.mock.On("GetFlowByID", flowID)} +func (_e *flowStoreInterfaceMock_Expecter) GetFlowByID(ctx interface{}, flowID interface{}) *flowStoreInterfaceMock_GetFlowByID_Call { + return &flowStoreInterfaceMock_GetFlowByID_Call{Call: _e.mock.On("GetFlowByID", ctx, flowID)} } -func (_c *flowStoreInterfaceMock_GetFlowByID_Call) Run(run func(flowID string)) *flowStoreInterfaceMock_GetFlowByID_Call { +func (_c *flowStoreInterfaceMock_GetFlowByID_Call) Run(run func(ctx context.Context, flowID string)) *flowStoreInterfaceMock_GetFlowByID_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -280,14 +306,14 @@ func (_c *flowStoreInterfaceMock_GetFlowByID_Call) Return(completeFlowDefinition return _c } -func (_c *flowStoreInterfaceMock_GetFlowByID_Call) RunAndReturn(run func(flowID string) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_GetFlowByID_Call { +func (_c *flowStoreInterfaceMock_GetFlowByID_Call) RunAndReturn(run func(ctx context.Context, flowID string) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_GetFlowByID_Call { _c.Call.Return(run) return _c } // GetFlowVersion provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) GetFlowVersion(flowID string, version int) (*FlowVersion, error) { - ret := _mock.Called(flowID, version) +func (_mock *flowStoreInterfaceMock) GetFlowVersion(ctx context.Context, flowID string, version int) (*FlowVersion, error) { + ret := _mock.Called(ctx, flowID, version) if len(ret) == 0 { panic("no return value specified for GetFlowVersion") @@ -295,18 +321,18 @@ func (_mock *flowStoreInterfaceMock) GetFlowVersion(flowID string, version int) var r0 *FlowVersion var r1 error - if returnFunc, ok := ret.Get(0).(func(string, int) (*FlowVersion, error)); ok { - return returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) (*FlowVersion, error)); ok { + return returnFunc(ctx, flowID, version) } - if returnFunc, ok := ret.Get(0).(func(string, int) *FlowVersion); ok { - r0 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) *FlowVersion); ok { + r0 = returnFunc(ctx, flowID, version) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*FlowVersion) } } - if returnFunc, ok := ret.Get(1).(func(string, int) error); ok { - r1 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, int) error); ok { + r1 = returnFunc(ctx, flowID, version) } else { r1 = ret.Error(1) } @@ -319,25 +345,31 @@ type flowStoreInterfaceMock_GetFlowVersion_Call struct { } // GetFlowVersion is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - version int -func (_e *flowStoreInterfaceMock_Expecter) GetFlowVersion(flowID interface{}, version interface{}) *flowStoreInterfaceMock_GetFlowVersion_Call { - return &flowStoreInterfaceMock_GetFlowVersion_Call{Call: _e.mock.On("GetFlowVersion", flowID, version)} +func (_e *flowStoreInterfaceMock_Expecter) GetFlowVersion(ctx interface{}, flowID interface{}, version interface{}) *flowStoreInterfaceMock_GetFlowVersion_Call { + return &flowStoreInterfaceMock_GetFlowVersion_Call{Call: _e.mock.On("GetFlowVersion", ctx, flowID, version)} } -func (_c *flowStoreInterfaceMock_GetFlowVersion_Call) Run(run func(flowID string, version int)) *flowStoreInterfaceMock_GetFlowVersion_Call { +func (_c *flowStoreInterfaceMock_GetFlowVersion_Call) Run(run func(ctx context.Context, flowID string, version int)) *flowStoreInterfaceMock_GetFlowVersion_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 int + var arg1 string if args[1] != nil { - arg1 = args[1].(int) + arg1 = args[1].(string) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) } run( arg0, arg1, + arg2, ) }) return _c @@ -348,14 +380,14 @@ func (_c *flowStoreInterfaceMock_GetFlowVersion_Call) Return(flowVersion *FlowVe return _c } -func (_c *flowStoreInterfaceMock_GetFlowVersion_Call) RunAndReturn(run func(flowID string, version int) (*FlowVersion, error)) *flowStoreInterfaceMock_GetFlowVersion_Call { +func (_c *flowStoreInterfaceMock_GetFlowVersion_Call) RunAndReturn(run func(ctx context.Context, flowID string, version int) (*FlowVersion, error)) *flowStoreInterfaceMock_GetFlowVersion_Call { _c.Call.Return(run) return _c } // IsFlowExists provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) IsFlowExists(flowID string) (bool, error) { - ret := _mock.Called(flowID) +func (_mock *flowStoreInterfaceMock) IsFlowExists(ctx context.Context, flowID string) (bool, error) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for IsFlowExists") @@ -363,16 +395,16 @@ func (_mock *flowStoreInterfaceMock) IsFlowExists(flowID string) (bool, error) { var r0 bool var r1 error - if returnFunc, ok := ret.Get(0).(func(string) (bool, error)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) bool); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = returnFunc(ctx, flowID) } else { r0 = ret.Get(0).(bool) } - if returnFunc, ok := ret.Get(1).(func(string) error); ok { - r1 = returnFunc(flowID) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, flowID) } else { r1 = ret.Error(1) } @@ -385,19 +417,25 @@ type flowStoreInterfaceMock_IsFlowExists_Call struct { } // IsFlowExists is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *flowStoreInterfaceMock_Expecter) IsFlowExists(flowID interface{}) *flowStoreInterfaceMock_IsFlowExists_Call { - return &flowStoreInterfaceMock_IsFlowExists_Call{Call: _e.mock.On("IsFlowExists", flowID)} +func (_e *flowStoreInterfaceMock_Expecter) IsFlowExists(ctx interface{}, flowID interface{}) *flowStoreInterfaceMock_IsFlowExists_Call { + return &flowStoreInterfaceMock_IsFlowExists_Call{Call: _e.mock.On("IsFlowExists", ctx, flowID)} } -func (_c *flowStoreInterfaceMock_IsFlowExists_Call) Run(run func(flowID string)) *flowStoreInterfaceMock_IsFlowExists_Call { +func (_c *flowStoreInterfaceMock_IsFlowExists_Call) Run(run func(ctx context.Context, flowID string)) *flowStoreInterfaceMock_IsFlowExists_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -408,14 +446,14 @@ func (_c *flowStoreInterfaceMock_IsFlowExists_Call) Return(b bool, err error) *f return _c } -func (_c *flowStoreInterfaceMock_IsFlowExists_Call) RunAndReturn(run func(flowID string) (bool, error)) *flowStoreInterfaceMock_IsFlowExists_Call { +func (_c *flowStoreInterfaceMock_IsFlowExists_Call) RunAndReturn(run func(ctx context.Context, flowID string) (bool, error)) *flowStoreInterfaceMock_IsFlowExists_Call { _c.Call.Return(run) return _c } // IsFlowExistsByHandle provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) IsFlowExistsByHandle(handle string, flowType common.FlowType) (bool, error) { - ret := _mock.Called(handle, flowType) +func (_mock *flowStoreInterfaceMock) IsFlowExistsByHandle(ctx context.Context, handle string, flowType common.FlowType) (bool, error) { + ret := _mock.Called(ctx, handle, flowType) if len(ret) == 0 { panic("no return value specified for IsFlowExistsByHandle") @@ -423,16 +461,16 @@ func (_mock *flowStoreInterfaceMock) IsFlowExistsByHandle(handle string, flowTyp var r0 bool var r1 error - if returnFunc, ok := ret.Get(0).(func(string, common.FlowType) (bool, error)); ok { - return returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, common.FlowType) (bool, error)); ok { + return returnFunc(ctx, handle, flowType) } - if returnFunc, ok := ret.Get(0).(func(string, common.FlowType) bool); ok { - r0 = returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, common.FlowType) bool); ok { + r0 = returnFunc(ctx, handle, flowType) } else { r0 = ret.Get(0).(bool) } - if returnFunc, ok := ret.Get(1).(func(string, common.FlowType) error); ok { - r1 = returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, common.FlowType) error); ok { + r1 = returnFunc(ctx, handle, flowType) } else { r1 = ret.Error(1) } @@ -445,25 +483,31 @@ type flowStoreInterfaceMock_IsFlowExistsByHandle_Call struct { } // IsFlowExistsByHandle is a helper method to define mock.On call +// - ctx context.Context // - handle string // - flowType common.FlowType -func (_e *flowStoreInterfaceMock_Expecter) IsFlowExistsByHandle(handle interface{}, flowType interface{}) *flowStoreInterfaceMock_IsFlowExistsByHandle_Call { - return &flowStoreInterfaceMock_IsFlowExistsByHandle_Call{Call: _e.mock.On("IsFlowExistsByHandle", handle, flowType)} +func (_e *flowStoreInterfaceMock_Expecter) IsFlowExistsByHandle(ctx interface{}, handle interface{}, flowType interface{}) *flowStoreInterfaceMock_IsFlowExistsByHandle_Call { + return &flowStoreInterfaceMock_IsFlowExistsByHandle_Call{Call: _e.mock.On("IsFlowExistsByHandle", ctx, handle, flowType)} } -func (_c *flowStoreInterfaceMock_IsFlowExistsByHandle_Call) Run(run func(handle string, flowType common.FlowType)) *flowStoreInterfaceMock_IsFlowExistsByHandle_Call { +func (_c *flowStoreInterfaceMock_IsFlowExistsByHandle_Call) Run(run func(ctx context.Context, handle string, flowType common.FlowType)) *flowStoreInterfaceMock_IsFlowExistsByHandle_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 common.FlowType + var arg1 string if args[1] != nil { - arg1 = args[1].(common.FlowType) + arg1 = args[1].(string) + } + var arg2 common.FlowType + if args[2] != nil { + arg2 = args[2].(common.FlowType) } run( arg0, arg1, + arg2, ) }) return _c @@ -474,14 +518,14 @@ func (_c *flowStoreInterfaceMock_IsFlowExistsByHandle_Call) Return(b bool, err e return _c } -func (_c *flowStoreInterfaceMock_IsFlowExistsByHandle_Call) RunAndReturn(run func(handle string, flowType common.FlowType) (bool, error)) *flowStoreInterfaceMock_IsFlowExistsByHandle_Call { +func (_c *flowStoreInterfaceMock_IsFlowExistsByHandle_Call) RunAndReturn(run func(ctx context.Context, handle string, flowType common.FlowType) (bool, error)) *flowStoreInterfaceMock_IsFlowExistsByHandle_Call { _c.Call.Return(run) return _c } // ListFlowVersions provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) ListFlowVersions(flowID string) ([]BasicFlowVersion, error) { - ret := _mock.Called(flowID) +func (_mock *flowStoreInterfaceMock) ListFlowVersions(ctx context.Context, flowID string) ([]BasicFlowVersion, error) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for ListFlowVersions") @@ -489,11 +533,11 @@ func (_mock *flowStoreInterfaceMock) ListFlowVersions(flowID string) ([]BasicFlo var r0 []BasicFlowVersion var r1 error - if returnFunc, ok := ret.Get(0).(func(string) ([]BasicFlowVersion, error)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]BasicFlowVersion, error)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) []BasicFlowVersion); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) []BasicFlowVersion); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]BasicFlowVersion) @@ -513,19 +557,25 @@ type flowStoreInterfaceMock_ListFlowVersions_Call struct { } // ListFlowVersions is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *flowStoreInterfaceMock_Expecter) ListFlowVersions(flowID interface{}) *flowStoreInterfaceMock_ListFlowVersions_Call { - return &flowStoreInterfaceMock_ListFlowVersions_Call{Call: _e.mock.On("ListFlowVersions", flowID)} +func (_e *flowStoreInterfaceMock_Expecter) ListFlowVersions(ctx interface{}, flowID interface{}) *flowStoreInterfaceMock_ListFlowVersions_Call { + return &flowStoreInterfaceMock_ListFlowVersions_Call{Call: _e.mock.On("ListFlowVersions", ctx, flowID)} } -func (_c *flowStoreInterfaceMock_ListFlowVersions_Call) Run(run func(flowID string)) *flowStoreInterfaceMock_ListFlowVersions_Call { +func (_c *flowStoreInterfaceMock_ListFlowVersions_Call) Run(run func(ctx context.Context, flowID string)) *flowStoreInterfaceMock_ListFlowVersions_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -536,14 +586,14 @@ func (_c *flowStoreInterfaceMock_ListFlowVersions_Call) Return(basicFlowVersions return _c } -func (_c *flowStoreInterfaceMock_ListFlowVersions_Call) RunAndReturn(run func(flowID string) ([]BasicFlowVersion, error)) *flowStoreInterfaceMock_ListFlowVersions_Call { +func (_c *flowStoreInterfaceMock_ListFlowVersions_Call) RunAndReturn(run func(ctx context.Context, flowID string) ([]BasicFlowVersion, error)) *flowStoreInterfaceMock_ListFlowVersions_Call { _c.Call.Return(run) return _c } // ListFlows provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) ListFlows(limit int, offset int, flowType string) ([]BasicFlowDefinition, int, error) { - ret := _mock.Called(limit, offset, flowType) +func (_mock *flowStoreInterfaceMock) ListFlows(ctx context.Context, limit int, offset int, flowType string) ([]BasicFlowDefinition, int, error) { + ret := _mock.Called(ctx, limit, offset, flowType) if len(ret) == 0 { panic("no return value specified for ListFlows") @@ -552,23 +602,23 @@ func (_mock *flowStoreInterfaceMock) ListFlows(limit int, offset int, flowType s var r0 []BasicFlowDefinition var r1 int var r2 error - if returnFunc, ok := ret.Get(0).(func(int, int, string) ([]BasicFlowDefinition, int, error)); ok { - return returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, int, int, string) ([]BasicFlowDefinition, int, error)); ok { + return returnFunc(ctx, limit, offset, flowType) } - if returnFunc, ok := ret.Get(0).(func(int, int, string) []BasicFlowDefinition); ok { - r0 = returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, int, int, string) []BasicFlowDefinition); ok { + r0 = returnFunc(ctx, limit, offset, flowType) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]BasicFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(int, int, string) int); ok { - r1 = returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(1).(func(context.Context, int, int, string) int); ok { + r1 = returnFunc(ctx, limit, offset, flowType) } else { r1 = ret.Get(1).(int) } - if returnFunc, ok := ret.Get(2).(func(int, int, string) error); ok { - r2 = returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(2).(func(context.Context, int, int, string) error); ok { + r2 = returnFunc(ctx, limit, offset, flowType) } else { r2 = ret.Error(2) } @@ -581,31 +631,37 @@ type flowStoreInterfaceMock_ListFlows_Call struct { } // ListFlows is a helper method to define mock.On call +// - ctx context.Context // - limit int // - offset int // - flowType string -func (_e *flowStoreInterfaceMock_Expecter) ListFlows(limit interface{}, offset interface{}, flowType interface{}) *flowStoreInterfaceMock_ListFlows_Call { - return &flowStoreInterfaceMock_ListFlows_Call{Call: _e.mock.On("ListFlows", limit, offset, flowType)} +func (_e *flowStoreInterfaceMock_Expecter) ListFlows(ctx interface{}, limit interface{}, offset interface{}, flowType interface{}) *flowStoreInterfaceMock_ListFlows_Call { + return &flowStoreInterfaceMock_ListFlows_Call{Call: _e.mock.On("ListFlows", ctx, limit, offset, flowType)} } -func (_c *flowStoreInterfaceMock_ListFlows_Call) Run(run func(limit int, offset int, flowType string)) *flowStoreInterfaceMock_ListFlows_Call { +func (_c *flowStoreInterfaceMock_ListFlows_Call) Run(run func(ctx context.Context, limit int, offset int, flowType string)) *flowStoreInterfaceMock_ListFlows_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 int + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(int) + arg0 = args[0].(context.Context) } var arg1 int if args[1] != nil { arg1 = args[1].(int) } - var arg2 string + var arg2 int if args[2] != nil { - arg2 = args[2].(string) + arg2 = args[2].(int) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -616,14 +672,14 @@ func (_c *flowStoreInterfaceMock_ListFlows_Call) Return(basicFlowDefinitions []B return _c } -func (_c *flowStoreInterfaceMock_ListFlows_Call) RunAndReturn(run func(limit int, offset int, flowType string) ([]BasicFlowDefinition, int, error)) *flowStoreInterfaceMock_ListFlows_Call { +func (_c *flowStoreInterfaceMock_ListFlows_Call) RunAndReturn(run func(ctx context.Context, limit int, offset int, flowType string) ([]BasicFlowDefinition, int, error)) *flowStoreInterfaceMock_ListFlows_Call { _c.Call.Return(run) return _c } // RestoreFlowVersion provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) RestoreFlowVersion(flowID string, version int) (*CompleteFlowDefinition, error) { - ret := _mock.Called(flowID, version) +func (_mock *flowStoreInterfaceMock) RestoreFlowVersion(ctx context.Context, flowID string, version int) (*CompleteFlowDefinition, error) { + ret := _mock.Called(ctx, flowID, version) if len(ret) == 0 { panic("no return value specified for RestoreFlowVersion") @@ -631,18 +687,18 @@ func (_mock *flowStoreInterfaceMock) RestoreFlowVersion(flowID string, version i var r0 *CompleteFlowDefinition var r1 error - if returnFunc, ok := ret.Get(0).(func(string, int) (*CompleteFlowDefinition, error)); ok { - return returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) (*CompleteFlowDefinition, error)); ok { + return returnFunc(ctx, flowID, version) } - if returnFunc, ok := ret.Get(0).(func(string, int) *CompleteFlowDefinition); ok { - r0 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID, version) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, int) error); ok { - r1 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, int) error); ok { + r1 = returnFunc(ctx, flowID, version) } else { r1 = ret.Error(1) } @@ -655,25 +711,31 @@ type flowStoreInterfaceMock_RestoreFlowVersion_Call struct { } // RestoreFlowVersion is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - version int -func (_e *flowStoreInterfaceMock_Expecter) RestoreFlowVersion(flowID interface{}, version interface{}) *flowStoreInterfaceMock_RestoreFlowVersion_Call { - return &flowStoreInterfaceMock_RestoreFlowVersion_Call{Call: _e.mock.On("RestoreFlowVersion", flowID, version)} +func (_e *flowStoreInterfaceMock_Expecter) RestoreFlowVersion(ctx interface{}, flowID interface{}, version interface{}) *flowStoreInterfaceMock_RestoreFlowVersion_Call { + return &flowStoreInterfaceMock_RestoreFlowVersion_Call{Call: _e.mock.On("RestoreFlowVersion", ctx, flowID, version)} } -func (_c *flowStoreInterfaceMock_RestoreFlowVersion_Call) Run(run func(flowID string, version int)) *flowStoreInterfaceMock_RestoreFlowVersion_Call { +func (_c *flowStoreInterfaceMock_RestoreFlowVersion_Call) Run(run func(ctx context.Context, flowID string, version int)) *flowStoreInterfaceMock_RestoreFlowVersion_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 int + var arg1 string if args[1] != nil { - arg1 = args[1].(int) + arg1 = args[1].(string) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) } run( arg0, arg1, + arg2, ) }) return _c @@ -684,14 +746,14 @@ func (_c *flowStoreInterfaceMock_RestoreFlowVersion_Call) Return(completeFlowDef return _c } -func (_c *flowStoreInterfaceMock_RestoreFlowVersion_Call) RunAndReturn(run func(flowID string, version int) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_RestoreFlowVersion_Call { +func (_c *flowStoreInterfaceMock_RestoreFlowVersion_Call) RunAndReturn(run func(ctx context.Context, flowID string, version int) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_RestoreFlowVersion_Call { _c.Call.Return(run) return _c } // UpdateFlow provides a mock function for the type flowStoreInterfaceMock -func (_mock *flowStoreInterfaceMock) UpdateFlow(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) { - ret := _mock.Called(flowID, flow) +func (_mock *flowStoreInterfaceMock) UpdateFlow(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) { + ret := _mock.Called(ctx, flowID, flow) if len(ret) == 0 { panic("no return value specified for UpdateFlow") @@ -699,18 +761,18 @@ func (_mock *flowStoreInterfaceMock) UpdateFlow(flowID string, flow *FlowDefinit var r0 *CompleteFlowDefinition var r1 error - if returnFunc, ok := ret.Get(0).(func(string, *FlowDefinition) (*CompleteFlowDefinition, error)); ok { - return returnFunc(flowID, flow) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *FlowDefinition) (*CompleteFlowDefinition, error)); ok { + return returnFunc(ctx, flowID, flow) } - if returnFunc, ok := ret.Get(0).(func(string, *FlowDefinition) *CompleteFlowDefinition); ok { - r0 = returnFunc(flowID, flow) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *FlowDefinition) *CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID, flow) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, *FlowDefinition) error); ok { - r1 = returnFunc(flowID, flow) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, *FlowDefinition) error); ok { + r1 = returnFunc(ctx, flowID, flow) } else { r1 = ret.Error(1) } @@ -723,25 +785,31 @@ type flowStoreInterfaceMock_UpdateFlow_Call struct { } // UpdateFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - flow *FlowDefinition -func (_e *flowStoreInterfaceMock_Expecter) UpdateFlow(flowID interface{}, flow interface{}) *flowStoreInterfaceMock_UpdateFlow_Call { - return &flowStoreInterfaceMock_UpdateFlow_Call{Call: _e.mock.On("UpdateFlow", flowID, flow)} +func (_e *flowStoreInterfaceMock_Expecter) UpdateFlow(ctx interface{}, flowID interface{}, flow interface{}) *flowStoreInterfaceMock_UpdateFlow_Call { + return &flowStoreInterfaceMock_UpdateFlow_Call{Call: _e.mock.On("UpdateFlow", ctx, flowID, flow)} } -func (_c *flowStoreInterfaceMock_UpdateFlow_Call) Run(run func(flowID string, flow *FlowDefinition)) *flowStoreInterfaceMock_UpdateFlow_Call { +func (_c *flowStoreInterfaceMock_UpdateFlow_Call) Run(run func(ctx context.Context, flowID string, flow *FlowDefinition)) *flowStoreInterfaceMock_UpdateFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 *FlowDefinition + var arg1 string if args[1] != nil { - arg1 = args[1].(*FlowDefinition) + arg1 = args[1].(string) + } + var arg2 *FlowDefinition + if args[2] != nil { + arg2 = args[2].(*FlowDefinition) } run( arg0, arg1, + arg2, ) }) return _c @@ -752,7 +820,7 @@ func (_c *flowStoreInterfaceMock_UpdateFlow_Call) Return(completeFlowDefinition return _c } -func (_c *flowStoreInterfaceMock_UpdateFlow_Call) RunAndReturn(run func(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_UpdateFlow_Call { +func (_c *flowStoreInterfaceMock_UpdateFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error)) *flowStoreInterfaceMock_UpdateFlow_Call { _c.Call.Return(run) return _c } diff --git a/backend/internal/flow/mgt/graphBuilderInterface_mock_test.go b/backend/internal/flow/mgt/graphBuilderInterface_mock_test.go index b97522305..7f7954fa9 100644 --- a/backend/internal/flow/mgt/graphBuilderInterface_mock_test.go +++ b/backend/internal/flow/mgt/graphBuilderInterface_mock_test.go @@ -96,7 +96,7 @@ func (_c *graphBuilderInterfaceMock_GetGraph_Call) Return(graphInterface core.Gr return _c } -func (_c *graphBuilderInterfaceMock_GetGraph_Call) RunAndReturn(run func(flow *CompleteFlowDefinition) (core.GraphInterface, *serviceerror.ServiceError)) *graphBuilderInterfaceMock_GetGraph_Call { +func (_c *graphBuilderInterfaceMock_GetGraph_Call) RunAndReturn(run func(*CompleteFlowDefinition) (core.GraphInterface, *serviceerror.ServiceError)) *graphBuilderInterfaceMock_GetGraph_Call { _c.Call.Return(run) return _c } @@ -136,7 +136,7 @@ func (_c *graphBuilderInterfaceMock_InvalidateCache_Call) Return() *graphBuilder return _c } -func (_c *graphBuilderInterfaceMock_InvalidateCache_Call) RunAndReturn(run func(flowID string)) *graphBuilderInterfaceMock_InvalidateCache_Call { +func (_c *graphBuilderInterfaceMock_InvalidateCache_Call) RunAndReturn(run func(string)) *graphBuilderInterfaceMock_InvalidateCache_Call { _c.Run(run) return _c } diff --git a/backend/internal/flow/mgt/handler.go b/backend/internal/flow/mgt/handler.go index 0235818ba..2be8fa242 100644 --- a/backend/internal/flow/mgt/handler.go +++ b/backend/internal/flow/mgt/handler.go @@ -65,6 +65,7 @@ func newFlowMgtHandler( // listFlows handles GET requests to list flow definitions with pagination and optional filtering. func (h *flowMgtHandler) listFlows(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() limit, offset, svcErr := parsePaginationParams(r) if svcErr != nil { handleError(w, svcErr) @@ -74,7 +75,7 @@ func (h *flowMgtHandler) listFlows(w http.ResponseWriter, r *http.Request) { flowTypeStr := r.URL.Query().Get(queryParamFlowType) flowType := common.FlowType(flowTypeStr) - flowList, svcErr := h.service.ListFlows(limit, offset, flowType) + flowList, svcErr := h.service.ListFlows(ctx, limit, offset, flowType) if svcErr != nil { handleError(w, svcErr) return @@ -86,6 +87,7 @@ func (h *flowMgtHandler) listFlows(w http.ResponseWriter, r *http.Request) { // createFlow handles POST requests to create a new flow definition. func (h *flowMgtHandler) createFlow(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() flowDefRequest, err := utils.DecodeJSONBody[FlowDefinition](r) if err != nil { handleInvalidRequestError(w) @@ -93,7 +95,7 @@ func (h *flowMgtHandler) createFlow(w http.ResponseWriter, r *http.Request) { } sanitized := sanitizeFlowDefinitionRequest(flowDefRequest) - createdFlow, svcErr := h.service.CreateFlow(sanitized) + createdFlow, svcErr := h.service.CreateFlow(ctx, sanitized) if svcErr != nil { handleError(w, svcErr) return @@ -105,13 +107,14 @@ func (h *flowMgtHandler) createFlow(w http.ResponseWriter, r *http.Request) { // getFlow handles GET requests to retrieve a flow definition by its ID. func (h *flowMgtHandler) getFlow(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() flowID := r.PathValue(pathParamFlowID) if flowID == "" { handleError(w, &ErrorMissingFlowID) return } - flow, svcErr := h.service.GetFlow(flowID) + flow, svcErr := h.service.GetFlow(ctx, flowID) if svcErr != nil { handleError(w, svcErr) return @@ -123,6 +126,7 @@ func (h *flowMgtHandler) getFlow(w http.ResponseWriter, r *http.Request) { // updateFlow handles PUT requests to update an existing flow definition. func (h *flowMgtHandler) updateFlow(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() flowID := r.PathValue(pathParamFlowID) if flowID == "" { handleError(w, &ErrorMissingFlowID) @@ -136,7 +140,7 @@ func (h *flowMgtHandler) updateFlow(w http.ResponseWriter, r *http.Request) { } sanitized := sanitizeFlowDefinitionRequest(flowDefRequest) - updatedFlow, svcErr := h.service.UpdateFlow(flowID, sanitized) + updatedFlow, svcErr := h.service.UpdateFlow(ctx, flowID, sanitized) if svcErr != nil { handleError(w, svcErr) return @@ -148,13 +152,14 @@ func (h *flowMgtHandler) updateFlow(w http.ResponseWriter, r *http.Request) { // deleteFlow handles DELETE requests to remove a flow definition by its ID. func (h *flowMgtHandler) deleteFlow(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() flowID := r.PathValue(pathParamFlowID) if flowID == "" { handleError(w, &ErrorMissingFlowID) return } - svcErr := h.service.DeleteFlow(flowID) + svcErr := h.service.DeleteFlow(ctx, flowID) if svcErr != nil { handleError(w, svcErr) return @@ -168,13 +173,14 @@ func (h *flowMgtHandler) deleteFlow(w http.ResponseWriter, r *http.Request) { // listFlowVersions handles GET requests to list all versions of a specific flow definition. func (h *flowMgtHandler) listFlowVersions(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() flowID := r.PathValue(pathParamFlowID) if flowID == "" { handleError(w, &ErrorMissingFlowID) return } - versionList, svcErr := h.service.ListFlowVersions(flowID) + versionList, svcErr := h.service.ListFlowVersions(ctx, flowID) if svcErr != nil { handleError(w, svcErr) return @@ -186,6 +192,7 @@ func (h *flowMgtHandler) listFlowVersions(w http.ResponseWriter, r *http.Request // getFlowVersion handles GET requests to retrieve a specific version of a flow definition. func (h *flowMgtHandler) getFlowVersion(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() flowID := r.PathValue(pathParamFlowID) versionStr := r.PathValue(pathParamVersion) @@ -200,7 +207,7 @@ func (h *flowMgtHandler) getFlowVersion(w http.ResponseWriter, r *http.Request) return } - flowVersion, svcErr := h.service.GetFlowVersion(flowID, version) + flowVersion, svcErr := h.service.GetFlowVersion(ctx, flowID, version) if svcErr != nil { handleError(w, svcErr) return @@ -213,6 +220,7 @@ func (h *flowMgtHandler) getFlowVersion(w http.ResponseWriter, r *http.Request) // restoreFlowVersion handles POST requests to restore a specific version of a flow definition. func (h *flowMgtHandler) restoreFlowVersion(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() flowID := r.PathValue(pathParamFlowID) if flowID == "" { handleError(w, &ErrorMissingFlowID) @@ -225,7 +233,7 @@ func (h *flowMgtHandler) restoreFlowVersion(w http.ResponseWriter, r *http.Reque return } - flow, svcErr := h.service.RestoreFlowVersion(flowID, request.Version) + flow, svcErr := h.service.RestoreFlowVersion(ctx, flowID, request.Version) if svcErr != nil { handleError(w, svcErr) return diff --git a/backend/internal/flow/mgt/handler_test.go b/backend/internal/flow/mgt/handler_test.go index 1d2c518bf..b3a87b013 100644 --- a/backend/internal/flow/mgt/handler_test.go +++ b/backend/internal/flow/mgt/handler_test.go @@ -25,6 +25,7 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "github.com/asgardeo/thunder/internal/flow/common" @@ -59,7 +60,7 @@ func (s *FlowMgtHandlerTestSuite) TestListFlows_Success() { Count: 2, } - s.mockService.EXPECT().ListFlows(30, 0, common.FlowType("")).Return(expectedList, nil) + s.mockService.EXPECT().ListFlows(mock.Anything, 30, 0, common.FlowType("")).Return(expectedList, nil) req := httptest.NewRequest(http.MethodGet, "/flows", nil) w := httptest.NewRecorder() @@ -77,7 +78,7 @@ func (s *FlowMgtHandlerTestSuite) TestListFlows_Success() { func (s *FlowMgtHandlerTestSuite) TestListFlows_WithPagination() { expectedList := &FlowListResponse{Flows: []BasicFlowDefinition{}, Count: 0} - s.mockService.EXPECT().ListFlows(20, 10, common.FlowType("")).Return(expectedList, nil) + s.mockService.EXPECT().ListFlows(mock.Anything, 20, 10, common.FlowType("")).Return(expectedList, nil) req := httptest.NewRequest(http.MethodGet, "/flows?limit=20&offset=10", nil) w := httptest.NewRecorder() @@ -90,7 +91,7 @@ func (s *FlowMgtHandlerTestSuite) TestListFlows_WithPagination() { func (s *FlowMgtHandlerTestSuite) TestListFlows_WithFlowType() { expectedList := &FlowListResponse{Flows: []BasicFlowDefinition{}, Count: 0} - s.mockService.EXPECT().ListFlows(30, 0, common.FlowTypeAuthentication).Return(expectedList, nil) + s.mockService.EXPECT().ListFlows(mock.Anything, 30, 0, common.FlowTypeAuthentication).Return(expectedList, nil) req := httptest.NewRequest(http.MethodGet, "/flows?flowType=AUTHENTICATION", nil) w := httptest.NewRecorder() @@ -128,7 +129,7 @@ func (s *FlowMgtHandlerTestSuite) TestListFlows_InvalidOffset() { } func (s *FlowMgtHandlerTestSuite) TestListFlows_ServiceError() { - s.mockService.EXPECT().ListFlows(30, 0, common.FlowType("")). + s.mockService.EXPECT().ListFlows(mock.Anything, 30, 0, common.FlowType("")). Return(nil, &serviceerror.InternalServerError) req := httptest.NewRequest(http.MethodGet, "/flows", nil) @@ -158,7 +159,7 @@ func (s *FlowMgtHandlerTestSuite) TestCreateFlow_Success() { Nodes: flowDef.Nodes, } - s.mockService.EXPECT().CreateFlow(flowDef).Return(createdFlow, nil) + s.mockService.EXPECT().CreateFlow(mock.Anything, flowDef).Return(createdFlow, nil) body, _ := json.Marshal(flowDef) req := httptest.NewRequest(http.MethodPost, "/flows", bytes.NewReader(body)) @@ -192,7 +193,7 @@ func (s *FlowMgtHandlerTestSuite) TestCreateFlow_ServiceError() { FlowType: common.FlowTypeAuthentication, } - s.mockService.EXPECT().CreateFlow(flowDef).Return(nil, &ErrorInvalidFlowData) + s.mockService.EXPECT().CreateFlow(mock.Anything, flowDef).Return(nil, &ErrorInvalidFlowData) body, _ := json.Marshal(flowDef) req := httptest.NewRequest(http.MethodPost, "/flows", bytes.NewReader(body)) @@ -214,7 +215,7 @@ func (s *FlowMgtHandlerTestSuite) TestGetFlow_Success() { FlowType: common.FlowTypeAuthentication, } - s.mockService.EXPECT().GetFlow(testFlowIDHandler).Return(expectedFlow, nil) + s.mockService.EXPECT().GetFlow(mock.Anything, testFlowIDHandler).Return(expectedFlow, nil) req := httptest.NewRequest(http.MethodGet, "/flows/"+testFlowIDHandler, nil) req.SetPathValue(pathParamFlowID, testFlowIDHandler) @@ -239,7 +240,7 @@ func (s *FlowMgtHandlerTestSuite) TestGetFlow_MissingFlowID() { } func (s *FlowMgtHandlerTestSuite) TestGetFlow_NotFound() { - s.mockService.EXPECT().GetFlow(testFlowIDHandler).Return(nil, &ErrorFlowNotFound) + s.mockService.EXPECT().GetFlow(mock.Anything, testFlowIDHandler).Return(nil, &ErrorFlowNotFound) req := httptest.NewRequest(http.MethodGet, "/flows/"+testFlowIDHandler, nil) req.SetPathValue(pathParamFlowID, testFlowIDHandler) @@ -269,7 +270,7 @@ func (s *FlowMgtHandlerTestSuite) TestUpdateFlow_Success() { Nodes: flowDef.Nodes, } - s.mockService.EXPECT().UpdateFlow(testFlowIDHandler, flowDef).Return(updatedFlow, nil) + s.mockService.EXPECT().UpdateFlow(mock.Anything, testFlowIDHandler, flowDef).Return(updatedFlow, nil) body, _ := json.Marshal(flowDef) req := httptest.NewRequest(http.MethodPut, "/flows/"+testFlowIDHandler, bytes.NewReader(body)) @@ -315,7 +316,7 @@ func (s *FlowMgtHandlerTestSuite) TestUpdateFlow_NotFound() { FlowType: common.FlowTypeAuthentication, } - s.mockService.EXPECT().UpdateFlow(testFlowIDHandler, flowDef).Return(nil, &ErrorFlowNotFound) + s.mockService.EXPECT().UpdateFlow(mock.Anything, testFlowIDHandler, flowDef).Return(nil, &ErrorFlowNotFound) body, _ := json.Marshal(flowDef) req := httptest.NewRequest(http.MethodPut, "/flows/"+testFlowIDHandler, bytes.NewReader(body)) @@ -331,7 +332,7 @@ func (s *FlowMgtHandlerTestSuite) TestUpdateFlow_NotFound() { // Test deleteFlow func (s *FlowMgtHandlerTestSuite) TestDeleteFlow_Success() { - s.mockService.EXPECT().DeleteFlow(testFlowIDHandler).Return(nil) + s.mockService.EXPECT().DeleteFlow(mock.Anything, testFlowIDHandler).Return(nil) req := httptest.NewRequest(http.MethodDelete, "/flows/"+testFlowIDHandler, nil) req.SetPathValue(pathParamFlowID, testFlowIDHandler) @@ -352,7 +353,7 @@ func (s *FlowMgtHandlerTestSuite) TestDeleteFlow_MissingFlowID() { } func (s *FlowMgtHandlerTestSuite) TestDeleteFlow_NotFound() { - s.mockService.EXPECT().DeleteFlow(testFlowIDHandler).Return(&ErrorFlowNotFound) + s.mockService.EXPECT().DeleteFlow(mock.Anything, testFlowIDHandler).Return(&ErrorFlowNotFound) req := httptest.NewRequest(http.MethodDelete, "/flows/"+testFlowIDHandler, nil) req.SetPathValue(pathParamFlowID, testFlowIDHandler) @@ -374,7 +375,7 @@ func (s *FlowMgtHandlerTestSuite) TestListFlowVersions_Success() { TotalVersions: 2, } - s.mockService.EXPECT().ListFlowVersions(testFlowIDHandler).Return(expectedList, nil) + s.mockService.EXPECT().ListFlowVersions(mock.Anything, testFlowIDHandler).Return(expectedList, nil) req := httptest.NewRequest(http.MethodGet, "/flows/"+testFlowIDHandler+"/versions", nil) req.SetPathValue(pathParamFlowID, testFlowIDHandler) @@ -399,7 +400,7 @@ func (s *FlowMgtHandlerTestSuite) TestListFlowVersions_MissingFlowID() { } func (s *FlowMgtHandlerTestSuite) TestListFlowVersions_NotFound() { - s.mockService.EXPECT().ListFlowVersions(testFlowIDHandler).Return(nil, &ErrorFlowNotFound) + s.mockService.EXPECT().ListFlowVersions(mock.Anything, testFlowIDHandler).Return(nil, &ErrorFlowNotFound) req := httptest.NewRequest(http.MethodGet, "/flows/"+testFlowIDHandler+"/versions", nil) req.SetPathValue(pathParamFlowID, testFlowIDHandler) @@ -420,7 +421,7 @@ func (s *FlowMgtHandlerTestSuite) TestGetFlowVersion_Success() { Name: "Test Flow", } - s.mockService.EXPECT().GetFlowVersion(testFlowIDHandler, 1).Return(expectedVersion, nil) + s.mockService.EXPECT().GetFlowVersion(mock.Anything, testFlowIDHandler, 1).Return(expectedVersion, nil) req := httptest.NewRequest(http.MethodGet, "/flows/"+testFlowIDHandler+"/versions/1", nil) req.SetPathValue(pathParamFlowID, testFlowIDHandler) @@ -479,7 +480,7 @@ func (s *FlowMgtHandlerTestSuite) TestGetFlowVersion_ZeroVersion() { } func (s *FlowMgtHandlerTestSuite) TestGetFlowVersion_NotFound() { - s.mockService.EXPECT().GetFlowVersion(testFlowIDHandler, 99).Return(nil, &ErrorVersionNotFound) + s.mockService.EXPECT().GetFlowVersion(mock.Anything, testFlowIDHandler, 99).Return(nil, &ErrorVersionNotFound) req := httptest.NewRequest(http.MethodGet, "/flows/"+testFlowIDHandler+"/versions/99", nil) req.SetPathValue(pathParamFlowID, testFlowIDHandler) @@ -502,7 +503,7 @@ func (s *FlowMgtHandlerTestSuite) TestRestoreFlowVersion_Success() { FlowType: common.FlowTypeAuthentication, } - s.mockService.EXPECT().RestoreFlowVersion(testFlowIDHandler, 1).Return(restoredFlow, nil) + s.mockService.EXPECT().RestoreFlowVersion(mock.Anything, testFlowIDHandler, 1).Return(restoredFlow, nil) body, _ := json.Marshal(request) req := httptest.NewRequest(http.MethodPost, "/flows/"+testFlowIDHandler+"/versions/restore", @@ -547,7 +548,7 @@ func (s *FlowMgtHandlerTestSuite) TestRestoreFlowVersion_InvalidJSON() { func (s *FlowMgtHandlerTestSuite) TestRestoreFlowVersion_NotFound() { request := &RestoreVersionRequest{Version: 99} - s.mockService.EXPECT().RestoreFlowVersion(testFlowIDHandler, 99).Return(nil, &ErrorVersionNotFound) + s.mockService.EXPECT().RestoreFlowVersion(mock.Anything, testFlowIDHandler, 99).Return(nil, &ErrorVersionNotFound) body, _ := json.Marshal(request) req := httptest.NewRequest(http.MethodPost, "/flows/"+testFlowIDHandler+"/versions/restore", diff --git a/backend/internal/flow/mgt/init.go b/backend/internal/flow/mgt/init.go index 3842bacc2..c268d2896 100644 --- a/backend/internal/flow/mgt/init.go +++ b/backend/internal/flow/mgt/init.go @@ -24,7 +24,9 @@ import ( "github.com/asgardeo/thunder/internal/flow/core" "github.com/asgardeo/thunder/internal/flow/executor" "github.com/asgardeo/thunder/internal/system/config" + "github.com/asgardeo/thunder/internal/system/database/provider" declarativeresource "github.com/asgardeo/thunder/internal/system/declarative_resource" + "github.com/asgardeo/thunder/internal/system/log" "github.com/asgardeo/thunder/internal/system/middleware" ) @@ -44,7 +46,11 @@ func Initialize( inferenceService := newFlowInferenceService() graphBuilder := newGraphBuilder(flowFactory, executorRegistry, graphCache) - service := newFlowMgtService(store, inferenceService, graphBuilder, executorRegistry) + txer, err := provider.GetDBProvider().GetConfigDBTransactioner() + if err != nil { + log.GetLogger().Fatal("Failed to get config DB transactioner", log.Error(err)) + } + service := newFlowMgtService(store, inferenceService, graphBuilder, executorRegistry, txer) if config.GetThunderRuntime().Config.DeclarativeResources.Enabled { if err := loadDeclarativeResources(store); err != nil { diff --git a/backend/internal/flow/mgt/service.go b/backend/internal/flow/mgt/service.go index 2507d5dd8..01f2ed8e4 100644 --- a/backend/internal/flow/mgt/service.go +++ b/backend/internal/flow/mgt/service.go @@ -20,6 +20,7 @@ package flowmgt import ( + "context" "errors" "fmt" "regexp" @@ -28,6 +29,7 @@ import ( "github.com/asgardeo/thunder/internal/flow/core" "github.com/asgardeo/thunder/internal/flow/executor" "github.com/asgardeo/thunder/internal/system/config" + "github.com/asgardeo/thunder/internal/system/database/transaction" declarativeresource "github.com/asgardeo/thunder/internal/system/declarative_resource" "github.com/asgardeo/thunder/internal/system/error/serviceerror" "github.com/asgardeo/thunder/internal/system/log" @@ -44,17 +46,21 @@ var handleFormatRegex = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]*[a-z0-9]$|^[a-z0 // FlowMgtServiceInterface defines the interface for the flow management service. type FlowMgtServiceInterface interface { - ListFlows(limit, offset int, flowType common.FlowType) (*FlowListResponse, *serviceerror.ServiceError) - CreateFlow(flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError) - GetFlow(flowID string) (*CompleteFlowDefinition, *serviceerror.ServiceError) - GetFlowByHandle(handle string, flowType common.FlowType) (*CompleteFlowDefinition, *serviceerror.ServiceError) - UpdateFlow(flowID string, flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError) - DeleteFlow(flowID string) *serviceerror.ServiceError - ListFlowVersions(flowID string) (*FlowVersionListResponse, *serviceerror.ServiceError) - GetFlowVersion(flowID string, version int) (*FlowVersion, *serviceerror.ServiceError) - RestoreFlowVersion(flowID string, version int) (*CompleteFlowDefinition, *serviceerror.ServiceError) - GetGraph(flowID string) (core.GraphInterface, *serviceerror.ServiceError) - IsValidFlow(flowID string) bool + ListFlows(ctx context.Context, limit, offset int, flowType common.FlowType) (*FlowListResponse, + *serviceerror.ServiceError) + CreateFlow(ctx context.Context, flowDef *FlowDefinition) (*CompleteFlowDefinition, *serviceerror.ServiceError) + GetFlow(ctx context.Context, flowID string) (*CompleteFlowDefinition, *serviceerror.ServiceError) + GetFlowByHandle(ctx context.Context, handle string, flowType common.FlowType) (*CompleteFlowDefinition, + *serviceerror.ServiceError) + UpdateFlow(ctx context.Context, flowID string, flowDef *FlowDefinition) (*CompleteFlowDefinition, + *serviceerror.ServiceError) + DeleteFlow(ctx context.Context, flowID string) *serviceerror.ServiceError + ListFlowVersions(ctx context.Context, flowID string) (*FlowVersionListResponse, *serviceerror.ServiceError) + GetFlowVersion(ctx context.Context, flowID string, version int) (*FlowVersion, *serviceerror.ServiceError) + RestoreFlowVersion(ctx context.Context, flowID string, version int) (*CompleteFlowDefinition, + *serviceerror.ServiceError) + GetGraph(ctx context.Context, flowID string) (core.GraphInterface, *serviceerror.ServiceError) + IsValidFlow(ctx context.Context, flowID string) bool } // flowMgtService is the default implementation of the FlowMgtServiceInterface. @@ -63,6 +69,7 @@ type flowMgtService struct { inferenceService flowInferenceServiceInterface graphBuilder graphBuilderInterface executorRegistry executor.ExecutorRegistryInterface + transactioner transaction.Transactioner logger *log.Logger } @@ -72,12 +79,14 @@ func newFlowMgtService( inferenceService flowInferenceServiceInterface, graphBuilder graphBuilderInterface, executorRegistry executor.ExecutorRegistryInterface, + transactioner transaction.Transactioner, ) FlowMgtServiceInterface { return &flowMgtService{ store: store, inferenceService: inferenceService, graphBuilder: graphBuilder, executorRegistry: executorRegistry, + transactioner: transactioner, logger: log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)), } } @@ -85,7 +94,7 @@ func newFlowMgtService( // Flow management methods // ListFlows retrieves a paginated list of flow definitions. Supports optional filtering by flow type. -func (s *flowMgtService) ListFlows(limit, offset int, flowType common.FlowType) ( +func (s *flowMgtService) ListFlows(ctx context.Context, limit, offset int, flowType common.FlowType) ( *FlowListResponse, *serviceerror.ServiceError) { if limit <= 0 { limit = defaultPageSize @@ -101,7 +110,7 @@ func (s *flowMgtService) ListFlows(limit, offset int, flowType common.FlowType) return nil, &ErrorInvalidFlowType } - flows, totalCount, err := s.store.ListFlows(limit, offset, string(flowType)) + flows, totalCount, err := s.store.ListFlows(ctx, limit, offset, string(flowType)) if err != nil { s.logger.Error("Failed to list flows", log.Error(err)) return nil, &serviceerror.InternalServerError @@ -119,7 +128,7 @@ func (s *flowMgtService) ListFlows(limit, offset int, flowType common.FlowType) } // CreateFlow creates a new flow definition with version 1. -func (s *flowMgtService) CreateFlow(flowDef *FlowDefinition) ( +func (s *flowMgtService) CreateFlow(ctx context.Context, flowDef *FlowDefinition) ( *CompleteFlowDefinition, *serviceerror.ServiceError) { if err := declarativeresource.CheckDeclarativeCreate(); err != nil { return nil, err @@ -129,47 +138,64 @@ func (s *flowMgtService) CreateFlow(flowDef *FlowDefinition) ( return nil, err } - // Check if a flow with the same handle and type already exists - exists, err := s.store.IsFlowExistsByHandle(flowDef.Handle, flowDef.FlowType) - if err != nil { - s.logger.Error("Failed to check flow existence by handle", log.Error(err)) - return nil, &serviceerror.InternalServerError - } - if exists { - return nil, &ErrorDuplicateFlowHandle - } + var createdFlow *CompleteFlowDefinition + var capturedSvcErr *serviceerror.ServiceError - svcErr := s.applyExecutorDefaultMeta(flowDef) - if svcErr != nil { - return nil, svcErr - } + err := s.transactioner.Transact(ctx, func(ctx context.Context) error { + // Check if a flow with the same handle and type already exists + exists, err := s.store.IsFlowExistsByHandle(ctx, flowDef.Handle, flowDef.FlowType) + if err != nil { + s.logger.Error("Failed to check flow existence by handle", log.Error(err)) + return err + } + if exists { + capturedSvcErr = &ErrorDuplicateFlowHandle + return errors.New("rollback") + } - flowID, genErr := utils.GenerateUUIDv7() - if genErr != nil { - s.logger.Error("Failed to generate UUID v7", log.Error(genErr)) - return nil, &serviceerror.InternalServerError + svcErr := s.applyExecutorDefaultMeta(flowDef) + if svcErr != nil { + capturedSvcErr = svcErr + return errors.New("rollback") + } + + flowID, genErr := utils.GenerateUUIDv7() + if genErr != nil { + s.logger.Error("Failed to generate UUID v7", log.Error(genErr)) + return genErr + } + + createdFlow, err = s.store.CreateFlow(ctx, flowID, flowDef) + if err != nil { + s.logger.Error("Failed to create flow", log.Error(err)) + return err + } + + s.logger.Debug("Flow created successfully", log.String(logKeyFlowID, flowID)) + + s.tryInferRegistrationFlow(ctx, flowID, flowDef) + return nil + }) + + if capturedSvcErr != nil { + return nil, capturedSvcErr } - createdFlow, storeErr := s.store.CreateFlow(flowID, flowDef) - if storeErr != nil { - s.logger.Error("Failed to create flow", log.Error(storeErr)) + if err != nil { return nil, &serviceerror.InternalServerError } - s.logger.Debug("Flow created successfully", log.String(logKeyFlowID, flowID)) - - s.tryInferRegistrationFlow(flowID, flowDef) - return createdFlow, nil } // GetFlow retrieves a flow definition by its ID. -func (s *flowMgtService) GetFlow(flowID string) (*CompleteFlowDefinition, *serviceerror.ServiceError) { +func (s *flowMgtService) GetFlow(ctx context.Context, flowID string) (*CompleteFlowDefinition, + *serviceerror.ServiceError) { if flowID == "" { return nil, &ErrorMissingFlowID } - flow, err := s.store.GetFlowByID(flowID) + flow, err := s.store.GetFlowByID(ctx, flowID) if err != nil { if errors.Is(err, errFlowNotFound) { return nil, &ErrorFlowNotFound @@ -182,7 +208,7 @@ func (s *flowMgtService) GetFlow(flowID string) (*CompleteFlowDefinition, *servi } // GetFlowByHandle retrieves a flow definition by its handle and type. -func (s *flowMgtService) GetFlowByHandle(handle string, flowType common.FlowType) ( +func (s *flowMgtService) GetFlowByHandle(ctx context.Context, handle string, flowType common.FlowType) ( *CompleteFlowDefinition, *serviceerror.ServiceError) { if handle == "" { return nil, &ErrorMissingFlowHandle @@ -191,7 +217,7 @@ func (s *flowMgtService) GetFlowByHandle(handle string, flowType common.FlowType return nil, &ErrorInvalidFlowType } - flow, err := s.store.GetFlowByHandle(handle, flowType) + flow, err := s.store.GetFlowByHandle(ctx, handle, flowType) if err != nil { if errors.Is(err, errFlowNotFound) { return nil, &ErrorFlowNotFound @@ -206,7 +232,7 @@ func (s *flowMgtService) GetFlowByHandle(handle string, flowType common.FlowType // UpdateFlow updates an existing flow definition with the incremented version. // Old versions are retained up to the configured max_version_history limit. -func (s *flowMgtService) UpdateFlow(flowID string, flowDef *FlowDefinition) ( +func (s *flowMgtService) UpdateFlow(ctx context.Context, flowID string, flowDef *FlowDefinition) ( *CompleteFlowDefinition, *serviceerror.ServiceError) { if err := declarativeresource.CheckDeclarativeUpdate(); err != nil { return nil, err @@ -221,47 +247,67 @@ func (s *flowMgtService) UpdateFlow(flowID string, flowDef *FlowDefinition) ( logger := s.logger.With(log.String(logKeyFlowID, flowID)) - // Verify the flow exists before updating - existingFlow, err := s.store.GetFlowByID(flowID) - if err != nil { - if errors.Is(err, errFlowNotFound) { - return nil, &ErrorFlowNotFound + var updatedFlow *CompleteFlowDefinition + var capturedSvcErr *serviceerror.ServiceError + + errTx := s.transactioner.Transact(ctx, func(ctx context.Context) error { + // Verify the flow exists before updating + existingFlow, err := s.store.GetFlowByID(ctx, flowID) + if err != nil { + if errors.Is(err, errFlowNotFound) { + capturedSvcErr = &ErrorFlowNotFound + return errors.New("rollback") + } + logger.Error("Failed to get existing flow", log.Error(err)) + return err } - logger.Error("Failed to get existing flow", log.Error(err)) - return nil, &serviceerror.InternalServerError - } - // Prevent changing the flow type - if existingFlow.FlowType != flowDef.FlowType { - return nil, &ErrorCannotUpdateFlowType - } + // Prevent changing the flow type + if existingFlow.FlowType != flowDef.FlowType { + capturedSvcErr = &ErrorCannotUpdateFlowType + return errors.New("rollback") + } - // Prevent changing the handle - if existingFlow.Handle != flowDef.Handle { - return nil, &ErrorHandleUpdateNotAllowed - } + // Prevent changing the handle + if existingFlow.Handle != flowDef.Handle { + capturedSvcErr = &ErrorHandleUpdateNotAllowed + return errors.New("rollback") + } + + svcErr := s.applyExecutorDefaultMeta(flowDef) + if svcErr != nil { + capturedSvcErr = svcErr + return errors.New("rollback") + } - svcErr := s.applyExecutorDefaultMeta(flowDef) - if svcErr != nil { - return nil, svcErr + var errStore error + updatedFlow, errStore = s.store.UpdateFlow(ctx, flowID, flowDef) + if errStore != nil { + logger.Error("Failed to update flow", log.Error(errStore)) + return errStore + } + + logger.Debug("Flow updated successfully") + + // Invalidate the cached graph since the flow has been updated + s.graphBuilder.InvalidateCache(flowID) + + return nil + }) + + if capturedSvcErr != nil { + return nil, capturedSvcErr } - updatedFlow, err := s.store.UpdateFlow(flowID, flowDef) - if err != nil { - logger.Error("Failed to update flow", log.Error(err)) + if errTx != nil { return nil, &serviceerror.InternalServerError } - logger.Debug("Flow updated successfully") - - // Invalidate the cached graph since the flow has been updated - s.graphBuilder.InvalidateCache(flowID) - return updatedFlow, nil } // DeleteFlow deletes a flow definition and all its version history. -func (s *flowMgtService) DeleteFlow(flowID string) *serviceerror.ServiceError { +func (s *flowMgtService) DeleteFlow(ctx context.Context, flowID string) *serviceerror.ServiceError { if err := declarativeresource.CheckDeclarativeDelete(); err != nil { return err } @@ -272,26 +318,33 @@ func (s *flowMgtService) DeleteFlow(flowID string) *serviceerror.ServiceError { logger := s.logger.With(log.String(logKeyFlowID, flowID)) - _, err := s.store.GetFlowByID(flowID) - if err != nil { - if errors.Is(err, errFlowNotFound) { - // Silently return if the flow does not exist - return nil + errTx := s.transactioner.Transact(ctx, func(ctx context.Context) error { + _, err := s.store.GetFlowByID(ctx, flowID) + if err != nil { + if errors.Is(err, errFlowNotFound) { + // Silently return if the flow does not exist + return nil + } + logger.Error("Failed to get existing flow", log.Error(err)) + return err } - logger.Error("Failed to get existing flow", log.Error(err)) - return &serviceerror.InternalServerError - } - err = s.store.DeleteFlow(flowID) - if err != nil { - logger.Error("Failed to delete flow", log.Error(err)) - return &serviceerror.InternalServerError - } + err = s.store.DeleteFlow(ctx, flowID) + if err != nil { + logger.Error("Failed to delete flow", log.Error(err)) + return err + } - logger.Debug("Flow deleted successfully") + logger.Debug("Flow deleted successfully") - // Invalidate the cached graph since the flow has been deleted - s.graphBuilder.InvalidateCache(flowID) + // Invalidate the cached graph since the flow has been deleted + s.graphBuilder.InvalidateCache(flowID) + return nil + }) + + if errTx != nil { + return &serviceerror.InternalServerError + } return nil } @@ -299,7 +352,7 @@ func (s *flowMgtService) DeleteFlow(flowID string) *serviceerror.ServiceError { // Flow version management methods // ListFlowVersions retrieves all versions of a flow definition. -func (s *flowMgtService) ListFlowVersions(flowID string) ( +func (s *flowMgtService) ListFlowVersions(ctx context.Context, flowID string) ( *FlowVersionListResponse, *serviceerror.ServiceError) { if flowID == "" { return nil, &ErrorMissingFlowID @@ -307,7 +360,7 @@ func (s *flowMgtService) ListFlowVersions(flowID string) ( logger := s.logger.With(log.String(logKeyFlowID, flowID)) - _, err := s.store.GetFlowByID(flowID) + _, err := s.store.GetFlowByID(ctx, flowID) if err != nil { if errors.Is(err, errFlowNotFound) { return nil, &ErrorFlowNotFound @@ -316,7 +369,7 @@ func (s *flowMgtService) ListFlowVersions(flowID string) ( return nil, &serviceerror.InternalServerError } - versions, err := s.store.ListFlowVersions(flowID) + versions, err := s.store.ListFlowVersions(ctx, flowID) if err != nil { logger.Error("Failed to list flow versions", log.Error(err)) return nil, &serviceerror.InternalServerError @@ -331,7 +384,7 @@ func (s *flowMgtService) ListFlowVersions(flowID string) ( } // GetFlowVersion retrieves a specific version of a flow definition. -func (s *flowMgtService) GetFlowVersion(flowID string, version int) ( +func (s *flowMgtService) GetFlowVersion(ctx context.Context, flowID string, version int) ( *FlowVersion, *serviceerror.ServiceError) { if flowID == "" { return nil, &ErrorMissingFlowID @@ -340,7 +393,7 @@ func (s *flowMgtService) GetFlowVersion(flowID string, version int) ( return nil, &ErrorInvalidVersion } - flowVersion, err := s.store.GetFlowVersion(flowID, version) + flowVersion, err := s.store.GetFlowVersion(ctx, flowID, version) if err != nil { if errors.Is(err, errFlowNotFound) { return nil, &ErrorFlowNotFound @@ -358,7 +411,7 @@ func (s *flowMgtService) GetFlowVersion(flowID string, version int) ( // RestoreFlowVersion restores a specific version as the active version. // Creates a new version by copying the configuration from the specified version. -func (s *flowMgtService) RestoreFlowVersion(flowID string, version int) ( +func (s *flowMgtService) RestoreFlowVersion(ctx context.Context, flowID string, version int) ( *CompleteFlowDefinition, *serviceerror.ServiceError) { if flowID == "" { return nil, &ErrorMissingFlowID @@ -369,42 +422,60 @@ func (s *flowMgtService) RestoreFlowVersion(flowID string, version int) ( logger := s.logger.With(log.String(logKeyFlowID, flowID), log.Int(logKeyVersion, version)) - _, err := s.store.GetFlowVersion(flowID, version) - if err != nil { - if errors.Is(err, errFlowNotFound) { - return nil, &ErrorFlowNotFound + var restoredFlow *CompleteFlowDefinition + var capturedSvcErr *serviceerror.ServiceError + + errTx := s.transactioner.Transact(ctx, func(ctx context.Context) error { + _, err := s.store.GetFlowVersion(ctx, flowID, version) + if err != nil { + if errors.Is(err, errFlowNotFound) { + capturedSvcErr = &ErrorFlowNotFound + return errors.New("rollback") + } + if errors.Is(err, errVersionNotFound) { + capturedSvcErr = &ErrorVersionNotFound + return errors.New("rollback") + } + logger.Error("Failed to get flow version for restore", log.Error(err)) + return err } - if errors.Is(err, errVersionNotFound) { - return nil, &ErrorVersionNotFound + + var errStore error + restoredFlow, errStore = s.store.RestoreFlowVersion(ctx, flowID, version) + if errStore != nil { + logger.Error("Failed to restore flow version", log.Error(errStore)) + return errStore } - logger.Error("Failed to get flow version for restore", log.Error(err)) - return nil, &serviceerror.InternalServerError + + logger.Debug("Flow version restored successfully") + + // Invalidate the cached graph since a version has been restored + s.graphBuilder.InvalidateCache(flowID) + return nil + }) + + if capturedSvcErr != nil { + return nil, capturedSvcErr } - restoredFlow, err := s.store.RestoreFlowVersion(flowID, version) - if err != nil { - logger.Error("Failed to restore flow version", log.Error(err)) + if errTx != nil { return nil, &serviceerror.InternalServerError } - logger.Debug("Flow version restored successfully") - - // Invalidate the cached graph since a version has been restored - s.graphBuilder.InvalidateCache(flowID) - return restoredFlow, nil } // Graph building methods // GetGraph retrieves or builds a graph for the given flow ID. -func (s *flowMgtService) GetGraph(flowID string) (core.GraphInterface, *serviceerror.ServiceError) { +func (s *flowMgtService) GetGraph(ctx context.Context, flowID string) (core.GraphInterface, + *serviceerror.ServiceError) { if flowID == "" { return nil, &ErrorMissingFlowID } // Fetch flow definition from store - flow, err := s.store.GetFlowByID(flowID) + flow, err := s.store.GetFlowByID(ctx, flowID) if err != nil { if errors.Is(err, errFlowNotFound) { return nil, &ErrorFlowNotFound @@ -418,12 +489,12 @@ func (s *flowMgtService) GetGraph(flowID string) (core.GraphInterface, *servicee } // IsValidFlow checks if a valid flow exists for the given flow ID. -func (s *flowMgtService) IsValidFlow(flowID string) bool { +func (s *flowMgtService) IsValidFlow(ctx context.Context, flowID string) bool { if flowID == "" { return false } - exists, err := s.store.IsFlowExists(flowID) + exists, err := s.store.IsFlowExists(ctx, flowID) if err != nil { s.logger.Error("Failed to check flow existence", log.String(logKeyFlowID, flowID), log.Error(err)) return false @@ -522,7 +593,7 @@ func isValidHandleFormat(handle string) bool { } // tryInferRegistrationFlow attempts to infer and create a registration flow from an authentication flow -func (s *flowMgtService) tryInferRegistrationFlow(authFlowID string, authFlowDef *FlowDefinition) { +func (s *flowMgtService) tryInferRegistrationFlow(ctx context.Context, authFlowID string, authFlowDef *FlowDefinition) { logger := s.logger.With(log.String("authFlowID", authFlowID)) if !config.GetThunderRuntime().Config.Flow.AutoInferRegistration { @@ -558,7 +629,7 @@ func (s *flowMgtService) tryInferRegistrationFlow(authFlowID string, authFlowDef return } - _, storeErr := s.store.CreateFlow(regFlowID, regFlowDef) + _, storeErr := s.store.CreateFlow(ctx, regFlowID, regFlowDef) if storeErr != nil { logger.Error("Failed to create inferred registration flow", log.Error(storeErr)) return diff --git a/backend/internal/flow/mgt/service_test.go b/backend/internal/flow/mgt/service_test.go index a026c5704..58029e1b1 100644 --- a/backend/internal/flow/mgt/service_test.go +++ b/backend/internal/flow/mgt/service_test.go @@ -19,6 +19,7 @@ package flowmgt import ( + "context" "errors" "testing" @@ -42,6 +43,13 @@ type FlowMgtServiceTestSuite struct { mockInference *flowInferenceServiceInterfaceMock mockGraphBuilder *graphBuilderInterfaceMock mockExecutorRegistry *executormock.ExecutorRegistryInterfaceMock + stubTransactioner *stubTransactioner +} + +type stubTransactioner struct{} + +func (s *stubTransactioner) Transact(ctx context.Context, fn func(context.Context) error) error { + return fn(ctx) } func TestFlowMgtServiceTestSuite(t *testing.T) { @@ -53,7 +61,9 @@ func (s *FlowMgtServiceTestSuite) SetupTest() { s.mockInference = newFlowInferenceServiceInterfaceMock(s.T()) s.mockGraphBuilder = newGraphBuilderInterfaceMock(s.T()) s.mockExecutorRegistry = executormock.NewExecutorRegistryInterfaceMock(s.T()) - s.service = newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, s.mockExecutorRegistry) + s.stubTransactioner = &stubTransactioner{} + s.service = newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, s.mockExecutorRegistry, + s.stubTransactioner) testConfig := &config.Config{ Flow: config.FlowConfig{ @@ -74,9 +84,9 @@ func (s *FlowMgtServiceTestSuite) TestListFlows_Success() { expectedFlows := []BasicFlowDefinition{ {ID: "flow1", Handle: "test-handle", Name: "Flow 1", FlowType: common.FlowTypeAuthentication}, } - s.mockStore.EXPECT().ListFlows(30, 0, "").Return(expectedFlows, 1, nil) + s.mockStore.EXPECT().ListFlows(mock.Anything, 30, 0, "").Return(expectedFlows, 1, nil) - result, err := s.service.ListFlows(30, 0, "") + result, err := s.service.ListFlows(context.Background(), 30, 0, "") s.Nil(err) s.NotNil(result) @@ -86,62 +96,62 @@ func (s *FlowMgtServiceTestSuite) TestListFlows_Success() { } func (s *FlowMgtServiceTestSuite) TestListFlows_DefaultLimit() { - s.mockStore.EXPECT().ListFlows(defaultPageSize, 0, "").Return([]BasicFlowDefinition{}, 0, nil) + s.mockStore.EXPECT().ListFlows(mock.Anything, defaultPageSize, 0, "").Return([]BasicFlowDefinition{}, 0, nil) - result, err := s.service.ListFlows(0, 0, "") + result, err := s.service.ListFlows(context.Background(), 0, 0, "") s.Nil(err) s.NotNil(result) } func (s *FlowMgtServiceTestSuite) TestListFlows_MaxLimitExceeded() { - s.mockStore.EXPECT().ListFlows(maxPageSize, 0, "").Return([]BasicFlowDefinition{}, 0, nil) + s.mockStore.EXPECT().ListFlows(mock.Anything, maxPageSize, 0, "").Return([]BasicFlowDefinition{}, 0, nil) - result, err := s.service.ListFlows(1000, 0, "") + result, err := s.service.ListFlows(context.Background(), 1000, 0, "") s.Nil(err) s.NotNil(result) } func (s *FlowMgtServiceTestSuite) TestListFlows_NegativeOffset() { - s.mockStore.EXPECT().ListFlows(30, 0, "").Return([]BasicFlowDefinition{}, 0, nil) + s.mockStore.EXPECT().ListFlows(mock.Anything, 30, 0, "").Return([]BasicFlowDefinition{}, 0, nil) - result, err := s.service.ListFlows(30, -10, "") + result, err := s.service.ListFlows(context.Background(), 30, -10, "") s.Nil(err) s.NotNil(result) } func (s *FlowMgtServiceTestSuite) TestListFlows_WithFlowType() { - s.mockStore.EXPECT().ListFlows(30, 0, string(common.FlowTypeAuthentication)). + s.mockStore.EXPECT().ListFlows(mock.Anything, 30, 0, string(common.FlowTypeAuthentication)). Return([]BasicFlowDefinition{}, 0, nil) - result, err := s.service.ListFlows(30, 0, common.FlowTypeAuthentication) + result, err := s.service.ListFlows(context.Background(), 30, 0, common.FlowTypeAuthentication) s.Nil(err) s.NotNil(result) } func (s *FlowMgtServiceTestSuite) TestListFlows_InvalidFlowType() { - result, err := s.service.ListFlows(30, 0, "invalid") + result, err := s.service.ListFlows(context.Background(), 30, 0, "invalid") s.Nil(result) s.Equal(&ErrorInvalidFlowType, err) } func (s *FlowMgtServiceTestSuite) TestListFlows_StoreError() { - s.mockStore.EXPECT().ListFlows(30, 0, "").Return(nil, 0, errors.New("db error")) + s.mockStore.EXPECT().ListFlows(mock.Anything, 30, 0, "").Return(nil, 0, errors.New("db error")) - result, err := s.service.ListFlows(30, 0, "") + result, err := s.service.ListFlows(context.Background(), 30, 0, "") s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) } func (s *FlowMgtServiceTestSuite) TestListFlows_PaginationLinks() { - s.mockStore.EXPECT().ListFlows(10, 20, "").Return([]BasicFlowDefinition{}, 100, nil) + s.mockStore.EXPECT().ListFlows(mock.Anything, 10, 20, "").Return([]BasicFlowDefinition{}, 100, nil) - result, err := s.service.ListFlows(10, 20, "") + result, err := s.service.ListFlows(context.Background(), 10, 20, "") s.Nil(err) s.NotNil(result) @@ -150,9 +160,9 @@ func (s *FlowMgtServiceTestSuite) TestListFlows_PaginationLinks() { } func (s *FlowMgtServiceTestSuite) TestListFlows_PaginationLinksFirstPage() { - s.mockStore.EXPECT().ListFlows(10, 0, "").Return([]BasicFlowDefinition{}, 100, nil) + s.mockStore.EXPECT().ListFlows(mock.Anything, 10, 0, "").Return([]BasicFlowDefinition{}, 100, nil) - result, err := s.service.ListFlows(10, 0, "") + result, err := s.service.ListFlows(context.Background(), 10, 0, "") s.Nil(err) s.NotNil(result) @@ -161,9 +171,9 @@ func (s *FlowMgtServiceTestSuite) TestListFlows_PaginationLinksFirstPage() { } func (s *FlowMgtServiceTestSuite) TestListFlows_PaginationLinksLastPage() { - s.mockStore.EXPECT().ListFlows(10, 90, "").Return([]BasicFlowDefinition{}, 100, nil) + s.mockStore.EXPECT().ListFlows(mock.Anything, 10, 90, "").Return([]BasicFlowDefinition{}, 100, nil) - result, err := s.service.ListFlows(10, 90, "") + result, err := s.service.ListFlows(context.Background(), 10, 90, "") s.Nil(err) s.NotNil(result) @@ -190,10 +200,11 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_Success() { FlowType: common.FlowTypeAuthentication, ActiveVersion: 1, } - s.mockStore.EXPECT().IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication).Return(false, nil) - s.mockStore.EXPECT().CreateFlow(mock.Anything, flowDef).Return(expectedFlow, nil) + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "test-handle", common.FlowTypeAuthentication). + Return(false, nil) + s.mockStore.EXPECT().CreateFlow(mock.Anything, mock.Anything, flowDef).Return(expectedFlow, nil) - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(err) s.NotNil(result) @@ -208,7 +219,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_ValidationError() { Nodes: []NodeDefinition{{Type: "start"}, {Type: "end"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&ErrorMissingFlowHandle, err) @@ -222,7 +233,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_InvalidHandleFormat_Uppercase() Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&ErrorInvalidFlowHandleFormat, err) @@ -236,7 +247,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_InvalidHandleFormat_Spaces() { Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&ErrorInvalidFlowHandleFormat, err) @@ -250,7 +261,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_InvalidHandleFormat_SpecialChar Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&ErrorInvalidFlowHandleFormat, err) @@ -264,7 +275,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_InvalidHandleFormat_StartsWithD Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&ErrorInvalidFlowHandleFormat, err) @@ -278,7 +289,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_InvalidHandleFormat_EndsWithUnd Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&ErrorInvalidFlowHandleFormat, err) @@ -318,10 +329,11 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_ValidHandleFormats() { Nodes: flowDef.Nodes, } - s.mockStore.EXPECT().IsFlowExistsByHandle(tc.handle, common.FlowTypeAuthentication).Return(false, nil) - s.mockStore.EXPECT().CreateFlow(mock.Anything, flowDef).Return(expectedFlow, nil) + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, tc.handle, common.FlowTypeAuthentication). + Return(false, nil) + s.mockStore.EXPECT().CreateFlow(mock.Anything, mock.Anything, flowDef).Return(expectedFlow, nil) - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(err) s.NotNil(result) @@ -338,7 +350,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_InvalidFlowType() { Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&ErrorInvalidFlowType, err) @@ -352,7 +364,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_InsufficientNodes() { Nodes: []NodeDefinition{{Type: "start"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(ErrorInvalidFlowData.Code, err.Code) @@ -366,7 +378,7 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_OnlyStartAndEnd() { Nodes: []NodeDefinition{{Type: "start"}, {Type: "end"}}, } - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(ErrorInvalidFlowData.Code, err.Code) @@ -379,10 +391,11 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_StoreError() { FlowType: common.FlowTypeAuthentication, Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - s.mockStore.EXPECT().IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication).Return(false, nil) - s.mockStore.EXPECT().CreateFlow(mock.Anything, flowDef).Return(nil, errors.New("db error")) + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "test-handle", common.FlowTypeAuthentication). + Return(false, nil) + s.mockStore.EXPECT().CreateFlow(mock.Anything, mock.Anything, flowDef).Return(nil, errors.New("db error")) - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -418,12 +431,13 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_WithAutoInference() { Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - s.mockStore.EXPECT().IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication).Return(false, nil) - s.mockStore.EXPECT().CreateFlow(mock.Anything, flowDef).Return(expectedFlow, nil) + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "test-handle", common.FlowTypeAuthentication). + Return(false, nil) + s.mockStore.EXPECT().CreateFlow(mock.Anything, mock.Anything, flowDef).Return(expectedFlow, nil) s.mockInference.EXPECT().InferRegistrationFlow(flowDef).Return(inferredRegFlow, nil) - s.mockStore.EXPECT().CreateFlow(mock.Anything, inferredRegFlow).Return(nil, nil) + s.mockStore.EXPECT().CreateFlow(mock.Anything, mock.Anything, inferredRegFlow).Return(nil, nil) - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(err) s.NotNil(result) @@ -454,12 +468,13 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_AutoInferenceFailure() { } // Mock expectations in the correct order of execution - s.mockStore.EXPECT().IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication).Return(false, nil) - s.mockStore.EXPECT().CreateFlow(mock.Anything, flowDef).Return(expectedFlow, nil) + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "test-handle", common.FlowTypeAuthentication). + Return(false, nil) + s.mockStore.EXPECT().CreateFlow(mock.Anything, mock.Anything, flowDef).Return(expectedFlow, nil) s.mockInference.EXPECT().InferRegistrationFlow(flowDef).Return(nil, errors.New("inference error")) // Should still succeed even if inference fails - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(err) s.NotNil(result) @@ -472,10 +487,10 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_DuplicateHandle() { FlowType: common.FlowTypeAuthentication, Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - s.mockStore.EXPECT().IsFlowExistsByHandle("existing-handle", common.FlowTypeAuthentication).Return( + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "existing-handle", common.FlowTypeAuthentication).Return( true, nil) - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&ErrorDuplicateFlowHandle, err) @@ -488,10 +503,10 @@ func (s *FlowMgtServiceTestSuite) TestCreateFlow_DuplicateHandleCheckError() { FlowType: common.FlowTypeAuthentication, Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - s.mockStore.EXPECT().IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication).Return( + s.mockStore.EXPECT().IsFlowExistsByHandle(mock.Anything, "test-handle", common.FlowTypeAuthentication).Return( false, errors.New("db error")) - result, err := s.service.CreateFlow(flowDef) + result, err := s.service.CreateFlow(context.Background(), flowDef) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -505,34 +520,34 @@ func (s *FlowMgtServiceTestSuite) TestGetFlow_Success() { Handle: "test-handle", Name: "Test", } - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(expectedFlow, nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(expectedFlow, nil) - result, err := s.service.GetFlow(testFlowIDService) + result, err := s.service.GetFlow(context.Background(), testFlowIDService) s.Nil(err) s.Equal(expectedFlow, result) } func (s *FlowMgtServiceTestSuite) TestGetFlow_EmptyID() { - result, err := s.service.GetFlow("") + result, err := s.service.GetFlow(context.Background(), "") s.Nil(result) s.Equal(&ErrorMissingFlowID, err) } func (s *FlowMgtServiceTestSuite) TestGetFlow_NotFound() { - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(nil, errFlowNotFound) - result, err := s.service.GetFlow(testFlowIDService) + result, err := s.service.GetFlow(context.Background(), testFlowIDService) s.Nil(result) s.Equal(&ErrorFlowNotFound, err) } func (s *FlowMgtServiceTestSuite) TestGetFlow_StoreError() { - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(nil, errors.New("db error")) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(nil, errors.New("db error")) - result, err := s.service.GetFlow(testFlowIDService) + result, err := s.service.GetFlow(context.Background(), testFlowIDService) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -547,10 +562,10 @@ func (s *FlowMgtServiceTestSuite) TestGetFlowByHandle_Success() { Name: "Test Auth Flow", FlowType: common.FlowTypeAuthentication, } - s.mockStore.EXPECT().GetFlowByHandle("test-auth-flow", common.FlowTypeAuthentication). + s.mockStore.EXPECT().GetFlowByHandle(mock.Anything, "test-auth-flow", common.FlowTypeAuthentication). Return(expectedFlow, nil) - result, err := s.service.GetFlowByHandle("test-auth-flow", common.FlowTypeAuthentication) + result, err := s.service.GetFlowByHandle(context.Background(), "test-auth-flow", common.FlowTypeAuthentication) s.Nil(err) s.Equal(expectedFlow, result) @@ -565,10 +580,10 @@ func (s *FlowMgtServiceTestSuite) TestGetFlowByHandle_SuccessRegistrationFlow() Name: "Test Registration Flow", FlowType: common.FlowTypeRegistration, } - s.mockStore.EXPECT().GetFlowByHandle("test-reg-flow", common.FlowTypeRegistration). + s.mockStore.EXPECT().GetFlowByHandle(mock.Anything, "test-reg-flow", common.FlowTypeRegistration). Return(expectedFlow, nil) - result, err := s.service.GetFlowByHandle("test-reg-flow", common.FlowTypeRegistration) + result, err := s.service.GetFlowByHandle(context.Background(), "test-reg-flow", common.FlowTypeRegistration) s.Nil(err) s.Equal(expectedFlow, result) @@ -577,41 +592,41 @@ func (s *FlowMgtServiceTestSuite) TestGetFlowByHandle_SuccessRegistrationFlow() } func (s *FlowMgtServiceTestSuite) TestGetFlowByHandle_EmptyHandle() { - result, err := s.service.GetFlowByHandle("", common.FlowTypeAuthentication) + result, err := s.service.GetFlowByHandle(context.Background(), "", common.FlowTypeAuthentication) s.Nil(result) s.Equal(&ErrorMissingFlowHandle, err) } func (s *FlowMgtServiceTestSuite) TestGetFlowByHandle_InvalidFlowType() { - result, err := s.service.GetFlowByHandle("test-handle", "INVALID_TYPE") + result, err := s.service.GetFlowByHandle(context.Background(), "test-handle", "INVALID_TYPE") s.Nil(result) s.Equal(&ErrorInvalidFlowType, err) } func (s *FlowMgtServiceTestSuite) TestGetFlowByHandle_EmptyFlowType() { - result, err := s.service.GetFlowByHandle("test-handle", "") + result, err := s.service.GetFlowByHandle(context.Background(), "test-handle", "") s.Nil(result) s.Equal(&ErrorInvalidFlowType, err) } func (s *FlowMgtServiceTestSuite) TestGetFlowByHandle_NotFound() { - s.mockStore.EXPECT().GetFlowByHandle("non-existent-handle", common.FlowTypeAuthentication). + s.mockStore.EXPECT().GetFlowByHandle(mock.Anything, "non-existent-handle", common.FlowTypeAuthentication). Return(nil, errFlowNotFound) - result, err := s.service.GetFlowByHandle("non-existent-handle", common.FlowTypeAuthentication) + result, err := s.service.GetFlowByHandle(context.Background(), "non-existent-handle", common.FlowTypeAuthentication) s.Nil(result) s.Equal(&ErrorFlowNotFound, err) } func (s *FlowMgtServiceTestSuite) TestGetFlowByHandle_StoreError() { - s.mockStore.EXPECT().GetFlowByHandle("test-handle", common.FlowTypeAuthentication). + s.mockStore.EXPECT().GetFlowByHandle(mock.Anything, "test-handle", common.FlowTypeAuthentication). Return(nil, errors.New("database connection error")) - result, err := s.service.GetFlowByHandle("test-handle", common.FlowTypeAuthentication) + result, err := s.service.GetFlowByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -636,11 +651,11 @@ func (s *FlowMgtServiceTestSuite) TestUpdateFlow_Success() { Name: "Updated", ActiveVersion: 2, } - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(existingFlow, nil) - s.mockStore.EXPECT().UpdateFlow(testFlowIDService, flowDef).Return(updatedFlow, nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(existingFlow, nil) + s.mockStore.EXPECT().UpdateFlow(mock.Anything, testFlowIDService, flowDef).Return(updatedFlow, nil) s.mockGraphBuilder.EXPECT().InvalidateCache(testFlowIDService) - result, err := s.service.UpdateFlow(testFlowIDService, flowDef) + result, err := s.service.UpdateFlow(context.Background(), testFlowIDService, flowDef) s.Nil(err) s.Equal(updatedFlow, result) @@ -649,7 +664,7 @@ func (s *FlowMgtServiceTestSuite) TestUpdateFlow_Success() { func (s *FlowMgtServiceTestSuite) TestUpdateFlow_EmptyID() { flowDef := &FlowDefinition{Name: "Test", FlowType: common.FlowTypeAuthentication} - result, err := s.service.UpdateFlow("", flowDef) + result, err := s.service.UpdateFlow(context.Background(), "", flowDef) s.Nil(result) s.Equal(&ErrorMissingFlowID, err) @@ -658,7 +673,7 @@ func (s *FlowMgtServiceTestSuite) TestUpdateFlow_EmptyID() { func (s *FlowMgtServiceTestSuite) TestUpdateFlow_ValidationError() { flowDef := &FlowDefinition{Handle: "", Name: "", FlowType: common.FlowTypeAuthentication} - result, err := s.service.UpdateFlow(testFlowIDService, flowDef) + result, err := s.service.UpdateFlow(context.Background(), testFlowIDService, flowDef) s.Nil(result) s.Equal(&ErrorMissingFlowHandle, err) @@ -671,9 +686,9 @@ func (s *FlowMgtServiceTestSuite) TestUpdateFlow_FlowNotFound() { FlowType: common.FlowTypeAuthentication, Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(nil, errFlowNotFound) - result, err := s.service.UpdateFlow(testFlowIDService, flowDef) + result, err := s.service.UpdateFlow(context.Background(), testFlowIDService, flowDef) s.Nil(result) s.Equal(&ErrorFlowNotFound, err) @@ -691,9 +706,9 @@ func (s *FlowMgtServiceTestSuite) TestUpdateFlow_CannotChangeFlowType() { FlowType: common.FlowTypeRegistration, Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(existingFlow, nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(existingFlow, nil) - result, err := s.service.UpdateFlow(testFlowIDService, flowDef) + result, err := s.service.UpdateFlow(context.Background(), testFlowIDService, flowDef) s.Nil(result) s.Equal(&ErrorCannotUpdateFlowType, err) @@ -711,9 +726,9 @@ func (s *FlowMgtServiceTestSuite) TestUpdateFlow_CannotChangeHandle() { FlowType: common.FlowTypeAuthentication, Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(existingFlow, nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(existingFlow, nil) - result, err := s.service.UpdateFlow(testFlowIDService, flowDef) + result, err := s.service.UpdateFlow(context.Background(), testFlowIDService, flowDef) s.Nil(result) s.Equal(&ErrorHandleUpdateNotAllowed, err) @@ -731,10 +746,10 @@ func (s *FlowMgtServiceTestSuite) TestUpdateFlow_StoreError() { FlowType: common.FlowTypeAuthentication, Nodes: []NodeDefinition{{Type: "start"}, {Type: "action"}, {Type: "end"}}, } - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(existingFlow, nil) - s.mockStore.EXPECT().UpdateFlow(testFlowIDService, flowDef).Return(nil, errors.New("db error")) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(existingFlow, nil) + s.mockStore.EXPECT().UpdateFlow(mock.Anything, testFlowIDService, flowDef).Return(nil, errors.New("db error")) - result, err := s.service.UpdateFlow(testFlowIDService, flowDef) + result, err := s.service.UpdateFlow(context.Background(), testFlowIDService, flowDef) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -744,43 +759,43 @@ func (s *FlowMgtServiceTestSuite) TestUpdateFlow_StoreError() { func (s *FlowMgtServiceTestSuite) TestDeleteFlow_Success() { existingFlow := &CompleteFlowDefinition{ID: testFlowIDService, Handle: "test-handle"} - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(existingFlow, nil) - s.mockStore.EXPECT().DeleteFlow(testFlowIDService).Return(nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(existingFlow, nil) + s.mockStore.EXPECT().DeleteFlow(mock.Anything, testFlowIDService).Return(nil) s.mockGraphBuilder.EXPECT().InvalidateCache(testFlowIDService) - err := s.service.DeleteFlow(testFlowIDService) + err := s.service.DeleteFlow(context.Background(), testFlowIDService) s.Nil(err) } func (s *FlowMgtServiceTestSuite) TestDeleteFlow_EmptyID() { - err := s.service.DeleteFlow("") + err := s.service.DeleteFlow(context.Background(), "") s.Equal(&ErrorMissingFlowID, err) } func (s *FlowMgtServiceTestSuite) TestDeleteFlow_NotFound() { - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(nil, errFlowNotFound) - err := s.service.DeleteFlow(testFlowIDService) + err := s.service.DeleteFlow(context.Background(), testFlowIDService) s.Nil(err) } func (s *FlowMgtServiceTestSuite) TestDeleteFlow_GetError() { - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(nil, errors.New("db error")) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(nil, errors.New("db error")) - err := s.service.DeleteFlow(testFlowIDService) + err := s.service.DeleteFlow(context.Background(), testFlowIDService) s.Equal(&serviceerror.InternalServerError, err) } func (s *FlowMgtServiceTestSuite) TestDeleteFlow_StoreError() { existingFlow := &CompleteFlowDefinition{ID: testFlowIDService, Handle: "test-handle"} - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(existingFlow, nil) - s.mockStore.EXPECT().DeleteFlow(testFlowIDService).Return(errors.New("db error")) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(existingFlow, nil) + s.mockStore.EXPECT().DeleteFlow(mock.Anything, testFlowIDService).Return(errors.New("db error")) - err := s.service.DeleteFlow(testFlowIDService) + err := s.service.DeleteFlow(context.Background(), testFlowIDService) s.Equal(&serviceerror.InternalServerError, err) } @@ -790,10 +805,10 @@ func (s *FlowMgtServiceTestSuite) TestDeleteFlow_StoreError() { func (s *FlowMgtServiceTestSuite) TestListFlowVersions_Success() { existingFlow := &CompleteFlowDefinition{ID: testFlowIDService, Handle: "test-handle"} versions := []BasicFlowVersion{{Version: 1}, {Version: 2}} - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(existingFlow, nil) - s.mockStore.EXPECT().ListFlowVersions(testFlowIDService).Return(versions, nil) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(existingFlow, nil) + s.mockStore.EXPECT().ListFlowVersions(mock.Anything, testFlowIDService).Return(versions, nil) - result, err := s.service.ListFlowVersions(testFlowIDService) + result, err := s.service.ListFlowVersions(context.Background(), testFlowIDService) s.Nil(err) s.NotNil(result) @@ -802,16 +817,16 @@ func (s *FlowMgtServiceTestSuite) TestListFlowVersions_Success() { } func (s *FlowMgtServiceTestSuite) TestListFlowVersions_EmptyID() { - result, err := s.service.ListFlowVersions("") + result, err := s.service.ListFlowVersions(context.Background(), "") s.Nil(result) s.Equal(&ErrorMissingFlowID, err) } func (s *FlowMgtServiceTestSuite) TestListFlowVersions_FlowNotFound() { - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(nil, errFlowNotFound) - result, err := s.service.ListFlowVersions(testFlowIDService) + result, err := s.service.ListFlowVersions(context.Background(), testFlowIDService) s.Nil(result) s.Equal(&ErrorFlowNotFound, err) @@ -819,10 +834,10 @@ func (s *FlowMgtServiceTestSuite) TestListFlowVersions_FlowNotFound() { func (s *FlowMgtServiceTestSuite) TestListFlowVersions_StoreError() { existingFlow := &CompleteFlowDefinition{ID: testFlowIDService, Handle: "test-handle"} - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(existingFlow, nil) - s.mockStore.EXPECT().ListFlowVersions(testFlowIDService).Return(nil, errors.New("db error")) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(existingFlow, nil) + s.mockStore.EXPECT().ListFlowVersions(mock.Anything, testFlowIDService).Return(nil, errors.New("db error")) - result, err := s.service.ListFlowVersions(testFlowIDService) + result, err := s.service.ListFlowVersions(context.Background(), testFlowIDService) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -832,50 +847,50 @@ func (s *FlowMgtServiceTestSuite) TestListFlowVersions_StoreError() { func (s *FlowMgtServiceTestSuite) TestGetFlowVersion_Success() { expectedVersion := &FlowVersion{Version: 1} - s.mockStore.EXPECT().GetFlowVersion(testFlowIDService, 1).Return(expectedVersion, nil) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, testFlowIDService, 1).Return(expectedVersion, nil) - result, err := s.service.GetFlowVersion(testFlowIDService, 1) + result, err := s.service.GetFlowVersion(context.Background(), testFlowIDService, 1) s.Nil(err) s.Equal(expectedVersion, result) } func (s *FlowMgtServiceTestSuite) TestGetFlowVersion_EmptyID() { - result, err := s.service.GetFlowVersion("", 1) + result, err := s.service.GetFlowVersion(context.Background(), "", 1) s.Nil(result) s.Equal(&ErrorMissingFlowID, err) } func (s *FlowMgtServiceTestSuite) TestGetFlowVersion_InvalidVersion() { - result, err := s.service.GetFlowVersion(testFlowIDService, 0) + result, err := s.service.GetFlowVersion(context.Background(), testFlowIDService, 0) s.Nil(result) s.Equal(&ErrorInvalidVersion, err) } func (s *FlowMgtServiceTestSuite) TestGetFlowVersion_FlowNotFound() { - s.mockStore.EXPECT().GetFlowVersion(testFlowIDService, 1).Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, testFlowIDService, 1).Return(nil, errFlowNotFound) - result, err := s.service.GetFlowVersion(testFlowIDService, 1) + result, err := s.service.GetFlowVersion(context.Background(), testFlowIDService, 1) s.Nil(result) s.Equal(&ErrorFlowNotFound, err) } func (s *FlowMgtServiceTestSuite) TestGetFlowVersion_VersionNotFound() { - s.mockStore.EXPECT().GetFlowVersion(testFlowIDService, 1).Return(nil, errVersionNotFound) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, testFlowIDService, 1).Return(nil, errVersionNotFound) - result, err := s.service.GetFlowVersion(testFlowIDService, 1) + result, err := s.service.GetFlowVersion(context.Background(), testFlowIDService, 1) s.Nil(result) s.Equal(&ErrorVersionNotFound, err) } func (s *FlowMgtServiceTestSuite) TestGetFlowVersion_StoreError() { - s.mockStore.EXPECT().GetFlowVersion(testFlowIDService, 1).Return(nil, errors.New("db error")) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, testFlowIDService, 1).Return(nil, errors.New("db error")) - result, err := s.service.GetFlowVersion(testFlowIDService, 1) + result, err := s.service.GetFlowVersion(context.Background(), testFlowIDService, 1) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -886,43 +901,43 @@ func (s *FlowMgtServiceTestSuite) TestGetFlowVersion_StoreError() { func (s *FlowMgtServiceTestSuite) TestRestoreFlowVersion_Success() { version := &FlowVersion{Version: 1} restoredFlow := &CompleteFlowDefinition{ActiveVersion: 2} - s.mockStore.EXPECT().GetFlowVersion(testFlowIDService, 1).Return(version, nil) - s.mockStore.EXPECT().RestoreFlowVersion(testFlowIDService, 1).Return(restoredFlow, nil) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, testFlowIDService, 1).Return(version, nil) + s.mockStore.EXPECT().RestoreFlowVersion(mock.Anything, testFlowIDService, 1).Return(restoredFlow, nil) s.mockGraphBuilder.EXPECT().InvalidateCache(testFlowIDService) - result, err := s.service.RestoreFlowVersion(testFlowIDService, 1) + result, err := s.service.RestoreFlowVersion(context.Background(), testFlowIDService, 1) s.Nil(err) s.Equal(restoredFlow, result) } func (s *FlowMgtServiceTestSuite) TestRestoreFlowVersion_EmptyID() { - result, err := s.service.RestoreFlowVersion("", 1) + result, err := s.service.RestoreFlowVersion(context.Background(), "", 1) s.Nil(result) s.Equal(&ErrorMissingFlowID, err) } func (s *FlowMgtServiceTestSuite) TestRestoreFlowVersion_InvalidVersion() { - result, err := s.service.RestoreFlowVersion(testFlowIDService, 0) + result, err := s.service.RestoreFlowVersion(context.Background(), testFlowIDService, 0) s.Nil(result) s.Equal(&ErrorInvalidVersion, err) } func (s *FlowMgtServiceTestSuite) TestRestoreFlowVersion_FlowNotFound() { - s.mockStore.EXPECT().GetFlowVersion(testFlowIDService, 1).Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, testFlowIDService, 1).Return(nil, errFlowNotFound) - result, err := s.service.RestoreFlowVersion(testFlowIDService, 1) + result, err := s.service.RestoreFlowVersion(context.Background(), testFlowIDService, 1) s.Nil(result) s.Equal(&ErrorFlowNotFound, err) } func (s *FlowMgtServiceTestSuite) TestRestoreFlowVersion_VersionNotFound() { - s.mockStore.EXPECT().GetFlowVersion(testFlowIDService, 1).Return(nil, errVersionNotFound) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, testFlowIDService, 1).Return(nil, errVersionNotFound) - result, err := s.service.RestoreFlowVersion(testFlowIDService, 1) + result, err := s.service.RestoreFlowVersion(context.Background(), testFlowIDService, 1) s.Nil(result) s.Equal(&ErrorVersionNotFound, err) @@ -930,10 +945,10 @@ func (s *FlowMgtServiceTestSuite) TestRestoreFlowVersion_VersionNotFound() { func (s *FlowMgtServiceTestSuite) TestRestoreFlowVersion_StoreError() { version := &FlowVersion{Version: 1} - s.mockStore.EXPECT().GetFlowVersion(testFlowIDService, 1).Return(version, nil) - s.mockStore.EXPECT().RestoreFlowVersion(testFlowIDService, 1).Return(nil, errors.New("db error")) + s.mockStore.EXPECT().GetFlowVersion(mock.Anything, testFlowIDService, 1).Return(version, nil) + s.mockStore.EXPECT().RestoreFlowVersion(mock.Anything, testFlowIDService, 1).Return(nil, errors.New("db error")) - result, err := s.service.RestoreFlowVersion(testFlowIDService, 1) + result, err := s.service.RestoreFlowVersion(context.Background(), testFlowIDService, 1) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -942,36 +957,37 @@ func (s *FlowMgtServiceTestSuite) TestRestoreFlowVersion_StoreError() { // GetGraph tests func (s *FlowMgtServiceTestSuite) TestGetGraph_Success() { - flow := &CompleteFlowDefinition{ID: testFlowIDService} - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(flow, nil) - s.mockGraphBuilder.EXPECT().GetGraph(flow).Return(nil, nil) + expectedFlow := &CompleteFlowDefinition{ID: testFlowIDService} + mockGraph := coremock.NewGraphInterfaceMock(s.T()) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(expectedFlow, nil) + s.mockGraphBuilder.EXPECT().GetGraph(expectedFlow).Return(mockGraph, nil) - result, err := s.service.GetGraph(testFlowIDService) + result, err := s.service.GetGraph(context.Background(), testFlowIDService) s.Nil(err) - s.Nil(result) + s.Equal(mockGraph, result) } func (s *FlowMgtServiceTestSuite) TestGetGraph_EmptyID() { - result, err := s.service.GetGraph("") + result, err := s.service.GetGraph(context.Background(), "") s.Nil(result) s.Equal(&ErrorMissingFlowID, err) } func (s *FlowMgtServiceTestSuite) TestGetGraph_FlowNotFound() { - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(nil, errFlowNotFound) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(nil, errFlowNotFound) - result, err := s.service.GetGraph(testFlowIDService) + result, err := s.service.GetGraph(context.Background(), testFlowIDService) s.Nil(result) s.Equal(&ErrorFlowNotFound, err) } func (s *FlowMgtServiceTestSuite) TestGetGraph_StoreError() { - s.mockStore.EXPECT().GetFlowByID(testFlowIDService).Return(nil, errors.New("db error")) + s.mockStore.EXPECT().GetFlowByID(mock.Anything, testFlowIDService).Return(nil, errors.New("db error")) - result, err := s.service.GetGraph(testFlowIDService) + result, err := s.service.GetGraph(context.Background(), testFlowIDService) s.Nil(result) s.Equal(&serviceerror.InternalServerError, err) @@ -980,31 +996,31 @@ func (s *FlowMgtServiceTestSuite) TestGetGraph_StoreError() { // IsValidFlow tests func (s *FlowMgtServiceTestSuite) TestIsValidFlow_Success() { - s.mockStore.EXPECT().IsFlowExists(testFlowIDService).Return(true, nil) + s.mockStore.EXPECT().IsFlowExists(mock.Anything, testFlowIDService).Return(true, nil) - result := s.service.IsValidFlow(testFlowIDService) + result := s.service.IsValidFlow(context.Background(), testFlowIDService) s.True(result) } func (s *FlowMgtServiceTestSuite) TestIsValidFlow_NotFound() { - s.mockStore.EXPECT().IsFlowExists(testFlowIDService).Return(false, nil) + s.mockStore.EXPECT().IsFlowExists(mock.Anything, testFlowIDService).Return(false, nil) - result := s.service.IsValidFlow(testFlowIDService) + result := s.service.IsValidFlow(context.Background(), testFlowIDService) s.False(result) } func (s *FlowMgtServiceTestSuite) TestIsValidFlow_EmptyID() { - result := s.service.IsValidFlow("") + result := s.service.IsValidFlow(context.Background(), "") s.False(result) } func (s *FlowMgtServiceTestSuite) TestIsValidFlow_StoreError() { - s.mockStore.EXPECT().IsFlowExists(testFlowIDService).Return(false, errors.New("db error")) + s.mockStore.EXPECT().IsFlowExists(mock.Anything, testFlowIDService).Return(false, errors.New("db error")) - result := s.service.IsValidFlow(testFlowIDService) + result := s.service.IsValidFlow(context.Background(), testFlowIDService) s.False(result) } @@ -1016,7 +1032,14 @@ func (s *FlowMgtServiceTestSuite) TestApplyExecutorDefaultMeta_Success() { mockExecutor := coremock.NewExecutorInterfaceMock(s.T()) // Create service with executor registry - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + // Create service with executor registry + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) defaultMeta := map[string]interface{}{ "components": []interface{}{ @@ -1066,7 +1089,14 @@ func (s *FlowMgtServiceTestSuite) TestApplyExecutorDefaultMeta_PreservesExisting mockExecutorRegistry := executormock.NewExecutorRegistryInterfaceMock(s.T()) // Create service with executor registry - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + // Create service with executor registry + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) customMeta := map[string]interface{}{ "components": []interface{}{ @@ -1110,7 +1140,14 @@ func (s *FlowMgtServiceTestSuite) TestApplyExecutorDefaultMeta_SkipsNonTaskExecu mockExecutorRegistry := executormock.NewExecutorRegistryInterfaceMock(s.T()) // Create service with executor registry - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + // Create service with executor registry + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) flowDef := &FlowDefinition{ Handle: "test-flow", @@ -1134,7 +1171,14 @@ func (s *FlowMgtServiceTestSuite) TestApplyExecutorDefaultMeta_HandlesGetExecuto mockExecutorRegistry := executormock.NewExecutorRegistryInterfaceMock(s.T()) // Create service with executor registry - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + // Create service with executor registry + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) mockExecutorRegistry.On("GetExecutor", "UnknownExecutor"). Return(nil, errors.New("executor not found")) @@ -1166,7 +1210,14 @@ func (s *FlowMgtServiceTestSuite) TestApplyExecutorDefaultMeta_HandlesGetExecuto func (s *FlowMgtServiceTestSuite) TestApplyExecutorDefaultMeta_NilExecutorRegistry() { // Create service WITHOUT executor registry - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, nil) + // Create service WITHOUT executor registry + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + nil, + &stubTransactioner{}, + ) flowDef := &FlowDefinition{ Handle: "test-flow", @@ -1197,7 +1248,14 @@ func (s *FlowMgtServiceTestSuite) TestApplyExecutorDefaultMeta_ExecutorReturnsNi mockExecutor := coremock.NewExecutorInterfaceMock(s.T()) // Create service with executor registry - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + // Create service with executor registry + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) // Executor returns nil meta mockExecutor.On("GetDefaultMeta").Return(nil) @@ -1240,7 +1298,14 @@ func (s *FlowMgtServiceTestSuite) TestApplyExecutorDefaultMeta_MultipleTaskExecu mockExecutor2 := coremock.NewExecutorInterfaceMock(s.T()) // Create service with executor registry - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + // Create service with executor registry + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) meta1 := map[string]interface{}{"executor": "Executor1"} meta2 := map[string]interface{}{"executor": "Executor2"} @@ -1302,7 +1367,13 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_Success_WithMetaG mockExecutorRegistry := executormock.NewExecutorRegistryInterfaceMock(s.T()) mockExecutor := coremock.NewExecutorInterfaceMock(s.T()) - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) authFlowDef := &FlowDefinition{ Handle: "auth-flow", @@ -1348,9 +1419,10 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_Success_WithMetaG FlowType: inferredRegFlow.FlowType, Nodes: inferredRegFlow.Nodes, } - s.mockStore.On("CreateFlow", mock.AnythingOfType("string"), inferredRegFlow).Return(completeFlow, nil) + s.mockStore.On("CreateFlow", mock.Anything, mock.AnythingOfType("string"), inferredRegFlow). + Return(completeFlow, nil) - service.(*flowMgtService).tryInferRegistrationFlow("auth-flow-id", authFlowDef) + service.(*flowMgtService).tryInferRegistrationFlow(context.Background(), "auth-flow-id", authFlowDef) s.Equal(defaultMeta, inferredRegFlow.Nodes[1].Meta) s.mockInference.AssertExpectations(s.T()) @@ -1370,7 +1442,13 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_SkipsNonAuthFlow( _ = config.InitializeThunderRuntime("test", testConfig) mockExecutorRegistry := executormock.NewExecutorRegistryInterfaceMock(s.T()) - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) regFlowDef := &FlowDefinition{ Handle: "reg-flow", @@ -1379,7 +1457,7 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_SkipsNonAuthFlow( Nodes: []NodeDefinition{}, } - service.(*flowMgtService).tryInferRegistrationFlow("reg-flow-id", regFlowDef) + service.(*flowMgtService).tryInferRegistrationFlow(context.Background(), "reg-flow-id", regFlowDef) s.mockInference.AssertNotCalled(s.T(), "InferRegistrationFlow") s.mockStore.AssertNotCalled(s.T(), "CreateFlow") @@ -1397,7 +1475,13 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_HandlesInferenceE _ = config.InitializeThunderRuntime("test", testConfig) mockExecutorRegistry := executormock.NewExecutorRegistryInterfaceMock(s.T()) - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) authFlowDef := &FlowDefinition{ Handle: "auth-flow", @@ -1408,7 +1492,7 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_HandlesInferenceE s.mockInference.On("InferRegistrationFlow", authFlowDef).Return(nil, errors.New("inference failed")) - service.(*flowMgtService).tryInferRegistrationFlow("auth-flow-id", authFlowDef) + service.(*flowMgtService).tryInferRegistrationFlow(context.Background(), "auth-flow-id", authFlowDef) s.mockInference.AssertExpectations(s.T()) s.mockStore.AssertNotCalled(s.T(), "CreateFlow") @@ -1426,7 +1510,13 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_HandlesMetaApplic _ = config.InitializeThunderRuntime("test", testConfig) mockExecutorRegistry := executormock.NewExecutorRegistryInterfaceMock(s.T()) - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) authFlowDef := &FlowDefinition{ Handle: "auth-flow", @@ -1456,7 +1546,7 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_HandlesMetaApplic mockExecutorRegistry.On("GetExecutor", "UnknownExecutor").Return(nil, errors.New("executor not found")) s.mockInference.On("InferRegistrationFlow", authFlowDef).Return(inferredRegFlow, nil) - service.(*flowMgtService).tryInferRegistrationFlow("auth-flow-id", authFlowDef) + service.(*flowMgtService).tryInferRegistrationFlow(context.Background(), "auth-flow-id", authFlowDef) s.mockInference.AssertExpectations(s.T()) mockExecutorRegistry.AssertExpectations(s.T()) @@ -1466,7 +1556,13 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_HandlesMetaApplic func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_DisabledAutoInference() { // Auto-inference is disabled in SetupTest, so just verify early return mockExecutorRegistry := executormock.NewExecutorRegistryInterfaceMock(s.T()) - service := newFlowMgtService(s.mockStore, s.mockInference, s.mockGraphBuilder, mockExecutorRegistry) + service := newFlowMgtService( + s.mockStore, + s.mockInference, + s.mockGraphBuilder, + mockExecutorRegistry, + &stubTransactioner{}, + ) authFlowDef := &FlowDefinition{ Handle: "auth-flow", @@ -1475,7 +1571,7 @@ func (s *FlowMgtServiceTestSuite) TestTryInferRegistrationFlow_DisabledAutoInfer Nodes: []NodeDefinition{}, } - service.(*flowMgtService).tryInferRegistrationFlow("auth-flow-id", authFlowDef) + service.(*flowMgtService).tryInferRegistrationFlow(context.Background(), "auth-flow-id", authFlowDef) s.mockInference.AssertNotCalled(s.T(), "InferRegistrationFlow") s.mockStore.AssertNotCalled(s.T(), "CreateFlow") diff --git a/backend/internal/flow/mgt/store.go b/backend/internal/flow/mgt/store.go index b72cd88f0..be3a132dd 100644 --- a/backend/internal/flow/mgt/store.go +++ b/backend/internal/flow/mgt/store.go @@ -19,16 +19,15 @@ package flowmgt import ( - "database/sql" + "context" "encoding/json" - "errors" "fmt" "time" "github.com/asgardeo/thunder/internal/flow/common" "github.com/asgardeo/thunder/internal/system/config" - "github.com/asgardeo/thunder/internal/system/database/model" "github.com/asgardeo/thunder/internal/system/database/provider" + "github.com/asgardeo/thunder/internal/system/database/transaction" "github.com/asgardeo/thunder/internal/system/log" ) @@ -48,22 +47,23 @@ const ( // flowStoreInterface defines the interface for flow store operations. type flowStoreInterface interface { - ListFlows(limit, offset int, flowType string) ([]BasicFlowDefinition, int, error) - CreateFlow(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) - GetFlowByID(flowID string) (*CompleteFlowDefinition, error) - GetFlowByHandle(handle string, flowType common.FlowType) (*CompleteFlowDefinition, error) - UpdateFlow(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) - DeleteFlow(flowID string) error - ListFlowVersions(flowID string) ([]BasicFlowVersion, error) - GetFlowVersion(flowID string, version int) (*FlowVersion, error) - RestoreFlowVersion(flowID string, version int) (*CompleteFlowDefinition, error) - IsFlowExists(flowID string) (bool, error) - IsFlowExistsByHandle(handle string, flowType common.FlowType) (bool, error) + ListFlows(ctx context.Context, limit, offset int, flowType string) ([]BasicFlowDefinition, int, error) + CreateFlow(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) + GetFlowByID(ctx context.Context, flowID string) (*CompleteFlowDefinition, error) + GetFlowByHandle(ctx context.Context, handle string, flowType common.FlowType) (*CompleteFlowDefinition, error) + UpdateFlow(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) + DeleteFlow(ctx context.Context, flowID string) error + ListFlowVersions(ctx context.Context, flowID string) ([]BasicFlowVersion, error) + GetFlowVersion(ctx context.Context, flowID string, version int) (*FlowVersion, error) + RestoreFlowVersion(ctx context.Context, flowID string, version int) (*CompleteFlowDefinition, error) + IsFlowExists(ctx context.Context, flowID string) (bool, error) + IsFlowExistsByHandle(ctx context.Context, handle string, flowType common.FlowType) (bool, error) } // flowStore is the default implementation of flowStoreInterface. type flowStore struct { dbProvider provider.DBProviderInterface + transactioner transaction.Transactioner deploymentID string maxVersionHistory int logger *log.Logger @@ -71,8 +71,14 @@ type flowStore struct { // newFlowStore creates a new instance of flowStore. func newFlowStore() flowStoreInterface { + txer, err := provider.GetDBProvider().GetConfigDBTransactioner() + if err != nil { + log.GetLogger().Fatal("Failed to get config DB transactioner", log.Error(err)) + } + return &flowStore{ dbProvider: provider.GetDBProvider(), + transactioner: txer, deploymentID: config.GetThunderRuntime().Config.Server.Identifier, maxVersionHistory: getMaxVersionHistory(), logger: log.GetLogger().With(log.String(log.LoggerKeyComponentName, "FlowStore")), @@ -80,31 +86,32 @@ func newFlowStore() flowStoreInterface { } // ListFlows retrieves a paginated list of flow definitions with optional filtering by flow type. -func (s *flowStore) ListFlows(limit, offset int, flowType string) ([]BasicFlowDefinition, int, error) { +func (s *flowStore) ListFlows(ctx context.Context, limit, offset int, flowType string) ([]BasicFlowDefinition, int, + error) { var flows []BasicFlowDefinition var totalCount int - err := s.withDBClient(func(dbClient provider.DBClientInterface) error { + err := s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { var countResults, results []map[string]interface{} var err error if flowType != "" { - countResults, err = dbClient.Query(queryCountFlowsWithType, flowType, s.deploymentID) + countResults, err = dbClient.QueryContext(ctx, queryCountFlowsWithType, flowType, s.deploymentID) if err != nil { return fmt.Errorf("failed to count flows: %w", err) } - results, err = dbClient.Query(queryListFlowsWithType, flowType, s.deploymentID, limit, offset) + results, err = dbClient.QueryContext(ctx, queryListFlowsWithType, flowType, s.deploymentID, limit, offset) if err != nil { return fmt.Errorf("failed to list flows: %w", err) } } else { - countResults, err = dbClient.Query(queryCountFlows, s.deploymentID) + countResults, err = dbClient.QueryContext(ctx, queryCountFlows, s.deploymentID) if err != nil { return fmt.Errorf("failed to count flows: %w", err) } - results, err = dbClient.Query(queryListFlows, s.deploymentID, limit, offset) + results, err = dbClient.QueryContext(ctx, queryListFlows, s.deploymentID, limit, offset) if err != nil { return fmt.Errorf("failed to list flows: %w", err) } @@ -135,42 +142,47 @@ func (s *flowStore) ListFlows(limit, offset int, flowType string) ([]BasicFlowDe } // CreateFlow creates a new flow definition with version 1. -func (s *flowStore) CreateFlow(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) { +func (s *flowStore) CreateFlow(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, + error) { nodesJSON, err := json.Marshal(flow.Nodes) if err != nil { return nil, fmt.Errorf("failed to marshal nodes: %w", err) } - err = s.withTransaction(func(tx model.TxInterface) error { - _, err := tx.Exec(queryCreateFlow, flowID, flow.Handle, flow.Name, flow.FlowType, int64(1), s.deploymentID) - if err != nil { - return fmt.Errorf("failed to create flow: %w", err) - } + err = s.transactioner.Transact(ctx, func(ctx context.Context) error { + return s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + _, err := dbClient.ExecuteContext(ctx, queryCreateFlow, flowID, flow.Handle, flow.Name, flow.FlowType, + int64(1), s.deploymentID) + if err != nil { + return fmt.Errorf("failed to create flow: %w", err) + } - internalID, err := s.getFlowInternalIDWithTx(tx, flowID) - if err != nil { - return err - } + internalID, err := s.getFlowInternalID(ctx, dbClient, flowID) + if err != nil { + return err + } - _, err = tx.Exec(queryInsertFlowVersion, internalID, 1, string(nodesJSON), s.deploymentID) - if err != nil { - return fmt.Errorf("failed to create flow version: %w", err) - } + _, err = dbClient.ExecuteContext(ctx, queryInsertFlowVersion, internalID, 1, string(nodesJSON), + s.deploymentID) + if err != nil { + return fmt.Errorf("failed to create flow version: %w", err) + } - return nil + return nil + }) }) if err != nil { return nil, err } - return s.GetFlowByID(flowID) + return s.GetFlowByID(ctx, flowID) } // GetFlowByID retrieves the active version of a flow definition by its ID. -func (s *flowStore) GetFlowByID(flowID string) (*CompleteFlowDefinition, error) { +func (s *flowStore) GetFlowByID(ctx context.Context, flowID string) (*CompleteFlowDefinition, error) { var flow *CompleteFlowDefinition - err := s.withDBClient(func(dbClient provider.DBClientInterface) error { - results, err := dbClient.Query(queryGetFlow, flowID, s.deploymentID) + err := s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + results, err := dbClient.QueryContext(ctx, queryGetFlow, flowID, s.deploymentID) if err != nil { return fmt.Errorf("failed to get flow: %w", err) } @@ -187,10 +199,11 @@ func (s *flowStore) GetFlowByID(flowID string) (*CompleteFlowDefinition, error) } // GetFlowByHandle retrieves a flow definition by handle and flow type. -func (s *flowStore) GetFlowByHandle(handle string, flowType common.FlowType) (*CompleteFlowDefinition, error) { +func (s *flowStore) GetFlowByHandle(ctx context.Context, handle string, + flowType common.FlowType) (*CompleteFlowDefinition, error) { var flow *CompleteFlowDefinition - err := s.withDBClient(func(dbClient provider.DBClientInterface) error { - results, err := dbClient.Query(queryGetFlowByHandle, handle, string(flowType), s.deploymentID) + err := s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + results, err := dbClient.QueryContext(ctx, queryGetFlowByHandle, handle, string(flowType), s.deploymentID) if err != nil { return fmt.Errorf("failed to get flow by handle: %w", err) } @@ -208,56 +221,60 @@ func (s *flowStore) GetFlowByHandle(handle string, flowType common.FlowType) (*C // UpdateFlow updates a flow definition by creating a new version. // Automatically deletes oldest versions if the count exceeds max_version_history. -func (s *flowStore) UpdateFlow(flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, error) { +func (s *flowStore) UpdateFlow(ctx context.Context, flowID string, flow *FlowDefinition) (*CompleteFlowDefinition, + error) { nodesJSON, err := json.Marshal(flow.Nodes) if err != nil { return nil, fmt.Errorf("failed to marshal nodes: %w", err) } - err = s.withTransaction(func(tx model.TxInterface) error { - flowResults, err := tx.Query(queryGetFlow, flowID, s.deploymentID) - if err != nil { - return fmt.Errorf("failed to get flow metadata: %w", err) - } + err = s.transactioner.Transact(ctx, func(ctx context.Context) error { + return s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + // Retrieve current flow metadata to determine active version + results, err := dbClient.QueryContext(ctx, queryGetFlow, flowID, s.deploymentID) + if err != nil { + return fmt.Errorf("failed to get flow metadata: %w", err) + } + if len(results) == 0 { + return errFlowNotFound + } - _, currentVersion, err := s.scanFlowMetadata(flowResults) - if closeErr := flowResults.Close(); closeErr != nil { - s.logger.Error("Failed to close flow results", log.Error(closeErr)) - } - if err != nil { - return errFlowNotFound - } + currentFlow, err := s.buildBasicFlowDefinitionFromRow(results[0]) + if err != nil { + return fmt.Errorf("failed to parse flow metadata: %w", err) + } - newVersion := int(currentVersion) + 1 + newVersion := currentFlow.ActiveVersion + 1 - internalID, err := s.getFlowInternalIDWithTx(tx, flowID) - if err != nil { - return err - } + internalID, err := s.getFlowInternalID(ctx, dbClient, flowID) + if err != nil { + return err + } - // Insert the new version first to ensure it succeeds before updating the flow - if err := s.pushToVersionStack(tx, internalID, newVersion, string(nodesJSON)); err != nil { - return err - } + // Insert the new version first to ensure it succeeds before updating the flow + if err := s.pushToVersionStack(ctx, dbClient, internalID, newVersion, string(nodesJSON)); err != nil { + return err + } - _, err = tx.Exec(queryUpdateFlow, flowID, flow.Name, newVersion, s.deploymentID) - if err != nil { - return fmt.Errorf("failed to update flow: %w", err) - } + _, err = dbClient.ExecuteContext(ctx, queryUpdateFlow, flowID, flow.Name, newVersion, s.deploymentID) + if err != nil { + return fmt.Errorf("failed to update flow: %w", err) + } - return nil + return nil + }) }) if err != nil { return nil, err } - return s.GetFlowByID(flowID) + return s.GetFlowByID(ctx, flowID) } // DeleteFlow deletes a flow definition and all its version history. -func (s *flowStore) DeleteFlow(flowID string) error { - return s.withDBClient(func(dbClient provider.DBClientInterface) error { - _, err := dbClient.Execute(queryDeleteFlow, flowID, s.deploymentID) +func (s *flowStore) DeleteFlow(ctx context.Context, flowID string) error { + return s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + _, err := dbClient.ExecuteContext(ctx, queryDeleteFlow, flowID, s.deploymentID) if err != nil { return fmt.Errorf("failed to delete flow: %w", err) } @@ -266,10 +283,10 @@ func (s *flowStore) DeleteFlow(flowID string) error { } // IsFlowExists checks if a flow exists with a given flow ID. -func (s *flowStore) IsFlowExists(flowID string) (bool, error) { +func (s *flowStore) IsFlowExists(ctx context.Context, flowID string) (bool, error) { var exists bool - err := s.withDBClient(func(dbClient provider.DBClientInterface) error { - results, err := dbClient.Query(queryCheckFlowExistsByID, flowID, s.deploymentID) + err := s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + results, err := dbClient.QueryContext(ctx, queryCheckFlowExistsByID, flowID, s.deploymentID) if err != nil { return fmt.Errorf("failed to check flow existence: %w", err) } @@ -282,10 +299,11 @@ func (s *flowStore) IsFlowExists(flowID string) (bool, error) { } // IsFlowExistsByHandle checks if a flow exists with the given handle and flow type. -func (s *flowStore) IsFlowExistsByHandle(handle string, flowType common.FlowType) (bool, error) { +func (s *flowStore) IsFlowExistsByHandle(ctx context.Context, handle string, flowType common.FlowType) (bool, error) { var exists bool - err := s.withDBClient(func(dbClient provider.DBClientInterface) error { - results, err := dbClient.Query(queryCheckFlowExistsByHandle, handle, string(flowType), s.deploymentID) + err := s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + results, err := dbClient.QueryContext(ctx, queryCheckFlowExistsByHandle, handle, string(flowType), + s.deploymentID) if err != nil { return fmt.Errorf("failed to check flow existence by handle: %w", err) } @@ -298,16 +316,16 @@ func (s *flowStore) IsFlowExistsByHandle(handle string, flowType common.FlowType } // ListFlowVersions retrieves all versions of a flow definition. -func (s *flowStore) ListFlowVersions(flowID string) ([]BasicFlowVersion, error) { +func (s *flowStore) ListFlowVersions(ctx context.Context, flowID string) ([]BasicFlowVersion, error) { var versions []BasicFlowVersion - err := s.withDBClient(func(dbClient provider.DBClientInterface) error { - internalID, err := s.getFlowInternalID(dbClient, flowID) + err := s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + internalID, err := s.getFlowInternalID(ctx, dbClient, flowID) if err != nil { return err } - results, err := dbClient.Query(queryListFlowVersions, internalID, s.deploymentID) + results, err := dbClient.QueryContext(ctx, queryListFlowVersions, internalID, s.deploymentID) if err != nil { return fmt.Errorf("failed to list flow versions: %w", err) } @@ -328,11 +346,11 @@ func (s *flowStore) ListFlowVersions(flowID string) ([]BasicFlowVersion, error) } // GetFlowVersion retrieves a specific version of a flow definition. -func (s *flowStore) GetFlowVersion(flowID string, version int) (*FlowVersion, error) { +func (s *flowStore) GetFlowVersion(ctx context.Context, flowID string, version int) (*FlowVersion, error) { var flowVersion *FlowVersion - err := s.withDBClient(func(dbClient provider.DBClientInterface) error { - results, err := dbClient.Query(queryGetFlowVersionWithMetadata, flowID, version, s.deploymentID) + err := s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + results, err := dbClient.QueryContext(ctx, queryGetFlowVersionWithMetadata, flowID, version, s.deploymentID) if err != nil { return fmt.Errorf("failed to get flow version: %w", err) } @@ -350,84 +368,87 @@ func (s *flowStore) GetFlowVersion(flowID string, version int) (*FlowVersion, er // RestoreFlowVersion restores a specified version as the active version. // This creates a new version by copying the configuration from the specified version. // Automatically deletes oldest versions if the count exceeds max_version_history. -func (s *flowStore) RestoreFlowVersion(flowID string, version int) (*CompleteFlowDefinition, error) { - err := s.withTransaction(func(tx model.TxInterface) error { - flowResults, err := tx.Query(queryGetFlow, flowID, s.deploymentID) - if err != nil { - return fmt.Errorf("failed to get flow metadata: %w", err) - } +func (s *flowStore) RestoreFlowVersion(ctx context.Context, flowID string, version int) (*CompleteFlowDefinition, + error) { + err := s.transactioner.Transact(ctx, func(ctx context.Context) error { + return s.withDBClient(ctx, func(ctx context.Context, dbClient provider.DBClientInterface) error { + // Get current flow metadata + flowResults, err := dbClient.QueryContext(ctx, queryGetFlow, flowID, s.deploymentID) + if err != nil { + return fmt.Errorf("failed to get flow metadata: %w", err) + } + if len(flowResults) == 0 { + return errFlowNotFound + } - flowName, currentVersion, err := s.scanFlowMetadata(flowResults) - if closeErr := flowResults.Close(); closeErr != nil { - s.logger.Error("Failed to close flow results", log.Error(closeErr)) - } - if err != nil { - return errFlowNotFound - } + currentFlow, err := s.buildBasicFlowDefinitionFromRow(flowResults[0]) + if err != nil { + return fmt.Errorf("failed to parse flow metadata: %w", err) + } - internalID, err := s.getFlowInternalIDWithTx(tx, flowID) - if err != nil { - return err - } + internalID, err := s.getFlowInternalID(ctx, dbClient, flowID) + if err != nil { + return err + } - versionResults, err := tx.Query(queryGetFlowVersion, internalID, version, s.deploymentID) - if err != nil { - return fmt.Errorf("failed to get version to restore: %w", err) - } + // Get version to restore + versionResults, err := dbClient.QueryContext(ctx, queryGetFlowVersion, internalID, version, s.deploymentID) + if err != nil { + return fmt.Errorf("failed to get version to restore: %w", err) + } + if len(versionResults) == 0 { + return errVersionNotFound + } - nodesJSON, err := s.scanFlowVersion(versionResults) - if closeErr := versionResults.Close(); closeErr != nil { - s.logger.Error("Failed to close version results", log.Error(closeErr)) - } - if err != nil { - return errVersionNotFound - } + nodesJSON, err := s.getString(versionResults[0], colNodes) + if err != nil { + return fmt.Errorf("failed to parse nodes from version: %w", err) + } - newVersion := int(currentVersion) + 1 + newVersion := currentFlow.ActiveVersion + 1 - // Insert the new version first to ensure it succeeds before updating the flow - if err := s.pushToVersionStack(tx, internalID, newVersion, nodesJSON); err != nil { - return err - } + // Insert the new version first to ensure it succeeds before updating the flow + if err := s.pushToVersionStack(ctx, dbClient, internalID, newVersion, nodesJSON); err != nil { + return err + } - _, err = tx.Exec(queryUpdateFlow, flowID, flowName, newVersion, s.deploymentID) - if err != nil { - return fmt.Errorf("failed to update flow: %w", err) - } + _, err = dbClient.ExecuteContext(ctx, queryUpdateFlow, flowID, currentFlow.Name, newVersion, s.deploymentID) + if err != nil { + return fmt.Errorf("failed to update flow: %w", err) + } - return nil + return nil + }) }) if err != nil { return nil, err } - return s.GetFlowByID(flowID) + return s.GetFlowByID(ctx, flowID) } // pushToVersionStack adds a new version to the version history and removes the oldest version // if the count exceeds max_version_history. -func (s *flowStore) pushToVersionStack(tx model.TxInterface, +func (s *flowStore) pushToVersionStack(ctx context.Context, dbClient provider.DBClientInterface, flowInternalID int64, version int, nodesJSON string) error { - _, err := tx.Exec(queryInsertFlowVersion, flowInternalID, version, nodesJSON, s.deploymentID) + _, err := dbClient.ExecuteContext(ctx, queryInsertFlowVersion, flowInternalID, version, nodesJSON, s.deploymentID) if err != nil { return fmt.Errorf("failed to insert flow version: %w", err) } - countResults, err := tx.Query(queryCountFlowVersions, flowInternalID, s.deploymentID) + countResults, err := dbClient.QueryContext(ctx, queryCountFlowVersions, flowInternalID, s.deploymentID) if err != nil { return fmt.Errorf("failed to count versions: %w", err) } - versionCount, err := s.parseCountFromRows(countResults) - if closeErr := countResults.Close(); closeErr != nil { - s.logger.Error("Failed to close count results", log.Error(closeErr)) - } + versionCount, err := s.parseCountResult(countResults) if err != nil { return err } if versionCount > s.maxVersionHistory { - if _, err := tx.Exec(queryDeleteOldestVersion, flowInternalID, s.deploymentID); err != nil { + _, err = dbClient.ExecuteContext(ctx, queryDeleteOldestVersion, flowInternalID, s.deploymentID) + if err != nil { return fmt.Errorf("failed to delete oldest version: %w", err) } } @@ -435,33 +456,10 @@ func (s *flowStore) pushToVersionStack(tx model.TxInterface, return nil } -// getFlowInternalIDWithTx retrieves the internal ID of a flow by its flow ID within a transaction. -func (s *flowStore) getFlowInternalIDWithTx(tx model.TxInterface, flowID string) (int64, error) { - results, err := tx.Query(queryGetFlowInternalID, flowID, s.deploymentID) - if err != nil { - return 0, fmt.Errorf("failed to get flow internal ID: %w", err) - } - - if !results.Next() { - _ = results.Close() - return 0, errFlowNotFound - } - - var internalID int64 - if err := results.Scan(&internalID); err != nil { - _ = results.Close() - return 0, fmt.Errorf("failed to scan internal ID: %w", err) - } - if closeErr := results.Close(); closeErr != nil { - s.logger.Error("Failed to close internal ID results", log.Error(closeErr)) - } - - return internalID, nil -} - // getFlowInternalID retrieves the internal ID of a flow by its flow ID. -func (s *flowStore) getFlowInternalID(dbClient provider.DBClientInterface, flowID string) (int64, error) { - results, err := dbClient.Query(queryGetFlowInternalID, flowID, s.deploymentID) +func (s *flowStore) getFlowInternalID(ctx context.Context, dbClient provider.DBClientInterface, flowID string) (int64, + error) { + results, err := dbClient.QueryContext(ctx, queryGetFlowInternalID, flowID, s.deploymentID) if err != nil { return 0, fmt.Errorf("failed to get flow internal ID: %w", err) } @@ -493,35 +491,13 @@ func (s *flowStore) getConfigDBClient() (provider.DBClientInterface, error) { } // withDBClient executes a function with a DB client, handling client retrieval errors. -func (s *flowStore) withDBClient(fn func(provider.DBClientInterface) error) error { +func (s *flowStore) withDBClient(ctx context.Context, + fn func(context.Context, provider.DBClientInterface) error) error { dbClient, err := s.getConfigDBClient() if err != nil { return err } - return fn(dbClient) -} - -// withTransaction executes a function within a database transaction. -func (s *flowStore) withTransaction(fn func(model.TxInterface) error) error { - return s.withDBClient(func(dbClient provider.DBClientInterface) error { - tx, err := dbClient.BeginTx() - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - - if err := fn(tx); err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - err = errors.Join(err, fmt.Errorf("failed to rollback transaction: %w", rollbackErr)) - } - return err - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("failed to commit transaction: %w", err) - } - - return nil - }) + return fn(ctx, dbClient) } // parseCountResult parses a count result from database query. @@ -547,52 +523,6 @@ func (s *flowStore) parseCountResult(results []map[string]interface{}) (int, err } } -// parseCountFromRows parses the count result from *sql.Rows. -func (s *flowStore) parseCountFromRows(rows *sql.Rows) (int, error) { - if !rows.Next() { - return 0, fmt.Errorf("no count result returned") - } - - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, fmt.Errorf("failed to scan count: %w", err) - } - - return int(count), nil -} - -// scanFlowMetadata scans a single row from FLOW table into individual fields. -func (s *flowStore) scanFlowMetadata(rows *sql.Rows) (flowName string, activeVersion int64, err error) { - if !rows.Next() { - return "", 0, fmt.Errorf("no flow found") - } - - var flowID, handle, flowType, nodes, createdAt, updatedAt string - err = rows.Scan(&flowID, &handle, &flowName, &flowType, &activeVersion, &nodes, &createdAt, &updatedAt) - if err != nil { - return "", 0, fmt.Errorf("failed to scan flow metadata: %w", err) - } - - return flowName, activeVersion, nil -} - -// scanFlowVersion scans a single row from FLOW_VERSION table into individual fields. -func (s *flowStore) scanFlowVersion(rows *sql.Rows) (nodes string, err error) { - if !rows.Next() { - return "", fmt.Errorf("no version found") - } - - var version int64 - var createdAt string - err = rows.Scan(&version, &nodes, &createdAt) - if err != nil { - return "", fmt.Errorf("failed to scan version data: %w", err) - } - - return nodes, nil -} - -// getString safely extracts a string value from a database row. // Handles both string (SQLite) and []byte (PostgreSQL) types. func (s *flowStore) getString(row map[string]interface{}, key string) (string, error) { val := row[key] diff --git a/backend/internal/flow/mgt/store_test.go b/backend/internal/flow/mgt/store_test.go index 0c8ac826c..ff503e193 100644 --- a/backend/internal/flow/mgt/store_test.go +++ b/backend/internal/flow/mgt/store_test.go @@ -19,24 +19,42 @@ package flowmgt import ( + "context" "errors" "testing" "time" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "github.com/asgardeo/thunder/internal/flow/common" "github.com/asgardeo/thunder/internal/system/config" "github.com/asgardeo/thunder/internal/system/log" - "github.com/asgardeo/thunder/tests/mocks/database/modelmock" "github.com/asgardeo/thunder/tests/mocks/database/providermock" ) +type mockTransactioner struct { + mock.Mock +} + +func (m *mockTransactioner) Transact(ctx context.Context, fn func(context.Context) error) error { + args := m.Called(ctx, fn) + err := args.Error(0) + if err != nil { + return err + } + if fn != nil { + return fn(ctx) + } + return nil +} + type FlowStoreTestSuite struct { suite.Suite - mockDBProvider *providermock.DBProviderInterfaceMock - mockDBClient *providermock.DBClientInterfaceMock - store *flowStore + mockDBProvider *providermock.DBProviderInterfaceMock + mockDBClient *providermock.DBClientInterfaceMock + mockTransactioner *mockTransactioner + store *flowStore } func TestFlowStoreTestSuite(t *testing.T) { @@ -51,8 +69,10 @@ func (s *FlowStoreTestSuite) SetupTest() { s.mockDBProvider = providermock.NewDBProviderInterfaceMock(s.T()) s.mockDBClient = providermock.NewDBClientInterfaceMock(s.T()) + s.mockTransactioner = &mockTransactioner{} s.store = &flowStore{ dbProvider: s.mockDBProvider, + transactioner: s.mockTransactioner, deploymentID: "test-deployment", maxVersionHistory: 5, logger: log.GetLogger().With(log.String(log.LoggerKeyComponentName, "FlowStore")), @@ -64,7 +84,7 @@ func (s *FlowStoreTestSuite) SetupTest() { func (s *FlowStoreTestSuite) TestListFlowsDBClientError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(nil, errors.New("connection error")) - flows, count, err := s.store.ListFlows(10, 0, "") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "") s.Error(err) s.Contains(err.Error(), "failed to get database client") @@ -74,10 +94,10 @@ func (s *FlowStoreTestSuite) TestListFlowsDBClientError() { func (s *FlowStoreTestSuite) TestListFlowsCountQueryError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryCountFlows, "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCountFlows, "test-deployment"). Return(nil, errors.New("query error")).Once() - flows, count, err := s.store.ListFlows(10, 0, "") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "") s.Error(err) s.Contains(err.Error(), "failed to count flows") @@ -87,12 +107,12 @@ func (s *FlowStoreTestSuite) TestListFlowsCountQueryError() { func (s *FlowStoreTestSuite) TestListFlowsQueryError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryCountFlows, "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCountFlows, "test-deployment"). Return([]map[string]interface{}{{colCount: int64(1)}}, nil).Once() - s.mockDBClient.EXPECT().Query(queryListFlows, "test-deployment", 10, 0). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryListFlows, "test-deployment", 10, 0). Return(nil, errors.New("query error")).Once() - flows, count, err := s.store.ListFlows(10, 0, "") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "") s.Error(err) s.Contains(err.Error(), "failed to list flows") @@ -104,10 +124,10 @@ func (s *FlowStoreTestSuite) TestListFlowsQueryError() { func (s *FlowStoreTestSuite) TestGetFlowByIDNotFound() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryGetFlow, "non-existent", "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlow, "non-existent", "test-deployment"). Return([]map[string]interface{}{}, nil).Once() - flow, err := s.store.GetFlowByID("non-existent") + flow, err := s.store.GetFlowByID(context.Background(), "non-existent") s.Error(err) s.ErrorIs(err, errFlowNotFound) @@ -117,7 +137,7 @@ func (s *FlowStoreTestSuite) TestGetFlowByIDNotFound() { func (s *FlowStoreTestSuite) TestGetFlowByIDDBError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(nil, errors.New("connection error")) - flow, err := s.store.GetFlowByID("flow-1") + flow, err := s.store.GetFlowByID(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "failed to get database client") @@ -126,10 +146,10 @@ func (s *FlowStoreTestSuite) TestGetFlowByIDDBError() { func (s *FlowStoreTestSuite) TestGetFlowByIDQueryError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryGetFlow, "flow-1", "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlow, "flow-1", "test-deployment"). Return(nil, errors.New("query error")).Once() - flow, err := s.store.GetFlowByID("flow-1") + flow, err := s.store.GetFlowByID(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "failed to get flow") @@ -140,10 +160,10 @@ func (s *FlowStoreTestSuite) TestGetFlowByIDQueryError() { func (s *FlowStoreTestSuite) TestDeleteFlowSuccess() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Execute(queryDeleteFlow, "flow-1", "test-deployment"). + s.mockDBClient.EXPECT().ExecuteContext(mock.Anything, queryDeleteFlow, "flow-1", "test-deployment"). Return(int64(1), nil).Once() - err := s.store.DeleteFlow("flow-1") + err := s.store.DeleteFlow(context.Background(), "flow-1") s.NoError(err) } @@ -151,7 +171,7 @@ func (s *FlowStoreTestSuite) TestDeleteFlowSuccess() { func (s *FlowStoreTestSuite) TestDeleteFlowDBError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(nil, errors.New("connection error")) - err := s.store.DeleteFlow("flow-1") + err := s.store.DeleteFlow(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "failed to get database client") @@ -159,10 +179,10 @@ func (s *FlowStoreTestSuite) TestDeleteFlowDBError() { func (s *FlowStoreTestSuite) TestDeleteFlowExecuteError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Execute(queryDeleteFlow, "flow-1", "test-deployment"). + s.mockDBClient.EXPECT().ExecuteContext(mock.Anything, queryDeleteFlow, "flow-1", "test-deployment"). Return(int64(0), errors.New("delete failed")).Once() - err := s.store.DeleteFlow("flow-1") + err := s.store.DeleteFlow(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "failed to delete flow") @@ -172,10 +192,10 @@ func (s *FlowStoreTestSuite) TestDeleteFlowExecuteError() { func (s *FlowStoreTestSuite) TestIsFlowExistsSuccess() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryCheckFlowExistsByID, "flow-1", "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCheckFlowExistsByID, "flow-1", "test-deployment"). Return([]map[string]interface{}{{"exists": 1}}, nil).Once() - exists, err := s.store.IsFlowExists("flow-1") + exists, err := s.store.IsFlowExists(context.Background(), "flow-1") s.NoError(err) s.True(exists) @@ -183,10 +203,10 @@ func (s *FlowStoreTestSuite) TestIsFlowExistsSuccess() { func (s *FlowStoreTestSuite) TestIsFlowExistsNotFound() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryCheckFlowExistsByID, "non-existent", "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCheckFlowExistsByID, "non-existent", "test-deployment"). Return([]map[string]interface{}{}, nil).Once() - exists, err := s.store.IsFlowExists("non-existent") + exists, err := s.store.IsFlowExists(context.Background(), "non-existent") s.NoError(err) s.False(exists) @@ -195,7 +215,7 @@ func (s *FlowStoreTestSuite) TestIsFlowExistsNotFound() { func (s *FlowStoreTestSuite) TestIsFlowExistsDBError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(nil, errors.New("connection error")) - exists, err := s.store.IsFlowExists("flow-1") + exists, err := s.store.IsFlowExists(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "failed to get database client") @@ -204,10 +224,10 @@ func (s *FlowStoreTestSuite) TestIsFlowExistsDBError() { func (s *FlowStoreTestSuite) TestIsFlowExistsQueryError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryCheckFlowExistsByID, "flow-1", "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCheckFlowExistsByID, "flow-1", "test-deployment"). Return(nil, errors.New("query error")).Once() - exists, err := s.store.IsFlowExists("flow-1") + exists, err := s.store.IsFlowExists(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "failed to check flow existence") @@ -229,11 +249,11 @@ func (s *FlowStoreTestSuite) TestGetFlowByHandleSuccess() { } s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryGetFlowByHandle, "test-handle", + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowByHandle, "test-handle", string(common.FlowTypeAuthentication), "test-deployment").Return( []map[string]interface{}{flowData}, nil).Once() - flow, err := s.store.GetFlowByHandle("test-handle", common.FlowTypeAuthentication) + flow, err := s.store.GetFlowByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.NoError(err) s.NotNil(flow) @@ -244,11 +264,11 @@ func (s *FlowStoreTestSuite) TestGetFlowByHandleSuccess() { func (s *FlowStoreTestSuite) TestGetFlowByHandleNotFound() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryGetFlowByHandle, "non-existent", + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowByHandle, "non-existent", string(common.FlowTypeAuthentication), "test-deployment").Return( []map[string]interface{}{}, nil).Once() - flow, err := s.store.GetFlowByHandle("non-existent", common.FlowTypeAuthentication) + flow, err := s.store.GetFlowByHandle(context.Background(), "non-existent", common.FlowTypeAuthentication) s.Error(err) s.ErrorIs(err, errFlowNotFound) @@ -258,7 +278,7 @@ func (s *FlowStoreTestSuite) TestGetFlowByHandleNotFound() { func (s *FlowStoreTestSuite) TestGetFlowByHandleDBError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(nil, errors.New("connection error")) - flow, err := s.store.GetFlowByHandle("test-handle", common.FlowTypeAuthentication) + flow, err := s.store.GetFlowByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.Error(err) s.Contains(err.Error(), "failed to get database client") @@ -267,11 +287,11 @@ func (s *FlowStoreTestSuite) TestGetFlowByHandleDBError() { func (s *FlowStoreTestSuite) TestGetFlowByHandleQueryError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryGetFlowByHandle, "test-handle", + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowByHandle, "test-handle", string(common.FlowTypeAuthentication), "test-deployment").Return( nil, errors.New("query error")).Once() - flow, err := s.store.GetFlowByHandle("test-handle", common.FlowTypeAuthentication) + flow, err := s.store.GetFlowByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.Error(err) s.Contains(err.Error(), "failed to get flow by handle") @@ -282,11 +302,11 @@ func (s *FlowStoreTestSuite) TestGetFlowByHandleQueryError() { func (s *FlowStoreTestSuite) TestIsFlowExistsByHandleSuccess() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryCheckFlowExistsByHandle, "test-handle", + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCheckFlowExistsByHandle, "test-handle", string(common.FlowTypeAuthentication), "test-deployment").Return( []map[string]interface{}{{"exists": 1}}, nil).Once() - exists, err := s.store.IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication) + exists, err := s.store.IsFlowExistsByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.NoError(err) s.True(exists) @@ -294,11 +314,11 @@ func (s *FlowStoreTestSuite) TestIsFlowExistsByHandleSuccess() { func (s *FlowStoreTestSuite) TestIsFlowExistsByHandleNotFound() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryCheckFlowExistsByHandle, "non-existent", + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCheckFlowExistsByHandle, "non-existent", string(common.FlowTypeAuthentication), "test-deployment").Return( []map[string]interface{}{}, nil).Once() - exists, err := s.store.IsFlowExistsByHandle("non-existent", common.FlowTypeAuthentication) + exists, err := s.store.IsFlowExistsByHandle(context.Background(), "non-existent", common.FlowTypeAuthentication) s.NoError(err) s.False(exists) @@ -307,7 +327,7 @@ func (s *FlowStoreTestSuite) TestIsFlowExistsByHandleNotFound() { func (s *FlowStoreTestSuite) TestIsFlowExistsByHandleDBError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(nil, errors.New("connection error")) - exists, err := s.store.IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication) + exists, err := s.store.IsFlowExistsByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.Error(err) s.Contains(err.Error(), "failed to get database client") @@ -316,11 +336,11 @@ func (s *FlowStoreTestSuite) TestIsFlowExistsByHandleDBError() { func (s *FlowStoreTestSuite) TestIsFlowExistsByHandleQueryError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryCheckFlowExistsByHandle, "test-handle", + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCheckFlowExistsByHandle, "test-handle", string(common.FlowTypeAuthentication), "test-deployment").Return( nil, errors.New("query error")).Once() - exists, err := s.store.IsFlowExistsByHandle("test-handle", common.FlowTypeAuthentication) + exists, err := s.store.IsFlowExistsByHandle(context.Background(), "test-handle", common.FlowTypeAuthentication) s.Error(err) s.Contains(err.Error(), "failed to check flow existence by handle") @@ -332,7 +352,7 @@ func (s *FlowStoreTestSuite) TestIsFlowExistsByHandleQueryError() { func (s *FlowStoreTestSuite) TestListFlowVersionsDBError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(nil, errors.New("connection error")) - versions, err := s.store.ListFlowVersions("flow-1") + versions, err := s.store.ListFlowVersions(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "failed to get database client") @@ -341,10 +361,10 @@ func (s *FlowStoreTestSuite) TestListFlowVersionsDBError() { func (s *FlowStoreTestSuite) TestListFlowVersionsFlowNotFound() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryGetFlowInternalID, "flow-1", "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowInternalID, "flow-1", "test-deployment"). Return([]map[string]interface{}{}, nil).Once() - versions, err := s.store.ListFlowVersions("flow-1") + versions, err := s.store.ListFlowVersions(context.Background(), "flow-1") s.Error(err) s.Contains(err.Error(), "flow not found") @@ -355,10 +375,11 @@ func (s *FlowStoreTestSuite) TestListFlowVersionsFlowNotFound() { func (s *FlowStoreTestSuite) TestGetFlowVersionNotFound() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().Query(queryGetFlowVersionWithMetadata, "flow-1", 99, "test-deployment"). + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowVersionWithMetadata, "flow-1", 99, + "test-deployment"). Return([]map[string]interface{}{}, nil).Once() - version, err := s.store.GetFlowVersion("flow-1", 99) + version, err := s.store.GetFlowVersion(context.Background(), "flow-1", 99) s.Error(err) s.ErrorIs(err, errVersionNotFound) @@ -368,7 +389,7 @@ func (s *FlowStoreTestSuite) TestGetFlowVersionNotFound() { func (s *FlowStoreTestSuite) TestGetFlowVersionDBError() { s.mockDBProvider.EXPECT().GetConfigDBClient().Return(nil, errors.New("connection error")) - version, err := s.store.GetFlowVersion("flow-1", 1) + version, err := s.store.GetFlowVersion(context.Background(), "flow-1", 1) s.Error(err) s.Contains(err.Error(), "failed to get database client") @@ -377,11 +398,12 @@ func (s *FlowStoreTestSuite) TestGetFlowVersionDBError() { func (s *FlowStoreTestSuite) TestListFlowsWithTypeCountQueryError() { expectedError := errors.New("count query failed") - s.mockDBClient.EXPECT().Query(queryCountFlowsWithType, "authentication", s.store.deploymentID).Return( - nil, expectedError) s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCountFlowsWithType, "authentication", + s.store.deploymentID).Return( + nil, expectedError) - flows, count, err := s.store.ListFlows(10, 0, "authentication") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "authentication") s.Error(err) s.Nil(flows) @@ -390,14 +412,17 @@ func (s *FlowStoreTestSuite) TestListFlowsWithTypeCountQueryError() { } func (s *FlowStoreTestSuite) TestListFlowsWithTypeQueryError() { - s.mockDBClient.EXPECT().Query(queryCountFlowsWithType, "authentication", s.store.deploymentID).Return( + s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCountFlowsWithType, "authentication", + s.store.deploymentID).Return( []map[string]interface{}{{colCount: int64(5)}}, nil) expectedError := errors.New("list query failed") - s.mockDBClient.EXPECT().Query(queryListFlowsWithType, "authentication", s.store.deploymentID, 10, 0).Return( + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryListFlowsWithType, "authentication", + s.store.deploymentID, 10, + 0).Return( nil, expectedError) - s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - flows, count, err := s.store.ListFlows(10, 0, "authentication") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "authentication") s.Error(err) s.Nil(flows) @@ -406,15 +431,15 @@ func (s *FlowStoreTestSuite) TestListFlowsWithTypeQueryError() { } func (s *FlowStoreTestSuite) TestListFlowsBuildFlowError() { - s.mockDBClient.EXPECT().Query(queryCountFlows, s.store.deploymentID).Return( + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryCountFlows, s.store.deploymentID).Return( []map[string]interface{}{{colCount: int64(1)}}, nil) - s.mockDBClient.EXPECT().Query(queryListFlows, s.store.deploymentID, 10, 0).Return( + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryListFlows, s.store.deploymentID, 10, 0).Return( []map[string]interface{}{ {colFlowID: "flow-1"}, // Missing name field }, nil) s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - flows, count, err := s.store.ListFlows(10, 0, "") + flows, count, err := s.store.ListFlows(context.Background(), 10, 0, "") s.Error(err) s.Nil(flows) @@ -425,16 +450,17 @@ func (s *FlowStoreTestSuite) TestListFlowsBuildFlowError() { func (s *FlowStoreTestSuite) TestListFlowVersionsQueryError() { expectedError := errors.New("query failed") // First mock getFlowInternalID call - s.mockDBClient.EXPECT().Query(queryGetFlowInternalID, "flow-123", s.store.deploymentID).Return( - []map[string]interface{}{ - {"id": int64(1)}, - }, nil) + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowInternalID, "flow-123", s.store.deploymentID). + Return( + []map[string]interface{}{ + {"id": int64(1)}, + }, nil) // Then mock the list query that fails - s.mockDBClient.EXPECT().Query(queryListFlowVersions, int64(1), s.store.deploymentID).Return( + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryListFlowVersions, int64(1), s.store.deploymentID).Return( nil, expectedError) s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - versions, err := s.store.ListFlowVersions("flow-123") + versions, err := s.store.ListFlowVersions(context.Background(), "flow-123") s.Error(err) s.Nil(versions) @@ -443,18 +469,19 @@ func (s *FlowStoreTestSuite) TestListFlowVersionsQueryError() { func (s *FlowStoreTestSuite) TestListFlowVersionsBuildVersionError() { // First mock getFlowInternalID call - s.mockDBClient.EXPECT().Query(queryGetFlowInternalID, "flow-123", s.store.deploymentID).Return( - []map[string]interface{}{ - {"id": int64(1)}, - }, nil) + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowInternalID, "flow-123", s.store.deploymentID). + Return( + []map[string]interface{}{ + {"id": int64(1)}, + }, nil) // Then mock the list query with invalid data - s.mockDBClient.EXPECT().Query(queryListFlowVersions, int64(1), s.store.deploymentID).Return( + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryListFlowVersions, int64(1), s.store.deploymentID).Return( []map[string]interface{}{ {colVersion: "invalid"}, // Invalid version type }, nil) s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - versions, err := s.store.ListFlowVersions("flow-123") + versions, err := s.store.ListFlowVersions(context.Background(), "flow-123") s.Error(err) s.Empty(versions) // Returns empty slice on error, not nil @@ -463,11 +490,12 @@ func (s *FlowStoreTestSuite) TestListFlowVersionsBuildVersionError() { func (s *FlowStoreTestSuite) TestGetFlowVersionQueryError() { expectedError := errors.New("query failed") - s.mockDBClient.EXPECT().Query(queryGetFlowVersionWithMetadata, "flow-123", 5, s.store.deploymentID).Return( + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowVersionWithMetadata, "flow-123", 5, + s.store.deploymentID).Return( nil, expectedError) s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - version, err := s.store.GetFlowVersion("flow-123", 5) + version, err := s.store.GetFlowVersion(context.Background(), "flow-123", 5) s.Error(err) s.Nil(version) @@ -475,13 +503,14 @@ func (s *FlowStoreTestSuite) TestGetFlowVersionQueryError() { } func (s *FlowStoreTestSuite) TestGetFlowVersionBuildError() { - s.mockDBClient.EXPECT().Query(queryGetFlowVersionWithMetadata, "flow-123", 5, s.store.deploymentID).Return( + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowVersionWithMetadata, "flow-123", 5, + s.store.deploymentID).Return( []map[string]interface{}{ {colFlowID: 123}, // Invalid type - should be string }, nil) s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - version, err := s.store.GetFlowVersion("flow-123", 5) + version, err := s.store.GetFlowVersion(context.Background(), "flow-123", 5) s.Error(err) s.Nil(version) @@ -489,12 +518,13 @@ func (s *FlowStoreTestSuite) TestGetFlowVersionBuildError() { } func (s *FlowStoreTestSuite) TestGetFlowInternalIDMissingField() { - s.mockDBClient.EXPECT().Query(queryGetFlowInternalID, "flow-123", s.store.deploymentID).Return( - []map[string]interface{}{ - {"wrong_field": int64(1)}, - }, nil) + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowInternalID, "flow-123", s.store.deploymentID). + Return( + []map[string]interface{}{ + {"wrong_field": int64(1)}, + }, nil) - internalID, err := s.store.getFlowInternalID(s.mockDBClient, "flow-123") + internalID, err := s.store.getFlowInternalID(context.Background(), s.mockDBClient, "flow-123") s.Error(err) s.Equal(int64(0), internalID) @@ -502,12 +532,13 @@ func (s *FlowStoreTestSuite) TestGetFlowInternalIDMissingField() { } func (s *FlowStoreTestSuite) TestGetFlowInternalIDInvalidType() { - s.mockDBClient.EXPECT().Query(queryGetFlowInternalID, "flow-123", s.store.deploymentID).Return( - []map[string]interface{}{ - {"id": "not-an-int"}, // Wrong type - }, nil) + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowInternalID, "flow-123", s.store.deploymentID). + Return( + []map[string]interface{}{ + {"id": "not-an-int"}, // Wrong type + }, nil) - internalID, err := s.store.getFlowInternalID(s.mockDBClient, "flow-123") + internalID, err := s.store.getFlowInternalID(context.Background(), s.mockDBClient, "flow-123") s.Error(err) s.Equal(int64(0), internalID) @@ -516,10 +547,11 @@ func (s *FlowStoreTestSuite) TestGetFlowInternalIDInvalidType() { func (s *FlowStoreTestSuite) TestGetFlowInternalIDQueryError() { expectedError := errors.New("query failed") - s.mockDBClient.EXPECT().Query(queryGetFlowInternalID, "flow-123", s.store.deploymentID).Return( - nil, expectedError) + s.mockDBClient.EXPECT().QueryContext(mock.Anything, queryGetFlowInternalID, "flow-123", s.store.deploymentID). + Return( + nil, expectedError) - internalID, err := s.store.getFlowInternalID(s.mockDBClient, "flow-123") + internalID, err := s.store.getFlowInternalID(context.Background(), s.mockDBClient, "flow-123") s.Error(err) s.Equal(int64(0), internalID) @@ -636,14 +668,13 @@ func (s *FlowStoreTestSuite) TestCreateFlow_BeginTxError() { Nodes: []NodeDefinition{{Type: "start", ID: "node1"}}, } - s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().BeginTx().Return(nil, errors.New("tx error")) + s.mockTransactioner.On("Transact", mock.Anything, mock.Anything).Return(errors.New("tx error")) - result, err := s.store.CreateFlow("flow-1", flowDef) + result, err := s.store.CreateFlow(context.Background(), "flow-1", flowDef) s.Error(err) s.Nil(result) - s.Contains(err.Error(), "failed to begin transaction") + s.Contains(err.Error(), "tx error") } func (s *FlowStoreTestSuite) TestCreateFlow_ExecError() { @@ -654,14 +685,20 @@ func (s *FlowStoreTestSuite) TestCreateFlow_ExecError() { Nodes: []NodeDefinition{{Type: "start", ID: "node1"}}, } - mockTx := modelmock.NewTxInterfaceMock(s.T()) + s.mockTransactioner.On("Transact", mock.Anything, mock.Anything).Return(nil) s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().BeginTx().Return(mockTx, nil) - mockTx.EXPECT().Exec(queryCreateFlow, "flow-1", "login-handle", "Login Flow", common.FlowTypeAuthentication, - int64(1), s.store.deploymentID).Return(nil, errors.New("insert error")) - mockTx.EXPECT().Rollback().Return(nil) + s.mockDBClient.EXPECT().ExecuteContext( + mock.Anything, + queryCreateFlow, + "flow-1", + "login-handle", + "Login Flow", + common.FlowTypeAuthentication, + int64(1), + s.store.deploymentID, + ).Return(int64(0), errors.New("insert error")) - result, err := s.store.CreateFlow("flow-1", flowDef) + result, err := s.store.CreateFlow(context.Background(), "flow-1", flowDef) s.Error(err) s.Nil(result) @@ -676,25 +713,23 @@ func (s *FlowStoreTestSuite) TestUpdateFlow_BeginTxError() { Nodes: []NodeDefinition{}, } - s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().BeginTx().Return(nil, errors.New("tx error")) + s.mockTransactioner.On("Transact", mock.Anything, mock.Anything).Return(errors.New("tx error")) - result, err := s.store.UpdateFlow("flow-1", flowDef) + result, err := s.store.UpdateFlow(context.Background(), "flow-1", flowDef) s.Error(err) s.Nil(result) - s.Contains(err.Error(), "failed to begin transaction") + s.Contains(err.Error(), "tx error") } func (s *FlowStoreTestSuite) TestRestoreFlowVersion_BeginTxError() { - s.mockDBProvider.EXPECT().GetConfigDBClient().Return(s.mockDBClient, nil) - s.mockDBClient.EXPECT().BeginTx().Return(nil, errors.New("tx error")) + s.mockTransactioner.On("Transact", mock.Anything, mock.Anything).Return(errors.New("tx error")) - result, err := s.store.RestoreFlowVersion("flow-1", 1) + result, err := s.store.RestoreFlowVersion(context.Background(), "flow-1", 1) s.Error(err) s.Nil(result) - s.Contains(err.Error(), "failed to begin transaction") + s.Contains(err.Error(), "tx error") } // Helper Function Tests diff --git a/backend/internal/mcp/tools/flow/tool.go b/backend/internal/mcp/tools/flow/tool.go index 4778ddabf..8dd16b800 100644 --- a/backend/internal/mcp/tools/flow/tool.go +++ b/backend/internal/mcp/tools/flow/tool.go @@ -165,7 +165,7 @@ func (t *flowTools) listFlows( flowType := flowCommon.FlowType(input.FlowType) - listResponse, svcErr := t.flowService.ListFlows(limit, input.Offset, flowType) + listResponse, svcErr := t.flowService.ListFlows(ctx, limit, input.Offset, flowType) if svcErr != nil { return nil, flowListOutput{}, fmt.Errorf("failed to list flows: %s", svcErr.ErrorDescription) } @@ -184,7 +184,7 @@ func (t *flowTools) getFlowByHandle( ) (*mcp.CallToolResult, *flowmgt.CompleteFlowDefinition, error) { flowType := flowCommon.FlowType(input.FlowType) - flow, svcErr := t.flowService.GetFlowByHandle(input.Handle, flowType) + flow, svcErr := t.flowService.GetFlowByHandle(ctx, input.Handle, flowType) if svcErr != nil { return nil, nil, fmt.Errorf("failed to get flow by handle: %s", svcErr.ErrorDescription) } @@ -198,7 +198,7 @@ func (t *flowTools) getFlowByID( req *mcp.CallToolRequest, input common.IDInput, ) (*mcp.CallToolResult, *flowmgt.CompleteFlowDefinition, error) { - flow, svcErr := t.flowService.GetFlow(input.ID) + flow, svcErr := t.flowService.GetFlow(ctx, input.ID) if svcErr != nil { return nil, nil, fmt.Errorf("failed to get flow: %s", svcErr.ErrorDescription) } @@ -212,7 +212,7 @@ func (t *flowTools) createFlow( req *mcp.CallToolRequest, input flowmgt.FlowDefinition, ) (*mcp.CallToolResult, *flowmgt.CompleteFlowDefinition, error) { - createdFlow, svcErr := t.flowService.CreateFlow(&input) + createdFlow, svcErr := t.flowService.CreateFlow(ctx, &input) if svcErr != nil { return nil, nil, fmt.Errorf("failed to create flow: %s", svcErr.ErrorDescription) } @@ -227,7 +227,7 @@ func (t *flowTools) updateFlow( input updateFlowInput, ) (*mcp.CallToolResult, *flowmgt.CompleteFlowDefinition, error) { // Get current flow to retrieve immutable fields (handle, flowType) - currentFlow, svcErr := t.flowService.GetFlow(input.ID) + currentFlow, svcErr := t.flowService.GetFlow(ctx, input.ID) if svcErr != nil { return nil, nil, fmt.Errorf("failed to get flow: %s", svcErr.ErrorDescription) } @@ -240,7 +240,7 @@ func (t *flowTools) updateFlow( Nodes: input.Nodes, } - updatedFlow, svcErr := t.flowService.UpdateFlow(input.ID, updateDef) + updatedFlow, svcErr := t.flowService.UpdateFlow(ctx, input.ID, updateDef) if svcErr != nil { return nil, nil, fmt.Errorf("failed to update flow: %s", svcErr.ErrorDescription) } diff --git a/backend/internal/mcp/tools/flow/tool_test.go b/backend/internal/mcp/tools/flow/tool_test.go index 8da9ff017..6a7e1d647 100644 --- a/backend/internal/mcp/tools/flow/tool_test.go +++ b/backend/internal/mcp/tools/flow/tool_test.go @@ -63,7 +63,7 @@ func (suite *FlowToolsTestSuite) TestListFlows() { TotalResults: 1, Flows: mockFlows, } - suite.mockFlowService.EXPECT().ListFlows(100, 0, flowCommon.FlowType("")).Return(mockResponse, nil) + suite.mockFlowService.EXPECT().ListFlows(mock.Anything, 100, 0, flowCommon.FlowType("")).Return(mockResponse, nil) input := listFlowsInput{ PaginationInput: common.PaginationInput{ @@ -83,7 +83,7 @@ func (suite *FlowToolsTestSuite) TestListFlows_Error() { Code: "ERR_LIST", Error: "Failed to list flows", } - suite.mockFlowService.EXPECT().ListFlows(10, 0, flowCommon.FlowType("")).Return(nil, expectedErr) + suite.mockFlowService.EXPECT().ListFlows(mock.Anything, 10, 0, flowCommon.FlowType("")).Return(nil, expectedErr) input := listFlowsInput{ PaginationInput: common.PaginationInput{ @@ -103,7 +103,8 @@ func (suite *FlowToolsTestSuite) TestGetFlowByHandle() { Handle: "handle-1", FlowType: flowCommon.FlowTypeAuthentication, } - suite.mockFlowService.EXPECT().GetFlowByHandle("handle-1", flowCommon.FlowTypeAuthentication).Return(mockFlow, nil) + suite.mockFlowService.EXPECT().GetFlowByHandle(mock.Anything, "handle-1", flowCommon.FlowTypeAuthentication). + Return(mockFlow, nil) input := getFlowByHandleInput{ Handle: "handle-1", @@ -122,7 +123,7 @@ func (suite *FlowToolsTestSuite) TestGetFlowByID() { Handle: "handle-1", FlowType: flowCommon.FlowTypeAuthentication, } - suite.mockFlowService.EXPECT().GetFlow("flow-1").Return(mockFlow, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, "flow-1").Return(mockFlow, nil) input := common.IDInput{ ID: "flow-1", @@ -144,7 +145,7 @@ func (suite *FlowToolsTestSuite) TestCreateFlow() { Handle: "new-flow", FlowType: flowCommon.FlowTypeRegistration, } - suite.mockFlowService.EXPECT().CreateFlow(&input).Return(createdFlow, nil) + suite.mockFlowService.EXPECT().CreateFlow(mock.Anything, &input).Return(createdFlow, nil) result, output, err := suite.tools.createFlow(ctx(), nil, input) @@ -162,7 +163,7 @@ func (suite *FlowToolsTestSuite) TestCreateFlow_Error() { Code: "ERR_CREATE", Error: "Failed to create flow", } - suite.mockFlowService.EXPECT().CreateFlow(&input).Return(nil, expectedErr) + suite.mockFlowService.EXPECT().CreateFlow(mock.Anything, &input).Return(nil, expectedErr) result, output, err := suite.tools.createFlow(ctx(), nil, input) @@ -193,12 +194,13 @@ func (suite *FlowToolsTestSuite) TestUpdateFlow() { Handle: "updated-handle", FlowType: flowCommon.FlowTypeAuthentication, } - suite.mockFlowService.EXPECT().GetFlow("flow-1").Return(currentFlow, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, "flow-1").Return(currentFlow, nil) // Expect UpdateFlow with correct definition construction - suite.mockFlowService.EXPECT().UpdateFlow("flow-1", mock.MatchedBy(func(def *flowmgt.FlowDefinition) bool { - return def.Handle == "updated-handle" && def.Name == "Updated Flow" - })).Return(updatedFlow, nil) + suite.mockFlowService.EXPECT().UpdateFlow(mock.Anything, "flow-1", mock.MatchedBy( + func(def *flowmgt.FlowDefinition) bool { + return def.Handle == "updated-handle" && def.Name == "Updated Flow" + })).Return(updatedFlow, nil) result, output, err := suite.tools.updateFlow(ctx(), nil, input) diff --git a/backend/internal/system/export/service_test.go b/backend/internal/system/export/service_test.go index ddc4d021d..e7a033935 100644 --- a/backend/internal/system/export/service_test.go +++ b/backend/internal/system/export/service_test.go @@ -43,6 +43,7 @@ import ( "github.com/asgardeo/thunder/tests/mocks/userschemamock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -2156,7 +2157,7 @@ func (suite *ExportServiceTestSuite) TestExportResourcesWithExporter_Flow() { UpdatedAt: "2025-12-22 10:00:00", } - suite.mockFlowService.EXPECT().GetFlow(flowID).Return(mockFlow, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, flowID).Return(mockFlow, nil) exporter, exists := suite.exportService.(*exportService).registry.Get("flow") assert.True(suite.T(), exists, "Flow exporter should be registered") @@ -2250,7 +2251,7 @@ func (suite *ExportServiceTestSuite) TestExportResourcesWithExporter_FlowWithCom UpdatedAt: "2025-12-22 10:00:00", } - suite.mockFlowService.EXPECT().GetFlow(flowID).Return(mockFlow, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, flowID).Return(mockFlow, nil) exporter, exists := suite.exportService.(*exportService).registry.Get("flow") assert.True(suite.T(), exists) @@ -2296,8 +2297,8 @@ func (suite *ExportServiceTestSuite) TestExportResourcesWithExporter_MultipleFlo }, } - suite.mockFlowService.EXPECT().GetFlow(testFlow1ID).Return(flow1, nil) - suite.mockFlowService.EXPECT().GetFlow(testFlow2ID).Return(flow2, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, testFlow1ID).Return(flow1, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, testFlow2ID).Return(flow2, nil) exporter, _ := suite.exportService.(*exportService).registry.Get("flow") options := &ExportOptions{Format: formatYAML} @@ -2318,7 +2319,7 @@ func (suite *ExportServiceTestSuite) TestExportResourcesWithExporter_FlowNotFoun Code: "FLOW_NOT_FOUND", Error: "Flow not found", } - suite.mockFlowService.EXPECT().GetFlow(flowID).Return(nil, flowError) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, flowID).Return(nil, flowError) exporter, _ := suite.exportService.(*exportService).registry.Get("flow") options := &ExportOptions{Format: formatYAML} @@ -2381,9 +2382,9 @@ func (suite *ExportServiceTestSuite) TestExportResourcesWithExporter_WildcardFlo }, } - suite.mockFlowService.EXPECT().ListFlows(10000, 0, flowcommon.FlowType("")).Return(flowList, nil) - suite.mockFlowService.EXPECT().GetFlow(testFlow1ID).Return(flow1Complete, nil) - suite.mockFlowService.EXPECT().GetFlow(testFlow2ID).Return(flow2Complete, nil) + suite.mockFlowService.EXPECT().ListFlows(mock.Anything, 10000, 0, flowcommon.FlowType("")).Return(flowList, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, testFlow1ID).Return(flow1Complete, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, testFlow2ID).Return(flow2Complete, nil) exporter, _ := suite.exportService.(*exportService).registry.Get("flow") options := &ExportOptions{Format: formatYAML} @@ -2401,7 +2402,7 @@ func (suite *ExportServiceTestSuite) TestExportResourcesWithExporter_WildcardFlo Code: "DB_ERROR", Error: "Database error", } - suite.mockFlowService.EXPECT().ListFlows(10000, 0, flowcommon.FlowType("")).Return(nil, dbError) + suite.mockFlowService.EXPECT().ListFlows(mock.Anything, 10000, 0, flowcommon.FlowType("")).Return(nil, dbError) exporter, _ := suite.exportService.(*exportService).registry.Get("flow") options := &ExportOptions{Format: formatYAML} @@ -2430,7 +2431,7 @@ func (suite *ExportServiceTestSuite) TestExportResources_FlowOnly() { UpdatedAt: "2025-12-22 10:00:00", } - suite.mockFlowService.EXPECT().GetFlow(flowID).Return(mockFlow, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, flowID).Return(mockFlow, nil) request := &ExportRequest{ Flows: []string{flowID}, @@ -2474,7 +2475,7 @@ func (suite *ExportServiceTestSuite) TestExportResources_MixedWithFlows() { } suite.appServiceMock.EXPECT().GetApplication(appID).Return(mockApp, nil) - suite.mockFlowService.EXPECT().GetFlow(flowID).Return(mockFlow, nil) + suite.mockFlowService.EXPECT().GetFlow(mock.Anything, flowID).Return(mockFlow, nil) request := &ExportRequest{ Applications: []string{appID}, diff --git a/backend/tests/mocks/flow/flowmgtmock/FlowMgtServiceInterface_mock.go b/backend/tests/mocks/flow/flowmgtmock/FlowMgtServiceInterface_mock.go index f3465e4c5..f802dadb0 100644 --- a/backend/tests/mocks/flow/flowmgtmock/FlowMgtServiceInterface_mock.go +++ b/backend/tests/mocks/flow/flowmgtmock/FlowMgtServiceInterface_mock.go @@ -5,9 +5,11 @@ package flowmgtmock import ( + "context" + "github.com/asgardeo/thunder/internal/flow/common" "github.com/asgardeo/thunder/internal/flow/core" - "github.com/asgardeo/thunder/internal/flow/mgt" + flowmgt "github.com/asgardeo/thunder/internal/flow/mgt" "github.com/asgardeo/thunder/internal/system/error/serviceerror" mock "github.com/stretchr/testify/mock" ) @@ -40,8 +42,8 @@ func (_m *FlowMgtServiceInterfaceMock) EXPECT() *FlowMgtServiceInterfaceMock_Exp } // CreateFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) CreateFlow(flowDef *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(flowDef) +func (_mock *FlowMgtServiceInterfaceMock) CreateFlow(ctx context.Context, flowDef *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowDef) if len(ret) == 0 { panic("no return value specified for CreateFlow") @@ -49,18 +51,18 @@ func (_mock *FlowMgtServiceInterfaceMock) CreateFlow(flowDef *flowmgt.FlowDefini var r0 *flowmgt.CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(*flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(flowDef) + if returnFunc, ok := ret.Get(0).(func(context.Context, *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowDef) } - if returnFunc, ok := ret.Get(0).(func(*flowmgt.FlowDefinition) *flowmgt.CompleteFlowDefinition); ok { - r0 = returnFunc(flowDef) + if returnFunc, ok := ret.Get(0).(func(context.Context, *flowmgt.FlowDefinition) *flowmgt.CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowDef) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*flowmgt.CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(*flowmgt.FlowDefinition) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowDef) + if returnFunc, ok := ret.Get(1).(func(context.Context, *flowmgt.FlowDefinition) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowDef) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -76,18 +78,25 @@ type FlowMgtServiceInterfaceMock_CreateFlow_Call struct { // CreateFlow is a helper method to define mock.On call // - flowDef *flowmgt.FlowDefinition -func (_e *FlowMgtServiceInterfaceMock_Expecter) CreateFlow(flowDef interface{}) *FlowMgtServiceInterfaceMock_CreateFlow_Call { - return &FlowMgtServiceInterfaceMock_CreateFlow_Call{Call: _e.mock.On("CreateFlow", flowDef)} +// - ctx context.Context +// - flowDef *flowmgt.FlowDefinition +func (_e *FlowMgtServiceInterfaceMock_Expecter) CreateFlow(ctx interface{}, flowDef interface{}) *FlowMgtServiceInterfaceMock_CreateFlow_Call { + return &FlowMgtServiceInterfaceMock_CreateFlow_Call{Call: _e.mock.On("CreateFlow", ctx, flowDef)} } -func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) Run(run func(flowDef *flowmgt.FlowDefinition)) *FlowMgtServiceInterfaceMock_CreateFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) Run(run func(ctx context.Context, flowDef *flowmgt.FlowDefinition)) *FlowMgtServiceInterfaceMock_CreateFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 *flowmgt.FlowDefinition + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(*flowmgt.FlowDefinition) + arg0 = args[0].(context.Context) + } + var arg1 *flowmgt.FlowDefinition + if args[1] != nil { + arg1 = args[1].(*flowmgt.FlowDefinition) } run( arg0, + arg1, ) }) return _c @@ -98,22 +107,22 @@ func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) Return(completeFlowDefini return _c } -func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) RunAndReturn(run func(flowDef *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_CreateFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_CreateFlow_Call) RunAndReturn(run func(ctx context.Context, flowDef *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_CreateFlow_Call { _c.Call.Return(run) return _c } // DeleteFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) DeleteFlow(flowID string) *serviceerror.ServiceError { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) DeleteFlow(ctx context.Context, flowID string) *serviceerror.ServiceError { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for DeleteFlow") } var r0 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string) *serviceerror.ServiceError); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *serviceerror.ServiceError); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*serviceerror.ServiceError) @@ -129,18 +138,25 @@ type FlowMgtServiceInterfaceMock_DeleteFlow_Call struct { // DeleteFlow is a helper method to define mock.On call // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) DeleteFlow(flowID interface{}) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { - return &FlowMgtServiceInterfaceMock_DeleteFlow_Call{Call: _e.mock.On("DeleteFlow", flowID)} +// - ctx context.Context +// - flowID string +func (_e *FlowMgtServiceInterfaceMock_Expecter) DeleteFlow(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { + return &FlowMgtServiceInterfaceMock_DeleteFlow_Call{Call: _e.mock.On("DeleteFlow", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -151,14 +167,14 @@ func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) Return(serviceError *serv return _c } -func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) RunAndReturn(run func(flowID string) *serviceerror.ServiceError) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_DeleteFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string) *serviceerror.ServiceError) *FlowMgtServiceInterfaceMock_DeleteFlow_Call { _c.Call.Return(run) return _c } // GetFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) GetFlow(flowID string) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) GetFlow(ctx context.Context, flowID string) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for GetFlow") @@ -166,18 +182,18 @@ func (_mock *FlowMgtServiceInterfaceMock) GetFlow(flowID string) (*flowmgt.Compl var r0 *flowmgt.CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) *flowmgt.CompleteFlowDefinition); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *flowmgt.CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*flowmgt.CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -193,18 +209,25 @@ type FlowMgtServiceInterfaceMock_GetFlow_Call struct { // GetFlow is a helper method to define mock.On call // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlow(flowID interface{}) *FlowMgtServiceInterfaceMock_GetFlow_Call { - return &FlowMgtServiceInterfaceMock_GetFlow_Call{Call: _e.mock.On("GetFlow", flowID)} +// - ctx context.Context +// - flowID string +func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlow(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_GetFlow_Call { + return &FlowMgtServiceInterfaceMock_GetFlow_Call{Call: _e.mock.On("GetFlow", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_GetFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_GetFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -215,14 +238,14 @@ func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) Return(completeFlowDefinitio return _c } -func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) RunAndReturn(run func(flowID string) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlow_Call { _c.Call.Return(run) return _c } // GetFlowByHandle provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) GetFlowByHandle(handle string, flowType common.FlowType) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(handle, flowType) +func (_mock *FlowMgtServiceInterfaceMock) GetFlowByHandle(ctx context.Context, handle string, flowType common.FlowType) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, handle, flowType) if len(ret) == 0 { panic("no return value specified for GetFlowByHandle") @@ -230,18 +253,18 @@ func (_mock *FlowMgtServiceInterfaceMock) GetFlowByHandle(handle string, flowTyp var r0 *flowmgt.CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string, common.FlowType) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, common.FlowType) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, handle, flowType) } - if returnFunc, ok := ret.Get(0).(func(string, common.FlowType) *flowmgt.CompleteFlowDefinition); ok { - r0 = returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, common.FlowType) *flowmgt.CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, handle, flowType) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*flowmgt.CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, common.FlowType) *serviceerror.ServiceError); ok { - r1 = returnFunc(handle, flowType) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, common.FlowType) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, handle, flowType) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -256,25 +279,31 @@ type FlowMgtServiceInterfaceMock_GetFlowByHandle_Call struct { } // GetFlowByHandle is a helper method to define mock.On call +// - ctx context.Context // - handle string // - flowType common.FlowType -func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlowByHandle(handle interface{}, flowType interface{}) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { - return &FlowMgtServiceInterfaceMock_GetFlowByHandle_Call{Call: _e.mock.On("GetFlowByHandle", handle, flowType)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlowByHandle(ctx interface{}, handle interface{}, flowType interface{}) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { + return &FlowMgtServiceInterfaceMock_GetFlowByHandle_Call{Call: _e.mock.On("GetFlowByHandle", ctx, handle, flowType)} } -func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) Run(run func(handle string, flowType common.FlowType)) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) Run(run func(ctx context.Context, handle string, flowType common.FlowType)) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 common.FlowType + var arg1 string if args[1] != nil { - arg1 = args[1].(common.FlowType) + arg1 = args[1].(string) + } + var arg2 common.FlowType + if args[2] != nil { + arg2 = args[2].(common.FlowType) } run( arg0, arg1, + arg2, ) }) return _c @@ -285,14 +314,14 @@ func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) Return(completeFlowD return _c } -func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) RunAndReturn(run func(handle string, flowType common.FlowType) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call) RunAndReturn(run func(ctx context.Context, handle string, flowType common.FlowType) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlowByHandle_Call { _c.Call.Return(run) return _c } // GetFlowVersion provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) GetFlowVersion(flowID string, version int) (*flowmgt.FlowVersion, *serviceerror.ServiceError) { - ret := _mock.Called(flowID, version) +func (_mock *FlowMgtServiceInterfaceMock) GetFlowVersion(ctx context.Context, flowID string, version int) (*flowmgt.FlowVersion, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID, version) if len(ret) == 0 { panic("no return value specified for GetFlowVersion") @@ -300,18 +329,18 @@ func (_mock *FlowMgtServiceInterfaceMock) GetFlowVersion(flowID string, version var r0 *flowmgt.FlowVersion var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string, int) (*flowmgt.FlowVersion, *serviceerror.ServiceError)); ok { - return returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) (*flowmgt.FlowVersion, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID, version) } - if returnFunc, ok := ret.Get(0).(func(string, int) *flowmgt.FlowVersion); ok { - r0 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) *flowmgt.FlowVersion); ok { + r0 = returnFunc(ctx, flowID, version) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*flowmgt.FlowVersion) } } - if returnFunc, ok := ret.Get(1).(func(string, int) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, int) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID, version) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -326,25 +355,31 @@ type FlowMgtServiceInterfaceMock_GetFlowVersion_Call struct { } // GetFlowVersion is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - version int -func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlowVersion(flowID interface{}, version interface{}) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { - return &FlowMgtServiceInterfaceMock_GetFlowVersion_Call{Call: _e.mock.On("GetFlowVersion", flowID, version)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) GetFlowVersion(ctx interface{}, flowID interface{}, version interface{}) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { + return &FlowMgtServiceInterfaceMock_GetFlowVersion_Call{Call: _e.mock.On("GetFlowVersion", ctx, flowID, version)} } -func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) Run(run func(flowID string, version int)) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) Run(run func(ctx context.Context, flowID string, version int)) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 int + var arg1 string if args[1] != nil { - arg1 = args[1].(int) + arg1 = args[1].(string) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) } run( arg0, arg1, + arg2, ) }) return _c @@ -355,14 +390,14 @@ func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) Return(flowVersion *f return _c } -func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) RunAndReturn(run func(flowID string, version int) (*flowmgt.FlowVersion, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { +func (_c *FlowMgtServiceInterfaceMock_GetFlowVersion_Call) RunAndReturn(run func(ctx context.Context, flowID string, version int) (*flowmgt.FlowVersion, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetFlowVersion_Call { _c.Call.Return(run) return _c } // GetGraph provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) GetGraph(flowID string) (core.GraphInterface, *serviceerror.ServiceError) { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) GetGraph(ctx context.Context, flowID string) (core.GraphInterface, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for GetGraph") @@ -370,18 +405,18 @@ func (_mock *FlowMgtServiceInterfaceMock) GetGraph(flowID string) (core.GraphInt var r0 core.GraphInterface var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string) (core.GraphInterface, *serviceerror.ServiceError)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (core.GraphInterface, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) core.GraphInterface); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) core.GraphInterface); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(core.GraphInterface) } } - if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -396,19 +431,25 @@ type FlowMgtServiceInterfaceMock_GetGraph_Call struct { } // GetGraph is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) GetGraph(flowID interface{}) *FlowMgtServiceInterfaceMock_GetGraph_Call { - return &FlowMgtServiceInterfaceMock_GetGraph_Call{Call: _e.mock.On("GetGraph", flowID)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) GetGraph(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_GetGraph_Call { + return &FlowMgtServiceInterfaceMock_GetGraph_Call{Call: _e.mock.On("GetGraph", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_GetGraph_Call { +func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_GetGraph_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -419,22 +460,22 @@ func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) Return(graphInterface core. return _c } -func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) RunAndReturn(run func(flowID string) (core.GraphInterface, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetGraph_Call { +func (_c *FlowMgtServiceInterfaceMock_GetGraph_Call) RunAndReturn(run func(ctx context.Context, flowID string) (core.GraphInterface, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_GetGraph_Call { _c.Call.Return(run) return _c } // IsValidFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) IsValidFlow(flowID string) bool { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) IsValidFlow(ctx context.Context, flowID string) bool { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for IsValidFlow") } var r0 bool - if returnFunc, ok := ret.Get(0).(func(string) bool); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = returnFunc(ctx, flowID) } else { r0 = ret.Get(0).(bool) } @@ -447,19 +488,25 @@ type FlowMgtServiceInterfaceMock_IsValidFlow_Call struct { } // IsValidFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) IsValidFlow(flowID interface{}) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { - return &FlowMgtServiceInterfaceMock_IsValidFlow_Call{Call: _e.mock.On("IsValidFlow", flowID)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) IsValidFlow(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { + return &FlowMgtServiceInterfaceMock_IsValidFlow_Call{Call: _e.mock.On("IsValidFlow", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -470,14 +517,14 @@ func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) Return(b bool) *FlowMgtS return _c } -func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) RunAndReturn(run func(flowID string) bool) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_IsValidFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string) bool) *FlowMgtServiceInterfaceMock_IsValidFlow_Call { _c.Call.Return(run) return _c } // ListFlowVersions provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) ListFlowVersions(flowID string) (*flowmgt.FlowVersionListResponse, *serviceerror.ServiceError) { - ret := _mock.Called(flowID) +func (_mock *FlowMgtServiceInterfaceMock) ListFlowVersions(ctx context.Context, flowID string) (*flowmgt.FlowVersionListResponse, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID) if len(ret) == 0 { panic("no return value specified for ListFlowVersions") @@ -485,11 +532,11 @@ func (_mock *FlowMgtServiceInterfaceMock) ListFlowVersions(flowID string) (*flow var r0 *flowmgt.FlowVersionListResponse var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string) (*flowmgt.FlowVersionListResponse, *serviceerror.ServiceError)); ok { - return returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*flowmgt.FlowVersionListResponse, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID) } - if returnFunc, ok := ret.Get(0).(func(string) *flowmgt.FlowVersionListResponse); ok { - r0 = returnFunc(flowID) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) *flowmgt.FlowVersionListResponse); ok { + r0 = returnFunc(ctx, flowID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*flowmgt.FlowVersionListResponse) @@ -511,19 +558,25 @@ type FlowMgtServiceInterfaceMock_ListFlowVersions_Call struct { } // ListFlowVersions is a helper method to define mock.On call +// - ctx context.Context // - flowID string -func (_e *FlowMgtServiceInterfaceMock_Expecter) ListFlowVersions(flowID interface{}) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { - return &FlowMgtServiceInterfaceMock_ListFlowVersions_Call{Call: _e.mock.On("ListFlowVersions", flowID)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) ListFlowVersions(ctx interface{}, flowID interface{}) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { + return &FlowMgtServiceInterfaceMock_ListFlowVersions_Call{Call: _e.mock.On("ListFlowVersions", ctx, flowID)} } -func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) Run(run func(flowID string)) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { +func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) Run(run func(ctx context.Context, flowID string)) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) } run( arg0, + arg1, ) }) return _c @@ -534,14 +587,14 @@ func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) Return(flowVersionL return _c } -func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) RunAndReturn(run func(flowID string) (*flowmgt.FlowVersionListResponse, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { +func (_c *FlowMgtServiceInterfaceMock_ListFlowVersions_Call) RunAndReturn(run func(ctx context.Context, flowID string) (*flowmgt.FlowVersionListResponse, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_ListFlowVersions_Call { _c.Call.Return(run) return _c } // ListFlows provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) ListFlows(limit int, offset int, flowType common.FlowType) (*flowmgt.FlowListResponse, *serviceerror.ServiceError) { - ret := _mock.Called(limit, offset, flowType) +func (_mock *FlowMgtServiceInterfaceMock) ListFlows(ctx context.Context, limit int, offset int, flowType common.FlowType) (*flowmgt.FlowListResponse, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, limit, offset, flowType) if len(ret) == 0 { panic("no return value specified for ListFlows") @@ -549,18 +602,18 @@ func (_mock *FlowMgtServiceInterfaceMock) ListFlows(limit int, offset int, flowT var r0 *flowmgt.FlowListResponse var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(int, int, common.FlowType) (*flowmgt.FlowListResponse, *serviceerror.ServiceError)); ok { - return returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, int, int, common.FlowType) (*flowmgt.FlowListResponse, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, limit, offset, flowType) } - if returnFunc, ok := ret.Get(0).(func(int, int, common.FlowType) *flowmgt.FlowListResponse); ok { - r0 = returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(0).(func(context.Context, int, int, common.FlowType) *flowmgt.FlowListResponse); ok { + r0 = returnFunc(ctx, limit, offset, flowType) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*flowmgt.FlowListResponse) } } - if returnFunc, ok := ret.Get(1).(func(int, int, common.FlowType) *serviceerror.ServiceError); ok { - r1 = returnFunc(limit, offset, flowType) + if returnFunc, ok := ret.Get(1).(func(context.Context, int, int, common.FlowType) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, limit, offset, flowType) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -575,31 +628,37 @@ type FlowMgtServiceInterfaceMock_ListFlows_Call struct { } // ListFlows is a helper method to define mock.On call +// - ctx context.Context // - limit int // - offset int // - flowType common.FlowType -func (_e *FlowMgtServiceInterfaceMock_Expecter) ListFlows(limit interface{}, offset interface{}, flowType interface{}) *FlowMgtServiceInterfaceMock_ListFlows_Call { - return &FlowMgtServiceInterfaceMock_ListFlows_Call{Call: _e.mock.On("ListFlows", limit, offset, flowType)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) ListFlows(ctx interface{}, limit interface{}, offset interface{}, flowType interface{}) *FlowMgtServiceInterfaceMock_ListFlows_Call { + return &FlowMgtServiceInterfaceMock_ListFlows_Call{Call: _e.mock.On("ListFlows", ctx, limit, offset, flowType)} } -func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) Run(run func(limit int, offset int, flowType common.FlowType)) *FlowMgtServiceInterfaceMock_ListFlows_Call { +func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) Run(run func(ctx context.Context, limit int, offset int, flowType common.FlowType)) *FlowMgtServiceInterfaceMock_ListFlows_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 int + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(int) + arg0 = args[0].(context.Context) } var arg1 int if args[1] != nil { arg1 = args[1].(int) } - var arg2 common.FlowType + var arg2 int if args[2] != nil { - arg2 = args[2].(common.FlowType) + arg2 = args[2].(int) + } + var arg3 common.FlowType + if args[3] != nil { + arg3 = args[3].(common.FlowType) } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -610,14 +669,14 @@ func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) Return(flowListResponse *f return _c } -func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) RunAndReturn(run func(limit int, offset int, flowType common.FlowType) (*flowmgt.FlowListResponse, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_ListFlows_Call { +func (_c *FlowMgtServiceInterfaceMock_ListFlows_Call) RunAndReturn(run func(ctx context.Context, limit int, offset int, flowType common.FlowType) (*flowmgt.FlowListResponse, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_ListFlows_Call { _c.Call.Return(run) return _c } // RestoreFlowVersion provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) RestoreFlowVersion(flowID string, version int) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(flowID, version) +func (_mock *FlowMgtServiceInterfaceMock) RestoreFlowVersion(ctx context.Context, flowID string, version int) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID, version) if len(ret) == 0 { panic("no return value specified for RestoreFlowVersion") @@ -625,18 +684,18 @@ func (_mock *FlowMgtServiceInterfaceMock) RestoreFlowVersion(flowID string, vers var r0 *flowmgt.CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string, int) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID, version) } - if returnFunc, ok := ret.Get(0).(func(string, int) *flowmgt.CompleteFlowDefinition); ok { - r0 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int) *flowmgt.CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID, version) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*flowmgt.CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, int) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID, version) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, int) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID, version) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -651,25 +710,31 @@ type FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call struct { } // RestoreFlowVersion is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - version int -func (_e *FlowMgtServiceInterfaceMock_Expecter) RestoreFlowVersion(flowID interface{}, version interface{}) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { - return &FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call{Call: _e.mock.On("RestoreFlowVersion", flowID, version)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) RestoreFlowVersion(ctx interface{}, flowID interface{}, version interface{}) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { + return &FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call{Call: _e.mock.On("RestoreFlowVersion", ctx, flowID, version)} } -func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) Run(run func(flowID string, version int)) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { +func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) Run(run func(ctx context.Context, flowID string, version int)) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 int + var arg1 string if args[1] != nil { - arg1 = args[1].(int) + arg1 = args[1].(string) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) } run( arg0, arg1, + arg2, ) }) return _c @@ -680,14 +745,14 @@ func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) Return(completeFl return _c } -func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) RunAndReturn(run func(flowID string, version int) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { +func (_c *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call) RunAndReturn(run func(ctx context.Context, flowID string, version int) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_RestoreFlowVersion_Call { _c.Call.Return(run) return _c } // UpdateFlow provides a mock function for the type FlowMgtServiceInterfaceMock -func (_mock *FlowMgtServiceInterfaceMock) UpdateFlow(flowID string, flowDef *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { - ret := _mock.Called(flowID, flowDef) +func (_mock *FlowMgtServiceInterfaceMock) UpdateFlow(ctx context.Context, flowID string, flowDef *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError) { + ret := _mock.Called(ctx, flowID, flowDef) if len(ret) == 0 { panic("no return value specified for UpdateFlow") @@ -695,18 +760,18 @@ func (_mock *FlowMgtServiceInterfaceMock) UpdateFlow(flowID string, flowDef *flo var r0 *flowmgt.CompleteFlowDefinition var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func(string, *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { - return returnFunc(flowID, flowDef) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)); ok { + return returnFunc(ctx, flowID, flowDef) } - if returnFunc, ok := ret.Get(0).(func(string, *flowmgt.FlowDefinition) *flowmgt.CompleteFlowDefinition); ok { - r0 = returnFunc(flowID, flowDef) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *flowmgt.FlowDefinition) *flowmgt.CompleteFlowDefinition); ok { + r0 = returnFunc(ctx, flowID, flowDef) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*flowmgt.CompleteFlowDefinition) } } - if returnFunc, ok := ret.Get(1).(func(string, *flowmgt.FlowDefinition) *serviceerror.ServiceError); ok { - r1 = returnFunc(flowID, flowDef) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, *flowmgt.FlowDefinition) *serviceerror.ServiceError); ok { + r1 = returnFunc(ctx, flowID, flowDef) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(*serviceerror.ServiceError) @@ -721,25 +786,31 @@ type FlowMgtServiceInterfaceMock_UpdateFlow_Call struct { } // UpdateFlow is a helper method to define mock.On call +// - ctx context.Context // - flowID string // - flowDef *flowmgt.FlowDefinition -func (_e *FlowMgtServiceInterfaceMock_Expecter) UpdateFlow(flowID interface{}, flowDef interface{}) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { - return &FlowMgtServiceInterfaceMock_UpdateFlow_Call{Call: _e.mock.On("UpdateFlow", flowID, flowDef)} +func (_e *FlowMgtServiceInterfaceMock_Expecter) UpdateFlow(ctx interface{}, flowID interface{}, flowDef interface{}) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { + return &FlowMgtServiceInterfaceMock_UpdateFlow_Call{Call: _e.mock.On("UpdateFlow", ctx, flowID, flowDef)} } -func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) Run(run func(flowID string, flowDef *flowmgt.FlowDefinition)) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) Run(run func(ctx context.Context, flowID string, flowDef *flowmgt.FlowDefinition)) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 string + var arg0 context.Context if args[0] != nil { - arg0 = args[0].(string) + arg0 = args[0].(context.Context) } - var arg1 *flowmgt.FlowDefinition + var arg1 string if args[1] != nil { - arg1 = args[1].(*flowmgt.FlowDefinition) + arg1 = args[1].(string) + } + var arg2 *flowmgt.FlowDefinition + if args[2] != nil { + arg2 = args[2].(*flowmgt.FlowDefinition) } run( arg0, arg1, + arg2, ) }) return _c @@ -750,7 +821,7 @@ func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) Return(completeFlowDefini return _c } -func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) RunAndReturn(run func(flowID string, flowDef *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { +func (_c *FlowMgtServiceInterfaceMock_UpdateFlow_Call) RunAndReturn(run func(ctx context.Context, flowID string, flowDef *flowmgt.FlowDefinition) (*flowmgt.CompleteFlowDefinition, *serviceerror.ServiceError)) *FlowMgtServiceInterfaceMock_UpdateFlow_Call { _c.Call.Return(run) return _c }