Skip to content

Commit a0360e5

Browse files
Watcher attends to pod events, not jobs events, to more accurately capture running time (#4560)
1 parent c5652d7 commit a0360e5

File tree

1 file changed

+117
-90
lines changed
  • app_backend/src/metta/app_backend/job_runner

1 file changed

+117
-90
lines changed

app_backend/src/metta/app_backend/job_runner/watcher.py

Lines changed: 117 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import logging
23
import threading
34
import time
@@ -9,6 +10,7 @@
910
client,
1011
watch, # type: ignore[attr-defined]
1112
)
13+
from kubernetes import config as kubernetes_config
1214

1315
from metta.app_backend.clients.stats_client import StatsClient
1416
from metta.app_backend.job_runner.config import (
@@ -18,7 +20,6 @@
1820
LABEL_JOB_ID,
1921
get_dispatch_config,
2022
)
21-
from metta.app_backend.job_runner.dispatcher import get_k8s_client
2223
from metta.app_backend.models.job_request import JobRequestUpdate, JobStatus
2324
from metta.common.util.log_config import init_logging, suppress_noisy_logs
2425

@@ -45,15 +46,11 @@ def _is_healthy() -> bool:
4546

4647
class HealthHandler(BaseHTTPRequestHandler):
4748
def do_GET(self):
48-
if self.path == "/health" or self.path == "/healthz":
49-
if _is_healthy():
50-
self.send_response(200)
51-
self.end_headers()
52-
self.wfile.write(b"ok")
53-
else:
54-
self.send_response(503)
55-
self.end_headers()
56-
self.wfile.write(b"unhealthy: watch loop stale")
49+
if self.path in ("/health", "/healthz"):
50+
status, body = (200, b"ok") if _is_healthy() else (503, b"unhealthy: watch loop stale")
51+
self.send_response(status)
52+
self.end_headers()
53+
self.wfile.write(body)
5754
else:
5855
self.send_response(404)
5956
self.end_headers()
@@ -64,123 +61,167 @@ def log_message(self, format, *args):
6461

6562
def _start_health_server():
6663
server = HTTPServer(("0.0.0.0", HEALTH_PORT), HealthHandler)
67-
thread = threading.Thread(target=server.serve_forever, daemon=True)
68-
thread.start()
64+
threading.Thread(target=server.serve_forever, daemon=True).start()
6965
logger.info(f"Health server started on port {HEALTH_PORT}")
7066

7167

72-
# kubernetes-stubs V1WatchEventDict uses Any for object; we define our own for V1Job typing
73-
class K8sJobWatchEvent(TypedDict):
74-
type: Literal["ADDED", "MODIFIED", "DELETED"]
75-
object: client.V1Job
68+
@functools.cache
69+
def _get_k8s_clients() -> tuple[client.CoreV1Api, client.BatchV1Api]:
70+
cfg = get_dispatch_config()
71+
if cfg.LOCAL_DEV:
72+
if not cfg.LOCAL_DEV_K8S_CONTEXT:
73+
raise ValueError("LOCAL_DEV=true requires LOCAL_DEV_K8S_CONTEXT to be set")
74+
kubernetes_config.load_kube_config(context=cfg.LOCAL_DEV_K8S_CONTEXT)
75+
else:
76+
kubernetes_config.load_incluster_config()
77+
return client.CoreV1Api(), client.BatchV1Api()
78+
79+
80+
# ADDED: Pod created (usually starts in Pending phase)
81+
# MODIFIED: Pod state changed (phase transitions, container status updates)
82+
# DELETED: Pod removed from cluster
83+
# BOOKMARK: Internal watch checkpoint (no actual change, just resourceVersion update)
84+
# ERROR: Watch stream error
85+
K8sPodWatchEventType = Literal["ADDED", "MODIFIED", "DELETED", "BOOKMARK", "ERROR"]
86+
87+
88+
class K8sPodWatchEvent(TypedDict):
89+
type: K8sPodWatchEventType
90+
object: client.V1Pod
7691

7792

7893
def run_watcher():
7994
cfg = get_dispatch_config()
80-
get_k8s_client() # initialize cached client
95+
_get_k8s_clients()
8196

8297
_start_health_server()
83-
_update_heartbeat() # mark healthy before blocking on auth
98+
_update_heartbeat()
8499

85-
# Pass token directly instead of writing to config file
86100
stats_client = StatsClient(backend_url=cfg.STATS_SERVER_URI, machine_token=cfg.MACHINE_TOKEN)
87101
stats_client._validate_authenticated()
88-
89102
logger.info(f"Watcher started: stats_server_uri={cfg.STATS_SERVER_URI}, namespace={JOB_NAMESPACE}")
90103

