diff --git a/pkg/config/config.go b/pkg/config/config.go index 8e2dc354..d2f07a6e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -56,6 +56,13 @@ type TLSConfig struct { KeyLog string `yaml:"key_log"` } +type SessionTimerConfig struct { + DefaultExpires int `yaml:"default_expires"` // Default session interval in seconds (default: 1800) + MinSE int `yaml:"min_se"` // Minimum acceptable session interval (default: 90) + PreferRefresher string `yaml:"prefer_refresher"` // Preferred refresher role: "uac" or "uas" (default: "uac") + UseUpdate bool `yaml:"use_update"` // Use UPDATE instead of re-INVITE for refresh (default: false) +} + type Config struct { Redis *redis.RedisConfig `yaml:"redis"` // required ApiKey string `yaml:"api_key"` // required (env LIVEKIT_API_KEY) @@ -100,6 +107,9 @@ type Config struct { EnableJitterBuffer bool `yaml:"enable_jitter_buffer"` EnableJitterBufferProb float64 `yaml:"enable_jitter_buffer_prob"` + // SessionTimer configures RFC 4028 session timer support + SessionTimer SessionTimerConfig `yaml:"session_timer"` + // internal ServiceName string `yaml:"-"` NodeID string // Do not provide, will be overwritten @@ -158,6 +168,17 @@ func (c *Config) Init() error { c.MaxCpuUtilization = 0.9 } + // Initialize session timer defaults + if c.SessionTimer.DefaultExpires == 0 { + c.SessionTimer.DefaultExpires = 1800 // 30 minutes + } + if c.SessionTimer.MinSE == 0 { + c.SessionTimer.MinSE = 90 // RFC 4028 minimum + } + if c.SessionTimer.PreferRefresher == "" { + c.SessionTimer.PreferRefresher = "uac" + } + if err := c.InitLogger(); err != nil { return err } diff --git a/pkg/sip/inbound.go b/pkg/sip/inbound.go index 800a6dda..055fc738 100644 --- a/pkg/sip/inbound.go +++ b/pkg/sip/inbound.go @@ -553,28 +553,30 @@ func (s *Server) onNotify(log *slog.Logger, req *sip.Request, tx sip.ServerTrans } type inboundCall struct { - s *Server - log logger.Logger - cc *sipInbound - mon *stats.CallMonitor - state *CallState - extraAttrs map[string]string - attrsToHdr map[string]string - ctx context.Context - cancel func() - closeReason atomic.Pointer[ReasonHeader] - call *rpc.SIPCall - media *MediaPort - dtmf chan dtmf.Event // buffered - lkRoom *Room // LiveKit room; only active after correct pin is entered - callDur func() time.Duration - joinDur func() time.Duration - forwardDTMF atomic.Bool - done atomic.Bool - started core.Fuse - stats Stats - jitterBuf bool - projectID string + s *Server + log logger.Logger + cc *sipInbound + mon *stats.CallMonitor + state *CallState + extraAttrs map[string]string + attrsToHdr map[string]string + ctx context.Context + cancel func() + closeReason atomic.Pointer[ReasonHeader] + call *rpc.SIPCall + media *MediaPort + dtmf chan dtmf.Event // buffered + lkRoom *Room // LiveKit room; only active after correct pin is entered + callDur func() time.Duration + joinDur func() time.Duration + forwardDTMF atomic.Bool + done atomic.Bool + started core.Fuse + stats Stats + jitterBuf bool + projectID string + sessionTimer *SessionTimer // RFC 4028 session timer + lastSDP []byte // Last SDP answer sent (for session refresh) } func (s *Server) newInboundCall( @@ -622,6 +624,9 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI c.call.SipCallId = h.Value() } + // Initialize session timer (RFC 4028) + c.initSessionTimer(req, conf) + c.cc.StartRinging() // Send initial request. In the best case scenario, we will immediately get a room name to join. // Otherwise, we could even learn that this number is not allowed and reject the call, or ask for pin if required. @@ -713,6 +718,10 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI headers = AttrsToHeaders(r.LocalParticipant.Attributes(), c.attrsToHdr, headers) } c.log.Infow("Accepting the call", "headers", headers) + + // Store SDP for session refresh + c.lastSDP = answerData + err := c.cc.Accept(ctx, answerData, headers) if errors.Is(err, errNoACK) { c.log.Errorw("Call accepted, but no ACK received", err) @@ -812,6 +821,11 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI c.started.Break() + // Start session timer after call is established + if c.sessionTimer != nil { + c.sessionTimer.Start() + } + var noAck = false // Wait for the caller to terminate the call. Send regular keep alives ticker := time.NewTicker(stateUpdateTick) @@ -1038,6 +1052,58 @@ func (c *inboundCall) printStats(log logger.Logger) { log.Infow("call statistics", "stats", c.stats.Load()) } +// initSessionTimer initializes the session timer from the incoming INVITE +func (c *inboundCall) initSessionTimer(req *sip.Request, conf *config.Config) { + // Convert config format to session timer config + stConfig := SessionTimerConfig{ + DefaultExpires: conf.SessionTimer.DefaultExpires, + MinSE: conf.SessionTimer.MinSE, + UseUpdate: conf.SessionTimer.UseUpdate, + } + + // Parse prefer refresher string + switch conf.SessionTimer.PreferRefresher { + case "uac": + stConfig.PreferRefresher = RefresherUAC + case "uas": + stConfig.PreferRefresher = RefresherUAS + default: + stConfig.PreferRefresher = RefresherUAC + } + + c.sessionTimer = NewSessionTimer(stConfig, false, c.log) // isUAC=false for inbound + c.sessionTimer.SetContext(c.ctx) + + // Set up callbacks + c.sessionTimer.SetCallbacks( + func(ctx context.Context) error { + return c.sendSessionRefresh(ctx) + }, + func(ctx context.Context) error { + c.log.Warnw("Session timer expired, terminating call", nil) + c.closeWithTimeout() + return nil + }, + ) + + // Share timer with sipInbound for response generation + c.cc.sessionTimer = c.sessionTimer + + // Negotiate session timer parameters from INVITE + _, _, _, err := c.sessionTimer.NegotiateInvite(req) + if err != nil { + c.log.Warnw("Session timer negotiation failed, timer disabled", err) + } +} + +// sendSessionRefresh sends a session refresh (re-INVITE or UPDATE) +func (c *inboundCall) sendSessionRefresh(ctx context.Context) error { + c.log.Infow("Sending session refresh") + + // Use the sipInbound layer to send the refresh with the same SDP + return c.cc.sendSessionRefresh(ctx, c.lastSDP) +} + // close should only be called from handleInvite. func (c *inboundCall) close(error bool, status CallStatus, reason string) { if !c.done.CompareAndSwap(false, true) { @@ -1060,6 +1126,12 @@ func (c *inboundCall) close(error bool, status CallStatus, reason string) { c.closeMedia() c.cc.CloseWithStatus(sipCode, sipStatus) + + // Stop session timer if active + if c.sessionTimer != nil { + c.sessionTimer.Stop() + } + if c.callDur != nil { c.callDur() } @@ -1352,6 +1424,7 @@ type sipInbound struct { ringing chan struct{} acked core.Fuse setHeaders setHeadersFunc + sessionTimer *SessionTimer // Session timer reference } func (c *sipInbound) ValidateInvite() error { @@ -1550,6 +1623,15 @@ func (c *sipInbound) Accept(ctx context.Context, sdpData []byte, headers map[str c.addExtraHeaders(r) + // Add session timer headers if negotiated + if c.sessionTimer != nil { + sessionExpires := c.sessionTimer.GetSessionExpires() + refresher := c.sessionTimer.GetRefresher() + if sessionExpires > 0 { + c.sessionTimer.AddHeadersToResponse(r, sessionExpires, refresher) + } + } + r.AppendHeader(&contentTypeHeaderSDP) for k, v := range headers { r.AppendHeader(sip.NewHeader(k, v)) @@ -1665,6 +1747,85 @@ func (c *sipInbound) setCSeq(req *sip.Request) { c.nextRequestCSeq++ } +// sendSessionRefresh sends a mid-dialog re-INVITE to refresh the session +func (c *sipInbound) sendSessionRefresh(ctx context.Context, sdpOffer []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.inviteOk == nil || c.invite == nil { + return errors.New("call not established") + } + + ctx, span := tracer.Start(ctx, "sipInbound.sendSessionRefresh") + defer span.End() + + // Create re-INVITE request with the same dialog parameters + req := sip.NewRequest(sip.INVITE, c.invite.Recipient) + + // Copy essential headers from original INVITE + req.RemoveHeader("Call-ID") + if callID := c.invite.CallID(); callID != nil { + req.AppendHeader(callID) + } + + // From and To headers (maintaining tags) + req.AppendHeader(c.from) + req.AppendHeader(c.to) + + // Contact + if c.contact != nil { + req.AppendHeader(c.contact) + } + + // Set new CSeq + c.setCSeq(req) + + // Add SDP body + if sdpOffer != nil && len(sdpOffer) > 0 { + req.SetBody(sdpOffer) + req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) + } + + // Add session timer headers if active + if c.sessionTimer != nil { + c.sessionTimer.AddHeadersToRequest(req) + } + + // Add custom headers + if c.setHeaders != nil { + for k, v := range c.setHeaders(nil) { + req.AppendHeader(sip.NewHeader(k, v)) + } + } + + // Swap src/dst for client-like behavior + c.swapSrcDst(req) + + // Send the request and wait for response + tx, err := c.s.sipSrv.TransactionLayer().Request(req) + if err != nil { + return fmt.Errorf("failed to send session refresh: %w", err) + } + defer tx.Terminate() + + // Wait for response + resp, err := sipResponse(ctx, tx, nil, nil) + if err != nil { + return fmt.Errorf("session refresh failed: %w", err) + } + + if resp.StatusCode != sip.StatusOK { + return fmt.Errorf("session refresh rejected: %d %s", resp.StatusCode, resp.Reason) + } + + c.log.Infow("Session refresh successful") + + // Send ACK + ack := sip.NewAckRequest(req, resp, nil) + c.swapSrcDst(ack) + return c.s.sipSrv.TransportLayer().WriteMsg(ack) +} + func (c *sipInbound) sendBye() { if c.inviteOk == nil { return // call wasn't established diff --git a/pkg/sip/outbound.go b/pkg/sip/outbound.go index 7acb6579..ed2bec0f 100644 --- a/pkg/sip/outbound.go +++ b/pkg/sip/outbound.go @@ -66,17 +66,19 @@ type sipOutboundConfig struct { } type outboundCall struct { - c *Client - log logger.Logger - state *CallState - cc *sipOutbound - media *MediaPort - started core.Fuse - stopped core.Fuse - closing core.Fuse - stats Stats - jitterBuf bool - projectID string + c *Client + log logger.Logger + state *CallState + cc *sipOutbound + media *MediaPort + started core.Fuse + stopped core.Fuse + closing core.Fuse + stats Stats + jitterBuf bool + projectID string + sessionTimer *SessionTimer // RFC 4028 session timer + lastSDP []byte // Last SDP offer sent (for session refresh) mu sync.RWMutex mon *stats.CallMonitor @@ -149,6 +151,9 @@ func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Lo return nil, fmt.Errorf("update room failed: %w", err) } + // Initialize session timer (RFC 4028) + call.initSessionTimer(ctx, conf) + c.cmu.Lock() defer c.cmu.Unlock() c.activeCalls[id] = call @@ -203,6 +208,12 @@ func (c *outboundCall) Dial(ctx context.Context) error { info.StartedAtNs = time.Now().UnixNano() info.CallStatus = livekit.SIPCallStatus_SCS_ACTIVE }) + + // Start session timer after call is established + if c.sessionTimer != nil { + c.sessionTimer.Start() + } + return nil } @@ -270,6 +281,52 @@ func (c *outboundCall) closeWithTimeout() { c.close(psrpc.NewErrorf(psrpc.DeadlineExceeded, "media-timeout"), callDropped, "media-timeout", livekit.DisconnectReason_UNKNOWN_REASON) } +// initSessionTimer initializes the session timer for outbound calls +func (c *outboundCall) initSessionTimer(ctx context.Context, conf *config.Config) { + // Convert config format to session timer config + stConfig := SessionTimerConfig{ + DefaultExpires: conf.SessionTimer.DefaultExpires, + MinSE: conf.SessionTimer.MinSE, + UseUpdate: conf.SessionTimer.UseUpdate, + } + + // Parse prefer refresher string + switch conf.SessionTimer.PreferRefresher { + case "uac": + stConfig.PreferRefresher = RefresherUAC + case "uas": + stConfig.PreferRefresher = RefresherUAS + default: + stConfig.PreferRefresher = RefresherUAC + } + + c.sessionTimer = NewSessionTimer(stConfig, true, c.log) // isUAC=true for outbound + c.sessionTimer.SetContext(ctx) + + // Set up callbacks + c.sessionTimer.SetCallbacks( + func(ctx context.Context) error { + return c.sendSessionRefresh(ctx) + }, + func(ctx context.Context) error { + c.log.Warnw("Session timer expired, terminating call", nil) + c.closeWithTimeout() + return nil + }, + ) + + // Share timer with sipOutbound for request generation + c.cc.sessionTimer = c.sessionTimer +} + +// sendSessionRefresh sends a session refresh (re-INVITE or UPDATE) +func (c *outboundCall) sendSessionRefresh(ctx context.Context) error { + c.log.Infow("Sending session refresh") + + // Use the sipOutbound layer to send the refresh with the same SDP + return c.cc.sendSessionRefresh(ctx, c.lastSDP) +} + func (c *outboundCall) printStats() { c.log.Infow("call statistics", "stats", c.stats.Load()) } @@ -298,6 +355,11 @@ func (c *outboundCall) close(err error, status CallStatus, description string, r _ = c.lkRoom.CloseWithReason(status.DisconnectReason()) c.lkRoomIn = nil + // Stop session timer if active + if c.sessionTimer != nil { + c.sessionTimer.Stop() + } + c.stopSIP(description) c.c.cmu.Lock() @@ -528,6 +590,10 @@ func (c *outboundCall) sipSignal(ctx context.Context) error { if err != nil { return err } + + // Store SDP offer for session refresh + c.lastSDP = sdpOfferData + c.mon.SDPSize(len(sdpOfferData), true) c.log.Debugw("SDP offer", "sdp", string(sdpOfferData)) joinDur := c.mon.JoinDur() @@ -683,11 +749,12 @@ func (c *Client) newOutbound(log logger.Logger, id LocalTag, from, contact URI, } type sipOutbound struct { - log logger.Logger - c *Client - id LocalTag - from *sip.FromHeader - contact *sip.ContactHeader + log logger.Logger + c *Client + id LocalTag + from *sip.FromHeader + contact *sip.ContactHeader + sessionTimer *SessionTimer // Session timer reference mu sync.RWMutex tag RemoteTag @@ -867,6 +934,13 @@ authLoop: req.PrependHeader(&sip.RouteHeader{Address: hdr.(*sip.RecordRouteHeader).Address}) } + // Negotiate session timer from response + if c.sessionTimer != nil { + if err := c.sessionTimer.NegotiateResponse(resp); err != nil { + c.log.Warnw("Failed to negotiate session timer from response", err) + } + } + return c.inviteOk.Body(), nil } @@ -904,6 +978,11 @@ func (c *sipOutbound) attemptInvite(ctx context.Context, callID sip.CallIDHeader req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) req.AppendHeader(sip.NewHeader("Allow", "INVITE, ACK, CANCEL, BYE, NOTIFY, REFER, MESSAGE, OPTIONS, INFO, SUBSCRIBE")) + // Add session timer headers if configured + if c.sessionTimer != nil { + c.sessionTimer.AddHeadersToRequest(req) + } + if authHeader != "" { req.AppendHeader(sip.NewHeader(authHeaderName, authHeader)) } @@ -935,6 +1014,84 @@ func (c *sipOutbound) setCSeq(req *sip.Request) { c.nextCSeq++ } +// sendSessionRefresh sends a mid-dialog re-INVITE to refresh the session +func (c *sipOutbound) sendSessionRefresh(ctx context.Context, sdpOffer []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.invite == nil || c.inviteOk == nil { + return errors.New("call not established") + } + + ctx, span := tracer.Start(ctx, "sipOutbound.sendSessionRefresh") + defer span.End() + + // Create re-INVITE request using the established dialog + req := sip.NewRequest(sip.INVITE, c.invite.Recipient) + + // Copy essential headers from original INVITE + req.RemoveHeader("Call-ID") + if callID := c.invite.CallID(); callID != nil { + req.AppendHeader(callID) + } + + // From and To headers (maintaining tags) + req.AppendHeader(c.from) + req.AppendHeader(c.to) + + // Contact + if c.contact != nil { + req.AppendHeader(c.contact) + } + + // Set new CSeq + c.setCSeq(req) + + // Add SDP body + if sdpOffer != nil && len(sdpOffer) > 0 { + req.SetBody(sdpOffer) + req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) + } + + // Add session timer headers if active + if c.sessionTimer != nil { + c.sessionTimer.AddHeadersToRequest(req) + } + + // Add User-Agent + req.AppendHeader(sip.NewHeader("User-Agent", "LiveKit")) + + // Add custom headers + if c.getHeaders != nil { + for k, v := range c.getHeaders(nil) { + req.AppendHeader(sip.NewHeader(k, v)) + } + } + + // Send the request and wait for response + tx, err := c.c.sipCli.TransactionRequest(req) + if err != nil { + return fmt.Errorf("failed to send session refresh: %w", err) + } + defer tx.Terminate() + + // Wait for response + resp, err := sipResponse(ctx, tx, c.c.closing.Watch(), nil) + if err != nil { + return fmt.Errorf("session refresh failed: %w", err) + } + + if resp.StatusCode != sip.StatusOK { + return fmt.Errorf("session refresh rejected: %d %s", resp.StatusCode, resp.Reason) + } + + c.log.Infow("Session refresh successful") + + // Send ACK + ack := sip.NewAckRequest(req, resp, nil) + return c.c.sipCli.WriteRequest(ack) +} + func (c *sipOutbound) sendBye() { if c.invite == nil || c.inviteOk == nil { return // call wasn't established diff --git a/pkg/sip/session_timer.go b/pkg/sip/session_timer.go new file mode 100644 index 00000000..bf39518e --- /dev/null +++ b/pkg/sip/session_timer.go @@ -0,0 +1,515 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sip + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/emiago/sipgo/sip" + "github.com/livekit/protocol/logger" +) + +const ( + // RFC 4028 minimum session interval + minSessionExpiresRFC = 90 + + // Default session interval (30 minutes) + defaultSessionExpires = 1800 + + // Extension name for Supported/Require headers + timerExtension = "timer" +) + +// RefresherRole indicates which party is responsible for refreshing the session +type RefresherRole int + +const ( + RefresherNone RefresherRole = iota + RefresherUAC + RefresherUAS +) + +func (r RefresherRole) String() string { + switch r { + case RefresherUAC: + return "uac" + case RefresherUAS: + return "uas" + default: + return "none" + } +} + +func parseRefresherRole(s string) RefresherRole { + switch strings.ToLower(s) { + case "uac": + return RefresherUAC + case "uas": + return RefresherUAS + default: + return RefresherNone + } +} + +// SessionTimerConfig holds configuration for session timers +type SessionTimerConfig struct { + DefaultExpires int // Default session interval in seconds + MinSE int // Minimum session interval in seconds + PreferRefresher RefresherRole // Preferred refresher role + UseUpdate bool // Use UPDATE instead of re-INVITE for refresh +} + +// DefaultSessionTimerConfig returns the default session timer configuration +func DefaultSessionTimerConfig() SessionTimerConfig { + return SessionTimerConfig{ + DefaultExpires: defaultSessionExpires, + MinSE: minSessionExpiresRFC, + PreferRefresher: RefresherUAC, + UseUpdate: false, + } +} + +// SessionTimer manages RFC 4028 session timers for a SIP dialog +type SessionTimer struct { + mu sync.Mutex + + config SessionTimerConfig + log logger.Logger + + // Negotiated parameters + sessionExpires int // Negotiated session interval in seconds + refresher RefresherRole // Who is responsible for refresh + isUAC bool // Are we the UAC in this dialog? + + // Timers + refreshTimer *time.Timer // Timer for sending refresh + expiryTimer *time.Timer // Timer for session expiry + expiryGeneration uint64 // Generation counter to invalidate old expiry timers + lastRefresh time.Time // Timestamp of last refresh + + // Callbacks + onRefresh func(ctx context.Context) error // Callback to send refresh request + onExpiry func(ctx context.Context) error // Callback to handle session expiry + + // State + started bool + stopped bool + ctx context.Context +} + +// NewSessionTimer creates a new session timer +func NewSessionTimer(config SessionTimerConfig, isUAC bool, log logger.Logger) *SessionTimer { + if log == nil { + log = logger.GetLogger() + } + + return &SessionTimer{ + config: config, + log: log, + isUAC: isUAC, + sessionExpires: config.DefaultExpires, + refresher: RefresherNone, + } +} + +// SetContext sets the context for the session timer +func (st *SessionTimer) SetContext(ctx context.Context) { + st.mu.Lock() + defer st.mu.Unlock() + st.ctx = ctx +} + +// SetCallbacks sets the refresh and expiry callbacks +func (st *SessionTimer) SetCallbacks(onRefresh, onExpiry func(ctx context.Context) error) { + st.mu.Lock() + defer st.mu.Unlock() + st.onRefresh = onRefresh + st.onExpiry = onExpiry +} + +// NegotiateInvite negotiates session timer parameters from an incoming INVITE request +// Returns the negotiated values and any error (including 422 rejection) +func (st *SessionTimer) NegotiateInvite(req *sip.Request) (sessionExpires int, minSE int, refresher RefresherRole, err error) { + st.mu.Lock() + defer st.mu.Unlock() + + // Check for Session-Expires header + seHeader := req.GetHeader("Session-Expires") + if seHeader == nil { + // No session timer requested + return 0, 0, RefresherNone, nil + } + + // Parse Session-Expires header: "1800;refresher=uac" + parts := strings.Split(seHeader.Value(), ";") + requestedExpires, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + st.log.Warnw("Invalid Session-Expires header", err, "value", seHeader.Value()) + return 0, 0, RefresherNone, fmt.Errorf("invalid Session-Expires value") + } + + // Parse refresher parameter if present + requestedRefresher := RefresherNone + for _, part := range parts[1:] { + kv := strings.Split(part, "=") + if len(kv) == 2 && strings.TrimSpace(strings.ToLower(kv[0])) == "refresher" { + requestedRefresher = parseRefresherRole(strings.TrimSpace(kv[1])) + } + } + + // Check Min-SE header + minSEHeader := req.GetHeader("Min-SE") + requestedMinSE := st.config.MinSE + if minSEHeader != nil { + if parsed, err := strconv.Atoi(strings.TrimSpace(minSEHeader.Value())); err == nil { + if parsed > requestedMinSE { + requestedMinSE = parsed + } + } + } + + // Enforce our minimum + if requestedExpires < st.config.MinSE { + st.log.Infow("Session interval too small, rejecting with 422", + "requested", requestedExpires, + "minSE", st.config.MinSE) + return 0, st.config.MinSE, RefresherNone, fmt.Errorf("session interval too small: %d < %d", requestedExpires, st.config.MinSE) + } + + // Accept the requested interval + negotiatedExpires := requestedExpires + + // Determine refresher role + // UAS (us) decides the final refresher role + negotiatedRefresher := requestedRefresher + if negotiatedRefresher == RefresherNone { + // If not specified, use our preference + negotiatedRefresher = st.config.PreferRefresher + if negotiatedRefresher == RefresherNone { + // Default to UAC if still unspecified + negotiatedRefresher = RefresherUAC + } + } + + st.sessionExpires = negotiatedExpires + st.refresher = negotiatedRefresher + + st.log.Infow("Negotiated session timer from INVITE", + "sessionExpires", negotiatedExpires, + "minSE", requestedMinSE, + "refresher", negotiatedRefresher.String()) + + return negotiatedExpires, requestedMinSE, negotiatedRefresher, nil +} + +// NegotiateResponse negotiates session timer parameters from a response (for UAC) +// This is called when we receive a 2xx response to our INVITE +func (st *SessionTimer) NegotiateResponse(res *sip.Response) error { + st.mu.Lock() + defer st.mu.Unlock() + + // Check for Session-Expires header in response + seHeader := res.GetHeader("Session-Expires") + if seHeader == nil { + // UAS doesn't support session timers + st.log.Infow("UAS doesn't support session timers (no Session-Expires in response)") + return nil + } + + // Parse Session-Expires header + parts := strings.Split(seHeader.Value(), ";") + negotiatedExpires, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + st.log.Warnw("Invalid Session-Expires in response", err, "value", seHeader.Value()) + return fmt.Errorf("invalid Session-Expires value") + } + + // Parse refresher parameter + negotiatedRefresher := RefresherNone + for _, part := range parts[1:] { + kv := strings.Split(part, "=") + if len(kv) == 2 && strings.TrimSpace(strings.ToLower(kv[0])) == "refresher" { + negotiatedRefresher = parseRefresherRole(strings.TrimSpace(kv[1])) + } + } + + if negotiatedRefresher == RefresherNone { + // If not specified, default to UAC (us) + negotiatedRefresher = RefresherUAC + } + + st.sessionExpires = negotiatedExpires + st.refresher = negotiatedRefresher + + st.log.Infow("Negotiated session timer from response", + "sessionExpires", negotiatedExpires, + "refresher", negotiatedRefresher.String()) + + return nil +} + +// AddHeadersToRequest adds session timer headers to an outgoing INVITE request +func (st *SessionTimer) AddHeadersToRequest(req *sip.Request) { + st.mu.Lock() + defer st.mu.Unlock() + + // Add Supported: timer + req.AppendHeader(sip.NewHeader("Supported", timerExtension)) + + // Add Session-Expires header + sessionExpires := st.config.DefaultExpires + refresher := st.config.PreferRefresher + if refresher == RefresherNone { + refresher = RefresherUAC // Default to UAC + } + + seValue := fmt.Sprintf("%d;refresher=%s", sessionExpires, refresher.String()) + req.AppendHeader(sip.NewHeader("Session-Expires", seValue)) + + // Add Min-SE header + req.AppendHeader(sip.NewHeader("Min-SE", strconv.Itoa(st.config.MinSE))) + + st.log.Debugw("Added session timer headers to INVITE", + "sessionExpires", sessionExpires, + "minSE", st.config.MinSE, + "refresher", refresher.String()) +} + +// AddHeadersToResponse adds session timer headers to a response +func (st *SessionTimer) AddHeadersToResponse(res *sip.Response, sessionExpires int, refresher RefresherRole) { + if sessionExpires == 0 { + return + } + + // Add Require: timer (to indicate timer support is required) + res.AppendHeader(sip.NewHeader("Require", timerExtension)) + + // Add Session-Expires header + seValue := fmt.Sprintf("%d;refresher=%s", sessionExpires, refresher.String()) + res.AppendHeader(sip.NewHeader("Session-Expires", seValue)) + + st.log.Debugw("Added session timer headers to response", + "sessionExpires", sessionExpires, + "refresher", refresher.String()) +} + +// Start starts the session timer +func (st *SessionTimer) Start() { + st.mu.Lock() + defer st.mu.Unlock() + + if st.started || st.sessionExpires == 0 { + return + } + + st.started = true + st.lastRefresh = time.Now() + + if st.ctx == nil { + st.log.Warnw("Session timer started without context", nil) + return + } + + // Determine if we are the refresher + weAreRefresher := (st.isUAC && st.refresher == RefresherUAC) || (!st.isUAC && st.refresher == RefresherUAS) + + if weAreRefresher { + // We are responsible for refreshing + // Refresh at half the session interval (per RFC 4028) + refreshInterval := time.Duration(st.sessionExpires/2) * time.Second + st.refreshTimer = time.AfterFunc(refreshInterval, func() { + st.handleRefresh() + }) + + st.log.Infow("Started session timer as refresher", + "sessionExpires", st.sessionExpires, + "refreshIn", refreshInterval) + } + + // Always set expiry timer (both refresher and non-refresher) + // Expiry warning at: expires - min(32, expires/3) seconds + st.expiryGeneration++ + currentGen := st.expiryGeneration + expiryWarning := st.sessionExpires - min(32, st.sessionExpires/3) + expiryDuration := time.Duration(expiryWarning) * time.Second + st.expiryTimer = time.AfterFunc(expiryDuration, func() { + st.handleExpiry(currentGen) + }) + + st.log.Infow("Started session timer", + "sessionExpires", st.sessionExpires, + "expiryWarning", expiryDuration, + "weAreRefresher", weAreRefresher) +} + +// Stop stops the session timer +func (st *SessionTimer) Stop() { + st.mu.Lock() + defer st.mu.Unlock() + + if st.stopped { + return + } + + st.stopped = true + + if st.refreshTimer != nil { + st.refreshTimer.Stop() + st.refreshTimer = nil + } + + if st.expiryTimer != nil { + st.expiryTimer.Stop() + st.expiryTimer = nil + } + + st.log.Infow("Stopped session timer") +} + +// OnRefreshReceived should be called when a session refresh request is received +// This resets the expiry timer +func (st *SessionTimer) OnRefreshReceived() { + st.mu.Lock() + defer st.mu.Unlock() + + if !st.started || st.stopped { + return + } + + st.lastRefresh = time.Now() + + // Reset expiry timer + if st.expiryTimer != nil { + st.expiryTimer.Stop() + } + + st.expiryGeneration++ + currentGen := st.expiryGeneration + expiryWarning := st.sessionExpires - min(32, st.sessionExpires/3) + expiryDuration := time.Duration(expiryWarning) * time.Second + st.expiryTimer = time.AfterFunc(expiryDuration, func() { + st.handleExpiry(currentGen) + }) + + st.log.Infow("Session refresh received, reset expiry timer", + "sessionExpires", st.sessionExpires, + "nextExpiry", expiryDuration) +} + +// handleRefresh is called when it's time to send a session refresh +func (st *SessionTimer) handleRefresh() { + st.mu.Lock() + if st.stopped || st.ctx == nil { + st.mu.Unlock() + return + } + + ctx := st.ctx + onRefresh := st.onRefresh + st.mu.Unlock() + + if onRefresh == nil { + st.log.Warnw("No refresh callback registered", nil) + return + } + + st.log.Infow("Sending session refresh") + + err := onRefresh(ctx) + if err != nil { + st.log.Errorw("Failed to send session refresh", err) + // Don't reschedule on error - let expiry timer handle it + return + } + + // Reschedule next refresh + st.mu.Lock() + defer st.mu.Unlock() + + if st.stopped { + return + } + + st.lastRefresh = time.Now() + + refreshInterval := time.Duration(st.sessionExpires/2) * time.Second + st.refreshTimer = time.AfterFunc(refreshInterval, func() { + st.handleRefresh() + }) + + st.log.Infow("Session refresh sent, scheduled next refresh", + "nextRefresh", refreshInterval) +} + +// handleExpiry is called when the session expires without refresh +func (st *SessionTimer) handleExpiry(generation uint64) { + st.mu.Lock() + // Check if this timer is stale (a newer timer was created) + if generation != st.expiryGeneration { + st.mu.Unlock() + return + } + if st.stopped || st.ctx == nil { + st.mu.Unlock() + return + } + + ctx := st.ctx + onExpiry := st.onExpiry + st.mu.Unlock() + + if onExpiry == nil { + st.log.Warnw("No expiry callback registered", nil) + return + } + + st.log.Warnw("Session timer expired, terminating call", nil, + "sessionExpires", st.sessionExpires, + "lastRefresh", st.lastRefresh) + + err := onExpiry(ctx) + if err != nil { + st.log.Errorw("Failed to handle session expiry", err) + } + + st.Stop() +} + +// GetSessionExpires returns the negotiated session expires value +func (st *SessionTimer) GetSessionExpires() int { + st.mu.Lock() + defer st.mu.Unlock() + return st.sessionExpires +} + +// GetRefresher returns the negotiated refresher role +func (st *SessionTimer) GetRefresher() RefresherRole { + st.mu.Lock() + defer st.mu.Unlock() + return st.refresher +} + +// IsStarted returns whether the timer is started +func (st *SessionTimer) IsStarted() bool { + st.mu.Lock() + defer st.mu.Unlock() + return st.started +} + diff --git a/pkg/sip/session_timer_test.go b/pkg/sip/session_timer_test.go new file mode 100644 index 00000000..62793e0d --- /dev/null +++ b/pkg/sip/session_timer_test.go @@ -0,0 +1,442 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sip + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/livekit/sipgo/sip" + "github.com/livekit/protocol/logger" +) + +// Helper function to create a properly formatted SIP request for testing +func createTestRequest() *sip.Request { + req := sip.NewRequest(sip.INVITE, sip.Uri{User: "test", Host: "example.com"}) + // Add required headers for NewResponseFromRequest to work + req.AppendHeader(sip.NewHeader("Via", "SIP/2.0/UDP test:5060;branch=z9hG4bK123")) + req.AppendHeader(sip.NewHeader("From", ";tag=abc123")) + req.AppendHeader(sip.NewHeader("To", "")) + req.AppendHeader(sip.NewHeader("Call-ID", "test-call-id")) + req.AppendHeader(sip.NewHeader("CSeq", "1 INVITE")) + return req +} + +func TestSessionTimerNegotiateInvite(t *testing.T) { + tests := []struct { + name string + config SessionTimerConfig + sessionExpiresValue string + minSEValue string + expectError bool + expectedExpires int + expectedRefresher RefresherRole + }{ + { + name: "valid session timer with refresher=uac", + config: SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAC, + }, + sessionExpiresValue: "1800;refresher=uac", + minSEValue: "90", + expectError: false, + expectedExpires: 1800, + expectedRefresher: RefresherUAC, + }, + { + name: "valid session timer with refresher=uas", + config: SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAS, + }, + sessionExpiresValue: "1800;refresher=uas", + minSEValue: "90", + expectError: false, + expectedExpires: 1800, + expectedRefresher: RefresherUAS, + }, + { + name: "session interval too small", + config: SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAC, + }, + sessionExpiresValue: "60", + minSEValue: "90", + expectError: true, + }, + { + name: "no session expires header", + config: SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAC, + }, + sessionExpiresValue: "", + expectError: false, + expectedExpires: 0, + expectedRefresher: RefresherNone, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + log := logger.GetLogger() + st := NewSessionTimer(tt.config, false, log) + + // Create a mock INVITE request + req := sip.NewRequest(sip.INVITE, sip.Uri{User: "test", Host: "example.com"}) + if tt.sessionExpiresValue != "" { + req.AppendHeader(sip.NewHeader("Session-Expires", tt.sessionExpiresValue)) + } + if tt.minSEValue != "" { + req.AppendHeader(sip.NewHeader("Min-SE", tt.minSEValue)) + } + + sessionExpires, _, refresher, err := st.NegotiateInvite(req) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got nil") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if sessionExpires != tt.expectedExpires { + t.Errorf("Expected sessionExpires=%d, got %d", tt.expectedExpires, sessionExpires) + } + if refresher != tt.expectedRefresher { + t.Errorf("Expected refresher=%v, got %v", tt.expectedRefresher, refresher) + } + } + }) + } +} + +func TestSessionTimerNegotiateResponse(t *testing.T) { + tests := []struct { + name string + config SessionTimerConfig + sessionExpiresValue string + expectedExpires int + expectedRefresher RefresherRole + }{ + { + name: "valid response with refresher=uac", + config: SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAC, + }, + sessionExpiresValue: "1800;refresher=uac", + expectedExpires: 1800, + expectedRefresher: RefresherUAC, + }, + { + name: "valid response with refresher=uas", + config: SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAS, + }, + sessionExpiresValue: "1800;refresher=uas", + expectedExpires: 1800, + expectedRefresher: RefresherUAS, + }, + { + name: "response without refresher defaults to uac", + config: SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAC, + }, + sessionExpiresValue: "1800", + expectedExpires: 1800, + expectedRefresher: RefresherUAC, + }, + { + name: "no session expires in response", + config: SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAC, + }, + sessionExpiresValue: "", + expectedExpires: 1800, // Should remain at default + expectedRefresher: RefresherNone, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + log := logger.GetLogger() + st := NewSessionTimer(tt.config, true, log) + + // Create a mock 200 OK response + req := createTestRequest() + res := sip.NewResponseFromRequest(req, sip.StatusOK, "OK", nil) + if tt.sessionExpiresValue != "" { + res.AppendHeader(sip.NewHeader("Session-Expires", tt.sessionExpiresValue)) + } + + err := st.NegotiateResponse(res) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + actualExpires := st.GetSessionExpires() + actualRefresher := st.GetRefresher() + + if actualExpires != tt.expectedExpires { + t.Errorf("Expected sessionExpires=%d, got %d", tt.expectedExpires, actualExpires) + } + if actualRefresher != tt.expectedRefresher { + t.Errorf("Expected refresher=%v, got %v", tt.expectedRefresher, actualRefresher) + } + }) + } +} + +func TestSessionTimerAddHeadersToRequest(t *testing.T) { + config := SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAC, + } + + log := logger.GetLogger() + st := NewSessionTimer(config, true, log) + + req := sip.NewRequest(sip.INVITE, sip.Uri{User: "test", Host: "example.com"}) + st.AddHeadersToRequest(req) + + // Check for Supported header + supportedHeader := req.GetHeader("Supported") + if supportedHeader == nil || supportedHeader.Value() != "timer" { + t.Errorf("Expected Supported: timer header") + } + + // Check for Session-Expires header + sessionExpiresHeader := req.GetHeader("Session-Expires") + if sessionExpiresHeader == nil { + t.Errorf("Expected Session-Expires header") + } + + // Check for Min-SE header + minSEHeader := req.GetHeader("Min-SE") + if minSEHeader == nil || minSEHeader.Value() != "90" { + t.Errorf("Expected Min-SE: 90 header") + } +} + +func TestSessionTimerAddHeadersToResponse(t *testing.T) { + config := SessionTimerConfig{ + DefaultExpires: 1800, + MinSE: 90, + PreferRefresher: RefresherUAS, + } + + log := logger.GetLogger() + st := NewSessionTimer(config, false, log) + + req := createTestRequest() + res := sip.NewResponseFromRequest(req, sip.StatusOK, "OK", nil) + + st.AddHeadersToResponse(res, 1800, RefresherUAS) + + // Check for Require header + requireHeader := res.GetHeader("Require") + if requireHeader == nil || requireHeader.Value() != "timer" { + t.Errorf("Expected Require: timer header") + } + + // Check for Session-Expires header + sessionExpiresHeader := res.GetHeader("Session-Expires") + if sessionExpiresHeader == nil { + t.Errorf("Expected Session-Expires header") + } +} + +func TestSessionTimerRefreshCallback(t *testing.T) { + config := SessionTimerConfig{ + DefaultExpires: 1, // 1 second for fast testing + MinSE: 1, + PreferRefresher: RefresherUAC, + } + + log := logger.GetLogger() + st := NewSessionTimer(config, true, log) + st.sessionExpires = 1 + st.refresher = RefresherUAC + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + st.SetContext(ctx) + + var refreshCalled atomic.Bool + st.SetCallbacks( + func(ctx context.Context) error { + refreshCalled.Store(true) + return nil + }, + func(ctx context.Context) error { + return nil + }, + ) + + st.Start() + + // Wait for refresh callback to be called (should happen at half interval = 0.5s) + time.Sleep(750 * time.Millisecond) + + if !refreshCalled.Load() { + t.Errorf("Refresh callback was not called") + } + + st.Stop() +} + +func TestSessionTimerExpiryCallback(t *testing.T) { + config := SessionTimerConfig{ + DefaultExpires: 2, // 2 seconds for testing + MinSE: 1, + PreferRefresher: RefresherNone, // We are not the refresher + } + + log := logger.GetLogger() + st := NewSessionTimer(config, false, log) + st.sessionExpires = 2 + st.refresher = RefresherUAC // Remote is refresher, but they won't refresh + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + st.SetContext(ctx) + + var expiryCalled atomic.Bool + st.SetCallbacks( + func(ctx context.Context) error { + return nil + }, + func(ctx context.Context) error { + expiryCalled.Store(true) + return nil + }, + ) + + st.Start() + + // Wait for expiry callback to be called + // Expiry happens at: expires - min(32, expires/3) = 2 - min(32, 0) = 2 seconds + time.Sleep(2500 * time.Millisecond) + + if !expiryCalled.Load() { + t.Errorf("Expiry callback was not called") + } + + st.Stop() +} + +func TestSessionTimerOnRefreshReceived(t *testing.T) { + config := SessionTimerConfig{ + DefaultExpires: 2, + MinSE: 1, + PreferRefresher: RefresherNone, + } + + log := logger.GetLogger() + st := NewSessionTimer(config, false, log) + st.sessionExpires = 2 + st.refresher = RefresherUAC + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + st.SetContext(ctx) + + var expiryCalled atomic.Bool + st.SetCallbacks( + func(ctx context.Context) error { + return nil + }, + func(ctx context.Context) error { + expiryCalled.Store(true) + return nil + }, + ) + + st.Start() + + // Wait a bit (but less than expiry time) + time.Sleep(500 * time.Millisecond) + + // Receive a refresh - this should reset the expiry timer to 2s from now (t=2.5s) + st.OnRefreshReceived() + + // Wait for the original expiry time (t=2.0s) - should not expire because we refreshed + // We're now at t=1.5s, which is past the original expiry of t=2s but before the new expiry of t=2.5s + time.Sleep(1000 * time.Millisecond) + + if expiryCalled.Load() { + t.Errorf("Expiry callback was called despite receiving refresh") + } + + st.Stop() +} + +func TestSessionTimerStop(t *testing.T) { + config := SessionTimerConfig{ + DefaultExpires: 1, + MinSE: 1, + PreferRefresher: RefresherUAC, + } + + log := logger.GetLogger() + st := NewSessionTimer(config, true, log) + st.sessionExpires = 1 + st.refresher = RefresherUAC + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + st.SetContext(ctx) + + var refreshCalled atomic.Bool + st.SetCallbacks( + func(ctx context.Context) error { + refreshCalled.Store(true) + return nil + }, + func(ctx context.Context) error { + return nil + }, + ) + + st.Start() + + // Stop immediately + st.Stop() + + // Wait to ensure callbacks are not called + time.Sleep(1500 * time.Millisecond) + + if refreshCalled.Load() { + t.Errorf("Refresh callback was called after Stop()") + } +}