Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Databricks notebook source
# MAGIC %md
# MAGIC # Stateful Streaming
# MAGIC It is very common for a Charge Point Operator to want to understand current status of their fleet of chargers, which might contain information if the Charger needs repairing. In this exercise, we'll use our knowledge of Stateful streaming to report on the **Status of Chargers in the last 5 minutes**. In this exercise, we'll put our Stateful Streaming knowledge to the test by:
# MAGIC # Structured Streaming
# MAGIC It is very common for a Charge Point Operator to want to understand current status of their fleet of chargers, which might contain information if the Charger needs repairing. In this exercise, our goal is to continuously report on the **Status of Chargers in 5-minute intervals**. To achieve this, we'll put our knowledge of Structured Streaming to the test by:
# MAGIC
# MAGIC * Ingesting a Data Stream (StatusNotification Requests)
# MAGIC * Unpack the JSON string so that we can extract the status
# MAGIC * Ignore late data past 10 minutes and aggregate data over the previous 5 minutes
# MAGIC * Ignore late data past 10 minutes and aggregate data over 5-minute tumbling windows
# MAGIC * Write to storage
# MAGIC
# MAGIC Sample OCPP StatusNotification Request:
# MAGIC Sample [OCPP](https://en.wikipedia.org/wiki/Open_Charge_Point_Protocol) StatusNotification Request:
# MAGIC ```json
# MAGIC {
# MAGIC "connector_id": 1,
Expand All @@ -34,7 +34,7 @@

# COMMAND ----------

exercise_name = "stateful_streaming"
exercise_name = "structured_streaming"

# COMMAND ----------

Expand All @@ -61,13 +61,17 @@
# COMMAND ----------

import pandas as pd
from pyspark.sql.functions import col,lit
from pyspark.sql.functions import col, lit

# In a real streaming workload, set this to the total number of CPU cores
# provided by your worker nodes in your cluster to minimize excessive shuffling in streaming.
spark.conf.set("spark.sql.shuffle.partitions", 8)

url = "https://raw.githubusercontent.com/kelseymok/charge-point-live-status/main/ocpp_producer/data/1683036538.json"
pandas_df = pd.read_json(url, orient='records')
pandas_df["index"] = pandas_df.index
mock_data_df = spark.createDataFrame(pandas_df)

mock_data_df = spark.createDataFrame(pandas_df)

# COMMAND ----------

Expand All @@ -90,30 +94,6 @@
# COMMAND ----------


from pyspark.sql import DataFrame

def read_from_stream(input_df: DataFrame) -> DataFrame:
### YOUR CODE HERE
raw_stream_data = (
spark.readStream.format(None)
.option("rowsPerSecond", None)
.load()
)
###


# This is just data setup, not part of the exercise
return raw_stream_data.\
join(mock_data_df, raw_stream_data.value == mock_data_df.index, 'left').\
drop("timestamp").\
drop("index")


df = read_from_stream(mock_data_df)

# COMMAND ----------


############ SOLUTION ##############
from pyspark.sql import DataFrame

Expand Down Expand Up @@ -228,21 +208,6 @@ def test_read_from_stream(spark, f: Callable):
# MAGIC | |-- vendor_error_code: integer (nullable = true)
# MAGIC ```

# COMMAND ----------

from pyspark.sql import DataFrame
from pyspark.sql.functions import from_json, col
from pyspark.sql.types import StringType, IntegerType, StructField, StructType


def unpack_json_from_status_notification_request(input_df: DataFrame) -> DataFrame:
### YOUR CODE HERE
body_schema = None

return input_df.withColumn("new_body", from_json(col("body"), schema=body_schema))
###


# COMMAND ----------

############## SOLUTION #################
Expand Down Expand Up @@ -366,16 +331,6 @@ def test_unpack_json_from_status_notification_request_unit(spark, f: Callable):

# COMMAND ----------

from pyspark.sql.functions import col, expr, to_timestamp
import pyspark.sql.functions as F

