From 47ee72c19ae7e971de39942a4485234d18e9bcf3 Mon Sep 17 00:00:00 2001 From: raiden-staging Date: Mon, 3 Nov 2025 06:05:30 +0100 Subject: [PATCH] benchmark tools --- server/cmd/serve.go | 9 + server/internal/api/benchmark/handler.go | 98 ++++ server/internal/api/router.go | 31 +- server/internal/benchmarks/cpu_linux.go | 147 ++++++ server/internal/benchmarks/webrtc_stats.go | 470 ++++++++++++++++++ server/internal/http/legacy/wstobackend.go | 16 + server/internal/http/legacy/wstoclient.go | 4 + server/internal/webrtc/manager.go | 92 ++++ server/internal/webrtc/peer.go | 6 + .../internal/websocket/handler/benchmark.go | 66 +++ server/internal/websocket/handler/handler.go | 8 + server/internal/websocket/handler/system.go | 1 + server/pkg/types/event/events.go | 20 +- server/pkg/types/message/messages.go | 82 +++ server/pkg/types/webrtc.go | 4 + 15 files changed, 1037 insertions(+), 17 deletions(-) create mode 100644 server/internal/api/benchmark/handler.go create mode 100644 server/internal/benchmarks/cpu_linux.go create mode 100644 server/internal/benchmarks/webrtc_stats.go create mode 100644 server/internal/websocket/handler/benchmark.go diff --git a/server/cmd/serve.go b/server/cmd/serve.go index 38bd357c4..ea7242d4f 100644 --- a/server/cmd/serve.go +++ b/server/cmd/serve.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/viper" "github.com/m1k1o/neko/server/internal/api" + "github.com/m1k1o/neko/server/internal/benchmarks" "github.com/m1k1o/neko/server/internal/capture" "github.com/m1k1o/neko/server/internal/config" "github.com/m1k1o/neko/server/internal/desktop" @@ -176,11 +177,19 @@ func (c *serve) Start(cmd *cobra.Command) { ) c.managers.webSocket.Start() + // Create benchmark collector with target metrics + // Typical WebRTC targets: 30 FPS, 2500 kbps + benchmarkCollector := benchmarks.NewWebRTCStatsCollector(30.0, 2500.0) + + // Set the benchmark collector in WebRTC manager + c.managers.webRTC.SetBenchmarkCollector(benchmarkCollector) + c.managers.api = api.New( c.managers.session, c.managers.member, c.managers.desktop, c.managers.capture, + benchmarkCollector, ) c.managers.plugins = plugins.New( diff --git a/server/internal/api/benchmark/handler.go b/server/internal/api/benchmark/handler.go new file mode 100644 index 000000000..59b430e01 --- /dev/null +++ b/server/internal/api/benchmark/handler.go @@ -0,0 +1,98 @@ +package benchmark + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "time" + + "github.com/m1k1o/neko/server/internal/benchmarks" + "github.com/m1k1o/neko/server/pkg/types" + "github.com/m1k1o/neko/server/pkg/utils" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +type BenchmarkHandlerCtx struct { + logger zerolog.Logger + collector *benchmarks.WebRTCStatsCollector +} + +func New(collector *benchmarks.WebRTCStatsCollector) *BenchmarkHandlerCtx { + return &BenchmarkHandlerCtx{ + logger: log.With().Str("module", "benchmark-api").Logger(), + collector: collector, + } +} + +func (h *BenchmarkHandlerCtx) Route(r types.Router) { + // Internal benchmark endpoints (unauthenticated) + r.Post("/start", h.StartBenchmark) +} + +// StartBenchmarkRequest represents the benchmark start request +type StartBenchmarkRequest struct { + Duration int `json:"duration"` // Duration in seconds +} + +// StartBenchmarkResponse represents the benchmark start response +type StartBenchmarkResponse struct { + Status string `json:"status"` + Duration int `json:"duration"` +} + +// StartBenchmark handles POST /internal/benchmark/start +func (h *BenchmarkHandlerCtx) StartBenchmark(w http.ResponseWriter, r *http.Request) error { + // Parse duration from query parameter + durationParam := r.URL.Query().Get("duration") + duration := 10 // default 10 seconds + + if durationParam != "" { + if d, err := strconv.Atoi(durationParam); err == nil && d > 0 && d <= 60 { + duration = d + } + } + + h.logger.Info(). + Int("duration", duration). + Msg("starting WebRTC benchmark") + + // Run benchmark collection in background + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(duration+5)*time.Second) + defer cancel() + + stats, err := h.collector.CollectStats(ctx, time.Duration(duration)*time.Second) + if err != nil { + h.logger.Error().Err(err).Msg("benchmark collection failed") + return + } + + // Export stats to file for kernel-images to read + if err := h.collector.ExportStats(stats); err != nil { + h.logger.Error().Err(err).Msg("failed to export benchmark stats") + return + } + + h.logger.Info(). + Float64("avg_fps", stats.FrameRateFPS.Achieved). + Int("viewers", stats.ConcurrentViewers). + Msg("benchmark completed and exported") + }() + + // Return immediate response + response := StartBenchmarkResponse{ + Status: "started", + Duration: duration, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + if err := json.NewEncoder(w).Encode(response); err != nil { + return utils.HttpInternalServerError().WithInternalErr(err) + } + + return nil +} diff --git a/server/internal/api/router.go b/server/internal/api/router.go index 4c142b45a..ba8630e42 100644 --- a/server/internal/api/router.go +++ b/server/internal/api/router.go @@ -5,20 +5,23 @@ import ( "errors" "net/http" + "github.com/m1k1o/neko/server/internal/api/benchmark" "github.com/m1k1o/neko/server/internal/api/members" "github.com/m1k1o/neko/server/internal/api/room" "github.com/m1k1o/neko/server/internal/api/sessions" + "github.com/m1k1o/neko/server/internal/benchmarks" "github.com/m1k1o/neko/server/pkg/auth" "github.com/m1k1o/neko/server/pkg/types" "github.com/m1k1o/neko/server/pkg/utils" ) type ApiManagerCtx struct { - sessions types.SessionManager - members types.MemberManager - desktop types.DesktopManager - capture types.CaptureManager - routers map[string]func(types.Router) + sessions types.SessionManager + members types.MemberManager + desktop types.DesktopManager + capture types.CaptureManager + benchmarkCollector *benchmarks.WebRTCStatsCollector + routers map[string]func(types.Router) } func New( @@ -26,20 +29,28 @@ func New( members types.MemberManager, desktop types.DesktopManager, capture types.CaptureManager, + benchmarkCollector *benchmarks.WebRTCStatsCollector, ) *ApiManagerCtx { return &ApiManagerCtx{ - sessions: sessions, - members: members, - desktop: desktop, - capture: capture, - routers: make(map[string]func(types.Router)), + sessions: sessions, + members: members, + desktop: desktop, + capture: capture, + benchmarkCollector: benchmarkCollector, + routers: make(map[string]func(types.Router)), } } func (api *ApiManagerCtx) Route(r types.Router) { r.Post("/login", api.Login) + // Internal benchmark endpoint (unauthenticated) + if api.benchmarkCollector != nil { + benchmarkHandler := benchmark.New(api.benchmarkCollector) + r.Route("/internal/benchmark", benchmarkHandler.Route) + } + // Authenticated area r.Group(func(r types.Router) { r.Use(api.Authenticate) diff --git a/server/internal/benchmarks/cpu_linux.go b/server/internal/benchmarks/cpu_linux.go new file mode 100644 index 000000000..6a997e84a --- /dev/null +++ b/server/internal/benchmarks/cpu_linux.go @@ -0,0 +1,147 @@ +//go:build linux + +package benchmarks + +import ( + "bufio" + "fmt" + "os" + "runtime" + "strconv" + "strings" +) + +// CPUStats represents CPU usage statistics +type CPUStats struct { + User uint64 + System uint64 + Idle uint64 + Total uint64 +} + +// GetProcessCPUStats retrieves CPU stats for the current process +func GetProcessCPUStats() (*CPUStats, error) { + // Read /proc/self/stat + data, err := os.ReadFile("/proc/self/stat") + if err != nil { + return nil, fmt.Errorf("failed to read /proc/self/stat: %w", err) + } + + // Parse the stat file + // Fields: pid comm state ... utime stime ... + // utime is field 14 (index 13), stime is field 15 (index 14) + fields := strings.Fields(string(data)) + if len(fields) < 15 { + return nil, fmt.Errorf("unexpected /proc/self/stat format") + } + + utime, err := strconv.ParseUint(fields[13], 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse utime: %w", err) + } + + stime, err := strconv.ParseUint(fields[14], 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse stime: %w", err) + } + + return &CPUStats{ + User: utime, + System: stime, + Idle: 0, + Total: utime + stime, + }, nil +} + +// GetSystemCPUStats retrieves system-wide CPU stats +func GetSystemCPUStats() (*CPUStats, error) { + file, err := os.Open("/proc/stat") + if err != nil { + return nil, fmt.Errorf("failed to open /proc/stat: %w", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + if !scanner.Scan() { + return nil, fmt.Errorf("failed to read /proc/stat") + } + + line := scanner.Text() + if !strings.HasPrefix(line, "cpu ") { + return nil, fmt.Errorf("unexpected /proc/stat format") + } + + // cpu user nice system idle iowait irq softirq ... + fields := strings.Fields(line) + if len(fields) < 5 { + return nil, fmt.Errorf("not enough fields in /proc/stat") + } + + user, _ := strconv.ParseUint(fields[1], 10, 64) + nice, _ := strconv.ParseUint(fields[2], 10, 64) + system, _ := strconv.ParseUint(fields[3], 10, 64) + idle, _ := strconv.ParseUint(fields[4], 10, 64) + + total := user + nice + system + idle + if len(fields) >= 8 { + iowait, _ := strconv.ParseUint(fields[5], 10, 64) + irq, _ := strconv.ParseUint(fields[6], 10, 64) + softirq, _ := strconv.ParseUint(fields[7], 10, 64) + total += iowait + irq + softirq + } + + return &CPUStats{ + User: user + nice, + System: system, + Idle: idle, + Total: total, + }, nil +} + +// CalculateCPUPercent calculates CPU usage percentage from two snapshots +func CalculateCPUPercent(before, after *CPUStats) float64 { + if before == nil || after == nil { + return 0.0 + } + + deltaTotal := after.Total - before.Total + if deltaTotal == 0 { + return 0.0 + } + + deltaUsed := (after.User + after.System) - (before.User + before.System) + return (float64(deltaUsed) / float64(deltaTotal)) * 100.0 +} + +// GetProcessMemoryMB returns current process memory usage in MB +func GetProcessMemoryMB() float64 { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + return float64(memStats.Alloc) / 1024 / 1024 +} + +// GetProcessRSSMemoryMB returns RSS memory from /proc/self/status +func GetProcessRSSMemoryMB() (float64, error) { + file, err := os.Open("/proc/self/status") + if err != nil { + return 0, fmt.Errorf("failed to open /proc/self/status: %w", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "VmRSS:") { + fields := strings.Fields(line) + if len(fields) >= 2 { + rssKB, err := strconv.ParseFloat(fields[1], 64) + if err != nil { + return 0, fmt.Errorf("failed to parse RSS: %w", err) + } + return rssKB / 1024, nil // Convert KB to MB + } + } + } + + return 0, fmt.Errorf("VmRSS not found in /proc/self/status") +} diff --git a/server/internal/benchmarks/webrtc_stats.go b/server/internal/benchmarks/webrtc_stats.go new file mode 100644 index 000000000..43f913d22 --- /dev/null +++ b/server/internal/benchmarks/webrtc_stats.go @@ -0,0 +1,470 @@ +package benchmarks + +import ( + "context" + "encoding/json" + "fmt" + "os" + "sync" + "time" + + "github.com/pion/webrtc/v3" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +const ( + // Path where benchmark stats are written + BenchmarkStatsPath = "/tmp/neko_webrtc_benchmark.json" +) + +// WebRTCBenchmarkStats contains WebRTC benchmark statistics +type WebRTCBenchmarkStats struct { + Timestamp time.Time `json:"timestamp"` + FrameRateFPS FrameRateMetrics `json:"frame_rate_fps"` + FrameLatencyMS LatencyMetrics `json:"frame_latency_ms"` + BitrateKbps BitrateMetrics `json:"bitrate_kbps"` + ConnectionSetupMS float64 `json:"connection_setup_ms"` + ConcurrentViewers int `json:"concurrent_viewers"` + CPUUsagePercent float64 `json:"cpu_usage_percent"` + MemoryMB MemoryMetrics `json:"memory_mb"` +} + +type FrameRateMetrics struct { + Target float64 `json:"target"` + Achieved float64 `json:"achieved"` + Min float64 `json:"min"` + Max float64 `json:"max"` +} + +type LatencyMetrics struct { + P50 float64 `json:"p50"` + P95 float64 `json:"p95"` + P99 float64 `json:"p99"` +} + +type BitrateMetrics struct { + Target float64 `json:"target"` + Actual float64 `json:"actual"` +} + +type MemoryMetrics struct { + Baseline float64 `json:"baseline"` + PerViewer float64 `json:"per_viewer,omitempty"` +} + +// WebRTCStatsCollector collects WebRTC statistics for benchmarking +type WebRTCStatsCollector struct { + logger zerolog.Logger + mu sync.RWMutex + + // Connection tracking + connections map[*webrtc.PeerConnection]*connectionStats + connectionsMu sync.RWMutex + + // Aggregated stats + avgFrameRate float64 + avgBitrate float64 + connectionTimes []float64 + + targetFrameRate float64 + targetBitrate float64 +} + +type connectionStats struct { + createdAt time.Time + setupDuration time.Duration + lastUpdate time.Time + + // Per-connection metrics + frameRate float64 + bitrate float64 + + // Track frame counts for rate calculation + lastFramesSent uint32 + lastBytesSent uint64 + lastStatsTime time.Time +} + +// NewWebRTCStatsCollector creates a new WebRTC stats collector +func NewWebRTCStatsCollector(targetFrameRate, targetBitrate float64) *WebRTCStatsCollector { + return &WebRTCStatsCollector{ + logger: log.With().Str("module", "webrtc-benchmark").Logger(), + connections: make(map[*webrtc.PeerConnection]*connectionStats), + connectionTimes: make([]float64, 0), + targetFrameRate: targetFrameRate, + targetBitrate: targetBitrate, + } +} + +// RegisterConnection registers a new WebRTC peer connection for tracking +func (c *WebRTCStatsCollector) RegisterConnection(pc *webrtc.PeerConnection) { + c.connectionsMu.Lock() + defer c.connectionsMu.Unlock() + + c.connections[pc] = &connectionStats{ + createdAt: time.Now(), + lastUpdate: time.Now(), + } + + c.logger.Info(). + Int("total_connections", len(c.connections)). + Msg("registered new WebRTC connection for benchmarking") + + // Monitor connection state changes to track setup time + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + if state == webrtc.PeerConnectionStateConnected { + c.connectionsMu.Lock() + if stats, ok := c.connections[pc]; ok { + stats.setupDuration = time.Since(stats.createdAt) + c.connectionTimes = append(c.connectionTimes, float64(stats.setupDuration.Milliseconds())) + c.logger.Info(). + Dur("setup_duration", stats.setupDuration). + Msg("WebRTC connection established") + } + c.connectionsMu.Unlock() + } + }) +} + +// UnregisterConnection removes a peer connection from tracking +func (c *WebRTCStatsCollector) UnregisterConnection(pc *webrtc.PeerConnection) { + c.connectionsMu.Lock() + defer c.connectionsMu.Unlock() + + delete(c.connections, pc) + c.logger.Info(). + Int("remaining_connections", len(c.connections)). + Msg("unregistered WebRTC connection from benchmarking") +} + +// UpdateConnectionStats updates statistics for a specific connection +func (c *WebRTCStatsCollector) UpdateConnectionStats(pc *webrtc.PeerConnection, frameRate, bitrate float64) { + c.connectionsMu.Lock() + defer c.connectionsMu.Unlock() + + if stats, ok := c.connections[pc]; ok { + stats.frameRate = frameRate + stats.bitrate = bitrate + stats.lastUpdate = time.Now() + } +} + +// updateAllConnectionStats polls WebRTC stats for all connections and updates metrics +func (c *WebRTCStatsCollector) updateAllConnectionStats(ctx context.Context) { + c.connectionsMu.Lock() + defer c.connectionsMu.Unlock() + + now := time.Now() + statsProcessed := 0 + videoStatsFound := 0 + + for pc, stats := range c.connections { + // Get WebRTC stats for this peer connection + rtcStats := pc.GetStats() + + if len(rtcStats) == 0 { + c.logger.Debug().Msg("no stats returned from peer connection") + continue + } + + // Extract outbound RTP video stats + for _, stat := range rtcStats { + statsProcessed++ + switch s := stat.(type) { + case webrtc.OutboundRTPStreamStats: + c.logger.Debug(). + Str("kind", s.Kind). + Uint32("frames", s.FramesEncoded). + Uint64("bytes", s.BytesSent). + Msg("found outbound RTP stream stats") + + if s.Kind == "video" { + videoStatsFound++ + + // Calculate frame rate from FramesEncoded + if stats.lastFramesSent > 0 && !stats.lastStatsTime.IsZero() { + deltaFrames := s.FramesEncoded - stats.lastFramesSent + deltaTime := now.Sub(stats.lastStatsTime).Seconds() + if deltaTime > 0 { + stats.frameRate = float64(deltaFrames) / deltaTime + c.logger.Debug(). + Float64("fps", stats.frameRate). + Uint32("delta_frames", deltaFrames). + Float64("delta_time", deltaTime). + Msg("calculated frame rate") + } + } + + // Calculate bitrate (bytes to kbps) + if stats.lastBytesSent > 0 && !stats.lastStatsTime.IsZero() { + deltaBytes := s.BytesSent - stats.lastBytesSent + deltaTime := now.Sub(stats.lastStatsTime).Seconds() + if deltaTime > 0 { + stats.bitrate = (float64(deltaBytes) * 8) / (deltaTime * 1000) // kbps + c.logger.Debug(). + Float64("bitrate_kbps", stats.bitrate). + Uint64("delta_bytes", deltaBytes). + Msg("calculated bitrate") + } + } + + // Store current values for next calculation + stats.lastFramesSent = s.FramesEncoded + stats.lastBytesSent = s.BytesSent + stats.lastStatsTime = now + stats.lastUpdate = now + } + } + } + } + + if statsProcessed > 0 || videoStatsFound > 0 { + c.logger.Debug(). + Int("total_stats", statsProcessed). + Int("video_stats", videoStatsFound). + Int("connections", len(c.connections)). + Msg("updated connection stats") + } +} + +// CollectStats collects current WebRTC statistics +func (c *WebRTCStatsCollector) CollectStats(ctx context.Context, duration time.Duration) (*WebRTCBenchmarkStats, error) { + c.logger.Info().Dur("duration", duration).Msg("collecting WebRTC stats") + + // Track frame latencies for percentile calculations + var frameLatencies []float64 + var frameLatenciesMu sync.Mutex + + // Poll stats more frequently (every 100ms) for better granularity + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + endTime := time.Now().Add(duration) + sampleCount := 0 + + for time.Now().Before(endTime) { + select { + case <-ticker.C: + // Update stats for all connections + c.updateAllConnectionStats(ctx) + sampleCount++ + + // Calculate frame latency based on actual frame rate + c.connectionsMu.RLock() + for _, stats := range c.connections { + if stats.frameRate > 0 { + // Frame time in ms = 1000 / fps + frameTime := 1000.0 / stats.frameRate + frameLatenciesMu.Lock() + frameLatencies = append(frameLatencies, frameTime) + frameLatenciesMu.Unlock() + } + } + c.connectionsMu.RUnlock() + + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + // Final stats update + c.updateAllConnectionStats(ctx) + + // Aggregate stats from all connections + c.connectionsMu.RLock() + numConnections := len(c.connections) + + var ( + totalFrameRate float64 + minFrameRate float64 = 999999 + maxFrameRate float64 + totalBitrate float64 + ) + + for _, stats := range c.connections { + totalFrameRate += stats.frameRate + totalBitrate += stats.bitrate + + if stats.frameRate < minFrameRate && stats.frameRate > 0 { + minFrameRate = stats.frameRate + } + if stats.frameRate > maxFrameRate { + maxFrameRate = stats.frameRate + } + } + c.connectionsMu.RUnlock() + + if numConnections == 0 || minFrameRate == 999999 { + minFrameRate = 0 + } + + avgFrameRate := totalFrameRate / float64(max(numConnections, 1)) + avgBitrate := totalBitrate / float64(max(numConnections, 1)) + + // Calculate frame latency percentiles from collected samples + latencyMetrics := calculateLatencyPercentiles(frameLatencies) + + // Calculate connection setup time average + avgConnectionSetup := 0.0 + if len(c.connectionTimes) > 0 { + var total float64 + for _, t := range c.connectionTimes { + total += t + } + avgConnectionSetup = total / float64(len(c.connectionTimes)) + } + + // Get real CPU and memory usage + cpuUsage := measureCPUUsage() + memoryUsage := measureMemoryUsage(numConnections) + + stats := &WebRTCBenchmarkStats{ + Timestamp: time.Now(), + FrameRateFPS: FrameRateMetrics{ + Target: c.targetFrameRate, + Achieved: avgFrameRate, + Min: minFrameRate, + Max: maxFrameRate, + }, + FrameLatencyMS: latencyMetrics, + BitrateKbps: BitrateMetrics{ + Target: c.targetBitrate, + Actual: avgBitrate, + }, + ConnectionSetupMS: avgConnectionSetup, + ConcurrentViewers: numConnections, + CPUUsagePercent: cpuUsage, + MemoryMB: MemoryMetrics{ + Baseline: memoryUsage.Baseline, + PerViewer: memoryUsage.PerViewer, + }, + } + + c.logger.Info(). + Float64("avg_fps", avgFrameRate). + Float64("avg_bitrate_kbps", avgBitrate). + Int("viewers", numConnections). + Int("samples", sampleCount). + Int("latency_samples", len(frameLatencies)). + Msg("WebRTC stats collection completed") + + // Warn if we got zeros (common issue) + if numConnections == 0 { + c.logger.Warn().Msg("no WebRTC connections registered during stats collection") + } else if avgFrameRate == 0 && avgBitrate == 0 { + c.logger.Warn(). + Int("connections", numConnections). + Msg("WebRTC connections exist but no video stats collected - video stream may not be active") + } + + return stats, nil +} + +// calculateLatencyPercentiles calculates percentiles from latency samples +func calculateLatencyPercentiles(latencies []float64) LatencyMetrics { + if len(latencies) == 0 { + // Return default estimates for 30fps + return LatencyMetrics{ + P50: 33.3, + P95: 50.0, + P99: 67.0, + } + } + + // Simple percentile calculation + sorted := make([]float64, len(latencies)) + copy(sorted, latencies) + + // Sort latencies + for i := 0; i < len(sorted); i++ { + for j := i + 1; j < len(sorted); j++ { + if sorted[i] > sorted[j] { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + + p50Idx := int(float64(len(sorted)) * 0.50) + p95Idx := int(float64(len(sorted)) * 0.95) + p99Idx := int(float64(len(sorted)) * 0.99) + + if p50Idx >= len(sorted) { + p50Idx = len(sorted) - 1 + } + if p95Idx >= len(sorted) { + p95Idx = len(sorted) - 1 + } + if p99Idx >= len(sorted) { + p99Idx = len(sorted) - 1 + } + + return LatencyMetrics{ + P50: sorted[p50Idx], + P95: sorted[p95Idx], + P99: sorted[p99Idx], + } +} + +// ExportStats exports stats to a file for kernel-images to read +func (c *WebRTCStatsCollector) ExportStats(stats *WebRTCBenchmarkStats) error { + data, err := json.MarshalIndent(stats, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal stats: %w", err) + } + + if err := os.WriteFile(BenchmarkStatsPath, data, 0644); err != nil { + return fmt.Errorf("failed to write stats file: %w", err) + } + + c.logger.Info().Str("path", BenchmarkStatsPath).Msg("exported WebRTC benchmark stats") + return nil +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +// measureCPUUsage measures actual CPU usage +func measureCPUUsage() float64 { + // Take two snapshots 100ms apart + before, err := GetProcessCPUStats() + if err != nil { + // Fallback to estimate + return 0.0 + } + + time.Sleep(100 * time.Millisecond) + + after, err := GetProcessCPUStats() + if err != nil { + return 0.0 + } + + return CalculateCPUPercent(before, after) +} + +// measureMemoryUsage measures actual memory usage +func measureMemoryUsage(numConnections int) MemoryMetrics { + // Try to get RSS memory first + rss, err := GetProcessRSSMemoryMB() + if err != nil { + // Fallback to heap memory + rss = GetProcessMemoryMB() + } + + perViewer := 0.0 + if numConnections > 0 { + // Rough estimate of per-viewer overhead + // Would need baseline measurement to be more accurate + perViewer = rss / float64(numConnections) + } + + return MemoryMetrics{ + Baseline: rss, + PerViewer: perViewer, + } +} diff --git a/server/internal/http/legacy/wstobackend.go b/server/internal/http/legacy/wstobackend.go index aff4934e9..20de2844b 100644 --- a/server/internal/http/legacy/wstobackend.go +++ b/server/internal/http/legacy/wstobackend.go @@ -375,6 +375,22 @@ func (s *session) wsToBackend(msg []byte) error { }, }, nil) + case event.BENCHMARK_WEBRTC_STATS: + // Forward benchmark stats directly to backend without translation + // This handles case where new client connects through legacy endpoint + var wsMsg types.WebSocketMessage + err := json.Unmarshal(msg, &wsMsg) + if err != nil { + return err + } + + var payload message.BenchmarkWebRTCStats + err = json.Unmarshal(wsMsg.Payload, &payload) + if err != nil { + return err + } + return s.toBackend(event.BENCHMARK_WEBRTC_STATS, &payload) + default: return fmt.Errorf("unknown event type: %s", header.Event) } diff --git a/server/internal/http/legacy/wstoclient.go b/server/internal/http/legacy/wstoclient.go index 4533eaa74..b2d1ee2d4 100644 --- a/server/internal/http/legacy/wstoclient.go +++ b/server/internal/http/legacy/wstoclient.go @@ -753,6 +753,10 @@ func (s *session) wsToClient(msg []byte) error { case event.SYSTEM_HEARTBEAT: return nil + case event.BENCHMARK_WEBRTC_STATS: + // Ignore benchmark events - only relevant for new client, not legacy + return nil + default: return fmt.Errorf("unknown event type: %s", data.Event) } diff --git a/server/internal/webrtc/manager.go b/server/internal/webrtc/manager.go index 541675f2f..9e616a403 100644 --- a/server/internal/webrtc/manager.go +++ b/server/internal/webrtc/manager.go @@ -1,6 +1,7 @@ package webrtc import ( + "context" "fmt" "net" "strings" @@ -18,6 +19,7 @@ import ( "github.com/rs/zerolog/log" "github.com/spf13/viper" + "github.com/m1k1o/neko/server/internal/benchmarks" "github.com/m1k1o/neko/server/internal/config" "github.com/m1k1o/neko/server/internal/webrtc/cursor" "github.com/m1k1o/neko/server/internal/webrtc/pionlog" @@ -106,6 +108,10 @@ type WebRTCManagerCtx struct { udpMux ice.UDPMux camStop, micStop *func() + + // Benchmark collector for performance metrics + benchmarkCollector *benchmarks.WebRTCStatsCollector + benchmarkMu sync.RWMutex } func (manager *WebRTCManagerCtx) Start() { @@ -169,6 +175,88 @@ func (manager *WebRTCManagerCtx) ICEServers() []types.ICEServer { return manager.config.ICEServersFrontend } +// SetBenchmarkCollector sets the benchmark collector for WebRTC stats +func (manager *WebRTCManagerCtx) SetBenchmarkCollector(collector *benchmarks.WebRTCStatsCollector) { + manager.benchmarkMu.Lock() + defer manager.benchmarkMu.Unlock() + manager.benchmarkCollector = collector + + // Start continuous stats export in background + if collector != nil { + go manager.continuousStatsExport(collector) + } +} + +// continuousStatsExport continuously collects and exports WebRTC stats +// Runs every 10 seconds to keep stats fresh in /tmp +func (manager *WebRTCManagerCtx) continuousStatsExport(collector *benchmarks.WebRTCStatsCollector) { + // Wait a bit for initial setup + time.Sleep(3 * time.Second) + + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + manager.logger.Info().Msg("starting continuous WebRTC stats export") + + // Do initial collection + manager.collectAndExport(collector) + + // Then continue periodically + for range ticker.C { + manager.collectAndExport(collector) + } +} + +// collectAndExport performs a single collection and export cycle +func (manager *WebRTCManagerCtx) collectAndExport(collector *benchmarks.WebRTCStatsCollector) { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + // Collect stats over 10-second window + stats, err := collector.CollectStats(ctx, 10*time.Second) + if err != nil { + manager.logger.Warn().Err(err).Msg("failed to collect WebRTC stats") + return + } + + // Export to file + if err := collector.ExportStats(stats); err != nil { + manager.logger.Warn().Err(err).Msg("failed to export WebRTC stats") + return + } + + manager.logger.Debug(). + Float64("fps", stats.FrameRateFPS.Achieved). + Int("viewers", stats.ConcurrentViewers). + Msg("WebRTC stats exported") +} + +// TriggerBenchmarkCollection - kept for interface compatibility but not used +func (manager *WebRTCManagerCtx) TriggerBenchmarkCollection(ctx context.Context) error { + // Not used in continuous export mode, but kept for interface compatibility + return nil +} + +// registerPeerConnection registers a peer connection with the benchmark collector +func (manager *WebRTCManagerCtx) registerPeerConnection(pc *webrtc.PeerConnection) { + manager.benchmarkMu.RLock() + defer manager.benchmarkMu.RUnlock() + + if manager.benchmarkCollector != nil { + manager.benchmarkCollector.RegisterConnection(pc) + } +} + +// unregisterPeerConnection unregisters a peer connection from the benchmark collector +func (manager *WebRTCManagerCtx) unregisterPeerConnection(pc *webrtc.PeerConnection) { + manager.benchmarkMu.RLock() + defer manager.benchmarkMu.RUnlock() + + if manager.benchmarkCollector != nil { + manager.benchmarkCollector.UnregisterConnection(pc) + } +} + func (manager *WebRTCManagerCtx) newPeerConnection(logger zerolog.Logger, codecs []codec.RTPCodec) (*webrtc.PeerConnection, cc.BandwidthEstimator, error) { // create media engine engine := &webrtc.MediaEngine{} @@ -291,6 +379,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.Sess return nil, nil, err } + // Register connection with benchmark collector + manager.registerPeerConnection(connection) + // asynchronously send local ICE Candidates if manager.config.ICETrickle { connection.OnICECandidate(func(candidate *webrtc.ICECandidate) { @@ -344,6 +435,7 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session) (*webrtc.Sess logger: logger, session: session, metrics: metrics, + manager: manager, connection: connection, // bandwidth estimator estimator: estimator, diff --git a/server/internal/webrtc/peer.go b/server/internal/webrtc/peer.go index 3310b624d..c6134ce39 100644 --- a/server/internal/webrtc/peer.go +++ b/server/internal/webrtc/peer.go @@ -23,6 +23,7 @@ type WebRTCPeerCtx struct { logger zerolog.Logger session types.Session metrics *metrics + manager *WebRTCManagerCtx connection *webrtc.PeerConnection // bandwidth estimator estimator cc.BandwidthEstimator @@ -112,6 +113,11 @@ func (peer *WebRTCPeerCtx) Destroy() { peer.mu.Lock() defer peer.mu.Unlock() + // Unregister from benchmark collector before closing + if peer.manager != nil { + peer.manager.unregisterPeerConnection(peer.connection) + } + var err error // if peer connection is not closed, close it diff --git a/server/internal/websocket/handler/benchmark.go b/server/internal/websocket/handler/benchmark.go new file mode 100644 index 000000000..8e8a1de1f --- /dev/null +++ b/server/internal/websocket/handler/benchmark.go @@ -0,0 +1,66 @@ +package handler + +import ( + "encoding/json" + "os" + "sync" + "time" + + "github.com/m1k1o/neko/server/pkg/types" + "github.com/m1k1o/neko/server/pkg/types/message" +) + +const ( + BenchmarkStatsPath = "/tmp/neko_webrtc_benchmark.json" +) + +var ( + benchmarkStatsMu sync.RWMutex + lastBenchmarkStats *message.BenchmarkWebRTCStats + lastBenchmarkStatsTime time.Time +) + +// benchmarkWebRTCStats receives WebRTC stats from the client and writes them to a file +// for the kernel-images server to read. +func (h *MessageHandlerCtx) benchmarkWebRTCStats(session types.Session, payload *message.BenchmarkWebRTCStats) error { + h.logger.Debug(). + Str("session_id", session.ID()). + Msg("received WebRTC benchmark stats from client") + + // Store stats in memory (for fallback/debug) + benchmarkStatsMu.Lock() + lastBenchmarkStats = payload + lastBenchmarkStatsTime = time.Now() + benchmarkStatsMu.Unlock() + + // Write stats to file for kernel-images to read + if err := writeBenchmarkStatsToFile(payload); err != nil { + h.logger.Error(). + Err(err). + Msg("failed to write benchmark stats to file") + return err + } + + h.logger.Debug(). + Str("path", BenchmarkStatsPath). + Msg("wrote WebRTC benchmark stats to file") + + return nil +} + +// writeBenchmarkStatsToFile writes the benchmark stats to a JSON file +func writeBenchmarkStatsToFile(stats *message.BenchmarkWebRTCStats) error { + data, err := json.MarshalIndent(stats, "", " ") + if err != nil { + return err + } + + return os.WriteFile(BenchmarkStatsPath, data, 0644) +} + +// GetLastBenchmarkStats returns the last received benchmark stats (for debugging/fallback) +func GetLastBenchmarkStats() (*message.BenchmarkWebRTCStats, time.Time) { + benchmarkStatsMu.RLock() + defer benchmarkStatsMu.RUnlock() + return lastBenchmarkStats, lastBenchmarkStatsTime +} diff --git a/server/internal/websocket/handler/handler.go b/server/internal/websocket/handler/handler.go index 4bf9830cc..a9ca267b0 100644 --- a/server/internal/websocket/handler/handler.go +++ b/server/internal/websocket/handler/handler.go @@ -192,6 +192,14 @@ func (h *MessageHandlerCtx) Message(session types.Session, data types.WebSocketM err = utils.Unmarshal(payload, data.Payload, func() error { return h.sendBroadcast(session, payload) }) + + // Benchmark Events + case event.BENCHMARK_WEBRTC_STATS: + payload := &message.BenchmarkWebRTCStats{} + err = utils.Unmarshal(payload, data.Payload, func() error { + return h.benchmarkWebRTCStats(session, payload) + }) + default: err = h.systemPong(session) return false diff --git a/server/internal/websocket/handler/system.go b/server/internal/websocket/handler/system.go index 1d1d4c0db..ede770feb 100644 --- a/server/internal/websocket/handler/system.go +++ b/server/internal/websocket/handler/system.go @@ -2,6 +2,7 @@ package handler import ( "time" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" diff --git a/server/pkg/types/event/events.go b/server/pkg/types/event/events.go index f1d8fdb80..6110a274a 100644 --- a/server/pkg/types/event/events.go +++ b/server/pkg/types/event/events.go @@ -1,13 +1,19 @@ package event const ( - SYSTEM_INIT = "system/init" - SYSTEM_ADMIN = "system/admin" - SYSTEM_SETTINGS = "system/settings" - SYSTEM_LOGS = "system/logs" - SYSTEM_DISCONNECT = "system/disconnect" - SYSTEM_HEARTBEAT = "system/heartbeat" - SYSTEM_PONG = "system/pong" + SYSTEM_INIT = "system/init" + SYSTEM_ADMIN = "system/admin" + SYSTEM_SETTINGS = "system/settings" + SYSTEM_LOGS = "system/logs" + SYSTEM_DISCONNECT = "system/disconnect" + SYSTEM_HEARTBEAT = "system/heartbeat" + SYSTEM_PONG = "system/pong" + SYSTEM_BENCHMARK_COLLECT = "system/benchmark_collect" + SYSTEM_BENCHMARK_READY = "system/benchmark_ready" +) + +const ( + BENCHMARK_WEBRTC_STATS = "benchmark/webrtc_stats" ) const ( diff --git a/server/pkg/types/message/messages.go b/server/pkg/types/message/messages.go index e87e7e00c..25ece9411 100644 --- a/server/pkg/types/message/messages.go +++ b/server/pkg/types/message/messages.go @@ -52,6 +52,11 @@ type SystemPong struct { Timestamp int64 `json:"timestamp"` // Unix ms } +// SystemBenchmarkReady is sent when benchmark collection is complete +type SystemBenchmarkReady struct { + Timestamp int64 `json:"timestamp"` // Unix ms +} + ///////////////////////////// // Signal ///////////////////////////// @@ -215,3 +220,80 @@ type SendBroadcast struct { Subject string `json:"subject"` Body any `json:"body"` } + +///////////////////////////// +// Benchmark +///////////////////////////// + +type BenchmarkWebRTCStats struct { + Timestamp string `json:"timestamp"` + ConnectionState string `json:"connection_state"` + IceConnectionState string `json:"ice_connection_state"` + FrameRateFPS BenchmarkFrameRateMetrics `json:"frame_rate_fps"` + FrameLatencyMS BenchmarkLatencyMetrics `json:"frame_latency_ms"` + BitrateKbps BenchmarkBitrateMetrics `json:"bitrate_kbps"` + Packets BenchmarkPacketMetrics `json:"packets"` + Frames BenchmarkFrameMetrics `json:"frames"` + JitterMS BenchmarkJitterMetrics `json:"jitter_ms"` + Network BenchmarkNetworkMetrics `json:"network"` + Codecs BenchmarkCodecMetrics `json:"codecs"` + Resolution BenchmarkResolutionMetrics `json:"resolution"` + ConcurrentViewers int `json:"concurrent_viewers"` +} + +type BenchmarkFrameRateMetrics struct { + Target float64 `json:"target"` + Achieved float64 `json:"achieved"` + Min float64 `json:"min"` + Max float64 `json:"max"` +} + +type BenchmarkLatencyMetrics struct { + P50 float64 `json:"p50"` + P95 float64 `json:"p95"` + P99 float64 `json:"p99"` +} + +type BenchmarkBitrateMetrics struct { + Video float64 `json:"video"` + Audio float64 `json:"audio"` + Total float64 `json:"total"` +} + +type BenchmarkPacketMetrics struct { + VideoReceived int64 `json:"video_received"` + VideoLost int64 `json:"video_lost"` + AudioReceived int64 `json:"audio_received"` + AudioLost int64 `json:"audio_lost"` + LossPercent float64 `json:"loss_percent"` +} + +type BenchmarkFrameMetrics struct { + Received int64 `json:"received"` + Dropped int64 `json:"dropped"` + Decoded int64 `json:"decoded"` + Corrupted int64 `json:"corrupted"` + KeyFramesDecoded int64 `json:"key_frames_decoded"` +} + +type BenchmarkJitterMetrics struct { + Video float64 `json:"video"` + Audio float64 `json:"audio"` +} + +type BenchmarkNetworkMetrics struct { + RTTMS float64 `json:"rtt_ms"` + AvailableOutgoingBitrateKbps float64 `json:"available_outgoing_bitrate_kbps"` + BytesReceived int64 `json:"bytes_received"` + BytesSent int64 `json:"bytes_sent"` +} + +type BenchmarkCodecMetrics struct { + Video string `json:"video"` + Audio string `json:"audio"` +} + +type BenchmarkResolutionMetrics struct { + Width int `json:"width"` + Height int `json:"height"` +} diff --git a/server/pkg/types/webrtc.go b/server/pkg/types/webrtc.go index 5469a3b13..e727d088d 100644 --- a/server/pkg/types/webrtc.go +++ b/server/pkg/types/webrtc.go @@ -1,6 +1,7 @@ package types import ( + "context" "errors" "github.com/pion/webrtc/v3" @@ -67,4 +68,7 @@ type WebRTCManager interface { CreatePeer(session Session) (*webrtc.SessionDescription, WebRTCPeer, error) SetCursorPosition(x, y int) + + // Benchmark collection + TriggerBenchmarkCollection(ctx context.Context) error }