Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions backend/internal/application/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package application

import (
"context"
"errors"
"slices"

Expand Down Expand Up @@ -650,7 +651,7 @@ func (as *applicationService) DeleteApplication(appID string) *serviceerror.Serv
// If the flow ID is not provided, it sets the default authentication flow ID.
func (as *applicationService) validateAuthFlowID(app *model.ApplicationDTO) *serviceerror.ServiceError {
if app.AuthFlowID != "" {
isValidFlow := as.flowMgtService.IsValidFlow(app.AuthFlowID)
isValidFlow := as.flowMgtService.IsValidFlow(context.TODO(), app.AuthFlowID)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using context.TODO() is not recommended for production code. Since this function (validateAuthFlowID) is called from within request handlers that have access to request context, the context should be passed as a parameter through the call chain instead of using context.TODO(). This ensures proper request tracing, cancellation, and timeout propagation.

Copilot uses AI. Check for mistakes.
if !isValidFlow {
return &ErrorInvalidAuthFlowID
}
Expand All @@ -671,13 +672,13 @@ func (as *applicationService) validateRegistrationFlowID(app *model.ApplicationD
logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "ApplicationService"))

if app.RegistrationFlowID != "" {
isValidFlow := as.flowMgtService.IsValidFlow(app.RegistrationFlowID)
isValidFlow := as.flowMgtService.IsValidFlow(context.TODO(), app.RegistrationFlowID)
if !isValidFlow {
return &ErrorInvalidRegistrationFlowID
}
} else {
// Try to get the equivalent registration flow for the auth flow
authFlow, svcErr := as.flowMgtService.GetFlow(app.AuthFlowID)
authFlow, svcErr := as.flowMgtService.GetFlow(context.TODO(), app.AuthFlowID)
if svcErr != nil {
if svcErr.Type == serviceerror.ServerErrorType {
logger.Error("Error while retrieving auth flow definition",
Expand All @@ -688,7 +689,7 @@ func (as *applicationService) validateRegistrationFlowID(app *model.ApplicationD
}

registrationFlow, svcErr := as.flowMgtService.GetFlowByHandle(
authFlow.Handle, flowcommon.FlowTypeRegistration)
context.TODO(), authFlow.Handle, flowcommon.FlowTypeRegistration)
Comment on lines 675 to 692
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using context.TODO() is not recommended. Since these validations are part of the application creation/update flow that originates from HTTP handlers, the context should be passed down through the call chain from the handler. Consider updating the validateRegistrationFlowID method signature to accept a context parameter.

Copilot uses AI. Check for mistakes.
if svcErr != nil {
if svcErr.Type == serviceerror.ServerErrorType {
logger.Error("Error while retrieving registration flow definition by handle",
Expand Down Expand Up @@ -913,7 +914,7 @@ func (as *applicationService) getDefaultAuthFlowID() (string, *serviceerror.Serv

defaultAuthFlowHandle := config.GetThunderRuntime().Config.Flow.DefaultAuthFlowHandle
defaultAuthFlow, svcErr := as.flowMgtService.GetFlowByHandle(
defaultAuthFlowHandle, flowcommon.FlowTypeAuthentication)
context.TODO(), defaultAuthFlowHandle, flowcommon.FlowTypeAuthentication)
Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using context.TODO() is not recommended. The getDefaultAuthFlowID method is called from validateAuthFlowID which is part of the validation flow. The context should be propagated from the original HTTP request through the entire call chain.

Suggested change
context.TODO(), defaultAuthFlowHandle, flowcommon.FlowTypeAuthentication)
context.Background(), defaultAuthFlowHandle, flowcommon.FlowTypeAuthentication)

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its intentional as context propagation is being addressed service by service. In this PR we are focusing on the flow service.

Copy link

Copilot AI Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of context.TODO() is not recommended here. Since this is being called from application service methods which don't currently have a context parameter, consider using context.Background() instead, or better yet, propagate context through the application service methods. context.TODO() is typically used as a placeholder during development when context propagation is not yet implemented, but this PR is specifically about adding context propagation.

Suggested change
context.TODO(), defaultAuthFlowHandle, flowcommon.FlowTypeAuthentication)
context.Background(), defaultAuthFlowHandle, flowcommon.FlowTypeAuthentication)

Copilot uses AI. Check for mistakes.

if svcErr != nil {
if svcErr.Type == serviceerror.ServerErrorType {
Expand Down
84 changes: 43 additions & 41 deletions backend/internal/application/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ func (suite *ServiceTestSuite) TestValidateAuthFlowID_WithValidFlowID() {
AuthFlowID: "auth-flow-123",
}

mockFlowMgtService.EXPECT().IsValidFlow("auth-flow-123").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "auth-flow-123").Return(true)

svcErr := service.validateAuthFlowID(app)

Expand All @@ -667,7 +667,7 @@ func (suite *ServiceTestSuite) TestValidateAuthFlowID_WithInvalidFlowID() {
AuthFlowID: "invalid-flow",
}

mockFlowMgtService.EXPECT().IsValidFlow("invalid-flow").Return(false)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "invalid-flow").Return(false)

svcErr := service.validateAuthFlowID(app)

Expand Down Expand Up @@ -696,7 +696,7 @@ func (suite *ServiceTestSuite) TestValidateAuthFlowID_WithEmptyFlowID_SetsDefaul
ID: "default-flow-id-123",
Handle: "default_auth_flow",
}
mockFlowMgtService.EXPECT().GetFlowByHandle("default_auth_flow", flowcommon.FlowTypeAuthentication).
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "default_auth_flow", flowcommon.FlowTypeAuthentication).
Return(defaultFlow, nil)

svcErr := service.validateAuthFlowID(app)
Expand All @@ -722,7 +722,7 @@ func (suite *ServiceTestSuite) TestValidateAuthFlowID_WithEmptyFlowID_ErrorRetri
AuthFlowID: "",
}

mockFlowMgtService.EXPECT().GetFlowByHandle("default_auth_flow", flowcommon.FlowTypeAuthentication).
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "default_auth_flow", flowcommon.FlowTypeAuthentication).
Return(nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType})

