Skip to content

Commit

Permalink
Adding foundation model example notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed Jun 5, 2024
1 parent d43969c commit a12d9ef
Show file tree
Hide file tree
Showing 5 changed files with 1,128 additions and 149 deletions.
108 changes: 66 additions & 42 deletions examples/foundation-model-examples/chronos-example.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,57 @@
# Databricks notebook source
# MAGIC %md
# MAGIC This is an example notebook that shows how to use [chronos](https://github.com/amazon-science/chronos-forecasting/tree/main) models on Databricks.
# MAGIC
# MAGIC The notebook loads the model, distributes the inference, registers the model, deploys the model and makes online forecasts.

# COMMAND ----------

# MAGIC %pip install git+https://github.com/amazon-science/chronos-forecasting.git --quiet
# MAGIC dbutils.library.restartPython()

# COMMAND ----------

# MAGIC %md
# MAGIC ## Distributed Inference
# MAGIC ## Prepare Data

# COMMAND ----------

catalog = "solacc_uc" # Name of the catalog we use to manage our assets
db = "mmf" # Name of the schema we use to manage our assets (e.g. datasets)
n = 100 # Number of time series to sample

# COMMAND ----------

# This cell will create tables: {catalog}.{db}.m4_daily_train, {catalog}.{db}.m4_monthly_train, {catalog}.{db}.rossmann_daily_train, {catalog}.{db}.rossmann_daily_test

dbutils.notebook.run("data_preparation", timeout_seconds=0, arguments={"catalog": catalog, "db": db, "n": n})

# COMMAND ----------

# MAGIC %md
# MAGIC

# COMMAND ----------

from pyspark.sql.functions import collect_list

# Make sure that the data exists
df = spark.table(f'{catalog}.{db}.m4_daily_train')
df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y'))
display(df)

# COMMAND ----------

# MAGIC %md
# MAGIC ## Distribute Inference

# COMMAND ----------

import pandas as pd
import numpy as np
import torch
from typing import Iterator
from pyspark.sql.functions import collect_list, pandas_udf
from pyspark.sql.functions import pandas_udf


def create_get_horizon_timestamps(freq, prediction_length):
Expand Down Expand Up @@ -66,22 +104,23 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:

# COMMAND ----------

# Make sure that the data exists
df = spark.table('solacc_uc.mmf.m4_daily_train')
df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y'))
chronos_model = "chronos-t5-tiny" # Alternatibely chronos-t5-mini, chronos-t5-small, chronos-t5-base, chronos-t5-large
prediction_length = 10 # Time horizon for forecasting
num_samples = 10 # Number of forecast to generate. We will take median as our final forecast.
batch_size = 4 # Number of time series to process simultaneously
freq = "D" # Frequency of the time series
device_count = torch.cuda.device_count() # Number of GPUs available

# COMMAND ----------

get_horizon_timestamps = create_get_horizon_timestamps(freq="D", prediction_length=10)
get_horizon_timestamps = create_get_horizon_timestamps(freq=freq, prediction_length=prediction_length)
forecast_udf = create_forecast_udf(
repository="amazon/chronos-t5-tiny",
prediction_length=10,
num_samples=10,
batch_size=4,
repository=f"amazon/{chronos_model}",
prediction_length=prediction_length,
num_samples=num_samples,
batch_size=batch_size,
)

device_count = torch.cuda.device_count()

forecasts = df.repartition(device_count).select(
df.unique_id,
get_horizon_timestamps(df.ds).alias("ds"),
Expand All @@ -100,7 +139,6 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
import mlflow
import torch
import numpy as np
from mlflow.models import infer_signature
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, TensorSpec
mlflow.set_registry_uri("databricks-uc")
Expand All @@ -125,17 +163,18 @@ def predict(self, context, input_data, params=None):
)
return forecast.numpy()

pipeline = ChronosModel("amazon/chronos-t5-tiny")
pipeline = ChronosModel(f"amazon/{chronos_model}")
input_schema = Schema([TensorSpec(np.dtype(np.double), (-1, -1))])
output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1, -1, -1))])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
input_example = np.random.rand(1, 52)
registered_model_name=f"{catalog}.{db}.{chronos_model}"

