Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 172 additions & 5 deletions src/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
log2 "mist/multilogger"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"

dockerClient "github.com/docker/docker/client"
"github.com/redis/go-redis/v9"
)

Expand All @@ -26,38 +29,53 @@ type App struct {
wg sync.WaitGroup
log *slog.Logger
statusRegistry *StatusRegistry
dockerClient *dockerClient.Client
containerMgr *ContainerMgr
}

func NewApp(redisAddr, gpuType string, log *slog.Logger) *App {
client := redis.NewClient(&redis.Options{Addr: redisAddr})
func NewApp(redisAddr, gpuType string, log *slog.Logger) (*App, error) {
redisClient := redis.NewClient(&redis.Options{Addr: redisAddr})
scheduler := NewScheduler(redisAddr, log)
statusRegistry := NewStatusRegistry(client, log)

consumerID := fmt.Sprintf("worker_%d", os.Getpid())
supervisor := NewSupervisor(redisAddr, consumerID, gpuType, log)

// Initialize Docker client with explicit API version 1.41 for compatibility
// (Docker daemon supports up to 1.41, but client defaults to 1.50)
dockerClient, err := dockerClient.NewClientWithOpts(dockerClient.FromEnv, dockerClient.WithVersion("1.41"))
if err != nil {
return nil, fmt.Errorf("failed to create docker client: %w", err)
}

// Initialize container manager with reasonable defaults
containerMgr := NewContainerMgr(dockerClient, 100, 50)

mux := http.NewServeMux()
a := &App{
redisClient: client,
redisClient: redisClient,
scheduler: scheduler,
supervisor: supervisor,
httpServer: &http.Server{Addr: ":3000", Handler: mux},
log: log,
statusRegistry: statusRegistry,
dockerClient: dockerClient,
containerMgr: containerMgr,
}

mux.HandleFunc("/auth/login", a.login)
mux.HandleFunc("/auth/refresh", a.refresh)
mux.HandleFunc("/jobs", a.handleJobs)
mux.HandleFunc("/jobs/status", a.getJobStatus)
mux.HandleFunc("/containers/", a.handleContainerLogs)
mux.HandleFunc("/supervisors/status", a.getSupervisorStatus)
mux.HandleFunc("/supervisors/status/", a.getSupervisorStatusByID)
mux.HandleFunc("/supervisors", a.getAllSupervisors)

a.log.Info("new app initialized", "redis_address", redisAddr,
"gpu_type", gpuType, "http_address", a.httpServer.Addr)

return a
return a, nil
}

func (a *App) Start() error {
Expand Down Expand Up @@ -109,6 +127,14 @@ func (a *App) Shutdown(ctx context.Context) error {
a.log.Info("redis client closed successfully")
}

if a.dockerClient != nil {
if err := a.dockerClient.Close(); err != nil {
a.log.Error("error closing docker client", "err", err)
} else {
a.log.Info("docker client closed successfully")
}
}

a.log.Info("shutdown completed")

return nil
Expand All @@ -124,7 +150,11 @@ func main() {
fmt.Fprintf(os.Stderr, "failed to create logger: %v\n", err)
os.Exit(1)
}
app := NewApp("localhost:6379", "AMD", log)
app, err := NewApp("localhost:6379", "AMD", log)
if err != nil {
log.Error("failed to create app", "err", err)
os.Exit(1)
}

if err := app.Start(); err != nil {
log.Error("failed to start app", "err", err)
Expand Down Expand Up @@ -397,3 +427,140 @@ func (a *App) getAllSupervisors(w http.ResponseWriter, r *http.Request) {
return
}
}

// AssociateContainerWithUser stores the container-user association in Redis.
// This should be called when a container is created to track ownership for authorization.
func (a *App) AssociateContainerWithUser(ctx context.Context, containerID, userID string) error {
key := fmt.Sprintf("container:%s:owner", containerID)
return a.redisClient.Set(ctx, key, userID, 0).Err()
}

// getContainerOwner retrieves the owner user ID for a container from Redis
func (a *App) getContainerOwner(ctx context.Context, containerID string) (string, error) {
key := fmt.Sprintf("container:%s:owner", containerID)
userID, err := a.redisClient.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", fmt.Errorf("container not found or not associated with any user")
}
return "", fmt.Errorf("failed to get container owner: %w", err)
}
return userID, nil
}

// getCurrentUser extracts the current user ID from the request
// This is a placeholder - in a real implementation, this would extract from JWT token, session, etc.
func (a *App) getCurrentUser(r *http.Request) (string, error) {
// For now, we'll use a simple Authorization header or user query parameter
// In a production system, this would validate JWT tokens, session cookies, etc.
authHeader := r.Header.Get("Authorization")
if authHeader != "" {
// Extract user from "Bearer <token>" or similar
parts := strings.Split(authHeader, " ")
if len(parts) == 2 && parts[0] == "Bearer" {
// In a real implementation, decode and validate the token
// For now, we'll use the token as a simple user identifier
return parts[1], nil
}
}

// Fallback: check for user query parameter (for testing)
userID := r.URL.Query().Get("user")
if userID != "" {
return userID, nil
}

return "", fmt.Errorf("authentication required")
}