svcErr := service.validateAuthFlowID(app)
Expand All @@ -738,7 +738,7 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_WithValidFlowID()
RegistrationFlowID: "reg-flow-123",
}

mockFlowMgtService.EXPECT().IsValidFlow("reg-flow-123").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "reg-flow-123").Return(true)

svcErr := service.validateRegistrationFlowID(app)

Expand All @@ -753,7 +753,7 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_WithInvalidFlowID(
RegistrationFlowID: "invalid-reg-flow",
}

mockFlowMgtService.EXPECT().IsValidFlow("invalid-reg-flow").Return(false)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "invalid-reg-flow").Return(false)

svcErr := service.validateRegistrationFlowID(app)

Expand All @@ -778,8 +778,8 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_WithEmptyFlowID_In
Handle: "basic_auth",
}

mockFlowMgtService.EXPECT().GetFlow("auth-flow-123").Return(authFlow, nil)
mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration).
mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123").Return(authFlow, nil)
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration).
Return(regFlow, nil)

svcErr := service.validateRegistrationFlowID(app)
Expand All @@ -796,7 +796,7 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_ErrorRetrievingAut
RegistrationFlowID: "",
}

mockFlowMgtService.EXPECT().GetFlow("auth-flow-123").
mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123").
Return(nil, &serviceerror.ServiceError{Type: serviceerror.ServerErrorType})

svcErr := service.validateRegistrationFlowID(app)
Expand All @@ -818,8 +818,8 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_ErrorRetrievingReg
Handle: "basic_auth",
}

mockFlowMgtService.EXPECT().GetFlow("auth-flow-123").Return(authFlow, nil)
mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration).
mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123").Return(authFlow, nil)
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration).
Return(nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType})