def select_columns(input_df: DataFrame) -> DataFrame:
### YOUR CODE HERE
return input_df
###

# COMMAND ----------

########### SOLUTION ###########
from pyspark.sql import DataFrame

Expand All @@ -394,7 +349,6 @@ def select_columns(input_df: DataFrame) -> DataFrame:
# To display data in the resulting dataframe
display(df.transform(read_from_stream).transform(unpack_json_from_status_notification_request).transform(select_columns))


# df.transform(read_from_stream).transform(unpack_json_from_status_notification_request).transform(select_columns).printSchema()

# COMMAND ----------
Expand Down Expand Up @@ -471,7 +425,14 @@ def test_select_columns_unit(spark, f: Callable):

# MAGIC %md
# MAGIC ## EXERCISE: Aggregate, set up window and watermark
# MAGIC In this exercise, we'll want to create a tumbling window of 5 minutes and expire events that are older than 10 minutes. We also want to group by `charge_point_id` and `status`, and sum the number of status updates in the window period.
# MAGIC In this exercise, we'll want to create a tumbling window of 5 minutes and 10-minute lateness threshold for controlling our watermarks. We also want to group by `charge_point_id` and `status`, then compute the following statistics in each 5-minute window:
# MAGIC - count of status updates, alias this as `count_status_updates`
# MAGIC - min timestamp detected, alias this as `min_timestamp_detected`
# MAGIC - max timestamp detected, alias this as `max_timestamp_detected`
# MAGIC
# MAGIC To understand watermarks better, see this [section](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#handling-late-data-and-watermarking)
# MAGIC
# MAGIC *"You can define the watermark of a query by specifying the **event time column** and the **threshold on how late the data is expected to be** in terms of event time. For a specific window ending at time T, the engine will maintain state and allow late data to update the state until (max event time seen by the engine - late threshold > T). In other words, late data within the threshold will be aggregated, but data later than the threshold will start getting dropped (see later in the section for the exact guarantees)."*

# COMMAND ----------

Expand All @@ -485,19 +446,17 @@ def aggregate_window_watermark(input_df: DataFrame) -> DataFrame:
.groupBy(col("charge_point_id"),
col("status"),
window(col("timestamp"), "5 minutes"))\
.agg(F.count(col("status")))
.agg(
F.count(col("status")).alias("count_status_updates"),
# [PAWARIT] - add a few columns to help sanity check
F.min(col("timestamp")).alias("min_timestamp_detected"),
F.max(col("timestamp")).alias("max_timestamp_detected"),
)
###



# df.transform(read_from_stream).transform(unpack_json_from_status_notification_request).transform(select_columns).transform(aggregate_window_watermark).printSchema()

display(df.transform(read_from_stream).transform(unpack_json_from_status_notification_request).transform(select_columns).transform(aggregate_window_watermark), outputMode="update")


# COMMAND ----------


display(df.transform(read_from_stream).transform(unpack_json_from_status_notification_request).transform(select_columns).transform(aggregate_window_watermark), outputMode="append")

# COMMAND ----------

Expand All @@ -506,6 +465,10 @@ def aggregate_window_watermark(input_df: DataFrame) -> DataFrame:
# MAGIC @Syed: not sure if this is possible but have a think about it, maybe do some research - don't spend too much time on this if it's too complicated. But we should say why we don't have a unit test if we don't do it here. I THINK we should be able to test for the schema shape which should be a good start. Is there a way to get some metadata about the watermarks and the window time?
# MAGIC
# MAGIC @kelsey: Watermarks I'm not sure, will do some little research. But windows, as you see, we get the start and end time. Any other metadata you are looking for ?
# MAGIC
# MAGIC [Pawarit]: Yes, definitely get the window as an extra column to help people understand + sanity check. The unit test below was not working correctly because of two reasons:
# MAGIC 1. If you use the batch DataFrame APIs, the watermarks become irrelevant/ignored
# MAGIC 2. To disregard late arriving data, you actually depend on *newly arriving* records to close past windows. For example, if you have data from 09:01 (which would belong to the 09:00-09:05 window) and your lateness threshold is 10 minutes, you'll need to wait to see a record from later than 09:15 in order to finalize and send off the 09:00-09:05 window.

