diff --git a/helm-chart/kuberay-operator/crds/ray.io_rayjobs.yaml b/helm-chart/kuberay-operator/crds/ray.io_rayjobs.yaml index 2e65eead13d..24e732ef58f 100644 --- a/helm-chart/kuberay-operator/crds/ray.io_rayjobs.yaml +++ b/helm-chart/kuberay-operator/crds/ray.io_rayjobs.yaml @@ -12060,6 +12060,10 @@ spec: description: ShutdownAfterJobFinishes will determine whether to delete the ray cluster once rayJob succeed or fai type: boolean + suspend: + description: suspend specifies whether the RayJob controller should + create a RayCluster instance If a job is appl + type: boolean ttlSecondsAfterFinished: description: TTLSecondsAfterFinished is the TTL to clean up RayCluster. format: int32 diff --git a/ray-operator/apis/ray/v1alpha1/rayjob_types.go b/ray-operator/apis/ray/v1alpha1/rayjob_types.go index fe6c9cef015..2d94f4327d3 100644 --- a/ray-operator/apis/ray/v1alpha1/rayjob_types.go +++ b/ray-operator/apis/ray/v1alpha1/rayjob_types.go @@ -39,6 +39,7 @@ const ( JobDeploymentStatusRunning JobDeploymentStatus = "Running" JobDeploymentStatusFailedToGetJobStatus JobDeploymentStatus = "FailedToGetJobStatus" JobDeploymentStatusComplete JobDeploymentStatus = "Complete" + JobDeploymentStatusSuspended JobDeploymentStatus = "Suspended" ) // RayJobSpec defines the desired state of RayJob @@ -61,6 +62,12 @@ type RayJobSpec struct { RayClusterSpec *RayClusterSpec `json:"rayClusterSpec,omitempty"` // clusterSelector is used to select running rayclusters by labels ClusterSelector map[string]string `json:"clusterSelector,omitempty"` + // suspend specifies whether the RayJob controller should create a RayCluster instance + // If a job is applied with the suspend field set to true, + // the RayCluster will not be created and will wait for the transition to false. + // If the RayCluster is already created, it will be deleted. + // In case of transition to false a new RayCluster will be created. + Suspend bool `json:"suspend,omitempty"` } // RayJobStatus defines the observed state of RayJob diff --git a/ray-operator/config/crd/bases/ray.io_rayjobs.yaml b/ray-operator/config/crd/bases/ray.io_rayjobs.yaml index 2e65eead13d..24e732ef58f 100644 --- a/ray-operator/config/crd/bases/ray.io_rayjobs.yaml +++ b/ray-operator/config/crd/bases/ray.io_rayjobs.yaml @@ -12060,6 +12060,10 @@ spec: description: ShutdownAfterJobFinishes will determine whether to delete the ray cluster once rayJob succeed or fai type: boolean + suspend: + description: suspend specifies whether the RayJob controller should + create a RayCluster instance If a job is appl + type: boolean ttlSecondsAfterFinished: description: TTLSecondsAfterFinished is the TTL to clean up RayCluster. format: int32 diff --git a/ray-operator/config/samples/ray_v1alpha1_rayjob.yaml b/ray-operator/config/samples/ray_v1alpha1_rayjob.yaml index 0e940a2d294..789b00c506e 100644 --- a/ray-operator/config/samples/ray_v1alpha1_rayjob.yaml +++ b/ray-operator/config/samples/ray_v1alpha1_rayjob.yaml @@ -3,6 +3,7 @@ kind: RayJob metadata: name: rayjob-sample spec: + suspend: false entrypoint: python /home/ray/samples/sample_code.py # runtimeEnv decoded to '{ # "pip": [ diff --git a/ray-operator/controllers/ray/rayjob_controller.go b/ray-operator/controllers/ray/rayjob_controller.go index a6f1cbc8630..510f68e70ab 100644 --- a/ray-operator/controllers/ray/rayjob_controller.go +++ b/ray-operator/controllers/ray/rayjob_controller.go @@ -15,6 +15,7 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/reconcile" "github.com/ray-project/kuberay/ray-operator/controllers/ray/common" "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" @@ -92,7 +93,7 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request) if isJobPendingOrRunning(rayJobInstance.Status.JobStatus) { rayDashboardClient := utils.GetRayDashboardClientFunc() rayDashboardClient.InitClient(rayJobInstance.Status.DashboardURL) - err := rayDashboardClient.StopJob(rayJobInstance.Status.JobId, &r.Log) + err := rayDashboardClient.StopJob(ctx, rayJobInstance.Status.JobId, &r.Log) if err != nil { r.Log.Info("Failed to stop job", "error", err) } @@ -150,6 +151,20 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request) err = r.updateState(ctx, rayJobInstance, nil, rayJobInstance.Status.JobStatus, rayv1alpha1.JobDeploymentStatusFailedToGetOrCreateRayCluster, err) return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err } + // If there is no cluster instance and no error suspend the job deployment + if rayClusterInstance == nil { + // Already suspended? + if rayJobInstance.Status.JobDeploymentStatus == rayv1alpha1.JobDeploymentStatusSuspended { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } + err = r.updateState(ctx, rayJobInstance, nil, rayJobInstance.Status.JobStatus, rayv1alpha1.JobDeploymentStatusSuspended, err) + if err != nil { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } + r.Log.Info("rayJob suspended", "RayJob", rayJobInstance.Name) + r.Recorder.Eventf(rayJobInstance, corev1.EventTypeNormal, "Suspended", "Suspended RayJob %s", rayJobInstance.Name) + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } // Always update RayClusterStatus along with jobStatus and jobDeploymentStatus updates. rayJobInstance.Status.RayClusterStatus = rayClusterInstance.Status @@ -178,7 +193,7 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request) } // Check the current status of ray jobs before submitting. - jobInfo, err := rayDashboardClient.GetJobInfo(rayJobInstance.Status.JobId) + jobInfo, err := rayDashboardClient.GetJobInfo(ctx, rayJobInstance.Status.JobId) if err != nil { err = r.updateState(ctx, rayJobInstance, jobInfo, rayJobInstance.Status.JobStatus, rayv1alpha1.JobDeploymentStatusFailedToGetJobStatus, err) // Dashboard service in head pod takes time to start, it's possible we get connection refused error. @@ -189,7 +204,7 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request) r.Log.V(1).Info("RayJob information", "RayJob", rayJobInstance.Name, "jobInfo", jobInfo, "rayJobInstance", rayJobInstance.Status.JobStatus) if jobInfo == nil { // Submit the job if no id set - jobId, err := rayDashboardClient.SubmitJob(rayJobInstance, &r.Log) + jobId, err := rayDashboardClient.SubmitJob(ctx, rayJobInstance, &r.Log) if err != nil { r.Log.Error(err, "failed to submit job") err = r.updateState(ctx, rayJobInstance, jobInfo, rayJobInstance.Status.JobStatus, rayv1alpha1.JobDeploymentStatusFailedJobDeploy, err) @@ -213,9 +228,48 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request) return ctrl.Result{}, err } - // Job may takes long time to start and finish, let's just periodically requeue the job and check status. - if isJobPendingOrRunning(jobInfo.JobStatus) && rayJobInstance.Status.JobDeploymentStatus == rayv1alpha1.JobDeploymentStatusRunning { - return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, nil + if rayJobInstance.Status.JobDeploymentStatus == rayv1alpha1.JobDeploymentStatusRunning { + // If suspend flag is set AND + // the RayJob is submitted against the RayCluster created by THIS job, then + // try to gracefully stop the Ray job and delete (suspend) the cluster + if rayJobInstance.Spec.Suspend && len(rayJobInstance.Spec.ClusterSelector) == 0 { + info, err := rayDashboardClient.GetJobInfo(ctx, rayJobInstance.Status.JobId) + if err != nil { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } + if !rayv1alpha1.IsJobTerminal(info.JobStatus) { + err := rayDashboardClient.StopJob(ctx, rayJobInstance.Status.JobId, &r.Log) + if err != nil { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } + } + if info.JobStatus != rayv1alpha1.JobStatusStopped { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, nil + } + + _, err = r.deleteCluster(ctx, rayJobInstance) + if err != nil && !errors.IsNotFound(err) { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, nil + } + // Since RayCluster instance is gone, remove it status also + // on RayJob resource + rayJobInstance.Status.RayClusterStatus = rayv1alpha1.RayClusterStatus{} + rayJobInstance.Status.RayClusterName = "" + rayJobInstance.Status.DashboardURL = "" + rayJobInstance.Status.JobId = "" + rayJobInstance.Status.Message = "" + err = r.updateState(ctx, rayJobInstance, jobInfo, rayv1alpha1.JobStatusStopped, rayv1alpha1.JobDeploymentStatusSuspended, nil) + if err != nil { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } + r.Log.Info("rayJob suspended", "RayJob", rayJobInstance.Name) + r.Recorder.Eventf(rayJobInstance, corev1.EventTypeNormal, "Suspended", "Suspended RayJob %s", rayJobInstance.Name) + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, nil + // Job may takes long time to start and finish, let's just periodically requeue the job and check status. + } + if isJobPendingOrRunning(jobInfo.JobStatus) { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, nil + } } // Let's use rayJobInstance.Status.JobStatus to make sure we only delete cluster after the CR is updated. @@ -231,34 +285,38 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request) return ctrl.Result{RequeueAfter: time.Duration(delta) * time.Second}, nil } } - r.Log.Info("shutdownAfterJobFinishes set to true, we will delete cluster", "RayJob", rayJobInstance.Name, "clusterName", fmt.Sprintf("%s/%s", rayJobInstance.Namespace, rayJobInstance.Status.RayClusterName)) - clusterIdentifier := types.NamespacedName{ - Name: rayJobInstance.Status.RayClusterName, - Namespace: rayJobInstance.Namespace, - } - cluster := rayv1alpha1.RayCluster{} - if err := r.Get(ctx, clusterIdentifier, &cluster); err != nil { - if !errors.IsNotFound(err) { - return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err - } - r.Log.Info("The associated cluster has been already deleted and it can not be found", "RayCluster", clusterIdentifier) - } else { - if cluster.DeletionTimestamp != nil { - r.Log.Info("The cluster deletion is ongoing.", "rayjob", rayJobInstance.Name, "raycluster", cluster.Name) - } else { - if err := r.Delete(ctx, &cluster); err != nil { - return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err - } - r.Log.Info("The associated cluster is deleted", "RayCluster", clusterIdentifier) - r.Recorder.Eventf(rayJobInstance, corev1.EventTypeNormal, "Deleted", "Deleted cluster %s", rayJobInstance.Status.RayClusterName) - return ctrl.Result{Requeue: true}, nil - } - } + return r.deleteCluster(ctx, rayJobInstance) } } + return ctrl.Result{}, nil +} +func (r *RayJobReconciler) deleteCluster(ctx context.Context, rayJobInstance *rayv1alpha1.RayJob) (reconcile.Result, error) { + clusterIdentifier := types.NamespacedName{ + Name: rayJobInstance.Status.RayClusterName, + Namespace: rayJobInstance.Namespace, + } + cluster := rayv1alpha1.RayCluster{} + if err := r.Get(ctx, clusterIdentifier, &cluster); err != nil { + if !errors.IsNotFound(err) { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } + r.Log.Info("The associated cluster has been already deleted and it can not be found", "RayCluster", clusterIdentifier) + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } else { + if cluster.DeletionTimestamp != nil { + r.Log.Info("The cluster deletion is ongoing.", "rayjob", rayJobInstance.Name, "raycluster", cluster.Name) + } else { + if err := r.Delete(ctx, &cluster); err != nil { + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err + } + r.Log.Info("The associated cluster is deleted", "RayCluster", clusterIdentifier) + r.Recorder.Eventf(rayJobInstance, corev1.EventTypeNormal, "Deleted", "Deleted cluster %s", rayJobInstance.Status.RayClusterName) + return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, nil + } + } return ctrl.Result{}, nil } @@ -343,7 +401,11 @@ func (r *RayJobReconciler) updateState(ctx context.Context, rayJob *rayv1alpha1. if jobInfo != nil { rayJob.Status.Message = jobInfo.Message rayJob.Status.StartTime = utils.ConvertUnixTimeToMetav1Time(jobInfo.StartTime) - rayJob.Status.EndTime = utils.ConvertUnixTimeToMetav1Time(jobInfo.EndTime) + if jobInfo.StartTime >= jobInfo.EndTime { + rayJob.Status.EndTime = nil + } else { + rayJob.Status.EndTime = utils.ConvertUnixTimeToMetav1Time(jobInfo.EndTime) + } } // TODO (kevin85421): ObservedGeneration should be used to determine whether update this CR or not. @@ -391,11 +453,15 @@ func (r *RayJobReconciler) getOrCreateRayClusterInstance(ctx context.Context, ra return nil, err } - // one special case is the job is complete status and cluster has been recycled. + // special case: is the job is complete status and cluster has been recycled. if isJobSucceedOrFailed(rayJobInstance.Status.JobStatus) && rayJobInstance.Status.JobDeploymentStatus == rayv1alpha1.JobDeploymentStatusComplete { r.Log.Info("The cluster has been recycled for the job, skip duplicate creation", "rayjob", rayJobInstance.Name) return nil, err } + // special case: don't create a cluster instance and don't return an error if the suspend flag of the job is true + if rayJobInstance.Spec.Suspend { + return nil, nil + } r.Log.Info("RayCluster not found, creating rayCluster!", "raycluster", rayClusterNamespacedName) rayClusterInstance, err = r.constructRayClusterForRayJob(rayJobInstance, rayClusterInstanceName) diff --git a/ray-operator/controllers/ray/rayjob_controller_suspended_test.go b/ray-operator/controllers/ray/rayjob_controller_suspended_test.go new file mode 100644 index 00000000000..689fc4b33a7 --- /dev/null +++ b/ray-operator/controllers/ray/rayjob_controller_suspended_test.go @@ -0,0 +1,312 @@ +/* + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ray + +import ( + "context" + "fmt" + "time" + + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/ray-project/kuberay/ray-operator/controllers/ray/common" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + rayiov1alpha1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1" + + corev1 "k8s.io/api/core/v1" + "k8s.io/utils/pointer" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +var _ = Context("Inside the default namespace", func() { + ctx := context.TODO() + var workerPods corev1.PodList + var headPods corev1.PodList + mySuspendedRayCluster := &rayiov1alpha1.RayCluster{} + + mySuspendedRayJob := &rayiov1alpha1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: "rayjob-test-suspend", + Namespace: "default", + }, + Spec: rayiov1alpha1.RayJobSpec{ + Suspend: true, + Entrypoint: "sleep 999", + RayClusterSpec: &rayiov1alpha1.RayClusterSpec{ + RayVersion: "2.4.0", + HeadGroupSpec: rayiov1alpha1.HeadGroupSpec{ + ServiceType: corev1.ServiceTypeClusterIP, + Replicas: pointer.Int32(1), + RayStartParams: map[string]string{ + "port": "6379", + "object-store-memory": "100000000", + "dashboard-host": "0.0.0.0", + "num-cpus": "1", + "node-ip-address": "127.0.0.1", + "block": "true", + "dashboard-agent-listen-port": "52365", + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "rayCluster": "raycluster-sample", + "groupName": "headgroup", + }, + Annotations: map[string]string{ + "key": "value", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-head", + Image: "rayproject/ray:2.2.0", + Env: []corev1.EnvVar{ + { + Name: "MY_POD_IP", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "status.podIP", + }, + }, + }, + }, + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + Ports: []corev1.ContainerPort{ + { + Name: "gcs-server", + ContainerPort: 6379, + }, + { + Name: "dashboard", + ContainerPort: 8265, + }, + { + Name: "head", + ContainerPort: 10001, + }, + { + Name: "dashboard-agent", + ContainerPort: 52365, + }, + }, + }, + }, + }, + }, + }, + WorkerGroupSpecs: []rayiov1alpha1.WorkerGroupSpec{ + { + Replicas: pointer.Int32(3), + MinReplicas: pointer.Int32(0), + MaxReplicas: pointer.Int32(10000), + GroupName: "small-group", + RayStartParams: map[string]string{ + "port": "6379", + "num-cpus": "1", + "dashboard-agent-listen-port": "52365", + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "default", + Labels: map[string]string{ + "rayCluster": "raycluster-sample", + "groupName": "small-group", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "ray-worker", + Image: "rayproject/ray:2.2.0", + Command: []string{"echo"}, + Args: []string{"Hello Ray"}, + Env: []corev1.EnvVar{ + { + Name: "MY_POD_IP", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "status.podIP", + }, + }, + }, + }, + Ports: []corev1.ContainerPort{ + { + Name: "client", + ContainerPort: 80, + }, + { + Name: "dashboard-agent", + ContainerPort: 52365, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + Describe("When creating a rayjob with suspend == true", func() { + It("should create a rayjob object", func() { + err := k8sClient.Create(ctx, mySuspendedRayJob) + Expect(err).NotTo(HaveOccurred(), "failed to create test RayJob resource") + }) + + It("should see a rayjob object", func() { + Eventually( + getResourceFunc(ctx, client.ObjectKey{Name: mySuspendedRayJob.Name, Namespace: "default"}, mySuspendedRayJob), + time.Second*3, time.Millisecond*500).Should(BeNil(), "My myRayJob = %v", mySuspendedRayJob.Name) + }) + + It("should have deployment status suspended", func() { + Eventually( + getRayJobDeploymentStatus(ctx, mySuspendedRayJob), + time.Second*5, time.Millisecond*500).Should(Equal(rayiov1alpha1.JobDeploymentStatusSuspended)) + }) + + It("should NOT create a raycluster object", func() { + // Ray Cluster name can be present on RayJob's CRD + Eventually( + getRayClusterNameForRayJob(ctx, mySuspendedRayJob), + time.Second*15, time.Millisecond*500).Should(Not(BeEmpty())) + // However the actual cluster instance and underlying resources should not be created while suspend == true + Eventually( + // k8sClient client throws error if resource not found + func() bool { + err := getResourceFunc(ctx, client.ObjectKey{Name: mySuspendedRayJob.Status.RayClusterName, Namespace: "default"}, mySuspendedRayCluster)() + return errors.IsNotFound(err) + }, + time.Second*10, time.Millisecond*500).Should(BeTrue()) + }) + + It("should unsuspend a rayjob object", func() { + mySuspendedRayJob.Spec.Suspend = false + err := k8sClient.Update(ctx, mySuspendedRayJob) + Expect(err).NotTo(HaveOccurred(), "failed to update test RayJob resource") + }) + + It("should create a raycluster object", func() { + // Ray Cluster name can be present on RayJob's CRD + Eventually( + getRayClusterNameForRayJob(ctx, mySuspendedRayJob), + time.Second*15, time.Millisecond*500).Should(Not(BeEmpty())) + // The actual cluster instance and underlying resources SHOULD be created when suspend == false + Eventually( + // k8sClient client does not throw error if cluster IS found + getResourceFunc(ctx, client.ObjectKey{Name: mySuspendedRayJob.Status.RayClusterName, Namespace: "default"}, mySuspendedRayCluster), + time.Second*3, time.Millisecond*500).Should(BeNil()) + }) + + It("should create 3 workers", func() { + Eventually( + listResourceFunc(ctx, &workerPods, client.MatchingLabels{ + common.RayClusterLabelKey: mySuspendedRayCluster.Name, + common.RayNodeGroupLabelKey: "small-group", + }, + &client.ListOptions{Namespace: "default"}), + time.Second*15, time.Millisecond*500).Should(Equal(3), fmt.Sprintf("workerGroup %v", workerPods.Items)) + if len(workerPods.Items) > 0 { + Expect(workerPods.Items[0].Status.Phase).Should(Or(Equal(corev1.PodRunning), Equal(corev1.PodPending))) + } + }) + + It("should create a head pod resource", func() { + err := k8sClient.List(ctx, &headPods, + client.MatchingLabels{ + common.RayClusterLabelKey: mySuspendedRayCluster.Name, + common.RayNodeGroupLabelKey: "headgroup", + }, + &client.ListOptions{Namespace: "default"}, + client.InNamespace(mySuspendedRayCluster.Namespace)) + + Expect(err).NotTo(HaveOccurred(), "failed list head pods") + Expect(len(headPods.Items)).Should(BeNumerically("==", 1), "My head pod list= %v", headPods.Items) + + pod := &corev1.Pod{} + if len(headPods.Items) > 0 { + pod = &headPods.Items[0] + } + Eventually( + getResourceFunc(ctx, client.ObjectKey{Name: pod.Name, Namespace: "default"}, pod), + time.Second*3, time.Millisecond*500).Should(BeNil(), "My head pod = %v", pod) + Expect(pod.Status.Phase).Should(Or(Equal(corev1.PodPending))) + }) + + It("should be able to update all Pods to Running", func() { + // We need to manually update Pod statuses otherwise they'll always be Pending. + // envtest doesn't create a full K8s cluster. It's only the control plane. + // There's no container runtime or any other K8s controllers. + // So Pods are created, but no controller updates them from Pending to Running. + // See https://book.kubebuilder.io/reference/envtest.html + + for _, headPod := range headPods.Items { + headPod.Status.Phase = corev1.PodRunning + Expect(k8sClient.Status().Update(ctx, &headPod)).Should(BeNil()) + } + + Eventually( + isAllPodsRunning(ctx, headPods, client.MatchingLabels{ + common.RayClusterLabelKey: mySuspendedRayCluster.Name, + common.RayNodeGroupLabelKey: "headgroup", + }, "default"), + time.Second*15, time.Millisecond*500).Should(Equal(true), "Head Pod should be running.") + + for _, workerPod := range workerPods.Items { + workerPod.Status.Phase = corev1.PodRunning + Expect(k8sClient.Status().Update(ctx, &workerPod)).Should(BeNil()) + } + + Eventually( + isAllPodsRunning(ctx, workerPods, client.MatchingLabels{common.RayClusterLabelKey: mySuspendedRayCluster.Name, common.RayNodeGroupLabelKey: "small-group"}, "default"), + time.Second*15, time.Millisecond*500).Should(Equal(true), "All worker Pods should be running.") + }) + + It("Dashboard URL should be set", func() { + Eventually( + getDashboardURLForRayJob(ctx, mySuspendedRayJob), + time.Second*3, time.Millisecond*500).Should(HavePrefix(mySuspendedRayJob.Name), "Dashboard URL = %v", mySuspendedRayJob.Status.DashboardURL) + }) + }) +}) + +func getRayJobDeploymentStatus(ctx context.Context, rayJob *rayiov1alpha1.RayJob) func() (rayiov1alpha1.JobDeploymentStatus, error) { + return func() (rayiov1alpha1.JobDeploymentStatus, error) { + if err := k8sClient.Get(ctx, client.ObjectKey{Name: rayJob.Name, Namespace: "default"}, rayJob); err != nil { + return "", err + } + return rayJob.Status.JobDeploymentStatus, nil + } +} diff --git a/ray-operator/controllers/ray/rayservice_controller.go b/ray-operator/controllers/ray/rayservice_controller.go index 9ab5bf6a488..b4149825ee2 100644 --- a/ray-operator/controllers/ray/rayservice_controller.go +++ b/ray-operator/controllers/ray/rayservice_controller.go @@ -620,7 +620,7 @@ func (r *RayServiceReconciler) checkIfNeedSubmitServeDeployment(rayServiceInstan return shouldUpdate } -func (r *RayServiceReconciler) updateServeDeployment(rayServiceInstance *rayv1alpha1.RayService, rayDashboardClient utils.RayDashboardClientInterface, clusterName string) error { +func (r *RayServiceReconciler) updateServeDeployment(ctx context.Context, rayServiceInstance *rayv1alpha1.RayService, rayDashboardClient utils.RayDashboardClientInterface, clusterName string) error { r.Log.V(1).Info("updateServeDeployment", "config", rayServiceInstance.Spec.ServeDeploymentGraphSpec) runtimeEnv := make(map[string]interface{}) _ = yaml.Unmarshal([]byte(rayServiceInstance.Spec.ServeDeploymentGraphSpec.RuntimeEnv), &runtimeEnv) @@ -632,7 +632,7 @@ func (r *RayServiceReconciler) updateServeDeployment(rayServiceInstance *rayv1al deploymentJson, _ := json.Marshal(servingClusterDeployments) r.Log.V(1).Info("updateServeDeployment", "json config", string(deploymentJson)) - if err := rayDashboardClient.UpdateDeployments(rayServiceInstance.Spec.ServeDeploymentGraphSpec); err != nil { + if err := rayDashboardClient.UpdateDeployments(ctx, rayServiceInstance.Spec.ServeDeploymentGraphSpec); err != nil { r.Log.Error(err, "fail to update deployment") return err } @@ -648,7 +648,7 @@ func (r *RayServiceReconciler) updateServeDeployment(rayServiceInstance *rayv1al // updates health timestamps, and checks if the RayCluster is overall healthy. // It's return values should be interpreted as // (Serve app healthy?, Serve app ready?, error if any) -func (r *RayServiceReconciler) getAndCheckServeStatus(dashboardClient utils.RayDashboardClientInterface, rayServiceServeStatus *rayv1alpha1.RayServiceStatus, unhealthySecondThreshold *int32) (bool, bool, error) { +func (r *RayServiceReconciler) getAndCheckServeStatus(ctx context.Context, dashboardClient utils.RayDashboardClientInterface, rayServiceServeStatus *rayv1alpha1.RayServiceStatus, unhealthySecondThreshold *int32) (bool, bool, error) { serviceUnhealthySecondThreshold := ServiceUnhealthySecondThreshold if unhealthySecondThreshold != nil { serviceUnhealthySecondThreshold = float64(*unhealthySecondThreshold) @@ -656,7 +656,7 @@ func (r *RayServiceReconciler) getAndCheckServeStatus(dashboardClient utils.RayD var serveStatuses *utils.ServeDeploymentStatuses var err error - if serveStatuses, err = dashboardClient.GetDeploymentsStatus(); err != nil { + if serveStatuses, err = dashboardClient.GetDeploymentsStatus(ctx); err != nil { r.Log.Error(err, "Failed to get Serve deployment statuses from dashboard!") return false, false, err } @@ -866,7 +866,7 @@ func (r *RayServiceReconciler) updateStatusForActiveCluster(ctx context.Context, rayDashboardClient.InitClient(clientURL) var isHealthy, isReady bool - if isHealthy, isReady, err = r.getAndCheckServeStatus(rayDashboardClient, rayServiceStatus, rayServiceInstance.Spec.ServiceUnhealthySecondThreshold); err != nil { + if isHealthy, isReady, err = r.getAndCheckServeStatus(ctx, rayDashboardClient, rayServiceStatus, rayServiceInstance.Spec.ServiceUnhealthySecondThreshold); err != nil { r.updateAndCheckDashboardStatus(rayServiceStatus, false, rayServiceInstance.Spec.DeploymentUnhealthySecondThreshold) return err } @@ -927,7 +927,7 @@ func (r *RayServiceReconciler) reconcileServe(ctx context.Context, rayServiceIns shouldUpdate := r.checkIfNeedSubmitServeDeployment(rayServiceInstance, rayClusterInstance, rayServiceStatus) if shouldUpdate { - if err = r.updateServeDeployment(rayServiceInstance, rayDashboardClient, rayClusterInstance.Name); err != nil { + if err = r.updateServeDeployment(ctx, rayServiceInstance, rayDashboardClient, rayClusterInstance.Name); err != nil { if !r.updateAndCheckDashboardStatus(rayServiceStatus, false, rayServiceInstance.Spec.DeploymentUnhealthySecondThreshold) { logger.Info("Dashboard is unhealthy, restart the cluster.") r.markRestart(rayServiceInstance) @@ -941,7 +941,7 @@ func (r *RayServiceReconciler) reconcileServe(ctx context.Context, rayServiceIns } var isHealthy, isReady bool - if isHealthy, isReady, err = r.getAndCheckServeStatus(rayDashboardClient, rayServiceStatus, rayServiceInstance.Spec.DeploymentUnhealthySecondThreshold); err != nil { + if isHealthy, isReady, err = r.getAndCheckServeStatus(ctx, rayDashboardClient, rayServiceStatus, rayServiceInstance.Spec.DeploymentUnhealthySecondThreshold); err != nil { if !r.updateAndCheckDashboardStatus(rayServiceStatus, false, rayServiceInstance.Spec.DeploymentUnhealthySecondThreshold) { logger.Info("Dashboard is unhealthy, restart the cluster.") r.markRestart(rayServiceInstance) diff --git a/ray-operator/controllers/ray/utils/dashboard_httpclient.go b/ray-operator/controllers/ray/utils/dashboard_httpclient.go index f54f363823d..415b6f42ccd 100644 --- a/ray-operator/controllers/ray/utils/dashboard_httpclient.go +++ b/ray-operator/controllers/ray/utils/dashboard_httpclient.go @@ -77,13 +77,13 @@ type ServingClusterDeployments struct { type RayDashboardClientInterface interface { InitClient(url string) - GetDeployments() (string, error) - UpdateDeployments(specs rayv1alpha1.ServeDeploymentGraphSpec) error - GetDeploymentsStatus() (*ServeDeploymentStatuses, error) + GetDeployments(context.Context) (string, error) + UpdateDeployments(ctx context.Context, spec rayv1alpha1.ServeDeploymentGraphSpec) error + GetDeploymentsStatus(context.Context) (*ServeDeploymentStatuses, error) ConvertServeConfig(specs []rayv1alpha1.ServeConfigSpec) []ServeConfigSpec - GetJobInfo(jobId string) (*RayJobInfo, error) - SubmitJob(rayJob *rayv1alpha1.RayJob, log *logr.Logger) (jobId string, err error) - StopJob(jobName string, log *logr.Logger) (err error) + GetJobInfo(ctx context.Context, jobId string) (*RayJobInfo, error) + SubmitJob(ctx context.Context, rayJob *rayv1alpha1.RayJob, log *logr.Logger) (jobId string, err error) + StopJob(ctx context.Context, jobName string, log *logr.Logger) (err error) } // GetRayDashboardClientFunc Used for unit tests. @@ -172,8 +172,8 @@ func (r *RayDashboardClient) InitClient(url string) { } // GetDeployments get the current deployments in the Ray cluster. -func (r *RayDashboardClient) GetDeployments() (string, error) { - req, err := http.NewRequest("GET", r.dashboardURL+DeployPath, nil) +func (r *RayDashboardClient) GetDeployments(ctx context.Context) (string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", r.dashboardURL+DeployPath, nil) if err != nil { return "", err } @@ -193,7 +193,7 @@ func (r *RayDashboardClient) GetDeployments() (string, error) { } // UpdateDeployments update the deployments in the Ray cluster. -func (r *RayDashboardClient) UpdateDeployments(spec rayv1alpha1.ServeDeploymentGraphSpec) error { +func (r *RayDashboardClient) UpdateDeployments(ctx context.Context, spec rayv1alpha1.ServeDeploymentGraphSpec) error { runtimeEnv := make(map[string]interface{}) _ = yaml.Unmarshal([]byte(spec.RuntimeEnv), &runtimeEnv) @@ -209,7 +209,7 @@ func (r *RayDashboardClient) UpdateDeployments(spec rayv1alpha1.ServeDeploymentG return err } - req, err := http.NewRequest(http.MethodPut, r.dashboardURL+DeployPath, bytes.NewBuffer(deploymentJson)) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, r.dashboardURL+DeployPath, bytes.NewBuffer(deploymentJson)) if err != nil { return err } @@ -230,8 +230,8 @@ func (r *RayDashboardClient) UpdateDeployments(spec rayv1alpha1.ServeDeploymentG } // GetDeploymentsStatus get the current deployment statuses in the Ray cluster. -func (r *RayDashboardClient) GetDeploymentsStatus() (*ServeDeploymentStatuses, error) { - req, err := http.NewRequest("GET", r.dashboardURL+StatusPath, nil) +func (r *RayDashboardClient) GetDeploymentsStatus(ctx context.Context) (*ServeDeploymentStatuses, error) { + req, err := http.NewRequestWithContext(ctx, "GET", r.dashboardURL+StatusPath, nil) if err != nil { return nil, err } @@ -327,8 +327,8 @@ type RayJobStopResponse struct { Stopped bool `json:"stopped"` } -func (r *RayDashboardClient) GetJobInfo(jobId string) (*RayJobInfo, error) { - req, err := http.NewRequest("GET", r.dashboardURL+JobPath+jobId, nil) +func (r *RayDashboardClient) GetJobInfo(ctx context.Context, jobId string) (*RayJobInfo, error) { + req, err := http.NewRequestWithContext(ctx, "GET", r.dashboardURL+JobPath+jobId, nil) if err != nil { return nil, err } @@ -357,7 +357,7 @@ func (r *RayDashboardClient) GetJobInfo(jobId string) (*RayJobInfo, error) { return &jobInfo, nil } -func (r *RayDashboardClient) SubmitJob(rayJob *rayv1alpha1.RayJob, log *logr.Logger) (jobId string, err error) { +func (r *RayDashboardClient) SubmitJob(ctx context.Context, rayJob *rayv1alpha1.RayJob, log *logr.Logger) (jobId string, err error) { request, err := ConvertRayJobToReq(rayJob) if err != nil { return "", err @@ -368,7 +368,7 @@ func (r *RayDashboardClient) SubmitJob(rayJob *rayv1alpha1.RayJob, log *logr.Log } log.Info("Submit a ray job", "rayJob", rayJob.Name, "jobInfo", string(rayJobJson)) - req, err := http.NewRequest(http.MethodPost, r.dashboardURL+JobPath, bytes.NewBuffer(rayJobJson)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, r.dashboardURL+JobPath, bytes.NewBuffer(rayJobJson)) if err != nil { return } @@ -391,10 +391,10 @@ func (r *RayDashboardClient) SubmitJob(rayJob *rayv1alpha1.RayJob, log *logr.Log return jobResp.JobId, nil } -func (r *RayDashboardClient) StopJob(jobName string, log *logr.Logger) (err error) { +func (r *RayDashboardClient) StopJob(ctx context.Context, jobName string, log *logr.Logger) (err error) { log.Info("Stop a ray job", "rayJob", jobName) - req, err := http.NewRequest(http.MethodPost, r.dashboardURL+JobPath+jobName+"/stop", nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, r.dashboardURL+JobPath+jobName+"/stop", nil) if err != nil { return err } @@ -414,7 +414,7 @@ func (r *RayDashboardClient) StopJob(jobName string, log *logr.Logger) (err erro } if !jobStopResp.Stopped { - jobInfo, err := r.GetJobInfo(jobName) + jobInfo, err := r.GetJobInfo(ctx, jobName) if err != nil { return err } diff --git a/ray-operator/controllers/ray/utils/dashboard_httpclient_test.go b/ray-operator/controllers/ray/utils/dashboard_httpclient_test.go index c1c0ec2b2b7..9ec14b67642 100644 --- a/ray-operator/controllers/ray/utils/dashboard_httpclient_test.go +++ b/ray-operator/controllers/ray/utils/dashboard_httpclient_test.go @@ -1,6 +1,7 @@ package utils import ( + "context" "encoding/base64" "encoding/json" "net/http" @@ -80,16 +81,16 @@ var _ = Describe("RayFrameworkGenerator", func() { return httpmock.NewStringResponse(200, "Ray misbehaved and sent string, not JSON"), nil }) - jobId, err := rayDashboardClient.SubmitJob(rayJob, &ctrl.Log) + jobId, err := rayDashboardClient.SubmitJob(context.TODO(), rayJob, &ctrl.Log) Expect(err).To(BeNil()) Expect(jobId).To(Equal(expectJobId)) - rayJobInfo, err := rayDashboardClient.GetJobInfo(jobId) + rayJobInfo, err := rayDashboardClient.GetJobInfo(context.TODO(), jobId) Expect(err).To(BeNil()) Expect(rayJobInfo.Entrypoint).To(Equal(rayJob.Spec.Entrypoint)) Expect(rayJobInfo.JobStatus).To(Equal(rayv1alpha1.JobStatusRunning)) - _, err = rayDashboardClient.GetJobInfo(errorJobId) + _, err = rayDashboardClient.GetJobInfo(context.TODO(), errorJobId) Expect(err).NotTo(BeNil()) Expect(err.Error()).To(ContainSubstring("GetJobInfo fail")) Expect(err.Error()).To(ContainSubstring("Ray misbehaved")) @@ -107,7 +108,7 @@ var _ = Describe("RayFrameworkGenerator", func() { return httpmock.NewBytesResponse(200, bodyBytes), nil }) - err := rayDashboardClient.StopJob("stop-job-1", &ctrl.Log) + err := rayDashboardClient.StopJob(context.TODO(), "stop-job-1", &ctrl.Log) Expect(err).To(BeNil()) }) @@ -134,7 +135,7 @@ var _ = Describe("RayFrameworkGenerator", func() { return httpmock.NewBytesResponse(200, bodyBytes), nil }) - err := rayDashboardClient.StopJob("stop-job-1", &ctrl.Log) + err := rayDashboardClient.StopJob(context.TODO(), "stop-job-1", &ctrl.Log) Expect(err).To(BeNil()) }) }) diff --git a/ray-operator/controllers/ray/utils/fake_serve_httpclient.go b/ray-operator/controllers/ray/utils/fake_serve_httpclient.go index 07161f03068..e131b0e1973 100644 --- a/ray-operator/controllers/ray/utils/fake_serve_httpclient.go +++ b/ray-operator/controllers/ray/utils/fake_serve_httpclient.go @@ -1,6 +1,7 @@ package utils import ( + "context" "fmt" "net/http" @@ -16,21 +17,23 @@ type FakeRayDashboardClient struct { serveStatuses ServeDeploymentStatuses } +var _ RayDashboardClientInterface = (*FakeRayDashboardClient)(nil) + func (r *FakeRayDashboardClient) InitClient(url string) { r.client = http.Client{} r.dashboardURL = "http://" + url } -func (r *FakeRayDashboardClient) GetDeployments() (string, error) { +func (r *FakeRayDashboardClient) GetDeployments(_ context.Context) (string, error) { panic("Fake GetDeployments not implemented") } -func (r *FakeRayDashboardClient) UpdateDeployments(specs rayv1alpha1.ServeDeploymentGraphSpec) error { +func (r *FakeRayDashboardClient) UpdateDeployments(_ context.Context, specs rayv1alpha1.ServeDeploymentGraphSpec) error { fmt.Print("UpdateDeployments fake succeeds.") return nil } -func (r *FakeRayDashboardClient) GetDeploymentsStatus() (*ServeDeploymentStatuses, error) { +func (r *FakeRayDashboardClient) GetDeploymentsStatus(_ context.Context) (*ServeDeploymentStatuses, error) { return &r.serveStatuses, nil } @@ -80,14 +83,14 @@ func (r *FakeRayDashboardClient) SetServeStatus(status ServeDeploymentStatuses) r.serveStatuses = status } -func (r *FakeRayDashboardClient) GetJobInfo(jobId string) (*RayJobInfo, error) { +func (r *FakeRayDashboardClient) GetJobInfo(_ context.Context, jobId string) (*RayJobInfo, error) { return nil, nil } -func (r *FakeRayDashboardClient) SubmitJob(rayJob *rayv1alpha1.RayJob, log *logr.Logger) (jobId string, err error) { +func (r *FakeRayDashboardClient) SubmitJob(_ context.Context, rayJob *rayv1alpha1.RayJob, log *logr.Logger) (jobId string, err error) { return "", nil } -func (r *FakeRayDashboardClient) StopJob(jobName string, log *logr.Logger) (err error) { +func (r *FakeRayDashboardClient) StopJob(_ context.Context, jobName string, log *logr.Logger) (err error) { return nil }