with mlflow.start_run() as run:
mlflow.pyfunc.log_model(
"model",
python_model=pipeline,
registered_model_name="solacc_uc.mmf.chronos-t5-tiny",
registered_model_name=registered_model_name,
signature=signature,
input_example=input_example,
pip_requirements=[
Expand All @@ -145,34 +184,27 @@ def predict(self, context, input_data, params=None):

# COMMAND ----------

import mlflow
from mlflow import MlflowClient
import pandas as pd
import numpy as np

client = MlflowClient()
mlflow.set_registry_uri("databricks-uc")


def get_latest_model_version(client, registered_name):
def get_latest_model_version(client, registered_model_name):
latest_version = 1
for mv in client.search_model_versions(f"name='{registered_name}'"):
for mv in client.search_model_versions(f"name='{registered_model_name}'"):
version_int = int(mv.version)
if version_int > latest_version:
latest_version = version_int
return latest_version

registered_name = f"solacc_uc.mmf.chronos-t5-tiny"
model_version = get_latest_model_version(client, registered_name)
logged_model = f"models:/{registered_name}/{model_version}"
model_version = get_latest_model_version(client, registered_model_name)
logged_model = f"models:/{registered_model_name}/{model_version}"

# Load model as a PyFuncModel
loaded_model = mlflow.pyfunc.load_model(logged_model)

# Create random input data
input_data = np.random.rand(5, 52) # (batch, series)

# Generate forecast
# Generate forecasts
loaded_model.predict(input_data)

# COMMAND ----------
Expand All @@ -182,13 +214,8 @@ def get_latest_model_version(client, registered_name):

# COMMAND ----------

catalog = "solacc_uc"
log_schema = "mmf" # A schema within the catalog where the inferece log is going to be stored
model_serving_endpoint_name = "chronos-t5-tiny"

token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

# With the token, you can create our authorization header for our subsequent REST calls
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}

# Next you need an endpoint at which to execute your request which you can get from the notebook's tags collection
Expand All @@ -204,12 +231,14 @@ def get_latest_model_version(client, registered_name):

import requests

model_serving_endpoint_name = chronos_model

my_json = {
"name": model_serving_endpoint_name,
"config": {
"served_models": [
{
"model_name": registered_name,
"model_name": registered_model_name,
"model_version": model_version,
"workload_type": "GPU_SMALL",
"workload_size": "Small",
Expand All @@ -218,20 +247,15 @@ def get_latest_model_version(client, registered_name):
],
"auto_capture_config": {
"catalog_name": catalog,
"schema_name": log_schema,
"schema_name": db,
"table_name_prefix": model_serving_endpoint_name,
},
},
}

# Make sure to the schema for the inference table exists
_ = spark.sql(
f"CREATE SCHEMA IF NOT EXISTS {catalog}.{log_schema}"
)

# Make sure to drop the inference table of it exists
_ = spark.sql(
f"DROP TABLE IF EXISTS {catalog}.{log_schema}.`{model_serving_endpoint_name}_payload`"
f"DROP TABLE IF EXISTS {catalog}.{db}.`{model_serving_endpoint_name}_payload`"
)

# COMMAND ----------
Expand Down Expand Up @@ -338,7 +362,7 @@ def wait_for_endpoint():
# COMMAND ----------

# MAGIC %md
# MAGIC ## Online Inference
# MAGIC ## Online Forecast

# COMMAND ----------

Expand Down
144 changes: 144 additions & 0 deletions examples/foundation-model-examples/data_preparation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Databricks notebook source
# MAGIC %pip install datasetsforecast==0.0.8 --quiet
# MAGIC dbutils.library.restartPython()

# COMMAND ----------

import pathlib
import pandas as pd
from datasetsforecast.m4 import M4
import logging
logger = spark._jvm.org.apache.log4j
logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
logging.getLogger("py4j.clientserver").setLevel(logging.ERROR)

# COMMAND ----------

dbutils.widgets.text("catalog", "")
dbutils.widgets.text("db", "")
dbutils.widgets.text("n", "")
dbutils.widgets.text("volume", "rossmann")

catalog = dbutils.widgets.get("catalog") # Name of the catalog we use to manage our assets
db = dbutils.widgets.get("db") # Name of the schema we use to manage our assets (e.g. datasets)
volume = dbutils.widgets.get("volume") # Name of the schema where you have your rossmann dataset csv sotred
n = int(dbutils.widgets.get("n")) # Number of time series to sample

# Make sure the catalog, schema and volume exist
_ = spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog}")
_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{db}")
_ = spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{db}.{volume}")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Daily M4 Data

# COMMAND ----------

def create_m4_daily():
y_df, _, _ = M4.load(directory=str(pathlib.Path.home()), group="Daily")
_ids = [f"D{i}" for i in range(1, n)]
y_df = (
y_df.groupby("unique_id")
.filter(lambda x: x.unique_id.iloc[0] in _ids)
.groupby("unique_id")
.apply(transform_group_daily)
.reset_index(drop=True)
)
return y_df


def transform_group_daily(df):
unique_id = df.unique_id.iloc[0]
_start = pd.Timestamp("2020-01-01")
_end = _start + pd.DateOffset(days=int(df.count()[0]) - 1)
date_idx = pd.date_range(start=_start, end=_end, freq="D", name="ds")
res_df = pd.DataFrame(data=[], index=date_idx).reset_index()
res_df["unique_id"] = unique_id
res_df["y"] = df.y.values
return res_df


(
spark.createDataFrame(create_m4_daily())
.write.format("delta").mode("overwrite")
.saveAsTable(f"{catalog}.{db}.m4_daily_train")
)

print(f"Saved data to {catalog}.{db}.m4_daily_train")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Monthly M4 Data

# COMMAND ----------

def create_m4_monthly():
y_df, _, _ = M4.load(directory=str(pathlib.Path.home()), group="Monthly")
_ids = [f"M{i}" for i in range(1, n + 1)]
y_df = (
y_df.groupby("unique_id")
.filter(lambda x: x.unique_id.iloc[0] in _ids)
.groupby("unique_id")
.apply(transform_group_monthly)
.reset_index(drop=True)
)
return y_df


def transform_group_monthly(df):
unique_id = df.unique_id.iloc[0]
_cnt = 60 # df.count()[0]
_start = pd.Timestamp("2018-01-01")
_end = _start + pd.DateOffset(months=_cnt)
date_idx = pd.date_range(start=_start, end=_end, freq="M", name="date")
_df = (
pd.DataFrame(data=[], index=date_idx)
.reset_index()
.rename(columns={"index": "date"})
)
_df["unique_id"] = unique_id
_df["y"] = df[:60].y.values
return _df


(
spark.createDataFrame(create_m4_monthly())
.write.format("delta").mode("overwrite")
.saveAsTable(f"{catalog}.{db}.m4_monthly_train")
)

print(f"Saved data to {catalog}.{db}.m4_monthly_train")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Daily Rossmann with Exogenous Regressors

# COMMAND ----------

# MAGIC %md Download the dataset from [Kaggle](kaggle.com/competitions/rossmann-store-sales/data) and store them in the volume.

# COMMAND ----------

# Randomly select 100 stores to forecast
import random
random.seed(7)

# Number of time series to sample
sample = True
stores = sorted(random.sample(range(0, 1000), n))

train = spark.read.csv(f"/Volumes/{catalog}/{db}/{volume}/train.csv", header=True, inferSchema=True)
test = spark.read.csv(f"/Volumes/{catalog}/{db}/{volume}/test.csv", header=True, inferSchema=True)

if sample:
train = train.filter(train.Store.isin(stores))
test = test.filter(test.Store.isin(stores))

train.write.mode("overwrite").option("mergeSchema", "true").saveAsTable(f"{catalog}.{db}.rossmann_daily_train")
test.write.mode("overwrite").option("mergeSchema", "true").saveAsTable(f"{catalog}.{db}.rossmann_daily_test")

print(f"Saved data to {catalog}.{db}.rossmann_daily_train and {catalog}.{db}.rossmann_daily_test")
Loading

0 comments on commit a12d9ef

Please sign in to comment.