1+ import functools
12import logging
23import threading
34import time
910 client ,
1011 watch , # type: ignore[attr-defined]
1112)
13+ from kubernetes import config as kubernetes_config
1214
1315from metta .app_backend .clients .stats_client import StatsClient
1416from metta .app_backend .job_runner .config import (
1820 LABEL_JOB_ID ,
1921 get_dispatch_config ,
2022)
21- from metta .app_backend .job_runner .dispatcher import get_k8s_client
2223from metta .app_backend .models .job_request import JobRequestUpdate , JobStatus
2324from metta .common .util .log_config import init_logging , suppress_noisy_logs
2425
@@ -45,15 +46,11 @@ def _is_healthy() -> bool:
4546
4647class HealthHandler (BaseHTTPRequestHandler ):
4748 def do_GET (self ):
48- if self .path == "/health" or self .path == "/healthz" :
49- if _is_healthy ():
50- self .send_response (200 )
51- self .end_headers ()
52- self .wfile .write (b"ok" )
53- else :
54- self .send_response (503 )
55- self .end_headers ()
56- self .wfile .write (b"unhealthy: watch loop stale" )
49+ if self .path in ("/health" , "/healthz" ):
50+ status , body = (200 , b"ok" ) if _is_healthy () else (503 , b"unhealthy: watch loop stale" )
51+ self .send_response (status )
52+ self .end_headers ()
53+ self .wfile .write (body )
5754 else :
5855 self .send_response (404 )
5956 self .end_headers ()
@@ -64,123 +61,167 @@ def log_message(self, format, *args):
6461
6562def _start_health_server ():
6663 server = HTTPServer (("0.0.0.0" , HEALTH_PORT ), HealthHandler )
67- thread = threading .Thread (target = server .serve_forever , daemon = True )
68- thread .start ()
64+ threading .Thread (target = server .serve_forever , daemon = True ).start ()
6965 logger .info (f"Health server started on port { HEALTH_PORT } " )
7066
7167
72- # kubernetes-stubs V1WatchEventDict uses Any for object; we define our own for V1Job typing
73- class K8sJobWatchEvent (TypedDict ):
74- type : Literal ["ADDED" , "MODIFIED" , "DELETED" ]
75- object : client .V1Job
68+ @functools .cache
69+ def _get_k8s_clients () -> tuple [client .CoreV1Api , client .BatchV1Api ]:
70+ cfg = get_dispatch_config ()
71+ if cfg .LOCAL_DEV :
72+ if not cfg .LOCAL_DEV_K8S_CONTEXT :
73+ raise ValueError ("LOCAL_DEV=true requires LOCAL_DEV_K8S_CONTEXT to be set" )
74+ kubernetes_config .load_kube_config (context = cfg .LOCAL_DEV_K8S_CONTEXT )
75+ else :
76+ kubernetes_config .load_incluster_config ()
77+ return client .CoreV1Api (), client .BatchV1Api ()
78+
79+
80+ # ADDED: Pod created (usually starts in Pending phase)
81+ # MODIFIED: Pod state changed (phase transitions, container status updates)
82+ # DELETED: Pod removed from cluster
83+ # BOOKMARK: Internal watch checkpoint (no actual change, just resourceVersion update)
84+ # ERROR: Watch stream error
85+ K8sPodWatchEventType = Literal ["ADDED" , "MODIFIED" , "DELETED" , "BOOKMARK" , "ERROR" ]
86+
87+
88+ class K8sPodWatchEvent (TypedDict ):
89+ type : K8sPodWatchEventType
90+ object : client .V1Pod
7691
7792
7893def run_watcher ():
7994 cfg = get_dispatch_config ()
80- get_k8s_client () # initialize cached client
95+ _get_k8s_clients ()
8196
8297 _start_health_server ()
83- _update_heartbeat () # mark healthy before blocking on auth
98+ _update_heartbeat ()
8499
85- # Pass token directly instead of writing to config file
86100 stats_client = StatsClient (backend_url = cfg .STATS_SERVER_URI , machine_token = cfg .MACHINE_TOKEN )
87101 stats_client ._validate_authenticated ()
88-
89102 logger .info (f"Watcher started: stats_server_uri={ cfg .STATS_SERVER_URI } , namespace={ JOB_NAMESPACE } " )
90103
91104 try :
92105 while True :
93106 try :
94- watch_jobs (stats_client )
107+ _watch_pods (stats_client )
95108 except Exception as e :
96109 logger .error (f"Watch error, restarting: { e } " , exc_info = True )
97110 time .sleep (1 )
98111 finally :
99112 stats_client .close ()
100113
101114
102- def watch_jobs (stats_client : StatsClient ):
115+ def _watch_pods (stats_client : StatsClient ):
103116 label_selector = f"{ LABEL_APP } ={ LABEL_APP_VALUE } "
104- batch_v1 = get_k8s_client ()
117+ core_v1 , _ = _get_k8s_clients ()
105118
106- job_list = batch_v1 . list_namespaced_job (namespace = JOB_NAMESPACE , label_selector = label_selector )
107- if not job_list .metadata or not job_list .metadata .resource_version :
108- logger .error (f"Invalid job list: { job_list } " )
119+ pod_list = core_v1 . list_namespaced_pod (namespace = JOB_NAMESPACE , label_selector = label_selector )
120+ if not pod_list .metadata or not pod_list .metadata .resource_version :
121+ logger .error (f"Invalid pod list: { pod_list } " )
109122 return
110- resource_version = job_list .metadata .resource_version
111123
112- # Process any existing jobs that may have completed before watcher started
113- for k8s_job in job_list .items :
114- if not k8s_job .metadata or not k8s_job .metadata .labels :
115- continue
116- job_id_str = k8s_job .metadata .labels .get (LABEL_JOB_ID )
117- if job_id_str :
118- handle_job_state (stats_client , UUID (job_id_str ), k8s_job )
124+ for pod in pod_list .items :
125+ _handle_pod_state (stats_client , pod )
119126
120- logger .info (f"Starting watch from resourceVersion={ resource_version } " )
127+ resource_version = pod_list .metadata .resource_version
128+ logger .info (f"Starting pod watch from resourceVersion={ resource_version } " )
121129 _update_heartbeat ()
122130
123131 w = watch .Watch ()
124- event : K8sJobWatchEvent
132+ event : K8sPodWatchEvent
125133 for event in w .stream ( # type: ignore[assignment]
126- batch_v1 . list_namespaced_job ,
134+ core_v1 . list_namespaced_pod ,
127135 namespace = JOB_NAMESPACE ,
128136 label_selector = label_selector ,
129137 resource_version = resource_version ,
130138 timeout_seconds = WATCH_TIMEOUT_SECONDS ,
131139 ):
132140 _update_heartbeat ()
133- event_type = event ["type" ]
134- k8s_job = event [ "object" ]
135- if not k8s_job . metadata or not k8s_job . metadata . name or not k8s_job . metadata . labels :
136- logger . error ( f"Invalid k8s job: { k8s_job } " )
137- continue
141+ event_type , pod = event ["type" ], event [ "object " ]
142+ if event_type in ( "ADDED" , "MODIFIED" ):
143+ _handle_pod_state ( stats_client , pod )
144+ elif event_type == "DELETED" :
145+ _handle_pod_deleted ( stats_client , pod )
138146
139- job_id_str = k8s_job .metadata .labels .get (LABEL_JOB_ID )
140- if not job_id_str :
141- logger .error (f"Job { k8s_job .metadata .name } has no job ID label" )
142- continue
143147
144- job_id = UUID (job_id_str )
145- job_name = k8s_job .metadata .name
148+ def _get_job_info (pod : client .V1Pod ) -> tuple [UUID , str ] | None :
149+ if not pod .metadata or not pod .metadata .labels :
150+ return None
151+ job_id_str = pod .metadata .labels .get (LABEL_JOB_ID )
152+ if not job_id_str :
153+ return None
154+ return UUID (job_id_str ), pod .metadata .name or "unknown"
146155
147- logger .debug (f"Event { event_type } for job { job_name } (id={ job_id } )" )
148156
149- if event_type in ( "ADDED" , "MODIFIED" ):
150- handle_job_state ( stats_client , job_id , k8s_job )
151- elif event_type == "DELETED" :
152- logger . info ( f"Job { job_name } deleted" )
157+ def _handle_pod_state ( stats_client : StatsClient , pod : client . V1Pod ):
158+ info = _get_job_info ( pod )
159+ if not info or not pod . status :
160+ return
153161
162+ job_id , pod_name = info
163+ phase = pod .status .phase
164+
165+ if phase == "Succeeded" :
166+ _update_job_status (stats_client , job_id , JobStatus .completed )
167+ _delete_k8s_job_for_pod (pod )
168+ logger .info (f"Job { job_id } completed (pod { pod_name } )" )
169+ elif phase == "Failed" :
170+ error = _get_pod_error (pod )
171+ _update_job_status (stats_client , job_id , JobStatus .failed , error = error )
172+ _delete_k8s_job_for_pod (pod )
173+ logger .info (f"Job { job_id } failed (pod { pod_name } ): { error } " )
174+ elif phase == "Running" and _is_container_running (pod ):
175+ _update_job_status (stats_client , job_id , JobStatus .running , worker = pod_name )
176+ logger .debug (f"Job { job_id } running (pod { pod_name } )" )
177+
178+
179+ def _handle_pod_deleted (stats_client : StatsClient , pod : client .V1Pod ):
180+ info = _get_job_info (pod )
181+ if not info :
182+ return
154183
155- def handle_job_state (
156- stats_client : StatsClient ,
157- job_id : UUID ,
158- k8s_job : client .V1Job ,
159- ):
160- if not (k8s_job .metadata and k8s_job .spec and k8s_job .status and k8s_job .metadata .name ):
161- logger .error (f"Invalid k8s job: { k8s_job } " )
184+ phase = pod .status .phase if pod .status else None
185+ if phase in ("Succeeded" , "Failed" ):
162186 return
163187
164- status = k8s_job . status
165- job_name = k8s_job . metadata . name
166- backoff_limit = k8s_job . spec . backoff_limit or 0
188+ job_id , pod_name = info
189+ _update_job_status ( stats_client , job_id , JobStatus . failed , error = "Pod deleted unexpectedly" )
190+ logger . warning ( f"Job { job_id } failed: pod { pod_name } deleted unexpectedly (phase= { phase } )" )
167191
168- if status .succeeded and status .succeeded > 0 :
169- update_job_status (stats_client , job_id , JobStatus .completed )
170- delete_k8s_job (job_name )
171- logger .info (f"Job { job_id } completed" )
172192
173- elif status . failed and status . failed >= backoff_limit :
174- update_job_status ( stats_client , job_id , JobStatus . failed , error = "k8s job failed" )
175- delete_k8s_job ( job_name )
176- logger . info ( f"Job { job_id } failed" )
193+ def _is_container_running ( pod : client . V1Pod ) -> bool :
194+ if not pod . status or not pod . status . container_statuses :
195+ return False
196+ return any ( cs . state and cs . state . running for cs in pod . status . container_statuses )
177197
178- elif status .active and status .active > 0 :
179- update_job_status (stats_client , job_id , JobStatus .running , worker = job_name )
180- logger .debug (f"Job { job_id } running" )
181198
199+ def _get_pod_error (pod : client .V1Pod ) -> str :
200+ if pod .status and pod .status .container_statuses :
201+ for cs in pod .status .container_statuses :
202+ if cs .state and cs .state .terminated and cs .state .terminated .reason :
203+ return cs .state .terminated .reason
204+ return (pod .status .message if pod .status else None ) or "Pod failed"
182205
183- def update_job_status (
206+
207+ def _get_job_name_for_pod (pod : client .V1Pod ) -> str | None :
208+ if not pod .metadata or not pod .metadata .owner_references :
209+ return None
210+ return next ((ref .name for ref in pod .metadata .owner_references if ref .kind == "Job" ), None )
211+
212+
213+ def _delete_k8s_job_for_pod (pod : client .V1Pod ):
214+ job_name = _get_job_name_for_pod (pod )
215+ if not job_name :
216+ return
217+ try :
218+ _ , batch_v1 = _get_k8s_clients ()
219+ batch_v1 .delete_namespaced_job (name = job_name , namespace = JOB_NAMESPACE , propagation_policy = "Background" )
220+ except Exception as e :
221+ logger .error (f"Failed to delete k8s job { job_name } : { e } " )
222+
223+
224+ def _update_job_status (
184225 stats_client : StatsClient ,
185226 job_id : UUID ,
186227 status : JobStatus ,
@@ -189,27 +230,13 @@ def update_job_status(
189230):
190231 try :
191232 current = stats_client .get_job (job_id )
192- if current .status == status :
193- return
194- if current .status in (JobStatus .completed , JobStatus .failed ):
233+ if current .status == status or current .status in (JobStatus .completed , JobStatus .failed ):
195234 return
196-
197235 stats_client .update_job (job_id , JobRequestUpdate (status = status , error = error , worker = worker ))
198236 except Exception as e :
199237 logger .error (f"Failed to update job { job_id } status to { status } : { e } " )
200238
201239
202- def delete_k8s_job (job_name : str ):
203- try :
204- get_k8s_client ().delete_namespaced_job (
205- name = job_name ,
206- namespace = JOB_NAMESPACE ,
207- propagation_policy = "Background" ,
208- )
209- except Exception as e :
210- logger .error (f"Failed to delete k8s job { job_name } : { e } " )
211-
212-
213240if __name__ == "__main__" :
214241 init_logging ()
215242 suppress_noisy_logs ()
0 commit comments