Skip to content

Commit

Permalink
Merge pull request #61 from databricks-industry-solutions/finetune-mo…
Browse files Browse the repository at this point in the history
…irai

added-moirai-finetuning
  • Loading branch information
ryuta-yoshimatsu authored Jun 21, 2024
2 parents f350bd5 + d35bc3c commit a9fa758
Show file tree
Hide file tree
Showing 16 changed files with 704 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
catalog = "mmf" # Name of the catalog we use to manage our assets
db = "m4" # Name of the schema we use to manage our assets (e.g. datasets)
volume = "chronos_fine_tune" # Name of the volume we store the data and the weigts
chronos_model = "chronos-t5-tiny" # Chronos model to finetune. Alternatives: -mini, -small, -base, -large
model = "chronos-t5-tiny" # Chronos model to finetune. Alternatives: -mini, -small, -base, -large
n = 1000 # Number of time series to sample

# COMMAND ----------
Expand Down Expand Up @@ -89,7 +89,7 @@ def convert_to_arrow(
start_times = list(df["ds"].apply(lambda x: x.min().to_numpy()))

# Make sure that the volume exists. We stored the fine-tuned weights here.
_ = spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{db}.chronos_fine_tune")
_ = spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{db}.{volume}")

# Convert to GluonTS arrow format and save it in UC Volume
convert_to_arrow(
Expand All @@ -105,7 +105,7 @@ def convert_to_arrow(
# MAGIC
# MAGIC In this example, we wil fine-tune `amazon/chronos-t5-tiny` for 1000 steps with initial learning rate of 1e-3.
# MAGIC
# MAGIC Make sure that you have the configuration yaml files placed inside the `configs` folder and the `train.py` script in the same directory. These two assets are taken directly from [chronos-forecasting/scripts/training](https://github.com/amazon-science/chronos-forecasting/tree/main/scripts/training). They are subject to change as the Chronos' team develops the framework further. Keep your eyes on the latest changes (we will try too) and use the latest versions if needed. We have made a small change to our `train.py` script and set the frequency of the time series to daily ("D").
# MAGIC Make sure that you have the configuration yaml files placed inside the `configs` folder and the `train.py` script in the same directory. These two assets are taken directly from [chronos-forecasting/scripts/training](https://github.com/amazon-science/chronos-forecasting/tree/main/scripts/training). They are subject to change as the Chronos' team develops the framework further. Keep your eyes on the latest changes (we will try too) and use the latest versions as needed. We have made a small change to our `train.py` script and set the frequency of the time series to daily ("D").
# MAGIC
# MAGIC Inside the configuration yaml (for this example, `configs/chronos-t5-tiny.yaml`), make sure to set the parameters:
# MAGIC - `training_data_paths` to `/Volumes/mmf/m4/chronos_fine_tune/data.arrow`, where your arrow converted file is stored
Expand Down Expand Up @@ -168,7 +168,7 @@ def predict(self, context, input_data, params=None):
files = os.listdir(f"/Volumes/{catalog}/{db}/{volume}/")
runs = [int(file[4:]) for file in files if "run-" in file]
latest_run = max(runs)
registered_model_name=f"{catalog}.{db}.{chronos_model}_finetuned"
registered_model_name=f"{catalog}.{db}.{model}_finetuned"
weights = f"/Volumes/{catalog}/{db}/{volume}/run-{latest_run}/checkpoint-final/"

# Get the model signature for registry
Expand All @@ -195,7 +195,7 @@ def predict(self, context, input_data, params=None):

# MAGIC %md
# MAGIC ##Reload Model
# MAGIC We reload the model from the registry and perform forecasting on the in-training time series (for testing purpose). You can also go ahead and deploy this model behind a Model Serving's real-time endpoint. See the previous notebook: `01_chronos_load_inference` for more information.
# MAGIC We reload the model from the registry and perform forecasting on the in-training time series (for testing purpose). You can also go ahead and deploy this model behind a Model Serving's real-time endpoint. See the previous notebook: [`01_chronos_load_inference`](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/foundation-model-examples/chronos/01_chronos_load_inference.py) for more information.

# COMMAND ----------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:

# COMMAND ----------

moirai_model = "moirai-1.0-R-small" # Alternatibely moirai-1.0-R-base, moirai-1.0-R-large
model = "moirai-1.0-R-small" # Alternatibely moirai-1.0-R-base, moirai-1.0-R-large
prediction_length = 10 # Time horizon for forecasting
num_samples = 10 # Number of forecast to generate. We will take median as our final forecast.
patch_size = 32 # Patch size: choose from {"auto", 8, 16, 32, 64, 128}
Expand All @@ -129,7 +129,7 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
get_horizon_timestamps = create_get_horizon_timestamps(freq=freq, prediction_length=prediction_length)

forecast_udf = create_forecast_udf(
repository=f"Salesforce/{moirai_model}",
repository=f"Salesforce/{model}",
prediction_length=prediction_length,
patch_size=patch_size,
num_samples=num_samples,
Expand Down Expand Up @@ -192,7 +192,7 @@ def predict(self, context, input_data, params=None):
)
return np.median(forecast[0], axis=0)

pipeline = MoiraiModel(f"Salesforce/{moirai_model}")
pipeline = MoiraiModel(f"Salesforce/{model}")
input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))])
output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
Expand Down
224 changes: 224 additions & 0 deletions examples/foundation-model-examples/moirai/02_moirai_fine_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Databricks notebook source
# MAGIC %md
# MAGIC This is an example notebook that shows how to use [Moirai](https://github.com/SalesforceAIResearch/uni2ts) models on Databricks.
# MAGIC
# MAGIC The notebook loads, fine-tunes, and registers the model.