# COMMAND ----------

Expand All @@ -516,74 +479,103 @@ def aggregate_window_watermark(input_df: DataFrame) -> DataFrame:

def test_aggregate_window_watermark_unit(spark, f: Callable):

input_pandas = pd.DataFrame([
{
"charge_point_id": "444984d5-0b9c-474e-a972-71347941ae0e",
"status": "Reserved",
# "window": json.dumps({"start": "2023-01-01T09:25:00.000+0000", "end": "2023-01-01T09:30:00.000+0000"}),
"timestamp": "2023-01-01T09:25:00.000+0000",
# "count(status)": 1
},
{
"charge_point_id": "444984d5-0b9c-474e-a972-71347941ae0e",
"status": "Reserved",
# "window": json.dumps({"start": "2023-01-01T09:25:00.000+0000", "end": "2023-01-01T09:30:00.000+0000"}),
"timestamp": "2022-01-01T09:25:00.000+0000",
# "count(status)": 1
},
])

# window_schema = StructType([
# StructField("start", StringType(),True),
# StructField("end", StringType(),True)
# ])

input_df = spark.createDataFrame(
input_pandas,
StructType([
sequence = [
{
"charge_point_id": "444984d5-0b9c-474e-a972-71347941ae0e",
"status": "Reserved",
"timestamp": "2023-01-01T09:00:00.000+0000",
},
{
"charge_point_id": "444984d5-0b9c-474e-a972-71347941ae0e",
"status": "Reserved",
"timestamp": "2023-01-01T09:16:00.000+0000", # this record shouldn't show up because there's no later record to close this window yet. We would need something later than 09:30 to close this record's window
},
{
"charge_point_id": "444984d5-0b9c-474e-a972-71347941ae0e",
"status": "Reserved",
"timestamp": "2023-01-01T09:04:59.000+0000", # this record should be ignored due to lateness
}
]
input_schema = StructType([
StructField("charge_point_id", StringType()),
StructField("status", StringType()),
StructField("timestamp", StringType())
# StructField("count(status)", LongType())
])
).withColumn("timestamp", to_timestamp("timestamp"))

print("-->Input Schema")
input_df.printSchema()

result = input_df.transform(f)
# result = (input_df.transform(f).output("update")) #I have also seen this weird parenthesis execution for streaming
# result = input_df.transform(f).output("update") #Syed can you check this?
print("Transformed DF")

StructField("timestamp", StringType())]
)

exercise_directory = f"{working_directory}/watermark_exercise/"
dbutils.fs.rm(f"dbfs:{exercise_directory}", True)

input_location = f"{exercise_directory}/input_stream/"
output_location = f"{exercise_directory}/output_stream/"
checkpoint_location = f"{exercise_directory}/checkpoint/"

# To properly simulate the effect of watermarking,
# we need to run across multiple triggers.
# Spark will first check for the latest timestamp in the current trigger,
# then use that as the criteria for dropping late events in the *next* trigger.
# The easiest way to achieve this control is to use:
# - Delta Lake as a streaming sink and source with trigger(availableNow=True)
# - A simple Python for loop to move to the next trigger
for record in sequence:

record_pandas = pd.DataFrame([record])
record_df = (spark.createDataFrame(record_pandas, input_schema)
.withColumn("timestamp", to_timestamp("timestamp"))
)
record_df.write.mode("append").format("delta").save(input_location)

fancy_streaming_df = (
spark
.readStream.format("delta")
.load(input_location)
.transform(f)
)

streaming_query = (
fancy_streaming_df
.writeStream
.format("delta")
.option("path", output_location)
.option("checkpointLocation", checkpoint_location)
.option("outputMode", "append")
.trigger(availableNow=True)
.start()
.awaitTermination()
)

result = spark.read.format("delta").load(output_location)
print("-->Result Schema")
result.printSchema()