// authorizeContainerAccess checks if the current user has access to the specified container
func (a *App) authorizeContainerAccess(ctx context.Context, containerID string, userID string) error {
ownerID, err := a.getContainerOwner(ctx, containerID)
if err != nil {
return err
}

if ownerID != userID {
return fmt.Errorf("unauthorized: user %s does not have access to container %s", userID, containerID)
}

return nil
}

// handleContainerLogs handles requests to /containers/{containerID}/logs
func (a *App) handleContainerLogs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

// Extract container ID from path
// Path format: /containers/{containerID}/logs
path := strings.TrimPrefix(r.URL.Path, "/containers/")
parts := strings.Split(path, "/")
if len(parts) < 2 || parts[1] != "logs" {
http.Error(w, "Invalid path. Expected /containers/{containerID}/logs", http.StatusBadRequest)
return
}

containerID := parts[0]
if containerID == "" {
http.Error(w, "Container ID is required", http.StatusBadRequest)
return
}

// Get current user
userID, err := a.getCurrentUser(r)
if err != nil {
a.log.Warn("authentication failed", "error", err, "remote_address", r.RemoteAddr)
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

// Authorize access to container
if err := a.authorizeContainerAccess(ctx, containerID, userID); err != nil {
a.log.Warn("authorization failed", "error", err, "user_id", userID, "container_id", containerID)
http.Error(w, "Unauthorized: "+err.Error(), http.StatusForbidden)
return
}

// Parse query parameters for log options
tailStr := r.URL.Query().Get("tail")
tail := 0
if tailStr != "" {
var err error
tail, err = strconv.Atoi(tailStr)
if err != nil || tail < 0 {
http.Error(w, "Invalid tail parameter. Must be a non-negative integer", http.StatusBadRequest)
return
}
}

followStr := r.URL.Query().Get("follow")
follow := followStr == "true" || followStr == "1"
since := r.URL.Query().Get("since")
until := r.URL.Query().Get("until")

// Fetch container logs
logsReader, err := a.containerMgr.GetContainerLogs(containerID, tail, follow, since, until)
if err != nil {
a.log.Error("failed to get container logs", "error", err, "container_id", containerID)
http.Error(w, fmt.Sprintf("Failed to fetch container logs: %v", err), http.StatusInternalServerError)
return
}
defer logsReader.Close()

// Set appropriate headers for streaming logs
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
if follow {
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
}

// Stream logs to response
_, err = io.Copy(w, logsReader)
if err != nil && !errors.Is(err, io.EOF) {
a.log.Error("error streaming logs", "error", err, "container_id", containerID)
// Don't send error to client if we've already started streaming
return
}

a.log.Info("container logs retrieved", "container_id", containerID, "user_id", userID)
}
84 changes: 84 additions & 0 deletions src/docker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package main

import (
"context"
"fmt"
"io"
"log/slog"
"sync"

"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
)

// ContainerMgr manages Docker containers and volumes, enforces resource limits, and tracks active resources.
type ContainerMgr struct {
ctx context.Context
cli *client.Client
containerLimit int
volumeLimit int
containers map[string]struct{}
volumes map[string]struct{}
mu sync.Mutex
}

// NewContainerMgr creates a new ContainerMgr with the specified Docker client and resource limits.
func NewContainerMgr(client *client.Client, containerLimit, volumeLimit int) *ContainerMgr {
return &ContainerMgr{
ctx: context.Background(),
cli: client,
containerLimit: containerLimit,
volumeLimit: volumeLimit,
containers: make(map[string]struct{}),
volumes: make(map[string]struct{}),
}
}

// GetContainerLogs fetches container logs from Docker for the specified container.
// Returns a ReadCloser that can be used to read the logs, or an error if the operation fails.
// Options:
// - tail: number of lines to return from the end of logs (0 = all)
// - follow: whether to follow log output (default: false)
// - since: return logs since this timestamp (RFC3339 format)
// - until: return logs before this timestamp (RFC3339 format)
func (mgr *ContainerMgr) GetContainerLogs(containerID string, tail int, follow bool, since, until string) (io.ReadCloser, error) {
ctx := mgr.ctx
cli := mgr.cli

opts := container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Timestamps: false, // Don't include timestamps in output
}

// Docker API: use "all" for all logs, or a number for tail
if tail > 0 {
opts.Tail = fmt.Sprintf("%d", tail)
} else {
opts.Tail = "all"
}

if since != "" {
opts.Since = since
}
if until != "" {
opts.Until = until
}

reader, err := cli.ContainerLogs(ctx, containerID, opts)
if err != nil {
return nil, fmt.Errorf("failed to get container logs: %w", err)
}

return reader, nil
}

// AssociateContainer associates a container ID with the ContainerMgr for tracking
func (mgr *ContainerMgr) AssociateContainer(containerID string) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
mgr.containers[containerID] = struct{}{}
slog.Info("container associated", "container_id", containerID)
}

32 changes: 31 additions & 1 deletion src/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,41 @@ module mist

go 1.24.3

require github.com/redis/go-redis/v9 v9.10.0
require (
github.com/docker/docker v28.5.1+incompatible
github.com/redis/go-redis/v9 v9.10.0
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/sys/atomicwriter v0.1.0 // indirect
github.com/moby/term v0.5.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/time v0.14.0 // indirect
gotest.tools/v3 v3.5.2 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading