Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions benchmarks/src/main/scala/com/devsisters/shardcake/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,26 @@ object Server {
pod = PodAddress("localhost", config.shardingPort)
shards = (1 to config.numberOfShards).map(_ -> Some(pod)).toMap
} yield new ShardManagerClient {
def register(podAddress: PodAddress): Task[Unit] = ZIO.unit
def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit
def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit
def getAssignments: Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards)
def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit
def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit
def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit
def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = ZIO.succeed(shards)
}
}

private val memory: ULayer[Storage] =
ZLayer {
for {
assignmentsRef <- Ref.make(Map.empty[ShardId, Option[PodAddress]])
assignmentsRef <- Ref.make(Map.empty[Role, Map[ShardId, Option[PodAddress]]])
podsRef <- Ref.make(Map.empty[PodAddress, Pod])
} yield new Storage {
def getAssignments: Task[Map[ShardId, Option[PodAddress]]] = assignmentsRef.get
def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = assignmentsRef.set(assignments)
def assignmentsStream: ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = ZStream.never
def getPods: Task[Map[PodAddress, Pod]] = podsRef.get
def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = podsRef.set(pods)
def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] =
assignmentsRef.get.map(_.getOrElse(role, Map.empty))
def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] =
assignmentsRef.update(_.updated(role, assignments))
def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = ZStream.never
def getPods: Task[Map[PodAddress, Pod]] = podsRef.get
def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = podsRef.set(pods)
}
}

Expand Down
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ val grpcNettyVersion = "1.71.0"
val zioK8sVersion = "3.1.0"
val zioCacheVersion = "0.2.4"
val zioCatsInteropVersion = "23.1.0.5"
val zioJsonVersion = "0.7.39"
val sttpVersion = "3.10.3"
val calibanVersion = "2.10.0"
val redis4catsVersion = "1.7.2"
Expand Down Expand Up @@ -73,6 +74,7 @@ lazy val core = project
Seq(
"dev.zio" %% "zio" % zioVersion,
"dev.zio" %% "zio-streams" % zioVersion,
"dev.zio" %% "zio-json" % zioJsonVersion,
"org.scala-lang.modules" %% "scala-collection-compat" % scalaCompatVersion
)
)
Expand Down
9 changes: 8 additions & 1 deletion core/src/main/scala/com/devsisters/shardcake/Pod.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
package com.devsisters.shardcake

case class Pod(address: PodAddress, version: String)
import zio.json._

case class Pod(address: PodAddress, version: String, role: Role)

object Pod {
implicit val encoder: JsonEncoder[Pod] = DeriveJsonEncoder.gen[Pod]
implicit val decoder: JsonDecoder[Pod] = DeriveJsonDecoder.gen[Pod]
}
5 changes: 5 additions & 0 deletions core/src/main/scala/com/devsisters/shardcake/PodAddress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package com.devsisters.shardcake

import scala.collection.compat._

import zio.json._

case class PodAddress(host: String, port: Int) {
override def toString: String = s"$host:$port"
}
Expand All @@ -12,4 +14,7 @@ object PodAddress {
case host :: port :: Nil => port.toIntOption.map(port => PodAddress(host, port))
case _ => None
}

implicit val encoder: JsonEncoder[PodAddress] = DeriveJsonEncoder.gen[PodAddress]
implicit val decoder: JsonDecoder[PodAddress] = DeriveJsonDecoder.gen[PodAddress]
}
12 changes: 12 additions & 0 deletions core/src/main/scala/com/devsisters/shardcake/Role.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.devsisters.shardcake

import zio.json._

case class Role(name: String)

object Role {
val default: Role = Role("default")

implicit val encoder: JsonEncoder[Role] = DeriveJsonEncoder.gen[Role]
implicit val decoder: JsonDecoder[Role] = DeriveJsonDecoder.gen[Role]
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.devsisters.shardcake.interfaces

import com.devsisters.shardcake.PodAddress
import com.devsisters.shardcake.Pod
import zio.{ UIO, ULayer, ZIO, ZLayer }

/**
Expand All @@ -15,7 +15,7 @@ trait PodsHealth {
/**
* Check if a pod is still alive.
*/
def isAlive(podAddress: PodAddress): UIO[Boolean]
def isAlive(pod: Pod): UIO[Boolean]
}

