Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 194 additions & 16 deletions internal/migration/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"log/slog"
"math"
"os"
"sync"
"time"

"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -27,15 +29,33 @@ import (
// MigrateProtocol is the libp2p protocol ID for agent migration.
const MigrateProtocol protocol.ID = "/igor/migrate/1.0.0"

type managedAgent struct {
instance *agent.Instance
cancel context.CancelFunc
done chan struct{}
closeOnce sync.Once
}

func (m *managedAgent) close(ctx context.Context) error {
var closeErr error
m.closeOnce.Do(func() {
if m.instance != nil {
closeErr = m.instance.Close(ctx)
}
})
return closeErr
}

// Service coordinates agent migration between nodes.
type Service struct {
host host.Host
runtimeEngine *runtime.Engine
storageProvider storage.Provider
logger *slog.Logger

mu sync.RWMutex
// Active agents running on this node
activeAgents map[string]*agent.Instance
activeAgents map[string]*managedAgent
}

// NewService creates a new migration service.
Expand All @@ -50,7 +70,7 @@ func NewService(
runtimeEngine: engine,
storageProvider: storage,
logger: logger,
activeAgents: make(map[string]*agent.Instance),
activeAgents: make(map[string]*managedAgent),
}

// Register migration protocol handler
Expand Down Expand Up @@ -163,12 +183,9 @@ func (s *Service) MigrateAgent(
"target_node", started.NodeID,
)

// Terminate local instance if exists
if instance, exists := s.activeAgents[agentID]; exists {
if err := instance.Close(ctx); err != nil {
s.logger.Error("Failed to close local instance", "error", err)
}
delete(s.activeAgents, agentID)
// Terminate local instance if this process currently runs the agent.
if managed, exists := s.getManagedAgent(agentID); exists {
s.stopManagedAgent(ctx, agentID, managed)
s.logger.Info("Local agent instance terminated", "agent_id", agentID)
}

Expand Down Expand Up @@ -202,6 +219,11 @@ func (s *Service) handleIncomingMigration(stream network.Stream) {
}

pkg := transfer.Package
if pkg.AgentID == "" {
s.sendStartConfirmation(stream, "", false, "agent_id is required")
return
}

s.logger.Info("Agent package received",
"agent_id", pkg.AgentID,
"wasm_size", len(pkg.WASMBinary),
Expand All @@ -217,13 +239,27 @@ func (s *Service) handleIncomingMigration(stream network.Stream) {
return
}

// Write WASM binary to temporary file
wasmPath := fmt.Sprintf("/tmp/igor-agent-%s.wasm", pkg.AgentID)
if err := os.WriteFile(wasmPath, pkg.WASMBinary, 0644); err != nil {
// Write WASM binary to a secure temporary file.
tmpFile, err := os.CreateTemp("", "igor-agent-*.wasm")
if err != nil {
s.logger.Error("Failed to create temp WASM file", "error", err)
s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error())
return
}
wasmPath := tmpFile.Name()
defer os.Remove(wasmPath)

if _, err := tmpFile.Write(pkg.WASMBinary); err != nil {
_ = tmpFile.Close()
s.logger.Error("Failed to write WASM binary", "error", err)
s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error())
return
}
if err := tmpFile.Close(); err != nil {
s.logger.Error("Failed to close temp WASM file", "error", err)
s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error())
return
}

// Load agent with budget from package
instance, err := agent.LoadAgent(
Expand All @@ -245,21 +281,26 @@ func (s *Service) handleIncomingMigration(stream network.Stream) {
// Initialize agent
if err := instance.Init(ctx); err != nil {
s.logger.Error("Failed to initialize agent", "error", err)
instance.Close(ctx)
_ = instance.Close(ctx)
s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error())
return
}

// Resume from checkpoint
if err := instance.LoadCheckpointFromStorage(ctx); err != nil {
s.logger.Error("Failed to resume agent", "error", err)
instance.Close(ctx)
_ = instance.Close(ctx)
s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error())
return
}