# COMMAND ----------

# MAGIC %pip install git+https://github.com/SalesforceAIResearch/uni2ts.git --quiet
# MAGIC dbutils.library.restartPython()

# COMMAND ----------

# MAGIC %md
# MAGIC ## Prepare Data
# MAGIC Make sure that the catalog and the schema already exist.

# COMMAND ----------

catalog = "mmf" # Name of the catalog we use to manage our assets
db = "random" # Name of the schema we use to manage our assets (e.g. datasets)
volume = "moirai_fine_tune" # Name of the volume we store the data and the weigts
model = "moirai-1.0-R-small" # Alternatibely: moirai-1.0-R-base, moirai-1.0-R-large
n = 100 # Number of time series to sample

# COMMAND ----------

# Make sure that the database exists.
_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{db}")

# Make sure that the volume exists. We stored the fine-tuned weights here.
_ = spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{db}.{volume}")

# COMMAND ----------

# MAGIC %md
# MAGIC We synthesize `n` number of time series (randomly sampled) of daily resolution and store it as a csv file in UC Volume.

# COMMAND ----------

import pandas as pd
import numpy as np

df_dict = {}

for i in range(n):

# Create a date range for the index
date_range = pd.date_range(start='2021-01-01', end='2023-12-31', freq='D')

# Create a DataFrame with a date range index and two columns: 'item_id' and 'target'
df = pd.DataFrame({
'item_id': str(f"item_{i}"),
'target': np.random.randn(len(date_range))
}, index=date_range)

# Set 'item_id' as the second level of the MultiIndex
df.set_index('item_id', append=True, inplace=True)

# Sort the index
df.sort_index(inplace=True)

df_dict[i] = df


pdf = pd.concat([df_dict[i] for i in range(n)])
pdf.to_csv(f"/Volumes/{catalog}/{db}/{volume}/random.csv", index=True)
pdf

# COMMAND ----------

# MAGIC %md
# MAGIC This dotenv file is needed to use the `uni2ts.data.builder.simple` function from the `uni2ts` library to build a dataset.

# COMMAND ----------

import os
import site

uni2ts = os.path.join(site.getsitepackages()[0], "uni2ts")
dotenv = os.path.join(uni2ts, ".env")
os.environ['DOTENV'] = dotenv
os.environ['CUSTOM_DATA_PATH'] = f"/Volumes/{catalog}/{db}/{volume}"