object PodsHealth {
Expand All @@ -26,7 +26,7 @@ object PodsHealth {
*/
val noop: ULayer[PodsHealth] =
ZLayer.succeed(new PodsHealth {
def isAlive(podAddress: PodAddress): UIO[Boolean] = ZIO.succeed(true)
def isAlive(pod: Pod): UIO[Boolean] = ZIO.succeed(true)
})

/**
Expand All @@ -35,6 +35,6 @@ object PodsHealth {
*/
val local: ZLayer[Pods, Nothing, PodsHealth] =
ZLayer {
ZIO.serviceWith[Pods](podApi => (podAddress: PodAddress) => podApi.ping(podAddress).option.map(_.isDefined))
ZIO.serviceWith[Pods](podApi => (pod: Pod) => podApi.ping(pod.address).option.map(_.isDefined))
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.devsisters.shardcake.interfaces

import com.devsisters.shardcake.{ Pod, PodAddress, ShardId }
import com.devsisters.shardcake.{ Pod, PodAddress, Role, ShardId }
import zio.{ Ref, Task, ZLayer }
import zio.stream.{ SubscriptionRef, ZStream }

Expand All @@ -12,17 +12,17 @@ trait Storage {
/**
* Get the current state of shard assignments to pods
*/
def getAssignments: Task[Map[ShardId, Option[PodAddress]]]
def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]]

/**
* Save the current state of shard assignments to pods
*/
def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit]
def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit]

/**
* A stream that will emit the state of shard assignments whenever it changes
*/
def assignmentsStream: ZStream[Any, Throwable, Map[Int, Option[PodAddress]]]
def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]]

/**
* Get the list of existing pods
Expand All @@ -47,11 +47,13 @@ object Storage {
assignmentsRef <- SubscriptionRef.make(Map.empty[ShardId, Option[PodAddress]])
podsRef <- Ref.make(Map.empty[PodAddress, Pod])
} yield new Storage {
def getAssignments: Task[Map[ShardId, Option[PodAddress]]] = assignmentsRef.get
def saveAssignments(assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] = assignmentsRef.set(assignments)
def assignmentsStream: ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] = assignmentsRef.changes
def getPods: Task[Map[PodAddress, Pod]] = podsRef.get
def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = podsRef.set(pods)
def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = assignmentsRef.get
def saveAssignments(role: Role, assignments: Map[ShardId, Option[PodAddress]]): Task[Unit] =
assignmentsRef.set(assignments)
def assignmentsStream(role: Role): ZStream[Any, Throwable, Map[ShardId, Option[PodAddress]]] =
assignmentsRef.changes
def getPods: Task[Map[PodAddress, Pod]] = podsRef.get
def savePods(pods: Map[PodAddress, Pod]): Task[Unit] = podsRef.set(pods)
}
}
}
3 changes: 3 additions & 0 deletions entities/src/main/scala/com/devsisters/shardcake/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import zio._

/**
* Sharding configuration
* @param role role of the current pod
* @param numberOfShards number of shards (see documentation on how to choose this), should be same on all nodes
* @param selfHost hostname or IP address of the current pod
* @param shardingPort port used for pods to communicate together
Expand All @@ -20,6 +21,7 @@ import zio._
* @param unregisterRetrySchedule retry schedule for unregistering the pod from the Shard Manager
*/
case class Config(
role: Role,
numberOfShards: Int,
selfHost: String,
shardingPort: Int,
Expand All @@ -36,6 +38,7 @@ case class Config(

object Config {
val default: Config = Config(
role = Role.default,
numberOfShards = 300,
selfHost = "localhost",
shardingPort = 54321,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.devsisters.shardcake
import caliban.client.Operations.IsOperation
import caliban.client.SelectionBuilder
import com.devsisters.shardcake.internal.GraphQLClient
import com.devsisters.shardcake.internal.GraphQLClient.PodAddressInput
import com.devsisters.shardcake.internal.GraphQLClient.{ PodAddressInput, RoleInput }
import sttp.client3.SttpBackend
import sttp.client3.asynchttpclient.zio.AsyncHttpClientZioBackend
import zio.{ Config => _, _ }
Expand All @@ -12,10 +12,10 @@ import zio.{ Config => _, _ }
* An interface to communicate with the Shard Manager API
*/
trait ShardManagerClient {
def register(podAddress: PodAddress): Task[Unit]
def register(podAddress: PodAddress, role: Role): Task[Unit]
def unregister(podAddress: PodAddress): Task[Unit]
def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit]
def getAssignments: Task[Map[Int, Option[PodAddress]]]
def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]]
}

object ShardManagerClient {
Expand Down Expand Up @@ -49,35 +49,37 @@ object ShardManagerClient {
pod = PodAddress(config.selfHost, config.shardingPort)
shards = (1 to config.numberOfShards).map(_ -> Some(pod)).toMap
} yield new ShardManagerClient {
def register(podAddress: PodAddress): Task[Unit] = ZIO.unit
def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit
def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit
def getAssignments: Task[Map[Int, Option[PodAddress]]] = ZIO.succeed(shards)
def register(podAddress: PodAddress, role: Role): Task[Unit] = ZIO.unit
def unregister(podAddress: PodAddress): Task[Unit] = ZIO.unit
def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] = ZIO.unit
def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] = ZIO.succeed(shards)
}
}