// Store as active agent
s.activeAgents[pkg.AgentID] = instance
// Start target-side execution loop and register as active.
if err := s.startManagedAgentLoop(pkg.AgentID, instance); err != nil {
s.logger.Error("Failed to start migrated agent", "error", err)
_ = instance.Close(ctx)
s.sendStartConfirmation(stream, pkg.AgentID, false, err.Error())
return
Comment on lines +298 to +302

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Clean up checkpoint when managed agent registration fails

If startManagedAgentLoop fails here (for example because registerManagedAgent rejects a duplicate agentID), the handler returns an error but leaves the just-written checkpoint file in place. In this path, the existing local agent keeps running while its persisted checkpoint has been overwritten by the incoming transfer, so a later restart/load can resume from the wrong state even though migration was rejected.

Useful? React with 👍 / 👎.

}

s.logger.Info("Agent migration accepted and started",
"agent_id", pkg.AgentID,
Expand All @@ -270,6 +311,132 @@ func (s *Service) handleIncomingMigration(stream network.Stream) {
s.sendStartConfirmation(stream, pkg.AgentID, true, "")
}

func (s *Service) startManagedAgentLoop(agentID string, instance *agent.Instance) error {
agentCtx, cancel := context.WithCancel(context.Background())
managed := &managedAgent{
instance: instance,
cancel: cancel,
done: make(chan struct{}),
}

if err := s.registerManagedAgent(agentID, managed); err != nil {
cancel()
close(managed.done)
return err
}

go s.runManagedAgentLoop(agentCtx, agentID, managed)
return nil
}

func (s *Service) runManagedAgentLoop(ctx context.Context, agentID string, managed *managedAgent) {
defer close(managed.done)
defer s.unregisterManagedAgent(agentID, managed)

ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()

checkpointTicker := time.NewTicker(5 * time.Second)
defer checkpointTicker.Stop()

s.logger.Info("Starting migrated agent tick loop", "agent_id", agentID)

for {
select {
case <-ctx.Done():
if err := managed.instance.SaveCheckpointToStorage(context.Background()); err != nil {
s.logger.Error("Failed to save checkpoint on agent stop", "agent_id", agentID, "error", err)
}
if err := managed.close(context.Background()); err != nil {
s.logger.Error("Failed to close agent instance", "agent_id", agentID, "error", err)
}
s.logger.Info("Stopped migrated agent tick loop", "agent_id", agentID)
return

case <-ticker.C:
if err := managed.instance.Tick(ctx); err != nil {
if managed.instance.Budget <= 0 {
s.logger.Info("Migrated agent budget exhausted, terminating",
"agent_id", agentID,
"reason", "budget_exhausted",
)
} else {
s.logger.Error("Migrated agent tick failed", "agent_id", agentID, "error", err)
}

if saveErr := managed.instance.SaveCheckpointToStorage(context.Background()); saveErr != nil {
s.logger.Error("Failed to save checkpoint on agent termination", "agent_id", agentID, "error", saveErr)
}
if closeErr := managed.close(context.Background()); closeErr != nil {
s.logger.Error("Failed to close agent instance", "agent_id", agentID, "error", closeErr)
}
return
}

case <-checkpointTicker.C:
if err := managed.instance.SaveCheckpointToStorage(ctx); err != nil {
s.logger.Error("Failed to save periodic checkpoint", "agent_id", agentID, "error", err)
}
}
}
}

func (s *Service) stopManagedAgent(ctx context.Context, agentID string, managed *managedAgent) {
if managed.cancel != nil {
managed.cancel()
}

if managed.done != nil {
select {
case <-managed.done:
case <-ctx.Done():
case <-time.After(2 * time.Second):
s.logger.Warn("Timed out waiting for agent loop shutdown", "agent_id", agentID)
}
}

if err := managed.close(context.Background()); err != nil {
s.logger.Error("Failed to close local instance", "agent_id", agentID, "error", err)
}

s.unregisterManagedAgent(agentID, managed)
}

func (s *Service) registerManagedAgent(agentID string, managed *managedAgent) error {
s.mu.Lock()
defer s.mu.Unlock()

if _, exists := s.activeAgents[agentID]; exists {
return fmt.Errorf("agent %s is already active on this node", agentID)
}

s.activeAgents[agentID] = managed
return nil
}

func (s *Service) getManagedAgent(agentID string) (*managedAgent, bool) {
s.mu.RLock()
defer s.mu.RUnlock()

managed, exists := s.activeAgents[agentID]
return managed, exists
}

func (s *Service) unregisterManagedAgent(agentID string, expected *managedAgent) {
s.mu.Lock()
defer s.mu.Unlock()

current, exists := s.activeAgents[agentID]
if !exists {
return
}
if expected != nil && current != expected {
return
}

delete(s.activeAgents, agentID)
}

// sendStartConfirmation sends an AgentStarted message.
func (s *Service) sendStartConfirmation(
stream io.Writer,
Expand All @@ -292,12 +459,23 @@ func (s *Service) sendStartConfirmation(

// RegisterAgent registers an actively running agent with the migration service.
func (s *Service) RegisterAgent(agentID string, instance *agent.Instance) {
s.activeAgents[agentID] = instance
managed := &managedAgent{instance: instance}
if err := s.registerManagedAgent(agentID, managed); err != nil {
s.logger.Error("Failed to register agent with migration service",
"agent_id", agentID,
"error", err,
)
return
}

s.logger.Info("Agent registered with migration service", "agent_id", agentID)
}

// GetActiveAgents returns the list of active agent IDs.
func (s *Service) GetActiveAgents() []string {
s.mu.RLock()
defer s.mu.RUnlock()

agents := make([]string, 0, len(s.activeAgents))
for id := range s.activeAgents {
agents = append(agents, id)
Expand Down
33 changes: 28 additions & 5 deletions internal/storage/fs_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log/slog"
"os"
"path/filepath"
"regexp"
)

// FSProvider implements Provider using the local filesystem.
Expand All @@ -14,6 +15,8 @@ type FSProvider struct {
logger *slog.Logger
}

var validAgentIDPattern = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$`)

// NewFSProvider creates a new filesystem-based storage provider.
// The baseDir will be created if it doesn't exist.
func NewFSProvider(baseDir string, logger *slog.Logger) (*FSProvider, error) {
Expand All @@ -37,7 +40,10 @@ func (p *FSProvider) SaveCheckpoint(
agentID string,
state []byte,
) error {
checkpointPath := p.checkpointPath(agentID)
checkpointPath, pathErr := p.checkpointPath(agentID)
if pathErr != nil {
return pathErr
}
tempPath := checkpointPath + ".tmp"

// Write to temporary file
Expand Down Expand Up @@ -89,7 +95,10 @@ func (p *FSProvider) LoadCheckpoint(
ctx context.Context,
agentID string,
) ([]byte, error) {
checkpointPath := p.checkpointPath(agentID)
checkpointPath, pathErr := p.checkpointPath(agentID)
if pathErr != nil {
return nil, pathErr
}

data, err := os.ReadFile(checkpointPath)
if err != nil {
Expand All @@ -113,7 +122,10 @@ func (p *FSProvider) DeleteCheckpoint(
ctx context.Context,
agentID string,
) error {
checkpointPath := p.checkpointPath(agentID)
checkpointPath, pathErr := p.checkpointPath(agentID)
if pathErr != nil {
return pathErr
}

err := os.Remove(checkpointPath)
if err != nil && !os.IsNotExist(err) {
Expand All @@ -124,7 +136,18 @@ func (p *FSProvider) DeleteCheckpoint(
return nil
}

func validateAgentID(agentID string) error {
if !validAgentIDPattern.MatchString(agentID) {
return fmt.Errorf("invalid agent_id %q", agentID)
}
return nil
}

// checkpointPath returns the filesystem path for an agent's checkpoint.
func (p *FSProvider) checkpointPath(agentID string) string {
return filepath.Join(p.baseDir, agentID+".checkpoint")
func (p *FSProvider) checkpointPath(agentID string) (string, error) {
if err := validateAgentID(agentID); err != nil {
return "", err
}

return filepath.Join(p.baseDir, agentID+".checkpoint"), nil
}