# Schema Shape Test
result_schema = result.schema
expected_schema = StructType(
[
StructField("charge_point_id", StringType(),True),
StructField("status", StringType(),True),
StructField("window", StructType([
StructField("start", TimestampType(),True),
StructField("end", TimestampType(),True)
]),False),
StructField("count(status)", LongType(),False),
])
assert result_schema == expected_schema, f"Expected {expected_schema}, but got {result_schema}"
# Test schema internals:
# This case is a little trickier because the timeWindow field
# contains more metadata such as watermarkDelayMs
result_schema_json_string = result.schema.json()
expected_schema_json_string = '{"fields":[{"metadata":{},"name":"charge_point_id","nullable":true,"type":"string"},{"metadata":{},"name":"status","nullable":true,"type":"string"},{"metadata":{"spark.timeWindow":true,"spark.watermarkDelayMs":600000},"name":"window","nullable":true,"type":{"fields":[{"metadata":{},"name":"start","nullable":true,"type":"timestamp"},{"metadata":{},"name":"end","nullable":true,"type":"timestamp"}],"type":"struct"}},{"metadata":{},"name":"count_status_updates","nullable":true,"type":"long"},{"metadata":{},"name":"min_timestamp_detected","nullable":true,"type":"timestamp"},{"metadata":{},"name":"max_timestamp_detected","nullable":true,"type":"timestamp"}],"type":"struct"}'

result_records = [(x.charge_point_id, x.status, x.window.start, x.window.end) for x in result.collect()]
assert result_schema_json_string == expected_schema_json_string, f"""Expected {result_schema_json_string}, but got {expected_schema_json_string}"""

result_records = [x.asDict(True) for x in result.collect()]
expected_records = [
("444984d5-0b9c-474e-a972-71347941ae0e", "Reserved", datetime.datetime(2023, 1, 1, 9, 25), datetime.datetime(2023, 1, 1, 9, 30))
{
'charge_point_id': '444984d5-0b9c-474e-a972-71347941ae0e',
'status': 'Reserved',
'window': {
'start': datetime.datetime(2023, 1, 1, 9, 0),
'end': datetime.datetime(2023, 1, 1, 9, 5)
},
'count_status_updates': 1,
'min_timestamp_detected': datetime.datetime(2023, 1, 1, 9, 0),
'max_timestamp_detected': datetime.datetime(2023, 1, 1, 9, 0)
}
]
assert result_records == expected_records, f"Expected {expected_records}, but got {result_records}"

result_count = result.count()
expected_count = 1
assert result_count == expected_count, f"Expected {expected_count}, but got {result_count}"


dbutils.fs.rm(f"dbfs:{exercise_directory}", True)
print("All tests pass! :)")

test_aggregate_window_watermark_unit(spark, aggregate_window_watermark)
Expand Down Expand Up @@ -635,11 +627,23 @@ def test_aggregate_window_watermark_unit(spark, f: Callable):

# COMMAND ----------

# MAGIC %sql select `Time Window`, `Charge Point Id`, `Status Count` from counts;
# MAGIC %sql
# MAGIC select
# MAGIC `Time Window`,
# MAGIC `Charge Point Id`,
# MAGIC `Status Count`
# MAGIC from
# MAGIC counts;

# COMMAND ----------

# MAGIC %sql select date_format(`Time Window`.end, "MMM-dd HH:mm") as time from counts
# MAGIC %sql
# MAGIC select
# MAGIC date_format(`Time Window`.end,
# MAGIC "MMM-dd HH:mm"
# MAGIC ) as time
# MAGIC from
# MAGIC counts

# COMMAND ----------

Expand All @@ -649,3 +653,21 @@ def test_aggregate_window_watermark_unit(spark, f: Callable):
# COMMAND ----------

# MAGIC %md

# COMMAND ----------

# MAGIC %md
# MAGIC
# MAGIC ### Clean-up
# MAGIC - Shut down all streams

# COMMAND ----------

import time

time.sleep(300)
for s in spark.streams.active:
try:
s.stop()
except:
pass