# COMMAND ----------

# MAGIC %sh
# MAGIC rm -f $DOTENV
# MAGIC touch $DOTENV
# MAGIC echo "CUSTOM_DATA_PATH=$CUSTOM_DATA_PATH" >> $DOTENV

# COMMAND ----------

# MAGIC %md
# MAGIC We convert the dataset into the Uni2TS format. `random` is the name of the dataset to use, which we load from our volume's location. See the [README](https://github.com/SalesforceAIResearch/uni2ts/tree/main?tab=readme-ov-file#fine-tuning) of Uni2TS for more information on the parameters.

# COMMAND ----------

# MAGIC %sh python -m uni2ts.data.builder.simple random /Volumes/mmf/random/moirai_fine_tune/random.csv \
# MAGIC --dataset_type long \
# MAGIC --offset 640

# COMMAND ----------

# MAGIC %md
# MAGIC ##Run Fine-tuning
# MAGIC
# MAGIC In this example, we wil fine-tune `moirai-1.0-R-small` for max 100 epochs with early stopping (can be specified here: `examples/foundation-model-examples/moirai/conf/finetune/default.yaml`). The learning rate is set to 1e-3, which you can modify in `examples/foundation-model-examples/moirai/conf/finetune/default.yaml`.
# MAGIC
# MAGIC Make sure that you have the configuration yaml files placed inside the `conf` folder and the `train.py` script in the same directory. These two assets are taken directly from and [cli/conf](https://github.com/SalesforceAIResearch/uni2ts/tree/main/cli/conf) and [cli/train.py](https://github.com/SalesforceAIResearch/uni2ts/blob/main/cli/train.py). They are subject to change as the Moirai' team develops the framework further. Keep your eyes on the latest changes (we will try too) and use the latest versions as needed.
# MAGIC
# MAGIC The key configuration files to be customized for you use case are `examples/foundation-model-examples/moirai/conf/finetune/default.yaml`, `examples/foundation-model-examples/moirai/conf/finetune/data/random.yaml` and `examples/foundation-model-examples/moirai/conf/finetune/val_data/random.yaml`. Refer to the Moirai [documentation](https://github.com/SalesforceAIResearch/uni2ts) for more detail.

# COMMAND ----------

# MAGIC %sh python train.py \
# MAGIC -cp conf/finetune \
# MAGIC run_name=random_run \
# MAGIC model=moirai_1.0_R_small \
# MAGIC data=random \
# MAGIC val_data=random

# COMMAND ----------

# MAGIC %md
# MAGIC ##Register Model
# MAGIC We get the fine-tuned weights from the run from the UC volume, wrap the pipeline with `mlflow.pyfunc.PythonModel` and register this on Unity Catalog.

# COMMAND ----------

import mlflow
import torch
import numpy as np
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, TensorSpec
mlflow.set_registry_uri("databricks-uc")


class FineTunedMoiraiModel(mlflow.pyfunc.PythonModel):
def predict(self, context, input_data, params=None):
from einops import rearrange
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = MoiraiForecast.load_from_checkpoint(
prediction_length=10,
context_length=len(input_data),
patch_size=32,
num_samples=10,
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
checkpoint_path=context.artifacts["weights"],
).to(device)

# Time series values. Shape: (batch, time, variate)
past_target = rearrange(
torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1"
)
# 1s if the value is observed, 0s otherwise. Shape: (batch, time, variate)
past_observed_target = torch.ones_like(past_target, dtype=torch.bool)
# 1s if the value is padding, 0s otherwise. Shape: (batch, time)
past_is_pad = torch.zeros_like(past_target, dtype=torch.bool).squeeze(-1)
forecast = model(
past_target=past_target.to(device),
past_observed_target=past_observed_target.to(device),
past_is_pad=past_is_pad.to(device),
)
return np.median(forecast.cpu()[0], axis=0)

input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))])
output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
input_example = np.random.rand(52)
registered_model_name=f"{catalog}.{db}.moirai-1-r-small_finetuned"
weights = f"/Volumes/{catalog}/{db}/{volume}/outputs/moirai_1.0_R_small/random/random_run/checkpoints/epoch=0-step=100.ckpt"