91104
try:
92105
while True:
93106
try:
94-
watch_jobs(stats_client)
107+
_watch_pods(stats_client)
95108
except Exception as e:
96109
logger.error(f"Watch error, restarting: {e}", exc_info=True)
97110
time.sleep(1)
98111
finally:
99112
stats_client.close()
100113

101114

102-
def watch_jobs(stats_client: StatsClient):
115+
def _watch_pods(stats_client: StatsClient):
103116
label_selector = f"{LABEL_APP}={LABEL_APP_VALUE}"
104-
batch_v1 = get_k8s_client()
117+
core_v1, _ = _get_k8s_clients()
105118

106-
job_list = batch_v1.list_namespaced_job(namespace=JOB_NAMESPACE, label_selector=label_selector)
107-
if not job_list.metadata or not job_list.metadata.resource_version:
108-
logger.error(f"Invalid job list: {job_list}")
119+
pod_list = core_v1.list_namespaced_pod(namespace=JOB_NAMESPACE, label_selector=label_selector)
120+
if not pod_list.metadata or not pod_list.metadata.resource_version:
121+
logger.error(f"Invalid pod list: {pod_list}")
109122
return
110-
resource_version = job_list.metadata.resource_version
111123

112-
# Process any existing jobs that may have completed before watcher started
113-
for k8s_job in job_list.items:
114-
if not k8s_job.metadata or not k8s_job.metadata.labels:
115-
continue
116-
job_id_str = k8s_job.metadata.labels.get(LABEL_JOB_ID)
117-
if job_id_str:
118-
handle_job_state(stats_client, UUID(job_id_str), k8s_job)
124+
for pod in pod_list.items:
125+
_handle_pod_state(stats_client, pod)
119126

120-
logger.info(f"Starting watch from resourceVersion={resource_version}")
127+
resource_version = pod_list.metadata.resource_version
128+
logger.info(f"Starting pod watch from resourceVersion={resource_version}")
121129
_update_heartbeat()
122130

123131
w = watch.Watch()
124-
event: K8sJobWatchEvent
132+
event: K8sPodWatchEvent
125133
for event in w.stream( # type: ignore[assignment]
126-
batch_v1.list_namespaced_job,
134+
core_v1.list_namespaced_pod,
127135
namespace=JOB_NAMESPACE,
128136
label_selector=label_selector,
129137
resource_version=resource_version,
130138
timeout_seconds=WATCH_TIMEOUT_SECONDS,
131139
):
132140
_update_heartbeat()
133-
event_type = event["type"]
134-
k8s_job = event["object"]
135-
if not k8s_job.metadata or not k8s_job.metadata.name or not k8s_job.metadata.labels:
136-
logger.error(f"Invalid k8s job: {k8s_job}")
137-
continue
141+
event_type, pod = event["type"], event["object"]
142+
if event_type in ("ADDED", "MODIFIED"):
143+
_handle_pod_state(stats_client, pod)
144+
elif event_type == "DELETED":
145+
_handle_pod_deleted(stats_client, pod)
138146

139-
job_id_str = k8s_job.metadata.labels.get(LABEL_JOB_ID)
140-
if not job_id_str:
141-
logger.error(f"Job {k8s_job.metadata.name} has no job ID label")
142-
continue
143147

144-
job_id = UUID(job_id_str)
145-
job_name = k8s_job.metadata.name
148+
def _get_job_info(pod: client.V1Pod) -> tuple[UUID, str] | None:
149+
if not pod.metadata or not pod.metadata.labels:
150+
return None
151+
job_id_str = pod.metadata.labels.get(LABEL_JOB_ID)
152+
if not job_id_str:
153+
return None
154+
return UUID(job_id_str), pod.metadata.name or "unknown"
146155

147-
logger.debug(f"Event {event_type} for job {job_name} (id={job_id})")
148156

149-
if event_type in ("ADDED", "MODIFIED"):
150-
handle_job_state(stats_client, job_id, k8s_job)
151-
elif event_type == "DELETED":
152-
logger.info(f"Job {job_name} deleted")
157+
def _handle_pod_state(stats_client: StatsClient, pod: client.V1Pod):
158+
info = _get_job_info(pod)
159+
if not info or not pod.status:
160+
return
153161