svcErr := service.validateRegistrationFlowID(app)
Expand All @@ -836,7 +836,7 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_ClientErrorRetriev
RegistrationFlowID: "",
}

mockFlowMgtService.EXPECT().GetFlow("auth-flow-123").
mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123").
Return(nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType})

svcErr := service.validateRegistrationFlowID(app)
Expand All @@ -858,8 +858,8 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_ServerErrorRetriev
Handle: "basic_auth",
}

mockFlowMgtService.EXPECT().GetFlow("auth-flow-123").Return(authFlow, nil)
mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration).
mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "auth-flow-123").Return(authFlow, nil)
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration).
Return(nil, &serviceerror.ServiceError{Type: serviceerror.ServerErrorType})

svcErr := service.validateRegistrationFlowID(app)
Expand All @@ -885,7 +885,7 @@ func (suite *ServiceTestSuite) TestGetDefaultAuthFlowID_Success() {
ID: "flow-id-789",
Handle: "custom_auth_flow",
}
mockFlowMgtService.EXPECT().GetFlowByHandle("custom_auth_flow", flowcommon.FlowTypeAuthentication).
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "custom_auth_flow", flowcommon.FlowTypeAuthentication).
Return(defaultFlow, nil)

result, svcErr := service.getDefaultAuthFlowID()
Expand All @@ -907,7 +907,7 @@ func (suite *ServiceTestSuite) TestGetDefaultAuthFlowID_ErrorRetrieving() {

service, _, _, mockFlowMgtService := suite.setupTestService()

mockFlowMgtService.EXPECT().GetFlowByHandle("custom_auth_flow", flowcommon.FlowTypeAuthentication).
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "custom_auth_flow", flowcommon.FlowTypeAuthentication).
Return(nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType})

result, svcErr := service.getDefaultAuthFlowID()
Expand All @@ -930,7 +930,7 @@ func (suite *ServiceTestSuite) TestGetDefaultAuthFlowID_ServerError() {

service, _, _, mockFlowMgtService := suite.setupTestService()

mockFlowMgtService.EXPECT().GetFlowByHandle("custom_auth_flow", flowcommon.FlowTypeAuthentication).
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "custom_auth_flow", flowcommon.FlowTypeAuthentication).
Return(nil, &serviceerror.ServiceError{Type: serviceerror.ServerErrorType})

result, svcErr := service.getDefaultAuthFlowID()
Expand Down Expand Up @@ -2427,19 +2427,20 @@ func (suite *ServiceTestSuite) TestValidateApplication_InvalidURL() {
}

mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError)
mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().GetFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(&flowmgt.CompleteFlowDefinition{
ID: "edc013d0-e893-4dc0-990c-3e1d203e005b",
Handle: "basic_auth",
}, nil).Maybe()
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").
Return(&flowmgt.CompleteFlowDefinition{
ID: "edc013d0-e893-4dc0-990c-3e1d203e005b",
Handle: "basic_auth",
}, nil).Maybe()

// Return success for registration flow so URL validation runs
mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration).Return(
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration).Return(
&flowmgt.CompleteFlowDefinition{
ID: "reg_flow_basic",
Handle: "basic_auth",
}, nil).Maybe()
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything).Return(true).Maybe()
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, mock.Anything).Return(true).Maybe()

result, inboundAuth, svcErr := service.ValidateApplication(app)

Expand Down Expand Up @@ -2470,19 +2471,20 @@ func (suite *ServiceTestSuite) TestValidateApplication_InvalidLogoURL() {
}

mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError)
mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().GetFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(&flowmgt.CompleteFlowDefinition{
ID: "edc013d0-e893-4dc0-990c-3e1d203e005b",
Handle: "basic_auth",
}, nil).Maybe()
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").
Return(&flowmgt.CompleteFlowDefinition{
ID: "edc013d0-e893-4dc0-990c-3e1d203e005b",
Handle: "basic_auth",
}, nil).Maybe()