with mlflow.start_run() as run:
mlflow.pyfunc.log_model(
"model",
python_model=FineTunedMoiraiModel(),
registered_model_name=registered_model_name,
artifacts={"weights": weights},
signature=signature,
input_example=input_example,
pip_requirements=[
"git+https://github.com/SalesforceAIResearch/uni2ts.git",
],
)

# COMMAND ----------

# MAGIC %md
# MAGIC ##Reload Model
# MAGIC We reload the model from the registry and perform forecasting on a randomly generated time series (for testing purpose). You can also go ahead and deploy this model behind a Model Serving's real-time endpoint. See the previous notebook: [`01_moirai_load_inference`](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/foundation-model-examples/chronos/02_moirai_load_inference.py) for more information.

# COMMAND ----------

from mlflow import MlflowClient
client = MlflowClient()

def get_latest_model_version(client, registered_model_name):
latest_version = 1
for mv in client.search_model_versions(f"name='{registered_model_name}'"):
version_int = int(mv.version)
if version_int > latest_version:
latest_version = version_int
return latest_version

model_version = get_latest_model_version(client, registered_model_name)
logged_model = f"models:/{registered_model_name}/{model_version}"

# Load model as a PyFuncModel
loaded_model = mlflow.pyfunc.load_model(logged_model)

# Create input data
input_data = np.random.rand(52)

# Generate forecasts
loaded_model.predict(input_data)

# COMMAND ----------


Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder
dataset: ETTh1
weight: 1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder
dataset: random
weight: 1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
hydra:
run:
dir: /Volumes/mmf/random/moirai_fine_tune/outputs/${hydra:runtime.choices.model}/${hydra:runtime.choices.data}/${run_name}
defaults:
- model: ???
- data: ???
- val_data: null
- _self_
run_name: ???
seed: 0
tf32: true
compile: false # set to mode: default, reduce-overhead, max-autotune
trainer:
_target_: lightning.Trainer
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 32
logger:
_target_: lightning.pytorch.loggers.TensorBoardLogger
save_dir: ${hydra:runtime.output_dir}
name: logs
callbacks:
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: epoch
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
monitor: val/PackedNLLLoss
save_weights_only: true
mode: min
save_top_k: 1
every_n_epochs: 1
- _target_: lightning.pytorch.callbacks.EarlyStopping
monitor: val/PackedNLLLoss
min_delta: 0.0
patience: 3
mode: min
strict: false
verbose: true
max_epochs: 100
enable_progress_bar: true
accumulate_grad_batches: 1
gradient_clip_val: 1.0
gradient_clip_algorithm: norm
train_dataloader:
_target_: uni2ts.data.loader.DataLoader
batch_size: 128
batch_size_factor: 2.0
cycle: true
num_batches_per_epoch: 100
shuffle: true
num_workers: 11
collate_fn:
_target_: uni2ts.data.loader.PackCollate
max_length: ${model.module_kwargs.max_seq_len}
seq_fields: ${cls_getattr:${model._target_},seq_fields}
pad_func_map: ${cls_getattr:${model._target_},pad_func_map}
pin_memory: true
drop_last: false
fill_last: false
worker_init_fn: null
prefetch_factor: 2
persistent_workers: true
val_dataloader:
_target_: uni2ts.data.loader.DataLoader
batch_size: 128
batch_size_factor: 2.0
cycle: false
num_batches_per_epoch: null
shuffle: false
num_workers: 11
collate_fn:
_target_: uni2ts.data.loader.PackCollate
max_length: ${model.module_kwargs.max_seq_len}
seq_fields: ${cls_getattr:${model._target_},seq_fields}
pad_func_map: ${cls_getattr:${model._target_},pad_func_map}
pin_memory: false
drop_last: false
fill_last: true
worker_init_fn: null
prefetch_factor: 2
persistent_workers: true
Loading

0 comments on commit a9fa758

Please sign in to comment.