class ShardManagerClientLive(sttp: SttpBackend[Task, Any], config: Config) extends ShardManagerClient {
private def send[Origin: IsOperation, A](query: SelectionBuilder[Origin, A]): Task[A] =
sttp.send(query.toRequest(config.shardManagerUri)).map(_.body).absolve

def register(podAddress: PodAddress): Task[Unit] =
def register(podAddress: PodAddress, role: Role): Task[Unit] =
send(
GraphQLClient.Mutations.register(PodAddressInput(podAddress.host, podAddress.port), config.serverVersion)
GraphQLClient.Mutations.register(
PodAddressInput(podAddress.host, podAddress.port),
config.serverVersion,
RoleInput(role.name)
)
).unit

def unregister(podAddress: PodAddress): Task[Unit] =
send(
GraphQLClient.Mutations.unregister(PodAddressInput(podAddress.host, podAddress.port), config.serverVersion)
).unit
send(GraphQLClient.Mutations.unregister(PodAddressInput(podAddress.host, podAddress.port))).unit

def notifyUnhealthyPod(podAddress: PodAddress): Task[Unit] =
ZIO.logWarning(s"Notifying Shard Manager about unhealthy pod $podAddress") *>
send(GraphQLClient.Mutations.notifyUnhealthyPod(PodAddressInput(podAddress.host, podAddress.port)))

def getAssignments: Task[Map[Int, Option[PodAddress]]] =
def getAssignments(role: Role): Task[Map[ShardId, Option[PodAddress]]] =
send(
GraphQLClient.Queries
.getAssignments(
.getAssignments(role.name)(
GraphQLClient.Assignment.shardId ~ GraphQLClient.Assignment
.pod((GraphQLClient.PodAddress.host ~ GraphQLClient.PodAddress.port).map { case (host, port) =>
PodAddress(host, port)
Expand Down
35 changes: 18 additions & 17 deletions entities/src/main/scala/com/devsisters/shardcake/Sharding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class Sharding private (
val register: Task[Unit] =
ZIO.logDebug(s"Registering pod $address to Shard Manager") *>
isShuttingDownRef.set(false) *>
shardManager.register(address)
shardManager.register(address, config.role)

val unregister: UIO[Unit] =
(
// ping the shard manager first to stop if it's not available
shardManager.getAssignments *>
shardManager.getAssignments(config.role) *>
ZIO.logDebug(s"Stopping local entities") *>
isShuttingDownRef.set(true) *>
entityStates.get.flatMap(
Expand All @@ -63,33 +63,31 @@ class Sharding private (

private def startSingletonsIfNeeded: UIO[Unit] =
ZIO
.whenZIO(isSingletonNode) {
.whenZIODiscard(isSingletonNode) {
singletons.updateZIO { singletons =>
ZIO.foreach(singletons) {
case (name, run, None) =>
ZIO.logDebug(s"Starting singleton $name") *>
Metrics.singletons.tagged("singleton_name", name).increment *>
Metrics.singletons.tagged("role", config.role.name).tagged("singleton_name", name).increment *>
run.forkDaemon.map(fiber => (name, run, Some(fiber)))
case other => ZIO.succeed(other)
}
}
}
.unit

private def stopSingletonsIfNeeded: UIO[Unit] =
ZIO
.unlessZIO(isSingletonNode) {
.unlessZIODiscard(isSingletonNode) {
singletons.updateZIO { singletons =>
ZIO.foreach(singletons) {
case (name, run, Some(fiber)) =>
ZIO.logDebug(s"Stopping singleton $name") *>
Metrics.singletons.tagged("singleton_name", name).decrement *>
Metrics.singletons.tagged("role", config.role.name).tagged("singleton_name", name).decrement *>
fiber.interrupt.as((name, run, None))
case other => ZIO.succeed(other)
}
}
}
.unit

def registerSingleton[R](name: String, run: URIO[R, Nothing]): URIO[R, Unit] =
ZIO.environment[R].flatMap(env => singletons.update(list => (name, run.provideEnvironment(env), None) :: list)) <*
Expand All @@ -100,7 +98,7 @@ class Sharding private (
ZIO
.unlessZIO(isShuttingDown) {
shardAssignments.update(shards.foldLeft(_) { case (map, shard) => map.updated(shard, address) }) *>
Metrics.shards.incrementBy(shards.size) *>
Metrics.shards.tagged("role", config.role.name).incrementBy(shards.size) *>
startSingletonsIfNeeded *>
ZIO.logDebug(s"Assigned shards: ${renderShardIds(shards)}")
}
Expand All @@ -116,7 +114,7 @@ class Sharding private (
_.entityManager.terminateEntitiesOnShards(shards) // this will return once all shards are terminated
)
) *>
Metrics.shards.decrementBy(shards.size) *>
Metrics.shards.tagged("role", config.role.name).decrementBy(shards.size) *>
stopSingletonsIfNeeded *>
ZIO.logDebug(s"Unassigned shards: ${renderShardIds(shards)}")

Expand Down Expand Up @@ -145,8 +143,9 @@ class Sharding private (
val assignments = assignmentsOpt.flatMap { case (k, v) => v.map(k -> _) }
ZIO.logDebug("Received new shard assignments") *>
Metrics.shards
.tagged("role", config.role.name)
.set(assignmentsOpt.count { case (_, podOpt) => podOpt.contains(address) })
.when(replaceAllAssignments) *>
.whenDiscard(replaceAllAssignments) *>
(if (replaceAllAssignments) shardAssignments.set(assignments)
else
shardAssignments.update(map =>
Expand All @@ -162,12 +161,14 @@ class Sharding private (
latch <- Promise.make[Nothing, Unit]
assignmentStream = ZStream.fromZIO(
// first, get the assignments from the shard manager directly
shardManager.getAssignments.map(_ -> true)
shardManager.getAssignments(config.role).map(_ -> true)
) ++
// then, get assignments changes from Redis
storage.assignmentsStream.map(_ -> false)
storage.assignmentsStream(config.role).map(_ -> false)
_ <- assignmentStream.mapZIO { case (assignmentsOpt, replaceAllAssignments) =>
updateAssignments(assignmentsOpt, replaceAllAssignments) *> latch.succeed(()).when(replaceAllAssignments)
updateAssignments(assignmentsOpt, replaceAllAssignments) *> latch
.succeed(())
.whenDiscard(replaceAllAssignments)
}.runDrain
.retry(Schedule.fixed(config.refreshAssignmentsRetryInterval))
.interruptible
Expand Down Expand Up @@ -239,9 +240,9 @@ class Sharding private (
.modify(repliers => (repliers.get(replier.id), repliers - replier.id))
.flatMap(ZIO.foreachDiscard(_)(_.asInstanceOf[ReplyChannel[Reply]].replyStream(replies)))

private def handleError(ex: Throwable): ZIO[Any, Nothing, Any] =
private def handleError(ex: Throwable): ZIO[Any, Nothing, Unit] =
ZIO
.whenCase(ex) { case PodUnavailable(pod) =>
.whenCaseDiscard(ex) { case PodUnavailable(pod) =>
val notify = Clock.currentDateTime.flatMap(cdt =>
lastUnhealthyNodeReported
.updateAndGet(old =>
Expand All @@ -250,7 +251,7 @@ class Sharding private (
)
.map(_ isEqual cdt)
)
ZIO.whenZIO(notify)(shardManager.notifyUnhealthyPod(pod).forkDaemon)
ZIO.whenZIODiscard(notify)(shardManager.notifyUnhealthyPod(pod).forkDaemon)
}

private def sendToSelf[Msg, Res](
Expand Down
Loading
Loading