diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index 776db5354..95847e50f 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -336,13 +336,43 @@ def get_job_logs( """Get the TrainJob logs""" # Get the TrainJob Pod name. pod_name = None - for c in self.get_job(name).steps: - if c.status != constants.POD_PENDING and c.name == step: + job = self.get_job(name) + + # First search if pod already exists + for c in job.steps: + if c.name == step and c.pod_name and c.status != constants.POD_PENDING: pod_name = c.pod_name break - if pod_name is None: + + # If follow=False → old behaviour + if pod_name is None and not follow: return + # If follow=True → wait for pod to be created & running + if pod_name is None and follow: + import time + + timeout = 120 # seconds + interval = 2 # seconds + waited = 0 + + while waited < timeout: + job = self.get_job(name) + for c in job.steps: + if c.name == step and c.pod_name and c.status != constants.POD_PENDING: + pod_name = c.pod_name + break + + if pod_name: + break + + time.sleep(interval) + waited += interval + + # Timeout → no pod found + if pod_name is None: + return + # Remove the number for the node step. container_name = re.sub(r"-\d+$", "", step) yield from self._read_pod_logs(