162+
job_id, pod_name = info
163+
phase = pod.status.phase
164+
165+
if phase == "Succeeded":
166+
_update_job_status(stats_client, job_id, JobStatus.completed)
167+
_delete_k8s_job_for_pod(pod)
168+
logger.info(f"Job {job_id} completed (pod {pod_name})")
169+
elif phase == "Failed":
170+
error = _get_pod_error(pod)
171+
_update_job_status(stats_client, job_id, JobStatus.failed, error=error)
172+
_delete_k8s_job_for_pod(pod)
173+
logger.info(f"Job {job_id} failed (pod {pod_name}): {error}")
174+
elif phase == "Running" and _is_container_running(pod):
175+
_update_job_status(stats_client, job_id, JobStatus.running, worker=pod_name)
176+
logger.debug(f"Job {job_id} running (pod {pod_name})")
177+
178+
179+
def _handle_pod_deleted(stats_client: StatsClient, pod: client.V1Pod):
180+
info = _get_job_info(pod)
181+
if not info:
182+
return
154183

155-
def handle_job_state(
156-
stats_client: StatsClient,
157-
job_id: UUID,
158-
k8s_job: client.V1Job,
159-
):
160-
if not (k8s_job.metadata and k8s_job.spec and k8s_job.status and k8s_job.metadata.name):
161-
logger.error(f"Invalid k8s job: {k8s_job}")
184+
phase = pod.status.phase if pod.status else None
185+
if phase in ("Succeeded", "Failed"):
162186
return
163187

164-
status = k8s_job.status
165-
job_name = k8s_job.metadata.name
166-
backoff_limit = k8s_job.spec.backoff_limit or 0
188+
job_id, pod_name = info
189+
_update_job_status(stats_client, job_id, JobStatus.failed, error="Pod deleted unexpectedly")
190+
logger.warning(f"Job {job_id} failed: pod {pod_name} deleted unexpectedly (phase={phase})")
167191

168-
if status.succeeded and status.succeeded > 0:
169-
update_job_status(stats_client, job_id, JobStatus.completed)
170-
delete_k8s_job(job_name)
171-
logger.info(f"Job {job_id} completed")
172192

173-
elif status.failed and status.failed >= backoff_limit:
174-
update_job_status(stats_client, job_id, JobStatus.failed, error="k8s job failed")
175-
delete_k8s_job(job_name)
176-
logger.info(f"Job {job_id} failed")
193+
def _is_container_running(pod: client.V1Pod) -> bool:
194+
if not pod.status or not pod.status.container_statuses:
195+
return False
196+
return any(cs.state and cs.state.running for cs in pod.status.container_statuses)
177197

178-
elif status.active and status.active > 0:
179-
update_job_status(stats_client, job_id, JobStatus.running, worker=job_name)
180-
logger.debug(f"Job {job_id} running")
181198

199+
def _get_pod_error(pod: client.V1Pod) -> str:
200+
if pod.status and pod.status.container_statuses:
201+
for cs in pod.status.container_statuses:
202+
if cs.state and cs.state.terminated and cs.state.terminated.reason:
203+
return cs.state.terminated.reason
204+
return (pod.status.message if pod.status else None) or "Pod failed"
182205

183-
def update_job_status(
206+
207+
def _get_job_name_for_pod(pod: client.V1Pod) -> str | None:
208+
if not pod.metadata or not pod.metadata.owner_references:
209+
return None
210+
return next((ref.name for ref in pod.metadata.owner_references if ref.kind == "Job"), None)
211+
212+
213+
def _delete_k8s_job_for_pod(pod: client.V1Pod):
214+
job_name = _get_job_name_for_pod(pod)
215+
if not job_name:
216+
return
217+
try:
218+
_, batch_v1 = _get_k8s_clients()
219+
batch_v1.delete_namespaced_job(name=job_name, namespace=JOB_NAMESPACE, propagation_policy="Background")
220+
except Exception as e:
221+
logger.error(f"Failed to delete k8s job {job_name}: {e}")
222+
223+
224+
def _update_job_status(
184225
stats_client: StatsClient,
185226
job_id: UUID,
186227
status: JobStatus,
@@ -189,27 +230,13 @@ def update_job_status(
189230
):
190231
try:
191232
current = stats_client.get_job(job_id)
192-
if current.status == status:
193-
return
194-
if current.status in (JobStatus.completed, JobStatus.failed):
233+
if current.status == status or current.status in (JobStatus.completed, JobStatus.failed):
195234
return
196-
197235
stats_client.update_job(job_id, JobRequestUpdate(status=status, error=error, worker=worker))
198236
except Exception as e:
199237
logger.error(f"Failed to update job {job_id} status to {status}: {e}")
200238

201239

202-
def delete_k8s_job(job_name: str):
203-
try:
204-
get_k8s_client().delete_namespaced_job(
205-
name=job_name,
206-
namespace=JOB_NAMESPACE,
207-
propagation_policy="Background",
208-
)
209-
except Exception as e:
210-
logger.error(f"Failed to delete k8s job {job_name}: {e}")
211-
212-
213240
if __name__ == "__main__":
214241
init_logging()
215242
suppress_noisy_logs()

0 commit comments

Comments
 (0)