From 8c2c259f18af96420e51699d3c63d3f610a66d72 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 28 Mar 2025 17:18:42 +0900 Subject: [PATCH 01/23] Role support --- build.sbt | 2 + .../scala/com/devsisters/shardcake/Pod.scala | 9 +- .../com/devsisters/shardcake/PodAddress.scala | 5 + .../scala/com/devsisters/shardcake/Role.scala | 10 + .../shardcake/interfaces/Pods.scala | 10 +- .../shardcake/interfaces/Storage.scala | 20 +- .../com/devsisters/shardcake/GraphQLApi.scala | 7 +- .../devsisters/shardcake/ManagerConfig.scala | 8 +- .../devsisters/shardcake/ShardManager.scala | 423 +++++++++++------- .../shardcake/ShardManagerSpec.scala | 177 +++++--- .../devsisters/shardcake/RedisConfig.scala | 9 +- .../devsisters/shardcake/StorageRedis.scala | 24 +- .../shardcake/StorageRedisSpec.scala | 23 +- .../devsisters/shardcake/RedisConfig.scala | 9 +- .../devsisters/shardcake/StorageRedis.scala | 60 +-- .../shardcake/StorageRedisSpec.scala | 17 +- 16 files changed, 518 insertions(+), 295 deletions(-) create mode 100644 core/src/main/scala/com/devsisters/shardcake/Role.scala diff --git a/build.sbt b/build.sbt index 4fccc86..7edb632 100644 --- a/build.sbt +++ b/build.sbt @@ -9,6 +9,7 @@ val grpcNettyVersion = "1.71.0" val zioK8sVersion = "3.1.0" val zioCacheVersion = "0.2.4" val zioCatsInteropVersion = "23.1.0.5" +val zioJsonVersion = "0.7.39" val sttpVersion = "3.10.3" val calibanVersion = "2.10.0" val redis4catsVersion = "1.7.2" @@ -73,6 +74,7 @@ lazy val core = project Seq( "dev.zio" %% "zio" % zioVersion, "dev.zio" %% "zio-streams" % zioVersion, + "dev.zio" %% "zio-json" % zioJsonVersion, "org.scala-lang.modules" %% "scala-collection-compat" % scalaCompatVersion ) ) diff --git a/core/src/main/scala/com/devsisters/shardcake/Pod.scala b/core/src/main/scala/com/devsisters/shardcake/Pod.scala index 6239f95..1b36fcb 100644 --- a/core/src/main/scala/com/devsisters/shardcake/Pod.scala +++ b/core/src/main/scala/com/devsisters/shardcake/Pod.scala @@ -1,3 +1,10 @@ package com.devsisters.shardcake -case class Pod(address: PodAddress, version: String) +import zio.json._ + +case class Pod(address: PodAddress, version: String, roles: Set[Role]) + +object Pod { + implicit val encoder: JsonEncoder[Pod] = DeriveJsonEncoder.gen[Pod] + implicit val decoder: JsonDecoder[Pod] = DeriveJsonDecoder.gen[Pod] +} diff --git a/core/src/main/scala/com/devsisters/shardcake/PodAddress.scala b/core/src/main/scala/com/devsisters/shardcake/PodAddress.scala index 4f4c3d8..724cf34 100644 --- a/core/src/main/scala/com/devsisters/shardcake/PodAddress.scala +++ b/core/src/main/scala/com/devsisters/shardcake/PodAddress.scala @@ -2,6 +2,8 @@ package com.devsisters.shardcake import scala.collection.compat._ +import zio.json._ + case class PodAddress(host: String, port: Int) { override def toString: String = s"$host:$port" } @@ -12,4 +14,7 @@ object PodAddress { case host :: port :: Nil => port.toIntOption.map(port => PodAddress(host, port)) case _ => None } + + implicit val encoder: JsonEncoder[PodAddress] = DeriveJsonEncoder.gen[PodAddress] + implicit val decoder: JsonDecoder[PodAddress] = DeriveJsonDecoder.gen[PodAddress] } diff --git a/core/src/main/scala/com/devsisters/shardcake/Role.scala b/core/src/main/scala/com/devsisters/shardcake/Role.scala new file mode 100644 index 0000000..39c24f7 --- /dev/null +++ b/core/src/main/scala/com/devsisters/shardcake/Role.scala @@ -0,0 +1,10 @@ +package com.devsisters.shardcake + +import zio.json._ + +case class Role(name: String) + +object Role { + implicit val encoder: JsonEncoder[Role] = DeriveJsonEncoder.gen[Role] + implicit val decoder: JsonDecoder[Role] = DeriveJsonDecoder.gen[Role] +} diff --git a/core/src/main/scala/com/devsisters/shardcake/interfaces/Pods.scala b/core/src/main/scala/com/devsisters/shardcake/interfaces/Pods.scala index dc288eb..8f8e6a6 100644 --- a/core/src/main/scala/com/devsisters/shardcake/interfaces/Pods.scala +++ b/core/src/main/scala/com/devsisters/shardcake/interfaces/Pods.scala @@ -1,7 +1,7 @@ package com.devsisters.shardcake.interfaces import com.devsisters.shardcake.interfaces.Pods.BinaryMessage -import com.devsisters.shardcake.{ PodAddress, ShardId } +import com.devsisters.shardcake.{ PodAddress, Role, ShardId } import zio.stream.ZStream import zio.{ Task, ULayer, ZIO, ZLayer } @@ -15,12 +15,12 @@ trait Pods { /** * Notify a pod that it was assigned a list of shards */ - def assignShards(pod: PodAddress, shards: Set[ShardId]): Task[Unit] + def assignShards(pod: PodAddress, shards: Set[ShardId], role: Role): Task[Unit] /** * Notify a pod that it was unassigned a list of shards */ - def unassignShards(pod: PodAddress, shards: Set[ShardId]): Task[Unit] + def unassignShards(pod: PodAddress, shards: Set[ShardId], role: Role): Task[Unit] /** * Check that a pod is responsive @@ -64,8 +64,8 @@ object Pods { */ val noop: ULayer[Pods] = ZLayer.succeed(new Pods { - def assignShards(pod: PodAddress, shards: Set[ShardId]): Task[Unit] = ZIO.unit - def unassignShards(pod: PodAddress, shards: Set[ShardId]): Task[Unit] = ZIO.unit + def assignShards(pod: PodAddress, shards: Set[ShardId], role: Role): Task[Unit] = ZIO.unit + def unassignShards(pod: PodAddress, shards: Set[ShardId], role: Role): Task[Unit] = ZIO.unit def ping(pod: PodAddress): Task[Unit] = ZIO.unit def sendMessage(pod: PodAddress, message: BinaryMessage): Task[Option[Array[Byte]]] = ZIO.none def sendStream( diff --git a/core/src/main/scala/com/devsisters/shardcake/interfaces/Storage.scala b/core/src/main/scala/com/devsisters/shardcake/interfaces/Storage.scala index 7475c47..aa832a6 100644 --- a/core/src/main/scala/com/devsisters/shardcake/interfaces/Storage.scala +++ b/core/src/main/scala/com/devsisters/shardcake/interfaces/Storage.scala @@ -1,6 +1,6 @@ package com.devsisters.shardcake.interfaces -import com.devsisters.shardcake.{ Pod, PodAddress, ShardId } +import com.devsisters.shardcake.{ Pod, PodAddress, Role, ShardId } import zio.{ Ref, Task, ZLayer } import zio.stream.{ SubscriptionRef, ZStream } @@ -12,17 +12,17 @@ trait Storage { /** * Get the current state of shard assignments to pods */ - def getAssignments: Task[Map[ShardId, Option[PodAddress]]] + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] /** * Save the current state of shard assignments to pods */ - def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] + def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] /** * A stream that will emit the state of shard assignments whenever it changes */ - def assignmentsStream: ZStream[Any, Throwable, Map[Int, Option[PodAddress]]] + def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[Int, Option[PodAddress]]] /** * Get the list of existing pods @@ -47,11 +47,13 @@ object Storage { assignmentsRef <- SubscriptionRef.make(Map.empty[ShardId, Option[PodAddress]]) podsRef <- Ref.make(Map.empty[PodAddress, Pod]) } yield new Storage { - def getAssignments: Task[Map[ShardId, Option[PodAddress]]] = assignmentsRef.get - def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = assignmentsRef.set(assignments) - def assignmentsStream: ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = assignmentsRef.changes - def getPods: Task[Map[PodAddress, Pod]] = podsRef.get - def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = podsRef.set(pods) + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = assignmentsRef.get + def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = + assignmentsRef.set(assignments) + def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = + assignmentsRef.changes + def getPods: Task[Map[PodAddress, Pod]] = podsRef.get + def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = podsRef.set(pods) } } } diff --git a/manager/src/main/scala/com/devsisters/shardcake/GraphQLApi.scala b/manager/src/main/scala/com/devsisters/shardcake/GraphQLApi.scala index 1992d53..8b4ab4a 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/GraphQLApi.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/GraphQLApi.scala @@ -11,7 +11,8 @@ object GraphQLApi extends GenericSchema[ShardManager] { import auto._ case class Assignment(shardId: ShardId, pod: Option[PodAddress]) - case class Queries(getAssignments: URIO[ShardManager, List[Assignment]]) + case class RoleArgs(role: String) + case class Queries(getAssignments: RoleArgs => URIO[ShardManager, List[Assignment]]) case class PodAddressArgs(podAddress: PodAddress) case class Mutations( register: Pod => RIO[ShardManager, Unit], @@ -24,9 +25,9 @@ object GraphQLApi extends GenericSchema[ShardManager] { val api: GraphQL[ShardManager] = graphQL[ShardManager, Queries, Mutations, Subscriptions]( RootResolver( - Queries( + Queries(args => ZIO.serviceWithZIO( - _.getAssignments.map(_.map { case (k, v) => Assignment(k, v) }.toList.sortBy(_.shardId)) + _.getAssignments(Role(args.role)).map(_.map { case (k, v) => Assignment(k, v) }.toList.sortBy(_.shardId)) ) ), Mutations( diff --git a/manager/src/main/scala/com/devsisters/shardcake/ManagerConfig.scala b/manager/src/main/scala/com/devsisters/shardcake/ManagerConfig.scala index 8917de5..bf32ff6 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ManagerConfig.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ManagerConfig.scala @@ -5,6 +5,7 @@ import zio._ /** * Shard Manager configuration * @param numberOfShards number of shards (see documentation on how to choose this), should be same on all nodes + * @param numberOfShardsPerRole overrides of the number of shards per role * @param apiPort port to expose the GraphQL API * @param rebalanceInterval interval for regular rebalancing of shards * @param rebalanceRetryInterval retry interval for rebalancing when some shards failed to be rebalanced @@ -16,6 +17,7 @@ import zio._ */ case class ManagerConfig( numberOfShards: Int, + numberOfShardsPerRole: Map[Role, Int], apiPort: Int, rebalanceInterval: Duration, rebalanceRetryInterval: Duration, @@ -24,12 +26,16 @@ case class ManagerConfig( persistRetryCount: Int, rebalanceRate: Double, podHealthCheckInterval: Duration -) +) { + def getNumberOfShards(role: Role): Int = + numberOfShardsPerRole.getOrElse(role, numberOfShards) +} object ManagerConfig { val default: ManagerConfig = ManagerConfig( numberOfShards = 300, + numberOfShardsPerRole = Map.empty, apiPort = 8080, rebalanceInterval = 20 seconds, rebalanceRetryInterval = 10 seconds, diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index e4e6201..150293e 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -16,7 +16,7 @@ import scala.collection.compat._ */ class ShardManager( stateRef: Ref.Synchronized[ShardManagerState], - rebalanceSemaphore: Semaphore, + rebalanceSemaphores: Ref.Synchronized[Map[Role, Semaphore]], eventsHub: Hub[ShardingEvent], healthApi: PodsHealth, podApi: Pods, @@ -24,8 +24,8 @@ class ShardManager( config: ManagerConfig ) { - def getAssignments: UIO[Map[ShardId, Option[PodAddress]]] = - stateRef.get.map(_.shards) + def getAssignments(role: Role): UIO[Map[ShardId, Option[PodAddress]]] = + stateRef.get.map(_.shards(role)) def getShardingEvents: ZStream[Any, Nothing, ShardingEvent] = ZStream.fromHub(eventsHub) @@ -40,8 +40,10 @@ class ShardManager( .map(cdt => state.copy(pods = state.pods.updated(pod.address, PodWithMetadata(pod, cdt)))) ) _ <- ManagerMetrics.pods.increment - _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address)) - _ <- ZIO.when(state.unassignedShards.nonEmpty)(rebalance(rebalanceImmediately = false)) + _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address, pod.roles)) + _ <- ZIO.foreachDiscard(pod.roles) { role => + ZIO.when(state.unassignedShards(role).nonEmpty)(rebalance(role, rebalanceImmediately = false).forkDaemon) + } _ <- persistPods.forkDaemon } yield (), onFalse = ZIO.logWarning(s"Pod $pod requested to register but is not alive, ignoring") *> @@ -65,44 +67,68 @@ class ShardManager( } yield () def unregister(podAddress: PodAddress): UIO[Unit] = - ZIO - .whenZIO(stateRef.get.map(_.pods.contains(podAddress))) { - for { - _ <- ZIO.logInfo(s"Unregistering $podAddress") - unassignments <- stateRef.modify { state => - ( - state.shards.collect { case (shard, Some(p)) if p == podAddress => shard }.toSet, - state.copy( - pods = state.pods - podAddress, - shards = - state.shards.map { case (k, v) => k -> (if (v.contains(podAddress)) None else v) } - ) + ZIO.whenZIODiscard(stateRef.get.map(_.pods.contains(podAddress))) { + for { + _ <- ZIO.logInfo(s"Unregistering $podAddress") + unassignments <- stateRef.modify { state => + ( + state.assignments.map { case (role, assignments) => + role -> assignments.shards.collect { + case (shard, Some(p)) if p == podAddress => shard + }.toSet + }.filter { case (_, shards) => shards.nonEmpty }, + state.copy( + pods = state.pods - podAddress, + assignments = state.assignments.map { case (role, assignments) => + role -> assignments.copy( + shards = assignments.shards.map { case (k, v) => + k -> (if (v.contains(podAddress)) None else v) + } + ) + } ) - } - _ <- ManagerMetrics.pods.decrement - _ <- ManagerMetrics.assignedShards.tagged("pod_address", podAddress.toString).decrementBy(unassignments.size) - _ <- ManagerMetrics.unassignedShards.incrementBy(unassignments.size) - _ <- eventsHub.publish(ShardingEvent.PodUnregistered(podAddress)) - _ <- eventsHub - .publish(ShardingEvent.ShardsUnassigned(podAddress, unassignments)) - .when(unassignments.nonEmpty) - _ <- persistPods.forkDaemon - _ <- rebalance(rebalanceImmediately = true).forkDaemon - } yield () + ) + } + _ <- ZIO.foreachDiscard(unassignments) { case (role, shards) => + for { + _ <- ManagerMetrics.assignedShards + .tagged("role", role.name) + .tagged("pod_address", podAddress.toString) + .decrementBy(shards.size) + _ <- ManagerMetrics.unassignedShards + .tagged("role", role.name) + .incrementBy(shards.size) + _ <- eventsHub + .publish(ShardingEvent.ShardsUnassigned(podAddress, role, shards)) + .when(shards.nonEmpty) + _ <- rebalance(role, rebalanceImmediately = true).forkDaemon + } yield () + } + _ <- ManagerMetrics.pods.decrement + _ <- eventsHub.publish(ShardingEvent.PodUnregistered(podAddress)) + _ <- persistPods.forkDaemon + } yield () + } + + private def getSemaphore(role: Role): UIO[Semaphore] = + rebalanceSemaphores.modifyZIO(map => + map.get(role) match { + case Some(s) => ZIO.succeed((s, map)) + case None => Semaphore.make(1).map(s => (s, map.updated(role, s))) } - .unit + ) - private def rebalance(rebalanceImmediately: Boolean): UIO[Unit] = - rebalanceSemaphore.withPermit { + private def rebalance(role: Role, rebalanceImmediately: Boolean): UIO[Unit] = + getSemaphore(role).flatMap(_.withPermit { for { state <- stateRef.get // find which shards to assign and unassign - (assignments, unassignments) = if (rebalanceImmediately || state.unassignedShards.nonEmpty) - decideAssignmentsForUnassignedShards(state) - else decideAssignmentsForUnbalancedShards(state, config.rebalanceRate) + (assignments, unassignments) = if (rebalanceImmediately || state.unassignedShards(role).nonEmpty) + decideAssignmentsForUnassignedShards(role, state) + else decideAssignmentsForUnbalancedShards(role, state, config.rebalanceRate) areChanges = assignments.nonEmpty || unassignments.nonEmpty - _ <- (ZIO.logDebug(s"Rebalancing (rebalanceImmediately=$rebalanceImmediately)") *> - ManagerMetrics.rebalances.increment).when(areChanges) + _ <- (ZIO.logDebug(s"Rebalancing role ${role.name} (rebalanceImmediately=$rebalanceImmediately)") *> + ManagerMetrics.rebalances.tagged("role", role.name).increment).when(areChanges) // ping pods first to make sure they are ready and remove those who aren't failedPingedPods <- ZIO .foreachPar(assignments.keySet ++ unassignments.keySet)(pod => @@ -121,13 +147,18 @@ class ShardManager( // do the unassignments first failed <- ZIO .foreachPar(readyUnassignments.toList) { case (pod, shards) => - (podApi.unassignShards(pod, shards) *> updateShardsState(shards, None)).foldZIO( + (podApi.unassignShards(pod, shards, role) *> updateShardsState(role, shards, None)).foldZIO( _ => ZIO.succeed((Set(pod), shards)), _ => - ManagerMetrics.assignedShards.tagged("pod_address", pod.toString).decrementBy(shards.size) *> - ManagerMetrics.unassignedShards.incrementBy(shards.size) *> + ManagerMetrics.assignedShards + .tagged("role", role.name) + .tagged("pod_address", pod.toString) + .decrementBy(shards.size) *> + ManagerMetrics.unassignedShards + .tagged("role", role.name) + .incrementBy(shards.size) *> eventsHub - .publish(ShardingEvent.ShardsUnassigned(pod, shards)) + .publish(ShardingEvent.ShardsUnassigned(pod, role, shards)) .as((Set.empty, Set.empty)) ) } @@ -141,37 +172,52 @@ class ShardManager( // then do the assignments failedAssignedPods <- ZIO .foreachPar(filteredAssignments.toList) { case (pod, shards) => - (podApi.assignShards(pod, shards) *> updateShardsState(shards, Some(pod))).foldZIO( - _ => ZIO.succeed(Set(pod)), - _ => - ManagerMetrics.assignedShards - .tagged("pod_address", pod.toString) - .incrementBy(shards.size) *> - ManagerMetrics.unassignedShards.decrementBy(shards.size) *> - eventsHub.publish(ShardingEvent.ShardsAssigned(pod, shards)).as(Set.empty) - ) + (podApi.assignShards(pod, shards, role) *> updateShardsState(role, shards, Some(pod))) + .foldZIO( + _ => ZIO.succeed(Set(pod)), + _ => + ManagerMetrics.assignedShards + .tagged("role", role.name) + .tagged("pod_address", pod.toString) + .incrementBy(shards.size) *> + ManagerMetrics.unassignedShards + .tagged("role", role.name) + .decrementBy(shards.size) *> + eventsHub + .publish(ShardingEvent.ShardsAssigned(pod, role, shards)) + .as(Set.empty) + ) } .map(_.flatten[PodAddress].toSet) failedPods = failedPingedPods ++ failedUnassignedPods ++ failedAssignedPods // check if failing pods are still up _ <- ZIO.foreachDiscard(failedPods)(notifyUnhealthyPod(_)).forkDaemon - _ <- ZIO.logWarning(s"Failed to rebalance pods: $failedPods").when(failedPods.nonEmpty) + _ <- ZIO.logWarning(s"Failed to rebalance pods for role ${role.name}: $failedPods").when(failedPods.nonEmpty) // retry rebalancing later if there was any failure - _ <- (Clock.sleep(config.rebalanceRetryInterval) *> rebalance(rebalanceImmediately)).forkDaemon + _ <- (Clock.sleep(config.rebalanceRetryInterval) *> rebalance(role, rebalanceImmediately)).forkDaemon .when(failedPods.nonEmpty && rebalanceImmediately) // persist state changes to Redis - _ <- persistAssignments.forkDaemon.when(areChanges) + _ <- persistAssignments(role).forkDaemon.when(areChanges) } yield () - } + }) private def withRetry[E, A](zio: IO[E, A]): UIO[Unit] = zio .retry[Any, Any](Schedule.spaced(config.persistRetryInterval) && Schedule.recurs(config.persistRetryCount)) .ignore - private def persistAssignments: UIO[Unit] = + private def persistAssignments(role: Role): UIO[Unit] = withRetry( - stateRef.get.flatMap(state => stateRepository.saveAssignments(state.shards)) + stateRef.get.flatMap(state => stateRepository.saveAssignments(role, state.shards(role))) + ) + + private val persistAllAssignments: UIO[Unit] = + withRetry( + stateRef.get.flatMap(state => + ZIO.foreachDiscard(state.assignments) { case (role, assignments) => + stateRepository.saveAssignments(role, assignments.shards) + } + ) ) private def persistPods: UIO[Unit] = @@ -179,16 +225,23 @@ class ShardManager( stateRef.get.flatMap(state => stateRepository.savePods(state.pods.map { case (k, v) => (k, v.pod) })) ) - private def updateShardsState(shards: Set[ShardId], pod: Option[PodAddress]): Task[Unit] = + private def updateShardsState(role: Role, shards: Set[ShardId], pod: Option[PodAddress]): Task[Unit] = stateRef.updateZIO(state => ZIO .whenCase(pod) { case Some(pod) if !state.pods.contains(pod) => ZIO.fail(new Exception(s"Pod $pod is no longer registered")) } .as( - state.copy(shards = state.shards.map { case (shard, assignment) => - shard -> (if (shards.contains(shard)) pod else assignment) - }) + state.copy(assignments = + state.assignments.updated( + role, + ShardAssignments( + state.shards(role).map { case (shard, assignment) => + shard -> (if (shards.contains(shard)) pod else assignment) + } + ) + ) + ) ) ) } @@ -201,76 +254,118 @@ object ShardManager { val live: ZLayer[PodsHealth with Pods with Storage with ManagerConfig, Throwable, ShardManager] = ZLayer.scoped { for { - config <- ZIO.service[ManagerConfig] - stateRepository <- ZIO.service[Storage] - healthApi <- ZIO.service[PodsHealth] - podApi <- ZIO.service[Pods] - pods <- stateRepository.getPods - assignments <- stateRepository.getAssignments + config <- ZIO.service[ManagerConfig] + stateRepository <- ZIO.service[Storage] + healthApi <- ZIO.service[PodsHealth] + podApi <- ZIO.service[Pods] + pods <- stateRepository.getPods // remove unhealthy pods on startup - failedFilteredPods <- + failedFilteredPods <- ZIO.partitionPar(pods) { addrPod => ZIO.ifZIO(healthApi.isAlive(addrPod._1))(ZIO.succeed(addrPod), ZIO.fail(addrPod._2)) } - (failedPods, filtered) = failedFilteredPods - _ <- ZIO.when(failedPods.nonEmpty)( - ZIO.logInfo(s"Ignoring pods that are no longer alive ${failedPods.mkString("[", ", ", "]")}") - ) - filteredPods = filtered.toMap - failedFilteredAssignments = partitionMap(assignments) { - case assignment @ (_, Some(address)) if filteredPods.contains(address) => - Right(assignment) - case assignment => Left(assignment) - } - (failed, filteredAssignments) = failedFilteredAssignments - failedAssignments = failed.collect { case (shard, Some(addr)) => shard -> addr } - _ <- ZIO.when(failedAssignments.nonEmpty)( - ZIO.logWarning( - s"Ignoring assignments for pods that are no longer alive ${failedAssignments.mkString("[", ", ", "]")}" - ) - ) - cdt <- ZIO.succeed(OffsetDateTime.now()) - initialState = ShardManagerState( - filteredPods.map { case (k, v) => k -> PodWithMetadata(v, cdt) }, - (1 to config.numberOfShards).map(_ -> None).toMap ++ filteredAssignments - ) - _ <- ZIO.logInfo( - s"Recovered pods ${filteredPods - .mkString("[", ", ", "]")} and assignments ${filteredAssignments.mkString("[", ", ", "]")}" - ) - _ <- ManagerMetrics.pods.incrementBy(initialState.pods.size) - _ <- ZIO.foreachDiscard(initialState.shards) { case (_, podAddressOpt) => - podAddressOpt match { - case Some(podAddress) => - ManagerMetrics.assignedShards.tagged("pod_address", podAddress.toString).increment - case None => - ManagerMetrics.unassignedShards.increment - } - } - state <- Ref.Synchronized.make(initialState) - rebalanceSemaphore <- Semaphore.make(1) - eventsHub <- Hub.unbounded[ShardingEvent] - shardManager = - new ShardManager(state, rebalanceSemaphore, eventsHub, healthApi, podApi, stateRepository, config) - _ <- ZIO.addFinalizer { - shardManager.persistAssignments.catchAllCause(cause => - ZIO.logWarningCause("Failed to persist assignments on shutdown", cause) - ) *> - shardManager.persistPods.catchAllCause(cause => - ZIO.logWarningCause("Failed to persist pods on shutdown", cause) - ) - } - _ <- shardManager.persistPods.forkDaemon + (failedPods, filtered) = failedFilteredPods + _ <- ZIO.when(failedPods.nonEmpty)( + ZIO.logInfo(s"Ignoring pods that are no longer alive ${failedPods.mkString("[", ", ", "]")}") + ) + filteredPods = filtered.toMap + roles = filteredPods.flatMap(_._2.roles).toSet + _ <- ZIO.when(filteredPods.nonEmpty)(ZIO.logInfo(s"Recovered pods ${filteredPods.mkString("[", ", ", "]")}")) + roleAssignments <- ZIO + .foreach(roles) { role => + for { + assignments <- stateRepository.getAssignments(role) + failedFilteredAssignments = partitionMap(assignments) { + case assignment @ (_, Some(address)) + if filteredPods.contains(address) => + Right(assignment) + case assignment => Left(assignment) + } + (failed, filteredAssignments) = failedFilteredAssignments + failedAssignments = failed.collect { case (shard, Some(addr)) => shard -> addr } + _ <- + ZIO.when(failedAssignments.nonEmpty)( + ZIO.logWarning( + s"Ignoring assignments for pods that are no longer alive for role ${role.name}: ${failedAssignments + .mkString("[", ", ", "]")}" + ) + ) + _ <- + ZIO.when(filteredAssignments.nonEmpty)( + ZIO.logInfo( + s"Recovered assignments for role ${role.name}: ${filteredAssignments + .mkString("[", ", ", "]")}" + ) + ) + } yield role -> filteredAssignments + } + .map(_.toMap) + cdt <- ZIO.succeed(OffsetDateTime.now()) + initialState = ShardManagerState( + filteredPods.map { case (k, v) => k -> PodWithMetadata(v, cdt) }, + roles + .map(role => + role -> ShardAssignments( + (1 to config.getNumberOfShards(role)).map(_ -> None).toMap ++ + roleAssignments.getOrElse(role, Map.empty) + ) + ) + .toMap, + config.getNumberOfShards + ) + _ <- ManagerMetrics.pods.incrementBy(initialState.pods.size) + _ <- ZIO + .foreachDiscard(initialState.roles) { role => + ZIO.foreachDiscard(initialState.shards(role)) { case (_, podAddressOpt) => + podAddressOpt match { + case Some(podAddress) => + ManagerMetrics.assignedShards + .tagged("role", role.name) + .tagged("pod_address", podAddress.toString) + .increment + case None => + ManagerMetrics.unassignedShards + .tagged("role", role.name) + .increment + } + } + } + state <- Ref.Synchronized.make(initialState) + rebalanceSemaphores <- Ref.Synchronized.make(Map.empty[Role, Semaphore]) + eventsHub <- Hub.unbounded[ShardingEvent] + shardManager = new ShardManager( + stateRef = state, + rebalanceSemaphores = rebalanceSemaphores, + eventsHub = eventsHub, + healthApi = healthApi, + podApi = podApi, + stateRepository = stateRepository, + config = config + ) + _ <- ZIO.addFinalizer { + shardManager.persistAllAssignments.catchAllCause(cause => + ZIO.logWarningCause("Failed to persist assignments on shutdown", cause) + ) *> + shardManager.persistPods.catchAllCause(cause => + ZIO.logWarningCause("Failed to persist pods on shutdown", cause) + ) + } + _ <- shardManager.persistPods.forkDaemon // rebalance immediately if there are unassigned shards - _ <- shardManager.rebalance(rebalanceImmediately = initialState.unassignedShards.nonEmpty).forkDaemon + _ <- + ZIO.foreachDiscard(roles)(role => + shardManager.rebalance(role, rebalanceImmediately = initialState.unassignedShards(role).nonEmpty).forkDaemon + ) // start a regular rebalance at the given interval - _ <- shardManager - .rebalance(rebalanceImmediately = false) - .repeat(Schedule.spaced(config.rebalanceInterval)) - .forkDaemon - _ <- shardManager.getShardingEvents.mapZIO(event => ZIO.logInfo(event.toString)).runDrain.forkDaemon - _ <- shardManager.checkAllPodsHealth.repeat(Schedule.spaced(config.podHealthCheckInterval)).forkDaemon - _ <- ZIO.logInfo("Shard Manager loaded") + _ <- state.get + .flatMap(state => + ZIO.foreachParDiscard(state.roles)(shardManager.rebalance(_, rebalanceImmediately = false)) + ) + .repeat(Schedule.spaced(config.rebalanceInterval)) + .forkDaemon + _ <- shardManager.getShardingEvents.mapZIO(event => ZIO.logInfo(event.toString)).runDrain.forkDaemon + _ <- shardManager.checkAllPodsHealth.repeat(Schedule.spaced(config.podHealthCheckInterval)).forkDaemon + _ <- ZIO.logInfo("Shard Manager loaded") } yield shardManager } @@ -306,70 +401,95 @@ object ShardManager { if (xs eq ys) 0 else loop(xs, ys) } - case class ShardManagerState(pods: Map[PodAddress, PodWithMetadata], shards: Map[ShardId, Option[PodAddress]]) { + case class ShardAssignments(shards: Map[ShardId, Option[PodAddress]]) { lazy val unassignedShards: Set[ShardId] = shards.collect { case (k, None) => k }.toSet - lazy val averageShardsPerPod: ShardId = if (pods.nonEmpty) shards.size / pods.size else 0 - private lazy val podVersions = pods.values.toList.map(extractVersion) - lazy val maxVersion: Option[List[ShardId]] = podVersions.maxOption - lazy val allPodsHaveMaxVersion: Boolean = podVersions.forall(maxVersion.contains) lazy val shardsPerPod: Map[PodAddress, Set[ShardId]] = - pods.map { case (k, _) => k -> Set.empty[ShardId] } ++ - shards.groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet } + shards.groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet } + } + + case class ShardManagerState( + pods: Map[PodAddress, PodWithMetadata], + assignments: Map[Role, ShardAssignments], + getNumberOfShards: Role => Int + ) { + private lazy val podVersions = pods.values.toList.map(extractVersion) + lazy val maxVersion: Option[List[ShardId]] = podVersions.maxOption + lazy val allPodsHaveMaxVersion: Boolean = podVersions.forall(maxVersion.contains) + lazy val roles: Set[Role] = pods.values.flatMap(_.pod.roles).toSet ++ assignments.keySet + + private lazy val emptyShardsPerPod: Map[PodAddress, Set[ShardId]] = + pods.map { case (k, _) => k -> Set.empty[ShardId] } + + private def assignmentsForRole(role: Role): ShardAssignments = + assignments.getOrElse(role, ShardAssignments((1 to getNumberOfShards(role)).map(_ -> None).toMap)) + + def shards(role: Role): Map[ShardId, Option[PodAddress]] = assignmentsForRole(role).shards + def unassignedShards(role: Role): Set[ShardId] = assignmentsForRole(role).unassignedShards + def averageShardsPerPod(role: Role): ShardId = if (pods.nonEmpty) shards(role).size / pods.size else 0 + + def shardsPerPod(role: Role): Map[PodAddress, Set[ShardId]] = + emptyShardsPerPod ++ assignmentsForRole(role).shardsPerPod } case class PodWithMetadata(pod: Pod, registered: OffsetDateTime) sealed trait ShardingEvent object ShardingEvent { - case class ShardsAssigned(pod: PodAddress, shards: Set[ShardId]) extends ShardingEvent { - override def toString: String = s"ShardsAssigned(pod=$pod, shards=${renderShardIds(shards)})" + case class ShardsAssigned(pod: PodAddress, role: Role, shards: Set[ShardId]) extends ShardingEvent { + override def toString: String = s"ShardsAssigned(pod=$pod, role=${role.name}, shards=${renderShardIds(shards)})" } - case class ShardsUnassigned(pod: PodAddress, shards: Set[ShardId]) extends ShardingEvent { - override def toString: String = s"ShardsUnassigned(pod=$pod, shards=${renderShardIds(shards)})" + case class ShardsUnassigned(pod: PodAddress, role: Role, shards: Set[ShardId]) extends ShardingEvent { + override def toString: String = s"ShardsUnassigned(pod=$pod, role=${role.name}, shards=${renderShardIds(shards)})" } - case class PodRegistered(pod: PodAddress) extends ShardingEvent - case class PodUnregistered(pod: PodAddress) extends ShardingEvent - case class PodHealthChecked(pod: PodAddress) extends ShardingEvent + case class PodRegistered(pod: PodAddress, roles: Set[Role]) extends ShardingEvent + case class PodUnregistered(pod: PodAddress) extends ShardingEvent + case class PodHealthChecked(pod: PodAddress) extends ShardingEvent } def decideAssignmentsForUnassignedShards( + role: Role, state: ShardManagerState ): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) = - pickNewPods(state.unassignedShards.toList, state, rebalanceImmediately = true, 1.0) + pickNewPods(state.unassignedShards(role).toList, role, state, rebalanceImmediately = true, 1.0) def decideAssignmentsForUnbalancedShards( + role: Role, state: ShardManagerState, rebalanceRate: Double ): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) = { val extraShardsToAllocate = if (state.allPodsHaveMaxVersion) { // don't do regular rebalance in the middle of a rolling update - state.shardsPerPod.flatMap { case (_, shards) => - // count how many extra shards compared to the average - val extraShards = (shards.size - state.averageShardsPerPod).max(0) - Random.shuffle(shards).take(extraShards) - }.toSet + state + .shardsPerPod(role) + .flatMap { case (_, shards) => + // count how many extra shards compared to the average + val extraShards = (shards.size - state.averageShardsPerPod(role)).max(0) + Random.shuffle(shards).take(extraShards) + } + .toSet } else Set.empty val sortedShardsToRebalance = extraShardsToAllocate.toList.sortBy { shard => // handle unassigned shards first, then shards on the pods with most shards, then shards on old pods - state.shards.get(shard).flatten.fold((Int.MinValue, OffsetDateTime.MIN)) { pod => + state.shards(role).get(shard).flatten.fold((Int.MinValue, OffsetDateTime.MIN)) { pod => ( - state.shardsPerPod.get(pod).fold(Int.MinValue)(-_.size), + state.shardsPerPod(role).get(pod).fold(Int.MinValue)(-_.size), state.pods.get(pod).fold(OffsetDateTime.MIN)(_.registered) ) } } - pickNewPods(sortedShardsToRebalance, state, rebalanceImmediately = false, rebalanceRate) + pickNewPods(sortedShardsToRebalance, role, state, rebalanceImmediately = false, rebalanceRate) } private def pickNewPods( shardsToRebalance: List[ShardId], + role: Role, state: ShardManagerState, rebalanceImmediately: Boolean, rebalanceRate: Double ): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) = { - val (_, assignments) = shardsToRebalance.foldLeft((state.shardsPerPod, List.empty[(ShardId, PodAddress)])) { + val (_, assignments) = shardsToRebalance.foldLeft((state.shardsPerPod(role), List.empty[(ShardId, PodAddress)])) { case ((shardsPerPod, assignments), shard) => val unassignedPods = assignments.flatMap { case (shard, _) => - state.shards.get(shard).flatten[PodAddress] + state.shards(role).get(shard).flatten[PodAddress] }.toSet // find pod with least amount of shards shardsPerPod @@ -379,13 +499,14 @@ object ShardManager { } // don't assign too many shards to the same pods, unless we need rebalance immediately .filter { case (pod, _) => - rebalanceImmediately || assignments.count { case (_, p) => p == pod } < state.shards.size * rebalanceRate + rebalanceImmediately || + assignments.count { case (_, p) => p == pod } < state.shards(role).size * rebalanceRate } // don't assign to a pod that was unassigned in the same rebalance .filterNot { case (pod, _) => unassignedPods.contains(pod) } .minByOption(_._2.size) match { case Some((pod, shards)) => - val oldPod = state.shards.get(shard).flatten + val oldPod = state.shards(role).get(shard).flatten // if old pod is same as new pod, don't change anything if (oldPod.contains(pod)) (shardsPerPod, assignments) @@ -404,7 +525,7 @@ object ShardManager { case None => (shardsPerPod, assignments) } } - val unassignments = assignments.flatMap { case (shard, _) => state.shards.get(shard).flatten.map(shard -> _) } + val unassignments = assignments.flatMap { case (shard, _) => state.shards(role).get(shard).flatten.map(shard -> _) } val assignmentsPerPod = assignments.groupBy(_._2).map { case (k, v) => k -> v.map(_._1).toSet } val unassignmentsPerPod = unassignments.groupBy(_._2).map { case (k, v) => k -> v.map(_._1).toSet } (assignmentsPerPod, unassignmentsPerPod) diff --git a/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala b/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala index 00c6e43..38d1736 100644 --- a/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala +++ b/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala @@ -1,6 +1,6 @@ package com.devsisters.shardcake -import com.devsisters.shardcake.ShardManager.{ PodWithMetadata, ShardManagerState } +import com.devsisters.shardcake.ShardManager.{ PodWithMetadata, ShardAssignments, ShardManagerState } import com.devsisters.shardcake.interfaces.{ Pods, PodsHealth, Storage } import zio._ import zio.stream.ZStream @@ -9,9 +9,10 @@ import zio.test._ import java.time.OffsetDateTime object ShardManagerSpec extends ZIOSpecDefault { - private val pod1 = PodWithMetadata(Pod(PodAddress("1", 1), "1.0.0"), OffsetDateTime.MIN) - private val pod2 = PodWithMetadata(Pod(PodAddress("2", 2), "1.0.0"), OffsetDateTime.MIN) - private val pod3 = PodWithMetadata(Pod(PodAddress("3", 3), "1.0.0"), OffsetDateTime.MIN) + private val role = Role("default") + private val pod1 = PodWithMetadata(Pod(PodAddress("1", 1), "1.0.0", Set(role)), OffsetDateTime.MIN) + private val pod2 = PodWithMetadata(Pod(PodAddress("2", 2), "1.0.0", Set(role)), OffsetDateTime.MIN) + private val pod3 = PodWithMetadata(Pod(PodAddress("3", 3), "1.0.0", Set(role)), OffsetDateTime.MIN) override def spec: Spec[Any, Throwable] = suite("ShardManagerSpec")( @@ -20,13 +21,17 @@ object ShardManagerSpec extends ZIOSpecDefault { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), - shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)) + assignments = + Map(role -> ShardAssignments(Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)))), + getNumberOfShards = ManagerConfig.default.getNumberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) - assertTrue(assignments.contains(pod2.pod.address)) && - assertTrue(assignments.size == 1) && - assertTrue(unassignments.contains(pod1.pod.address)) && - assertTrue(unassignments.size == 1) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + assertTrue( + assignments.contains(pod2.pod.address), + assignments.size == 1, + unassignments.contains(pod1.pod.address), + unassignments.size == 1 + ) }, test("Don't rebalance to pod with older version") { val state = @@ -35,94 +40,123 @@ object ShardManagerSpec extends ZIOSpecDefault { pod1.pod.address -> pod1, pod2.pod.address -> pod2.copy(pod = pod2.pod.copy(version = "0.1.2")) ), // older version - shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)) + assignments = + Map(role -> ShardAssignments(Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)))), + getNumberOfShards = ManagerConfig.default.getNumberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) - assertTrue(assignments.isEmpty) && assertTrue(unassignments.isEmpty) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + assertTrue(assignments.isEmpty, unassignments.isEmpty) }, test("Don't rebalance when already well balanced") { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), - shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod2.pod.address)) + assignments = + Map(role -> ShardAssignments(Map(1 -> Some(pod1.pod.address), 2 -> Some(pod2.pod.address)))), + getNumberOfShards = ManagerConfig.default.getNumberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) - assertTrue(assignments.isEmpty) && assertTrue(unassignments.isEmpty) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + assertTrue(assignments.isEmpty, unassignments.isEmpty) }, test("Don't rebalance when only 1 shard difference") { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), - shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)) + assignments = Map( + role -> ShardAssignments( + Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)) + ) + ), + getNumberOfShards = ManagerConfig.default.getNumberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) - assertTrue(assignments.isEmpty) && assertTrue(unassignments.isEmpty) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + assertTrue(assignments.isEmpty, unassignments.isEmpty) }, test("Rebalance when 2 shard difference") { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), - shards = Map( - 1 -> Some(pod1.pod.address), - 2 -> Some(pod1.pod.address), - 3 -> Some(pod1.pod.address), - 4 -> Some(pod2.pod.address) - ) + assignments = Map( + role -> ShardAssignments( + Map( + 1 -> Some(pod1.pod.address), + 2 -> Some(pod1.pod.address), + 3 -> Some(pod1.pod.address), + 4 -> Some(pod2.pod.address) + ) + ) + ), + getNumberOfShards = ManagerConfig.default.getNumberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) - assertTrue(assignments.contains(pod2.pod.address)) && - assertTrue(assignments.size == 1) && - assertTrue(unassignments.contains(pod1.pod.address)) && - assertTrue(unassignments.size == 1) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + assertTrue( + assignments.contains(pod2.pod.address), + assignments.size == 1, + unassignments.contains(pod1.pod.address), + unassignments.size == 1 + ) }, test("Pick the pod with less shards") { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2, pod3.pod.address -> pod3), - shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)) + assignments = Map( + role -> ShardAssignments( + Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)) + ) + ), + getNumberOfShards = ManagerConfig.default.getNumberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) - assertTrue(assignments.contains(pod3.pod.address)) && - assertTrue(assignments.size == 1) && - assertTrue(unassignments.contains(pod1.pod.address)) && - assertTrue(unassignments.size == 1) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + assertTrue( + assignments.contains(pod3.pod.address), + assignments.size == 1, + unassignments.contains(pod1.pod.address), + unassignments.size == 1 + ) }, test("Don't rebalance if pod list is empty") { val state = ShardManagerState( pods = Map(), - shards = Map(1 -> Some(pod1.pod.address)) + assignments = Map(role -> ShardAssignments(Map(1 -> Some(pod1.pod.address)))), + getNumberOfShards = ManagerConfig.default.getNumberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) - assertTrue(assignments.isEmpty) && assertTrue(unassignments.isEmpty) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + assertTrue(assignments.isEmpty, unassignments.isEmpty) }, test("Balance well when 30 nodes are starting one by one") { val state = ShardManagerState( pods = Map(), - shards = (1 to 300).map(_ -> None).toMap + assignments = Map(role -> ShardAssignments((1 to 300).map(_ -> None).toMap)), + getNumberOfShards = ManagerConfig.default.getNumberOfShards ) val result = (1 to 30).foldLeft(state) { case (state, podNumber) => val podAddress = PodAddress("", podNumber) val s1 = state.copy(pods = - state.pods.updated(podAddress, PodWithMetadata(Pod(podAddress, "v1"), OffsetDateTime.now())) + state.pods.updated(podAddress, PodWithMetadata(Pod(podAddress, "v1", Set(role)), OffsetDateTime.now())) ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(s1, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, s1, 1d) val s2 = unassignments.foldLeft(s1) { case (state, (_, shards)) => - shards.foldLeft(state) { case (state, shard) => state.copy(shards = state.shards.updated(shard, None)) } + shards.foldLeft(state) { case (state, shard) => + state.copy(assignments = Map(role -> ShardAssignments(state.shards(role).updated(shard, None)))) + } } val s3 = assignments.foldLeft(s2) { case (state, (address, shards)) => shards.foldLeft(state) { case (state, shard) => - state.copy(shards = state.shards.updated(shard, Some(address))) + state.copy(assignments = + Map(role -> ShardAssignments(state.shards(role).updated(shard, Some(address)))) + ) } } s3 } val shardsPerPod = - result.shards.groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet } + result.shards(role).groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet } assertTrue(shardsPerPod.values.forall(_.size == 10)) } ), @@ -130,23 +164,28 @@ object ShardManagerSpec extends ZIOSpecDefault { test("Simulate scaling out scenario") { (for { // setup 20 pods first - _ <- simulate((1 to 20).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1")))) + _ <- simulate( + (1 to 20).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", Set(role)))) + ) _ <- TestClock.adjust(10 minutes) - assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments) + assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments(role)) // check that all pods are assigned and that all pods have 15 shards each - assert1 = assertTrue(assignments.values.forall(_.isDefined)) && assertTrue( + assert1 = assertTrue( + assignments.values.forall(_.isDefined), assignments.groupBy(_._2).forall(_._2.size == 15) ) // bring 5 new pods - _ <- simulate((21 to 25).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1")))) + _ <- simulate( + (21 to 25).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", Set(role)))) + ) _ <- TestClock.adjust(20 seconds) - assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments) + assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments(role)) // check that new pods received some shards but only 6 assert2 = assertTrue(assignments.groupBy(_._2).filter(_._1.exists(_.port > 20)).forall(_._2.size == 6)) _ <- TestClock.adjust(1 minute) - assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments) + assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments(role)) // check that all pods now have 12 shards each assert3 = assertTrue(assignments.groupBy(_._2).forall(_._2.size == 12)) @@ -155,20 +194,24 @@ object ShardManagerSpec extends ZIOSpecDefault { test("Simulate scaling down scenario") { (for { // setup 25 pods first - _ <- simulate((1 to 25).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1")))) + _ <- simulate( + (1 to 25).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", Set(role)))) + ) _ <- TestClock.adjust(10 minutes) - assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments) + assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments(role)) // check that all pods are assigned and that all pods have 12 shards each - assert1 = assertTrue(assignments.values.forall(_.isDefined)) && assertTrue( + assert1 = assertTrue( + assignments.values.forall(_.isDefined), assignments.groupBy(_._2).forall(_._2.size == 12) ) // remove 5 pods _ <- simulate((21 to 25).toList.map(i => SimulationEvent.PodUnregister(PodAddress("server", i)))) _ <- TestClock.adjust(1 second) - assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments) + assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments(role)) // check that all shards have been rebalanced already - assert2 = assertTrue(assignments.values.forall(_.exists(_.port <= 20))) && assertTrue( + assert2 = assertTrue( + assignments.values.forall(_.exists(_.port <= 20)), assignments.groupBy(_._2).forall(_._2.size == 15) ) @@ -177,21 +220,25 @@ object ShardManagerSpec extends ZIOSpecDefault { test("Simulate temporary storage restart followed by manager restart") { { val setup = (for { - _ <- simulate((1 to 10).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1")))) + _ <- simulate { + (1 to 10).toList.map(i => + SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", Set(role))) + ) + } _ <- TestClock.adjust(10 minutes) // busy wait for the forked daemon fibers to do their job _ <- ZIO.iterate(Map.empty[ShardId, Option[PodAddress]])(_.isEmpty)(_ => - ZIO.serviceWithZIO[Storage](_.getAssignments) + ZIO.serviceWithZIO[Storage](_.getAssignments(role)) ) _ <- ZIO.iterate(Map.empty[PodAddress, Pod])(_.isEmpty)(_ => ZIO.serviceWithZIO[Storage](_.getPods)) // simulate non-persistent storage restart - _ <- ZIO.serviceWithZIO[Storage](s => s.saveAssignments(Map.empty) *> s.savePods(Map.empty)) + _ <- ZIO.serviceWithZIO[Storage](s => s.saveAssignments(role, Map.empty) *> s.savePods(Map.empty)) } yield {}).provideSome[Storage]( ZLayer.makeSome[Storage, ShardManager](config, Pods.noop, PodsHealth.local, ShardManager.live) ) val test = for { - shutdownAssignments <- ZIO.serviceWithZIO[Storage](_.getAssignments) + shutdownAssignments <- ZIO.serviceWithZIO[Storage](_.getAssignments(role)) shutdownPods <- ZIO.serviceWithZIO[Storage](_.getPods) } yield // manager should have saved its state to storage when it shut down @@ -212,11 +259,11 @@ object ShardManagerSpec extends ZIOSpecDefault { val config: ULayer[ManagerConfig] = ZLayer.succeed(ManagerConfig.default) val storage: ULayer[Storage] = ZLayer.succeed(new Storage { - def getAssignments: Task[Map[ShardId, Option[PodAddress]]] = ZIO.succeed(Map.empty) - def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = ZIO.unit - def assignmentsStream: ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = ZStream.empty - def getPods: Task[Map[PodAddress, Pod]] = ZIO.succeed(Map.empty) - def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = ZIO.unit + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = ZIO.succeed(Map.empty) + def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = ZIO.unit + def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = ZStream.empty + def getPods: Task[Map[PodAddress, Pod]] = ZIO.succeed(Map.empty) + def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = ZIO.unit }) val shardManager: ZLayer[Any, Throwable, ShardManager] = diff --git a/storage-redis/src/main/scala/com/devsisters/shardcake/RedisConfig.scala b/storage-redis/src/main/scala/com/devsisters/shardcake/RedisConfig.scala index 3e64c3b..73d0a53 100644 --- a/storage-redis/src/main/scala/com/devsisters/shardcake/RedisConfig.scala +++ b/storage-redis/src/main/scala/com/devsisters/shardcake/RedisConfig.scala @@ -2,11 +2,12 @@ package com.devsisters.shardcake /** * The configuration for the Redis storage implementation. - * @param assignmentsKey the key to store shard assignments - * @param podsKey the key to store registered pods + * @param assignmentsKey a function from a role to the key to use to store shard assignments + * @param podsKey the key to use to store registered pods */ -case class RedisConfig(assignmentsKey: String, podsKey: String) +case class RedisConfig(assignmentsKey: Role => String, podsKey: String) object RedisConfig { - val default: RedisConfig = RedisConfig(assignmentsKey = "shard_assignments", podsKey = "pods") + val default: RedisConfig = + RedisConfig(assignmentsKey = role => s"shard_assignments:${role.name}", podsKey = "pods") } diff --git a/storage-redis/src/main/scala/com/devsisters/shardcake/StorageRedis.scala b/storage-redis/src/main/scala/com/devsisters/shardcake/StorageRedis.scala index 39536bf..eceebd2 100644 --- a/storage-redis/src/main/scala/com/devsisters/shardcake/StorageRedis.scala +++ b/storage-redis/src/main/scala/com/devsisters/shardcake/StorageRedis.scala @@ -4,6 +4,7 @@ import com.devsisters.shardcake.interfaces.Storage import dev.profunktor.redis4cats.RedisCommands import dev.profunktor.redis4cats.data.RedisChannel import dev.profunktor.redis4cats.pubsub.PubSubCommands +import zio.json._ import zio.stream.ZStream import zio.stream.interop.fs2z._ import zio.{ Task, ZIO, ZLayer } @@ -24,36 +25,41 @@ object StorageRedis { stringClient <- ZIO.service[RedisCommands[Task, String, String]] pubSubClient <- ZIO.service[PubSubCommands[fs2Stream, String, String]] } yield new Storage { - def getAssignments: Task[Map[ShardId, Option[PodAddress]]] = + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = stringClient - .hGetAll(config.assignmentsKey) + .hGetAll(config.assignmentsKey(role)) .map(_.flatMap { case (k, v) => val pod = if (v.isEmpty) None else PodAddress(v) k.toIntOption.map(_ -> pod) }) - def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = + def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = stringClient.hSet( - config.assignmentsKey, + config.assignmentsKey(role), assignments.map { case (k, v) => k.toString -> v.fold("")(_.toString) } ) *> pubSubClient - .publish(RedisChannel(config.assignmentsKey))(fs2.Stream.eval[Task, String](ZIO.succeed("ping"))) + .publish(RedisChannel(config.assignmentsKey(role)))(fs2.Stream.eval[Task, String](ZIO.succeed("ping"))) .toZStream(1) .runDrain - def assignmentsStream: ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = - pubSubClient.subscribe(RedisChannel(config.assignmentsKey)).toZStream(1).mapZIO(_ => getAssignments) + def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = + pubSubClient + .subscribe(RedisChannel(config.assignmentsKey(role))) + .toZStream(1) + .mapZIO(_ => getAssignments(role)) def getPods: Task[Map[PodAddress, Pod]] = stringClient .hGetAll(config.podsKey) - .map(_.toList.flatMap { case (k, v) => PodAddress(k).map(address => address -> Pod(address, v)) }.toMap) + .map(_.toList.flatMap { case (k, v) => + PodAddress(k).flatMap(address => v.fromJson[Pod].toOption.map(address -> _)) + }.toMap) def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = stringClient.del(config.podsKey) *> stringClient - .hSet(config.podsKey, pods.map { case (k, v) => k.toString -> v.version }) + .hSet(config.podsKey, pods.map { case (k, v) => k.toString -> v.toJson }) .when(pods.nonEmpty) .unit } diff --git a/storage-redis/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala b/storage-redis/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala index 62dd39e..98d8bd6 100644 --- a/storage-redis/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala +++ b/storage-redis/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala @@ -1,13 +1,13 @@ package com.devsisters.shardcake -import com.devsisters.shardcake.StorageRedis.{ fs2Stream, Redis } +import com.devsisters.shardcake.StorageRedis.Redis import com.devsisters.shardcake.interfaces.Storage import com.dimafeng.testcontainers.GenericContainer +import dev.profunktor.redis4cats.Redis import dev.profunktor.redis4cats.connection.RedisClient import dev.profunktor.redis4cats.data.RedisCodec import dev.profunktor.redis4cats.effect.Log -import dev.profunktor.redis4cats.pubsub.{ PubSub, PubSubCommands } -import dev.profunktor.redis4cats.{ Redis, RedisCommands } +import dev.profunktor.redis4cats.pubsub.PubSub import zio.Clock.ClockLive import zio._ import zio.interop.catz._ @@ -49,12 +49,15 @@ object StorageRedisSpec extends ZIOSpecDefault { ) } + private val role = Role("default") + def spec: Spec[TestEnvironment with Scope, Any] = suite("StorageRedisSpec")( test("save and get pods") { - val expected = List(Pod(PodAddress("host1", 1), "1.0.0"), Pod(PodAddress("host2", 2), "2.0.0")) - .map(p => p.address -> p) - .toMap + val expected = + List(Pod(PodAddress("host1", 1), "1.0.0", Set(role)), Pod(PodAddress("host2", 2), "2.0.0", Set(role))) + .map(p => p.address -> p) + .toMap for { _ <- ZIO.serviceWithZIO[Storage](_.savePods(expected)) actual <- ZIO.serviceWithZIO[Storage](_.getPods) @@ -63,17 +66,17 @@ object StorageRedisSpec extends ZIOSpecDefault { test("save and get assignments") { val expected = Map(1 -> Some(PodAddress("host1", 1)), 2 -> None) for { - _ <- ZIO.serviceWithZIO[Storage](_.saveAssignments(expected)) - actual <- ZIO.serviceWithZIO[Storage](_.getAssignments) + _ <- ZIO.serviceWithZIO[Storage](_.saveAssignments(role, expected)) + actual <- ZIO.serviceWithZIO[Storage](_.getAssignments(role)) } yield assertTrue(expected == actual) }, test("assignments stream") { val expected = Map(1 -> Some(PodAddress("host1", 1)), 2 -> None) for { p <- Promise.make[Nothing, Map[Int, Option[PodAddress]]] - _ <- ZStream.serviceWithStream[Storage](_.assignmentsStream).runForeach(p.succeed(_)).fork + _ <- ZStream.serviceWithStream[Storage](_.assignmentsStream(role)).runForeach(p.succeed(_)).fork _ <- ClockLive.sleep(1 second) - _ <- ZIO.serviceWithZIO[Storage](_.saveAssignments(expected)) + _ <- ZIO.serviceWithZIO[Storage](_.saveAssignments(role, expected)) actual <- p.await } yield assertTrue(expected == actual) } diff --git a/storage-redisson/src/main/scala/com/devsisters/shardcake/RedisConfig.scala b/storage-redisson/src/main/scala/com/devsisters/shardcake/RedisConfig.scala index 3e64c3b..73d0a53 100644 --- a/storage-redisson/src/main/scala/com/devsisters/shardcake/RedisConfig.scala +++ b/storage-redisson/src/main/scala/com/devsisters/shardcake/RedisConfig.scala @@ -2,11 +2,12 @@ package com.devsisters.shardcake /** * The configuration for the Redis storage implementation. - * @param assignmentsKey the key to store shard assignments - * @param podsKey the key to store registered pods + * @param assignmentsKey a function from a role to the key to use to store shard assignments + * @param podsKey the key to use to store registered pods */ -case class RedisConfig(assignmentsKey: String, podsKey: String) +case class RedisConfig(assignmentsKey: Role => String, podsKey: String) object RedisConfig { - val default: RedisConfig = RedisConfig(assignmentsKey = "shard_assignments", podsKey = "pods") + val default: RedisConfig = + RedisConfig(assignmentsKey = role => s"shard_assignments:${role.name}", podsKey = "pods") } diff --git a/storage-redisson/src/main/scala/com/devsisters/shardcake/StorageRedis.scala b/storage-redisson/src/main/scala/com/devsisters/shardcake/StorageRedis.scala index 7859653..6f230c5 100644 --- a/storage-redisson/src/main/scala/com/devsisters/shardcake/StorageRedis.scala +++ b/storage-redisson/src/main/scala/com/devsisters/shardcake/StorageRedis.scala @@ -1,16 +1,16 @@ package com.devsisters.shardcake +import scala.collection.compat._ import scala.jdk.CollectionConverters._ import com.devsisters.shardcake.interfaces.Storage import org.redisson.api.RedissonClient import org.redisson.api.listener.MessageListener import org.redisson.client.codec.StringCodec +import zio.json._ import zio.stream.ZStream import zio.{ Queue, Task, Unsafe, ZIO, ZLayer } -import scala.collection.compat._ - object StorageRedis { /** @@ -19,13 +19,12 @@ object StorageRedis { val live: ZLayer[RedissonClient with RedisConfig, Nothing, Storage] = ZLayer { for { - config <- ZIO.service[RedisConfig] - redisClient <- ZIO.service[RedissonClient] - assignmentsMap = redisClient.getMap[String, String](config.assignmentsKey) - podsMap = redisClient.getMap[String, String](config.podsKey) - assignmentsTopic = redisClient.getTopic(config.assignmentsKey, StringCodec.INSTANCE) + config <- ZIO.service[RedisConfig] + redisClient <- ZIO.service[RedissonClient] + podsMap = redisClient.getMap[String, String](config.podsKey) } yield new Storage { - def getAssignments: Task[Map[ShardId, Option[PodAddress]]] = + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = { + val assignmentsMap = redisClient.getMap[String, String](config.assignmentsKey(role)) ZIO .fromCompletionStage(assignmentsMap.readAllEntrySetAsync()) .map( @@ -38,37 +37,46 @@ object StorageRedis { ) .toMap ) - def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = + } + + def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = { + val assignmentsMap = redisClient.getMap[String, String](config.assignmentsKey(role)) + val assignmentsTopic = redisClient.getTopic(config.assignmentsKey(role), StringCodec.INSTANCE) ZIO.fromCompletionStage(assignmentsMap.putAllAsync(assignments.map { case (k, v) => k.toString -> v.fold("")(_.toString) }.asJava)) *> ZIO.fromCompletionStage(assignmentsTopic.publishAsync("ping")).unit - def assignmentsStream: ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = + } + + def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = ZStream.unwrap { for { - queue <- Queue.unbounded[String] - runtime <- ZIO.runtime[Any] - _ <- ZIO.fromCompletionStage( - assignmentsTopic.addListenerAsync( - classOf[String], - new MessageListener[String] { - def onMessage(channel: CharSequence, msg: String): Unit = - Unsafe.unsafe(implicit unsafe => runtime.unsafe.run(queue.offer(msg))) - } - ) - ) - } yield ZStream.fromQueueWithShutdown(queue).mapZIO(_ => getAssignments) + queue <- Queue.unbounded[String] + runtime <- ZIO.runtime[Any] + assignmentsTopic = redisClient.getTopic(config.assignmentsKey(role), StringCodec.INSTANCE) + _ <- ZIO.fromCompletionStage( + assignmentsTopic.addListenerAsync( + classOf[String], + new MessageListener[String] { + def onMessage(channel: CharSequence, msg: String): Unit = + Unsafe.unsafe(implicit unsafe => runtime.unsafe.run(queue.offer(msg))) + } + ) + ) + } yield ZStream.fromQueueWithShutdown(queue).mapZIO(_ => getAssignments(role)) } - def getPods: Task[Map[PodAddress, Pod]] = + def getPods: Task[Map[PodAddress, Pod]] = ZIO .fromCompletionStage(podsMap.readAllEntrySetAsync()) .map( _.asScala - .flatMap(entry => PodAddress(entry.getKey).map(address => address -> Pod(address, entry.getValue))) + .flatMap(entry => + PodAddress(entry.getKey).flatMap(address => entry.getValue.fromJson[Pod].toOption.map(address -> _)) + ) .toMap ) - def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = - ZIO.fromCompletionStage(podsMap.putAllAsync(pods.map { case (k, v) => k.toString -> v.version }.asJava)).unit + def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = + ZIO.fromCompletionStage(podsMap.putAllAsync(pods.map { case (k, v) => k.toString -> v.toJson }.asJava)).unit } } } diff --git a/storage-redisson/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala b/storage-redisson/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala index 9a71b2f..bc8310d 100644 --- a/storage-redisson/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala +++ b/storage-redisson/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala @@ -34,12 +34,15 @@ object StorageRedisSpec extends ZIOSpecDefault { } yield client } + private val role = Role("default") + def spec: Spec[TestEnvironment with Scope, Any] = suite("StorageRedisSpec")( test("save and get pods") { - val expected = List(Pod(PodAddress("host1", 1), "1.0.0"), Pod(PodAddress("host2", 2), "2.0.0")) - .map(p => p.address -> p) - .toMap + val expected = + List(Pod(PodAddress("host1", 1), "1.0.0", Set(role)), Pod(PodAddress("host2", 2), "2.0.0", Set(role))) + .map(p => p.address -> p) + .toMap for { _ <- ZIO.serviceWithZIO[Storage](_.savePods(expected)) actual <- ZIO.serviceWithZIO[Storage](_.getPods) @@ -48,17 +51,17 @@ object StorageRedisSpec extends ZIOSpecDefault { test("save and get assignments") { val expected = Map(1 -> Some(PodAddress("host1", 1)), 2 -> None) for { - _ <- ZIO.serviceWithZIO[Storage](_.saveAssignments(expected)) - actual <- ZIO.serviceWithZIO[Storage](_.getAssignments) + _ <- ZIO.serviceWithZIO[Storage](_.saveAssignments(role, expected)) + actual <- ZIO.serviceWithZIO[Storage](_.getAssignments(role)) } yield assertTrue(expected == actual) }, test("assignments stream") { val expected = Map(1 -> Some(PodAddress("host1", 1)), 2 -> None) for { p <- Promise.make[Nothing, Map[Int, Option[PodAddress]]] - _ <- ZStream.serviceWithStream[Storage](_.assignmentsStream).runForeach(p.succeed(_)).fork + _ <- ZStream.serviceWithStream[Storage](_.assignmentsStream(role)).runForeach(p.succeed(_)).fork _ <- ClockLive.sleep(1 second) - _ <- ZIO.serviceWithZIO[Storage](_.saveAssignments(expected)) + _ <- ZIO.serviceWithZIO[Storage](_.saveAssignments(role, expected)) actual <- p.await } yield assertTrue(expected == actual) } From 8e87cc3d346d5a7bb7364785d511f858a5f1401e Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Mon, 31 Mar 2025 16:29:55 +0900 Subject: [PATCH 02/23] Limit to 1 role --- .../scala/com/devsisters/shardcake/Pod.scala | 2 +- .../scala/com/devsisters/shardcake/Role.scala | 2 + .../shardcake/interfaces/Pods.scala | 10 +- .../devsisters/shardcake/ShardManager.scala | 246 +++++++++--------- .../shardcake/ShardManagerSpec.scala | 101 +++---- 5 files changed, 165 insertions(+), 196 deletions(-) diff --git a/core/src/main/scala/com/devsisters/shardcake/Pod.scala b/core/src/main/scala/com/devsisters/shardcake/Pod.scala index 1b36fcb..e42d91f 100644 --- a/core/src/main/scala/com/devsisters/shardcake/Pod.scala +++ b/core/src/main/scala/com/devsisters/shardcake/Pod.scala @@ -2,7 +2,7 @@ package com.devsisters.shardcake import zio.json._ -case class Pod(address: PodAddress, version: String, roles: Set[Role]) +case class Pod(address: PodAddress, version: String, role: Role) object Pod { implicit val encoder: JsonEncoder[Pod] = DeriveJsonEncoder.gen[Pod] diff --git a/core/src/main/scala/com/devsisters/shardcake/Role.scala b/core/src/main/scala/com/devsisters/shardcake/Role.scala index 39c24f7..2a8d44b 100644 --- a/core/src/main/scala/com/devsisters/shardcake/Role.scala +++ b/core/src/main/scala/com/devsisters/shardcake/Role.scala @@ -5,6 +5,8 @@ import zio.json._ case class Role(name: String) object Role { + val default: Role = Role("default") + implicit val encoder: JsonEncoder[Role] = DeriveJsonEncoder.gen[Role] implicit val decoder: JsonDecoder[Role] = DeriveJsonDecoder.gen[Role] } diff --git a/core/src/main/scala/com/devsisters/shardcake/interfaces/Pods.scala b/core/src/main/scala/com/devsisters/shardcake/interfaces/Pods.scala index 8f8e6a6..dc288eb 100644 --- a/core/src/main/scala/com/devsisters/shardcake/interfaces/Pods.scala +++ b/core/src/main/scala/com/devsisters/shardcake/interfaces/Pods.scala @@ -1,7 +1,7 @@ package com.devsisters.shardcake.interfaces import com.devsisters.shardcake.interfaces.Pods.BinaryMessage -import com.devsisters.shardcake.{ PodAddress, Role, ShardId } +import com.devsisters.shardcake.{ PodAddress, ShardId } import zio.stream.ZStream import zio.{ Task, ULayer, ZIO, ZLayer } @@ -15,12 +15,12 @@ trait Pods { /** * Notify a pod that it was assigned a list of shards */ - def assignShards(pod: PodAddress, shards: Set[ShardId], role: Role): Task[Unit] + def assignShards(pod: PodAddress, shards: Set[ShardId]): Task[Unit] /** * Notify a pod that it was unassigned a list of shards */ - def unassignShards(pod: PodAddress, shards: Set[ShardId], role: Role): Task[Unit] + def unassignShards(pod: PodAddress, shards: Set[ShardId]): Task[Unit] /** * Check that a pod is responsive @@ -64,8 +64,8 @@ object Pods { */ val noop: ULayer[Pods] = ZLayer.succeed(new Pods { - def assignShards(pod: PodAddress, shards: Set[ShardId], role: Role): Task[Unit] = ZIO.unit - def unassignShards(pod: PodAddress, shards: Set[ShardId], role: Role): Task[Unit] = ZIO.unit + def assignShards(pod: PodAddress, shards: Set[ShardId]): Task[Unit] = ZIO.unit + def unassignShards(pod: PodAddress, shards: Set[ShardId]): Task[Unit] = ZIO.unit def ping(pod: PodAddress): Task[Unit] = ZIO.unit def sendMessage(pod: PodAddress, message: BinaryMessage): Task[Option[Array[Byte]]] = ZIO.none def sendStream( diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index 150293e..fde2c40 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -15,7 +15,7 @@ import scala.collection.compat._ * A component in charge of assigning and unassigning shards to/from pods */ class ShardManager( - stateRef: Ref.Synchronized[ShardManagerState], + stateRef: Ref.Synchronized[Map[Role, ShardManagerState]], rebalanceSemaphores: Ref.Synchronized[Map[Role, Semaphore]], eventsHub: Hub[ShardingEvent], healthApi: PodsHealth, @@ -25,7 +25,7 @@ class ShardManager( ) { def getAssignments(role: Role): UIO[Map[ShardId, Option[PodAddress]]] = - stateRef.get.map(_.shards(role)) + stateRef.get.map(_.get(role).fold(Map.empty[ShardId, Option[PodAddress]])(_.shards)) def getShardingEvents: ZStream[Any, Nothing, ShardingEvent] = ZStream.fromHub(eventsHub) @@ -34,25 +34,35 @@ class ShardManager( ZIO.ifZIO(healthApi.isAlive(pod.address))( onTrue = for { _ <- ZIO.logInfo(s"Registering $pod") - state <- stateRef.updateAndGetZIO(state => - ZIO - .succeed(OffsetDateTime.now()) - .map(cdt => state.copy(pods = state.pods.updated(pod.address, PodWithMetadata(pod, cdt)))) - ) - _ <- ManagerMetrics.pods.increment - _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address, pod.roles)) - _ <- ZIO.foreachDiscard(pod.roles) { role => - ZIO.when(state.unassignedShards(role).nonEmpty)(rebalance(role, rebalanceImmediately = false).forkDaemon) + _ <- ZIO.whenZIO(stateRef.get.map(_.exists { case (role, state) => + state.pods.get(pod.address).exists(_ => role != pod.role) + }))(ZIO.fail(new RuntimeException(s"Pod $pod is already registered with a different role"))) + cdt <- ZIO.succeed(OffsetDateTime.now()) + state <- stateRef.modify { states => + val previous = states.getOrElse(pod.role, ShardManagerState(config.getNumberOfShards(pod.role))) + val state = previous.copy(pods = previous.pods.updated(pod.address, PodWithMetadata(pod, cdt))) + (state, states.updated(pod.role, state)) } + _ <- ManagerMetrics.pods.increment + _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address, pod.role)) + _ <- ZIO.when(state.unassignedShards.nonEmpty)( + rebalance(pod.role, rebalanceImmediately = false).forkDaemon + ) _ <- persistPods.forkDaemon } yield (), onFalse = ZIO.logWarning(s"Pod $pod requested to register but is not alive, ignoring") *> ZIO.fail(new RuntimeException(s"Pod $pod is not healthy, refusing to register")) ) + private def podExists(podAddress: PodAddress): UIO[Boolean] = + podRole(podAddress).map(_.isDefined) + + private def podRole(podAddress: PodAddress): UIO[Option[Role]] = + stateRef.get.map(_.collectFirst { case (role, state) if state.pods.contains(podAddress) => role }) + def notifyUnhealthyPod(podAddress: PodAddress, ignoreMetric: Boolean = false): UIO[Unit] = ZIO - .whenZIODiscard(stateRef.get.map(_.pods.contains(podAddress))) { + .whenZIODiscard(podExists(podAddress)) { ManagerMetrics.podHealthChecked.tagged("pod_address", podAddress.toString).increment.unless(ignoreMetric) *> eventsHub.publish(ShardingEvent.PodHealthChecked(podAddress)) *> ZIO.unlessZIO(healthApi.isAlive(podAddress))( @@ -62,51 +72,43 @@ class ShardManager( def checkAllPodsHealth: UIO[Unit] = for { - pods <- stateRef.get.map(_.pods.keySet) + pods <- stateRef.get.map(_.values.flatMap(_.pods.keySet)) _ <- ZIO.foreachParDiscard(pods)(notifyUnhealthyPod(_, ignoreMetric = true)).withParallelism(4) } yield () def unregister(podAddress: PodAddress): UIO[Unit] = - ZIO.whenZIODiscard(stateRef.get.map(_.pods.contains(podAddress))) { + ZIO.whenCaseZIODiscard(podRole(podAddress)) { case Some(role) => for { _ <- ZIO.logInfo(s"Unregistering $podAddress") - unassignments <- stateRef.modify { state => + unassignments <- stateRef.modify { states => + val previous = states.get(role) ( - state.assignments.map { case (role, assignments) => - role -> assignments.shards.collect { - case (shard, Some(p)) if p == podAddress => shard - }.toSet - }.filter { case (_, shards) => shards.nonEmpty }, - state.copy( - pods = state.pods - podAddress, - assignments = state.assignments.map { case (role, assignments) => - role -> assignments.copy( - shards = assignments.shards.map { case (k, v) => - k -> (if (v.contains(podAddress)) None else v) - } + previous + .map(_.shards.collect { case (shard, Some(p)) if p == podAddress => shard }.toSet) + .getOrElse(Set.empty), + previous + .map(p => + p.copy( + pods = p.pods - podAddress, + shards = + p.shards.map { case (k, v) => k -> (if (v.contains(podAddress)) None else v) } ) - } - ) + ) + .fold(states)(states.updated(role, _)) ) } - _ <- ZIO.foreachDiscard(unassignments) { case (role, shards) => - for { - _ <- ManagerMetrics.assignedShards - .tagged("role", role.name) - .tagged("pod_address", podAddress.toString) - .decrementBy(shards.size) - _ <- ManagerMetrics.unassignedShards - .tagged("role", role.name) - .incrementBy(shards.size) - _ <- eventsHub - .publish(ShardingEvent.ShardsUnassigned(podAddress, role, shards)) - .when(shards.nonEmpty) - _ <- rebalance(role, rebalanceImmediately = true).forkDaemon - } yield () - } _ <- ManagerMetrics.pods.decrement + _ <- ManagerMetrics.assignedShards + .tagged("role", role.name) + .tagged("pod_address", podAddress.toString) + .decrementBy(unassignments.size) + _ <- ManagerMetrics.unassignedShards.tagged("role", role.name).incrementBy(unassignments.size) _ <- eventsHub.publish(ShardingEvent.PodUnregistered(podAddress)) + _ <- eventsHub + .publish(ShardingEvent.ShardsUnassigned(podAddress, role, unassignments)) + .when(unassignments.nonEmpty) _ <- persistPods.forkDaemon + _ <- rebalance(role, rebalanceImmediately = true).forkDaemon } yield () } @@ -121,11 +123,11 @@ class ShardManager( private def rebalance(role: Role, rebalanceImmediately: Boolean): UIO[Unit] = getSemaphore(role).flatMap(_.withPermit { for { - state <- stateRef.get + state <- stateRef.get.map(_.getOrElse(role, ShardManagerState(config.getNumberOfShards(role)))) // find which shards to assign and unassign - (assignments, unassignments) = if (rebalanceImmediately || state.unassignedShards(role).nonEmpty) - decideAssignmentsForUnassignedShards(role, state) - else decideAssignmentsForUnbalancedShards(role, state, config.rebalanceRate) + (assignments, unassignments) = if (rebalanceImmediately || state.unassignedShards.nonEmpty) + decideAssignmentsForUnassignedShards(state) + else decideAssignmentsForUnbalancedShards(state, config.rebalanceRate) areChanges = assignments.nonEmpty || unassignments.nonEmpty _ <- (ZIO.logDebug(s"Rebalancing role ${role.name} (rebalanceImmediately=$rebalanceImmediately)") *> ManagerMetrics.rebalances.tagged("role", role.name).increment).when(areChanges) @@ -147,7 +149,7 @@ class ShardManager( // do the unassignments first failed <- ZIO .foreachPar(readyUnassignments.toList) { case (pod, shards) => - (podApi.unassignShards(pod, shards, role) *> updateShardsState(role, shards, None)).foldZIO( + (podApi.unassignShards(pod, shards) *> updateShardsState(role, shards, None)).foldZIO( _ => ZIO.succeed((Set(pod), shards)), _ => ManagerMetrics.assignedShards @@ -172,7 +174,7 @@ class ShardManager( // then do the assignments failedAssignedPods <- ZIO .foreachPar(filteredAssignments.toList) { case (pod, shards) => - (podApi.assignShards(pod, shards, role) *> updateShardsState(role, shards, Some(pod))) + (podApi.assignShards(pod, shards) *> updateShardsState(role, shards, Some(pod))) .foldZIO( _ => ZIO.succeed(Set(pod)), _ => @@ -208,13 +210,15 @@ class ShardManager( private def persistAssignments(role: Role): UIO[Unit] = withRetry( - stateRef.get.flatMap(state => stateRepository.saveAssignments(role, state.shards(role))) + stateRef.get.flatMap(states => + stateRepository.saveAssignments(role, states.get(role).map(_.shards).getOrElse(Map.empty)) + ) ) private val persistAllAssignments: UIO[Unit] = withRetry( - stateRef.get.flatMap(state => - ZIO.foreachDiscard(state.assignments) { case (role, assignments) => + stateRef.get.flatMap(states => + ZIO.foreachDiscard(states) { case (role, assignments) => stateRepository.saveAssignments(role, assignments.shards) } ) @@ -222,28 +226,29 @@ class ShardManager( private def persistPods: UIO[Unit] = withRetry( - stateRef.get.flatMap(state => stateRepository.savePods(state.pods.map { case (k, v) => (k, v.pod) })) + stateRef.get.flatMap(states => + stateRepository.savePods(states.values.flatMap(_.pods.map { case (k, v) => (k, v.pod) }).toMap) + ) ) private def updateShardsState(role: Role, shards: Set[ShardId], pod: Option[PodAddress]): Task[Unit] = - stateRef.updateZIO(state => + stateRef.updateZIO { states => + val previous = states.get(role) ZIO - .whenCase(pod) { - case Some(pod) if !state.pods.contains(pod) => ZIO.fail(new Exception(s"Pod $pod is no longer registered")) + .whenCase((previous, pod)) { + case (Some(p), Some(pod)) if !p.pods.contains(pod) => ZIO.fail(new Exception(s"Pod $pod is not registered")) } .as( - state.copy(assignments = - state.assignments.updated( + previous.fold(states)(state => + states.updated( role, - ShardAssignments( - state.shards(role).map { case (shard, assignment) => - shard -> (if (shards.contains(shard)) pod else assignment) - } - ) + state.copy(shards = state.shards.map { case (shard, assignment) => + shard -> (if (shards.contains(shard)) pod else assignment) + }) ) ) ) - ) + } } object ShardManager { @@ -269,8 +274,9 @@ object ShardManager { ZIO.logInfo(s"Ignoring pods that are no longer alive ${failedPods.mkString("[", ", ", "]")}") ) filteredPods = filtered.toMap - roles = filteredPods.flatMap(_._2.roles).toSet + roles = filteredPods.map(_._2.role).toSet _ <- ZIO.when(filteredPods.nonEmpty)(ZIO.logInfo(s"Recovered pods ${filteredPods.mkString("[", ", ", "]")}")) + rolePods = filteredPods.groupBy { case (_, pod) => pod.role }.map { case (role, pods) => role -> pods.values } roleAssignments <- ZIO .foreach(roles) { role => for { @@ -301,22 +307,18 @@ object ShardManager { } .map(_.toMap) cdt <- ZIO.succeed(OffsetDateTime.now()) - initialState = ShardManagerState( - filteredPods.map { case (k, v) => k -> PodWithMetadata(v, cdt) }, - roles - .map(role => - role -> ShardAssignments( - (1 to config.getNumberOfShards(role)).map(_ -> None).toMap ++ - roleAssignments.getOrElse(role, Map.empty) - ) - ) - .toMap, - config.getNumberOfShards - ) - _ <- ManagerMetrics.pods.incrementBy(initialState.pods.size) + initialStates = rolePods.map { case (role, pods) => + role -> ShardManagerState( + pods.map(pod => pod.address -> PodWithMetadata(pod, cdt)).toMap, + (1 to config.getNumberOfShards(role)).map(_ -> None).toMap ++ + roleAssignments.getOrElse(role, Map.empty), + config.getNumberOfShards(role) + ) + } + _ <- ManagerMetrics.pods.incrementBy(filteredPods.size) _ <- ZIO - .foreachDiscard(initialState.roles) { role => - ZIO.foreachDiscard(initialState.shards(role)) { case (_, podAddressOpt) => + .foreachDiscard(initialStates) { case (role, state) => + ZIO.foreachDiscard(state.shards) { case (_, podAddressOpt) => podAddressOpt match { case Some(podAddress) => ManagerMetrics.assignedShards @@ -330,7 +332,7 @@ object ShardManager { } } } - state <- Ref.Synchronized.make(initialState) + state <- Ref.Synchronized.make(initialStates) rebalanceSemaphores <- Ref.Synchronized.make(Map.empty[Role, Semaphore]) eventsHub <- Hub.unbounded[ShardingEvent] shardManager = new ShardManager( @@ -352,14 +354,13 @@ object ShardManager { } _ <- shardManager.persistPods.forkDaemon // rebalance immediately if there are unassigned shards - _ <- - ZIO.foreachDiscard(roles)(role => - shardManager.rebalance(role, rebalanceImmediately = initialState.unassignedShards(role).nonEmpty).forkDaemon - ) + _ <- ZIO.foreachDiscard(initialStates) { case (role, state) => + shardManager.rebalance(role, rebalanceImmediately = state.unassignedShards.nonEmpty).forkDaemon + } // start a regular rebalance at the given interval _ <- state.get - .flatMap(state => - ZIO.foreachParDiscard(state.roles)(shardManager.rebalance(_, rebalanceImmediately = false)) + .flatMap(states => + ZIO.foreachParDiscard(states.keySet)(shardManager.rebalance(_, rebalanceImmediately = false)) ) .repeat(Schedule.spaced(config.rebalanceInterval)) .forkDaemon @@ -401,35 +402,26 @@ object ShardManager { if (xs eq ys) 0 else loop(xs, ys) } - case class ShardAssignments(shards: Map[ShardId, Option[PodAddress]]) { - lazy val unassignedShards: Set[ShardId] = shards.collect { case (k, None) => k }.toSet - lazy val shardsPerPod: Map[PodAddress, Set[ShardId]] = - shards.groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet } - } - case class ShardManagerState( pods: Map[PodAddress, PodWithMetadata], - assignments: Map[Role, ShardAssignments], - getNumberOfShards: Role => Int + shards: Map[ShardId, Option[PodAddress]], + numberOfShards: Int ) { + lazy val unassignedShards: Set[ShardId] = shards.collect { case (k, None) => k }.toSet + lazy val averageShardsPerPod: ShardId = if (pods.nonEmpty) shards.size / pods.size else 0 private lazy val podVersions = pods.values.toList.map(extractVersion) lazy val maxVersion: Option[List[ShardId]] = podVersions.maxOption lazy val allPodsHaveMaxVersion: Boolean = podVersions.forall(maxVersion.contains) - lazy val roles: Set[Role] = pods.values.flatMap(_.pod.roles).toSet ++ assignments.keySet - - private lazy val emptyShardsPerPod: Map[PodAddress, Set[ShardId]] = - pods.map { case (k, _) => k -> Set.empty[ShardId] } - private def assignmentsForRole(role: Role): ShardAssignments = - assignments.getOrElse(role, ShardAssignments((1 to getNumberOfShards(role)).map(_ -> None).toMap)) - - def shards(role: Role): Map[ShardId, Option[PodAddress]] = assignmentsForRole(role).shards - def unassignedShards(role: Role): Set[ShardId] = assignmentsForRole(role).unassignedShards - def averageShardsPerPod(role: Role): ShardId = if (pods.nonEmpty) shards(role).size / pods.size else 0 - - def shardsPerPod(role: Role): Map[PodAddress, Set[ShardId]] = - emptyShardsPerPod ++ assignmentsForRole(role).shardsPerPod + lazy val shardsPerPod: Map[PodAddress, Set[ShardId]] = + pods.map { case (k, _) => k -> Set.empty[ShardId] } ++ + shards.groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet } + } + object ShardManagerState { + def apply(numberOfShards: Int): ShardManagerState = + ShardManagerState(Map.empty, (1 to numberOfShards).map(_ -> None).toMap, numberOfShards) } + case class PodWithMetadata(pod: Pod, registered: OffsetDateTime) sealed trait ShardingEvent @@ -440,56 +432,50 @@ object ShardManager { case class ShardsUnassigned(pod: PodAddress, role: Role, shards: Set[ShardId]) extends ShardingEvent { override def toString: String = s"ShardsUnassigned(pod=$pod, role=${role.name}, shards=${renderShardIds(shards)})" } - case class PodRegistered(pod: PodAddress, roles: Set[Role]) extends ShardingEvent + case class PodRegistered(pod: PodAddress, role: Role) extends ShardingEvent case class PodUnregistered(pod: PodAddress) extends ShardingEvent case class PodHealthChecked(pod: PodAddress) extends ShardingEvent } def decideAssignmentsForUnassignedShards( - role: Role, state: ShardManagerState ): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) = - pickNewPods(state.unassignedShards(role).toList, role, state, rebalanceImmediately = true, 1.0) + pickNewPods(state.unassignedShards.toList, state, rebalanceImmediately = true, 1.0) def decideAssignmentsForUnbalancedShards( - role: Role, state: ShardManagerState, rebalanceRate: Double ): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) = { val extraShardsToAllocate = if (state.allPodsHaveMaxVersion) { // don't do regular rebalance in the middle of a rolling update - state - .shardsPerPod(role) - .flatMap { case (_, shards) => - // count how many extra shards compared to the average - val extraShards = (shards.size - state.averageShardsPerPod(role)).max(0) - Random.shuffle(shards).take(extraShards) - } - .toSet + state.shardsPerPod.flatMap { case (_, shards) => + // count how many extra shards compared to the average + val extraShards = (shards.size - state.averageShardsPerPod).max(0) + Random.shuffle(shards).take(extraShards) + }.toSet } else Set.empty val sortedShardsToRebalance = extraShardsToAllocate.toList.sortBy { shard => // handle unassigned shards first, then shards on the pods with most shards, then shards on old pods - state.shards(role).get(shard).flatten.fold((Int.MinValue, OffsetDateTime.MIN)) { pod => + state.shards.get(shard).flatten.fold((Int.MinValue, OffsetDateTime.MIN)) { pod => ( - state.shardsPerPod(role).get(pod).fold(Int.MinValue)(-_.size), + state.shardsPerPod.get(pod).fold(Int.MinValue)(-_.size), state.pods.get(pod).fold(OffsetDateTime.MIN)(_.registered) ) } } - pickNewPods(sortedShardsToRebalance, role, state, rebalanceImmediately = false, rebalanceRate) + pickNewPods(sortedShardsToRebalance, state, rebalanceImmediately = false, rebalanceRate) } private def pickNewPods( shardsToRebalance: List[ShardId], - role: Role, state: ShardManagerState, rebalanceImmediately: Boolean, rebalanceRate: Double ): (Map[PodAddress, Set[ShardId]], Map[PodAddress, Set[ShardId]]) = { - val (_, assignments) = shardsToRebalance.foldLeft((state.shardsPerPod(role), List.empty[(ShardId, PodAddress)])) { + val (_, assignments) = shardsToRebalance.foldLeft((state.shardsPerPod, List.empty[(ShardId, PodAddress)])) { case ((shardsPerPod, assignments), shard) => val unassignedPods = assignments.flatMap { case (shard, _) => - state.shards(role).get(shard).flatten[PodAddress] + state.shards.get(shard).flatten[PodAddress] }.toSet // find pod with least amount of shards shardsPerPod @@ -500,13 +486,13 @@ object ShardManager { // don't assign too many shards to the same pods, unless we need rebalance immediately .filter { case (pod, _) => rebalanceImmediately || - assignments.count { case (_, p) => p == pod } < state.shards(role).size * rebalanceRate + assignments.count { case (_, p) => p == pod } < state.shards.size * rebalanceRate } // don't assign to a pod that was unassigned in the same rebalance .filterNot { case (pod, _) => unassignedPods.contains(pod) } .minByOption(_._2.size) match { case Some((pod, shards)) => - val oldPod = state.shards(role).get(shard).flatten + val oldPod = state.shards.get(shard).flatten // if old pod is same as new pod, don't change anything if (oldPod.contains(pod)) (shardsPerPod, assignments) @@ -525,7 +511,7 @@ object ShardManager { case None => (shardsPerPod, assignments) } } - val unassignments = assignments.flatMap { case (shard, _) => state.shards(role).get(shard).flatten.map(shard -> _) } + val unassignments = assignments.flatMap { case (shard, _) => state.shards.get(shard).flatten.map(shard -> _) } val assignmentsPerPod = assignments.groupBy(_._2).map { case (k, v) => k -> v.map(_._1).toSet } val unassignmentsPerPod = unassignments.groupBy(_._2).map { case (k, v) => k -> v.map(_._1).toSet } (assignmentsPerPod, unassignmentsPerPod) diff --git a/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala b/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala index 38d1736..9e45109 100644 --- a/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala +++ b/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala @@ -1,6 +1,6 @@ package com.devsisters.shardcake -import com.devsisters.shardcake.ShardManager.{ PodWithMetadata, ShardAssignments, ShardManagerState } +import com.devsisters.shardcake.ShardManager.{ PodWithMetadata, ShardManagerState } import com.devsisters.shardcake.interfaces.{ Pods, PodsHealth, Storage } import zio._ import zio.stream.ZStream @@ -9,10 +9,10 @@ import zio.test._ import java.time.OffsetDateTime object ShardManagerSpec extends ZIOSpecDefault { - private val role = Role("default") - private val pod1 = PodWithMetadata(Pod(PodAddress("1", 1), "1.0.0", Set(role)), OffsetDateTime.MIN) - private val pod2 = PodWithMetadata(Pod(PodAddress("2", 2), "1.0.0", Set(role)), OffsetDateTime.MIN) - private val pod3 = PodWithMetadata(Pod(PodAddress("3", 3), "1.0.0", Set(role)), OffsetDateTime.MIN) + private val role = Role.default + private val pod1 = PodWithMetadata(Pod(PodAddress("1", 1), "1.0.0", role), OffsetDateTime.MIN) + private val pod2 = PodWithMetadata(Pod(PodAddress("2", 2), "1.0.0", role), OffsetDateTime.MIN) + private val pod3 = PodWithMetadata(Pod(PodAddress("3", 3), "1.0.0", role), OffsetDateTime.MIN) override def spec: Spec[Any, Throwable] = suite("ShardManagerSpec")( @@ -21,11 +21,10 @@ object ShardManagerSpec extends ZIOSpecDefault { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), - assignments = - Map(role -> ShardAssignments(Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)))), - getNumberOfShards = ManagerConfig.default.getNumberOfShards + shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)), + numberOfShards = ManagerConfig.default.numberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue( assignments.contains(pod2.pod.address), assignments.size == 1, @@ -40,55 +39,45 @@ object ShardManagerSpec extends ZIOSpecDefault { pod1.pod.address -> pod1, pod2.pod.address -> pod2.copy(pod = pod2.pod.copy(version = "0.1.2")) ), // older version - assignments = - Map(role -> ShardAssignments(Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)))), - getNumberOfShards = ManagerConfig.default.getNumberOfShards + shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)), + numberOfShards = ManagerConfig.default.numberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue(assignments.isEmpty, unassignments.isEmpty) }, test("Don't rebalance when already well balanced") { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), - assignments = - Map(role -> ShardAssignments(Map(1 -> Some(pod1.pod.address), 2 -> Some(pod2.pod.address)))), - getNumberOfShards = ManagerConfig.default.getNumberOfShards + shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod2.pod.address)), + numberOfShards = ManagerConfig.default.numberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue(assignments.isEmpty, unassignments.isEmpty) }, test("Don't rebalance when only 1 shard difference") { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), - assignments = Map( - role -> ShardAssignments( - Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)) - ) - ), - getNumberOfShards = ManagerConfig.default.getNumberOfShards + shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)), + numberOfShards = ManagerConfig.default.numberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue(assignments.isEmpty, unassignments.isEmpty) }, test("Rebalance when 2 shard difference") { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), - assignments = Map( - role -> ShardAssignments( - Map( - 1 -> Some(pod1.pod.address), - 2 -> Some(pod1.pod.address), - 3 -> Some(pod1.pod.address), - 4 -> Some(pod2.pod.address) - ) - ) + shards = Map( + 1 -> Some(pod1.pod.address), + 2 -> Some(pod1.pod.address), + 3 -> Some(pod1.pod.address), + 4 -> Some(pod2.pod.address) ), - getNumberOfShards = ManagerConfig.default.getNumberOfShards + numberOfShards = ManagerConfig.default.numberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue( assignments.contains(pod2.pod.address), assignments.size == 1, @@ -100,14 +89,10 @@ object ShardManagerSpec extends ZIOSpecDefault { val state = ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2, pod3.pod.address -> pod3), - assignments = Map( - role -> ShardAssignments( - Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)) - ) - ), - getNumberOfShards = ManagerConfig.default.getNumberOfShards + shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)), + numberOfShards = ManagerConfig.default.numberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue( assignments.contains(pod3.pod.address), assignments.size == 1, @@ -119,44 +104,42 @@ object ShardManagerSpec extends ZIOSpecDefault { val state = ShardManagerState( pods = Map(), - assignments = Map(role -> ShardAssignments(Map(1 -> Some(pod1.pod.address)))), - getNumberOfShards = ManagerConfig.default.getNumberOfShards + shards = Map(1 -> Some(pod1.pod.address)), + numberOfShards = ManagerConfig.default.numberOfShards ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, state, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue(assignments.isEmpty, unassignments.isEmpty) }, test("Balance well when 30 nodes are starting one by one") { val state = ShardManagerState( pods = Map(), - assignments = Map(role -> ShardAssignments((1 to 300).map(_ -> None).toMap)), - getNumberOfShards = ManagerConfig.default.getNumberOfShards + shards = (1 to 300).map(_ -> None).toMap, + numberOfShards = ManagerConfig.default.numberOfShards ) val result = (1 to 30).foldLeft(state) { case (state, podNumber) => val podAddress = PodAddress("", podNumber) val s1 = state.copy(pods = - state.pods.updated(podAddress, PodWithMetadata(Pod(podAddress, "v1", Set(role)), OffsetDateTime.now())) + state.pods.updated(podAddress, PodWithMetadata(Pod(podAddress, "v1", role), OffsetDateTime.now())) ) - val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(role, s1, 1d) + val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(s1, 1d) val s2 = unassignments.foldLeft(s1) { case (state, (_, shards)) => shards.foldLeft(state) { case (state, shard) => - state.copy(assignments = Map(role -> ShardAssignments(state.shards(role).updated(shard, None)))) + state.copy(shards = state.shards.updated(shard, None)) } } val s3 = assignments.foldLeft(s2) { case (state, (address, shards)) => shards.foldLeft(state) { case (state, shard) => - state.copy(assignments = - Map(role -> ShardAssignments(state.shards(role).updated(shard, Some(address)))) - ) + state.copy(shards = state.shards.updated(shard, Some(address))) } } s3 } val shardsPerPod = - result.shards(role).groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet } + result.shards.groupBy(_._2).collect { case (Some(address), shards) => address -> shards.keySet } assertTrue(shardsPerPod.values.forall(_.size == 10)) } ), @@ -165,7 +148,7 @@ object ShardManagerSpec extends ZIOSpecDefault { (for { // setup 20 pods first _ <- simulate( - (1 to 20).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", Set(role)))) + (1 to 20).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", role))) ) _ <- TestClock.adjust(10 minutes) assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments(role)) @@ -177,7 +160,7 @@ object ShardManagerSpec extends ZIOSpecDefault { // bring 5 new pods _ <- simulate( - (21 to 25).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", Set(role)))) + (21 to 25).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", role))) ) _ <- TestClock.adjust(20 seconds) assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments(role)) @@ -195,7 +178,7 @@ object ShardManagerSpec extends ZIOSpecDefault { (for { // setup 25 pods first _ <- simulate( - (1 to 25).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", Set(role)))) + (1 to 25).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", role))) ) _ <- TestClock.adjust(10 minutes) assignments <- ZIO.serviceWithZIO[ShardManager](_.getAssignments(role)) @@ -221,9 +204,7 @@ object ShardManagerSpec extends ZIOSpecDefault { { val setup = (for { _ <- simulate { - (1 to 10).toList.map(i => - SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", Set(role))) - ) + (1 to 10).toList.map(i => SimulationEvent.PodRegister(Pod(PodAddress("server", i), "1", role))) } _ <- TestClock.adjust(10 minutes) // busy wait for the forked daemon fibers to do their job From 1c6565cf8f9c8f5662871d12426454b26f5577d4 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Mon, 31 Mar 2025 17:01:12 +0900 Subject: [PATCH 03/23] Polish --- .../devsisters/shardcake/ShardManager.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index fde2c40..ca9cdba 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -81,20 +81,22 @@ class ShardManager( for { _ <- ZIO.logInfo(s"Unregistering $podAddress") unassignments <- stateRef.modify { states => - val previous = states.get(role) + val stateOpt = states.get(role) ( - previous + stateOpt .map(_.shards.collect { case (shard, Some(p)) if p == podAddress => shard }.toSet) .getOrElse(Set.empty), - previous - .map(p => - p.copy( - pods = p.pods - podAddress, - shards = - p.shards.map { case (k, v) => k -> (if (v.contains(podAddress)) None else v) } + stateOpt.fold(states)(state => + states.updated( + role, + state.copy( + pods = state.pods - podAddress, + shards = state.shards.map { case (k, v) => + k -> (if (v.contains(podAddress)) None else v) + } ) ) - .fold(states)(states.updated(role, _)) + ) ) } _ <- ManagerMetrics.pods.decrement From 1fa77fd74a889deac0434622bb2ada308335c101 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Mon, 31 Mar 2025 17:07:32 +0900 Subject: [PATCH 04/23] Metrics --- .../devsisters/shardcake/ShardManager.scala | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index ca9cdba..e73776c 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -43,7 +43,7 @@ class ShardManager( val state = previous.copy(pods = previous.pods.updated(pod.address, PodWithMetadata(pod, cdt))) (state, states.updated(pod.role, state)) } - _ <- ManagerMetrics.pods.increment + _ <- ManagerMetrics.pods.tagged("role", pod.role.name).increment _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address, pod.role)) _ <- ZIO.when(state.unassignedShards.nonEmpty)( rebalance(pod.role, rebalanceImmediately = false).forkDaemon @@ -99,7 +99,7 @@ class ShardManager( ) ) } - _ <- ManagerMetrics.pods.decrement + _ <- ManagerMetrics.pods.tagged("role", role.name).decrement _ <- ManagerMetrics.assignedShards .tagged("role", role.name) .tagged("pod_address", podAddress.toString) @@ -317,22 +317,22 @@ object ShardManager { config.getNumberOfShards(role) ) } - _ <- ManagerMetrics.pods.incrementBy(filteredPods.size) _ <- ZIO .foreachDiscard(initialStates) { case (role, state) => - ZIO.foreachDiscard(state.shards) { case (_, podAddressOpt) => - podAddressOpt match { - case Some(podAddress) => - ManagerMetrics.assignedShards - .tagged("role", role.name) - .tagged("pod_address", podAddress.toString) - .increment - case None => - ManagerMetrics.unassignedShards - .tagged("role", role.name) - .increment + ManagerMetrics.pods.tagged("role", role.name).incrementBy(state.pods.size) *> + ZIO.foreachDiscard(state.shards) { case (_, podAddressOpt) => + podAddressOpt match { + case Some(podAddress) => + ManagerMetrics.assignedShards + .tagged("role", role.name) + .tagged("pod_address", podAddress.toString) + .increment + case None => + ManagerMetrics.unassignedShards + .tagged("role", role.name) + .increment + } } - } } state <- Ref.Synchronized.make(initialStates) rebalanceSemaphores <- Ref.Synchronized.make(Map.empty[Role, Semaphore]) From 95189d59cbb1feb76492e82e8ff5aa243ccbec59 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Tue, 1 Apr 2025 12:49:49 +0900 Subject: [PATCH 05/23] Fix specs --- .../scala/com/devsisters/shardcake/StorageRedisSpec.scala | 4 ++-- .../scala/com/devsisters/shardcake/StorageRedisSpec.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/storage-redis/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala b/storage-redis/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala index 98d8bd6..ec782ed 100644 --- a/storage-redis/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala +++ b/storage-redis/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala @@ -49,13 +49,13 @@ object StorageRedisSpec extends ZIOSpecDefault { ) } - private val role = Role("default") + private val role = Role.default def spec: Spec[TestEnvironment with Scope, Any] = suite("StorageRedisSpec")( test("save and get pods") { val expected = - List(Pod(PodAddress("host1", 1), "1.0.0", Set(role)), Pod(PodAddress("host2", 2), "2.0.0", Set(role))) + List(Pod(PodAddress("host1", 1), "1.0.0", role), Pod(PodAddress("host2", 2), "2.0.0", role)) .map(p => p.address -> p) .toMap for { diff --git a/storage-redisson/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala b/storage-redisson/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala index bc8310d..0acd99a 100644 --- a/storage-redisson/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala +++ b/storage-redisson/src/test/scala/com/devsisters/shardcake/StorageRedisSpec.scala @@ -34,13 +34,13 @@ object StorageRedisSpec extends ZIOSpecDefault { } yield client } - private val role = Role("default") + private val role = Role.default def spec: Spec[TestEnvironment with Scope, Any] = suite("StorageRedisSpec")( test("save and get pods") { val expected = - List(Pod(PodAddress("host1", 1), "1.0.0", Set(role)), Pod(PodAddress("host2", 2), "2.0.0", Set(role))) + List(Pod(PodAddress("host1", 1), "1.0.0", role), Pod(PodAddress("host2", 2), "2.0.0", role)) .map(p => p.address -> p) .toMap for { From 1d21a118ac6b9bc56e4c8e356e9691ad2873da7e Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Wed, 2 Apr 2025 17:55:37 +0900 Subject: [PATCH 06/23] WIP --- .../com/devsisters/shardcake/Server.scala | 22 +++-- .../com/devsisters/shardcake/Config.scala | 7 +- .../shardcake/ShardManagerClient.scala | 36 ++++--- .../com/devsisters/shardcake/Sharding.scala | 38 ++++---- .../shardcake/internal/GraphQLClient.scala | 93 +++++++++++++++---- 5 files changed, 134 insertions(+), 62 deletions(-) diff --git a/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala b/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala index 9d48bb4..63b5d78 100644 --- a/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala +++ b/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala @@ -25,24 +25,26 @@ object Server { pod = PodAddress("localhost", config.shardingPort) shards = (1 to config.numberOfShards).map(_ -> Some(pod)).toMap } yield new ShardManagerClient { - def register(podAddress: PodAddress): Task[Unit] = ZIO.unit - def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit - def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit - def getAssignments: Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards) + def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit + def unregister(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit + def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit + def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards) } } private val memory: ULayer[Storage] = ZLayer { for { - assignmentsRef <- Ref.make(Map.empty[ShardId, Option[PodAddress]]) + assignmentsRef <- Ref.make(Map.empty[Role, Map[ShardId, Option[PodAddress]]]) podsRef <- Ref.make(Map.empty[PodAddress, Pod]) } yield new Storage { - def getAssignments: Task[Map[ShardId, Option[PodAddress]]] = assignmentsRef.get - def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = assignmentsRef.set(assignments) - def assignmentsStream: ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = ZStream.never - def getPods: Task[Map[PodAddress, Pod]] = podsRef.get - def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = podsRef.set(pods) + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = + assignmentsRef.get.map(_.getOrElse(role, Map.empty)) + def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = + assignmentsRef.update(_.updated(role, assignments)) + def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = ZStream.never + def getPods: Task[Map[PodAddress, Pod]] = podsRef.get + def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = podsRef.set(pods) } } diff --git a/entities/src/main/scala/com/devsisters/shardcake/Config.scala b/entities/src/main/scala/com/devsisters/shardcake/Config.scala index 1f2d98f..ba128e5 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/Config.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/Config.scala @@ -17,6 +17,7 @@ import zio._ * @param refreshAssignmentsRetryInterval retry interval in case of failure getting shard assignments from storage * @param unhealthyPodReportInterval interval to report unhealthy pods to the Shard Manager (this exists to prevent calling the Shard Manager for each failed message) * @param simulateRemotePods disable optimizations when sending a message to an entity hosted on the local shards (this will force serialization of all messages) + * @param role role of the current pod */ case class Config( numberOfShards: Int, @@ -29,7 +30,8 @@ case class Config( sendTimeout: Duration, refreshAssignmentsRetryInterval: Duration, unhealthyPodReportInterval: Duration, - simulateRemotePods: Boolean + simulateRemotePods: Boolean, + role: Role ) object Config { @@ -44,6 +46,7 @@ object Config { sendTimeout = 10 seconds, refreshAssignmentsRetryInterval = 5 seconds, unhealthyPodReportInterval = 5 seconds, - simulateRemotePods = false + simulateRemotePods = false, + role = Role.default ) } diff --git a/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala b/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala index 4e4bfbe..280df48 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala @@ -3,7 +3,7 @@ package com.devsisters.shardcake import caliban.client.Operations.IsOperation import caliban.client.SelectionBuilder import com.devsisters.shardcake.internal.GraphQLClient -import com.devsisters.shardcake.internal.GraphQLClient.PodAddressInput +import com.devsisters.shardcake.internal.GraphQLClient.{ PodAddressInput, RoleInput } import sttp.client3.SttpBackend import sttp.client3.asynchttpclient.zio.AsyncHttpClientZioBackend import zio.{ Config => _, _ } @@ -12,10 +12,10 @@ import zio.{ Config => _, _ } * An interface to communicate with the Shard Manager API */ trait ShardManagerClient { - def register(podAddress: PodAddress): Task[Unit] - def unregister(podAddress: PodAddress): Task[Unit] + def register(podAddress: PodAddress, role: Role): Task[Unit] + def unregister(podAddress: PodAddress, role: Role): Task[Unit] def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] - def getAssignments: Task[Map[Int, Option[PodAddress]]] + def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] } object ShardManagerClient { @@ -49,10 +49,10 @@ object ShardManagerClient { pod = PodAddress(config.selfHost, config.shardingPort) shards = (1 to config.numberOfShards).map(_ -> Some(pod)).toMap } yield new ShardManagerClient { - def register(podAddress: PodAddress): Task[Unit] = ZIO.unit - def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit - def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit - def getAssignments: Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards) + def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit + def unregister(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit + def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit + def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards) } } @@ -60,24 +60,32 @@ object ShardManagerClient { private def send[Origin: IsOperation, A](query: SelectionBuilder[Origin, A]): Task[A] = sttp.send(query.toRequest(config.shardManagerUri)).map(_.body).absolve - def register(podAddress: PodAddress): Task[Unit] = + def register(podAddress: PodAddress, role: Role): Task[Unit] = send( - GraphQLClient.Mutations.register(PodAddressInput(podAddress.host, podAddress.port), config.serverVersion) + GraphQLClient.Mutations.register( + PodAddressInput(podAddress.host, podAddress.port), + config.serverVersion, + RoleInput(role.name) + ) ).unit - def unregister(podAddress: PodAddress): Task[Unit] = + def unregister(podAddress: PodAddress, role: Role): Task[Unit] = send( - GraphQLClient.Mutations.unregister(PodAddressInput(podAddress.host, podAddress.port), config.serverVersion) + GraphQLClient.Mutations.unregister( + PodAddressInput(podAddress.host, podAddress.port), + config.serverVersion, + RoleInput(role.name) + ) ).unit def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.logWarning(s"Notifying Shard Manager about unhealthy pod $podAddress") *> send(GraphQLClient.Mutations.notifyUnhealthyPod(PodAddressInput(podAddress.host, podAddress.port))) - def getAssignments: Task[Map[Int, Option[PodAddress]]] = + def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] = send( GraphQLClient.Queries - .getAssignments( + .getAssignments(role.name)( GraphQLClient.Assignment.shardId ~ GraphQLClient.Assignment .pod((GraphQLClient.PodAddress.host ~ GraphQLClient.PodAddress.port).map { case (host, port) => PodAddress(host, port) diff --git a/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala b/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala index 8527f5b..79ea52a 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala @@ -37,25 +37,27 @@ class Sharding private ( val register: Task[Unit] = ZIO.logDebug(s"Registering pod $address to Shard Manager") *> isShuttingDownRef.set(false) *> - shardManager.register(address) + shardManager.register(address, config.role) val unregister: UIO[Unit] = // ping the shard manager first to stop if it's not available - shardManager.getAssignments.foldCauseZIO( - ZIO.logWarningCause("Shard Manager not available. Can't unregister cleanly", _), - _ => - ZIO.logDebug(s"Stopping local entities") *> - isShuttingDownRef.set(true) *> - entityStates.get.flatMap( - ZIO.foreachDiscard(_) { case (name, entity) => - entity.entityManager.terminateAllEntities.forkDaemon // run in a daemon fiber to make sure it doesn't get interrupted - .flatMap(_.join) - .catchAllCause(ZIO.logErrorCause(s"Error during stop of entity $name", _)) - } - ) *> - ZIO.logDebug(s"Unregistering pod $address to Shard Manager") *> - shardManager.unregister(address).catchAllCause(ZIO.logErrorCause("Error during unregister", _)) - ) + shardManager + .getAssignments(config.role) + .foldCauseZIO( + ZIO.logWarningCause("Shard Manager not available. Can't unregister cleanly", _), + _ => + ZIO.logDebug(s"Stopping local entities") *> + isShuttingDownRef.set(true) *> + entityStates.get.flatMap( + ZIO.foreachDiscard(_) { case (name, entity) => + entity.entityManager.terminateAllEntities.forkDaemon // run in a daemon fiber to make sure it doesn't get interrupted + .flatMap(_.join) + .catchAllCause(ZIO.logErrorCause(s"Error during stop of entity $name", _)) + } + ) *> + ZIO.logDebug(s"Unregistering pod $address to Shard Manager") *> + shardManager.unregister(address, config.role).catchAllCause(ZIO.logErrorCause("Error during unregister", _)) + ) val isSingletonNode: UIO[Boolean] = // Start singletons on the pod hosting shard 1. @@ -162,10 +164,10 @@ class Sharding private ( latch <- Promise.make[Nothing, Unit] assignmentStream = ZStream.fromZIO( // first, get the assignments from the shard manager directly - shardManager.getAssignments.map(_ -> true) + shardManager.getAssignments(config.role).map(_ -> true) ) ++ // then, get assignments changes from Redis - storage.assignmentsStream.map(_ -> false) + storage.assignmentsStream(config.role).map(_ -> false) _ <- assignmentStream.mapZIO { case (assignmentsOpt, replaceAllAssignments) => updateAssignments(assignmentsOpt, replaceAllAssignments) *> latch.succeed(()).when(replaceAllAssignments) }.runDrain diff --git a/entities/src/main/scala/com/devsisters/shardcake/internal/GraphQLClient.scala b/entities/src/main/scala/com/devsisters/shardcake/internal/GraphQLClient.scala index 3df852f..b8e9138 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/internal/GraphQLClient.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/internal/GraphQLClient.scala @@ -8,9 +8,9 @@ private[shardcake] object GraphQLClient { type Assignment object Assignment { - def shardId: SelectionBuilder[Assignment, Int] = + def shardId: SelectionBuilder[Assignment, Int] = _root_.caliban.client.SelectionBuilder.Field("shardId", Scalar()) - def pod[A](innerSelection: SelectionBuilder[PodAddress, A]): SelectionBuilder[Assignment, Option[A]] = + def pod[A](innerSelection: SelectionBuilder[PodAddress, A]): SelectionBuilder[Assignment, scala.Option[A]] = _root_.caliban.client.SelectionBuilder.Field("pod", OptionOf(Obj(innerSelection))) } @@ -20,10 +20,37 @@ private[shardcake] object GraphQLClient { def port: SelectionBuilder[PodAddress, Int] = _root_.caliban.client.SelectionBuilder.Field("port", Scalar()) } + type PodHealthChecked + object PodHealthChecked { + def pod[A](innerSelection: SelectionBuilder[PodAddress, A]): SelectionBuilder[PodHealthChecked, A] = + _root_.caliban.client.SelectionBuilder.Field("pod", Obj(innerSelection)) + } + + type PodRegistered + object PodRegistered { + def pod[A](innerSelection: SelectionBuilder[PodAddress, A]): SelectionBuilder[PodRegistered, A] = + _root_.caliban.client.SelectionBuilder.Field("pod", Obj(innerSelection)) + def role[A](innerSelection: SelectionBuilder[Role, A]): SelectionBuilder[PodRegistered, A] = + _root_.caliban.client.SelectionBuilder.Field("role", Obj(innerSelection)) + } + + type PodUnregistered + object PodUnregistered { + def pod[A](innerSelection: SelectionBuilder[PodAddress, A]): SelectionBuilder[PodUnregistered, A] = + _root_.caliban.client.SelectionBuilder.Field("pod", Obj(innerSelection)) + } + + type Role + object Role { + def name: SelectionBuilder[Role, String] = _root_.caliban.client.SelectionBuilder.Field("name", Scalar()) + } + type ShardsAssigned object ShardsAssigned { def pod[A](innerSelection: SelectionBuilder[PodAddress, A]): SelectionBuilder[ShardsAssigned, A] = _root_.caliban.client.SelectionBuilder.Field("pod", Obj(innerSelection)) + def role[A](innerSelection: SelectionBuilder[Role, A]): SelectionBuilder[ShardsAssigned, A] = + _root_.caliban.client.SelectionBuilder.Field("role", Obj(innerSelection)) def shards: SelectionBuilder[ShardsAssigned, List[Int]] = _root_.caliban.client.SelectionBuilder.Field("shards", ListOf(Scalar())) } @@ -32,6 +59,8 @@ private[shardcake] object GraphQLClient { object ShardsUnassigned { def pod[A](innerSelection: SelectionBuilder[PodAddress, A]): SelectionBuilder[ShardsUnassigned, A] = _root_.caliban.client.SelectionBuilder.Field("pod", Obj(innerSelection)) + def role[A](innerSelection: SelectionBuilder[Role, A]): SelectionBuilder[ShardsUnassigned, A] = + _root_.caliban.client.SelectionBuilder.Field("role", Obj(innerSelection)) def shards: SelectionBuilder[ShardsUnassigned, List[Int]] = _root_.caliban.client.SelectionBuilder.Field("shards", ListOf(Scalar())) } @@ -48,59 +77,87 @@ private[shardcake] object GraphQLClient { ) } } + final case class RoleInput(name: String) + object RoleInput { + implicit val encoder: ArgEncoder[RoleInput] = new ArgEncoder[RoleInput] { + override def encode(value: RoleInput): __Value = + __ObjectValue(List("name" -> implicitly[ArgEncoder[String]].encode(value.name))) + } + } type Queries = _root_.caliban.client.Operations.RootQuery object Queries { - def getAssignments[A]( + def getAssignments[A](role: String)( innerSelection: SelectionBuilder[Assignment, A] - ): SelectionBuilder[_root_.caliban.client.Operations.RootQuery, List[A]] = - _root_.caliban.client.SelectionBuilder.Field("getAssignments", ListOf(Obj(innerSelection))) + )(implicit encoder0: ArgEncoder[String]): SelectionBuilder[_root_.caliban.client.Operations.RootQuery, List[A]] = + _root_.caliban.client.SelectionBuilder.Field( + "getAssignments", + ListOf(Obj(innerSelection)), + arguments = List(Argument("role", role, "String!")(encoder0)) + ) } type Mutations = _root_.caliban.client.Operations.RootMutation object Mutations { - def register(address: PodAddressInput, version: String)(implicit + def register(address: PodAddressInput, version: String, role: RoleInput)(implicit encoder0: ArgEncoder[PodAddressInput], - encoder1: ArgEncoder[String] - ): SelectionBuilder[_root_.caliban.client.Operations.RootMutation, Option[Unit]] = + encoder1: ArgEncoder[String], + encoder2: ArgEncoder[RoleInput] + ): SelectionBuilder[_root_.caliban.client.Operations.RootMutation, scala.Option[Unit]] = _root_.caliban.client.SelectionBuilder.Field( "register", OptionOf(Scalar()), arguments = List( Argument("address", address, "PodAddressInput!")(encoder0), - Argument("version", version, "String!")(encoder1) + Argument("version", version, "String!")(encoder1), + Argument("role", role, "RoleInput!")(encoder2) ) ) - def unregister(address: PodAddressInput, version: String)(implicit + def unregister(address: PodAddressInput, version: String, role: RoleInput)(implicit encoder0: ArgEncoder[PodAddressInput], - encoder1: ArgEncoder[String] - ): SelectionBuilder[_root_.caliban.client.Operations.RootMutation, Option[Unit]] = + encoder1: ArgEncoder[String], + encoder2: ArgEncoder[RoleInput] + ): SelectionBuilder[_root_.caliban.client.Operations.RootMutation, scala.Option[Unit]] = _root_.caliban.client.SelectionBuilder.Field( "unregister", OptionOf(Scalar()), arguments = List( Argument("address", address, "PodAddressInput!")(encoder0), - Argument("version", version, "String!")(encoder1) + Argument("version", version, "String!")(encoder1), + Argument("role", role, "RoleInput!")(encoder2) ) ) def notifyUnhealthyPod(podAddress: PodAddressInput)(implicit encoder0: ArgEncoder[PodAddressInput] - ): SelectionBuilder[_root_.caliban.client.Operations.RootMutation, Unit] = _root_.caliban.client.SelectionBuilder - .Field( + ): SelectionBuilder[_root_.caliban.client.Operations.RootMutation, Unit] = + _root_.caliban.client.SelectionBuilder.Field( "notifyUnhealthyPod", Scalar(), arguments = List(Argument("podAddress", podAddress, "PodAddressInput!")(encoder0)) ) + def checkAllPodsHealth: SelectionBuilder[_root_.caliban.client.Operations.RootMutation, Unit] = + _root_.caliban.client.SelectionBuilder.Field("checkAllPodsHealth", Scalar()) } type Subscriptions = _root_.caliban.client.Operations.RootSubscription object Subscriptions { def events[A]( + onPodHealthChecked: SelectionBuilder[PodHealthChecked, A], + onPodRegistered: SelectionBuilder[PodRegistered, A], + onPodUnregistered: SelectionBuilder[PodUnregistered, A], onShardsAssigned: SelectionBuilder[ShardsAssigned, A], onShardsUnassigned: SelectionBuilder[ShardsUnassigned, A] - ): SelectionBuilder[_root_.caliban.client.Operations.RootSubscription, A] = _root_.caliban.client.SelectionBuilder - .Field( + ): SelectionBuilder[_root_.caliban.client.Operations.RootSubscription, A] = + _root_.caliban.client.SelectionBuilder.Field( "events", - ChoiceOf(Map("ShardsAssigned" -> Obj(onShardsAssigned), "ShardsUnassigned" -> Obj(onShardsUnassigned))) + ChoiceOf( + Map( + "PodHealthChecked" -> Obj(onPodHealthChecked), + "PodRegistered" -> Obj(onPodRegistered), + "PodUnregistered" -> Obj(onPodUnregistered), + "ShardsAssigned" -> Obj(onShardsAssigned), + "ShardsUnassigned" -> Obj(onShardsUnassigned) + ) + ) ) } From 6df86fd865ffd08541f1c9386c5139dc04351640 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Thu, 3 Apr 2025 12:38:59 +0900 Subject: [PATCH 07/23] Tag metrics --- .../main/scala/com/devsisters/shardcake/Sharding.scala | 9 +++++---- .../devsisters/shardcake/internal/EntityManager.scala | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala b/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala index 79ea52a..bfbc5b3 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala @@ -70,7 +70,7 @@ class Sharding private ( ZIO.foreach(singletons) { case (name, run, None) => ZIO.logDebug(s"Starting singleton $name") *> - Metrics.singletons.tagged("singleton_name", name).increment *> + Metrics.singletons.tagged("role", config.role.name).tagged("singleton_name", name).increment *> run.forkDaemon.map(fiber => (name, run, Some(fiber))) case other => ZIO.succeed(other) } @@ -85,7 +85,7 @@ class Sharding private ( ZIO.foreach(singletons) { case (name, run, Some(fiber)) => ZIO.logDebug(s"Stopping singleton $name") *> - Metrics.singletons.tagged("singleton_name", name).decrement *> + Metrics.singletons.tagged("role", config.role.name).tagged("singleton_name", name).decrement *> fiber.interrupt.as((name, run, None)) case other => ZIO.succeed(other) } @@ -102,7 +102,7 @@ class Sharding private ( ZIO .unlessZIO(isShuttingDown) { shardAssignments.update(shards.foldLeft(_) { case (map, shard) => map.updated(shard, address) }) *> - Metrics.shards.incrementBy(shards.size) *> + Metrics.shards.tagged("role", config.role.name).incrementBy(shards.size) *> startSingletonsIfNeeded *> ZIO.logDebug(s"Assigned shards: ${renderShardIds(shards)}") } @@ -118,7 +118,7 @@ class Sharding private ( _.entityManager.terminateEntitiesOnShards(shards) // this will return once all shards are terminated ) ) *> - Metrics.shards.decrementBy(shards.size) *> + Metrics.shards.tagged("role", config.role.name).decrementBy(shards.size) *> stopSingletonsIfNeeded *> ZIO.logDebug(s"Unassigned shards: ${renderShardIds(shards)}") @@ -147,6 +147,7 @@ class Sharding private ( val assignments = assignmentsOpt.flatMap { case (k, v) => v.map(k -> _) } ZIO.logDebug("Received new shard assignments") *> Metrics.shards + .tagged("role", config.role.name) .set(assignmentsOpt.count { case (_, podOpt) => podOpt.contains(address) }) .when(replaceAllAssignments) *> (if (replaceAllAssignments) shardAssignments.set(assignments) diff --git a/entities/src/main/scala/com/devsisters/shardcake/internal/EntityManager.scala b/entities/src/main/scala/com/devsisters/shardcake/internal/EntityManager.scala index 18888f5..2c04838 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/internal/EntityManager.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/internal/EntityManager.scala @@ -57,7 +57,7 @@ private[shardcake] object EntityManager { config: Config, entityMaxIdleTime: Option[Duration] ) extends EntityManager[Req] { - private val gauge = Metrics.entities.tagged("type", recipientType.name) + private val gauge = Metrics.entities.tagged("role", config.role.name).tagged("type", recipientType.name) private def startExpirationFiber(entityId: String): UIO[Fiber[Nothing, Unit]] = { val maxIdleTime = entityMaxIdleTime getOrElse config.entityMaxIdleTime From b38bbc5ab11f8c6c9677bfb2099b6b491e697edc Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Thu, 3 Apr 2025 12:49:48 +0900 Subject: [PATCH 08/23] Simplify unregister --- .../scala/com/devsisters/shardcake/Server.scala | 2 +- .../devsisters/shardcake/ShardManagerClient.scala | 14 ++++---------- .../scala/com/devsisters/shardcake/Sharding.scala | 2 +- .../shardcake/internal/GraphQLClient.scala | 15 ++++----------- .../com/devsisters/shardcake/GraphQLApi.scala | 4 ++-- 5 files changed, 12 insertions(+), 25 deletions(-) diff --git a/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala b/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala index 63b5d78..b673afa 100644 --- a/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala +++ b/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala @@ -26,7 +26,7 @@ object Server { shards = (1 to config.numberOfShards).map(_ -> Some(pod)).toMap } yield new ShardManagerClient { def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit - def unregister(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit + def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards) } diff --git a/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala b/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala index 280df48..e65e2b7 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala @@ -13,7 +13,7 @@ import zio.{ Config => _, _ } */ trait ShardManagerClient { def register(podAddress: PodAddress, role: Role): Task[Unit] - def unregister(podAddress: PodAddress, role: Role): Task[Unit] + def unregister(podAddress: PodAddress): Task[Unit] def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] } @@ -50,7 +50,7 @@ object ShardManagerClient { shards = (1 to config.numberOfShards).map(_ -> Some(pod)).toMap } yield new ShardManagerClient { def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit - def unregister(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit + def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards) } @@ -69,14 +69,8 @@ object ShardManagerClient { ) ).unit - def unregister(podAddress: PodAddress, role: Role): Task[Unit] = - send( - GraphQLClient.Mutations.unregister( - PodAddressInput(podAddress.host, podAddress.port), - config.serverVersion, - RoleInput(role.name) - ) - ).unit + def unregister(podAddress: PodAddress): Task[Unit] = + send(GraphQLClient.Mutations.unregister(PodAddressInput(podAddress.host, podAddress.port))).unit def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.logWarning(s"Notifying Shard Manager about unhealthy pod $podAddress") *> diff --git a/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala b/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala index bfbc5b3..82c4d8e 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala @@ -56,7 +56,7 @@ class Sharding private ( } ) *> ZIO.logDebug(s"Unregistering pod $address to Shard Manager") *> - shardManager.unregister(address, config.role).catchAllCause(ZIO.logErrorCause("Error during unregister", _)) + shardManager.unregister(address).catchAllCause(ZIO.logErrorCause("Error during unregister", _)) ) val isSingletonNode: UIO[Boolean] = diff --git a/entities/src/main/scala/com/devsisters/shardcake/internal/GraphQLClient.scala b/entities/src/main/scala/com/devsisters/shardcake/internal/GraphQLClient.scala index b8e9138..4508b40 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/internal/GraphQLClient.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/internal/GraphQLClient.scala @@ -8,8 +8,7 @@ private[shardcake] object GraphQLClient { type Assignment object Assignment { - def shardId: SelectionBuilder[Assignment, Int] = - _root_.caliban.client.SelectionBuilder.Field("shardId", Scalar()) + def shardId: SelectionBuilder[Assignment, Int] = _root_.caliban.client.SelectionBuilder.Field("shardId", Scalar()) def pod[A](innerSelection: SelectionBuilder[PodAddress, A]): SelectionBuilder[Assignment, scala.Option[A]] = _root_.caliban.client.SelectionBuilder.Field("pod", OptionOf(Obj(innerSelection))) } @@ -112,19 +111,13 @@ private[shardcake] object GraphQLClient { Argument("role", role, "RoleInput!")(encoder2) ) ) - def unregister(address: PodAddressInput, version: String, role: RoleInput)(implicit - encoder0: ArgEncoder[PodAddressInput], - encoder1: ArgEncoder[String], - encoder2: ArgEncoder[RoleInput] + def unregister(podAddress: PodAddressInput)(implicit + encoder0: ArgEncoder[PodAddressInput] ): SelectionBuilder[_root_.caliban.client.Operations.RootMutation, scala.Option[Unit]] = _root_.caliban.client.SelectionBuilder.Field( "unregister", OptionOf(Scalar()), - arguments = List( - Argument("address", address, "PodAddressInput!")(encoder0), - Argument("version", version, "String!")(encoder1), - Argument("role", role, "RoleInput!")(encoder2) - ) + arguments = List(Argument("podAddress", podAddress, "PodAddressInput!")(encoder0)) ) def notifyUnhealthyPod(podAddress: PodAddressInput)(implicit encoder0: ArgEncoder[PodAddressInput] diff --git a/manager/src/main/scala/com/devsisters/shardcake/GraphQLApi.scala b/manager/src/main/scala/com/devsisters/shardcake/GraphQLApi.scala index 8b4ab4a..bdde54b 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/GraphQLApi.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/GraphQLApi.scala @@ -16,7 +16,7 @@ object GraphQLApi extends GenericSchema[ShardManager] { case class PodAddressArgs(podAddress: PodAddress) case class Mutations( register: Pod => RIO[ShardManager, Unit], - unregister: Pod => RIO[ShardManager, Unit], + unregister: PodAddressArgs => RIO[ShardManager, Unit], notifyUnhealthyPod: PodAddressArgs => URIO[ShardManager, Unit], checkAllPodsHealth: URIO[ShardManager, Unit] ) @@ -32,7 +32,7 @@ object GraphQLApi extends GenericSchema[ShardManager] { ), Mutations( pod => ZIO.serviceWithZIO(_.register(pod)), - pod => ZIO.serviceWithZIO(_.unregister(pod.address)), + args => ZIO.serviceWithZIO(_.unregister(args.podAddress)), args => ZIO.serviceWithZIO(_.notifyUnhealthyPod(args.podAddress)), ZIO.serviceWithZIO(_.checkAllPodsHealth) ), From cb2bd0b478df6f9cd54affd3b30bbb865ca1f9e9 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Thu, 3 Apr 2025 16:29:44 +0900 Subject: [PATCH 09/23] Polish --- .../devsisters/shardcake/ShardManager.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index e73776c..b74c946 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -33,22 +33,22 @@ class ShardManager( def register(pod: Pod): Task[Unit] = ZIO.ifZIO(healthApi.isAlive(pod.address))( onTrue = for { - _ <- ZIO.logInfo(s"Registering $pod") - _ <- ZIO.whenZIO(stateRef.get.map(_.exists { case (role, state) => - state.pods.get(pod.address).exists(_ => role != pod.role) - }))(ZIO.fail(new RuntimeException(s"Pod $pod is already registered with a different role"))) - cdt <- ZIO.succeed(OffsetDateTime.now()) - state <- stateRef.modify { states => - val previous = states.getOrElse(pod.role, ShardManagerState(config.getNumberOfShards(pod.role))) - val state = previous.copy(pods = previous.pods.updated(pod.address, PodWithMetadata(pod, cdt))) - (state, states.updated(pod.role, state)) - } - _ <- ManagerMetrics.pods.tagged("role", pod.role.name).increment - _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address, pod.role)) - _ <- ZIO.when(state.unassignedShards.nonEmpty)( - rebalance(pod.role, rebalanceImmediately = false).forkDaemon - ) - _ <- persistPods.forkDaemon + _ <- ZIO.logInfo(s"Registering $pod") + _ <- ZIO.whenZIO(stateRef.get.map(_.exists { case (role, state) => + state.pods.get(pod.address).exists(_ => role != pod.role) + }))(ZIO.fail(new RuntimeException(s"Pod $pod is already registered with a different role"))) + cdt <- ZIO.succeed(OffsetDateTime.now()) + triggerRebalance <- stateRef.modify { states => + val previous = + states.getOrElse(pod.role, ShardManagerState(config.getNumberOfShards(pod.role))) + val state = + previous.copy(pods = previous.pods.updated(pod.address, PodWithMetadata(pod, cdt))) + (state.unassignedShards.nonEmpty, states.updated(pod.role, state)) + } + _ <- ManagerMetrics.pods.tagged("role", pod.role.name).increment + _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address, pod.role)) + _ <- ZIO.when(triggerRebalance)(rebalance(pod.role, rebalanceImmediately = false).forkDaemon) + _ <- persistPods.forkDaemon } yield (), onFalse = ZIO.logWarning(s"Pod $pod requested to register but is not alive, ignoring") *> ZIO.fail(new RuntimeException(s"Pod $pod is not healthy, refusing to register")) From 33db721ecc1b58220182c96a1fa8c2ab4f85a7c2 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Thu, 3 Apr 2025 16:34:02 +0900 Subject: [PATCH 10/23] Polish --- .../scala/com/devsisters/shardcake/ShardManager.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index b74c946..d514dab 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -217,14 +217,8 @@ class ShardManager( ) ) - private val persistAllAssignments: UIO[Unit] = - withRetry( - stateRef.get.flatMap(states => - ZIO.foreachDiscard(states) { case (role, assignments) => - stateRepository.saveAssignments(role, assignments.shards) - } - ) - ) + private def persistAllAssignments: UIO[Unit] = + stateRef.get.flatMap(states => ZIO.foreachDiscard(states.keys)(persistAssignments)) private def persistPods: UIO[Unit] = withRetry( From fe9692124e8a0a26d44a9f5df472e7d07e6b8bd4 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Thu, 3 Apr 2025 16:36:41 +0900 Subject: [PATCH 11/23] Polish --- .../src/main/scala/com/devsisters/shardcake/ShardManager.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index d514dab..c80a1d2 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -481,8 +481,7 @@ object ShardManager { } // don't assign too many shards to the same pods, unless we need rebalance immediately .filter { case (pod, _) => - rebalanceImmediately || - assignments.count { case (_, p) => p == pod } < state.shards.size * rebalanceRate + rebalanceImmediately || assignments.count { case (_, p) => p == pod } < state.shards.size * rebalanceRate } // don't assign to a pod that was unassigned in the same rebalance .filterNot { case (pod, _) => unassignedPods.contains(pod) } From dcfef0a7dac3f80546b1dd2e299156be4475daad Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 09:49:59 +0900 Subject: [PATCH 12/23] Config by role --- .../shardcake/interfaces/PodsHealth.scala | 8 ++-- .../com/devsisters/shardcake/Config.scala | 10 ++--- .../com/devsisters/shardcake/K8sConfig.scala | 6 +-- .../devsisters/shardcake/K8sPodsHealth.scala | 12 ++--- .../devsisters/shardcake/ShardManager.scala | 44 +++++++++---------- 5 files changed, 40 insertions(+), 40 deletions(-) diff --git a/core/src/main/scala/com/devsisters/shardcake/interfaces/PodsHealth.scala b/core/src/main/scala/com/devsisters/shardcake/interfaces/PodsHealth.scala index d72a081..831fae5 100644 --- a/core/src/main/scala/com/devsisters/shardcake/interfaces/PodsHealth.scala +++ b/core/src/main/scala/com/devsisters/shardcake/interfaces/PodsHealth.scala @@ -1,6 +1,6 @@ package com.devsisters.shardcake.interfaces -import com.devsisters.shardcake.PodAddress +import com.devsisters.shardcake.{ Pod, PodAddress, Role } import zio.{ UIO, ULayer, ZIO, ZLayer } /** @@ -15,7 +15,7 @@ trait PodsHealth { /** * Check if a pod is still alive. */ - def isAlive(podAddress: PodAddress): UIO[Boolean] + def isAlive(pod: Pod): UIO[Boolean] } object PodsHealth { @@ -26,7 +26,7 @@ object PodsHealth { */ val noop: ULayer[PodsHealth] = ZLayer.succeed(new PodsHealth { - def isAlive(podAddress: PodAddress): UIO[Boolean] = ZIO.succeed(true) + def isAlive(pod: Pod): UIO[Boolean] = ZIO.succeed(true) }) /** @@ -35,6 +35,6 @@ object PodsHealth { */ val local: ZLayer[Pods, Nothing, PodsHealth] = ZLayer { - ZIO.serviceWith[Pods](podApi => (podAddress: PodAddress) => podApi.ping(podAddress).option.map(_.isDefined)) + ZIO.serviceWith[Pods](podApi => (pod: Pod) => podApi.ping(pod.address).option.map(_.isDefined)) } } diff --git a/entities/src/main/scala/com/devsisters/shardcake/Config.scala b/entities/src/main/scala/com/devsisters/shardcake/Config.scala index ba128e5..b47331e 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/Config.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/Config.scala @@ -6,6 +6,7 @@ import zio._ /** * Sharding configuration + * @param role role of the current pod * @param numberOfShards number of shards (see documentation on how to choose this), should be same on all nodes * @param selfHost hostname or IP address of the current pod * @param shardingPort port used for pods to communicate together @@ -17,9 +18,9 @@ import zio._ * @param refreshAssignmentsRetryInterval retry interval in case of failure getting shard assignments from storage * @param unhealthyPodReportInterval interval to report unhealthy pods to the Shard Manager (this exists to prevent calling the Shard Manager for each failed message) * @param simulateRemotePods disable optimizations when sending a message to an entity hosted on the local shards (this will force serialization of all messages) - * @param role role of the current pod */ case class Config( + role: Role, numberOfShards: Int, selfHost: String, shardingPort: Int, @@ -30,12 +31,12 @@ case class Config( sendTimeout: Duration, refreshAssignmentsRetryInterval: Duration, unhealthyPodReportInterval: Duration, - simulateRemotePods: Boolean, - role: Role + simulateRemotePods: Boolean ) object Config { val default: Config = Config( + role = Role.default, numberOfShards = 300, selfHost = "localhost", shardingPort = 54321, @@ -46,7 +47,6 @@ object Config { sendTimeout = 10 seconds, refreshAssignmentsRetryInterval = 5 seconds, unhealthyPodReportInterval = 5 seconds, - simulateRemotePods = false, - role = Role.default + simulateRemotePods = false ) } diff --git a/health-k8s/src/main/scala/com/devsisters/shardcake/K8sConfig.scala b/health-k8s/src/main/scala/com/devsisters/shardcake/K8sConfig.scala index c84ed06..ba7d95c 100644 --- a/health-k8s/src/main/scala/com/devsisters/shardcake/K8sConfig.scala +++ b/health-k8s/src/main/scala/com/devsisters/shardcake/K8sConfig.scala @@ -13,10 +13,10 @@ import zio._ case class K8sConfig( cacheSize: Int, cacheDuration: Duration, - namespace: Option[K8sNamespace], - labelSelector: Option[LabelSelector] + namespace: Role => Option[K8sNamespace], + labelSelector: Role => Option[LabelSelector] ) object K8sConfig { - val default: K8sConfig = K8sConfig(500, 3 seconds, None, None) + val default: K8sConfig = K8sConfig(500, 3 seconds, _ => None, _ => None) } diff --git a/health-k8s/src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala b/health-k8s/src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala index 7f162ab..77014d4 100644 --- a/health-k8s/src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala +++ b/health-k8s/src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala @@ -24,24 +24,24 @@ object K8sPodsHealth { .make( config.cacheSize, config.cacheDuration, - Lookup { (podAddress: PodAddress) => + Lookup { pod: Pod => pods .getAll( - config.namespace, + config.namespace(pod.role), 1, - Some(FieldSelector.FieldEquals(Chunk("status", "podIP"), podAddress.host)), - config.labelSelector + Some(FieldSelector.FieldEquals(Chunk("status", "podIP"), pod.address.host)), + config.labelSelector(pod.role) ) .runHead .map(_.isDefined) - .tap(ZIO.unless(_)(ZIO.logWarning(s"$podAddress is not found in k8s"))) + .tap(ZIO.unless(_)(ZIO.logWarning(s"${pod.address} is not found in k8s"))) .catchAllCause(cause => ZIO.logErrorCause(s"Error communicating with k8s", cause.map(asException)).as(true) ) } ) } yield new PodsHealth { - def isAlive(podAddress: PodAddress): UIO[Boolean] = cache.get(podAddress) + def isAlive(pod: Pod): UIO[Boolean] = cache.get(pod) } } diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index c80a1d2..bc0cc4b 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -31,7 +31,7 @@ class ShardManager( ZStream.fromHub(eventsHub) def register(pod: Pod): Task[Unit] = - ZIO.ifZIO(healthApi.isAlive(pod.address))( + ZIO.ifZIO(healthApi.isAlive(pod))( onTrue = for { _ <- ZIO.logInfo(s"Registering $pod") _ <- ZIO.whenZIO(stateRef.get.map(_.exists { case (role, state) => @@ -54,18 +54,19 @@ class ShardManager( ZIO.fail(new RuntimeException(s"Pod $pod is not healthy, refusing to register")) ) - private def podExists(podAddress: PodAddress): UIO[Boolean] = - podRole(podAddress).map(_.isDefined) - - private def podRole(podAddress: PodAddress): UIO[Option[Role]] = - stateRef.get.map(_.collectFirst { case (role, state) if state.pods.contains(podAddress) => role }) + private def findPod(podAddress: PodAddress): UIO[Option[Pod]] = + stateRef.get + .map(_.values.collectFirst { + case state if state.pods.contains(podAddress) => state.pods.get(podAddress).map(_.pod) + }) + .map(_.flatten) def notifyUnhealthyPod(podAddress: PodAddress, ignoreMetric: Boolean = false): UIO[Unit] = ZIO - .whenZIODiscard(podExists(podAddress)) { + .whenCaseZIODiscard(findPod(podAddress)) { case Some(pod) => ManagerMetrics.podHealthChecked.tagged("pod_address", podAddress.toString).increment.unless(ignoreMetric) *> eventsHub.publish(ShardingEvent.PodHealthChecked(podAddress)) *> - ZIO.unlessZIO(healthApi.isAlive(podAddress))( + ZIO.unlessZIO(healthApi.isAlive(pod))( ZIO.logWarning(s"Pod $podAddress is not alive, unregistering") *> unregister(podAddress) ) } @@ -77,18 +78,18 @@ class ShardManager( } yield () def unregister(podAddress: PodAddress): UIO[Unit] = - ZIO.whenCaseZIODiscard(podRole(podAddress)) { case Some(role) => + ZIO.whenCaseZIODiscard(findPod(podAddress)) { case Some(pod) => for { _ <- ZIO.logInfo(s"Unregistering $podAddress") unassignments <- stateRef.modify { states => - val stateOpt = states.get(role) + val stateOpt = states.get(pod.role) ( stateOpt .map(_.shards.collect { case (shard, Some(p)) if p == podAddress => shard }.toSet) .getOrElse(Set.empty), stateOpt.fold(states)(state => states.updated( - role, + pod.role, state.copy( pods = state.pods - podAddress, shards = state.shards.map { case (k, v) => @@ -99,18 +100,18 @@ class ShardManager( ) ) } - _ <- ManagerMetrics.pods.tagged("role", role.name).decrement + _ <- ManagerMetrics.pods.tagged("role", pod.role.name).decrement _ <- ManagerMetrics.assignedShards - .tagged("role", role.name) + .tagged("role", pod.role.name) .tagged("pod_address", podAddress.toString) .decrementBy(unassignments.size) - _ <- ManagerMetrics.unassignedShards.tagged("role", role.name).incrementBy(unassignments.size) + _ <- ManagerMetrics.unassignedShards.tagged("role", pod.role.name).incrementBy(unassignments.size) _ <- eventsHub.publish(ShardingEvent.PodUnregistered(podAddress)) _ <- eventsHub - .publish(ShardingEvent.ShardsUnassigned(podAddress, role, unassignments)) + .publish(ShardingEvent.ShardsUnassigned(podAddress, pod.role, unassignments)) .when(unassignments.nonEmpty) _ <- persistPods.forkDaemon - _ <- rebalance(role, rebalanceImmediately = true).forkDaemon + _ <- rebalance(pod.role, rebalanceImmediately = true).forkDaemon } yield () } @@ -261,16 +262,15 @@ object ShardManager { podApi <- ZIO.service[Pods] pods <- stateRepository.getPods // remove unhealthy pods on startup - failedFilteredPods <- - ZIO.partitionPar(pods) { addrPod => - ZIO.ifZIO(healthApi.isAlive(addrPod._1))(ZIO.succeed(addrPod), ZIO.fail(addrPod._2)) - } + failedFilteredPods <- ZIO.partitionPar(pods.values) { pod => + ZIO.ifZIO(healthApi.isAlive(pod))(ZIO.succeed(pod), ZIO.fail(pod)) + } (failedPods, filtered) = failedFilteredPods _ <- ZIO.when(failedPods.nonEmpty)( ZIO.logInfo(s"Ignoring pods that are no longer alive ${failedPods.mkString("[", ", ", "]")}") ) - filteredPods = filtered.toMap - roles = filteredPods.map(_._2.role).toSet + filteredPods = filtered.map(p => p.address -> p).toMap + roles = filtered.map(_.role).toSet _ <- ZIO.when(filteredPods.nonEmpty)(ZIO.logInfo(s"Recovered pods ${filteredPods.mkString("[", ", ", "]")}")) rolePods = filteredPods.groupBy { case (_, pod) => pod.role }.map { case (role, pods) => role -> pods.values } roleAssignments <- ZIO From 083407e51562135537d9e18a03efc6ad83426b01 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 09:55:18 +0900 Subject: [PATCH 13/23] Polish --- .../com/devsisters/shardcake/ManagerConfig.scala | 14 ++++---------- .../com/devsisters/shardcake/ShardManager.scala | 8 ++++---- .../devsisters/shardcake/ShardManagerSpec.scala | 16 ++++++++-------- 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ManagerConfig.scala b/manager/src/main/scala/com/devsisters/shardcake/ManagerConfig.scala index bf32ff6..298cdff 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ManagerConfig.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ManagerConfig.scala @@ -4,8 +4,7 @@ import zio._ /** * Shard Manager configuration - * @param numberOfShards number of shards (see documentation on how to choose this), should be same on all nodes - * @param numberOfShardsPerRole overrides of the number of shards per role + * @param numberOfShards number of shards (see documentation on how to choose this), should be same on all pods * @param apiPort port to expose the GraphQL API * @param rebalanceInterval interval for regular rebalancing of shards * @param rebalanceRetryInterval retry interval for rebalancing when some shards failed to be rebalanced @@ -16,8 +15,7 @@ import zio._ * @param podHealthCheckInterval interval for checking pod health */ case class ManagerConfig( - numberOfShards: Int, - numberOfShardsPerRole: Map[Role, Int], + numberOfShards: Role => Int, apiPort: Int, rebalanceInterval: Duration, rebalanceRetryInterval: Duration, @@ -26,16 +24,12 @@ case class ManagerConfig( persistRetryCount: Int, rebalanceRate: Double, podHealthCheckInterval: Duration -) { - def getNumberOfShards(role: Role): Int = - numberOfShardsPerRole.getOrElse(role, numberOfShards) -} +) object ManagerConfig { val default: ManagerConfig = ManagerConfig( - numberOfShards = 300, - numberOfShardsPerRole = Map.empty, + numberOfShards = _ => 300, apiPort = 8080, rebalanceInterval = 20 seconds, rebalanceRetryInterval = 10 seconds, diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index bc0cc4b..7d40028 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -40,7 +40,7 @@ class ShardManager( cdt <- ZIO.succeed(OffsetDateTime.now()) triggerRebalance <- stateRef.modify { states => val previous = - states.getOrElse(pod.role, ShardManagerState(config.getNumberOfShards(pod.role))) + states.getOrElse(pod.role, ShardManagerState(config.numberOfShards(pod.role))) val state = previous.copy(pods = previous.pods.updated(pod.address, PodWithMetadata(pod, cdt))) (state.unassignedShards.nonEmpty, states.updated(pod.role, state)) @@ -126,7 +126,7 @@ class ShardManager( private def rebalance(role: Role, rebalanceImmediately: Boolean): UIO[Unit] = getSemaphore(role).flatMap(_.withPermit { for { - state <- stateRef.get.map(_.getOrElse(role, ShardManagerState(config.getNumberOfShards(role)))) + state <- stateRef.get.map(_.getOrElse(role, ShardManagerState(config.numberOfShards(role)))) // find which shards to assign and unassign (assignments, unassignments) = if (rebalanceImmediately || state.unassignedShards.nonEmpty) decideAssignmentsForUnassignedShards(state) @@ -306,9 +306,9 @@ object ShardManager { initialStates = rolePods.map { case (role, pods) => role -> ShardManagerState( pods.map(pod => pod.address -> PodWithMetadata(pod, cdt)).toMap, - (1 to config.getNumberOfShards(role)).map(_ -> None).toMap ++ + (1 to config.numberOfShards(role)).map(_ -> None).toMap ++ roleAssignments.getOrElse(role, Map.empty), - config.getNumberOfShards(role) + config.numberOfShards(role) ) } _ <- ZIO diff --git a/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala b/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala index 9e45109..941b5a0 100644 --- a/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala +++ b/manager/src/test/scala/com/devsisters/shardcake/ShardManagerSpec.scala @@ -22,7 +22,7 @@ object ShardManagerSpec extends ZIOSpecDefault { ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)), - numberOfShards = ManagerConfig.default.numberOfShards + numberOfShards = ManagerConfig.default.numberOfShards(role) ) val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue( @@ -40,7 +40,7 @@ object ShardManagerSpec extends ZIOSpecDefault { pod2.pod.address -> pod2.copy(pod = pod2.pod.copy(version = "0.1.2")) ), // older version shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address)), - numberOfShards = ManagerConfig.default.numberOfShards + numberOfShards = ManagerConfig.default.numberOfShards(role) ) val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue(assignments.isEmpty, unassignments.isEmpty) @@ -50,7 +50,7 @@ object ShardManagerSpec extends ZIOSpecDefault { ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod2.pod.address)), - numberOfShards = ManagerConfig.default.numberOfShards + numberOfShards = ManagerConfig.default.numberOfShards(role) ) val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue(assignments.isEmpty, unassignments.isEmpty) @@ -60,7 +60,7 @@ object ShardManagerSpec extends ZIOSpecDefault { ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2), shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)), - numberOfShards = ManagerConfig.default.numberOfShards + numberOfShards = ManagerConfig.default.numberOfShards(role) ) val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue(assignments.isEmpty, unassignments.isEmpty) @@ -75,7 +75,7 @@ object ShardManagerSpec extends ZIOSpecDefault { 3 -> Some(pod1.pod.address), 4 -> Some(pod2.pod.address) ), - numberOfShards = ManagerConfig.default.numberOfShards + numberOfShards = ManagerConfig.default.numberOfShards(role) ) val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue( @@ -90,7 +90,7 @@ object ShardManagerSpec extends ZIOSpecDefault { ShardManagerState( pods = Map(pod1.pod.address -> pod1, pod2.pod.address -> pod2, pod3.pod.address -> pod3), shards = Map(1 -> Some(pod1.pod.address), 2 -> Some(pod1.pod.address), 3 -> Some(pod2.pod.address)), - numberOfShards = ManagerConfig.default.numberOfShards + numberOfShards = ManagerConfig.default.numberOfShards(role) ) val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue( @@ -105,7 +105,7 @@ object ShardManagerSpec extends ZIOSpecDefault { ShardManagerState( pods = Map(), shards = Map(1 -> Some(pod1.pod.address)), - numberOfShards = ManagerConfig.default.numberOfShards + numberOfShards = ManagerConfig.default.numberOfShards(role) ) val (assignments, unassignments) = ShardManager.decideAssignmentsForUnbalancedShards(state, 1d) assertTrue(assignments.isEmpty, unassignments.isEmpty) @@ -115,7 +115,7 @@ object ShardManagerSpec extends ZIOSpecDefault { ShardManagerState( pods = Map(), shards = (1 to 300).map(_ -> None).toMap, - numberOfShards = ManagerConfig.default.numberOfShards + numberOfShards = ManagerConfig.default.numberOfShards(role) ) val result = From 466c65dae62be658a97b8b6a50be5083a2e8182b Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 09:56:06 +0900 Subject: [PATCH 14/23] Fix --- .../src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/health-k8s/src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala b/health-k8s/src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala index 77014d4..267b16d 100644 --- a/health-k8s/src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala +++ b/health-k8s/src/main/scala/com/devsisters/shardcake/K8sPodsHealth.scala @@ -24,7 +24,7 @@ object K8sPodsHealth { .make( config.cacheSize, config.cacheDuration, - Lookup { pod: Pod => + Lookup { (pod: Pod) => pods .getAll( config.namespace(pod.role), From ea8123b2fe07d7e5045d9e2c2bcb5edb6f846b5a Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 10:05:02 +0900 Subject: [PATCH 15/23] Use alias --- .../main/scala/com/devsisters/shardcake/Server.scala | 8 ++++---- .../devsisters/shardcake/ShardManagerClient.scala | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala b/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala index b673afa..3c985f3 100644 --- a/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala +++ b/benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala @@ -25,10 +25,10 @@ object Server { pod = PodAddress("localhost", config.shardingPort) shards = (1 to config.numberOfShards).map(_ -> Some(pod)).toMap } yield new ShardManagerClient { - def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit - def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit - def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit - def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards) + def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit + def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit + def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = ZIO.succeed(shards) } } diff --git a/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala b/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala index e65e2b7..eb6c98e 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/ShardManagerClient.scala @@ -15,7 +15,7 @@ trait ShardManagerClient { def register(podAddress: PodAddress, role: Role): Task[Unit] def unregister(podAddress: PodAddress): Task[Unit] def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] - def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] } object ShardManagerClient { @@ -49,10 +49,10 @@ object ShardManagerClient { pod = PodAddress(config.selfHost, config.shardingPort) shards = (1 to config.numberOfShards).map(_ -> Some(pod)).toMap } yield new ShardManagerClient { - def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit - def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit - def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit - def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards) + def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit + def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit + def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = ZIO.succeed(shards) } } @@ -76,7 +76,7 @@ object ShardManagerClient { ZIO.logWarning(s"Notifying Shard Manager about unhealthy pod $podAddress") *> send(GraphQLClient.Mutations.notifyUnhealthyPod(PodAddressInput(podAddress.host, podAddress.port))) - def getAssignments(role: Role): Task[Map[Int, Option[PodAddress]]] = + def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = send( GraphQLClient.Queries .getAssignments(role.name)( From 8339226c4387cb62e6480f86cca31672640bbd4c Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 10:13:49 +0900 Subject: [PATCH 16/23] Polish --- .../devsisters/shardcake/ShardManager.scala | 201 +++++++++--------- 1 file changed, 101 insertions(+), 100 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index 7d40028..8e78e8e 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -47,7 +47,7 @@ class ShardManager( } _ <- ManagerMetrics.pods.tagged("role", pod.role.name).increment _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address, pod.role)) - _ <- ZIO.when(triggerRebalance)(rebalance(pod.role, rebalanceImmediately = false).forkDaemon) + _ <- ZIO.whenDiscard(triggerRebalance)(rebalance(pod.role, rebalanceImmediately = false).forkDaemon) _ <- persistPods.forkDaemon } yield (), onFalse = ZIO.logWarning(s"Pod $pod requested to register but is not alive, ignoring") *> @@ -256,113 +256,114 @@ object ShardManager { val live: ZLayer[PodsHealth with Pods with Storage with ManagerConfig, Throwable, ShardManager] = ZLayer.scoped { for { - config <- ZIO.service[ManagerConfig] - stateRepository <- ZIO.service[Storage] - healthApi <- ZIO.service[PodsHealth] - podApi <- ZIO.service[Pods] - pods <- stateRepository.getPods + config <- ZIO.service[ManagerConfig] + stateRepository <- ZIO.service[Storage] + healthApi <- ZIO.service[PodsHealth] + podApi <- ZIO.service[Pods] + oldPods <- stateRepository.getPods // remove unhealthy pods on startup - failedFilteredPods <- ZIO.partitionPar(pods.values) { pod => - ZIO.ifZIO(healthApi.isAlive(pod))(ZIO.succeed(pod), ZIO.fail(pod)) - } - (failedPods, filtered) = failedFilteredPods - _ <- ZIO.when(failedPods.nonEmpty)( - ZIO.logInfo(s"Ignoring pods that are no longer alive ${failedPods.mkString("[", ", ", "]")}") - ) - filteredPods = filtered.map(p => p.address -> p).toMap - roles = filtered.map(_.role).toSet - _ <- ZIO.when(filteredPods.nonEmpty)(ZIO.logInfo(s"Recovered pods ${filteredPods.mkString("[", ", ", "]")}")) - rolePods = filteredPods.groupBy { case (_, pod) => pod.role }.map { case (role, pods) => role -> pods.values } - roleAssignments <- ZIO - .foreach(roles) { role => - for { - assignments <- stateRepository.getAssignments(role) - failedFilteredAssignments = partitionMap(assignments) { - case assignment @ (_, Some(address)) - if filteredPods.contains(address) => - Right(assignment) - case assignment => Left(assignment) - } - (failed, filteredAssignments) = failedFilteredAssignments - failedAssignments = failed.collect { case (shard, Some(addr)) => shard -> addr } - _ <- - ZIO.when(failedAssignments.nonEmpty)( - ZIO.logWarning( - s"Ignoring assignments for pods that are no longer alive for role ${role.name}: ${failedAssignments - .mkString("[", ", ", "]")}" - ) + failedFilteredPods <- ZIO.partitionPar(oldPods.values) { pod => + ZIO.ifZIO(healthApi.isAlive(pod))(ZIO.succeed(pod), ZIO.fail(pod)) + } + (failedPods, pods) = failedFilteredPods + _ <- ZIO.whenDiscard(failedPods.nonEmpty)( + ZIO.logInfo(s"Ignoring pods that are no longer alive ${failedPods.mkString("[", ", ", "]")}") + ) + _ <- ZIO.whenDiscard(pods.nonEmpty)( + ZIO.logInfo(s"Recovered pods ${pods.mkString("[", ", ", "]")}") + ) + podsByAddress = pods.map(p => p.address -> p).toMap + podsByRole = pods.groupBy(_.role) + roleAssignments <- ZIO + .foreach(podsByRole.keySet) { role => + for { + assignments <- stateRepository.getAssignments(role) + failedFilteredAssignments = partitionMap(assignments) { + case assignment @ (_, Some(address)) + if podsByAddress.contains(address) => + Right(assignment) + case assignment => Left(assignment) + } + (failed, filteredAssignments) = failedFilteredAssignments + failedAssignments = failed.collect { case (shard, Some(addr)) => shard -> addr } + _ <- + ZIO.whenDiscard(failedAssignments.nonEmpty)( + ZIO.logWarning( + s"Ignoring assignments for pods that are no longer alive for role ${role.name}: ${failedAssignments + .mkString("[", ", ", "]")}" ) - _ <- - ZIO.when(filteredAssignments.nonEmpty)( - ZIO.logInfo( - s"Recovered assignments for role ${role.name}: ${filteredAssignments - .mkString("[", ", ", "]")}" - ) + ) + _ <- + ZIO.whenDiscard(filteredAssignments.nonEmpty)( + ZIO.logInfo( + s"Recovered assignments for role ${role.name}: ${filteredAssignments + .mkString("[", ", ", "]")}" ) - } yield role -> filteredAssignments - } - .map(_.toMap) - cdt <- ZIO.succeed(OffsetDateTime.now()) - initialStates = rolePods.map { case (role, pods) => - role -> ShardManagerState( - pods.map(pod => pod.address -> PodWithMetadata(pod, cdt)).toMap, - (1 to config.numberOfShards(role)).map(_ -> None).toMap ++ - roleAssignments.getOrElse(role, Map.empty), - config.numberOfShards(role) - ) + ) + } yield role -> filteredAssignments } - _ <- ZIO - .foreachDiscard(initialStates) { case (role, state) => - ManagerMetrics.pods.tagged("role", role.name).incrementBy(state.pods.size) *> - ZIO.foreachDiscard(state.shards) { case (_, podAddressOpt) => - podAddressOpt match { - case Some(podAddress) => - ManagerMetrics.assignedShards - .tagged("role", role.name) - .tagged("pod_address", podAddress.toString) - .increment - case None => - ManagerMetrics.unassignedShards - .tagged("role", role.name) - .increment - } - } - } - state <- Ref.Synchronized.make(initialStates) - rebalanceSemaphores <- Ref.Synchronized.make(Map.empty[Role, Semaphore]) - eventsHub <- Hub.unbounded[ShardingEvent] - shardManager = new ShardManager( - stateRef = state, - rebalanceSemaphores = rebalanceSemaphores, - eventsHub = eventsHub, - healthApi = healthApi, - podApi = podApi, - stateRepository = stateRepository, - config = config + .map(_.toMap) + cdt <- ZIO.succeed(OffsetDateTime.now()) + initialStates = podsByRole.map { case (role, pods) => + role -> ShardManagerState( + pods.map(pod => pod.address -> PodWithMetadata(pod, cdt)).toMap, + (1 to config.numberOfShards(role)).map(_ -> None).toMap ++ + roleAssignments.getOrElse(role, Map.empty), + config.numberOfShards(role) ) - _ <- ZIO.addFinalizer { - shardManager.persistAllAssignments.catchAllCause(cause => - ZIO.logWarningCause("Failed to persist assignments on shutdown", cause) - ) *> - shardManager.persistPods.catchAllCause(cause => - ZIO.logWarningCause("Failed to persist pods on shutdown", cause) - ) + } + _ <- ZIO + .foreachDiscard(initialStates) { case (role, state) => + ManagerMetrics.pods.tagged("role", role.name).incrementBy(state.pods.size) *> + ZIO.foreachDiscard(state.shards) { case (_, podAddressOpt) => + podAddressOpt match { + case Some(podAddress) => + ManagerMetrics.assignedShards + .tagged("role", role.name) + .tagged("pod_address", podAddress.toString) + .increment + case None => + ManagerMetrics.unassignedShards + .tagged("role", role.name) + .increment + } + } } - _ <- shardManager.persistPods.forkDaemon + state <- Ref.Synchronized.make(initialStates) + rebalanceSemaphores <- Ref.Synchronized.make(Map.empty[Role, Semaphore]) + eventsHub <- Hub.unbounded[ShardingEvent] + shardManager = new ShardManager( + stateRef = state, + rebalanceSemaphores = rebalanceSemaphores, + eventsHub = eventsHub, + healthApi = healthApi, + podApi = podApi, + stateRepository = stateRepository, + config = config + ) + _ <- ZIO.addFinalizer { + shardManager.persistAllAssignments.catchAllCause(cause => + ZIO.logWarningCause("Failed to persist assignments on shutdown", cause) + ) *> + shardManager.persistPods.catchAllCause(cause => + ZIO.logWarningCause("Failed to persist pods on shutdown", cause) + ) + } + _ <- shardManager.persistPods.forkDaemon // rebalance immediately if there are unassigned shards - _ <- ZIO.foreachDiscard(initialStates) { case (role, state) => - shardManager.rebalance(role, rebalanceImmediately = state.unassignedShards.nonEmpty).forkDaemon - } + _ <- ZIO.foreachDiscard(initialStates) { case (role, state) => + shardManager.rebalance(role, rebalanceImmediately = state.unassignedShards.nonEmpty).forkDaemon + } // start a regular rebalance at the given interval - _ <- state.get - .flatMap(states => - ZIO.foreachParDiscard(states.keySet)(shardManager.rebalance(_, rebalanceImmediately = false)) - ) - .repeat(Schedule.spaced(config.rebalanceInterval)) - .forkDaemon - _ <- shardManager.getShardingEvents.mapZIO(event => ZIO.logInfo(event.toString)).runDrain.forkDaemon - _ <- shardManager.checkAllPodsHealth.repeat(Schedule.spaced(config.podHealthCheckInterval)).forkDaemon - _ <- ZIO.logInfo("Shard Manager loaded") + _ <- state.get + .flatMap(states => + ZIO.foreachParDiscard(states.keySet)(shardManager.rebalance(_, rebalanceImmediately = false)) + ) + .repeat(Schedule.spaced(config.rebalanceInterval)) + .forkDaemon + _ <- shardManager.getShardingEvents.mapZIO(event => ZIO.logInfo(event.toString)).runDrain.forkDaemon + _ <- shardManager.checkAllPodsHealth.repeat(Schedule.spaced(config.podHealthCheckInterval)).forkDaemon + _ <- ZIO.logInfo("Shard Manager loaded") } yield shardManager } From 97679ac7cc4e527cabade5b12f06ce26e585827b Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 10:14:55 +0900 Subject: [PATCH 17/23] Use alias --- .../scala/com/devsisters/shardcake/interfaces/Storage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/com/devsisters/shardcake/interfaces/Storage.scala b/core/src/main/scala/com/devsisters/shardcake/interfaces/Storage.scala index aa832a6..ad63efe 100644 --- a/core/src/main/scala/com/devsisters/shardcake/interfaces/Storage.scala +++ b/core/src/main/scala/com/devsisters/shardcake/interfaces/Storage.scala @@ -22,7 +22,7 @@ trait Storage { /** * A stream that will emit the state of shard assignments whenever it changes */ - def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[Int, Option[PodAddress]]] + def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] /** * Get the list of existing pods From d519eea3d90ccc0641f258d9883cbcb36bf21886 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 10:22:44 +0900 Subject: [PATCH 18/23] Polish --- .../devsisters/shardcake/ShardManager.scala | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index 8e78e8e..7752665 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -262,10 +262,11 @@ object ShardManager { podApi <- ZIO.service[Pods] oldPods <- stateRepository.getPods // remove unhealthy pods on startup - failedFilteredPods <- ZIO.partitionPar(oldPods.values) { pod => - ZIO.ifZIO(healthApi.isAlive(pod))(ZIO.succeed(pod), ZIO.fail(pod)) + aliveStatuses <- ZIO.foreachPar(oldPods.values)(pod => healthApi.isAlive(pod).map(pod -> _)) + (failedPods, pods) = aliveStatuses.partitionMap { + case (pod, false) => Left(pod) + case (pod, true) => Right(pod) } - (failedPods, pods) = failedFilteredPods _ <- ZIO.whenDiscard(failedPods.nonEmpty)( ZIO.logInfo(s"Ignoring pods that are no longer alive ${failedPods.mkString("[", ", ", "]")}") ) @@ -274,33 +275,31 @@ object ShardManager { ) podsByAddress = pods.map(p => p.address -> p).toMap podsByRole = pods.groupBy(_.role) - roleAssignments <- ZIO + assignments <- ZIO .foreach(podsByRole.keySet) { role => for { - assignments <- stateRepository.getAssignments(role) - failedFilteredAssignments = partitionMap(assignments) { - case assignment @ (_, Some(address)) - if podsByAddress.contains(address) => - Right(assignment) - case assignment => Left(assignment) - } - (failed, filteredAssignments) = failedFilteredAssignments - failedAssignments = failed.collect { case (shard, Some(addr)) => shard -> addr } - _ <- + oldAssignments <- stateRepository.getAssignments(role) + (failed, assignments) = partitionMap(oldAssignments) { + case assignment @ (_, Some(address)) + if podsByAddress.contains(address) => + Right(assignment) + case assignment => Left(assignment) + } + failedAssignments = failed.collect { case (shard, Some(addr)) => shard -> addr } + _ <- ZIO.whenDiscard(failedAssignments.nonEmpty)( ZIO.logWarning( s"Ignoring assignments for pods that are no longer alive for role ${role.name}: ${failedAssignments .mkString("[", ", ", "]")}" ) ) - _ <- - ZIO.whenDiscard(filteredAssignments.nonEmpty)( + _ <- + ZIO.whenDiscard(assignments.nonEmpty)( ZIO.logInfo( - s"Recovered assignments for role ${role.name}: ${filteredAssignments - .mkString("[", ", ", "]")}" + s"Recovered assignments for role ${role.name}: ${assignments.mkString("[", ", ", "]")}" ) ) - } yield role -> filteredAssignments + } yield role -> assignments } .map(_.toMap) cdt <- ZIO.succeed(OffsetDateTime.now()) @@ -308,7 +307,7 @@ object ShardManager { role -> ShardManagerState( pods.map(pod => pod.address -> PodWithMetadata(pod, cdt)).toMap, (1 to config.numberOfShards(role)).map(_ -> None).toMap ++ - roleAssignments.getOrElse(role, Map.empty), + assignments.getOrElse(role, Map.empty), config.numberOfShards(role) ) } From d04eab1c6d5d6c03b10bde78c3727ced3ad51218 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 10:25:43 +0900 Subject: [PATCH 19/23] Clean imports --- .../scala/com/devsisters/shardcake/interfaces/PodsHealth.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/com/devsisters/shardcake/interfaces/PodsHealth.scala b/core/src/main/scala/com/devsisters/shardcake/interfaces/PodsHealth.scala index 831fae5..39b90a0 100644 --- a/core/src/main/scala/com/devsisters/shardcake/interfaces/PodsHealth.scala +++ b/core/src/main/scala/com/devsisters/shardcake/interfaces/PodsHealth.scala @@ -1,6 +1,6 @@ package com.devsisters.shardcake.interfaces -import com.devsisters.shardcake.{ Pod, PodAddress, Role } +import com.devsisters.shardcake.Pod import zio.{ UIO, ULayer, ZIO, ZLayer } /** From 09de2824c348b2da168f197d8120eee961be00d5 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 10:31:45 +0900 Subject: [PATCH 20/23] Polish --- .../com/devsisters/shardcake/Sharding.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala b/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala index 82c4d8e..5126984 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/Sharding.scala @@ -65,7 +65,7 @@ class Sharding private ( private def startSingletonsIfNeeded: UIO[Unit] = ZIO - .whenZIO(isSingletonNode) { + .whenZIODiscard(isSingletonNode) { singletons.updateZIO { singletons => ZIO.foreach(singletons) { case (name, run, None) => @@ -76,11 +76,10 @@ class Sharding private ( } } } - .unit private def stopSingletonsIfNeeded: UIO[Unit] = ZIO - .unlessZIO(isSingletonNode) { + .unlessZIODiscard(isSingletonNode) { singletons.updateZIO { singletons => ZIO.foreach(singletons) { case (name, run, Some(fiber)) => @@ -91,7 +90,6 @@ class Sharding private ( } } } - .unit def registerSingleton[R](name: String, run: URIO[R, Nothing]): URIO[R, Unit] = ZIO.environment[R].flatMap(env => singletons.update(list => (name, run.provideEnvironment(env), None) :: list)) <* @@ -149,7 +147,7 @@ class Sharding private ( Metrics.shards .tagged("role", config.role.name) .set(assignmentsOpt.count { case (_, podOpt) => podOpt.contains(address) }) - .when(replaceAllAssignments) *> + .whenDiscard(replaceAllAssignments) *> (if (replaceAllAssignments) shardAssignments.set(assignments) else shardAssignments.update(map => @@ -170,7 +168,9 @@ class Sharding private ( // then, get assignments changes from Redis storage.assignmentsStream(config.role).map(_ -> false) _ <- assignmentStream.mapZIO { case (assignmentsOpt, replaceAllAssignments) => - updateAssignments(assignmentsOpt, replaceAllAssignments) *> latch.succeed(()).when(replaceAllAssignments) + updateAssignments(assignmentsOpt, replaceAllAssignments) *> latch + .succeed(()) + .whenDiscard(replaceAllAssignments) }.runDrain .retry(Schedule.fixed(config.refreshAssignmentsRetryInterval)) .interruptible @@ -242,9 +242,9 @@ class Sharding private ( .modify(repliers => (repliers.get(replier.id), repliers - replier.id)) .flatMap(ZIO.foreachDiscard(_)(_.asInstanceOf[ReplyChannel[Reply]].replyStream(replies))) - private def handleError(ex: Throwable): ZIO[Any, Nothing, Any] = + private def handleError(ex: Throwable): ZIO[Any, Nothing, Unit] = ZIO - .whenCase(ex) { case PodUnavailable(pod) => + .whenCaseDiscard(ex) { case PodUnavailable(pod) => val notify = Clock.currentDateTime.flatMap(cdt => lastUnhealthyNodeReported .updateAndGet(old => @@ -253,7 +253,7 @@ class Sharding private ( ) .map(_ isEqual cdt) ) - ZIO.whenZIO(notify)(shardManager.notifyUnhealthyPod(pod).forkDaemon) + ZIO.whenZIODiscard(notify)(shardManager.notifyUnhealthyPod(pod).forkDaemon) } private def sendToSelf[Msg, Res]( From dfe58c6992f95446d09b37d021f89ef821d612a4 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 10:46:25 +0900 Subject: [PATCH 21/23] Renaming --- .../devsisters/shardcake/ShardManager.scala | 108 ++++++++++-------- 1 file changed, 60 insertions(+), 48 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index 7752665..6e6eb98 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -46,7 +46,7 @@ class ShardManager( (state.unassignedShards.nonEmpty, states.updated(pod.role, state)) } _ <- ManagerMetrics.pods.tagged("role", pod.role.name).increment - _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod.address, pod.role)) + _ <- eventsHub.publish(ShardingEvent.PodRegistered(pod)) _ <- ZIO.whenDiscard(triggerRebalance)(rebalance(pod.role, rebalanceImmediately = false).forkDaemon) _ <- persistPods.forkDaemon } yield (), @@ -65,7 +65,7 @@ class ShardManager( ZIO .whenCaseZIODiscard(findPod(podAddress)) { case Some(pod) => ManagerMetrics.podHealthChecked.tagged("pod_address", podAddress.toString).increment.unless(ignoreMetric) *> - eventsHub.publish(ShardingEvent.PodHealthChecked(podAddress)) *> + eventsHub.publish(ShardingEvent.PodHealthChecked(pod)) *> ZIO.unlessZIO(healthApi.isAlive(pod))( ZIO.logWarning(s"Pod $podAddress is not alive, unregistering") *> unregister(podAddress) ) @@ -106,9 +106,9 @@ class ShardManager( .tagged("pod_address", podAddress.toString) .decrementBy(unassignments.size) _ <- ManagerMetrics.unassignedShards.tagged("role", pod.role.name).incrementBy(unassignments.size) - _ <- eventsHub.publish(ShardingEvent.PodUnregistered(podAddress)) + _ <- eventsHub.publish(ShardingEvent.PodUnregistered(pod)) _ <- eventsHub - .publish(ShardingEvent.ShardsUnassigned(podAddress, pod.role, unassignments)) + .publish(ShardingEvent.ShardsUnassigned(pod.address, pod.role, unassignments)) .when(unassignments.nonEmpty) _ <- persistPods.forkDaemon _ <- rebalance(pod.role, rebalanceImmediately = true).forkDaemon @@ -136,34 +136,38 @@ class ShardManager( ManagerMetrics.rebalances.tagged("role", role.name).increment).when(areChanges) // ping pods first to make sure they are ready and remove those who aren't failedPingedPods <- ZIO - .foreachPar(assignments.keySet ++ unassignments.keySet)(pod => + .foreachPar(assignments.keySet ++ unassignments.keySet)(podAddress => podApi - .ping(pod) + .ping(podAddress) .timeout(config.pingTimeout) .someOrFailException - .fold(_ => Set(pod), _ => Set.empty[PodAddress]) + .fold(_ => Set(podAddress), _ => Set.empty[PodAddress]) ) .map(_.flatten) - shardsToRemove = - assignments.collect { case (pod, shards) if failedPingedPods.contains(pod) => shards }.toSet.flatten ++ - unassignments.collect { case (pod, shards) if failedPingedPods.contains(pod) => shards }.toSet.flatten + shardsToRemove = assignments.collect { + case (podAddress, shards) if failedPingedPods.contains(podAddress) => shards + }.toSet.flatten ++ + unassignments.collect { + case (podAddress, shards) if failedPingedPods.contains(podAddress) => shards + }.toSet.flatten readyAssignments = assignments.view.mapValues(_ diff shardsToRemove).filterNot(_._2.isEmpty).toMap readyUnassignments = unassignments.view.mapValues(_ diff shardsToRemove).filterNot(_._2.isEmpty).toMap // do the unassignments first failed <- ZIO - .foreachPar(readyUnassignments.toList) { case (pod, shards) => - (podApi.unassignShards(pod, shards) *> updateShardsState(role, shards, None)).foldZIO( - _ => ZIO.succeed((Set(pod), shards)), + .foreachPar(readyUnassignments.toList) { case (podAddress, shards) => + (podApi.unassignShards(podAddress, shards) *> + updateShardsState(role, shards, None)).foldZIO( + _ => ZIO.succeed((Set(podAddress), shards)), _ => ManagerMetrics.assignedShards .tagged("role", role.name) - .tagged("pod_address", pod.toString) + .tagged("pod_address", podAddress.toString) .decrementBy(shards.size) *> ManagerMetrics.unassignedShards .tagged("role", role.name) .incrementBy(shards.size) *> eventsHub - .publish(ShardingEvent.ShardsUnassigned(pod, role, shards)) + .publish(ShardingEvent.ShardsUnassigned(podAddress, role, shards)) .as((Set.empty, Set.empty)) ) } @@ -171,25 +175,26 @@ class ShardManager( .map { case (pods, shards) => (pods.flatten[PodAddress].toSet, shards.flatten[ShardId].toSet) } (failedUnassignedPods, failedUnassignedShards) = failed // remove assignments of shards that couldn't be unassigned, as well as faulty pods - filteredAssignments = (readyAssignments -- failedUnassignedPods).map { case (pod, shards) => - pod -> (shards diff failedUnassignedShards) + filteredAssignments = (readyAssignments -- failedUnassignedPods).map { case (podAddress, shards) => + podAddress -> (shards diff failedUnassignedShards) } // then do the assignments failedAssignedPods <- ZIO - .foreachPar(filteredAssignments.toList) { case (pod, shards) => - (podApi.assignShards(pod, shards) *> updateShardsState(role, shards, Some(pod))) + .foreachPar(filteredAssignments.toList) { case (podAddress, shards) => + (podApi.assignShards(podAddress, shards) *> + updateShardsState(role, shards, Some(podAddress))) .foldZIO( - _ => ZIO.succeed(Set(pod)), + _ => ZIO.succeed(Set(podAddress)), _ => ManagerMetrics.assignedShards .tagged("role", role.name) - .tagged("pod_address", pod.toString) + .tagged("pod_address", podAddress.toString) .incrementBy(shards.size) *> ManagerMetrics.unassignedShards .tagged("role", role.name) .decrementBy(shards.size) *> eventsHub - .publish(ShardingEvent.ShardsAssigned(pod, role, shards)) + .publish(ShardingEvent.ShardsAssigned(podAddress, role, shards)) .as(Set.empty) ) } @@ -228,19 +233,20 @@ class ShardManager( ) ) - private def updateShardsState(role: Role, shards: Set[ShardId], pod: Option[PodAddress]): Task[Unit] = + private def updateShardsState(role: Role, shards: Set[ShardId], podAddress: Option[PodAddress]): Task[Unit] = stateRef.updateZIO { states => val previous = states.get(role) ZIO - .whenCase((previous, pod)) { - case (Some(p), Some(pod)) if !p.pods.contains(pod) => ZIO.fail(new Exception(s"Pod $pod is not registered")) + .whenCase((previous, podAddress)) { + case (Some(state), Some(podAddress)) if !state.pods.contains(podAddress) => + ZIO.fail(new Exception(s"Pod $podAddress is not registered")) } .as( previous.fold(states)(state => states.updated( role, state.copy(shards = state.shards.map { case (shard, assignment) => - shard -> (if (shards.contains(shard)) pod else assignment) + shard -> (if (shards.contains(shard)) podAddress else assignment) }) ) ) @@ -422,15 +428,17 @@ object ShardManager { sealed trait ShardingEvent object ShardingEvent { - case class ShardsAssigned(pod: PodAddress, role: Role, shards: Set[ShardId]) extends ShardingEvent { - override def toString: String = s"ShardsAssigned(pod=$pod, role=${role.name}, shards=${renderShardIds(shards)})" + case class ShardsAssigned(podAddress: PodAddress, role: Role, shards: Set[ShardId]) extends ShardingEvent { + override def toString: String = + s"ShardsAssigned(pod=$podAddress, role=${role.name}, shards=${renderShardIds(shards)})" } - case class ShardsUnassigned(pod: PodAddress, role: Role, shards: Set[ShardId]) extends ShardingEvent { - override def toString: String = s"ShardsUnassigned(pod=$pod, role=${role.name}, shards=${renderShardIds(shards)})" + case class ShardsUnassigned(podAddress: PodAddress, role: Role, shards: Set[ShardId]) extends ShardingEvent { + override def toString: String = + s"ShardsUnassigned(pod=$podAddress, role=${role.name}, shards=${renderShardIds(shards)})" } - case class PodRegistered(pod: PodAddress, role: Role) extends ShardingEvent - case class PodUnregistered(pod: PodAddress) extends ShardingEvent - case class PodHealthChecked(pod: PodAddress) extends ShardingEvent + case class PodRegistered(pod: Pod) extends ShardingEvent + case class PodUnregistered(pod: Pod) extends ShardingEvent + case class PodHealthChecked(pod: Pod) extends ShardingEvent } def decideAssignmentsForUnassignedShards( @@ -452,10 +460,10 @@ object ShardManager { } else Set.empty val sortedShardsToRebalance = extraShardsToAllocate.toList.sortBy { shard => // handle unassigned shards first, then shards on the pods with most shards, then shards on old pods - state.shards.get(shard).flatten.fold((Int.MinValue, OffsetDateTime.MIN)) { pod => + state.shards.get(shard).flatten.fold((Int.MinValue, OffsetDateTime.MIN)) { podAddress => ( - state.shardsPerPod.get(pod).fold(Int.MinValue)(-_.size), - state.pods.get(pod).fold(OffsetDateTime.MIN)(_.registered) + state.shardsPerPod.get(podAddress).fold(Int.MinValue)(-_.size), + state.pods.get(podAddress).fold(OffsetDateTime.MIN)(_.registered) ) } } @@ -476,34 +484,38 @@ object ShardManager { // find pod with least amount of shards shardsPerPod // keep only pods with the max version - .filter { case (pod, _) => - state.maxVersion.forall(max => state.pods.get(pod).map(extractVersion).forall(_ == max)) + .filter { case (podAddress, _) => + state.maxVersion.forall(max => state.pods.get(podAddress).map(extractVersion).forall(_ == max)) } // don't assign too many shards to the same pods, unless we need rebalance immediately - .filter { case (pod, _) => - rebalanceImmediately || assignments.count { case (_, p) => p == pod } < state.shards.size * rebalanceRate + .filter { case (podAddress, _) => + rebalanceImmediately || + assignments.count { case (_, p) => p == podAddress } < state.shards.size * rebalanceRate } // don't assign to a pod that was unassigned in the same rebalance - .filterNot { case (pod, _) => unassignedPods.contains(pod) } + .filterNot { case (podAddress, _) => unassignedPods.contains(podAddress) } .minByOption(_._2.size) match { - case Some((pod, shards)) => - val oldPod = state.shards.get(shard).flatten + case Some((podAddress, shards)) => + val oldPodAddress = state.shards.get(shard).flatten // if old pod is same as new pod, don't change anything - if (oldPod.contains(pod)) + if (oldPodAddress.contains(podAddress)) (shardsPerPod, assignments) // if the new pod has more, as much, or only 1 less shard than the old pod, don't change anything else if ( - shardsPerPod.get(pod).fold(0)(_.size) + 1 >= oldPod.fold(Int.MaxValue)( + shardsPerPod.get(podAddress).fold(0)(_.size) + 1 >= oldPodAddress.fold(Int.MaxValue)( shardsPerPod.getOrElse(_, Nil).size ) ) (shardsPerPod, assignments) // otherwise, create a new assignment else { - val unassigned = oldPod.fold(shardsPerPod)(oldPod => shardsPerPod.updatedWith(oldPod)(_.map(_ - shard))) - (unassigned.updated(pod, shards + shard), (shard, pod) :: assignments) + val unassigned = + oldPodAddress.fold(shardsPerPod)(oldPodAddress => + shardsPerPod.updatedWith(oldPodAddress)(_.map(_ - shard)) + ) + (unassigned.updated(podAddress, shards + shard), (shard, podAddress) :: assignments) } - case None => (shardsPerPod, assignments) + case None => (shardsPerPod, assignments) } } val unassignments = assignments.flatMap { case (shard, _) => state.shards.get(shard).flatten.map(shard -> _) } From f7bc6f1e6d26fc1b5590b95b39d7b2b99ad7c311 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Fri, 4 Apr 2025 10:49:08 +0900 Subject: [PATCH 22/23] Rename --- .../main/scala/com/devsisters/shardcake/ShardManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala index 6e6eb98..c3065c5 100644 --- a/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala +++ b/manager/src/main/scala/com/devsisters/shardcake/ShardManager.scala @@ -424,7 +424,7 @@ object ShardManager { ShardManagerState(Map.empty, (1 to numberOfShards).map(_ -> None).toMap, numberOfShards) } - case class PodWithMetadata(pod: Pod, registered: OffsetDateTime) + case class PodWithMetadata(pod: Pod, registeredAt: OffsetDateTime) sealed trait ShardingEvent object ShardingEvent { @@ -463,7 +463,7 @@ object ShardManager { state.shards.get(shard).flatten.fold((Int.MinValue, OffsetDateTime.MIN)) { podAddress => ( state.shardsPerPod.get(podAddress).fold(Int.MinValue)(-_.size), - state.pods.get(podAddress).fold(OffsetDateTime.MIN)(_.registered) + state.pods.get(podAddress).fold(OffsetDateTime.MIN)(_.registeredAt) ) } } From 017a5a12ecb1828d8920360b6d74f814800d6a69 Mon Sep 17 00:00:00 2001 From: Pierre Ricadat Date: Wed, 16 Apr 2025 11:41:20 +0900 Subject: [PATCH 23/23] Fix test --- .../test/scala/example/ShardManagerAuthExampleSpec.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/test/scala/example/ShardManagerAuthExampleSpec.scala b/examples/src/test/scala/example/ShardManagerAuthExampleSpec.scala index ecb9c2e..c893f31 100644 --- a/examples/src/test/scala/example/ShardManagerAuthExampleSpec.scala +++ b/examples/src/test/scala/example/ShardManagerAuthExampleSpec.scala @@ -1,6 +1,6 @@ package example -import com.devsisters.shardcake.{ Config, ManagerConfig, Server, ShardManager, ShardManagerClient } +import com.devsisters.shardcake.{ Config, ManagerConfig, Role, Server, ShardManager, ShardManagerClient } import com.devsisters.shardcake.interfaces.{ Pods, PodsHealth, Storage } import sttp.client3.SttpBackend import sttp.client3.asynchttpclient.zio.AsyncHttpClientZioBackend @@ -52,8 +52,8 @@ object ShardManagerAuthExampleSpec extends ZIOSpecDefault { sttpBackendWithAuthTokenLayer("invalid"), ShardManagerClient.live ) - validRequest <- validClient.getAssignments.exit - invalidRequest <- invalidClient.getAssignments.exit + validRequest <- validClient.getAssignments(Role.default).exit + invalidRequest <- invalidClient.getAssignments(Role.default).exit } yield assertTrue(validRequest.isSuccess, invalidRequest.isFailure) } }