From a12d9ef1151f2ae6a4c4b26d91c01f1a9659b095 Mon Sep 17 00:00:00 2001 From: "ryuta.yoshimatsu@databricks.com" Date: Wed, 5 Jun 2024 06:45:22 +0000 Subject: [PATCH] Adding foundation model example notebooks --- .../chronos-example.py | 108 +++-- .../data_preparation.py | 144 +++++++ .../moirai-example.py | 293 ++++++++++++-- .../moment-example.py | 374 +++++++++++++++--- .../timesfm-example.py | 358 ++++++++++++++++- 5 files changed, 1128 insertions(+), 149 deletions(-) create mode 100644 examples/foundation-model-examples/data_preparation.py diff --git a/examples/foundation-model-examples/chronos-example.py b/examples/foundation-model-examples/chronos-example.py index 78a3920..63fd816 100644 --- a/examples/foundation-model-examples/chronos-example.py +++ b/examples/foundation-model-examples/chronos-example.py @@ -1,11 +1,49 @@ # Databricks notebook source +# MAGIC %md +# MAGIC This is an example notebook that shows how to use [chronos](https://github.com/amazon-science/chronos-forecasting/tree/main) models on Databricks. +# MAGIC +# MAGIC The notebook loads the model, distributes the inference, registers the model, deploys the model and makes online forecasts. + +# COMMAND ---------- + # MAGIC %pip install git+https://github.com/amazon-science/chronos-forecasting.git --quiet # MAGIC dbutils.library.restartPython() # COMMAND ---------- # MAGIC %md -# MAGIC ## Distributed Inference +# MAGIC ## Prepare Data + +# COMMAND ---------- + +catalog = "solacc_uc" # Name of the catalog we use to manage our assets +db = "mmf" # Name of the schema we use to manage our assets (e.g. datasets) +n = 100 # Number of time series to sample + +# COMMAND ---------- + +# This cell will create tables: {catalog}.{db}.m4_daily_train, {catalog}.{db}.m4_monthly_train, {catalog}.{db}.rossmann_daily_train, {catalog}.{db}.rossmann_daily_test + +dbutils.notebook.run("data_preparation", timeout_seconds=0, arguments={"catalog": catalog, "db": db, "n": n}) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC + +# COMMAND ---------- + +from pyspark.sql.functions import collect_list + +# Make sure that the data exists +df = spark.table(f'{catalog}.{db}.m4_daily_train') +df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y')) +display(df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Distribute Inference # COMMAND ---------- @@ -13,7 +51,7 @@ import numpy as np import torch from typing import Iterator -from pyspark.sql.functions import collect_list, pandas_udf +from pyspark.sql.functions import pandas_udf def create_get_horizon_timestamps(freq, prediction_length): @@ -66,22 +104,23 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: # COMMAND ---------- -# Make sure that the data exists -df = spark.table('solacc_uc.mmf.m4_daily_train') -df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y')) +chronos_model = "chronos-t5-tiny" # Alternatibely chronos-t5-mini, chronos-t5-small, chronos-t5-base, chronos-t5-large +prediction_length = 10 # Time horizon for forecasting +num_samples = 10 # Number of forecast to generate. We will take median as our final forecast. +batch_size = 4 # Number of time series to process simultaneously +freq = "D" # Frequency of the time series +device_count = torch.cuda.device_count() # Number of GPUs available # COMMAND ---------- -get_horizon_timestamps = create_get_horizon_timestamps(freq="D", prediction_length=10) +get_horizon_timestamps = create_get_horizon_timestamps(freq=freq, prediction_length=prediction_length) forecast_udf = create_forecast_udf( - repository="amazon/chronos-t5-tiny", - prediction_length=10, - num_samples=10, - batch_size=4, + repository=f"amazon/{chronos_model}", + prediction_length=prediction_length, + num_samples=num_samples, + batch_size=batch_size, ) -device_count = torch.cuda.device_count() - forecasts = df.repartition(device_count).select( df.unique_id, get_horizon_timestamps(df.ds).alias("ds"), @@ -100,7 +139,6 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: import mlflow import torch import numpy as np -from mlflow.models import infer_signature from mlflow.models.signature import ModelSignature from mlflow.types import DataType, Schema, TensorSpec mlflow.set_registry_uri("databricks-uc") @@ -125,17 +163,18 @@ def predict(self, context, input_data, params=None): ) return forecast.numpy() -pipeline = ChronosModel("amazon/chronos-t5-tiny") +pipeline = ChronosModel(f"amazon/{chronos_model}") input_schema = Schema([TensorSpec(np.dtype(np.double), (-1, -1))]) output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1, -1, -1))]) signature = ModelSignature(inputs=input_schema, outputs=output_schema) input_example = np.random.rand(1, 52) +registered_model_name=f"{catalog}.{db}.{chronos_model}" with mlflow.start_run() as run: mlflow.pyfunc.log_model( "model", python_model=pipeline, - registered_model_name="solacc_uc.mmf.chronos-t5-tiny", + registered_model_name=registered_model_name, signature=signature, input_example=input_example, pip_requirements=[ @@ -145,26 +184,19 @@ def predict(self, context, input_data, params=None): # COMMAND ---------- -import mlflow from mlflow import MlflowClient -import pandas as pd -import numpy as np - client = MlflowClient() -mlflow.set_registry_uri("databricks-uc") - -def get_latest_model_version(client, registered_name): +def get_latest_model_version(client, registered_model_name): latest_version = 1 - for mv in client.search_model_versions(f"name='{registered_name}'"): + for mv in client.search_model_versions(f"name='{registered_model_name}'"): version_int = int(mv.version) if version_int > latest_version: latest_version = version_int return latest_version -registered_name = f"solacc_uc.mmf.chronos-t5-tiny" -model_version = get_latest_model_version(client, registered_name) -logged_model = f"models:/{registered_name}/{model_version}" +model_version = get_latest_model_version(client, registered_model_name) +logged_model = f"models:/{registered_model_name}/{model_version}" # Load model as a PyFuncModel loaded_model = mlflow.pyfunc.load_model(logged_model) @@ -172,7 +204,7 @@ def get_latest_model_version(client, registered_name): # Create random input data input_data = np.random.rand(5, 52) # (batch, series) -# Generate forecast +# Generate forecasts loaded_model.predict(input_data) # COMMAND ---------- @@ -182,13 +214,8 @@ def get_latest_model_version(client, registered_name): # COMMAND ---------- -catalog = "solacc_uc" -log_schema = "mmf" # A schema within the catalog where the inferece log is going to be stored -model_serving_endpoint_name = "chronos-t5-tiny" - -token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) - # With the token, you can create our authorization header for our subsequent REST calls +token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} # Next you need an endpoint at which to execute your request which you can get from the notebook's tags collection @@ -204,12 +231,14 @@ def get_latest_model_version(client, registered_name): import requests +model_serving_endpoint_name = chronos_model + my_json = { "name": model_serving_endpoint_name, "config": { "served_models": [ { - "model_name": registered_name, + "model_name": registered_model_name, "model_version": model_version, "workload_type": "GPU_SMALL", "workload_size": "Small", @@ -218,20 +247,15 @@ def get_latest_model_version(client, registered_name): ], "auto_capture_config": { "catalog_name": catalog, - "schema_name": log_schema, + "schema_name": db, "table_name_prefix": model_serving_endpoint_name, }, }, } -# Make sure to the schema for the inference table exists -_ = spark.sql( - f"CREATE SCHEMA IF NOT EXISTS {catalog}.{log_schema}" -) - # Make sure to drop the inference table of it exists _ = spark.sql( - f"DROP TABLE IF EXISTS {catalog}.{log_schema}.`{model_serving_endpoint_name}_payload`" + f"DROP TABLE IF EXISTS {catalog}.{db}.`{model_serving_endpoint_name}_payload`" ) # COMMAND ---------- @@ -338,7 +362,7 @@ def wait_for_endpoint(): # COMMAND ---------- # MAGIC %md -# MAGIC ## Online Inference +# MAGIC ## Online Forecast # COMMAND ---------- diff --git a/examples/foundation-model-examples/data_preparation.py b/examples/foundation-model-examples/data_preparation.py new file mode 100644 index 0000000..0aa501a --- /dev/null +++ b/examples/foundation-model-examples/data_preparation.py @@ -0,0 +1,144 @@ +# Databricks notebook source +# MAGIC %pip install datasetsforecast==0.0.8 --quiet +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +import pathlib +import pandas as pd +from datasetsforecast.m4 import M4 +import logging +logger = spark._jvm.org.apache.log4j +logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR) +logging.getLogger("py4j.clientserver").setLevel(logging.ERROR) + +# COMMAND ---------- + +dbutils.widgets.text("catalog", "") +dbutils.widgets.text("db", "") +dbutils.widgets.text("n", "") +dbutils.widgets.text("volume", "rossmann") + +catalog = dbutils.widgets.get("catalog") # Name of the catalog we use to manage our assets +db = dbutils.widgets.get("db") # Name of the schema we use to manage our assets (e.g. datasets) +volume = dbutils.widgets.get("volume") # Name of the schema where you have your rossmann dataset csv sotred +n = int(dbutils.widgets.get("n")) # Number of time series to sample + +# Make sure the catalog, schema and volume exist +_ = spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog}") +_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{db}") +_ = spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{db}.{volume}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Daily M4 Data + +# COMMAND ---------- + +def create_m4_daily(): + y_df, _, _ = M4.load(directory=str(pathlib.Path.home()), group="Daily") + _ids = [f"D{i}" for i in range(1, n)] + y_df = ( + y_df.groupby("unique_id") + .filter(lambda x: x.unique_id.iloc[0] in _ids) + .groupby("unique_id") + .apply(transform_group_daily) + .reset_index(drop=True) + ) + return y_df + + +def transform_group_daily(df): + unique_id = df.unique_id.iloc[0] + _start = pd.Timestamp("2020-01-01") + _end = _start + pd.DateOffset(days=int(df.count()[0]) - 1) + date_idx = pd.date_range(start=_start, end=_end, freq="D", name="ds") + res_df = pd.DataFrame(data=[], index=date_idx).reset_index() + res_df["unique_id"] = unique_id + res_df["y"] = df.y.values + return res_df + + +( + spark.createDataFrame(create_m4_daily()) + .write.format("delta").mode("overwrite") + .saveAsTable(f"{catalog}.{db}.m4_daily_train") +) + +print(f"Saved data to {catalog}.{db}.m4_daily_train") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Monthly M4 Data + +# COMMAND ---------- + +def create_m4_monthly(): + y_df, _, _ = M4.load(directory=str(pathlib.Path.home()), group="Monthly") + _ids = [f"M{i}" for i in range(1, n + 1)] + y_df = ( + y_df.groupby("unique_id") + .filter(lambda x: x.unique_id.iloc[0] in _ids) + .groupby("unique_id") + .apply(transform_group_monthly) + .reset_index(drop=True) + ) + return y_df + + +def transform_group_monthly(df): + unique_id = df.unique_id.iloc[0] + _cnt = 60 # df.count()[0] + _start = pd.Timestamp("2018-01-01") + _end = _start + pd.DateOffset(months=_cnt) + date_idx = pd.date_range(start=_start, end=_end, freq="M", name="date") + _df = ( + pd.DataFrame(data=[], index=date_idx) + .reset_index() + .rename(columns={"index": "date"}) + ) + _df["unique_id"] = unique_id + _df["y"] = df[:60].y.values + return _df + + +( + spark.createDataFrame(create_m4_monthly()) + .write.format("delta").mode("overwrite") + .saveAsTable(f"{catalog}.{db}.m4_monthly_train") +) + +print(f"Saved data to {catalog}.{db}.m4_monthly_train") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Daily Rossmann with Exogenous Regressors + +# COMMAND ---------- + +# MAGIC %md Download the dataset from [Kaggle](kaggle.com/competitions/rossmann-store-sales/data) and store them in the volume. + +# COMMAND ---------- + +# Randomly select 100 stores to forecast +import random +random.seed(7) + +# Number of time series to sample +sample = True +stores = sorted(random.sample(range(0, 1000), n)) + +train = spark.read.csv(f"/Volumes/{catalog}/{db}/{volume}/train.csv", header=True, inferSchema=True) +test = spark.read.csv(f"/Volumes/{catalog}/{db}/{volume}/test.csv", header=True, inferSchema=True) + +if sample: + train = train.filter(train.Store.isin(stores)) + test = test.filter(test.Store.isin(stores)) + +train.write.mode("overwrite").option("mergeSchema", "true").saveAsTable(f"{catalog}.{db}.rossmann_daily_train") +test.write.mode("overwrite").option("mergeSchema", "true").saveAsTable(f"{catalog}.{db}.rossmann_daily_test") + +print(f"Saved data to {catalog}.{db}.rossmann_daily_train and {catalog}.{db}.rossmann_daily_test") diff --git a/examples/foundation-model-examples/moirai-example.py b/examples/foundation-model-examples/moirai-example.py index 54402a6..87134be 100644 --- a/examples/foundation-model-examples/moirai-example.py +++ b/examples/foundation-model-examples/moirai-example.py @@ -1,11 +1,44 @@ # Databricks notebook source +# MAGIC %md +# MAGIC This is an example notebook that shows how to use [Moirai](https://github.com/SalesforceAIResearch/uni2ts) models on Databricks. +# MAGIC +# MAGIC The notebook loads the model, distributes the inference, registers the model, deploys the model and makes online forecasts. + +# COMMAND ---------- + # MAGIC %pip install git+https://github.com/SalesforceAIResearch/uni2ts.git --quiet # MAGIC dbutils.library.restartPython() # COMMAND ---------- # MAGIC %md -# MAGIC ## Distributed Inference +# MAGIC ## Prepare Data + +# COMMAND ---------- + +catalog = "solacc_uc" # Name of the catalog we use to manage our assets +db = "mmf" # Name of the schema we use to manage our assets (e.g. datasets) +n = 100 # Number of time series to sample + +# COMMAND ---------- + +# This cell will create tables: {catalog}.{db}.m4_daily_train, {catalog}.{db}.m4_monthly_train, {catalog}.{db}.rossmann_daily_train, {catalog}.{db}.rossmann_daily_test + +dbutils.notebook.run("data_preparation", timeout_seconds=0, arguments={"catalog": catalog, "db": db, "n": n}) + +# COMMAND ---------- + +from pyspark.sql.functions import collect_list + +# Make sure that the data exists +df = spark.table(f'{catalog}.{db}.m4_daily_train') +df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y')) +display(df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Distribute Inference # COMMAND ---------- @@ -14,10 +47,30 @@ import torch from einops import rearrange from typing import Iterator -from pyspark.sql.functions import collect_list, pandas_udf +from pyspark.sql.functions import pandas_udf + + +def create_get_horizon_timestamps(freq, prediction_length): + + @pandas_udf('array') + def get_horizon_timestamps(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + one_ts_offset = pd.offsets.MonthEnd(1) if freq == "M" else pd.DateOffset(days=1) + barch_horizon_timestamps = [] + for batch in batch_iterator: + for series in batch: + timestamp = last = series.max() + horizon_timestamps = [] + for i in range(prediction_length): + timestamp = timestamp + one_ts_offset + horizon_timestamps.append(timestamp) + barch_horizon_timestamps.append(np.array(horizon_timestamps)) + yield pd.Series(barch_horizon_timestamps) + return get_horizon_timestamps + + +def create_forecast_udf(repository, prediction_length, patch_size, num_samples): -def create_forecast_udf(): @pandas_udf('array') def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: @@ -26,7 +79,7 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: import numpy as np import pandas as pd from uni2ts.model.moirai import MoiraiForecast, MoiraiModule - module = MoiraiModule.from_pretrained("Salesforce/moirai-1.0-R-small") + module = MoiraiModule.from_pretrained(repository) ## inference for bulk in bulk_iterator: @@ -34,10 +87,10 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: for series in bulk: model = MoiraiForecast( module=module, - prediction_length=10, + prediction_length=prediction_length, context_length=len(series), - patch_size=32, - num_samples=100, + patch_size=patch_size, + num_samples=num_samples, target_dim=1, feat_dynamic_real_dim=0, past_feat_dynamic_real_dim=0, @@ -55,16 +108,35 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: past_observed_target=past_observed_target, past_is_pad=past_is_pad, ) - #print(f"median.append: {np.median(forecast[0], axis=0)}") median.append(np.median(forecast[0], axis=0)) yield pd.Series(median) return forecast_udf -forecast_udf = create_forecast_udf() -df = spark.table('solacc_uc.mmf.m4_daily_train') -df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y')) -device_count = torch.cuda.device_count() -forecasts = df.repartition(device_count).select(df.unique_id, forecast_udf(df.y).alias("forecast")) +# COMMAND ---------- + +moirai_model = "moirai-1.0-R-small" # Alternatibely moirai-1.0-R-base, moirai-1.0-R-large +prediction_length = 10 # Time horizon for forecasting +num_samples = 10 # Number of forecast to generate. We will take median as our final forecast. +patch_size = 32 # Patch size: choose from {"auto", 8, 16, 32, 64, 128} +freq = "D" # Frequency of the time series +device_count = torch.cuda.device_count() # Number of GPUs available + +# COMMAND ---------- + +get_horizon_timestamps = create_get_horizon_timestamps(freq=freq, prediction_length=prediction_length) + +forecast_udf = create_forecast_udf( + repository=f"Salesforce/{moirai_model}", + prediction_length=prediction_length, + patch_size=patch_size, + num_samples=num_samples, + ) + +forecasts = df.repartition(device_count).select( + df.unique_id, + get_horizon_timestamps(df.ds).alias("ds"), + forecast_udf(df.y).alias("forecast"), + ) display(forecasts) @@ -78,9 +150,10 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: import mlflow import torch import numpy as np -from mlflow.models import infer_signature from mlflow.models.signature import ModelSignature from mlflow.types import DataType, Schema, TensorSpec +mlflow.set_registry_uri("databricks-uc") + class MoiraiModel(mlflow.pyfunc.PythonModel): def __init__(self, repository): @@ -95,7 +168,7 @@ def predict(self, context, input_data, params=None): prediction_length=10, context_length=len(input_data), patch_size=32, - num_samples=100, + num_samples=10, target_dim=1, feat_dynamic_real_dim=0, past_feat_dynamic_real_dim=0, @@ -114,20 +187,20 @@ def predict(self, context, input_data, params=None): past_observed_target=past_observed_target, past_is_pad=past_is_pad, ) - #print(f"median.append: {np.median(forecast[0], axis=0)}") return np.median(forecast[0], axis=0) -pipeline = MoiraiModel("Salesforce/moirai-1.0-R-small") +pipeline = MoiraiModel(f"Salesforce/{moirai_model}") input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))]) output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))]) signature = ModelSignature(inputs=input_schema, outputs=output_schema) input_example = np.random.rand(52) +registered_model_name=f"{catalog}.{db}.moirai-1-r-small" with mlflow.start_run() as run: mlflow.pyfunc.log_model( "model", python_model=pipeline, - registered_model_name="solacc_uc.mmf.moirai_test", + registered_model_name=registered_model_name, signature=signature, input_example=input_example, pip_requirements=[ @@ -137,43 +210,184 @@ def predict(self, context, input_data, params=None): # COMMAND ---------- -pipeline.predict(None, input_example) - -# COMMAND ---------- - -import mlflow from mlflow import MlflowClient -import pandas as pd - -mlflow.set_registry_uri("databricks-uc") mlflow_client = MlflowClient() -def get_latest_model_version(mlflow_client, registered_name): +def get_latest_model_version(mlflow_client, registered_model_name): latest_version = 1 - for mv in mlflow_client.search_model_versions(f"name='{registered_name}'"): + for mv in mlflow_client.search_model_versions(f"name='{registered_model_name}'"): version_int = int(mv.version) if version_int > latest_version: latest_version = version_int return latest_version -model_name = "moirai_test" -registered_name = f"solacc_uc.mmf.{model_name}" -model_version = get_latest_model_version(mlflow_client, registered_name) -logged_model = f"models:/{registered_name}/{model_version}" +model_version = get_latest_model_version(mlflow_client, registered_model_name) +logged_model = f"models:/{registered_model_name}/{model_version}" # Load model as a PyFuncModel loaded_model = mlflow.pyfunc.load_model(logged_model) # COMMAND ---------- -import numpy as np input_data = np.random.rand(52) loaded_model.predict(input_data) # COMMAND ---------- # MAGIC %md -# MAGIC ## Online Inference +# MAGIC ## Deploy Model for Online Forecast + +# COMMAND ---------- + +# With the token, you can create our authorization header for our subsequent REST calls +token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) +headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + +# Next you need an endpoint at which to execute your request which you can get from the notebook's tags collection +java_tags = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags() + +# This object comes from the Java CM - Convert the Java Map opject to a Python dictionary +tags = sc._jvm.scala.collection.JavaConversions.mapAsJavaMap(java_tags) + +# Lastly, extract the Databricks instance (domain name) from the dictionary +instance = tags["browserHostName"] + +# COMMAND ---------- + +import requests + +model_serving_endpoint_name = "moirai-1-r-small" + +my_json = { + "name": model_serving_endpoint_name, + "config": { + "served_models": [ + { + "model_name": registered_model_name, + "model_version": model_version, + "workload_type": "GPU_SMALL", + "workload_size": "Small", + "scale_to_zero_enabled": "true", + } + ], + "auto_capture_config": { + "catalog_name": catalog, + "schema_name": db, + "table_name_prefix": model_serving_endpoint_name, + }, + }, +} + +# Make sure to drop the inference table of it exists +_ = spark.sql( + f"DROP TABLE IF EXISTS {catalog}.{db}.`{model_serving_endpoint_name}_payload`" +) + +# COMMAND ---------- + +def func_create_endpoint(model_serving_endpoint_name): + # get endpoint status + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + url = f"{endpoint_url}/{model_serving_endpoint_name}" + r = requests.get(url, headers=headers) + if "RESOURCE_DOES_NOT_EXIST" in r.text: + print( + "Creating this new endpoint: ", + f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations", + ) + re = requests.post(endpoint_url, headers=headers, json=my_json) + else: + new_model_version = (my_json["config"])["served_models"][0]["model_version"] + print( + "This endpoint existed previously! We are updating it to a new config with new model version: ", + new_model_version, + ) + # update config + url = f"{endpoint_url}/{model_serving_endpoint_name}/config" + re = requests.put(url, headers=headers, json=my_json["config"]) + # wait till new config file in place + import time, json + + # get endpoint status + url = f"https://{instance}/api/2.0/serving-endpoints/{model_serving_endpoint_name}" + retry = True + total_wait = 0 + while retry: + r = requests.get(url, headers=headers) + assert ( + r.status_code == 200 + ), f"Expected an HTTP 200 response when accessing endpoint info, received {r.status_code}" + endpoint = json.loads(r.text) + if "pending_config" in endpoint.keys(): + seconds = 10 + print("New config still pending") + if total_wait < 6000: + # if less the 10 mins waiting, keep waiting + print(f"Wait for {seconds} seconds") + print(f"Total waiting time so far: {total_wait} seconds") + time.sleep(10) + total_wait += seconds + else: + print(f"Stopping, waited for {total_wait} seconds") + retry = False + else: + print("New config in place now!") + retry = False + + assert ( + re.status_code == 200 + ), f"Expected an HTTP 200 response, received {re.status_code}" + + +def func_delete_model_serving_endpoint(model_serving_endpoint_name): + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + url = f"{endpoint_url}/{model_serving_endpoint_name}" + response = requests.delete(url, headers=headers) + if response.status_code != 200: + raise Exception( + f"Request failed with status {response.status_code}, {response.text}" + ) + else: + print(model_serving_endpoint_name, "endpoint is deleted!") + return response.json() + +# COMMAND ---------- + +func_create_endpoint(model_serving_endpoint_name) + +# COMMAND ---------- + +import time, mlflow + + +def wait_for_endpoint(): + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + while True: + url = f"{endpoint_url}/{model_serving_endpoint_name}" + response = requests.get(url, headers=headers) + assert ( + response.status_code == 200 + ), f"Expected an HTTP 200 response, received {response.status_code}\n{response.text}" + + status = response.json().get("state", {}).get("ready", {}) + # print("status",status) + if status == "READY": + print(status) + print("-" * 80) + return + else: + print(f"Endpoint not ready ({status}), waiting 5 miutes") + time.sleep(300) # Wait 300 seconds + + +api_url = mlflow.utils.databricks_utils.get_webapp_url() + +wait_for_endpoint() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Online Forecast # COMMAND ---------- @@ -184,7 +398,7 @@ def get_latest_model_version(mlflow_client, registered_name): import matplotlib.pyplot as plt # Replace URL with the end point invocation url you get from Model Seriving page. -endpoint_url = "https://e2-demo-field-eng.cloud.databricks.com/serving-endpoints/moirai-test/invocations" +endpoint_url = f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations" token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get() def forecast(input_data, url=endpoint_url, databricks_token=token): headers = { @@ -202,4 +416,13 @@ def forecast(input_data, url=endpoint_url, databricks_token=token): # COMMAND ---------- +input_data = np.random.rand(52) forecast(input_data) + +# COMMAND ---------- + +func_delete_model_serving_endpoint(model_serving_endpoint_name) + +# COMMAND ---------- + + diff --git a/examples/foundation-model-examples/moment-example.py b/examples/foundation-model-examples/moment-example.py index e4233a6..cfa41aa 100644 --- a/examples/foundation-model-examples/moment-example.py +++ b/examples/foundation-model-examples/moment-example.py @@ -1,9 +1,42 @@ # Databricks notebook source +# MAGIC %md +# MAGIC This is an example notebook that shows how to use [moment](https://github.com/moment-timeseries-foundation-model/moment) model on Databricks. +# MAGIC +# MAGIC The notebook loads the model, distributes the inference, registers the model, deploys the model and makes online forecasts. + +# COMMAND ---------- + # MAGIC %pip install git+https://github.com/moment-timeseries-foundation-model/moment.git --quiet # MAGIC dbutils.library.restartPython() # COMMAND ---------- +# MAGIC %md +# MAGIC ## Prepare Data + +# COMMAND ---------- + +catalog = "solacc_uc" # Name of the catalog we use to manage our assets +db = "mmf" # Name of the schema we use to manage our assets (e.g. datasets) +n = 100 # Number of time series to sample + +# COMMAND ---------- + +# This cell will create tables: {catalog}.{db}.m4_daily_train, {catalog}.{db}.m4_monthly_train, {catalog}.{db}.rossmann_daily_train, {catalog}.{db}.rossmann_daily_test + +dbutils.notebook.run("data_preparation", timeout_seconds=0, arguments={"catalog": catalog, "db": db, "n": n}) + +# COMMAND ---------- + +from pyspark.sql.functions import collect_list + +# Make sure that the data exists +df = spark.table(f'{catalog}.{db}.m4_daily_train') +df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y')) +display(df) + +# COMMAND ---------- + # MAGIC %md # MAGIC ## Distributed Inference @@ -13,49 +46,88 @@ import numpy as np import torch from typing import Iterator -from pyspark.sql.functions import collect_list, pandas_udf - -@pandas_udf('array') -def forecast_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: - ## initialization step - from momentfm import MOMENTPipeline - model = MOMENTPipeline.from_pretrained( - "AutonLab/MOMENT-1-large", - device_map="cuda", - model_kwargs={ - "task_name": "forecasting", - "forecast_horizon": 10}, - ) - model.init() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device) - - ## inference - for batch in batch_iterator: - batch_forecast = [] - for series in batch: - # takes in tensor of shape [batchsize, n_channels, context_length] - context = list(series) - if len(context) < 512: - input_mask = [1] * len(context) + [0] * (512 - len(context)) - context = context + [0] * (512 - len(context)) - else: - input_mask = [1] * 512 - context = context[-512:] - - input_mask = torch.reshape(torch.tensor(input_mask),(1, 512)).to(device) - context = torch.reshape(torch.tensor(context),(1, 1, 512)).to(dtype=torch.float32).to(device) - output = model(context, input_mask=input_mask) - - forecast = output.forecast.squeeze().tolist() - batch_forecast.append(forecast) - - yield pd.Series(batch_forecast) - -df = spark.table('solacc_uc.mmf.m4_daily_train') -df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y')) -device_count = torch.cuda.device_count() -forecasts = df.repartition(device_count).select(df.unique_id, df.ds, forecast_udf(df.y).alias("forecast")) +from pyspark.sql.functions import pandas_udf + + +def create_get_horizon_timestamps(freq, prediction_length): + + @pandas_udf('array') + def get_horizon_timestamps(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + one_ts_offset = pd.offsets.MonthEnd(1) if freq == "M" else pd.DateOffset(days=1) + barch_horizon_timestamps = [] + for batch in batch_iterator: + for series in batch: + timestamp = last = series.max() + horizon_timestamps = [] + for i in range(prediction_length): + timestamp = timestamp + one_ts_offset + horizon_timestamps.append(timestamp) + barch_horizon_timestamps.append(np.array(horizon_timestamps)) + yield pd.Series(barch_horizon_timestamps) + + return get_horizon_timestamps + + +def create_forecast_udf(repository, prediction_length): + + @pandas_udf('array') + def forecast_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + ## initialization step + from momentfm import MOMENTPipeline + model = MOMENTPipeline.from_pretrained( + repository, + model_kwargs={ + "task_name": "forecasting", + "forecast_horizon": prediction_length}, + ) + model.init() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + ## inference + for batch in batch_iterator: + batch_forecast = [] + for series in batch: + # takes in tensor of shape [batchsize, n_channels, context_length] + context = list(series) + if len(context) < 512: + input_mask = [1] * len(context) + [0] * (512 - len(context)) + context = context + [0] * (512 - len(context)) + else: + input_mask = [1] * 512 + context = context[-512:] + + input_mask = torch.reshape(torch.tensor(input_mask),(1, 512)).to(device) + context = torch.reshape(torch.tensor(context),(1, 1, 512)).to(dtype=torch.float32).to(device) + output = model(context, input_mask=input_mask) + forecast = output.forecast.squeeze().tolist() + batch_forecast.append(forecast) + + yield pd.Series(batch_forecast) + + return forecast_udf + +# COMMAND ---------- + +moment_model = "MOMENT-1-large" +prediction_length = 10 # Time horizon for forecasting +freq = "D" # Frequency of the time series +device_count = torch.cuda.device_count() # Number of GPUs available + +# COMMAND ---------- + +get_horizon_timestamps = create_get_horizon_timestamps(freq=freq, prediction_length=prediction_length) + +forecast_udf = create_forecast_udf( + repository=f"AutonLab/{moment_model}", + prediction_length=prediction_length, + ) + +forecasts = df.repartition(device_count).select( + df.unique_id, + get_horizon_timestamps(df.ds).alias("ds"), + forecast_udf(df.y).alias("forecast"), + ) display(forecasts) @@ -78,7 +150,6 @@ def __init__(self, repository): from momentfm import MOMENTPipeline self.model = MOMENTPipeline.from_pretrained( repository, - device_map="cuda", model_kwargs={ "task_name": "forecasting", "forecast_horizon": 10}, @@ -101,17 +172,18 @@ def predict(self, context, input_data, params=None): forecast = output.forecast.squeeze().tolist() return forecast -pipeline = MomentModel("AutonLab/MOMENT-1-large") +pipeline = MomentModel(f"AutonLab/{moment_model}") input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))]) output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))]) signature = ModelSignature(inputs=input_schema, outputs=output_schema) input_example = np.random.rand(52) +registered_model_name=f"{catalog}.{db}.{moment_model}" with mlflow.start_run() as run: mlflow.pyfunc.log_model( "model", python_model=pipeline, - registered_model_name="solacc_uc.mmf.moment_test", + registered_model_name=registered_model_name, signature=signature, input_example=input_example, pip_requirements=[ @@ -121,39 +193,219 @@ def predict(self, context, input_data, params=None): # COMMAND ---------- -pipeline.predict(None, input_example) - -# COMMAND ---------- - -import mlflow from mlflow import MlflowClient -import pandas as pd - -mlflow.set_registry_uri("databricks-uc") mlflow_client = MlflowClient() -def get_latest_model_version(mlflow_client, registered_name): +def get_latest_model_version(mlflow_client, registered_model_name): latest_version = 1 - for mv in mlflow_client.search_model_versions(f"name='{registered_name}'"): + for mv in mlflow_client.search_model_versions(f"name='{registered_model_name}'"): version_int = int(mv.version) if version_int > latest_version: latest_version = version_int return latest_version -model_name = "moment_test" -registered_name = f"solacc_uc.mmf.{model_name}" -model_version = get_latest_model_version(mlflow_client, registered_name) -logged_model = f"models:/{registered_name}/{model_version}" +model_version = get_latest_model_version(mlflow_client, registered_model_name) +logged_model = f"models:/{registered_model_name}/{model_version}" # Load model as a PyFuncModel loaded_model = mlflow.pyfunc.load_model(logged_model) # COMMAND ---------- -import numpy as np input_data = np.random.rand(52) loaded_model.predict(input_data) # COMMAND ---------- +# MAGIC %md +# MAGIC ## Deploy Model for Online Forecast + +# COMMAND ---------- + +# With the token, you can create our authorization header for our subsequent REST calls +token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) +headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + +# Next you need an endpoint at which to execute your request which you can get from the notebook's tags collection +java_tags = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags() + +# This object comes from the Java CM - Convert the Java Map opject to a Python dictionary +tags = sc._jvm.scala.collection.JavaConversions.mapAsJavaMap(java_tags) + +# Lastly, extract the Databricks instance (domain name) from the dictionary +instance = tags["browserHostName"] + +# COMMAND ---------- + +import requests + +model_serving_endpoint_name = moment_model + +my_json = { + "name": model_serving_endpoint_name, + "config": { + "served_models": [ + { + "model_name": registered_model_name, + "model_version": model_version, + "workload_type": "GPU_SMALL", + "workload_size": "Small", + "scale_to_zero_enabled": "true", + } + ], + "auto_capture_config": { + "catalog_name": catalog, + "schema_name": db, + "table_name_prefix": model_serving_endpoint_name, + }, + }, +} + +# Make sure to drop the inference table of it exists +_ = spark.sql( + f"DROP TABLE IF EXISTS {catalog}.{db}.`{model_serving_endpoint_name}_payload`" +) + +# COMMAND ---------- + +def func_create_endpoint(model_serving_endpoint_name): + # get endpoint status + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + url = f"{endpoint_url}/{model_serving_endpoint_name}" + r = requests.get(url, headers=headers) + if "RESOURCE_DOES_NOT_EXIST" in r.text: + print( + "Creating this new endpoint: ", + f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations", + ) + re = requests.post(endpoint_url, headers=headers, json=my_json) + else: + new_model_version = (my_json["config"])["served_models"][0]["model_version"] + print( + "This endpoint existed previously! We are updating it to a new config with new model version: ", + new_model_version, + ) + # update config + url = f"{endpoint_url}/{model_serving_endpoint_name}/config" + re = requests.put(url, headers=headers, json=my_json["config"]) + # wait till new config file in place + import time, json + + # get endpoint status + url = f"https://{instance}/api/2.0/serving-endpoints/{model_serving_endpoint_name}" + retry = True + total_wait = 0 + while retry: + r = requests.get(url, headers=headers) + assert ( + r.status_code == 200 + ), f"Expected an HTTP 200 response when accessing endpoint info, received {r.status_code}" + endpoint = json.loads(r.text) + if "pending_config" in endpoint.keys(): + seconds = 10 + print("New config still pending") + if total_wait < 6000: + # if less the 10 mins waiting, keep waiting + print(f"Wait for {seconds} seconds") + print(f"Total waiting time so far: {total_wait} seconds") + time.sleep(10) + total_wait += seconds + else: + print(f"Stopping, waited for {total_wait} seconds") + retry = False + else: + print("New config in place now!") + retry = False + + assert ( + re.status_code == 200 + ), f"Expected an HTTP 200 response, received {re.status_code}" + + +def func_delete_model_serving_endpoint(model_serving_endpoint_name): + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + url = f"{endpoint_url}/{model_serving_endpoint_name}" + response = requests.delete(url, headers=headers) + if response.status_code != 200: + raise Exception( + f"Request failed with status {response.status_code}, {response.text}" + ) + else: + print(model_serving_endpoint_name, "endpoint is deleted!") + return response.json() + +# COMMAND ---------- + +func_create_endpoint(model_serving_endpoint_name) + +# COMMAND ---------- + +import time, mlflow + + +def wait_for_endpoint(): + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + while True: + url = f"{endpoint_url}/{model_serving_endpoint_name}" + response = requests.get(url, headers=headers) + assert ( + response.status_code == 200 + ), f"Expected an HTTP 200 response, received {response.status_code}\n{response.text}" + + status = response.json().get("state", {}).get("ready", {}) + # print("status",status) + if status == "READY": + print(status) + print("-" * 80) + return + else: + print(f"Endpoint not ready ({status}), waiting 5 miutes") + time.sleep(300) # Wait 300 seconds + + +api_url = mlflow.utils.databricks_utils.get_webapp_url() + +wait_for_endpoint() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Online Forecast + +# COMMAND ---------- + +import os +import requests +import pandas as pd +import json +import matplotlib.pyplot as plt + +# Replace URL with the end point invocation url you get from Model Seriving page. +endpoint_url = f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations" +token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get() +def forecast(input_data, url=endpoint_url, databricks_token=token): + headers = { + "Authorization": f"Bearer {databricks_token}", + "Content-Type": "application/json", + } + body = {"inputs": input_data.tolist()} + data = json.dumps(body) + response = requests.request(method="POST", headers=headers, url=url, data=data) + if response.status_code != 200: + raise Exception( + f"Request failed with status {response.status_code}, {response.text}" + ) + return response.json() + +# COMMAND ---------- + +input_data = np.random.rand(52) +forecast(input_data) + +# COMMAND ---------- + +func_delete_model_serving_endpoint(model_serving_endpoint_name) + +# COMMAND ---------- + diff --git a/examples/foundation-model-examples/timesfm-example.py b/examples/foundation-model-examples/timesfm-example.py index cf61b50..d0b70c8 100644 --- a/examples/foundation-model-examples/timesfm-example.py +++ b/examples/foundation-model-examples/timesfm-example.py @@ -1,5 +1,16 @@ # Databricks notebook source -# MAGIC %pip install -r ../../requirements.txt --quiet +# MAGIC %md +# MAGIC This is an example notebook that shows how to use [TimesFM](https://github.com/google-research/timesfm) models on Databricks. +# MAGIC +# MAGIC The notebook loads the model, distributes the inference, registers the model, deploys the model and makes online forecasts. +# MAGIC +# MAGIC As of today (June 5, 2024), TimesFM supports python version below [3.10](https://github.com/google-research/timesfm/issues/60). So make sure your cluster is below DBR ML 14.3. + +# COMMAND ---------- + +# MAGIC %pip install jax[cuda12]==0.4.26 --quiet +# MAGIC %pip install protobuf==3.20.* --quiet +# MAGIC %pip install utilsforecast --quiet # MAGIC dbutils.library.restartPython() # COMMAND ---------- @@ -11,10 +22,44 @@ # COMMAND ---------- +# MAGIC %md +# MAGIC ## Prepare Data + +# COMMAND ---------- + +catalog = "solacc_uc" # Name of the catalog we use to manage our assets +db = "mmf" # Name of the schema we use to manage our assets (e.g. datasets) +n = 100 # Number of time series to sample + +# COMMAND ---------- + +# This cell will create tables: {catalog}.{db}.m4_daily_train, {catalog}.{db}.m4_monthly_train, {catalog}.{db}.rossmann_daily_train, {catalog}.{db}.rossmann_daily_test + +dbutils.notebook.run("data_preparation", timeout_seconds=0, arguments={"catalog": catalog, "db": db, "n": n}) + +# COMMAND ---------- + +# Make sure that the data exists +df = spark.table(f'{catalog}.{db}.m4_daily_train').toPandas() +display(df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Distribute Inference + +# COMMAND ---------- + +# MAGIC %md +# MAGIC See the [github repository](https://github.com/google-research/timesfm/tree/master?tab=readme-ov-file#initialize-the-model-and-load-a-checkpoint) of TimesFM for detailed description of the input parameters. + +# COMMAND ---------- + import timesfm + tfm = timesfm.TimesFm( - context_len=512, - horizon_len=10, + context_len=512, # Max context length of the model. It needs to be a multiplier of input_patch_len, i.e. a multiplier of 32. + horizon_len=10, # Horizon length can be set to anything. We recommend setting it to the largest horizon length. input_patch_len=32, output_patch_len=128, num_layers=20, @@ -22,21 +67,312 @@ backend="gpu", ) -# COMMAND ---------- - tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") -# COMMAND ---------- - -import pandas as pd -df = spark.table('solacc_uc.mmf.m4_daily_train').toPandas() forecast_df = tfm.forecast_on_df( inputs=df, - freq="D", # monthly + freq="D", value_name="y", num_jobs=-1, ) +display(forecast_df) + # COMMAND ---------- -display(forecast_df) +# MAGIC %md +# MAGIC ##Register Model + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We should ensure that any non-serializable attributes (like the timesfm model in TimesFMModel class) are not included in the serialization process. One common approach is to override the __getstate__ and __setstate__ methods in the class to manage what gets pickled. This modification ensures that the timesfm model is not included in the serialization process, thus avoiding the error. The load_model method is called to load the model when needed, such as during prediction or after deserialization. + +# COMMAND ---------- + +import mlflow +import torch +import numpy as np +from mlflow.models import infer_signature +from mlflow.models.signature import ModelSignature +from mlflow.types import DataType, Schema, TensorSpec +mlflow.set_registry_uri("databricks-uc") + + +class TimesFMModel(mlflow.pyfunc.PythonModel): + def __init__(self, repository): + self.repository = repository + self.tfm = None + + def load_model(self): + import timesfm + self.tfm = timesfm.TimesFm( + context_len=512, + horizon_len=10, + input_patch_len=32, + output_patch_len=128, + num_layers=20, + model_dims=1280, + backend="gpu", + ) + self.tfm.load_from_checkpoint(repo_id=self.repository) + + def predict(self, context, input_df, params=None): + if self.tfm is None: + self.load_model() + + forecast_df = self.tfm.forecast_on_df( + inputs=input_df, + freq="D", + value_name="y", + num_jobs=-1, + ) + return forecast_df + + def __getstate__(self): + state = self.__dict__.copy() + # Remove the tfm attribute from the state, as it's not serializable + del state['tfm'] + return state + + def __setstate__(self, state): + # Restore instance attributes + self.__dict__.update(state) + # Reload the model since it was not stored in the state + self.load_model() + + +pipeline = TimesFMModel("google/timesfm-1.0-200m") +signature = infer_signature( + model_input=df, + model_output=pipeline.predict(None, df), +) + +registered_model_name=f"{catalog}.{db}.timesfm-1-200m" + +with mlflow.start_run() as run: + mlflow.pyfunc.log_model( + "model", + python_model=pipeline, + registered_model_name=registered_model_name, + signature=signature, + input_example=df, + pip_requirements=[ + "jax[cuda12]==0.4.26", + "protobuf==3.20.*", + "utilsforecast==0.1.10", + "git+https://github.com/google-research/timesfm.git", + ], + ) + +# COMMAND ---------- + +from mlflow import MlflowClient +client = MlflowClient() + +def get_latest_model_version(client, registered_model_name): + latest_version = 1 + for mv in client.search_model_versions(f"name='{registered_model_name}'"): + version_int = int(mv.version) + if version_int > latest_version: + latest_version = version_int + return latest_version + +model_version = get_latest_model_version(client, registered_model_name) +logged_model = f"models:/{registered_model_name}/{model_version}" + +# Load model as a PyFuncModel +loaded_model = mlflow.pyfunc.load_model(logged_model) + +# Generate forecasts +loaded_model.predict(df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Deploy Model for Online Forecast + +# COMMAND ---------- + +# With the token, you can create our authorization header for our subsequent REST calls +token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) +headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + +# Next you need an endpoint at which to execute your request which you can get from the notebook's tags collection +java_tags = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags() + +# This object comes from the Java CM - Convert the Java Map opject to a Python dictionary +tags = sc._jvm.scala.collection.JavaConversions.mapAsJavaMap(java_tags) + +# Lastly, extract the Databricks instance (domain name) from the dictionary +instance = tags["browserHostName"] + +# COMMAND ---------- + +import requests + +model_serving_endpoint_name = "timesfm-1-200m" + +my_json = { + "name": model_serving_endpoint_name, + "config": { + "served_models": [ + { + "model_name": registered_model_name, + "model_version": model_version, + "workload_type": "GPU_SMALL", + "workload_size": "Small", + "scale_to_zero_enabled": "true", + } + ], + "auto_capture_config": { + "catalog_name": catalog, + "schema_name": db, + "table_name_prefix": model_serving_endpoint_name, + }, + }, +} + +# Make sure to drop the inference table of it exists +_ = spark.sql( + f"DROP TABLE IF EXISTS {catalog}.{db}.`{model_serving_endpoint_name}_payload`" +) + +# COMMAND ---------- + +def func_create_endpoint(model_serving_endpoint_name): + # get endpoint status + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + url = f"{endpoint_url}/{model_serving_endpoint_name}" + r = requests.get(url, headers=headers) + if "RESOURCE_DOES_NOT_EXIST" in r.text: + print( + "Creating this new endpoint: ", + f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations", + ) + re = requests.post(endpoint_url, headers=headers, json=my_json) + else: + new_model_version = (my_json["config"])["served_models"][0]["model_version"] + print( + "This endpoint existed previously! We are updating it to a new config with new model version: ", + new_model_version, + ) + # update config + url = f"{endpoint_url}/{model_serving_endpoint_name}/config" + re = requests.put(url, headers=headers, json=my_json["config"]) + # wait till new config file in place + import time, json + + # get endpoint status + url = f"https://{instance}/api/2.0/serving-endpoints/{model_serving_endpoint_name}" + retry = True + total_wait = 0 + while retry: + r = requests.get(url, headers=headers) + assert ( + r.status_code == 200 + ), f"Expected an HTTP 200 response when accessing endpoint info, received {r.status_code}" + endpoint = json.loads(r.text) + if "pending_config" in endpoint.keys(): + seconds = 10 + print("New config still pending") + if total_wait < 6000: + # if less the 10 mins waiting, keep waiting + print(f"Wait for {seconds} seconds") + print(f"Total waiting time so far: {total_wait} seconds") + time.sleep(10) + total_wait += seconds + else: + print(f"Stopping, waited for {total_wait} seconds") + retry = False + else: + print("New config in place now!") + retry = False + + assert ( + re.status_code == 200 + ), f"Expected an HTTP 200 response, received {re.status_code}" + + +def func_delete_model_serving_endpoint(model_serving_endpoint_name): + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + url = f"{endpoint_url}/{model_serving_endpoint_name}" + response = requests.delete(url, headers=headers) + if response.status_code != 200: + raise Exception( + f"Request failed with status {response.status_code}, {response.text}" + ) + else: + print(model_serving_endpoint_name, "endpoint is deleted!") + return response.json() + +# COMMAND ---------- + +func_create_endpoint(model_serving_endpoint_name) + +# COMMAND ---------- + +import time, mlflow + + +def wait_for_endpoint(): + endpoint_url = f"https://{instance}/api/2.0/serving-endpoints" + while True: + url = f"{endpoint_url}/{model_serving_endpoint_name}" + response = requests.get(url, headers=headers) + assert ( + response.status_code == 200 + ), f"Expected an HTTP 200 response, received {response.status_code}\n{response.text}" + + status = response.json().get("state", {}).get("ready", {}) + # print("status",status) + if status == "READY": + print(status) + print("-" * 80) + return + else: + print(f"Endpoint not ready ({status}), waiting 5 miutes") + time.sleep(300) # Wait 300 seconds + + +api_url = mlflow.utils.databricks_utils.get_webapp_url() + +wait_for_endpoint() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Online Forecast + +# COMMAND ---------- + +import os +import requests +import pandas as pd +import json +import matplotlib.pyplot as plt + +# Replace URL with the end point invocation url you get from Model Seriving page. +endpoint_url = f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations" +token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get() +def forecast(input_data, url=endpoint_url, databricks_token=token): + headers = { + "Authorization": f"Bearer {databricks_token}", + "Content-Type": "application/json", + } + body = {"inputs": input_data.tolist()} + data = json.dumps(body) + response = requests.request(method="POST", headers=headers, url=url, data=data) + if response.status_code != 200: + raise Exception( + f"Request failed with status {response.status_code}, {response.text}" + ) + return response.json() + +# COMMAND ---------- + +forecast(df) + +# COMMAND ---------- + +func_delete_model_serving_endpoint(model_serving_endpoint_name)