diff --git a/examples/.env b/examples/.env index 5a52c16..b758308 100644 --- a/examples/.env +++ b/examples/.env @@ -72,3 +72,8 @@ LOG_LEVEL="debug" # [Optional] ORKA_VM_METADATA specifies the custom metadata passed to the VM. # Should be formatted as "key=value" comma separated list. ORKA_VM_METADATA="key1=value1,key2=value2" + +# [Optional] VM_TRACKER_INTERVAL specifies the interval at which the VM tracker will check for orphaned VMs. +# VMs are deleted if they do not have a corresponding GitHub runner for 2 consecutive checks. +# If not provided, it defaults to 300 seconds. +VM_TRACKER_INTERVAL="300s" diff --git a/main.go b/main.go index d56de74..2242e7a 100644 --- a/main.go +++ b/main.go @@ -118,7 +118,10 @@ func run(ctx context.Context, actionsClient *actions.ActionsClient, orkaClient * runnerProvisioner := provisioner.NewRunnerProvisioner(runnerScaleSet, actionsClient, orkaClient, envData) - runnerMessageProcessor := runners.NewRunnerMessageProcessor(ctx, runnerManager, runnerProvisioner, runnerScaleSet) + vmTracker := runners.NewVMTracker(orkaClient, actionsClient, logger) + go vmTracker.Start(ctx, envData.VMTrackerInterval) + + runnerMessageProcessor := runners.NewRunnerMessageProcessor(ctx, runnerManager, runnerProvisioner, vmTracker, runnerScaleSet) if err = runnerMessageProcessor.StartProcessingMessages(); err != nil { logger.Errorf("failed to start processing messages for runnerScaleSet %s: %w", runnerScaleSet.Name, err.Error()) diff --git a/pkg/env/constants.go b/pkg/env/constants.go index 635e38e..a63d238 100644 --- a/pkg/env/constants.go +++ b/pkg/env/constants.go @@ -27,5 +27,7 @@ const ( RunnerDeregistrationTimeoutEnvName = "RUNNER_DEREGISTRATION_TIMEOUT" RunnerDeregistrationPollIntervalEnvName = "RUNNER_DEREGISTRATION_POLL_INTERVAL" + VMTrackerIntervalEnvName = "VM_TRACKER_INTERVAL" + LogLevelEnvName = "LOG_LEVEL" ) diff --git a/pkg/env/env.go b/pkg/env/env.go index 1a5d36e..f426873 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -48,6 +48,8 @@ type Data struct { RunnerDeregistrationTimeout time.Duration RunnerDeregistrationPollInterval time.Duration + VMTrackerInterval time.Duration + LogLevel string } @@ -77,6 +79,8 @@ func ParseEnv() *Data { RunnerDeregistrationTimeout: getDurationEnv(RunnerDeregistrationTimeoutEnvName, 30*time.Second), RunnerDeregistrationPollInterval: getDurationEnv(RunnerDeregistrationPollIntervalEnvName, 2*time.Second), + VMTrackerInterval: getDurationEnv(VMTrackerIntervalEnvName, 300*time.Second), + LogLevel: getEnvWithDefault(LogLevelEnvName, logging.LogLevelInfo), } diff --git a/pkg/github/runners/message-processor.go b/pkg/github/runners/message-processor.go index 5745700..d2e93f3 100644 --- a/pkg/github/runners/message-processor.go +++ b/pkg/github/runners/message-processor.go @@ -26,7 +26,7 @@ const ( defaultJobId = "missing-job-id" ) -func NewRunnerMessageProcessor(ctx context.Context, runnerManager RunnerManagerInterface, runnerProvisioner RunnerProvisionerInterface, runnerScaleSet *types.RunnerScaleSet) *RunnerMessageProcessor { +func NewRunnerMessageProcessor(ctx context.Context, runnerManager RunnerManagerInterface, runnerProvisioner RunnerProvisionerInterface, vmTracker *VMTracker, runnerScaleSet *types.RunnerScaleSet) *RunnerMessageProcessor { return &RunnerMessageProcessor{ ctx: ctx, runnerManager: runnerManager, @@ -37,6 +37,7 @@ func NewRunnerMessageProcessor(ctx context.Context, runnerManager RunnerManagerI upstreamCanceledJobsMutex: sync.RWMutex{}, jobContextCancels: map[string]context.CancelFunc{}, jobContextCancelsMutex: sync.Mutex{}, + vmTracker: vmTracker, } } @@ -142,6 +143,7 @@ func (p *RunnerMessageProcessor) processRunnerMessage(message *types.RunnerScale context.AfterFunc(jobContext, func() { p.logger.Infof("cleaning up resources for %s after job context is canceled", executor.VMName) p.runnerProvisioner.CleanupResources(context.WithoutCancel(p.ctx), executor.VMName) + p.vmTracker.Untrack(executor.VMName) }) defer func() { @@ -168,6 +170,7 @@ func (p *RunnerMessageProcessor) processRunnerMessage(message *types.RunnerScale p.cancelJobContext(jobId, cancelReason) }() + p.vmTracker.Track(executor.VMName) executionErr = p.executeJobCommands(jobContext, jobId, executor, commands) }() } diff --git a/pkg/github/runners/types.go b/pkg/github/runners/types.go index 1ade3e7..75fe775 100644 --- a/pkg/github/runners/types.go +++ b/pkg/github/runners/types.go @@ -41,6 +41,7 @@ type RunnerMessageProcessor struct { logger *zap.SugaredLogger runnerManager RunnerManagerInterface runnerProvisioner RunnerProvisionerInterface + vmTracker *VMTracker runnerScaleSetName string upstreamCanceledJobs map[string]bool upstreamCanceledJobsMutex sync.RWMutex diff --git a/pkg/github/runners/vm_tracker.go b/pkg/github/runners/vm_tracker.go new file mode 100644 index 0000000..7e95717 --- /dev/null +++ b/pkg/github/runners/vm_tracker.go @@ -0,0 +1,110 @@ +package runners + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/macstadium/orka-github-actions-integration/pkg/github/actions" + "github.com/macstadium/orka-github-actions-integration/pkg/orka" + "go.uber.org/zap" +) + +type VMTracker struct { + orkaClient orka.OrkaService + actionsClient actions.ActionsService + logger *zap.SugaredLogger + + mu sync.Mutex + trackedVMs map[string]int +} + +func NewVMTracker(orkaClient orka.OrkaService, actionsClient actions.ActionsService, logger *zap.SugaredLogger) *VMTracker { + return &VMTracker{ + orkaClient: orkaClient, + actionsClient: actionsClient, + logger: logger.Named("vm-tracker"), + trackedVMs: make(map[string]int), + } +} + +func (tracker *VMTracker) Track(vmName string) { + tracker.mu.Lock() + defer tracker.mu.Unlock() + tracker.trackedVMs[vmName] = 0 + tracker.logger.Debugf("Now tracking VM %s for orphaned VM detection", vmName) +} + +func (tracker *VMTracker) Untrack(vmName string) { + tracker.logger.Debugf("Stopping tracking VM %s for orphaned VM detection", vmName) + tracker.mu.Lock() + defer tracker.mu.Unlock() + delete(tracker.trackedVMs, vmName) +} + +func (tracker *VMTracker) Start(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + tracker.checkaForOrphanedVMs(ctx) + } + } +} + +func (tracker *VMTracker) checkaForOrphanedVMs(ctx context.Context) { + tracker.logger.Debugf("Checking for orphaned VMs") + tracker.mu.Lock() + vmNames := make([]string, 0, len(tracker.trackedVMs)) + for name := range tracker.trackedVMs { + vmNames = append(vmNames, name) + } + tracker.mu.Unlock() + + if len(vmNames) == 0 { + tracker.logger.Debugf("No VMs to check for orphaned VMs") + return + } + + for _, name := range vmNames { + runner, err := tracker.actionsClient.GetRunner(ctx, name) + if err != nil { + tracker.logger.Warnf("failed to check GitHub for %s: %v", name, err) + continue + } + + tracker.mu.Lock() + if runner == nil { + tracker.trackedVMs[name]++ + strikes := tracker.trackedVMs[name] + tracker.mu.Unlock() + + tracker.logger.Warnf("VM %s has no GitHub runner (Strike %d/2)", name, strikes) + + if strikes >= 2 { + tracker.logger.Errorf("VM %s is orphaned. Forcing deletion.", name) + tracker.cleanupOrphanedVM(ctx, name) + } + } else { + tracker.trackedVMs[name] = 0 + tracker.mu.Unlock() + tracker.logger.Debugf("VM %s is healthy and registered", name) + } + } +} + +func (tracker *VMTracker) cleanupOrphanedVM(ctx context.Context, vmName string) { + err := tracker.orkaClient.DeleteVM(ctx, vmName) + if err != nil && !strings.Contains(err.Error(), "not found") { + tracker.logger.Errorf("Failed to delete orphaned VM %s: %v", vmName, err) + return + } + + tracker.Untrack(vmName) + tracker.logger.Infof("Successfully deleted orphaned VM %s", vmName) +} diff --git a/pkg/github/runners/vm_tracker_test.go b/pkg/github/runners/vm_tracker_test.go new file mode 100644 index 0000000..8f837e1 --- /dev/null +++ b/pkg/github/runners/vm_tracker_test.go @@ -0,0 +1,212 @@ +package runners + +import ( + "context" + "errors" + "testing" + + "github.com/google/uuid" + "github.com/macstadium/orka-github-actions-integration/pkg/github/types" + "github.com/macstadium/orka-github-actions-integration/pkg/orka" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "go.uber.org/zap" +) + +func TestVMTracker(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "VMTracker Suite") +} + +type MockOrkaClient struct { + DeleteVMFunc func(ctx context.Context, name string) error + DeployVMFunc func(ctx context.Context, namePrefix, vmConfig string) (*orka.OrkaVMDeployResponseModel, error) +} + +func (m *MockOrkaClient) DeleteVM(ctx context.Context, name string) error { + if m.DeleteVMFunc != nil { + return m.DeleteVMFunc(ctx, name) + } + return nil +} + +func (m *MockOrkaClient) DeployVM(ctx context.Context, namePrefix, vmConfig string) (*orka.OrkaVMDeployResponseModel, error) { + if m.DeployVMFunc != nil { + return m.DeployVMFunc(ctx, namePrefix, vmConfig) + } + return nil, nil +} + +type MockActionsClient struct { + GetRunnerFunc func(ctx context.Context, runnerName string) (*types.RunnerReference, error) +} + +func (m *MockActionsClient) GetRunner(ctx context.Context, runnerName string) (*types.RunnerReference, error) { + if m.GetRunnerFunc != nil { + return m.GetRunnerFunc(ctx, runnerName) + } + return nil, nil +} + +func (m *MockActionsClient) GetRunnerScaleSet(ctx context.Context, id int, name string) (*types.RunnerScaleSet, error) { + return nil, nil +} +func (m *MockActionsClient) CreateRunnerScaleSet(ctx context.Context, rs *types.RunnerScaleSet) (*types.RunnerScaleSet, error) { + return nil, nil +} +func (m *MockActionsClient) DeleteRunnerScaleSet(ctx context.Context, id int) error { return nil } +func (m *MockActionsClient) CreateRunner(ctx context.Context, id int, name string) (*types.RunnerScaleSetJitRunnerConfig, error) { + return nil, nil +} +func (m *MockActionsClient) DeleteRunner(ctx context.Context, id int) error { return nil } +func (m *MockActionsClient) CreateMessageSession(ctx context.Context, id int, owner string) (*types.RunnerScaleSetSession, error) { + return nil, nil +} +func (m *MockActionsClient) DeleteMessageSession(ctx context.Context, id int, sessionId *uuid.UUID) error { + return nil +} +func (m *MockActionsClient) RefreshMessageSession(ctx context.Context, id int, sessionId *uuid.UUID) (*types.RunnerScaleSetSession, error) { + return nil, nil +} +func (m *MockActionsClient) AcquireJobs(ctx context.Context, id int, token string, reqIds []int64) ([]int64, error) { + return nil, nil +} +func (m *MockActionsClient) GetAcquirableJobs(ctx context.Context, id int) (*types.AcquirableJobList, error) { + return nil, nil +} +func (m *MockActionsClient) GetMessage(ctx context.Context, url, token string, lastId int64) (*types.RunnerScaleSetMessage, error) { + return nil, nil +} +func (m *MockActionsClient) DeleteMessage(ctx context.Context, url, token string, id int64) error { + return nil +} + +var _ = Describe("VMTracker", func() { + var ( + tracker *VMTracker + mockOrka *MockOrkaClient + mockActions *MockActionsClient + ctx context.Context + vmName string + ) + + BeforeEach(func() { + mockOrka = &MockOrkaClient{} + mockActions = &MockActionsClient{} + logger := zap.NewNop().Sugar() + ctx = context.Background() + vmName = "orka-vm-test-1" + + tracker = NewVMTracker(mockOrka, mockActions, logger) + }) + + Describe("Tracking State", func() { + It("should verify a VM is tracked after calling Track", func() { + tracker.Track(vmName) + + count, exists := tracker.trackedVMs[vmName] + + Expect(exists).To(BeTrue(), "VM should exist in map") + Expect(count).To(Equal(0), "Initial strikes should be 0") + }) + + It("should stop tracking a VM after calling Untrack", func() { + tracker.Track(vmName) + tracker.Untrack(vmName) + + _, exists := tracker.trackedVMs[vmName] + + Expect(exists).To(BeFalse(), "VM should be removed from map") + }) + }) + + Describe("Check Cycle", func() { + + Context("When the VM is new (Strike 0)", func() { + BeforeEach(func() { + tracker.Track(vmName) + }) + + It("should remain healthy if GitHub returns the runner", func() { + mockActions.GetRunnerFunc = func(c context.Context, n string) (*types.RunnerReference, error) { + return &types.RunnerReference{Name: vmName, Id: 123}, nil + } + + tracker.checkaForOrphanedVMs(ctx) + + strikes := tracker.trackedVMs[vmName] + Expect(strikes).To(Equal(0), "Strikes should remain 0 for healthy runner") + }) + + It("should apply Strike 1 if Runner is missing", func() { + mockActions.GetRunnerFunc = func(c context.Context, n string) (*types.RunnerReference, error) { + return nil, nil + } + + mockOrka.DeleteVMFunc = func(c context.Context, n string) error { + Fail("DeleteVM should not be called on first strike") + return nil + } + + tracker.checkaForOrphanedVMs(ctx) + + strikes := tracker.trackedVMs[vmName] + Expect(strikes).To(Equal(1), "Strikes should increment to 1") + }) + }) + + Context("When the VM has 1 Strike", func() { + BeforeEach(func() { + tracker.Track(vmName) + tracker.trackedVMs[vmName] = 1 + }) + + It("should Reset strikes to 0 if Runner appears", func() { + mockActions.GetRunnerFunc = func(c context.Context, n string) (*types.RunnerReference, error) { + return &types.RunnerReference{Name: vmName, Id: 123}, nil + } + + tracker.checkaForOrphanedVMs(ctx) + + strikes := tracker.trackedVMs[vmName] + Expect(strikes).To(Equal(0), "Strikes should reset to 0 upon recovery") + }) + + It("should Delete the VM if Runner is still missing (Strike 2)", func() { + mockActions.GetRunnerFunc = func(c context.Context, n string) (*types.RunnerReference, error) { + return nil, nil + } + + deleteCalled := false + mockOrka.DeleteVMFunc = func(c context.Context, n string) error { + Expect(n).To(Equal(vmName)) + deleteCalled = true + return nil + } + + tracker.checkaForOrphanedVMs(ctx) + + Expect(deleteCalled).To(BeTrue(), "DeleteVM must be called on 2nd strike") + + _, exists := tracker.trackedVMs[vmName] + Expect(exists).To(BeFalse(), "VM should be untracked after deletion") + }) + }) + + Context("When GitHub API fails", func() { + It("should ignore API errors and NOT apply strikes", func() { + tracker.Track(vmName) + + mockActions.GetRunnerFunc = func(c context.Context, n string) (*types.RunnerReference, error) { + return nil, errors.New("500 Internal Server Error") + } + + tracker.checkaForOrphanedVMs(ctx) + tracker.checkaForOrphanedVMs(ctx) + + strikes := tracker.trackedVMs[vmName] + Expect(strikes).To(Equal(0), "Strikes should not increase on API errors") + }) + }) + }) +})