// Return success for registration flow so URL validation runs
mockFlowMgtService.EXPECT().GetFlowByHandle("basic_auth", flowcommon.FlowTypeRegistration).Return(
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, "basic_auth", flowcommon.FlowTypeRegistration).Return(
&flowmgt.CompleteFlowDefinition{
ID: "reg_flow_basic",
Handle: "basic_auth",
}, nil).Maybe()
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything).Return(true).Maybe()
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, mock.Anything).Return(true).Maybe()

result, inboundAuth, svcErr := service.ValidateApplication(app)

Expand Down Expand Up @@ -2519,8 +2521,8 @@ func (suite *ServiceTestSuite) TestCreateApplication_StoreErrorWithRollback() {
}

mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError)
mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow("80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true)
mockCertService.EXPECT().CreateCertificate(mock.Anything).Return(&cert.Certificate{Type: "JWKS"}, nil)
mockStore.On("CreateApplication", mock.Anything).Return(errors.New("store error"))
mockCertService.EXPECT().
Expand Down Expand Up @@ -2561,8 +2563,8 @@ func (suite *ServiceTestSuite) TestCreateApplication_StoreErrorWithRollbackFailu
}

mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError)
mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow("80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true)
mockCertService.EXPECT().CreateCertificate(mock.Anything).Return(&cert.Certificate{Type: "JWKS"}, nil)
mockStore.On("CreateApplication", mock.Anything).Return(errors.New("store error"))
rollbackErr := &serviceerror.ServiceError{
Expand Down Expand Up @@ -2686,7 +2688,7 @@ func (suite *ServiceTestSuite) TestUpdateApplication_StoreErrorWhenCheckingClien
}

mockStore.On("GetApplicationByID", "app123").Return(existingApp, nil)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything).Return(true).Maybe()
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, mock.Anything).Return(true).Maybe()
// Return an error that's not ApplicationNotFoundError when checking client ID
mockStore.On("GetOAuthApplication", "new-client-id").Return(nil, errors.New("database connection error"))

Expand Down Expand Up @@ -2730,8 +2732,8 @@ func (suite *ServiceTestSuite) TestUpdateApplication_StoreErrorWithRollback() {
}

mockStore.On("GetApplicationByID", "app123").Return(existingApp, nil)
mockFlowMgtService.EXPECT().IsValidFlow("edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow("80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "edc013d0-e893-4dc0-990c-3e1d203e005b").Return(true)
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "80024fb3-29ed-4c33-aa48-8aee5e96d522").Return(true)
mockCertService.EXPECT().
GetCertificateByReference(cert.CertificateReferenceTypeApplication, "app123").
Return(nil, &cert.ErrorCertificateNotFound)
Expand Down Expand Up @@ -3449,12 +3451,12 @@ func (suite *ServiceTestSuite) TestValidateRegistrationFlowID_NoPrefix() {
}

mockStore.On("GetApplicationByName", "Test App").Return(nil, model.ApplicationNotFoundError)
mockFlowMgtService.EXPECT().IsValidFlow("invalid_flow_id").Return(true)
mockFlowMgtService.EXPECT().GetFlow("invalid_flow_id").Return(&flowmgt.CompleteFlowDefinition{
mockFlowMgtService.EXPECT().IsValidFlow(mock.Anything, "invalid_flow_id").Return(true)
mockFlowMgtService.EXPECT().GetFlow(mock.Anything, "invalid_flow_id").Return(&flowmgt.CompleteFlowDefinition{
ID: "invalid_flow_id",
Handle: "test_flow",
}, nil).Maybe()
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, flowcommon.FlowTypeRegistration).Return(
mockFlowMgtService.EXPECT().GetFlowByHandle(mock.Anything, mock.Anything, flowcommon.FlowTypeRegistration).Return(
nil, &serviceerror.ServiceError{Type: serviceerror.ClientErrorType}).Maybe()

result, inboundAuth, svcErr := service.ValidateApplication(app)
Expand Down
Loading
Loading