From ec8a8e2babb593b784975f6de5496fbdc386884b Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Fri, 15 Mar 2024 11:35:30 -0700 Subject: [PATCH 01/23] feat: Part 1 - Add FineTuning API (#201) PEFT - LoRA for Fine-Tuning LLMs This PR is Part 1 in PRs that will integrate fine tuning into Kaito. This PR is base API code. Future PRs will allow you to specify custom dataset, load config from configmap, and upload training results as image to ACR. --------- Signed-off-by: Ishaan Sehgal --- .github/matrix-configs.json | 20 +-- .github/workflows/kind-cluster/main.py | 2 +- .../{ => inference}/llama-2/Dockerfile | 0 .../{ => inference}/tfs-onnx/Dockerfile | 0 .../tfs-onnx/convert_to_onnx.py | 0 docker/presets/{ => inference}/tfs/Dockerfile | 0 docker/presets/tuning/Dockerfile | 23 +++ pkg/controllers/workspace_controller.go | 2 +- presets/models/llama2/README.md | 2 +- presets/models/llama2chat/README.md | 2 +- presets/test/falcon-benchmark/README.md | 2 +- .../manifests/tuning/falcon/falcon-7b.yaml | 103 ++++++++++++++ presets/tuning/tfs/cli.py | 129 +++++++++++++++++ presets/tuning/tfs/fine_tuning_api.py | 133 ++++++++++++++++++ presets/tuning/tfs/requirements.txt | 14 ++ 15 files changed, 417 insertions(+), 15 deletions(-) rename docker/presets/{ => inference}/llama-2/Dockerfile (100%) rename docker/presets/{ => inference}/tfs-onnx/Dockerfile (100%) rename docker/presets/{ => inference}/tfs-onnx/convert_to_onnx.py (100%) rename docker/presets/{ => inference}/tfs/Dockerfile (100%) create mode 100644 docker/presets/tuning/Dockerfile create mode 100644 presets/test/manifests/tuning/falcon/falcon-7b.yaml create mode 100644 presets/tuning/tfs/cli.py create mode 100644 presets/tuning/tfs/fine_tuning_api.py create mode 100644 presets/tuning/tfs/requirements.txt diff --git a/.github/matrix-configs.json b/.github/matrix-configs.json index ef6dc8f95..0b3666742 100644 --- a/.github/matrix-configs.json +++ b/.github/matrix-configs.json @@ -3,7 +3,7 @@ "model": { "runs_on": "self-hosted", "name": "falcon-7b", - "dockerfile": "docker/presets/falcon/Dockerfile", + "dockerfile": "docker/presets/inference/falcon/Dockerfile", "build_args": "--build-arg FALCON_MODEL_NAME=tiiuae/falcon-7b" }, "shouldBuildFalcon": "true" @@ -12,7 +12,7 @@ "model": { "runs_on": "self-hosted", "name": "falcon-7b-instruct", - "dockerfile": "docker/presets/falcon/Dockerfile", + "dockerfile": "docker/presets/inference/falcon/Dockerfile", "build_args": "--build-arg FALCON_MODEL_NAME=tiiuae/falcon-7b-instruct" }, "shouldBuildFalcon": "true" @@ -22,7 +22,7 @@ "model": { "runs_on": "self-hosted", "name": "falcon-40b", - "dockerfile": "docker/presets/falcon/Dockerfile", + "dockerfile": "docker/presets/inference/falcon/Dockerfile", "build_args": "--build-arg FALCON_MODEL_NAME=tiiuae/falcon-40b" }, "shouldBuildFalcon": "true" @@ -32,7 +32,7 @@ "model": { "runs_on": "self-hosted", "name": "falcon-40b-instruct", - "dockerfile": "docker/presets/falcon/Dockerfile", + "dockerfile": "docker/presets/inference/falcon/Dockerfile", "build_args": "--build-arg FALCON_MODEL_NAME=tiiuae/falcon-40b-instruct" }, "shouldBuildFalcon": "true" @@ -42,7 +42,7 @@ "model": { "runs_on": "self-hosted", "name": "llama-2-7b", - "dockerfile": "docker/presets/llama-2/Dockerfile", + "dockerfile": "docker/presets/inference/llama-2/Dockerfile", "build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-7b --build-arg SRC_DIR=/home/presets/llama-2" }, "shouldBuildLlama2": "true" @@ -52,7 +52,7 @@ "model": { "runs_on": "self-hosted", "name": "llama-2-13b", - "dockerfile": "docker/presets/llama-2/Dockerfile", + "dockerfile": "docker/presets/inference/llama-2/Dockerfile", "build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-13b --build-arg SRC_DIR=/home/presets/llama-2" }, "shouldBuildLlama2": "true" @@ -62,7 +62,7 @@ "model": { "runs_on": "self-hosted", "name": "llama-2-70b", - "dockerfile": "docker/presets/llama-2/Dockerfile", + "dockerfile": "docker/presets/inference/llama-2/Dockerfile", "build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-70b --build-arg SRC_DIR=/home/presets/llama-2" }, "shouldBuildLlama2": "true" @@ -72,7 +72,7 @@ "model": { "runs_on": "self-hosted", "name": "llama-2-7b-chat", - "dockerfile": "docker/presets/llama-2/Dockerfile", + "dockerfile": "docker/presets/inference/llama-2/Dockerfile", "build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-7b-chat --build-arg SRC_DIR=/home/presets/llama-2-chat" }, "shouldBuildLlama2Chat": "true" @@ -82,7 +82,7 @@ "model": { "runs_on": "self-hosted", "name": "llama-2-13b-chat", - "dockerfile": "docker/presets/llama-2/Dockerfile", + "dockerfile": "docker/presets/inference/llama-2/Dockerfile", "build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-13b-chat --build-arg SRC_DIR=/home/presets/llama-2-chat" }, "shouldBuildLlama2Chat": "true" @@ -92,7 +92,7 @@ "model": { "runs_on": "self-hosted", "name": "llama-2-70b-chat", - "dockerfile": "docker/presets/llama-2/Dockerfile", + "dockerfile": "docker/presets/inference/llama-2/Dockerfile", "build_args": "--build-arg LLAMA_WEIGHTS=/llama/llama-2-70b-chat --build-arg SRC_DIR=/home/presets/llama-2-chat" }, "shouldBuildLlama2Chat": "true" diff --git a/.github/workflows/kind-cluster/main.py b/.github/workflows/kind-cluster/main.py index 02816ee8c..34f4205f9 100644 --- a/.github/workflows/kind-cluster/main.py +++ b/.github/workflows/kind-cluster/main.py @@ -12,7 +12,7 @@ def get_weights_path(model_name): return f"/datadrive/{model_name}/weights" def get_dockerfile_path(model_runtime): - return f"/kaito/docker/presets/{model_runtime}/Dockerfile" + return f"/kaito/docker/presets/inference/{model_runtime}/Dockerfile" def generate_unique_id(): """Generate a unique identifier for a job.""" diff --git a/docker/presets/llama-2/Dockerfile b/docker/presets/inference/llama-2/Dockerfile similarity index 100% rename from docker/presets/llama-2/Dockerfile rename to docker/presets/inference/llama-2/Dockerfile diff --git a/docker/presets/tfs-onnx/Dockerfile b/docker/presets/inference/tfs-onnx/Dockerfile similarity index 100% rename from docker/presets/tfs-onnx/Dockerfile rename to docker/presets/inference/tfs-onnx/Dockerfile diff --git a/docker/presets/tfs-onnx/convert_to_onnx.py b/docker/presets/inference/tfs-onnx/convert_to_onnx.py similarity index 100% rename from docker/presets/tfs-onnx/convert_to_onnx.py rename to docker/presets/inference/tfs-onnx/convert_to_onnx.py diff --git a/docker/presets/tfs/Dockerfile b/docker/presets/inference/tfs/Dockerfile similarity index 100% rename from docker/presets/tfs/Dockerfile rename to docker/presets/inference/tfs/Dockerfile diff --git a/docker/presets/tuning/Dockerfile b/docker/presets/tuning/Dockerfile new file mode 100644 index 000000000..896deb85a --- /dev/null +++ b/docker/presets/tuning/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.10-slim + +ARG WEIGHTS_PATH +ARG MODEL_TYPE +ARG VERSION + +# Set the working directory +WORKDIR /workspace/tfs + +# Write the version to a file +RUN echo $VERSION > /workspace/tfs/version.txt + +# First, copy just the preset files and install dependencies +# This is done before copying the code to utilize Docker's layer caching and +# avoid reinstalling dependencies unless the requirements file changes. +COPY kaito/presets/tuning/${MODEL_TYPE}/requirements.txt /workspace/tfs/requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +COPY kaito/presets/tuning/${MODEL_TYPE}/cli.py /workspace/tfs/cli.py +COPY kaito/presets/tuning/${MODEL_TYPE}/fine_tuning_api.py /workspace/tfs/tuning_api.py + +# Copy the entire model weights to the weights directory +COPY ${WEIGHTS_PATH} /workspace/tfs/weights diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index 1be44f4e5..a2e4fc18d 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -130,7 +130,7 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka func (c *WorkspaceReconciler) deleteWorkspace(ctx context.Context, wObj *kaitov1alpha1.Workspace) (reconcile.Result, error) { klog.InfoS("deleteWorkspace", "workspace", klog.KObj(wObj)) - // TODO delete workspace, machine(s), training and inference (deployment, service) obj ( ok to delete machines? which will delete nodes??) + // TODO delete workspace, machine(s), fine_tuning and inference (deployment, service) obj ( ok to delete machines? which will delete nodes??) err := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeDeleting, metav1.ConditionTrue, "workspaceDeleted", "workspace is being deleted") if err != nil { klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) diff --git a/presets/models/llama2/README.md b/presets/models/llama2/README.md index 4a711d8ff..e6a40563a 100644 --- a/presets/models/llama2/README.md +++ b/presets/models/llama2/README.md @@ -29,7 +29,7 @@ export LLAMA_WEIGHTS_PATH= Use the following command to build the llama2 inference service image from the root of the repo. ``` docker build \ - --file docker/presets/llama-2/Dockerfile \ + --file docker/presets/inference/llama-2/Dockerfile \ --build-arg WEIGHTS_PATH=$LLAMA_WEIGHTS_PATH \ --build-arg MODEL_TYPE=llama2-completion \ --build-arg VERSION=0.0.1 \ diff --git a/presets/models/llama2chat/README.md b/presets/models/llama2chat/README.md index 852a22fab..53e241fab 100644 --- a/presets/models/llama2chat/README.md +++ b/presets/models/llama2chat/README.md @@ -29,7 +29,7 @@ export LLAMA_WEIGHTS_PATH= Use the following command to build the llama2chat inference service image from the root of the repo. ``` docker build \ - --file docker/presets/llama-2/Dockerfile \ + --file docker/presets/inference/llama-2/Dockerfile \ --build-arg WEIGHTS_PATH=$LLAMA_WEIGHTS_PATH \ --build-arg MODEL_TYPE=llama2-chat \ --build-arg VERSION=0.0.1 \ diff --git a/presets/test/falcon-benchmark/README.md b/presets/test/falcon-benchmark/README.md index 4a2078fe7..9c43c99fa 100644 --- a/presets/test/falcon-benchmark/README.md +++ b/presets/test/falcon-benchmark/README.md @@ -23,7 +23,7 @@ Ensure your `accelerate` configuration aligns with the values provided during be - If you haven't already, you can use the Azure CLI or the Azure Portal to create and configure a GPU node pool in your AKS cluster. 2. Building and Pushing the Docker Image: - - First, you need to build a Docker image from the provided [Dockerfile](https://github.com/Azure/kaito/blob/main/docker/presets/tfs/Dockerfile) and push it to a container registry accessible by your AKS cluster + - First, you need to build a Docker image from the provided [Dockerfile](https://github.com/Azure/kaito/blob/main/docker/presets/inference/tfs/Dockerfile) and push it to a container registry accessible by your AKS cluster - Example: ``` diff --git a/presets/test/manifests/tuning/falcon/falcon-7b.yaml b/presets/test/manifests/tuning/falcon/falcon-7b.yaml new file mode 100644 index 000000000..7852bf01a --- /dev/null +++ b/presets/test/manifests/tuning/falcon/falcon-7b.yaml @@ -0,0 +1,103 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: falcon-7b-tuning +spec: + replicas: 1 + selector: + matchLabels: + app: falcon + template: + metadata: + labels: + app: falcon + spec: + containers: + - name: falcon-container + image: aimodelsregistrytest.azurecr.io/tuning-falcon-7b:0.0.1 + command: ["/bin/sh", "-c", "sleep infinity"] + resources: + requests: + nvidia.com/gpu: 2 + limits: + nvidia.com/gpu: 2 # Requesting 2 GPUs + volumeMounts: + - name: dshm + mountPath: /dev/shm + - name: workspace + mountPath: /workspace + + - name: docker-sidecar + image: docker:dind + securityContext: + privileged: true # Allows container to manage its own containers + volumeMounts: + - name: workspace + mountPath: /workspace + env: + - name: ACR_USERNAME + value: "{{ACR_USERNAME}}" + - name: ACR_PASSWORD + value: "{{ACR_PASSWORD}}" + - name: TAG + value: "{{TAG}}" + command: ["/bin/sh"] + args: + - -c + - | + # Start the Docker daemon in the background with specific options for DinD + dockerd & + # Wait for the Docker daemon to be ready + while ! docker info > /dev/null 2>&1; do + echo "Waiting for Docker daemon to start..." + sleep 1 + done + echo 'Docker daemon started' + + while true; do + FILE_PATH=$(find /workspace/tfs -name 'fine_tuning_completed.txt') + if [ ! -z "$FILE_PATH" ]; then + echo "FOUND TRAINING COMPLETED FILE at $FILE_PATH" + + PARENT_DIR=$(dirname "$FILE_PATH") + echo "Parent directory is $PARENT_DIR" + + TEMP_CONTEXT=$(mktemp -d) + cp "$PARENT_DIR/adapter_config.json" "$TEMP_CONTEXT/adapter_config.json" + cp -r "$PARENT_DIR/adapter_model.safetensors" "$TEMP_CONTEXT/adapter_model.safetensors" + + # Create a minimal Dockerfile + echo 'FROM scratch + ADD adapter_config.json / + ADD adapter_model.safetensors /' > "$TEMP_CONTEXT/Dockerfile" + + # Login to Docker registry + echo $ACR_PASSWORD | docker login $ACR_USERNAME.azurecr.io -u $ACR_USERNAME --password-stdin + + docker build -t $ACR_USERNAME.azurecr.io/adapter-falcon-7b:$TAG "$TEMP_CONTEXT" + docker push $ACR_USERNAME.azurecr.io/adapter-falcon-7b:$TAG + + # Cleanup: Remove the temporary directory + rm -rf "$TEMP_CONTEXT" + + # Remove the file to prevent repeated builds, or handle as needed + # rm "$FILE_PATH" + fi + sleep 10 # Check every 10 seconds + done + + volumes: + - name: dshm + emptyDir: + medium: Memory + - name: workspace + emptyDir: {} + + tolerations: + - effect: NoSchedule + key: sku + operator: Equal + value: gpu + - effect: NoSchedule + key: nvidia.com/gpu + operator: Exists diff --git a/presets/tuning/tfs/cli.py b/presets/tuning/tfs/cli.py new file mode 100644 index 000000000..cab13506a --- /dev/null +++ b/presets/tuning/tfs/cli.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import os +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional + +import torch +from peft import LoraConfig +from transformers import (BitsAndBytesConfig, DataCollatorForLanguageModeling, + PreTrainedTokenizer, TrainerCallback) + + +@dataclass +class ExtDataCollator(DataCollatorForLanguageModeling): + tokenizer: Optional[PreTrainedTokenizer] = field(default=PreTrainedTokenizer, metadata={"help": "Tokenizer for DataCollatorForLanguageModeling"}) + +@dataclass +class ExtLoraConfig(LoraConfig): + """ + Lora Config + """ + init_lora_weights: bool = field(default=True, metadata={"help": "Enable initialization of LoRA weights"}) + target_modules: Optional[List[str]] = field(default=None, metadata={"help": ("List of module names to replace with LoRA.")}) + layers_to_transform: Optional[List[int]] = field(default=None, metadata={"help": "Layer indices to apply LoRA"}) + layers_pattern: Optional[List[str]] = field(default=None, metadata={"help": "Pattern to match layers for LoRA"}) + loftq_config: Dict[str, any] = field(default_factory=dict, metadata={"help": "LoftQ configuration for quantization"}) + +@dataclass +class DatasetConfig: + """ + Config for Dataset + """ + dataset_name: str = field(metadata={"help": "Name of Dataset"}) + shuffle_dataset: bool = field(default=True, metadata={"help": "Whether to shuffle dataset"}) + shuffle_seed: int = field(default=42, metadata={"help": "Seed for shuffling data"}) + context_column: str = field(default="Context", metadata={"help": "Example human input column in the dataset"}) + response_column: str = field(default="Response", metadata={"help": "Example bot response output column in the dataset"}) + train_test_split: float = field(default=0.8, metadata={"help": "Split between test and training data (e.g. 0.8 means 80/20% train/test split)"}) + +@dataclass +class TokenizerParams: + """ + Tokenizer params + """ + add_special_tokens: bool = field(default=True, metadata={"help": ""}) + padding: bool = field(default=False, metadata={"help": ""}) + truncation: bool = field(default=None, metadata={"help": ""}) + max_length: Optional[int] = field(default=None, metadata={"help": ""}) + stride: int = field(default=0, metadata={"help": ""}) + is_split_into_words: bool = field(default=False, metadata={"help": ""}) + tok_pad_to_multiple_of: Optional[int] = field(default=None, metadata={"help": ""}) + tok_return_tensors: Optional[str] = field(default=None, metadata={"help": ""}) + return_token_type_ids: Optional[bool] = field(default=None, metadata={"help": ""}) + return_attention_mask: Optional[bool] = field(default=None, metadata={"help": ""}) + return_overflowing_tokens: bool = field(default=False, metadata={"help": ""}) + return_special_tokens_mask: bool = field(default=False, metadata={"help": ""}) + return_offsets_mapping: bool = field(default=False, metadata={"help": ""}) + return_length: bool = field(default=False, metadata={"help": ""}) + verbose: bool = field(default=True, metadata={"help": ""}) + +@dataclass +class ModelConfig: + """ + Transformers Model Configuration Parameters + """ + pretrained_model_name_or_path: Optional[str] = field(default="/workspace/tfs/weights", metadata={"help": "Path to the pretrained model or model identifier from huggingface.co/models"}) + state_dict: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "State dictionary for the model"}) + cache_dir: Optional[str] = field(default=None, metadata={"help": "Cache directory for the model"}) + from_tf: bool = field(default=False, metadata={"help": "Load model from a TensorFlow checkpoint"}) + force_download: bool = field(default=False, metadata={"help": "Force the download of the model"}) + resume_download: bool = field(default=False, metadata={"help": "Resume an interrupted download"}) + proxies: Optional[str] = field(default=None, metadata={"help": "Proxy configuration for downloading the model"}) + output_loading_info: bool = field(default=False, metadata={"help": "Output additional loading information"}) + allow_remote_files: bool = field(default=False, metadata={"help": "Allow using remote files, default is local only"}) + m_revision: str = field(default="main", metadata={"help": "Specific model version to use"}) + trust_remote_code: bool = field(default=False, metadata={"help": "Enable trusting remote code when loading the model"}) + m_load_in_4bit: bool = field(default=False, metadata={"help": "Load model in 4-bit mode"}) + m_load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"}) + torch_dtype: Optional[str] = field(default=None, metadata={"help": "The torch dtype for the pre-trained model"}) + device_map: str = field(default="auto", metadata={"help": "The device map for the pre-trained model"}) + + def __post_init__(self): + """ + Post-initialization to validate some ModelConfig values + """ + if self.torch_dtype and not hasattr(torch, self.torch_dtype): + raise ValueError(f"Invalid torch dtype: {self.torch_dtype}") + self.torch_dtype = getattr(torch, self.torch_dtype) if self.torch_dtype else None + +@dataclass +class QuantizationConfig(BitsAndBytesConfig): + """ + Quanitization Configuration + """ + quant_method: str = field(default="bitsandbytes", metadata={"help": "Quantization Method {bitsandbytes,gptq,awq}"}) + load_in_8bit: bool = field(default=False, metadata={"help": "Enable 8-bit quantization"}) + load_in_4bit: bool = field(default=False, metadata={"help": "Enable 4-bit quantization"}) + llm_int8_threshold: float = field(default=6.0, metadata={"help": "LLM.int8 threshold"}) + llm_int8_skip_modules: List[str] = field(default=None, metadata={"help": "Modules to skip for 8-bit conversion"}) + llm_int8_enable_fp32_cpu_offload: bool = field(default=False, metadata={"help": "Enable FP32 CPU offload for 8-bit"}) + llm_int8_has_fp16_weight: bool = field(default=False, metadata={"help": "Use FP16 weights for LLM.int8"}) + bnb_4bit_compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype for 4-bit quantization"}) + bnb_4bit_quant_type: str = field(default="fp4", metadata={"help": "Quantization type for 4-bit"}) + bnb_4bit_use_double_quant: bool = field(default=False, metadata={"help": "Use double quantization for 4-bit"}) + +@dataclass +class TrainingConfig: + """ + Configuration for fine_tuning process + """ + save_output_path: str = field(default=".", metadata={"help": "Path where fine_tuning output is saved"}) + # Other fine_tuning-related configurations can go here + +# class CheckpointCallback(TrainerCallback): +# def on_train_end(self, args, state, control, **kwargs): +# model_path = args.output_dir +# timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") +# img_tag = f"ghcr.io/YOUR_USERNAME/LoRA-Adapter:{timestamp}" + +# # Write a file to indicate fine_tuning completion +# completion_indicator_path = os.path.join(model_path, "training_completed.txt") +# with open(completion_indicator_path, 'w') as f: +# f.write(f"Training completed at {timestamp}\n") +# f.write(f"Image Tag: {img_tag}\n") + + # This method is called whenever a checkpoint is saved. + # def on_save(self, args, state, control, **kwargs): + # docker_build_and_push() \ No newline at end of file diff --git a/presets/tuning/tfs/fine_tuning_api.py b/presets/tuning/tfs/fine_tuning_api.py new file mode 100644 index 000000000..46d733c2c --- /dev/null +++ b/presets/tuning/tfs/fine_tuning_api.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import os +from dataclasses import asdict +from datetime import datetime + +import torch +import transformers +from accelerate import Accelerator +from cli import (DatasetConfig, ExtDataCollator, ExtLoraConfig, ModelConfig, + QuantizationConfig, TokenizerParams, TrainingConfig) +from datasets import load_dataset +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig, HfArgumentParser, + TrainingArguments) + +# Parsing +parser = HfArgumentParser((ModelConfig, QuantizationConfig, ExtLoraConfig, TrainingConfig, TrainingArguments, ExtDataCollator, DatasetConfig, TokenizerParams)) +model_config, bnb_config, ext_lora_config, train_config, ta_args, dc_args, ds_config, tk_params, additional_args = parser.parse_args_into_dataclasses( + return_remaining_strings=True +) + +print("Unmatched arguments:", additional_args) + +accelerator = Accelerator() + +# Load Model Args +model_args = asdict(model_config) +model_args["local_files_only"] = not model_args.pop('allow_remote_files') +model_args["revision"] = model_args.pop('m_revision') +model_args["load_in_4bit"] = model_args.pop('m_load_in_4bit') +model_args["load_in_8bit"] = model_args.pop('m_load_in_8bit') +if accelerator.distributed_type != "NO": # Meaning we require distributed training + print("Setting device map for distributed training") + model_args["device_map"] = {"": Accelerator().process_index} + +# Load BitsAndBytesConfig +bnb_config_args = asdict(bnb_config) +bnb_config = BitsAndBytesConfig(**bnb_config_args) +enable_qlora = bnb_config.is_quantizable() + +# Load Tokenizer Params +tk_params = asdict(tk_params) +tk_params["pad_to_multiple_of"] = tk_params.pop("tok_pad_to_multiple_of") +tk_params["return_tensors"] = tk_params.pop("tok_return_tensors") + +# Load the Pre-Trained Tokenizer +tokenizer = AutoTokenizer.from_pretrained(**model_args) +if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token +if dc_args.mlm and tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language modeling. " + "You should pass `mlm=False` to train on causal language modeling instead." + ) +dc_args.tokenizer = tokenizer + +# Load the Pre-Trained Model +model = AutoModelForCausalLM.from_pretrained( + **model_args, + quantization_config=bnb_config if enable_qlora else None, +) + +print("model loaded") + +if enable_qlora: + print("enable_qlora") + # Preparing the Model for QLoRA + model = prepare_model_for_kbit_training(model) + +assert ext_lora_config is not None, "LoraConfig must be specified" +lora_config_args = asdict(ext_lora_config) +lora_config = LoraConfig(**lora_config_args) + +model = get_peft_model(model, lora_config) +# Cache is only used for generation, not for training +model.config.use_cache = False +model.print_trainable_parameters() + +# Loading and Preparing the Dataset +# Data format: https://huggingface.co/docs/autotrain/en/llm_finetuning +def preprocess_data(example): + prompt = f"human: {example[ds_config.context_column]}\n bot: {example[ds_config.response_column]}".strip() + return tokenizer(prompt, **tk_params) + +# Loading the dataset +dataset = load_dataset(ds_config.dataset_name, split="train") + +# Shuffling the dataset (if needed) +if ds_config.shuffle_dataset: + dataset = dataset.shuffle(seed=ds_config.shuffle_seed) + +# Preprocessing the data +dataset = dataset.map(preprocess_data) + +assert 0 < ds_config.train_test_split <= 1, "Train/Test split needs to be between 0 and 1" + +# Initialize variables for train and eval datasets +train_dataset, eval_dataset = dataset, None + +if ds_config.train_test_split < 1: + # Splitting the dataset into training and test sets + split_dataset = dataset.train_test_split( + test_size=1-ds_config.train_test_split, + seed=ds_config.shuffle_seed + ) + train_dataset, eval_dataset = split_dataset['train'], split_dataset['test'] + print("Training Dataset Dimensions: ", train_dataset.shape) + print("Test Dataset Dimensions: ", eval_dataset.shape) +else: + print(f"Using full dataset for training. Dimensions: {train_dataset.shape}") + +# checkpoint_callback = CheckpointCallback() + +# Training the Model +trainer = accelerator.prepare(transformers.Trainer( + model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + args=ta_args, + data_collator=dc_args, + # callbacks=[checkpoint_callback] +)) +trainer.train() +os.makedirs(train_config.save_output_path, exist_ok=True) +trainer.save_model(train_config.save_output_path) + +# Write file to signify training completion +timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") +completion_indicator_path = os.path.join(train_config.save_output_path, "fine_tuning_completed.txt") +with open(completion_indicator_path, 'w') as f: + f.write(f"Fine-Tuning completed at {timestamp}\n") diff --git a/presets/tuning/tfs/requirements.txt b/presets/tuning/tfs/requirements.txt new file mode 100644 index 000000000..091e6d21f --- /dev/null +++ b/presets/tuning/tfs/requirements.txt @@ -0,0 +1,14 @@ +datasets==2.16.1 +peft==0.8.2 +transformers==4.38.2 +torch==2.2.0 +accelerate==0.27.2 +fastapi==0.103.2 +pydantic==1.10.9 +uvicorn[standard]==0.23.2 +bitsandbytes==0.42.0 +gputil==1.4.0 +loralib +einops +xformers +deepspeed \ No newline at end of file From ab3a633ce33b490afe1f560ba5025ef7ea367c46 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Fri, 15 Mar 2024 12:56:16 -0700 Subject: [PATCH 02/23] fix: Protect secret with environment (#300) Signed-off-by: Ishaan Sehgal --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 29d14ac24..03be3669c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,6 +22,7 @@ env: jobs: unit-tests: runs-on: ubuntu-latest + environment: unit-tests steps: - name: Set up Go ${{ env.GO_VERSION }} uses: actions/setup-go@v5 From fb2aba3f1afd51a4d900b94fb48cc69d2fe42b67 Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Sun, 17 Mar 2024 07:27:08 -0700 Subject: [PATCH 03/23] ci: fix 1ES pool label name (#301) Signed-off-by: Heba <31887807+helayoty@users.noreply.github.com> --- .github/workflows/publish-image-acr.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-image-acr.yml b/.github/workflows/publish-image-acr.yml index aaf025093..a307bd942 100644 --- a/.github/workflows/publish-image-acr.yml +++ b/.github/workflows/publish-image-acr.yml @@ -18,7 +18,7 @@ env: jobs: check-tag: runs-on: - labels: [ self-hosted, "1ES.Pool=${{ matrix.runner }}" ] + labels: [ "self-hosted", "1ES.Pool=1es-aks-kaito-agent-pool-ubuntu" ] environment: publish-mcr outputs: tag: ${{ steps.get-tag.outputs.tag }} @@ -57,7 +57,7 @@ jobs: publish: runs-on: - labels: [ self-hosted, "1ES.Pool=${{ matrix.runner }}" ] + labels: [ "self-hosted", "1ES.Pool=1es-aks-kaito-agent-pool-ubuntu" ] environment: publish-mcr needs: - check-tag From 9894f3d89f50b236cda5f9ec6b48a2ff39e95808 Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:29:00 -0700 Subject: [PATCH 04/23] release: update manifest and helm charts for v0.2.1 (#302) - Fixes https://github.com/Azure/kaito/security/dependabot/13 - Release v0.2.1 --------- Signed-off-by: Heba Elayoty --- Makefile | 6 +- charts/kaito/gpu-provisioner/Chart.yaml | 4 +- charts/kaito/gpu-provisioner/README.md | 107 +++++++++--------- .../templates/configmap-logging.yaml | 2 +- .../gpu-provisioner/templates/deployment.yaml | 2 + charts/kaito/gpu-provisioner/values.yaml | 4 +- charts/kaito/workspace/Chart.yaml | 4 +- charts/kaito/workspace/README.md | 2 +- charts/kaito/workspace/values.yaml | 2 +- docker/presets/inference/llama-2/Dockerfile | 2 +- go.mod | 4 +- go.sum | 4 +- 12 files changed, 70 insertions(+), 73 deletions(-) diff --git a/Makefile b/Makefile index d3df4896c..c3bfe3385 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ # Image URL to use all building/pushing image targets REGISTRY ?= YOUR_REGISTRY IMG_NAME ?= workspace -VERSION ?= v0.2.0 +VERSION ?= v0.2.1 IMG_TAG ?= $(subst v,,$(VERSION)) ROOT_DIR := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) @@ -208,8 +208,6 @@ gpu-provisioner-helm: ## Update Azure client env vars and settings in helm valu $(eval AZURE_TENANT_ID=$(shell az account show | jq -r ".tenantId")) $(eval AZURE_SUBSCRIPTION_ID=$(shell az account show | jq -r ".id")) - yq -i '(.controller.image.repository) = "mcr.microsoft.com/aks/kaito/gpu-provisioner"' ./charts/kaito/gpu-provisioner/values.yaml - yq -i '(.controller.image.tag) = "0.1.0"' ./charts/kaito/gpu-provisioner/values.yaml yq -i '(.controller.env[] | select(.name=="ARM_SUBSCRIPTION_ID")) .value = "$(AZURE_SUBSCRIPTION_ID)"' ./charts/kaito/gpu-provisioner/values.yaml yq -i '(.controller.env[] | select(.name=="LOCATION")) .value = "$(AZURE_LOCATION)"' ./charts/kaito/gpu-provisioner/values.yaml yq -i '(.controller.env[] | select(.name=="ARM_RESOURCE_GROUP")) .value = "$(AZURE_RESOURCE_GROUP)"' ./charts/kaito/gpu-provisioner/values.yaml @@ -266,7 +264,7 @@ lint: $(GOLANGCI_LINT) .PHONY: release-manifest release-manifest: @sed -i -e 's/^VERSION ?= .*/VERSION ?= ${VERSION}/' ./Makefile - @sed -i -e "s/version: .*/version: ${IMG_TAG}/" ./charts/kaito/workspace/Chart.yaml + @sed -i -e "s/appVersion: .*/appVersion: ${IMG_TAG}/" ./charts/kaito/workspace/Chart.yaml @sed -i -e "s/tag: .*/tag: ${IMG_TAG}/" ./charts/kaito/workspace/values.yaml @sed -i -e 's/IMG_TAG=.*/IMG_TAG=${IMG_TAG}/' ./charts/kaito/workspace/README.md git checkout -b release-${VERSION} diff --git a/charts/kaito/gpu-provisioner/Chart.yaml b/charts/kaito/gpu-provisioner/Chart.yaml index 20f58d8cc..83889e614 100644 --- a/charts/kaito/gpu-provisioner/Chart.yaml +++ b/charts/kaito/gpu-provisioner/Chart.yaml @@ -2,8 +2,8 @@ apiVersion: v2 name: gpu-provisioner description: A Helm chart for gpu-provisioner type: application -version: 0.1.0 -appVersion: 0.1.0 +version: 0.2.0 +appVersion: 0.2.0 sources: - https://github.com/Azure/gpu-provisioner maintainers: diff --git a/charts/kaito/gpu-provisioner/README.md b/charts/kaito/gpu-provisioner/README.md index 952d2c976..d2dc949cd 100644 --- a/charts/kaito/gpu-provisioner/README.md +++ b/charts/kaito/gpu-provisioner/README.md @@ -1,6 +1,6 @@ # Karpenter Azure provider gpu-provisioner -![Version: 0.1.0](https://img.shields.io/badge/Version-0.1.0-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.1.0](https://img.shields.io/badge/AppVersion-0.1.0-informational?style=flat-square) +![Version: 0.2.0](https://img.shields.io/badge/Version-0.2.0-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.2.0](https://img.shields.io/badge/AppVersion-0.2.0-informational?style=flat-square) A Helm chart for gpu-provisioner @@ -9,63 +9,58 @@ A Helm chart for gpu-provisioner To install the chart with the release name `gpu-provisioner`: ```bash -helm install gpu-provisioner ./charts/kaito/gpu-provisioner +helm install gpu-provisioner ./charts/gpu-provisioner ``` ## Values -| Key | Type | Default | Description | -|------------------------------------|--------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| additionalAnnotations | object | `{}` | Additional annotations to add into metadata. | -| additionalClusterRoleRules | list | `[]` | Specifies additional rules for the core ClusterRole. | -| additionalLabels | object | `{}` | Additional labels to add into metadata. | -| affinity | object | `{"nodeAffinity":{"requiredDuringSchedulingIgnoredDuringExecution":{"nodeSelectorTerms":[{"matchExpressions":[{"key":"kubernetes.azure.com/cluster","operator":"Exists"},{"key":"type","operator":"NotIn","values":["virtual-kubelet"]},{"key":"kubernetes.io/os","operator":"In","values":["linux"]}]},{"matchExpressions":[{"key":"karpenter.sh/provisioner-name","operator":"DoesNotExist"}]}]}},"podAntiAffinity":{"requiredDuringSchedulingIgnoredDuringExecution":[{"topologyKey":"kubernetes.io/hostname"}]}}` | Affinity rules for scheduling the pod. If an explicit label selector is not provided for pod affinity or pod anti-affinity one will be created from the pod selector labels. | -| controller.env | list | `[{"name":"ARM_SUBSCRIPTION_ID","value":null},{"name":"AZURE_CLUSTER_NAME","value":null},{"name":"AZURE_NODE_RESOURCE_GROUP","value":null},{"name":"ARM_RESOURCE_GROUP","value":null}]` | Additional environment variables for the controller pod. | -| controller.envFrom | list | `[]` | | -| controller.errorOutputPaths | list | `["stderr"]` | Controller errorOutputPaths - default to stderr only | -| controller.healthProbe.port | int | `8081` | The container port to use for http health probe. | -| controller.image.digest | string | `""` | SHA256 digest of the controller image. | -| controller.image.repository | string | `"ghcr.io/azure/gpu-provisioner"` | Repository path to the controller image. | -| controller.image.tag | string | `"0.1.0"` | Tag of the controller image. | -| controller.logEncoding | string | `""` | Controller log encoding, defaults to the global log encoding | -| controller.logLevel | string | `"debug"` | Controller log level, defaults to the global log level | -| controller.metrics.port | int | `8000` | The container port to use for metrics. | -| controller.outputPaths | list | `["stdout"]` | Controller outputPaths - default to stdout only | -| controller.resources | object | `{"limits":{"cpu":"500m"},"requests":{"cpu":"200m"}}` | Resources for the controller pod. | -| controller.securityContext | object | `{}` | SecurityContext for the controller container. | -| dnsConfig | object | `{}` | Configure DNS Config for the pod | -| dnsPolicy | string | `"Default"` | Configure the DNS Policy for the pod | -| extraVolumes | list | `[]` | Additional volumes for the pod. | -| fullnameOverride | string | `""` | Overrides the chart's computed fullname. | -| hostNetwork | bool | `false` | Bind the pod to the host network. This is required when using a custom CNI. | -| imagePullPolicy | string | `"IfNotPresent"` | Image pull policy for Docker images. | -| imagePullSecrets | list | `[]` | Image pull secrets for Docker images. | -| logEncoding | string | `"console"` | Global log encoding | -| logLevel | string | `"debug"` | Global log level | -| nameOverride | string | `""` | Overrides the chart's name. | -| namespace | string | `"gpu-provisioner"` | | -| nodeSelector | object | `{"kubernetes.io/os":"linux"}` | Node selectors to schedule the pod to nodes with labels. | -| podAnnotations | object | `{}` | Additional annotations for the pod. | -| podDisruptionBudget.maxUnavailable | int | `1` | | -| podDisruptionBudget.name | string | `"karpenter"` | | -| podLabels | object | `{}` | Additional labels for the pod. | -| podSecurityContext | object | `{"fsGroup":1000}` | SecurityContext for the pod. | -| priorityClassName | string | `"system-cluster-critical"` | PriorityClass name for the pod. | -| replicas | int | `1` | Number of replicas. | -| revisionHistoryLimit | int | `10` | The number of old ReplicaSets to retain to allow rollback. | -| serviceAccount.annotations | object | `{}` | Additional annotations for the ServiceAccount. | -| serviceAccount.create | bool | `true` | Specifies if a ServiceAccount should be created. | -| serviceAccount.name | string | `""` | The name of the ServiceAccount to use. If not set and create is true, a name is generated using the fullname template. | -| serviceMonitor.additionalLabels | object | `{}` | Additional labels for the ServiceMonitor. | -| serviceMonitor.enabled | bool | `false` | Specifies whether a ServiceMonitor should be created. | -| serviceMonitor.endpointConfig | object | `{}` | Endpoint configuration for the ServiceMonitor. | -| settings.azure | object | `{"clusterName":null}` | Azure-specific configuration values | -| settings.azure.clusterName | string | `nil` | Cluster name. | -| strategy | object | `{"rollingUpdate":{"maxUnavailable":1}}` | Strategy for updating the pod. | -| terminationGracePeriodSeconds | string | `nil` | Override the default termination grace period for the pod. | -| tolerations | list | `[{"key":"CriticalAddonsOnly","operator":"Exists"}]` | Tolerations to allow the pod to be scheduled to nodes with taints. | -| topologySpreadConstraints | list | `[{"maxSkew":1,"topologyKey":"topology.kubernetes.io/zone","whenUnsatisfiable":"ScheduleAnyway"}]` | Topology spread constraints to increase the controller resilience by distributing pods across the cluster zones. If an explicit label selector is not provided one will be created from the pod selector labels. | -| workloadIdentity | object | `{"clientId":"","tenantId":""}` | Global Settings to configure gpu-provisioner | +| Key | Type | Default | Description | +|------------------------------------|--------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------| +| additionalAnnotations | object | `{}` | Additional annotations to add into metadata. | +| additionalLabels | object | `{}` | Additional labels to add into metadata. | +| affinity | object | `{"nodeAffinity":{"requiredDuringSchedulingIgnoredDuringExecution":{"nodeSelectorTerms":[{"matchExpressions":[{"key":"karpenter.sh/provisioner-name","operator":"DoesNotExist"}]}]}}}` | Affinity rules for scheduling the pod. | +| controller.env | list | `[]` | Additional environment variables for the controller pod. | +| controller.errorOutputPaths | list | `["stderr"]` | Controller errorOutputPaths - default to stderr only | +| controller.extraVolumeMounts | list | `[]` | Additional volumeMounts for the controller pod. | +| controller.image.repository | string | `mcr.microsoft.com/aks/kaito/gpu-provisioner` | | +| controller.image.tag | string | `0.2.0` | | +| controller.logEncoding | string | `""` | Controller log encoding, defaults to the global log encoding | +| controller.logLevel | string | `""` | Controller log level, defaults to the global log level | +| controller.outputPaths | list | `["stdout"]` | Controller outputPaths - default to stdout only | +| controller.resources | object | `{"limits":{"cpu":1,"memory":"1Gi"},"requests":{"cpu":1,"memory":"1Gi"}}` | Resources for the controller pod. | +| controller.securityContext | object | `{}` | SecurityContext for the controller container. | +| controller.sidecarContainer | object | `{}` | Additional sideCarContainer config - this will also inherit volume mounts from deployment | +| dnsConfig | object | `{}` | Configure DNS Config for the pod | +| dnsPolicy | string | `"Default"` | Configure the DNS Policy for the pod | +| extraVolumes | list | `[]` | Additional volumes for the pod. | +| fullnameOverride | string | `""` | Overrides the chart's computed fullname. | +| hostNetwork | bool | `false` | Bind the pod to the host network. This is required when using a custom CNI. | +| imagePullPolicy | string | `"IfNotPresent"` | Image pull policy for Docker images. | +| imagePullSecrets | list | `[]` | Image pull secrets for Docker images. | +| logEncoding | string | `"console"` | Gloabl log encoding | +| logLevel | string | `"debug"` | Global log level | +| nameOverride | string | `""` | Overrides the chart's name. | +| nodeSelector | object | `{"kubernetes.io/os":"linux"}` | Node selectors to schedule the pod to nodes with labels. | +| podAnnotations | object | `{}` | Additional annotations for the pod. | +| podDisruptionBudget.maxUnavailable | int | `1` | | +| podDisruptionBudget.name | string | `"karpenter"` | | +| podLabels | object | `{}` | Additional labels for the pod. | +| podSecurityContext | object | `{"fsGroup":1000}` | SecurityContext for the pod. | +| priorityClassName | string | `"system-cluster-critical"` | PriorityClass name for the pod. | +| replicas | int | `2` | Number of replicas. | +| revisionHistoryLimit | int | `10` | The number of old ReplicaSets to retain to allow rollback. | +| serviceAccount.annotations | object | `{}` | Additional annotations for the ServiceAccount. | +| serviceAccount.create | bool | `true` | Specifies if a ServiceAccount should be created. | +| serviceAccount.name | string | `""` | The name of the ServiceAccount to use. If not set and create is true, a name is generated using the fullname template. | +| serviceMonitor.additionalLabels | object | `{}` | Additional labels for the ServiceMonitor. | +| serviceMonitor.enabled | bool | `false` | Specifies whether a ServiceMonitor should be created. | +| serviceMonitor.endpointConfig | object | `{}` | Endpoint configuration for the ServiceMonitor. | +| settings | object | `{"azure":{"clusterName":"","tags":null}}` | Global Settings to configure Karpenter | +| settings.azure | object | `{"clusterName":"","tags":null}` | Azure-specific configuration values | +| settings.azure.clusterName | string | `""` | Cluster name. | | +| settings.azure.tags | string | `nil` | The global tags to use on all Azure infrastructure resources (launch templates, instances, SQS queue, etc.) | +| strategy | object | `{"rollingUpdate":{"maxUnavailable":1}}` | Strategy for updating the pod. | +| terminationGracePeriodSeconds | string | `nil` | Override the default termination grace period for the pod. | +| tolerations | list | `[{"key":"CriticalAddonsOnly","operator":"Exists"}]` | Tolerations to allow the pod to be scheduled to nodes with taints. | +| topologySpreadConstraints | list | `[{"maxSkew":1,"topologyKey":"topology.kubernetes.io/zone","whenUnsatisfiable":"ScheduleAnyway"}]` | topologySpreadConstraints to increase the controller resilience | ----------------------------------------------- -Autogenerated from chart metadata using [helm-docs v1.11.3](https://github.com/norwoodj/helm-docs/releases/v1.11.3) diff --git a/charts/kaito/gpu-provisioner/templates/configmap-logging.yaml b/charts/kaito/gpu-provisioner/templates/configmap-logging.yaml index e6dcbd46e..ce293dd76 100644 --- a/charts/kaito/gpu-provisioner/templates/configmap-logging.yaml +++ b/charts/kaito/gpu-provisioner/templates/configmap-logging.yaml @@ -1,7 +1,7 @@ apiVersion: v1 kind: ConfigMap metadata: - name: config-logging + name: gpu-provisioner-config-logging namespace: {{ .Values.namespace }} labels: {{- include "gpu-provisioner.labels" . | nindent 4 }} diff --git a/charts/kaito/gpu-provisioner/templates/deployment.yaml b/charts/kaito/gpu-provisioner/templates/deployment.yaml index 9c6ab6ab2..4234e024f 100644 --- a/charts/kaito/gpu-provisioner/templates/deployment.yaml +++ b/charts/kaito/gpu-provisioner/templates/deployment.yaml @@ -75,6 +75,8 @@ spec: image: {{ include "gpu-provisioner.controller.image" . }} imagePullPolicy: {{ .Values.imagePullPolicy }} env: + - name: CONFIG_LOGGING_NAME + value: "gpu-provisioner-config-logging" - name: SYSTEM_NAMESPACE valueFrom: fieldRef: diff --git a/charts/kaito/gpu-provisioner/values.yaml b/charts/kaito/gpu-provisioner/values.yaml index 11a0103c0..8c40f9e8d 100644 --- a/charts/kaito/gpu-provisioner/values.yaml +++ b/charts/kaito/gpu-provisioner/values.yaml @@ -103,7 +103,7 @@ controller: # -- Repository path to the controller image. repository: mcr.microsoft.com/aks/kaito/gpu-provisioner # -- Tag of the controller image. - tag: 0.1.0 + tag: 0.2.0 # -- SHA256 digest of the controller image. digest: "" # -- SecurityContext for the controller container. @@ -122,6 +122,8 @@ controller: value: - name: LEADER_ELECT # disable leader election for better debugging experience value: "false" + - name: E2E_TEST_MODE + value: "false" envFrom: [] # -- Resources for the controller pod. resources: diff --git a/charts/kaito/workspace/Chart.yaml b/charts/kaito/workspace/Chart.yaml index d38144ea8..35c17d199 100644 --- a/charts/kaito/workspace/Chart.yaml +++ b/charts/kaito/workspace/Chart.yaml @@ -6,13 +6,13 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.0 +version: 0.2.1 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to # follow Semantic Versioning. They should reflect the version the application is using. # It is recommended to use it with quotes. -appVersion: "0.2.0" +appVersion: "0.2.1" home: https://github.com/Azure/kaito sources: - https://github.com/Azure/kaito diff --git a/charts/kaito/workspace/README.md b/charts/kaito/workspace/README.md index dcfe202f5..497ecbfb5 100644 --- a/charts/kaito/workspace/README.md +++ b/charts/kaito/workspace/README.md @@ -5,7 +5,7 @@ ```bash export REGISTRY= export IMG_NAME=workspace -export IMG_TAG=0.2.0 +export IMG_TAG=0.2.1 helm install workspace ./charts/kaito/workspace --set image.repository=${REGISTRY}/$(IMG_NAME) --set image.tag=$(IMG_TAG) ``` diff --git a/charts/kaito/workspace/values.yaml b/charts/kaito/workspace/values.yaml index 77ff062b7..90ae02156 100644 --- a/charts/kaito/workspace/values.yaml +++ b/charts/kaito/workspace/values.yaml @@ -5,7 +5,7 @@ replicaCount: 1 image: repository: mcr.microsoft.com/aks/kaito/workspace pullPolicy: IfNotPresent - tag: 0.2.0 + tag: 0.2.1 imagePullSecrets: [] podAnnotations: {} podSecurityContext: diff --git a/docker/presets/inference/llama-2/Dockerfile b/docker/presets/inference/llama-2/Dockerfile index 641d158bc..285cb122a 100644 --- a/docker/presets/inference/llama-2/Dockerfile +++ b/docker/presets/inference/llama-2/Dockerfile @@ -20,7 +20,7 @@ WORKDIR /workspace/llama RUN ["/bin/bash", "-c", "sed -i $'/torch.distributed.init_process_group(\"nccl\")/c\\ import datetime\\\n torch.distributed.init_process_group(\"nccl\", timeout=datetime.timedelta(days=365*100))' /workspace/llama/llama/generation.py"] RUN pip install -e . -RUN pip install torch==2.1.0 fastapi==0.103.2 pydantic==1.10.9 gputil==1.4.0 +RUN pip install torch==2.1.0 fastapi==0.109.1 pydantic==1.10.9 gputil==1.4.0 RUN pip install 'uvicorn[standard]' ARG WEIGHTS_PATH diff --git a/go.mod b/go.mod index 33cee9053..7f4ee8996 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/onsi/gomega v1.27.8 github.com/samber/lo v1.38.1 github.com/stretchr/testify v1.8.4 + gopkg.in/yaml.v2 v2.4.0 gotest.tools v2.2.0+incompatible k8s.io/api v0.27.7 k8s.io/apimachinery v0.27.7 @@ -90,9 +91,8 @@ require ( google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect google.golang.org/grpc v1.56.3 // indirect - google.golang.org/protobuf v1.30.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/apiextensions-apiserver v0.27.2 // indirect k8s.io/component-base v0.27.7 // indirect diff --git a/go.sum b/go.sum index 7b101bc9b..b5eed9dcc 100644 --- a/go.sum +++ b/go.sum @@ -629,8 +629,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From d0498a0607db46c5019162a9307756a628128add Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 19 Mar 2024 14:00:47 -0700 Subject: [PATCH 05/23] fix: Upgrade FastAPI Version (#305) Fixes https://github.com/Azure/kaito/security/dependabot/14 --- presets/tuning/tfs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/tuning/tfs/requirements.txt b/presets/tuning/tfs/requirements.txt index 091e6d21f..9848f3e67 100644 --- a/presets/tuning/tfs/requirements.txt +++ b/presets/tuning/tfs/requirements.txt @@ -3,7 +3,7 @@ peft==0.8.2 transformers==4.38.2 torch==2.2.0 accelerate==0.27.2 -fastapi==0.103.2 +fastapi==0.109.1 pydantic==1.10.9 uvicorn[standard]==0.23.2 bitsandbytes==0.42.0 From e1673aad3d7ef10bca6768ce94b3ee5e26cb79ea Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Tue, 19 Mar 2024 14:23:33 -0700 Subject: [PATCH 06/23] ci: Update pipeline target Signed-off-by: Heba <31887807+helayoty@users.noreply.github.com> --- .github/workflows/create-release.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/create-release.yml b/.github/workflows/create-release.yml index bf87a90ae..935a88779 100644 --- a/.github/workflows/create-release.yml +++ b/.github/workflows/create-release.yml @@ -1,9 +1,5 @@ name: Create release on: - workflow_run: - workflows: [ "Create, Scan and Publish KAITO image" ] - types: [completed] - branches: [release-**] repository_dispatch: types: [ release-tag ] branches: [ release-** ] @@ -18,7 +14,6 @@ env: jobs: create-release: - if: ${{ github.event.workflow_run.conclusion == 'success' }} runs-on: ubuntu-20.04 steps: - name: Harden Runner From 2485c582c5efe043f6c21b6bd5f091b020c7faa1 Mon Sep 17 00:00:00 2001 From: Fei Guo Date: Tue, 19 Mar 2024 18:10:39 -0700 Subject: [PATCH 07/23] docs: Update README.md for announcing v0.2.1 (#307) Signed-off-by: Fei Guo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b95c05600..f022b3618 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ | ![notification](docs/img/bell.svg) What is NEW! | |-------------------------------------------------| -| Latest Release: March 4th, 2024. Kaito v0.2.0. | +| Latest Release: March 19th, 2024. Kaito v0.2.1. | | First Release: Nov 15th, 2023. Kaito v0.1.0. | Kaito is an operator that automates the AI/ML inference model deployment in a Kubernetes cluster. From 4ba337d098a70eb66edc6a8774ce1b499e9d8204 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Wed, 20 Mar 2024 13:08:08 -0700 Subject: [PATCH 08/23] feat: Part 2 - Add validation checks for TuningSpec, DataSource, DataDestination (#304) Open to feedback. Here is validation checks for the new CRD. Code Coverage: 87.7% - workspace_validation.go [coverage.txt](https://github.com/Azure/kaito/files/14661460/coverage.txt) --- api/v1alpha1/workspace_types.go | 14 +- api/v1alpha1/workspace_validation.go | 179 ++++- api/v1alpha1/workspace_validation_test.go | 617 ++++++++++++++++++ api/v1alpha1/zz_generated.deepcopy.go | 12 +- .../workspace/crds/kaito.sh_workspaces.yaml | 147 ++++- config/crd/bases/kaito.sh_workspaces.yaml | 11 +- pkg/utils/testUtils.go | 6 +- test/e2e/preset_test.go | 19 +- test/e2e/utils/utils.go | 30 +- 9 files changed, 986 insertions(+), 49 deletions(-) diff --git a/api/v1alpha1/workspace_types.go b/api/v1alpha1/workspace_types.go index 2f966f647..4484b8250 100644 --- a/api/v1alpha1/workspace_types.go +++ b/api/v1alpha1/workspace_types.go @@ -13,7 +13,7 @@ const ( ModelImageAccessModePrivate ModelImageAccessMode = "private" ) -// ResourceSpec desicribes the resource requirement of running the workload. +// ResourceSpec describes the resource requirement of running the workload. // If the number of nodes in the cluster that meet the InstanceType and // LabelSelector requirements is small than the Count, controller // will provision new nodes before deploying the workload. @@ -51,7 +51,7 @@ type PresetMeta struct { // AccessMode specifies whether the containerized model image is accessible via public registry // or private registry. This field defaults to "public" if not specified. // If this field is "private", user needs to provide the private image information in PresetOptions. - // +bebuilder:default:="public" + // +kubebuilder:default:="public" // +optional AccessMode ModelImageAccessMode `json:"accessMode,omitempty"` } @@ -106,7 +106,7 @@ type DataSource struct { // URLs specifies the links to the public data sources. E.g., files in a public github repository. // +optional URLs []string `json:"urls,omitempty"` - // The directory in the hsot that contains the data. + // The directory in the host that contains the data. // +optional HostPath string `json:"hostPath,omitempty"` // The name of the image that contains the source data. The assumption is that the source data locates in the @@ -150,9 +150,9 @@ type TuningSpec struct { // +optional Config string `json:"config,omitempty"` // Input describes the input used by the tuning method. - Input *DataSource `json:"input,omitempty"` + Input *DataSource `json:"input"` // Output specified where to store the tuning output. - Output *DataDestination `json:"output,omitempty"` + Output *DataDestination `json:"output"` } // WorkspaceStatus defines the observed state of Workspace @@ -181,8 +181,8 @@ type Workspace struct { metav1.ObjectMeta `json:"metadata,omitempty"` Resource ResourceSpec `json:"resource,omitempty"` - Inference InferenceSpec `json:"inference,omitempty"` - Tuning TuningSpec `json:"tuning,omitempty"` + Inference *InferenceSpec `json:"inference,omitempty"` + Tuning *TuningSpec `json:"tuning,omitempty"` Status WorkspaceStatus `json:"status,omitempty"` } diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 16576f684..79b27e1b9 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "reflect" + "sort" "strings" "github.com/azure/kaito/pkg/utils/plugin" @@ -35,16 +36,184 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { if base == nil { klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) errs = errs.Also( - w.Inference.validateCreate().ViaField("inference"), - w.Resource.validateCreate(w.Inference).ViaField("resource"), + w.validateCreate().ViaField("spec"), + // TODO: Consider validate resource based on Tuning Spec + w.Resource.validateCreate(*w.Inference).ViaField("resource"), ) + if w.Inference != nil { + // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter + errs = errs.Also(w.Inference.validateCreate().ViaField("inference")) + } + if w.Tuning != nil { + errs = errs.Also(w.Tuning.validateCreate().ViaField("tuning")) + } } else { klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) old := base.(*Workspace) errs = errs.Also( + w.validateUpdate(old).ViaField("spec"), w.Resource.validateUpdate(&old.Resource).ViaField("resource"), - w.Inference.validateUpdate(&old.Inference).ViaField("inference"), ) + if w.Inference != nil { + // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter + errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference")) + } + if w.Tuning != nil { + errs = errs.Also(w.Tuning.validateUpdate(old.Tuning).ViaField("tuning")) + } + } + return errs +} + +func (w *Workspace) validateCreate() (errs *apis.FieldError) { + if w.Inference == nil && w.Tuning == nil { + errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", "")) + } + if w.Inference != nil && w.Tuning != nil { + errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", "")) + } + return errs +} + +func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { + if (old.Inference == nil && w.Inference != nil) || (old.Inference != nil && w.Inference == nil) { + errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference")) + } + + if (old.Tuning == nil && w.Tuning != nil) || (old.Tuning != nil && w.Tuning == nil) { + errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning")) + } + return errs +} + +func (r *TuningSpec) validateCreate() (errs *apis.FieldError) { + if r.Input == nil { + errs = errs.Also(apis.ErrMissingField("Input")) + } else { + errs = errs.Also(r.Input.validateCreate().ViaField("Input")) + } + if r.Output == nil { + errs = errs.Also(apis.ErrMissingField("Output")) + } else { + errs = errs.Also(r.Output.validateCreate().ViaField("Output")) + } + // Currently require a preset to specified, in future we can consider defining a template + if r.Preset == nil { + errs = errs.Also(apis.ErrMissingField("Preset")) + } else if presetName := string(r.Preset.Name); !isValidPreset(presetName) { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName")) + } + methodLowerCase := strings.ToLower(string(r.Method)) + if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { + errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) + } + return errs +} + +func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) { + if r.Input == nil { + errs = errs.Also(apis.ErrMissingField("Input")) + } else { + errs = errs.Also(r.Input.validateUpdate(old.Input, true).ViaField("Input")) + } + if r.Output == nil { + errs = errs.Also(apis.ErrMissingField("Output")) + } else { + errs = errs.Also(r.Output.validateUpdate(old.Output).ViaField("Output")) + } + if !reflect.DeepEqual(old.Preset, r.Preset) { + errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset")) + } + oldMethod, newMethod := strings.ToLower(string(old.Method)), strings.ToLower(string(r.Method)) + if !reflect.DeepEqual(oldMethod, newMethod) { + errs = errs.Also(apis.ErrGeneric("Method cannot be changed", "Method")) + } + // Consider supporting config fields changing + return errs +} + +func (r *DataSource) validateCreate() (errs *apis.FieldError) { + sourcesSpecified := 0 + if len(r.URLs) > 0 { + sourcesSpecified++ + } + if r.HostPath != "" { + sourcesSpecified++ + } + if r.Image != "" { + sourcesSpecified++ + } + + // Ensure exactly one of URLs, HostPath, or Image is specified + if sourcesSpecified != 1 { + errs = errs.Also(apis.ErrGeneric("Exactly one of URLs, HostPath, or Image must be specified", "URLs", "HostPath", "Image")) + } + + return errs +} + +func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.FieldError) { + if isTuning && !reflect.DeepEqual(old.Name, r.Name) { + errs = errs.Also(apis.ErrInvalidValue("During tuning Name field cannot be changed once set", "Name")) + } + oldURLs := make([]string, len(old.URLs)) + copy(oldURLs, old.URLs) + sort.Strings(oldURLs) + + newURLs := make([]string, len(r.URLs)) + copy(newURLs, r.URLs) + sort.Strings(newURLs) + + if !reflect.DeepEqual(oldURLs, newURLs) { + errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs")) + } + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + + oldSecrets := make([]string, len(old.ImagePullSecrets)) + copy(oldSecrets, old.ImagePullSecrets) + sort.Strings(oldSecrets) + + newSecrets := make([]string, len(r.ImagePullSecrets)) + copy(newSecrets, r.ImagePullSecrets) + sort.Strings(newSecrets) + + if !reflect.DeepEqual(oldSecrets, newSecrets) { + errs = errs.Also(apis.ErrInvalidValue("ImagePullSecrets field cannot be changed once set", "ImagePullSecrets")) + } + return errs +} + +func (r *DataDestination) validateCreate() (errs *apis.FieldError) { + destinationsSpecified := 0 + if r.HostPath != "" { + destinationsSpecified++ + } + if r.Image != "" { + destinationsSpecified++ + } + + // If no destination is specified, return an error + if destinationsSpecified == 0 { + errs = errs.Also(apis.ErrMissingField("At least one of HostPath or Image must be specified")) + } + return errs +} + +func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) { + if old.HostPath != r.HostPath { + errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath")) + } + if old.Image != r.Image { + errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image")) + } + + if old.ImagePushSecret != r.ImagePushSecret { + errs = errs.Also(apis.ErrInvalidValue("ImagePushSecret field cannot be changed once set", "ImagePushSecret")) } return errs } @@ -131,7 +300,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) { presetName := string(i.Preset.Name) // Validate preset name if !isValidPreset(presetName) { - errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName")) + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName")) } // Validate private preset has private image specified if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" && @@ -151,7 +320,7 @@ func (i *InferenceSpec) validateUpdate(old *InferenceSpec) (errs *apis.FieldErro if !reflect.DeepEqual(i.Preset, old.Preset) { errs = errs.Also(apis.ErrGeneric("field is immutable", "preset")) } - //inference.template can be changed, but cannot be unset. + // inference.template can be changed, but cannot be set/unset. if (i.Template != nil && old.Template == nil) || (i.Template == nil && old.Template != nil) { errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template")) } diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 0a3fa2de1..11631e67b 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -488,6 +488,623 @@ func TestInferenceSpecValidateUpdate(t *testing.T) { } } +func TestWorkspaceValidateCreate(t *testing.T) { + tests := []struct { + name string + workspace *Workspace + wantErr bool + errField string + }{ + { + name: "Neither Inference nor Tuning specified", + workspace: &Workspace{}, + wantErr: true, + errField: "neither", + }, + { + name: "Both Inference and Tuning specified", + workspace: &Workspace{ + Inference: &InferenceSpec{}, + Tuning: &TuningSpec{}, + }, + wantErr: true, + errField: "both", + }, + { + name: "Only Inference specified", + workspace: &Workspace{ + Inference: &InferenceSpec{}, + }, + wantErr: false, + errField: "", + }, + { + name: "Only Tuning specified", + workspace: &Workspace{ + Tuning: &TuningSpec{Input: &DataSource{}}, + }, + wantErr: false, + errField: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.workspace.validateCreate() + if (errs != nil) != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + if errs != nil && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain field %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestWorkspaceValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldWorkspace *Workspace + newWorkspace *Workspace + expectErrs bool + errFields []string // Fields we expect to have errors + }{ + { + name: "Inference toggled on", + oldWorkspace: &Workspace{}, + newWorkspace: &Workspace{ + Inference: &InferenceSpec{}, + }, + expectErrs: true, + errFields: []string{"inference"}, + }, + { + name: "Inference toggled off", + oldWorkspace: &Workspace{ + Inference: &InferenceSpec{Preset: &PresetSpec{}}, + }, + newWorkspace: &Workspace{}, + expectErrs: true, + errFields: []string{"inference"}, + }, + { + name: "Tuning toggled on", + oldWorkspace: &Workspace{}, + newWorkspace: &Workspace{ + Tuning: &TuningSpec{Input: &DataSource{}}, + }, + expectErrs: true, + errFields: []string{"tuning"}, + }, + { + name: "Tuning toggled off", + oldWorkspace: &Workspace{ + Tuning: &TuningSpec{Input: &DataSource{}}, + }, + newWorkspace: &Workspace{}, + expectErrs: true, + errFields: []string{"tuning"}, + }, + { + name: "No toggling", + oldWorkspace: &Workspace{ + Tuning: &TuningSpec{Input: &DataSource{}}, + }, + newWorkspace: &Workspace{ + Tuning: &TuningSpec{Input: &DataSource{}}, + }, + expectErrs: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newWorkspace.validateUpdate(tt.oldWorkspace) + hasErrs := errs != nil + + if hasErrs != tt.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tt.expectErrs) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestTuningSpecValidateCreate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + tuningSpec *TuningSpec + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "All fields valid", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input", HostPath: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: false, + errFields: nil, + }, + { + name: "Missing Input", + tuningSpec: &TuningSpec{ + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Input"}, + }, + { + name: "Missing Output", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Output"}, + }, + { + name: "Missing Preset", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"Preset"}, + }, + { + name: "Invalid Preset", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}}, + Method: TuningMethodLora, + }, + wantErr: true, + errFields: []string{"presetName"}, + }, + { + name: "Invalid Method", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input"}, + Output: &DataDestination{HostPath: "valid-output"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: "invalid-method", + }, + wantErr: true, + errFields: []string{"Method"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.tuningSpec.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() errors = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateCreate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestTuningSpecValidateUpdate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + oldTuning *TuningSpec + newTuning *TuningSpec + expectErrs bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + Output: &DataDestination{HostPath: "path1"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + newTuning: &TuningSpec{ + Input: &DataSource{Name: "input1"}, + Output: &DataDestination{HostPath: "path1"}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodLora, + }, + expectErrs: false, + }, + { + name: "Input changed", + oldTuning: &TuningSpec{ + Input: &DataSource{Name: "input", HostPath: "inputpath"}, + Output: &DataDestination{HostPath: "outputpath"}, + }, + newTuning: &TuningSpec{ + Input: &DataSource{Name: "input", HostPath: "randompath"}, + Output: &DataDestination{HostPath: "outputpath"}, + }, + expectErrs: true, + errFields: []string{"HostPath"}, + }, + { + name: "Output changed", + oldTuning: &TuningSpec{ + Output: &DataDestination{HostPath: "path1"}, + }, + newTuning: &TuningSpec{ + Output: &DataDestination{HostPath: "path2"}, + }, + expectErrs: true, + errFields: []string{"Output"}, + }, + { + name: "Preset changed", + oldTuning: &TuningSpec{ + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + }, + newTuning: &TuningSpec{ + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("invalid-preset")}}, + }, + expectErrs: true, + errFields: []string{"Preset"}, + }, + { + name: "Method changed", + oldTuning: &TuningSpec{ + Method: TuningMethodLora, + }, + newTuning: &TuningSpec{ + Method: TuningMethodQLora, + }, + expectErrs: true, + errFields: []string{"Method"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newTuning.validateUpdate(tt.oldTuning) + hasErrs := errs != nil + + if hasErrs != tt.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tt.expectErrs) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestDataSourceValidateCreate(t *testing.T) { + tests := []struct { + name string + dataSource *DataSource + wantErr bool + errField string // The field we expect to have an error on + }{ + { + name: "URLs specified only", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + }, + wantErr: false, + }, + { + name: "HostPath specified only", + dataSource: &DataSource{ + HostPath: "/data/path", + }, + wantErr: false, + }, + { + name: "Image specified only", + dataSource: &DataSource{ + Image: "data-image:latest", + }, + wantErr: false, + }, + { + name: "None specified", + dataSource: &DataSource{}, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + { + name: "URLs and HostPath specified", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + HostPath: "/data/path", + }, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + { + name: "All fields specified", + dataSource: &DataSource{ + URLs: []string{"http://example.com/data"}, + HostPath: "/data/path", + Image: "data-image:latest", + }, + wantErr: true, + errField: "Exactly one of URLs, HostPath, or Image must be specified", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.dataSource.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs && tt.errField != "" && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestDataSourceValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldSource *DataSource + newSource *DataSource + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldSource: &DataSource{ + URLs: []string{"http://example.com/data1", "http://example.com/data2"}, + HostPath: "/data/path", + Image: "data-image:latest", + ImagePullSecrets: []string{"secret1", "secret2"}, + }, + newSource: &DataSource{ + URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter + HostPath: "/data/path", + Image: "data-image:latest", + ImagePullSecrets: []string{"secret2", "secret1"}, // Note the different order, should not matter + }, + wantErr: false, + }, + { + name: "Name changed", + oldSource: &DataSource{ + Name: "original-dataset", + }, + newSource: &DataSource{ + Name: "new-dataset", + }, + wantErr: true, + errFields: []string{"Name"}, + }, + { + name: "URLs changed", + oldSource: &DataSource{ + URLs: []string{"http://example.com/old"}, + }, + newSource: &DataSource{ + URLs: []string{"http://example.com/new"}, + }, + wantErr: true, + errFields: []string{"URLs"}, + }, + { + name: "HostPath changed", + oldSource: &DataSource{ + HostPath: "/old/path", + }, + newSource: &DataSource{ + HostPath: "/new/path", + }, + wantErr: true, + errFields: []string{"HostPath"}, + }, + { + name: "Image changed", + oldSource: &DataSource{ + Image: "old-image:latest", + }, + newSource: &DataSource{ + Image: "new-image:latest", + }, + wantErr: true, + errFields: []string{"Image"}, + }, + { + name: "ImagePullSecrets changed", + oldSource: &DataSource{ + ImagePullSecrets: []string{"old-secret"}, + }, + newSource: &DataSource{ + ImagePullSecrets: []string{"new-secret"}, + }, + wantErr: true, + errFields: []string{"ImagePullSecrets"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newSource.validateUpdate(tt.oldSource, true) + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateUpdate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + +func TestDataDestinationValidateCreate(t *testing.T) { + tests := []struct { + name string + dataDestination *DataDestination + wantErr bool + errField string // The field we expect to have an error on + }{ + { + name: "No fields specified", + dataDestination: &DataDestination{}, + wantErr: true, + errField: "At least one of HostPath or Image must be specified", + }, + { + name: "HostPath specified only", + dataDestination: &DataDestination{ + HostPath: "/data/path", + }, + wantErr: false, + }, + { + name: "Image specified only", + dataDestination: &DataDestination{ + Image: "data-image:latest", + }, + wantErr: false, + }, + { + name: "Both fields specified", + dataDestination: &DataDestination{ + HostPath: "/data/path", + Image: "data-image:latest", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.dataDestination.validateCreate() + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateCreate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs && tt.errField != "" && !strings.Contains(errs.Error(), tt.errField) { + t.Errorf("validateCreate() expected error to contain %s, but got %s", tt.errField, errs.Error()) + } + }) + } +} + +func TestDataDestinationValidateUpdate(t *testing.T) { + tests := []struct { + name string + oldDest *DataDestination + newDest *DataDestination + wantErr bool + errFields []string // Fields we expect to have errors + }{ + { + name: "No changes", + oldDest: &DataDestination{ + HostPath: "/data/old", + Image: "old-image:latest", + ImagePushSecret: "old-secret", + }, + newDest: &DataDestination{ + HostPath: "/data/old", + Image: "old-image:latest", + ImagePushSecret: "old-secret", + }, + wantErr: false, + }, + { + name: "HostPath changed", + oldDest: &DataDestination{ + HostPath: "/data/old", + }, + newDest: &DataDestination{ + HostPath: "/data/new", + }, + wantErr: true, + errFields: []string{"HostPath"}, + }, + { + name: "Image changed", + oldDest: &DataDestination{ + Image: "old-image:latest", + }, + newDest: &DataDestination{ + Image: "new-image:latest", + }, + wantErr: true, + errFields: []string{"Image"}, + }, + { + name: "ImagePushSecret changed", + oldDest: &DataDestination{ + ImagePushSecret: "old-secret", + }, + newDest: &DataDestination{ + ImagePushSecret: "new-secret", + }, + wantErr: true, + errFields: []string{"ImagePushSecret"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.newDest.validateUpdate(tt.oldDest) + hasErrs := errs != nil + + if hasErrs != tt.wantErr { + t.Errorf("validateUpdate() error = %v, wantErr %v", errs, tt.wantErr) + } + + if hasErrs { + for _, field := range tt.errFields { + if !strings.Contains(errs.Error(), field) { + t.Errorf("validateUpdate() expected errors to contain field %s, but got %s", field, errs.Error()) + } + } + } + }) + } +} + func TestGetSupportedSKUs(t *testing.T) { tests := []struct { name string diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index a9d662c0f..6c3ee1eb8 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -249,8 +249,16 @@ func (in *Workspace) DeepCopyInto(out *Workspace) { out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) in.Resource.DeepCopyInto(&out.Resource) - in.Inference.DeepCopyInto(&out.Inference) - in.Tuning.DeepCopyInto(&out.Tuning) + if in.Inference != nil { + in, out := &in.Inference, &out.Inference + *out = new(InferenceSpec) + (*in).DeepCopyInto(*out) + } + if in.Tuning != nil { + in, out := &in.Tuning, &out.Tuning + *out = new(TuningSpec) + (*in).DeepCopyInto(*out) + } in.Status.DeepCopyInto(&out.Status) } diff --git a/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml b/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml index 40908f609..a4103a897 100644 --- a/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml +++ b/charts/kaito/workspace/crds/kaito.sh_workspaces.yaml @@ -47,11 +47,56 @@ spec: type: string inference: properties: + adapters: + description: Adapters are integrated into the base model for inference. + Users can specify multiple adapters for the model and the respective + weight of using each of them. + items: + properties: + source: + description: Source describes where to obtain the adapter data. + properties: + hostPath: + description: The directory in the host that contains the + data. + type: string + image: + description: The name of the image that contains the source + data. The assumption is that the source data locates in + the `data` directory in the image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names + in the same namespace used for pulling the data image. + items: + type: string + type: array + name: + description: The name of the dataset. The same name will + be used as a container name. It must be a valid DNS subdomain + value, + type: string + urls: + description: URLs specifies the links to the public data + sources. E.g., files in a public github repository. + items: + type: string + type: array + type: object + strength: + description: Strength specifies the default multiplier for applying + the adapter weights to the raw model weights. It is usually + a float number between 0 and 1. It is defined as a string + type to be language agnostic. + type: string + type: object + type: array preset: - description: Preset describles the model that will be deployed with - preset configurations. + description: Preset describes the base model that will be deployed + with preset configurations. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -72,7 +117,7 @@ spec: type: string imagePullSecrets: description: ImagePullSecrets is a list of secret names in - the same namespace used for pulling the image. + the same namespace used for pulling the model image. items: type: string type: array @@ -95,7 +140,7 @@ spec: metadata: type: object resource: - description: ResourceSpec desicribes the resource requirement of running + description: ResourceSpec describes the resource requirement of running the workload. If the number of nodes in the cluster that meet the InstanceType and LabelSelector requirements is small than the Count, controller will provision new nodes before deploying the workload. The final list of @@ -245,6 +290,100 @@ spec: type: string type: array type: object + tuning: + properties: + config: + description: Config specifies the name of the configmap in the same + namespace that contains the arguments used by the tuning method. + If not specified, a default configmap is used based on the specified + method. + type: string + input: + description: Input describes the input used by the tuning method. + properties: + hostPath: + description: The directory in the host that contains the data. + type: string + image: + description: The name of the image that contains the source data. + The assumption is that the source data locates in the `data` + directory in the image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names in the + same namespace used for pulling the data image. + items: + type: string + type: array + name: + description: The name of the dataset. The same name will be used + as a container name. It must be a valid DNS subdomain value, + type: string + urls: + description: URLs specifies the links to the public data sources. + E.g., files in a public github repository. + items: + type: string + type: array + type: object + method: + description: Method specifies the Parameter-Efficient Fine-Tuning(PEFT) + method, such as lora, qlora, used for the tuning. + type: string + output: + description: Output specified where to store the tuning output. + properties: + hostPath: + description: The directory in the host that contains the output + data. + type: string + image: + description: Name of the image where the output data is pushed + to. + type: string + imagePushSecret: + description: ImagePushSecret is the name of the secret in the + same namespace that contains the authentication information + that is needed for running `docker push`. + type: string + type: object + preset: + description: Preset describes which model to load for tuning. + properties: + accessMode: + default: public + description: AccessMode specifies whether the containerized model + image is accessible via public registry or private registry. + This field defaults to "public" if not specified. If this field + is "private", user needs to provide the private image information + in PresetOptions. + enum: + - public + - private + type: string + name: + description: Name of the supported models with preset configurations. + type: string + presetOptions: + properties: + image: + description: Image is the name of the containerized model + image. + type: string + imagePullSecrets: + description: ImagePullSecrets is a list of secret names in + the same namespace used for pulling the model image. + items: + type: string + type: array + type: object + required: + - name + type: object + required: + - input + - output + type: object type: object served: true storage: true diff --git a/config/crd/bases/kaito.sh_workspaces.yaml b/config/crd/bases/kaito.sh_workspaces.yaml index b3af23a76..a4103a897 100644 --- a/config/crd/bases/kaito.sh_workspaces.yaml +++ b/config/crd/bases/kaito.sh_workspaces.yaml @@ -57,7 +57,7 @@ spec: description: Source describes where to obtain the adapter data. properties: hostPath: - description: The directory in the hsot that contains the + description: The directory in the host that contains the data. type: string image: @@ -96,6 +96,7 @@ spec: with preset configurations. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -139,7 +140,7 @@ spec: metadata: type: object resource: - description: ResourceSpec desicribes the resource requirement of running + description: ResourceSpec describes the resource requirement of running the workload. If the number of nodes in the cluster that meet the InstanceType and LabelSelector requirements is small than the Count, controller will provision new nodes before deploying the workload. The final list of @@ -301,7 +302,7 @@ spec: description: Input describes the input used by the tuning method. properties: hostPath: - description: The directory in the hsot that contains the data. + description: The directory in the host that contains the data. type: string image: description: The name of the image that contains the source data. @@ -350,6 +351,7 @@ spec: description: Preset describes which model to load for tuning. properties: accessMode: + default: public description: AccessMode specifies whether the containerized model image is accessible via public registry or private registry. This field defaults to "public" if not specified. If this field @@ -378,6 +380,9 @@ spec: required: - name type: object + required: + - input + - output type: object type: object served: true diff --git a/pkg/utils/testUtils.go b/pkg/utils/testUtils.go index f88b35a4f..5ef34af1d 100644 --- a/pkg/utils/testUtils.go +++ b/pkg/utils/testUtils.go @@ -35,7 +35,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Preset: &v1alpha1.PresetSpec{ PresetMeta: v1alpha1.PresetMeta{ Name: "test-distributed-model", @@ -60,7 +60,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Preset: &v1alpha1.PresetSpec{ PresetMeta: v1alpha1.PresetMeta{ Name: "test-model", @@ -85,7 +85,7 @@ var ( }, }, }, - Inference: v1alpha1.InferenceSpec{ + Inference: &v1alpha1.InferenceSpec{ Template: &corev1.PodTemplateSpec{}, }, } diff --git a/test/e2e/preset_test.go b/test/e2e/preset_test.go index eb0333df4..e8f262ef0 100644 --- a/test/e2e/preset_test.go +++ b/test/e2e/preset_test.go @@ -26,13 +26,13 @@ import ( ) const ( - PresetLlama2AChat = "llama-2-7b-chat" - PresetLlama2BChat = "llama-2-13b-chat" - PresetFalcon7BModel = "falcon-7b" - PresetFalcon40BModel = "falcon-40b" - PresetMistral7BModel = "mistral-7b" + PresetLlama2AChat = "llama-2-7b-chat" + PresetLlama2BChat = "llama-2-13b-chat" + PresetFalcon7BModel = "falcon-7b" + PresetFalcon40BModel = "falcon-40b" + PresetMistral7BModel = "mistral-7b" PresetMistral7BInstructModel = "mistral-7b-instruct" - PresetPhi2Model = "phi-2" + PresetPhi2Model = "phi-2" ) func createFalconWorkspaceWithPresetPublicMode(numOfNode int) *kaitov1alpha1.Workspace { @@ -348,17 +348,17 @@ var _ = Describe("Workspace Preset", func() { fmt.Print("Error: RUN_LLAMA_13B ENV Variable not set") runLlama13B = false } - + aiModelsRegistry = utils.GetEnv("AI_MODELS_REGISTRY") aiModelsRegistrySecret = utils.GetEnv("AI_MODELS_REGISTRY_SECRET") - + // Load stable model versions configs, err := utils.GetModelConfigInfo("/home/runner/work/kaito/kaito/presets/models/supported_models.yaml") if err != nil { fmt.Printf("Failed to load model configs: %v\n", err) os.Exit(1) } - + modelInfo, err = utils.ExtractModelVersion(configs) if err != nil { fmt.Printf("Failed to extract stable model versions: %v\n", err) @@ -404,7 +404,6 @@ var _ = Describe("Workspace Preset", func() { validateWorkspaceReadiness(workspaceObj) }) - It("should create a Phi-2 workspace with preset public mode successfully", func() { numOfNode := 1 workspaceObj := createPhi2WorkspaceWithPresetPublicMode(numOfNode) diff --git a/test/e2e/utils/utils.go b/test/e2e/utils/utils.go index 3914f00eb..38388374f 100644 --- a/test/e2e/utils/utils.go +++ b/test/e2e/utils/utils.go @@ -60,23 +60,23 @@ func ExtractModelVersion(configs map[string]interface{}) (map[string]string, err } for _, modelItem := range models { - model, ok := modelItem.(map[interface{}]interface{}) - if !ok { - return nil, fmt.Errorf("model item is not a map") - } + model, ok := modelItem.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("model item is not a map") + } - modelName, ok := model["name"].(string) - if !ok { - return nil, fmt.Errorf("model name is not a string or not found") - } + modelName, ok := model["name"].(string) + if !ok { + return nil, fmt.Errorf("model name is not a string or not found") + } - modelTag, ok := model["tag"].(string) // Using 'tag' as the version - if !ok { - return nil, fmt.Errorf("model version for %s is not a string or not found", modelName) - } + modelTag, ok := model["tag"].(string) // Using 'tag' as the version + if !ok { + return nil, fmt.Errorf("model version for %s is not a string or not found", modelName) + } - modelsInfo[modelName] = modelTag - } + modelsInfo[modelName] = modelTag + } return modelsInfo, nil } @@ -117,7 +117,7 @@ func GenerateWorkspaceManifest(name, namespace, imageName string, resourceCount workspaceInference.Template = podTemplate } - workspace.Inference = workspaceInference + workspace.Inference = &workspaceInference return workspace } From e69f8bf64ffc02c0b55b9967afa072b1df5f3a4c Mon Sep 17 00:00:00 2001 From: Dave Fellows Date: Tue, 26 Mar 2024 09:39:30 +1300 Subject: [PATCH 09/23] fix: Fix typo in Makefile (#315) I'm assuming this should be `.PHONY` not `.PHONE` Signed-off-by: Dave Fellows --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index c3bfe3385..f02db04bf 100644 --- a/Makefile +++ b/Makefile @@ -187,7 +187,7 @@ ifndef ignore-not-found endif ##@ gpu-provider -.PHONE: gpu-provisioner-identity-perm +.PHONY: gpu-provisioner-identity-perm gpu-provisioner-identity-perm: ## Create identity for gpu-provisioner az identity create --name gpuIdentity --resource-group $(AZURE_RESOURCE_GROUP) From 10af81a5b64f272718e3f107f1dbcd717d914d4e Mon Sep 17 00:00:00 2001 From: Fei Guo Date: Mon, 25 Mar 2024 13:55:42 -0700 Subject: [PATCH 10/23] docs: Update README.md to correct the mail list (#317) Correct the mail list name. Fix issue #314 Signed-off-by: Fei Guo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f022b3618..6678ecd54 100644 --- a/README.md +++ b/README.md @@ -152,4 +152,4 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope ## Contact -"Kaito devs" +"Kaito devs" From 2f0323adfaa510f15b2299aed1f3f2226aa8b3bd Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Mon, 25 Mar 2024 21:13:53 -0700 Subject: [PATCH 11/23] chore: Add GitHub issue/PR templates (#316) - Add report bug template - Add feature request template - Add PR template Signed-off-by: Heba Elayoty --- .github/ISSUE_TEMPLATE/feature_request.md | 16 +++++++++++++++ .github/ISSUE_TEMPLATE/report_bug.md | 25 +++++++++++++++++++++++ .github/PULL_REQUEST_TEMPLATE.md | 11 ++++++++++ 3 files changed, 52 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/ISSUE_TEMPLATE/report_bug.md create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..4239464fb --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,16 @@ +--- +name: Feature request +about: Suggest a new feature for Kaito +title: '' +labels: 'enhancement' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** + +**Describe the solution you'd like** + +**Describe alternatives you've considered** + +**Additional context** \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/report_bug.md b/.github/ISSUE_TEMPLATE/report_bug.md new file mode 100644 index 000000000..0d3dff0e8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/report_bug.md @@ -0,0 +1,25 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: 'bug' +assignees: '' + +--- + +**Describe the bug** + +**Steps To Reproduce** + +**Expected behavior** + +**Logs** + +**Environment** + +- Kubernetes version (use `kubectl version`): +- OS (e.g: `cat /etc/os-release`): +- Install tools: +- Others: + +**Additional context** \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..272b76da7 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,11 @@ +**Reason for Change**: + + +**Requirements** + +- [ ] added unit tests and e2e tests (if applicable). + +**Issue Fixed**: + + +**Notes for Reviewers**: \ No newline at end of file From 99b7d57c93e5c7e2bc8a2821f170b02b4cac051d Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 26 Mar 2024 12:28:13 -0500 Subject: [PATCH 12/23] fix: Adjust default model params (#310) Signed-off-by: Ishaan Sehgal Co-authored-by: Fei Guo --- .../kind-cluster/determine_models.py | 4 +-- .../text-generation/inference_api.py | 4 +-- .../tests/test_inference_api.py | 4 +-- presets/models/supported_models.yaml | 25 +++++++++++++------ 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/.github/workflows/kind-cluster/determine_models.py b/.github/workflows/kind-cluster/determine_models.py index 402365441..5ace3ba63 100644 --- a/.github/workflows/kind-cluster/determine_models.py +++ b/.github/workflows/kind-cluster/determine_models.py @@ -117,10 +117,10 @@ def check_modified_models(pr_branch): def main(): pr_branch = os.environ.get("PR_BRANCH", "main") # If not specified default to 'main' - force_run_all = os.environ.get("FORCE_RUN_ALL", False) # If not specified default to False + force_run_all = os.environ.get("FORCE_RUN_ALL", "false") # If not specified default to False affected_models = [] - if force_run_all: + if force_run_all != "false": affected_models = [model['name'] for model in YAML_PR['models']] else: # Logic to determine affected models diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index bf739844d..f6c604a54 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -125,11 +125,11 @@ def health_check(): class GenerateKwargs(BaseModel): max_length: int = 200 # Length of input prompt+max_new_tokens min_length: int = 0 - do_sample: bool = False + do_sample: bool = True early_stopping: bool = False num_beams: int = 1 temperature: float = 1.0 - top_k: int = 50 + top_k: int = 10 top_p: float = 1 typical_p: float = 1 repetition_penalty: float = 1 diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index d6506b08b..c15b0f38f 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -156,9 +156,9 @@ def test_default_generation_params(configured_app): _, kwargs = mock_pipeline.call_args assert kwargs['max_length'] == 200 assert kwargs['min_length'] == 0 - assert kwargs['do_sample'] is False + assert kwargs['do_sample'] is True assert kwargs['temperature'] == 1.0 - assert kwargs['top_k'] == 50 + assert kwargs['top_k'] == 10 assert kwargs['top_p'] == 1 assert kwargs['typical_p'] == 1 assert kwargs['repetition_penalty'] == 1 diff --git a/presets/models/supported_models.yaml b/presets/models/supported_models.yaml index 0f68002c5..0441a945a 100644 --- a/presets/models/supported_models.yaml +++ b/presets/models/supported_models.yaml @@ -34,23 +34,30 @@ models: type: text-generation version: https://huggingface.co/tiiuae/falcon-7b/commit/898df1396f35e447d5fe44e0a3ccaaaa69f30d36 runtime: tfs - tag: 0.0.3 + tag: 0.0.4 - name: falcon-7b-instruct type: text-generation version: https://huggingface.co/tiiuae/falcon-7b-instruct/commit/cf4b3c42ce2fdfe24f753f0f0d179202fea59c99 runtime: tfs - tag: 0.0.3 + tag: 0.0.4 + # Tag history: + # 0.0.4 - Adjust default model params (#310) + # 0.0.3 - Update Default Params (#294) + # 0.0.2 - Inference API Cleanup (#233) + # 0.0.1 - Initial Release - name: falcon-40b type: text-generation version: https://huggingface.co/tiiuae/falcon-40b/commit/4a70170c215b36a3cce4b4253f6d0612bb7d4146 runtime: tfs - tag: 0.0.3 + tag: 0.0.5 - name: falcon-40b-instruct type: text-generation version: https://huggingface.co/tiiuae/falcon-40b-instruct/commit/ecb78d97ac356d098e79f0db222c9ce7c5d9ee5f runtime: tfs - tag: 0.0.3 - # Tag history: + tag: 0.0.5 + # Tag history for 40b models: + # 0.0.5 - Adjust default model params (#310) + # 0.0.4 - Skipped due to incomplete upload issue # 0.0.3 - Update Default Params (#294) # 0.0.2 - Inference API Cleanup (#233) # 0.0.1 - Initial Release @@ -60,13 +67,14 @@ models: type: text-generation version: https://huggingface.co/mistralai/Mistral-7B-v0.1/commit/26bca36bde8333b5d7f72e9ed20ccda6a618af24 runtime: tfs - tag: 0.0.3 + tag: 0.0.4 - name: mistral-7b-instruct type: text-generation version: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/commit/b70aa86578567ba3301b21c8a27bea4e8f6d6d61 runtime: tfs - tag: 0.0.3 + tag: 0.0.4 # Tag history: + # 0.0.4 - Adjust default model params (#310) # 0.0.3 - Update Default Params (#294) # 0.0.2 - Inference API Cleanup (#233) # 0.0.1 - Initial Release @@ -76,7 +84,8 @@ models: type: text-generation version: https://huggingface.co/microsoft/phi-2/commit/b10c3eba545ad279e7208ee3a5d644566f001670 runtime: tfs - tag: 0.0.2 + tag: 0.0.3 # Tag history: + # 0.0.3 - Adjust default model params (#310) # 0.0.2 - Update Default Params (#294) # 0.0.1 - Initial Release From 8eddac305857a9ecc19fb2047d724ec0f1c04dd9 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Tue, 26 Mar 2024 13:32:04 -0700 Subject: [PATCH 13/23] fix: Update Model Tags (#311) Signed-off-by: Ishaan Sehgal --- presets/models/falcon/model.go | 8 ++++---- presets/models/mistral/model.go | 4 ++-- presets/models/phi/model.go | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/presets/models/falcon/model.go b/presets/models/falcon/model.go index 7501dce23..863a2fb52 100644 --- a/presets/models/falcon/model.go +++ b/presets/models/falcon/model.go @@ -37,10 +37,10 @@ var ( PresetFalcon40BInstructModel = PresetFalcon40BModel + "-instruct" PresetFalconTagMap = map[string]string{ - "Falcon7B": "0.0.3", - "Falcon7BInstruct": "0.0.3", - "Falcon40B": "0.0.3", - "Falcon40BInstruct": "0.0.3", + "Falcon7B": "0.0.4", + "Falcon7BInstruct": "0.0.4", + "Falcon40B": "0.0.5", + "Falcon40BInstruct": "0.0.5", } baseCommandPresetFalcon = "accelerate launch" diff --git a/presets/models/mistral/model.go b/presets/models/mistral/model.go index 7089eafb6..910e2da83 100644 --- a/presets/models/mistral/model.go +++ b/presets/models/mistral/model.go @@ -27,8 +27,8 @@ var ( PresetMistral7BInstructModel = PresetMistral7BModel + "-instruct" PresetMistralTagMap = map[string]string{ - "Mistral7B": "0.0.3", - "Mistral7BInstruct": "0.0.3", + "Mistral7B": "0.0.4", + "Mistral7BInstruct": "0.0.4", } baseCommandPresetMistral = "accelerate launch" diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index 2e54dce38..32df386cc 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -22,7 +22,7 @@ var ( PresetPhi2Model = "phi-2" PresetPhiTagMap = map[string]string{ - "Phi2": "0.0.2", + "Phi2": "0.0.3", } baseCommandPresetPhi = "accelerate launch" From d079f365510dd186f9235dee4cb32eb7f6306a74 Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Thu, 28 Mar 2024 17:21:04 -0700 Subject: [PATCH 14/23] ci: Update release workflow (#319) **Reason for Change**: - Automate the release workflow. - Use [Reusable workflows](https://docs.github.com/en/actions/using-workflows/reusing-workflows#reusable-workflows-and-starter-workflows) instead of workflow_dispatch. This will enhance the usability of e2e pipeline. - Remove unnecessary jobs. - Fix goreleaser to pick the current tag. - Add option to update k8s version for upcoming pipelines. **Requirements** - [ ] added unit tests and e2e tests (if applicable). **Issue Fixed**: **Notes for Reviewers**: --------- Signed-off-by: Heba Elayoty --- .github/workflows/build-publish-image.yml | 84 ---------- .github/workflows/codeql.yml | 2 +- .github/workflows/create-release.yml | 10 +- .../{kaito-e2e.yaml => e2e-workflow.yml} | 128 ++++++++------- .github/workflows/helm-chart.yml | 12 +- .github/workflows/kaito-e2e.yml | 29 ++++ .../workflows/{lint-go.yaml => lint-go.yml} | 10 +- .github/workflows/markdown-link-check.yml | 2 +- .github/workflows/publish-gh-image.yml | 151 ++++++++++++++++++ .github/workflows/publish-image-acr.yml | 119 -------------- .github/workflows/publish-mcr-image.yml | 79 +++++++++ .github/workflows/tests.yml | 15 +- Makefile | 9 +- 13 files changed, 367 insertions(+), 283 deletions(-) delete mode 100644 .github/workflows/build-publish-image.yml rename .github/workflows/{kaito-e2e.yaml => e2e-workflow.yml} (62%) create mode 100644 .github/workflows/kaito-e2e.yml rename .github/workflows/{lint-go.yaml => lint-go.yml} (69%) create mode 100644 .github/workflows/publish-gh-image.yml delete mode 100644 .github/workflows/publish-image-acr.yml create mode 100644 .github/workflows/publish-mcr-image.yml diff --git a/.github/workflows/build-publish-image.yml b/.github/workflows/build-publish-image.yml deleted file mode 100644 index eb430341f..000000000 --- a/.github/workflows/build-publish-image.yml +++ /dev/null @@ -1,84 +0,0 @@ -name: Create, Scan and Publish KAITO image -on: - pull_request: - branches: - - main - - release-** - types: [ closed ] - -permissions: - contents: write - packages: write - -env: - REGISTRY: ghcr.io - GO_VERSION: '1.20' - IMAGE_NAME: 'workspace' - -jobs: - export-registry: - if: github.event.pull_request.merged == true && contains(github.event.pull_request.title, 'update manifest and helm charts') - runs-on: ubuntu-20.04 - environment: preset-env - outputs: - registry: ${{ steps.export.outputs.registry }} - steps: - - id: export - run: | - # registry must be in lowercase - echo "registry=$(echo "${{ env.REGISTRY }}/${{ github.repository }}" | tr [:upper:] [:lower:])" >> $GITHUB_OUTPUT - - publish-images: - if: github.event.pull_request.merged == true && contains(github.event.pull_request.title, 'update manifest and helm charts') - needs: - - export-registry - env: - REGISTRY: ${{ needs.export-registry.outputs.registry }} - runs-on: ubuntu-20.04 - environment: preset-env - steps: - - id: get-tag - name: Get tag - run: echo "IMG_TAG=$(echo ${{ github.event.pull_request.head.ref }} | tr -d release-)" >> $GITHUB_ENV - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 0 - ref: ${{ env.IMG_TAG }} - - - name: Login to ${{ env.REGISTRY }} - uses: docker/login-action@343f7c4344506bcbf9b4de18042ae17996df046d - with: - registry: ${{ env.REGISTRY }} - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Set Image tag - run: | - ver=${{ env.IMG_TAG }} - echo "IMG_TAG=${ver#"v"}" >> $GITHUB_ENV - - name: Build image - run: | - OUTPUT_TYPE=type=registry make docker-build-kaito - env: - VERSION: ${{ env.IMG_TAG }} - - - name: Scan ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ env.IMG_TAG }} - uses: aquasecurity/trivy-action@master - with: - image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ env.IMG_TAG }} - format: 'table' - exit-code: '1' - ignore-unfixed: true - vuln-type: 'os,library' - severity: 'CRITICAL,HIGH' - timeout: '5m0s' - env: - TRIVY_USERNAME: ${{ github.actor }} - TRIVY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} - - - name: 'Dispatch release tag' - uses: peter-evans/repository-dispatch@v3 - with: - token: ${{ secrets.GITHUB_TOKEN }} - event-type: release-tag - client-payload: '{"isRelease": true,"registry": "$${{ env.REGISTRY }}","tag": "v${{ env.IMG_TAG }}"}' diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index b918270bc..1651267c0 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -18,7 +18,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: true fetch-depth: 0 diff --git a/.github/workflows/create-release.yml b/.github/workflows/create-release.yml index 935a88779..71f530adc 100644 --- a/.github/workflows/create-release.yml +++ b/.github/workflows/create-release.yml @@ -1,8 +1,7 @@ name: Create release on: repository_dispatch: - types: [ release-tag ] - branches: [ release-** ] + types: [ create-release ] permissions: id-token: write @@ -14,18 +13,20 @@ env: jobs: create-release: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - name: Harden Runner uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0 with: egress-policy: audit + - name: Set up Go ${{ env.GO_VERSION }} uses: actions/setup-go@v5 with: go-version: ${{ env.GO_VERSION }} + - name: Checkout the repository at the given SHA from the artifact - uses: actions/checkout@v4 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: true fetch-depth: 0 @@ -38,3 +39,4 @@ jobs: args: release --rm-dist --timeout 60m --debug env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GORELEASER_CURRENT_TAG: ${{ github.event.client_payload.tag }} diff --git a/.github/workflows/kaito-e2e.yaml b/.github/workflows/e2e-workflow.yml similarity index 62% rename from .github/workflows/kaito-e2e.yaml rename to .github/workflows/e2e-workflow.yml index 120df9f80..86979b41a 100644 --- a/.github/workflows/kaito-e2e.yaml +++ b/.github/workflows/e2e-workflow.yml @@ -1,40 +1,66 @@ -name: e2e-test - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true +name: kaito-e2e-workflow on: - push: - branches: [main] - paths-ignore: ['docs/**', '**.md', '**.mdx', '**.png', '**.jpg'] - pull_request: - branches: [main] - paths-ignore: ['docs/**', '**.md', '**.mdx', '**.png', '**.jpg'] - repository_dispatch: - types: [ release-tag ] - branches: [ release-** ] - -env: - GO_VERSION: "1.20" + workflow_call: + inputs: + git_sha: + type: string + required: true + tag: + type: string + isRelease: + type: boolean + default: false + registry: + type: string + region: + type: string + description: "the azure location to run the e2e test in" + default: "eastus" + k8s_version: + type: string + default: "1.27" + secrets: + E2E_CLIENT_ID: + required: true + E2E_TENANT_ID: + required: true + E2E_SUBSCRIPTION_ID: + required: true + E2E_AMRT_SECRET_NAME: + required: true + E2E_ACR_AMRT_USERNAME: + required: true + E2E_ACR_AMRT_PASSWORD: + required: true permissions: - id-token: write # This is required for requesting the JWT contents: read # This is required for actions/checkout jobs: e2e-tests: runs-on: ubuntu-latest + permissions: + contents: read + id-token: write # This is required for requesting the JWT environment: e2e-test + env: + GO_VERSION: "1.20" + steps: - - name: Shorten SHA - if: ${{ !github.event.client_payload.isRelease }} - id: vars - run: echo "pr_sha_short=$(git rev-parse --short ${{ github.event.pull_request.head.sha }})" >> $GITHUB_OUTPUT + - name: Harden Runner + uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0 + with: + egress-policy: audit + + - name: Checkout + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + with: + ref: ${{ inputs.git_sha }} - name: Set e2e Resource and Cluster Name run: | - rand=${{ steps.vars.outputs.pr_sha_short }} + rand=$(git rev-parse --short ${{ inputs.git_sha }}) if [ "$rand" = "" ]; then rand=$RANDOM @@ -46,36 +72,22 @@ jobs: echo "REGISTRY=kaito${rand}.azurecr.io" >> $GITHUB_ENV - name: Set Registry - if: ${{ github.event.client_payload.isRelease }} + if: ${{ inputs.isRelease }} run: | - echo "REGISTRY=${{ github.event.client_payload.registry }}" >> $GITHUB_ENV - echo "VERSION=$(echo ${{ github.event.client_payload.tag }} | tr -d v)" >> $GITHUB_ENV + echo "REGISTRY=${{ inputs.registry }}" >> $GITHUB_ENV + echo "VERSION=$(echo ${{ inputs.tag }} | tr -d v)" >> $GITHUB_ENV - name: Set up Go ${{ env.GO_VERSION }} uses: actions/setup-go@v5 with: go-version: ${{ env.GO_VERSION }} - - name: Checkout - if: ${{ !github.event.client_payload.isRelease }} - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 0 - - - name: Checkout - uses: actions/checkout@v4 - if: ${{ github.event.client_payload.isRelease }} - with: - fetch-depth: 0 - submodules: true - ref: ${{ env.REPO_TAG }} - - - uses: azure/login@v1.6.1 + - name: Az login + uses: azure/login@8c334a195cbb38e46038007b304988d888bf676a # v2.0.0 with: - client-id: ${{ secrets.AZURE_CLIENT_ID }} - tenant-id: ${{ secrets.AZURE_TENANT_ID }} - subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + client-id: ${{ secrets.E2E_CLIENT_ID }} + tenant-id: ${{ secrets.E2E_TENANT_ID }} + subscription-id: ${{ secrets.E2E_SUBSCRIPTION_ID }} - uses: azure/setup-helm@v4 with: @@ -104,7 +116,7 @@ jobs: az identity create --name gpuIdentity --resource-group ${{ env.CLUSTER_NAME }} - name: build KAITO image - if: ${{ !github.event.client_payload.isRelease }} + if: ${{ !inputs.isRelease }} shell: bash run: | make docker-build-kaito @@ -120,6 +132,8 @@ jobs: AZURE_ACR_NAME: ${{ env.CLUSTER_NAME }} AZURE_RESOURCE_GROUP: ${{ env.CLUSTER_NAME }} AZURE_CLUSTER_NAME: ${{ env.CLUSTER_NAME }} + AZURE_LOCATION: ${{ inputs.region }} + AKS_K8S_VERSION: ${{ inputs.k8s_version }} - name: Install gpu-provisioner helm chart shell: bash @@ -130,18 +144,18 @@ jobs: AZURE_RESOURCE_GROUP: ${{ env.CLUSTER_NAME }} AZURE_CLUSTER_NAME: ${{ env.CLUSTER_NAME }} - - uses: azure/login@v1.6.1 + - uses: azure/login@8c334a195cbb38e46038007b304988d888bf676a # v2.0.0 with: - client-id: ${{ secrets.AZURE_CLIENT_ID }} - tenant-id: ${{ secrets.AZURE_TENANT_ID }} - subscription-id: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + client-id: ${{ secrets.E2E_CLIENT_ID }} + tenant-id: ${{ secrets.E2E_TENANT_ID }} + subscription-id: ${{ secrets.E2E_SUBSCRIPTION_ID }} - name: Create Role Assignment uses: azure/CLI@v1.0.9 with: inlineScript: | IDENTITY_PRINCIPAL_ID="$(az identity show --name gpuIdentity --resource-group ${{ env.CLUSTER_NAME }} --query 'principalId' -otsv)" - az role assignment create --assignee ${IDENTITY_PRINCIPAL_ID} --scope "/subscriptions/${{ secrets.AZURE_SUBSCRIPTION_ID }}/resourceGroups/${{ env.CLUSTER_NAME }}" --role "Contributor" + az role assignment create --assignee ${IDENTITY_PRINCIPAL_ID} --scope "/subscriptions/${{ secrets.E2E_SUBSCRIPTION_ID }}/resourceGroups/${{ env.CLUSTER_NAME }}" --role "Contributor" - name: Create Azure Federated Identity uses: azure/CLI@v1.0.9 @@ -164,10 +178,10 @@ jobs: - name: Add Secret Credentials run: | - kubectl create secret docker-registry ${{secrets.AMRT_SECRET_NAME}} \ - --docker-server=${{secrets.ACR_AMRT_USERNAME}}.azurecr.io \ - --docker-username=${{secrets.ACR_AMRT_USERNAME}} \ - --docker-password=${{secrets.ACR_AMRT_PASSWORD}} + kubectl create secret docker-registry ${{ secrets.E2E_AMRT_SECRET_NAME }} \ + --docker-server=${{ secrets.E2E_ACR_AMRT_USERNAME }}.azurecr.io \ + --docker-username=${{ secrets.E2E_ACR_AMRT_USERNAME }} \ + --docker-password=${{ secrets.E2E_ACR_AMRT_PASSWORD }} - name: Log kaito-workspace run: | @@ -179,8 +193,8 @@ jobs: env: AZURE_CLUSTER_NAME: ${{ env.CLUSTER_NAME }} RUN_LLAMA_13B: ${{ env.RUN_LLAMA_13B }} - AI_MODELS_REGISTRY: ${{secrets.ACR_AMRT_USERNAME}}.azurecr.io - AI_MODELS_REGISTRY_SECRET: ${{secrets.AMRT_SECRET_NAME}} + AI_MODELS_REGISTRY: ${{ secrets.E2E_ACR_AMRT_USERNAME }}.azurecr.io + AI_MODELS_REGISTRY_SECRET: ${{ secrets.E2E_AMRT_SECRET_NAME }} - name: Cleanup e2e resources if: ${{ always() }} diff --git a/.github/workflows/helm-chart.yml b/.github/workflows/helm-chart.yml index 2666cd92f..2df542691 100644 --- a/.github/workflows/helm-chart.yml +++ b/.github/workflows/helm-chart.yml @@ -1,10 +1,8 @@ name: publish_helm_chart on: - workflow_run: - workflows: [ "Create, Scan and Publish KAITO image" ] - types: [ completed ] - branches: [ release-** ] + repository_dispatch: + types: [ create-release ] permissions: id-token: write # This is required for requesting the JWT @@ -15,15 +13,15 @@ permissions: pull-requests: read jobs: - release: + publish-helm: runs-on: ubuntu-latest - if: ${{ github.event.workflow_run.conclusion == 'success' }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: true fetch-depth: 0 + ref: ${{ github.event.client_payload.tag }} - name: Publish Workspace Helm chart uses: stefanprodan/helm-gh-pages@v1.7.0 diff --git a/.github/workflows/kaito-e2e.yml b/.github/workflows/kaito-e2e.yml new file mode 100644 index 000000000..b4000d30c --- /dev/null +++ b/.github/workflows/kaito-e2e.yml @@ -0,0 +1,29 @@ +name: pr-e2e-test + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +on: + pull_request: + paths-ignore: ['docs/**', '**.md', '**.mdx', '**.png', '**.jpg'] + +env: + GO_VERSION: "1.20" + +permissions: + id-token: write # This is required for requesting the JWT + contents: read # This is required for actions/checkout + +jobs: + run-e2e: + uses: ./.github/workflows/e2e-workflow.yml + with: + git_sha: ${{ github.event.pull_request.head.sha }} + secrets: + E2E_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + E2E_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + E2E_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + E2E_AMRT_SECRET_NAME: ${{ secrets.AMRT_SECRET_NAME }} + E2E_ACR_AMRT_USERNAME: ${{ secrets.ACR_AMRT_USERNAME }} + E2E_ACR_AMRT_PASSWORD: ${{ secrets.ACR_AMRT_PASSWORD }} diff --git a/.github/workflows/lint-go.yaml b/.github/workflows/lint-go.yml similarity index 69% rename from .github/workflows/lint-go.yaml rename to .github/workflows/lint-go.yml index 50b1d90b6..6e8371137 100644 --- a/.github/workflows/lint-go.yaml +++ b/.github/workflows/lint-go.yml @@ -17,14 +17,20 @@ env: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest permissions: contents: read steps: - - uses: actions/checkout@v4 + - name: Harden Runner + uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0 + with: + egress-policy: audit + + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: true fetch-depth: 0 + - name: Set up Go ${{ env.GO_VERSION }} uses: actions/setup-go@v5 with: diff --git a/.github/workflows/markdown-link-check.yml b/.github/workflows/markdown-link-check.yml index 92c438985..00e6c44f7 100644 --- a/.github/workflows/markdown-link-check.yml +++ b/.github/workflows/markdown-link-check.yml @@ -10,7 +10,7 @@ jobs: markdown-link-check: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - uses: gaurav-nelson/github-action-markdown-link-check@v1 with: # this will only show errors in the output diff --git a/.github/workflows/publish-gh-image.yml b/.github/workflows/publish-gh-image.yml new file mode 100644 index 000000000..775901441 --- /dev/null +++ b/.github/workflows/publish-gh-image.yml @@ -0,0 +1,151 @@ +name: Create, Scan and Publish KAITO image +on: + workflow_dispatch: + inputs: + release_version: + description: 'tag to be created for this image (i.e. vxx.xx.xx)' + required: true + pull_request: + types: [ closed ] + + +permissions: + id-token: write + contents: write + packages: write + +env: + GO_VERSION: '1.20' + IMAGE_NAME: 'workspace' + REGISTRY: ghcr.io + +jobs: + check-tag: + if: >- + github.event_name == 'workflow_dispatch' || + ( + github.event_name == 'pull_request' && + github.event.pull_request.merged == true && + contains(github.event.pull_request.title, 'update manifest and helm charts') + ) + runs-on: ubuntu-latest + environment: preset-env + outputs: + tag: ${{ steps.get-tag.outputs.tag }} + steps: + - name: validate version + if: github.event_name == 'workflow_dispatch' + run: | + echo "${{ github.event.inputs.release_version }}" | grep -E 'v[0-9]+\.[0-9]+\.[0-9]+$' + + - id: get-tag + name: Get tag + run: | + if [[ ${{ github.event_name }} == 'workflow_dispatch' ]]; then + echo "tag=$(echo ${{ github.event.inputs.release_version }})" >> $GITHUB_OUTPUT + else + echo "tag=$(echo ${{ github.event.pull_request.head.ref }} | tr -d release-)" >> $GITHUB_OUTPUT + fi + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - id: check-tag + name: Check for Tag + run: | + TAG="${{ steps.get-tag.outputs.tag }}" + if git show-ref --tags --verify --quiet "refs/tags/${TAG}"; then + echo "create_tag=$(echo 'false' )" >> $GITHUB_OUTPUT + else + echo "create_tag=$(echo 'true' )" >> $GITHUB_OUTPUT + fi + - name: 'Create tag' + if: steps.check-tag.outputs.create_tag == 'true' + uses: actions/github-script@v7 + with: + script: | + github.rest.git.createRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: 'refs/tags/${{ steps.get-tag.outputs.tag }}', + sha: context.sha + }) + + build-scan-publish-gh-images: + runs-on: ubuntu-latest + needs: [ check-tag ] + environment: preset-env + steps: + - id: get-registry + run: | + # registry must be in lowercase + echo "registry_repository=$(echo "${{ env.REGISTRY }}/${{ github.repository }}" | tr [:upper:] [:lower:])" >> $GITHUB_OUTPUT + + - id: get-tag + name: Get tag + run: | + echo "IMG_TAG=$(echo ${{ needs.check-tag.outputs.tag }} | tr -d v)" >> $GITHUB_ENV + + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + with: + submodules: true + fetch-depth: 0 + ref: ${{ needs.check-tag.outputs.tag }} + + - name: Login to ${{ steps.get-registry.outputs.registry_repository }} + uses: docker/login-action@343f7c4344506bcbf9b4de18042ae17996df046d + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build image + run: | + OUTPUT_TYPE=type=registry make docker-build-kaito + env: + VERSION: ${{ needs.check-tag.outputs.tag }} + REGISTRY: ${{ steps.get-registry.outputs.registry_repository }} + + - name: Scan ${{ steps.get-registry.outputs.registry_repository }}/${{ env.IMAGE_NAME }}:${{ env.IMG_TAG }} + uses: aquasecurity/trivy-action@master + with: + image-ref: ${{ steps.get-registry.outputs.registry_repository }}/${{ env.IMAGE_NAME }}:${{ env.IMG_TAG }} + format: 'table' + exit-code: '1' + ignore-unfixed: true + vuln-type: 'os,library' + severity: 'CRITICAL,HIGH' + timeout: '5m0s' + env: + TRIVY_USERNAME: ${{ github.actor }} + TRIVY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} + + run-e2e-gh-image: + needs: [ check-tag, build-scan-publish-gh-images ] + uses: ./.github/workflows/e2e-workflow.yml + with: + git_sha: ${{ github.sha }} + isRelease: true + registry: ${{ needs.get-registry.outputs.registry_repository }} + tag: ${{ needs.check-tag.outputs.tag }} + secrets: + E2E_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + E2E_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + E2E_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + E2E_AMRT_SECRET_NAME: ${{ secrets.AMRT_SECRET_NAME }} + E2E_ACR_AMRT_USERNAME: ${{ secrets.ACR_AMRT_USERNAME }} + E2E_ACR_AMRT_PASSWORD: ${{ secrets.ACR_AMRT_PASSWORD }} + + publish-mcr-image: + runs-on: ubuntu-latest + environment: preset-env + needs: [ check-tag, run-e2e-gh-image ] + steps: + - name: 'Dispatch release tag' + uses: peter-evans/repository-dispatch@v3 + with: + token: ${{ secrets.GITHUB_TOKEN }} + event-type: publish-mcr-image + client-payload: '{"tag": "${{ needs.check-tag.outputs.tag }}"}' diff --git a/.github/workflows/publish-image-acr.yml b/.github/workflows/publish-image-acr.yml deleted file mode 100644 index a307bd942..000000000 --- a/.github/workflows/publish-image-acr.yml +++ /dev/null @@ -1,119 +0,0 @@ -name: Push image to ACR -on: - workflow_dispatch: - inputs: - release_version: - description: 'tag to be created for this image (i.e. vxx.xx.xx)' - required: true - -permissions: - id-token: write - contents: write - packages: write - -env: - GO_VERSION: '1.20' - IMAGE_NAME: 'workspace' - -jobs: - check-tag: - runs-on: - labels: [ "self-hosted", "1ES.Pool=1es-aks-kaito-agent-pool-ubuntu" ] - environment: publish-mcr - outputs: - tag: ${{ steps.get-tag.outputs.tag }} - steps: - - name: validate version - run: | - echo "${{ github.event.inputs.release_version }}" | grep -E 'v[0-9]+\.[0-9]+\.[0-9]+$' - - id: get-tag - name: Get tag - run: echo "tag=$(echo ${{ github.event.inputs.release_version }})" >> $GITHUB_OUTPUT - - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Check for Tag - run: | - TAG="${{ steps.get-tag.outputs.tag }}" - if git show-ref --tags --verify --quiet "refs/tags/${TAG}"; then - echo "create_tag=$(echo 'false' )" >> $GITHUB_ENV - else - echo "create_tag=$(echo 'true' )" >> $GITHUB_ENV - fi - - name: 'Create tag' - if: ${{ env.create_tag == 'true' }} - uses: actions/github-script@v7 - with: - script: | - github.rest.git.createRef({ - owner: context.repo.owner, - repo: context.repo.repo, - ref: 'refs/tags/${{ steps.get-tag.outputs.tag }}', - sha: context.sha - }) - - publish: - runs-on: - labels: [ "self-hosted", "1ES.Pool=1es-aks-kaito-agent-pool-ubuntu" ] - environment: publish-mcr - needs: - - check-tag - steps: - - id: get-tag - name: Get tag - run: echo "IMG_TAG=$(echo ${{ needs.check-tag.outputs.tag }} | tr -d v)" >> $GITHUB_ENV - - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 - with: - go-version: ${{ env.GO_VERSION }} - - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - submodules: true - ref: ${{ needs.check-tag.outputs.tag }} - - - name: 'Build Image' - run: | - OUTPUT_TYPE=type=docker ARCH=arm64 make docker-build-kaito - env: - VERSION: ${{ env.IMG_TAG }} - REGISTRY: ${{ secrets.KAITO_MCR_REGISTRY }}/public/aks/kaito - - - name: Scan ${{ secrets.KAITO_MCR_REGISTRY }}/public/aks/kaito/${{ env.IMAGE_NAME }}:${{ env.IMG_TAG }} - uses: aquasecurity/trivy-action@master - with: - image-ref: ${{ secrets.KAITO_MCR_REGISTRY }}/public/aks/kaito/${{ env.IMAGE_NAME }}:${{ env.IMG_TAG }} - format: 'table' - exit-code: '1' - ignore-unfixed: true - vuln-type: 'os,library' - severity: 'CRITICAL,HIGH' - timeout: '5m0s' - env: - TRIVY_USERNAME: ${{ github.actor }} - TRIVY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} - - - name: 'Dispatch tag to e2e test' - uses: peter-evans/repository-dispatch@v3 - with: - token: ${{ secrets.GITHUB_TOKEN }} - event-type: release-tag - client-payload: '{"isRelease": true,"registry": "mcr.microsoft.com/aks/kaito","tag": "${{ needs.check-tag.outputs.tag }}"}' - - - name: Authenticate to ACR - run: | - az login --identity - az acr login -n ${{ secrets.KAITO_MCR_REGISTRY }} - - - name: 'Publish to ACR' - id: Publish - run: | - OUTPUT_TYPE=type=registry make docker-build-kaito - env: - VERSION: ${{ env.IMG_TAG }} - REGISTRY: ${{ secrets.KAITO_MCR_REGISTRY }}/public/aks/kaito diff --git a/.github/workflows/publish-mcr-image.yml b/.github/workflows/publish-mcr-image.yml new file mode 100644 index 000000000..a021be19f --- /dev/null +++ b/.github/workflows/publish-mcr-image.yml @@ -0,0 +1,79 @@ +name: Push image to MCR +on: + repository_dispatch: + types: [ publish-mcr-image ] + +permissions: + contents: write + packages: write + +env: + GO_VERSION: '1.20' + IMAGE_NAME: 'workspace' + +jobs: + build-publish-mcr-image: + runs-on: + labels: [ "self-hosted", "1ES.Pool=1es-aks-kaito-agent-pool-ubuntu" ] + environment: publish-mcr + steps: + - name: Set up Go ${{ env.GO_VERSION }} + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Set Image tag + run: | + ver=${{ github.event.client_payload.tag }} + echo "IMG_TAG=${ver#"v"}" >> $GITHUB_ENV + + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + submodules: true + ref: ${{ github.event.client_payload.tag }} + + - name: Authenticate to ACR + run: | + az login --identity + az acr login -n ${{ secrets.KAITO_MCR_REGISTRY }} + + - name: 'Build and Publish to MCR' + id: Publish + run: | + OUTPUT_TYPE=type=registry make docker-build-kaito + env: + VERSION: ${{ needs.get-tag.outputs.release-tag }} + REGISTRY: ${{ secrets.KAITO_MCR_REGISTRY }}/public/aks/kaito + + run-e2e-mcr: + permissions: + contents: read + id-token: write + statuses: write + needs: [ build-publish-mcr-image ] + uses: ./.github/workflows/e2e-workflow.yml + with: + git_sha: ${{ github.sha }} + isRelease: true + registry: "mcr.microsoft.com/aks/kaito" + tag: ${{ github.event.client_payload.tag }} + secrets: + E2E_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + E2E_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + E2E_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + E2E_AMRT_SECRET_NAME: ${{ secrets.AMRT_SECRET_NAME }} + E2E_ACR_AMRT_USERNAME: ${{ secrets.ACR_AMRT_USERNAME }} + E2E_ACR_AMRT_PASSWORD: ${{ secrets.ACR_AMRT_PASSWORD }} + + create-release: + runs-on: ubuntu-latest + environment: publish-mcr + needs: [ run-e2e-mcr ] + steps: + - name: 'Dispatch release tag' + uses: peter-evans/repository-dispatch@v3 + with: + token: ${{ secrets.GITHUB_TOKEN }} + event-type: create-release + client-payload: '{"tag": "${{ github.event.client_payload.tag }}"}' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 03be3669c..8a5d668d8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,17 +24,22 @@ jobs: runs-on: ubuntu-latest environment: unit-tests steps: - - name: Set up Go ${{ env.GO_VERSION }} - uses: actions/setup-go@v5 + - name: Harden Runner + uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0 with: - go-version: ${{ env.GO_VERSION }} + egress-policy: audit - - name: Check out the code in the Go module directory - uses: actions/checkout@v4 + - name: Check out the code + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 with: submodules: true fetch-depth: 0 + - name: Set up Go ${{ env.GO_VERSION }} + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + - name: Run unit tests & Generate coverage run: | make unit-test diff --git a/Makefile b/Makefile index f02db04bf..7d4f8409c 100644 --- a/Makefile +++ b/Makefile @@ -24,6 +24,7 @@ GINKGO := $(TOOLS_BIN_DIR)/$(GINKGO_BIN)-$(GINKGO_VER) AZURE_SUBSCRIPTION_ID ?= $(AZURE_SUBSCRIPTION_ID) AZURE_LOCATION ?= eastus +AKS_K8S_VERSION ?= 1.27.2 AZURE_RESOURCE_GROUP ?= demo AZURE_CLUSTER_NAME ?= kaito-demo AZURE_RESOURCE_GROUP_MC=MC_$(AZURE_RESOURCE_GROUP)_$(AZURE_CLUSTER_NAME)_$(AZURE_LOCATION) @@ -118,12 +119,13 @@ create-acr: ## Create test ACR .PHONY: create-aks-cluster create-aks-cluster: ## Create test AKS cluster (with msi, oidc, and workload identity enabled) - az aks create --name $(AZURE_CLUSTER_NAME) --resource-group $(AZURE_RESOURCE_GROUP) --attach-acr $(AZURE_ACR_NAME) \ - --node-count 1 --generate-ssh-keys --enable-managed-identity --enable-workload-identity --enable-oidc-issuer -o none + az aks create --name $(AZURE_CLUSTER_NAME) --resource-group $(AZURE_RESOURCE_GROUP) --location $(AZURE_LOCATION) \ + --attach-acr $(AZURE_ACR_NAME) --kubernetes-version $(AKS_K8S_VERSION) --node-count 1 --generate-ssh-keys \ + --enable-managed-identity --enable-workload-identity --enable-oidc-issuer -o none .PHONY: create-aks-cluster-with-kaito create-aks-cluster-with-kaito: ## Create test AKS cluster (with msi, oidc and kaito enabled) - az aks create --name $(AZURE_CLUSTER_NAME) --resource-group $(AZURE_RESOURCE_GROUP) --node-count 1 \ + az aks create --name $(AZURE_CLUSTER_NAME) --resource-group $(AZURE_RESOURCE_GROUP) --location $(AZURE_LOCATION) --node-count 1 \ --generate-ssh-keys --enable-managed-identity --enable-oidc-issuer --enable-ai-toolchain-operator -o none az aks get-credentials --name $(AZURE_CLUSTER_NAME) --resource-group $(AZURE_RESOURCE_GROUP) @@ -264,6 +266,7 @@ lint: $(GOLANGCI_LINT) .PHONY: release-manifest release-manifest: @sed -i -e 's/^VERSION ?= .*/VERSION ?= ${VERSION}/' ./Makefile + @sed -i -e "s/version: .*/version: ${IMG_TAG}/" ./charts/kaito/workspace/Chart.yaml @sed -i -e "s/appVersion: .*/appVersion: ${IMG_TAG}/" ./charts/kaito/workspace/Chart.yaml @sed -i -e "s/tag: .*/tag: ${IMG_TAG}/" ./charts/kaito/workspace/values.yaml @sed -i -e 's/IMG_TAG=.*/IMG_TAG=${IMG_TAG}/' ./charts/kaito/workspace/README.md From 2df7725be275c52c6ac27fa053d739280c524798 Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Thu, 28 Mar 2024 18:34:57 -0700 Subject: [PATCH 15/23] fix: Update helm chart to use Release.Namespace (#322) **Reason for Change**: Update helm chart to use `Release.Namespace` instead of chart name. **Requirements** - [ ] added unit tests and e2e tests (if applicable). **Issue Fixed**: **Notes for Reviewers**: Signed-off-by: Heba Elayoty --- Makefile | 4 ++-- charts/kaito/gpu-provisioner/README.md | 2 +- .../gpu-provisioner/templates/clusterrole-core.yaml | 2 +- .../gpu-provisioner/templates/configmap-logging.yaml | 2 +- charts/kaito/gpu-provisioner/templates/configmap.yaml | 2 +- charts/kaito/gpu-provisioner/templates/deployment.yaml | 9 +-------- charts/kaito/gpu-provisioner/templates/role.yaml | 2 +- charts/kaito/gpu-provisioner/templates/rolebinding.yaml | 6 +++--- .../kaito/gpu-provisioner/templates/serviceaccount.yaml | 2 +- charts/kaito/gpu-provisioner/values.yaml | 1 - charts/kaito/workspace/README.md | 4 +++- charts/kaito/workspace/templates/deployment.yaml | 9 +-------- .../workspace/templates/nvidia-device-plugin-ds.yaml | 2 +- charts/kaito/workspace/templates/role.yaml | 2 +- charts/kaito/workspace/templates/role_binding.yaml | 2 +- .../kaito/workspace/templates/secret-webhook-cert.yaml | 2 +- charts/kaito/workspace/templates/service.yaml | 2 +- charts/kaito/workspace/templates/serviceaccount.yaml | 2 +- docs/installation.md | 5 +++-- 19 files changed, 25 insertions(+), 37 deletions(-) diff --git a/Makefile b/Makefile index 7d4f8409c..c1676a052 100644 --- a/Makefile +++ b/Makefile @@ -147,7 +147,7 @@ az-patch-install-helm: ## Update Azure client env vars and settings in helm valu yq -i '(.image.repository) = "$(REGISTRY)/workspace"' ./charts/kaito/workspace/values.yaml yq -i '(.image.tag) = "$(IMG_TAG)"' ./charts/kaito/workspace/values.yaml - helm install kaito-workspace ./charts/kaito/workspace + helm install kaito-workspace ./charts/kaito/workspace --namespace $(KAITO_NAMESPACE) --create-namespace ##@ Build @@ -219,7 +219,7 @@ gpu-provisioner-helm: ## Update Azure client env vars and settings in helm valu yq -i '(.workloadIdentity.clientId) = "$(IDENTITY_CLIENT_ID)"' ./charts/kaito/gpu-provisioner/values.yaml yq -i '(.workloadIdentity.tenantId) = "$(AZURE_TENANT_ID)"' ./charts/kaito/gpu-provisioner/values.yaml - helm install kaito-gpu-provisioner ./charts/kaito/gpu-provisioner + helm install kaito-gpu-provisioner ./charts/kaito/gpu-provisioner --namespace $(GPU_NAMESPACE) --create-namespace ##@ Build Dependencies diff --git a/charts/kaito/gpu-provisioner/README.md b/charts/kaito/gpu-provisioner/README.md index d2dc949cd..81440c4e8 100644 --- a/charts/kaito/gpu-provisioner/README.md +++ b/charts/kaito/gpu-provisioner/README.md @@ -9,7 +9,7 @@ A Helm chart for gpu-provisioner To install the chart with the release name `gpu-provisioner`: ```bash -helm install gpu-provisioner ./charts/gpu-provisioner +helm install gpu-provisioner ./charts/gpu-provisioner --namespace=gpu-provisioner --create-namespace ``` ## Values diff --git a/charts/kaito/gpu-provisioner/templates/clusterrole-core.yaml b/charts/kaito/gpu-provisioner/templates/clusterrole-core.yaml index b697ce02a..b9c2840f3 100644 --- a/charts/kaito/gpu-provisioner/templates/clusterrole-core.yaml +++ b/charts/kaito/gpu-provisioner/templates/clusterrole-core.yaml @@ -15,7 +15,7 @@ roleRef: subjects: - kind: ServiceAccount name: gpu-provisioner - namespace: {{ .Values.namespace }} + namespace: {{ .Release.Namespace }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole diff --git a/charts/kaito/gpu-provisioner/templates/configmap-logging.yaml b/charts/kaito/gpu-provisioner/templates/configmap-logging.yaml index ce293dd76..5c27dc8b5 100644 --- a/charts/kaito/gpu-provisioner/templates/configmap-logging.yaml +++ b/charts/kaito/gpu-provisioner/templates/configmap-logging.yaml @@ -2,7 +2,7 @@ apiVersion: v1 kind: ConfigMap metadata: name: gpu-provisioner-config-logging - namespace: {{ .Values.namespace }} + namespace: {{ .Release.Namespace }} labels: {{- include "gpu-provisioner.labels" . | nindent 4 }} {{- with .Values.additionalAnnotations }} diff --git a/charts/kaito/gpu-provisioner/templates/configmap.yaml b/charts/kaito/gpu-provisioner/templates/configmap.yaml index 10474bab0..3d51e651e 100644 --- a/charts/kaito/gpu-provisioner/templates/configmap.yaml +++ b/charts/kaito/gpu-provisioner/templates/configmap.yaml @@ -2,7 +2,7 @@ apiVersion: v1 kind: ConfigMap metadata: name: gpu-provisioner-global-settings - namespace: {{ .Values.namespace }} + namespace: {{ .Release.Namespace }} labels: {{- include "gpu-provisioner.labels" . | nindent 4 }} {{- with .Values.additionalAnnotations }} diff --git a/charts/kaito/gpu-provisioner/templates/deployment.yaml b/charts/kaito/gpu-provisioner/templates/deployment.yaml index 4234e024f..fe746958a 100644 --- a/charts/kaito/gpu-provisioner/templates/deployment.yaml +++ b/charts/kaito/gpu-provisioner/templates/deployment.yaml @@ -1,15 +1,8 @@ -apiVersion: v1 -kind: Namespace -metadata: - name: {{ .Values.namespace }} - labels: - {{- include "gpu-provisioner.labels" . | nindent 4 }} ---- apiVersion: apps/v1 kind: Deployment metadata: name: {{ include "gpu-provisioner.fullname" . }} - namespace: {{ .Values.namespace }} + namespace: {{ .Release.Namespace }} labels: azure.workload.identity/use: "true" {{- include "gpu-provisioner.labels" . | nindent 4 }} diff --git a/charts/kaito/gpu-provisioner/templates/role.yaml b/charts/kaito/gpu-provisioner/templates/role.yaml index e23ac6fe5..3ebd01fce 100644 --- a/charts/kaito/gpu-provisioner/templates/role.yaml +++ b/charts/kaito/gpu-provisioner/templates/role.yaml @@ -2,7 +2,7 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: {{ include "gpu-provisioner.fullname" . }} - namespace: {{ .Values.namespace }} + namespace: {{ .Release.Namespace }} labels: {{- include "gpu-provisioner.labels" . | nindent 4 }} {{- with .Values.additionalAnnotations }} diff --git a/charts/kaito/gpu-provisioner/templates/rolebinding.yaml b/charts/kaito/gpu-provisioner/templates/rolebinding.yaml index 89606bb73..9c3140135 100644 --- a/charts/kaito/gpu-provisioner/templates/rolebinding.yaml +++ b/charts/kaito/gpu-provisioner/templates/rolebinding.yaml @@ -2,7 +2,7 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: {{ include "gpu-provisioner.fullname" . }} - namespace: {{ .Values.namespace }} + namespace: {{ .Release.Namespace }} labels: {{- include "gpu-provisioner.labels" . | nindent 4 }} {{- with .Values.additionalAnnotations }} @@ -16,7 +16,7 @@ roleRef: subjects: - kind: ServiceAccount name: gpu-provisioner - namespace: {{ .Values.namespace }} + namespace: {{ .Release.Namespace }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding @@ -36,4 +36,4 @@ roleRef: subjects: - kind: ServiceAccount name: gpu-provisioner - namespace: {{ .Values.namespace }} \ No newline at end of file + namespace: {{ .Release.Namespace }} \ No newline at end of file diff --git a/charts/kaito/gpu-provisioner/templates/serviceaccount.yaml b/charts/kaito/gpu-provisioner/templates/serviceaccount.yaml index 02e2e57a2..0b6c5b11a 100644 --- a/charts/kaito/gpu-provisioner/templates/serviceaccount.yaml +++ b/charts/kaito/gpu-provisioner/templates/serviceaccount.yaml @@ -2,7 +2,7 @@ apiVersion: v1 kind: ServiceAccount metadata: name: gpu-provisioner - namespace: {{ .Values.namespace }} + namespace: {{ .Release.Namespace }} labels: {{- include "gpu-provisioner.labels" . | nindent 4 }} annotations: diff --git a/charts/kaito/gpu-provisioner/values.yaml b/charts/kaito/gpu-provisioner/values.yaml index 8c40f9e8d..74ec2853f 100644 --- a/charts/kaito/gpu-provisioner/values.yaml +++ b/charts/kaito/gpu-provisioner/values.yaml @@ -1,4 +1,3 @@ -namespace: gpu-provisioner # -- Overrides the chart's name. nameOverride: "" # -- Overrides the chart's computed fullname. diff --git a/charts/kaito/workspace/README.md b/charts/kaito/workspace/README.md index 497ecbfb5..331f8f77f 100644 --- a/charts/kaito/workspace/README.md +++ b/charts/kaito/workspace/README.md @@ -6,7 +6,9 @@ export REGISTRY= export IMG_NAME=workspace export IMG_TAG=0.2.1 -helm install workspace ./charts/kaito/workspace --set image.repository=${REGISTRY}/$(IMG_NAME) --set image.tag=$(IMG_TAG) +helm install workspace ./charts/kaito/workspace \ +--set image.repository=${REGISTRY}/$(IMG_NAME) --set image.tag=$(IMG_TAG) \ +--namespace kaito-workspace --create-namespace ``` ## Values diff --git a/charts/kaito/workspace/templates/deployment.yaml b/charts/kaito/workspace/templates/deployment.yaml index 72b35b412..53d7e4a79 100644 --- a/charts/kaito/workspace/templates/deployment.yaml +++ b/charts/kaito/workspace/templates/deployment.yaml @@ -1,15 +1,8 @@ -apiVersion: v1 -kind: Namespace -metadata: - name: {{ include "kaito.fullname" . }} - labels: - {{- include "kaito.labels" . | nindent 4 }} ---- apiVersion: apps/v1 kind: Deployment metadata: name: {{ include "kaito.fullname" . }} - namespace: {{ include "kaito.fullname" . }} + namespace: {{ .Release.Namespace }} labels: {{- include "kaito.labels" . | nindent 4 }} spec: diff --git a/charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml b/charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml index 680310f26..21974de27 100644 --- a/charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml +++ b/charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml @@ -2,7 +2,7 @@ apiVersion: apps/v1 kind: DaemonSet metadata: name: nvidia-device-plugin-daemonset - namespace: {{ include "kaito.fullname" . }} + namespace: {{ .Release.Namespace }} labels: {{- include "kaito.labels" . | nindent 4 }} spec: diff --git a/charts/kaito/workspace/templates/role.yaml b/charts/kaito/workspace/templates/role.yaml index a15a8d819..af8d4c380 100644 --- a/charts/kaito/workspace/templates/role.yaml +++ b/charts/kaito/workspace/templates/role.yaml @@ -3,7 +3,7 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: {{ include "kaito.fullname" . }}-role - namespace: {{ include "kaito.fullname" .}} + namespace: {{ .Release.Namespace }} labels: {{- include "kaito.labels" . | nindent 4 }} rules: diff --git a/charts/kaito/workspace/templates/role_binding.yaml b/charts/kaito/workspace/templates/role_binding.yaml index fdad758ae..708b6b173 100644 --- a/charts/kaito/workspace/templates/role_binding.yaml +++ b/charts/kaito/workspace/templates/role_binding.yaml @@ -2,7 +2,7 @@ apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: {{ include "kaito.fullname" . }}-rolebinding - namespace: {{ include "kaito.fullname" . }} + namespace: {{ .Release.Namespace }} labels: {{- include "kaito.labels" . | nindent 4 }} roleRef: diff --git a/charts/kaito/workspace/templates/secret-webhook-cert.yaml b/charts/kaito/workspace/templates/secret-webhook-cert.yaml index 9fab5a666..b0c3e331e 100644 --- a/charts/kaito/workspace/templates/secret-webhook-cert.yaml +++ b/charts/kaito/workspace/templates/secret-webhook-cert.yaml @@ -2,7 +2,7 @@ apiVersion: v1 kind: Secret metadata: name: workspace-webhook-cert - namespace: {{ include "kaito.fullname" .}} + namespace: {{ .Release.Namespace }} labels: {{- include "kaito.labels" . | nindent 4 }} data: diff --git a/charts/kaito/workspace/templates/service.yaml b/charts/kaito/workspace/templates/service.yaml index 8573cc221..b8cb5de91 100644 --- a/charts/kaito/workspace/templates/service.yaml +++ b/charts/kaito/workspace/templates/service.yaml @@ -2,7 +2,7 @@ apiVersion: v1 kind: Service metadata: name: {{ include "kaito.fullname" . }} - namespace: {{ include "kaito.fullname" .}} + namespace: {{ .Release.Namespace }} labels: {{- include "kaito.labels" . | nindent 4 }} spec: diff --git a/charts/kaito/workspace/templates/serviceaccount.yaml b/charts/kaito/workspace/templates/serviceaccount.yaml index aa6d44af6..dc3036f99 100644 --- a/charts/kaito/workspace/templates/serviceaccount.yaml +++ b/charts/kaito/workspace/templates/serviceaccount.yaml @@ -2,6 +2,6 @@ apiVersion: v1 kind: ServiceAccount metadata: name: {{ include "kaito.fullname" . }}-sa - namespace: {{ include "kaito.fullname" . }} + namespace: {{ .Release.Namespace }} labels: {{- include "kaito.labels" . | nindent 4 }} diff --git a/docs/installation.md b/docs/installation.md index 1baa519f3..e441d1cfe 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -41,7 +41,7 @@ az aks install-cli Install the Workspace controller. ```bash -helm install workspace ./charts/kaito/workspace +helm install workspace ./charts/kaito/workspace --namespace kaito-workspace --create-namespace ``` Note that if you have installed another node provisioning controller that supports Karpenter-core APIs, the following steps for installing `gpu-provisioner` can be skipped. @@ -105,7 +105,8 @@ settings: EOF # install gpu-provisioner using values override file -helm install gpu-provisioner ./charts/kaito/gpu-provisioner -f values.override.yaml +helm install gpu-provisioner ./charts/kaito/gpu-provisioner \ +--namespace gpu-provisioner --create-namespace -f values.override.yaml ``` #### Create the federated credential From 8d0976f8809a3c55d568a1f830a6bbcba1ed2a80 Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Thu, 28 Mar 2024 21:55:23 -0700 Subject: [PATCH 16/23] release: update manifest and helm charts for v0.2.2 (#324) --- Makefile | 2 +- README.md | 3 ++- charts/kaito/workspace/Chart.yaml | 4 ++-- charts/kaito/workspace/README.md | 2 +- charts/kaito/workspace/values.yaml | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index c1676a052..91c4a1928 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ # Image URL to use all building/pushing image targets REGISTRY ?= YOUR_REGISTRY IMG_NAME ?= workspace -VERSION ?= v0.2.1 +VERSION ?= v0.2.2 IMG_TAG ?= $(subst v,,$(VERSION)) ROOT_DIR := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) diff --git a/README.md b/README.md index 6678ecd54..58b369554 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ # Kubernetes AI Toolchain Operator (Kaito) +![GitHub Release](https://img.shields.io/github/v/release/Azure/kaito) [![Go Report Card](https://goreportcard.com/badge/github.com/Azure/kaito)](https://goreportcard.com/report/github.com/Azure/kaito) ![GitHub go.mod Go version](https://img.shields.io/github/go-mod/go-version/Azure/kaito) [![codecov](https://codecov.io/gh/Azure/kaito/graph/badge.svg?token=XAQLLPB2AR)](https://codecov.io/gh/Azure/kaito) | ![notification](docs/img/bell.svg) What is NEW! | |-------------------------------------------------| -| Latest Release: March 19th, 2024. Kaito v0.2.1. | +| Latest Release: March 28th, 2024. Kaito v0.2.2. | | First Release: Nov 15th, 2023. Kaito v0.1.0. | Kaito is an operator that automates the AI/ML inference model deployment in a Kubernetes cluster. diff --git a/charts/kaito/workspace/Chart.yaml b/charts/kaito/workspace/Chart.yaml index 35c17d199..d3e7034a2 100644 --- a/charts/kaito/workspace/Chart.yaml +++ b/charts/kaito/workspace/Chart.yaml @@ -6,13 +6,13 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.2.1 +version: 0.2.2 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to # follow Semantic Versioning. They should reflect the version the application is using. # It is recommended to use it with quotes. -appVersion: "0.2.1" +appVersion: 0.2.2 home: https://github.com/Azure/kaito sources: - https://github.com/Azure/kaito diff --git a/charts/kaito/workspace/README.md b/charts/kaito/workspace/README.md index 331f8f77f..f937019b8 100644 --- a/charts/kaito/workspace/README.md +++ b/charts/kaito/workspace/README.md @@ -5,7 +5,7 @@ ```bash export REGISTRY= export IMG_NAME=workspace -export IMG_TAG=0.2.1 +export IMG_TAG=0.2.2 helm install workspace ./charts/kaito/workspace \ --set image.repository=${REGISTRY}/$(IMG_NAME) --set image.tag=$(IMG_TAG) \ --namespace kaito-workspace --create-namespace diff --git a/charts/kaito/workspace/values.yaml b/charts/kaito/workspace/values.yaml index 90ae02156..3841f064c 100644 --- a/charts/kaito/workspace/values.yaml +++ b/charts/kaito/workspace/values.yaml @@ -5,7 +5,7 @@ replicaCount: 1 image: repository: mcr.microsoft.com/aks/kaito/workspace pullPolicy: IfNotPresent - tag: 0.2.1 + tag: 0.2.2 imagePullSecrets: [] podAnnotations: {} podSecurityContext: From c49b93a500bd149ddbea1bde8c46e323b8fb9ab6 Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Mon, 1 Apr 2024 09:28:24 -0700 Subject: [PATCH 17/23] ci: Remove PR trigger from release workflow (#330) **Reason for Change**: Remove PR trigger from release workflow. As we follow a branch based approach. **Requirements** - [ ] added unit tests and e2e tests (if applicable). **Issue Fixed**: **Notes for Reviewers**: Signed-off-by: Heba Elayoty --- .github/workflows/publish-gh-image.yml | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/.github/workflows/publish-gh-image.yml b/.github/workflows/publish-gh-image.yml index 775901441..508d67946 100644 --- a/.github/workflows/publish-gh-image.yml +++ b/.github/workflows/publish-gh-image.yml @@ -5,9 +5,6 @@ on: release_version: description: 'tag to be created for this image (i.e. vxx.xx.xx)' required: true - pull_request: - types: [ closed ] - permissions: id-token: write @@ -21,31 +18,19 @@ env: jobs: check-tag: - if: >- - github.event_name == 'workflow_dispatch' || - ( - github.event_name == 'pull_request' && - github.event.pull_request.merged == true && - contains(github.event.pull_request.title, 'update manifest and helm charts') - ) runs-on: ubuntu-latest environment: preset-env outputs: tag: ${{ steps.get-tag.outputs.tag }} steps: - name: validate version - if: github.event_name == 'workflow_dispatch' run: | echo "${{ github.event.inputs.release_version }}" | grep -E 'v[0-9]+\.[0-9]+\.[0-9]+$' - id: get-tag name: Get tag run: | - if [[ ${{ github.event_name }} == 'workflow_dispatch' ]]; then echo "tag=$(echo ${{ github.event.inputs.release_version }})" >> $GITHUB_OUTPUT - else - echo "tag=$(echo ${{ github.event.pull_request.head.ref }} | tr -d release-)" >> $GITHUB_OUTPUT - fi - name: Checkout uses: actions/checkout@v4 From ee9101aa669e13b0ab53ec85c86f926adad41ddd Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Mon, 1 Apr 2024 12:34:13 -0700 Subject: [PATCH 18/23] fix: Add registry as a pipeline job output (#329) **Reason for Change**: - The job output for registry was missing. - calling the output should use the job name, not step one. - **Requirements** - [ ] added unit tests and e2e tests (if applicable). **Issue Fixed**: **Notes for Reviewers**: --------- Signed-off-by: Heba <31887807+helayoty@users.noreply.github.com> --- .github/workflows/publish-gh-image.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/publish-gh-image.yml b/.github/workflows/publish-gh-image.yml index 508d67946..1f53f43db 100644 --- a/.github/workflows/publish-gh-image.yml +++ b/.github/workflows/publish-gh-image.yml @@ -62,6 +62,8 @@ jobs: runs-on: ubuntu-latest needs: [ check-tag ] environment: preset-env + outputs: + registry_repository: ${{ steps.get-registry.outputs.registry_repository }} steps: - id: get-registry run: | @@ -113,7 +115,7 @@ jobs: with: git_sha: ${{ github.sha }} isRelease: true - registry: ${{ needs.get-registry.outputs.registry_repository }} + registry: ${{ needs.build-scan-publish-gh-images.outputs.registry_repository }} tag: ${{ needs.check-tag.outputs.tag }} secrets: E2E_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} From 08dd1f4002b99a8d221a14f0a7188a929e60df58 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Mon, 1 Apr 2024 15:14:15 -0700 Subject: [PATCH 19/23] feat: Initialize Fine-Tuning Interface and Core Methods - Part 3 (#308) Setup --- api/v1alpha1/workspace_validation_test.go | 29 ++++++- .../kaito_workspace_tuning_falcon_7b.yaml | 20 +++++ .../kaito_workspace_falcon_40b-instruct.yaml | 0 .../kaito_workspace_falcon_40b.yaml | 0 .../kaito_workspace_falcon_7b-instruct.yaml | 0 .../kaito_workspace_falcon_7b.yaml | 0 .../kaito_workspace_llama2_13b-chat.yaml | 0 .../kaito_workspace_llama2_13b.yaml | 0 .../kaito_workspace_llama2_70b-chat.yaml | 0 .../kaito_workspace_llama2_70b.yaml | 0 .../kaito_workspace_llama2_7b-chat.yaml | 0 .../kaito_workspace_llama2_7b.yaml | 0 .../kaito_workspace_mistral_7b-instruct.yaml | 0 .../kaito_workspace_mistral_7b.yaml | 0 .../kaito_workspace_phi-2.yaml | 0 pkg/controllers/workspace_controller.go | 65 +++++++++++++--- pkg/inference/preset-inferences.go | 10 +-- pkg/inference/preset-inferences_test.go | 2 +- pkg/model/interface.go | 24 +++--- pkg/tuning/preset-tuning-types.go | 21 ++++++ pkg/tuning/preset-tuning.go | 14 ++++ pkg/utils/testModel.go | 30 ++++++-- presets/models/falcon/README.md | 8 +- presets/models/falcon/model.go | 75 +++++++++++++++---- presets/models/llama2/README.md | 6 +- presets/models/llama2/model.go | 36 ++++++--- presets/models/llama2chat/README.md | 6 +- presets/models/llama2chat/model.go | 37 ++++++--- presets/models/mistral/README.md | 4 +- presets/models/mistral/model.go | 37 +++++++-- presets/models/phi/README.md | 2 +- presets/models/phi/model.go | 25 ++++++- 32 files changed, 357 insertions(+), 94 deletions(-) create mode 100644 examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml rename examples/{ => inference}/kaito_workspace_falcon_40b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_40b.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_7b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_falcon_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_13b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_13b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_70b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_70b.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_7b-chat.yaml (100%) rename examples/{ => inference}/kaito_workspace_llama2_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_mistral_7b-instruct.yaml (100%) rename examples/{ => inference}/kaito_workspace_mistral_7b.yaml (100%) rename examples/{ => inference}/kaito_workspace_phi-2.yaml (100%) create mode 100644 pkg/tuning/preset-tuning-types.go create mode 100644 pkg/tuning/preset-tuning.go diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 11631e67b..d196beabc 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -21,8 +21,15 @@ var perGPUMemoryRequirement string type testModel struct{} -func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ + GPUCountRequirement: gpuCountRequirement, + TotalGPUMemoryRequirement: totalGPUMemoryRequirement, + PerGPUMemoryRequirement: perGPUMemoryRequirement, + } +} +func (*testModel) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, PerGPUMemoryRequirement: perGPUMemoryRequirement, @@ -31,11 +38,22 @@ func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { func (*testModel) SupportDistributedInference() bool { return false } +func (*testModel) SupportTuning() bool { + return true +} type testModelPrivate struct{} -func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModelPrivate) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ + ImageAccessMode: "private", + GPUCountRequirement: gpuCountRequirement, + TotalGPUMemoryRequirement: totalGPUMemoryRequirement, + PerGPUMemoryRequirement: perGPUMemoryRequirement, + } +} +func (*testModelPrivate) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ ImageAccessMode: "private", GPUCountRequirement: gpuCountRequirement, TotalGPUMemoryRequirement: totalGPUMemoryRequirement, @@ -45,6 +63,9 @@ func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam { func (*testModelPrivate) SupportDistributedInference() bool { return false } +func (*testModelPrivate) SupportTuning() bool { + return true +} func RegisterValidationTestModels() { var test testModel diff --git a/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml b/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml new file mode 100644 index 000000000..6d6ed7831 --- /dev/null +++ b/examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml @@ -0,0 +1,20 @@ +apiVersion: kaito.sh/v1alpha1 +kind: Workspace +metadata: + name: workspace-tuning-falcon-7b +spec: + resource: + instanceType: "Standard_NC12s_v3" + labelSelector: + matchLabels: + app: tuning-falcon-7b + tuning: + preset: + name: falcon-7b + method: lora + config: tuning-config-map # ConfigMap containing tuning arguments + input: + name: tuning-data + hostPath: /path/to/your/input/data # dataset on node + output: + hostPath: /path/to/store/output # Tuning Output diff --git a/examples/kaito_workspace_falcon_40b-instruct.yaml b/examples/inference/kaito_workspace_falcon_40b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_falcon_40b-instruct.yaml rename to examples/inference/kaito_workspace_falcon_40b-instruct.yaml diff --git a/examples/kaito_workspace_falcon_40b.yaml b/examples/inference/kaito_workspace_falcon_40b.yaml similarity index 100% rename from examples/kaito_workspace_falcon_40b.yaml rename to examples/inference/kaito_workspace_falcon_40b.yaml diff --git a/examples/kaito_workspace_falcon_7b-instruct.yaml b/examples/inference/kaito_workspace_falcon_7b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_falcon_7b-instruct.yaml rename to examples/inference/kaito_workspace_falcon_7b-instruct.yaml diff --git a/examples/kaito_workspace_falcon_7b.yaml b/examples/inference/kaito_workspace_falcon_7b.yaml similarity index 100% rename from examples/kaito_workspace_falcon_7b.yaml rename to examples/inference/kaito_workspace_falcon_7b.yaml diff --git a/examples/kaito_workspace_llama2_13b-chat.yaml b/examples/inference/kaito_workspace_llama2_13b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_13b-chat.yaml rename to examples/inference/kaito_workspace_llama2_13b-chat.yaml diff --git a/examples/kaito_workspace_llama2_13b.yaml b/examples/inference/kaito_workspace_llama2_13b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_13b.yaml rename to examples/inference/kaito_workspace_llama2_13b.yaml diff --git a/examples/kaito_workspace_llama2_70b-chat.yaml b/examples/inference/kaito_workspace_llama2_70b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_70b-chat.yaml rename to examples/inference/kaito_workspace_llama2_70b-chat.yaml diff --git a/examples/kaito_workspace_llama2_70b.yaml b/examples/inference/kaito_workspace_llama2_70b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_70b.yaml rename to examples/inference/kaito_workspace_llama2_70b.yaml diff --git a/examples/kaito_workspace_llama2_7b-chat.yaml b/examples/inference/kaito_workspace_llama2_7b-chat.yaml similarity index 100% rename from examples/kaito_workspace_llama2_7b-chat.yaml rename to examples/inference/kaito_workspace_llama2_7b-chat.yaml diff --git a/examples/kaito_workspace_llama2_7b.yaml b/examples/inference/kaito_workspace_llama2_7b.yaml similarity index 100% rename from examples/kaito_workspace_llama2_7b.yaml rename to examples/inference/kaito_workspace_llama2_7b.yaml diff --git a/examples/kaito_workspace_mistral_7b-instruct.yaml b/examples/inference/kaito_workspace_mistral_7b-instruct.yaml similarity index 100% rename from examples/kaito_workspace_mistral_7b-instruct.yaml rename to examples/inference/kaito_workspace_mistral_7b-instruct.yaml diff --git a/examples/kaito_workspace_mistral_7b.yaml b/examples/inference/kaito_workspace_mistral_7b.yaml similarity index 100% rename from examples/kaito_workspace_mistral_7b.yaml rename to examples/inference/kaito_workspace_mistral_7b.yaml diff --git a/examples/kaito_workspace_phi-2.yaml b/examples/inference/kaito_workspace_phi-2.yaml similarity index 100% rename from examples/kaito_workspace_phi-2.yaml rename to examples/inference/kaito_workspace_phi-2.yaml diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index a2e4fc18d..042249f15 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -9,8 +9,8 @@ import ( "strings" "time" - appsv1 "k8s.io/api/apps/v1" - "k8s.io/utils/clock" + "github.com/azure/kaito/pkg/tuning" + batchv1 "k8s.io/api/batch/v1" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" @@ -21,6 +21,7 @@ import ( "github.com/azure/kaito/pkg/utils/plugin" "github.com/go-logr/logr" "github.com/samber/lo" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -28,6 +29,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" + "k8s.io/utils/clock" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" @@ -109,16 +111,22 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err = c.applyInference(ctx, wObj); err != nil { - if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, - "workspaceFailed", err.Error()); updateErr != nil { - klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) - return reconcile.Result{}, updateErr + if wObj.Tuning != nil { + if err = c.applyTuning(ctx, wObj); err != nil { + return reconcile.Result{}, err + } + } + if wObj.Inference != nil { + if err = c.applyInference(ctx, wObj); err != nil { + if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, + "workspaceFailed", err.Error()); updateErr != nil { + klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) + return reconcile.Result{}, updateErr + } + return reconcile.Result{}, err } - return reconcile.Result{}, err } - // TODO apply TrainingSpec if err = c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionTrue, "workspaceReady", "workspace is ready"); err != nil { klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -423,6 +431,41 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al return nil } +func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { + var err error + func() { + if wObj.Tuning.Preset != nil { + presetName := string(wObj.Tuning.Preset.Name) + model := plugin.KaitoModelRegister.MustGet(presetName) + + tuningParam := model.GetTuningParameters() + existingObj := &batchv1.Job{} + if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { + klog.InfoS("A tuning workload already exists for workspace", "workspace", klog.KObj(wObj)) + if err = resources.CheckResourceStatus(existingObj, c.Client, tuningParam.ReadinessTimeout); err != nil { + return + } + } else if apierrors.IsNotFound(err) { + var workloadObj client.Object + // Need to create a new workload + workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, tuningParam, c.Client) + if err != nil { + return + } + if err = resources.CheckResourceStatus(workloadObj, c.Client, tuningParam.ReadinessTimeout); err != nil { + return + } + } + } + }() + + if err != nil { + return err + } + + return nil +} + // applyInference applies inference spec. func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { var err error @@ -455,7 +498,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil { klog.InfoS("An inference workload already exists for workspace", "workspace", klog.KObj(wObj)) - if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.DeploymentTimeout); err != nil { + if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.ReadinessTimeout); err != nil { return } } else if apierrors.IsNotFound(err) { @@ -465,7 +508,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a if err != nil { return } - if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.DeploymentTimeout); err != nil { + if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.ReadinessTimeout); err != nil { return } } diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 9b02012b7..4c4792b54 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -67,7 +67,7 @@ var ( } ) -func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) error { +func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) error { existingService := &corev1.Service{} err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService) if err != nil { @@ -92,7 +92,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl return nil } -func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) (string, []corev1.LocalObjectReference) { +func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) { imageName := string(workspaceObj.Inference.Preset.Name) imageTag := inferenceObj.Tag imagePullSecretRefs := []corev1.LocalObjectReference{} @@ -110,7 +110,7 @@ func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, in } func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, - inferenceObj *model.PresetInferenceParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) { + inferenceObj *model.PresetParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) { if inferenceObj.TorchRunParams != nil && supportDistributedInference { if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceObj); err != nil { klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj) @@ -141,7 +141,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work // torchrun baseCommand // and sets the GPU resources required for inference. // Returns the command and resource configuration. -func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetInferenceParam) ([]string, corev1.ResourceRequirements) { +func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) { torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams) torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams) modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams) @@ -159,7 +159,7 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetI return commands, resourceRequirements } -func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) ([]corev1.Volume, []corev1.VolumeMount) { +func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) { volume := []corev1.Volume{} volumeMount := []corev1.VolumeMount{} diff --git a/pkg/inference/preset-inferences_test.go b/pkg/inference/preset-inferences_test.go index cd8df067c..31bf0551e 100644 --- a/pkg/inference/preset-inferences_test.go +++ b/pkg/inference/preset-inferences_test.go @@ -62,7 +62,7 @@ func TestCreatePresetInference(t *testing.T) { useHeadlessSvc := false - var inferenceObj *model.PresetInferenceParam + var inferenceObj *model.PresetParam model := plugin.KaitoModelRegister.MustGet(tc.modelName) inferenceObj = model.GetInferenceParameters() diff --git a/pkg/model/interface.go b/pkg/model/interface.go index 217c1f889..3a054cf25 100644 --- a/pkg/model/interface.go +++ b/pkg/model/interface.go @@ -7,27 +7,29 @@ import ( ) type Model interface { - GetInferenceParameters() *PresetInferenceParam + GetInferenceParameters() *PresetParam + GetTuningParameters() *PresetParam SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework. + SupportTuning() bool } -// PresetInferenceParam defines the preset inference parameters for a model. -type PresetInferenceParam struct { +// PresetParam defines the preset inference parameters for a model. +type PresetParam struct { ModelFamilyName string // The name of the model family. ImageAccessMode string // Defines where the Image is Public or Private. DiskStorageRequirement string // Disk storage requirements for the model. - GPUCountRequirement string // Number of GPUs required for the inference. - TotalGPUMemoryRequirement string // Total GPU memory required for the inference. + GPUCountRequirement string // Number of GPUs required for the Preset. + TotalGPUMemoryRequirement string // Total GPU memory required for the Preset. PerGPUMemoryRequirement string // GPU memory required per GPU. TorchRunParams map[string]string // Parameters for configuring the torchrun command. - TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed inference using torchrun (elastic). - ModelRunParams map[string]string // Parameters for running the model inference. - // DeploymentTimeout defines the maximum duration for pulling the Preset image. + TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic). + // BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line. + BaseCommand string + ModelRunParams map[string]string // Parameters for running the model training/inference. + // ReadinessTimeout defines the maximum duration for creating the workload. // This timeout accommodates the size of the image, ensuring pull completion // even under slower network conditions or unforeseen delays. - DeploymentTimeout time.Duration - // BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line. - BaseCommand string + ReadinessTimeout time.Duration // WorldSize defines the number of processes required for distributed inference. WorldSize int Tag string // The model image tag diff --git a/pkg/tuning/preset-tuning-types.go b/pkg/tuning/preset-tuning-types.go new file mode 100644 index 000000000..51f36511d --- /dev/null +++ b/pkg/tuning/preset-tuning-types.go @@ -0,0 +1,21 @@ +package tuning + +import corev1 "k8s.io/api/core/v1" + +const ( + DefaultNumProcesses = "1" + DefaultNumMachines = "1" + DefaultMachineRank = "0" + DefaultGPUIds = "all" +) + +var ( + DefaultAccelerateParams = map[string]string{ + "num_processes": DefaultNumProcesses, + "num_machines": DefaultNumMachines, + "machine_rank": DefaultMachineRank, + "gpu_ids": DefaultGPUIds, + } + + DefaultImagePullSecrets = []corev1.LocalObjectReference{} +) diff --git a/pkg/tuning/preset-tuning.go b/pkg/tuning/preset-tuning.go new file mode 100644 index 000000000..d9dfbd477 --- /dev/null +++ b/pkg/tuning/preset-tuning.go @@ -0,0 +1,14 @@ +package tuning + +import ( + "context" + kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" + "github.com/azure/kaito/pkg/model" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func CreatePresetTuning(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, + tuningObj *model.PresetParam, kubeClient client.Client) (client.Object, error) { + // TODO + return nil, nil +} diff --git a/pkg/utils/testModel.go b/pkg/utils/testModel.go index 99e3d8aca..5acd05ac5 100644 --- a/pkg/utils/testModel.go +++ b/pkg/utils/testModel.go @@ -12,27 +12,45 @@ import ( type testModel struct{} -func (*testModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: "1", - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, + } +} +func (*testModel) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ + GPUCountRequirement: "1", + ReadinessTimeout: time.Duration(30) * time.Minute, } } func (*testModel) SupportDistributedInference() bool { return false } +func (*testModel) SupportTuning() bool { + return true +} type testDistributedModel struct{} -func (*testDistributedModel) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*testDistributedModel) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ GPUCountRequirement: "1", - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, + } +} +func (*testDistributedModel) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ + GPUCountRequirement: "1", + ReadinessTimeout: time.Duration(30) * time.Minute, } } func (*testDistributedModel) SupportDistributedInference() bool { return true } +func (*testDistributedModel) SupportTuning() bool { + return true +} func RegisterTestModel() { var test testModel diff --git a/presets/models/falcon/README.md b/presets/models/falcon/README.md index 81a1ced6f..e8cd895e6 100644 --- a/presets/models/falcon/README.md +++ b/presets/models/falcon/README.md @@ -1,10 +1,10 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false| -|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/kaito_workspace_falcon_7b.yaml)|Deployment| false| -|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false| -|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/kaito_workspace_falcon_40b.yaml)|Deployment| false| +|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/inference/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false| +|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/inference/kaito_workspace_falcon_7b.yaml)|Deployment| false| +|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/inference/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false| +|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/inference/kaito_workspace_falcon_40b.yaml)|Deployment| false| ## Image Source - **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR). diff --git a/presets/models/falcon/model.go b/presets/models/falcon/model.go index 863a2fb52..bc7f882af 100644 --- a/presets/models/falcon/model.go +++ b/presets/models/falcon/model.go @@ -54,8 +54,8 @@ var falconA falcon7b type falcon7b struct{} -func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -64,22 +64,40 @@ func (*falcon7b) GetInferenceParameters() *model.PresetInferenceParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7B"], } - } +func (*falcon7b) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Falcon", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "2", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", + //TorchRunParams: tuning.DefaultAccelerateParams, // TODO + //ModelRunPrams: falconRunTuningParams, // TODO + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon7B"], + } +} + func (*falcon7b) SupportDistributedInference() bool { return false } +func (*falcon7b) SupportTuning() bool { + return true +} var falconB falcon7bInst type falcon7bInst struct{} -func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon7bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -88,22 +106,28 @@ func (*falcon7bInst) GetInferenceParameters() *model.PresetInferenceParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon7BInstruct"], } } +func (*falcon7bInst) GetTuningParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*falcon7bInst) SupportDistributedInference() bool { return false } +func (*falcon7bInst) SupportTuning() bool { + return false +} var falconC falcon40b type falcon40b struct{} -func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon40b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "400", @@ -112,22 +136,40 @@ func (*falcon40b) GetInferenceParameters() *model.PresetInferenceParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon40B"], } } +func (*falcon40b) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Falcon", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "2", + TotalGPUMemoryRequirement: "90Gi", + PerGPUMemoryRequirement: "16Gi", + //TorchRunParams: tuning.DefaultAccelerateParams, // TODO + //ModelRunPrams: falconRunTuningParams, // TODO + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetFalcon, + Tag: PresetFalconTagMap["Falcon40B"], + } +} func (*falcon40b) SupportDistributedInference() bool { return false } +func (*falcon40b) SupportTuning() bool { + return true +} var falconD falcon40bInst type falcon40bInst struct{} -func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*falcon40bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Falcon", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "400", @@ -136,12 +178,17 @@ func (*falcon40bInst) GetInferenceParameters() *model.PresetInferenceParam { PerGPUMemoryRequirement: "0Gi", // We run Falcon using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: falconRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetFalcon, Tag: PresetFalconTagMap["Falcon40BInstruct"], } } - +func (*falcon40bInst) GetTuningParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*falcon40bInst) SupportDistributedInference() bool { return false } +func (*falcon40bInst) SupportTuning() bool { + return false +} diff --git a/presets/models/llama2/README.md b/presets/models/llama2/README.md index e6a40563a..ba2646a2b 100644 --- a/presets/models/llama2/README.md +++ b/presets/models/llama2/README.md @@ -1,9 +1,9 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|llama2-7b |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_7b.yaml)|Deployment| false| -|llama2-13b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_13b.yaml)|StatefulSet| true| -|llama2-70b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_70b.yaml)|StatefulSet| true| +|llama2-7b |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_7b.yaml)|Deployment| false| +|llama2-13b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_13b.yaml)|StatefulSet| true| +|llama2-70b|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_70b.yaml)|StatefulSet| true| ## Image Source - **Private**: User needs to manage the lifecycle of the inference service images that contain model weights (e.g., managing image tags). The images are available in user's private container registry. diff --git a/presets/models/llama2/model.go b/presets/models/llama2/model.go index 30c97b7fd..6a62a8987 100644 --- a/presets/models/llama2/model.go +++ b/presets/models/llama2/model.go @@ -38,8 +38,8 @@ var llama2A llama2Text7b type llama2Text7b struct{} -func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "34Gi", @@ -49,23 +49,29 @@ func (*llama2Text7b) GetInferenceParameters() *model.PresetInferenceParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(10) * time.Minute, + ReadinessTimeout: time.Duration(10) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text7b) GetTuningParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text7b) SupportDistributedInference() bool { return false } +func (*llama2Text7b) SupportTuning() bool { + return false +} var llama2B llama2Text13b type llama2Text13b struct{} -func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text13b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "46Gi", @@ -75,22 +81,28 @@ func (*llama2Text13b) GetInferenceParameters() *model.PresetInferenceParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(20) * time.Minute, + ReadinessTimeout: time.Duration(20) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 2, // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text13b) GetTuningParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text13b) SupportDistributedInference() bool { return true } +func (*llama2Text13b) SupportTuning() bool { + return false +} var llama2C llama2Text70b type llama2Text70b struct{} -func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Text70b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "158Gi", @@ -100,12 +112,18 @@ func (*llama2Text70b) GetInferenceParameters() *model.PresetInferenceParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 8, // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Text70b) GetTuningParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Text70b) SupportDistributedInference() bool { return true } +func (*llama2Text70b) SupportTuning() bool { + return false +} diff --git a/presets/models/llama2chat/README.md b/presets/models/llama2chat/README.md index 53e241fab..0cf9ec3be 100644 --- a/presets/models/llama2chat/README.md +++ b/presets/models/llama2chat/README.md @@ -1,9 +1,9 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|llama2-7b-chat |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_7b-chat.yaml)|Deployment| false| -|llama2-13b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_13b-chat.yaml)|StatefulSet| true| -|llama2-70b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/kaito_workspace_llama2_70b-chat.yaml)|StatefulSet| true| +|llama2-7b-chat |[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_7b-chat.yaml)|Deployment| false| +|llama2-13b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_13b-chat.yaml)|StatefulSet| true| +|llama2-70b-chat|[meta](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)|[link](../../../examples/inference/kaito_workspace_llama2_70b-chat.yaml)|StatefulSet| true| ## Image Source - **Private**: User needs to manage the lifecycle of the inference service images that contain model weights (e.g., managing image tags). The images are available in user's private container registry. diff --git a/presets/models/llama2chat/model.go b/presets/models/llama2chat/model.go index cc0d8d4c6..89225bef5 100644 --- a/presets/models/llama2chat/model.go +++ b/presets/models/llama2chat/model.go @@ -38,8 +38,8 @@ var llama2chatA llama2Chat7b type llama2Chat7b struct{} -func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "34Gi", @@ -49,23 +49,28 @@ func (*llama2Chat7b) GetInferenceParameters() *model.PresetInferenceParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(10) * time.Minute, + ReadinessTimeout: time.Duration(10) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 1, // Tag: llama has private image access mode. The image tag is determined by the user. } - +} +func (*llama2Chat7b) GetTuningParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning } func (*llama2Chat7b) SupportDistributedInference() bool { return false } +func (*llama2Chat7b) SupportTuning() bool { + return false +} var llama2chatB llama2Chat13b type llama2Chat13b struct{} -func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat13b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "46Gi", @@ -75,22 +80,28 @@ func (*llama2Chat13b) GetInferenceParameters() *model.PresetInferenceParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(20) * time.Minute, + ReadinessTimeout: time.Duration(20) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 2, // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Chat13b) GetTuningParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Chat13b) SupportDistributedInference() bool { return true } +func (*llama2Chat13b) SupportTuning() bool { + return false +} var llama2chatC llama2Chat70b type llama2Chat70b struct{} -func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*llama2Chat70b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "LLaMa2", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePrivate), DiskStorageRequirement: "158Gi", @@ -100,12 +111,18 @@ func (*llama2Chat70b) GetInferenceParameters() *model.PresetInferenceParam { TorchRunParams: inference.DefaultTorchRunParams, TorchRunRdzvParams: inference.DefaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetLlama, WorldSize: 8, // Tag: llama has private image access mode. The image tag is determined by the user. } } +func (*llama2Chat70b) GetTuningParameters() *model.PresetParam { + return nil // Currently doesn't support fine-tuning +} func (*llama2Chat70b) SupportDistributedInference() bool { return true } +func (*llama2Chat70b) SupportTuning() bool { + return false +} diff --git a/presets/models/mistral/README.md b/presets/models/mistral/README.md index 4d0c56ba6..2d037f7a4 100644 --- a/presets/models/mistral/README.md +++ b/presets/models/mistral/README.md @@ -1,8 +1,8 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|mistral-7b-instruct |[mistralai](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|[link](../../../examples/kaito_workspace_mistral_7b-instruct.yaml)|Deployment| false| -|mistral-7b |[mistralai](https://huggingface.co/mistralai/Mistral-7B-v0.1) |[link](../../../examples/kaito_workspace_mistral_7b.yaml)|Deployment| false| +|mistral-7b-instruct |[mistralai](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)|[link](../../../examples/inference/kaito_workspace_mistral_7b-instruct.yaml)|Deployment| false| +|mistral-7b |[mistralai](https://huggingface.co/mistralai/Mistral-7B-v0.1) |[link](../../../examples/inference/kaito_workspace_mistral_7b.yaml)|Deployment| false| ## Image Source diff --git a/presets/models/mistral/model.go b/presets/models/mistral/model.go index 910e2da83..78115e805 100644 --- a/presets/models/mistral/model.go +++ b/presets/models/mistral/model.go @@ -42,8 +42,8 @@ var mistralA mistral7b type mistral7b struct{} -func (*mistral7b) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*mistral7b) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "100Gi", @@ -52,22 +52,41 @@ func (*mistral7b) GetInferenceParameters() *model.PresetInferenceParam { PerGPUMemoryRequirement: "0Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: mistralRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetMistral, Tag: PresetMistralTagMap["Mistral7B"], } } +func (*mistral7b) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Mistral", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "100Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", // We run Mistral using native vertical model parallel, no per GPU memory requirement. + //TorchRunParams: tuning.DefaultAccelerateParams, + //ModelRunParams: mistralRunParams, + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetMistral, + Tag: PresetMistralTagMap["Mistral7B"], + } +} + func (*mistral7b) SupportDistributedInference() bool { return false } +func (*mistral7b) SupportTuning() bool { + return true +} var mistralB mistral7bInst type mistral7bInst struct{} -func (*mistral7bInst) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*mistral7bInst) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Mistral", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "100Gi", @@ -76,12 +95,18 @@ func (*mistral7bInst) GetInferenceParameters() *model.PresetInferenceParam { PerGPUMemoryRequirement: "0Gi", // We run mistral using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: mistralRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetMistral, Tag: PresetMistralTagMap["Mistral7BInstruct"], } } +func (*mistral7bInst) GetTuningParameters() *model.PresetParam { + return nil // It is not recommended/ideal to further fine-tune instruct models - Already been fine-tuned +} func (*mistral7bInst) SupportDistributedInference() bool { return false } +func (*mistral7bInst) SupportTuning() bool { + return false +} diff --git a/presets/models/phi/README.md b/presets/models/phi/README.md index 7caeadb84..1e77252a5 100644 --- a/presets/models/phi/README.md +++ b/presets/models/phi/README.md @@ -1,7 +1,7 @@ ## Supported Models |Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference| |----|:----:|:----:| :----: |:----: | -|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/kaito_workspace_phi-2.yaml)|Deployment| false| +|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/inference/kaito_workspace_phi-2.yaml)|Deployment| false| ## Image Source diff --git a/presets/models/phi/model.go b/presets/models/phi/model.go index 32df386cc..3f6e8b120 100644 --- a/presets/models/phi/model.go +++ b/presets/models/phi/model.go @@ -36,8 +36,8 @@ var phiA phi2 type phi2 struct{} -func (*phi2) GetInferenceParameters() *model.PresetInferenceParam { - return &model.PresetInferenceParam{ +func (*phi2) GetInferenceParameters() *model.PresetParam { + return &model.PresetParam{ ModelFamilyName: "Phi", ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), DiskStorageRequirement: "50Gi", @@ -46,12 +46,29 @@ func (*phi2) GetInferenceParameters() *model.PresetInferenceParam { PerGPUMemoryRequirement: "0Gi", // We run Phi using native vertical model parallel, no per GPU memory requirement. TorchRunParams: inference.DefaultAccelerateParams, ModelRunParams: phiRunParams, - DeploymentTimeout: time.Duration(30) * time.Minute, + ReadinessTimeout: time.Duration(30) * time.Minute, BaseCommand: baseCommandPresetPhi, Tag: PresetPhiTagMap["Phi2"], } - +} +func (*phi2) GetTuningParameters() *model.PresetParam { + return &model.PresetParam{ + ModelFamilyName: "Phi", + ImageAccessMode: string(kaitov1alpha1.ModelImageAccessModePublic), + DiskStorageRequirement: "50Gi", + GPUCountRequirement: "1", + TotalGPUMemoryRequirement: "16Gi", + PerGPUMemoryRequirement: "16Gi", // We run Phi using native vertical model parallel, no per GPU memory requirement. + // TorchRunParams: inference.DefaultAccelerateParams, + // ModelRunParams: phiRunParams, + ReadinessTimeout: time.Duration(30) * time.Minute, + BaseCommand: baseCommandPresetPhi, + Tag: PresetPhiTagMap["Phi2"], + } } func (*phi2) SupportDistributedInference() bool { return false } +func (*phi2) SupportTuning() bool { + return true +} From 05fb90c1be906d0639b5a6a6511863992bea4ca8 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Mon, 1 Apr 2024 16:35:36 -0700 Subject: [PATCH 20/23] feat: Add sample front end helm chart (#320) **Reason for Change**: This helm chart provides a straightforward example for deploying a sample UI on top of a kaito workspace. --- charts/DemoUI/inference/.helmignore | 23 ++++++ charts/DemoUI/inference/Chart.yaml | 15 ++++ charts/DemoUI/inference/README.md | 44 +++++++++++ charts/DemoUI/inference/templates/NOTES.txt | 22 ++++++ .../DemoUI/inference/templates/_helpers.tpl | 62 +++++++++++++++ .../inference/templates/deployment.yaml | 77 +++++++++++++++++++ .../DemoUI/inference/templates/service.yaml | 15 ++++ .../inference/templates/serviceaccount.yaml | 13 ++++ charts/DemoUI/inference/values.yaml | 47 +++++++++++ demo/README.md | 7 ++ demo/inferenceUI/README.md | 61 +++++++++++++++ demo/inferenceUI/chainlit.py | 53 +++++++++++++ 12 files changed, 439 insertions(+) create mode 100644 charts/DemoUI/inference/.helmignore create mode 100644 charts/DemoUI/inference/Chart.yaml create mode 100644 charts/DemoUI/inference/README.md create mode 100644 charts/DemoUI/inference/templates/NOTES.txt create mode 100644 charts/DemoUI/inference/templates/_helpers.tpl create mode 100644 charts/DemoUI/inference/templates/deployment.yaml create mode 100644 charts/DemoUI/inference/templates/service.yaml create mode 100644 charts/DemoUI/inference/templates/serviceaccount.yaml create mode 100644 charts/DemoUI/inference/values.yaml create mode 100644 demo/README.md create mode 100644 demo/inferenceUI/README.md create mode 100644 demo/inferenceUI/chainlit.py diff --git a/charts/DemoUI/inference/.helmignore b/charts/DemoUI/inference/.helmignore new file mode 100644 index 000000000..0e8a0eb36 --- /dev/null +++ b/charts/DemoUI/inference/.helmignore @@ -0,0 +1,23 @@ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ diff --git a/charts/DemoUI/inference/Chart.yaml b/charts/DemoUI/inference/Chart.yaml new file mode 100644 index 000000000..8b6bb96bf --- /dev/null +++ b/charts/DemoUI/inference/Chart.yaml @@ -0,0 +1,15 @@ +apiVersion: v2 +name: inference +description: A Helm chart for chainlit +type: application +version: 0.1.0 +appVersion: "0.1.0" +sources: + - https://github.com/Azure/kaito +maintainers: + - name: ishaansehgal99 + email: ishaanforthewin@gmail.com + - name: Fei-Guo + email: vrgf2003@gmail.com + - name: helayoty + email: hebaelayoty@gmail.com diff --git a/charts/DemoUI/inference/README.md b/charts/DemoUI/inference/README.md new file mode 100644 index 000000000..ba314e6a9 --- /dev/null +++ b/charts/DemoUI/inference/README.md @@ -0,0 +1,44 @@ +# KAITO Demo Frontend Helm Chart +## Install +Before deploying the Demo front-end, you must set the `workspaceServiceURL` environment variable to point to your Workspace Service inference endpoint. + +To set this value, modify the `values.override.yaml` file or use the `--set` flag during Helm install/upgrade: + +```bash +helm install inference-frontend ./charts/DemoUI/inference/values.yaml --set env.workspaceServiceURL="http://:80/chat" +``` + +Or through a custom `values` file (`values.override.yaml`): +```bash +helm install inference-frontend ./charts/DemoUI/inference/values.yaml -f values.override.yaml +``` + +## Values + +| Key | Type | Default | Description | +|-------------------------------|--------|-------------------------|-------------------------------------------------------| +| `replicaCount` | int | `1` | Number of replicas | +| `image.repository` | string | `"python"` | Image repository | +| `image.pullPolicy` | string | `"IfNotPresent"` | Image pull policy | +| `image.tag` | string | `"3.8"` | Image tag | +| `imagePullSecrets` | list | `[]` | Specify image pull secrets | +| `podAnnotations` | object | `{}` | Annotations to add to the pod | +| `serviceAccount.create` | bool | `false` | Specifies whether a service account should be created | +| `serviceAccount.name` | string | `""` | The name of the service account to use | +| `service.type` | string | `"ClusterIP"` | Service type | +| `service.port` | int | `8000` | Service port | +| `env.workspaceServiceURL` | string | `""` | Workspace Service URL for the inference endpoint | +| `resources.limits.cpu` | string | `"500m"` | CPU limit | +| `resources.limits.memory` | string | `"256Mi"` | Memory limit | +| `resources.requests.cpu` | string | `"10m"` | CPU request | +| `resources.requests.memory` | string | `"128Mi"` | Memory request | +| `livenessProbe.exec.command` | list | `["pgrep", "chainlit"]` | Command for liveness probe | +| `readinessProbe.exec.command` | list | `["pgrep", "chainlit"]` | Command for readiness probe | +| `nodeSelector` | object | `{}` | Node labels for pod assignment | +| `tolerations` | list | `[]` | Tolerations for pod assignment | +| `affinity` | object | `{}` | Affinity for pod assignment | +| `ingress.enabled` | bool | `false` | Enable or disable ingress | + +### Liveness and Readiness Probes + +The `livenessProbe` and `readinessProbe` are configured to check if the Chainlit application is running by using `pgrep` to find the process. Adjust these probes as necessary for your deployment. diff --git a/charts/DemoUI/inference/templates/NOTES.txt b/charts/DemoUI/inference/templates/NOTES.txt new file mode 100644 index 000000000..113532096 --- /dev/null +++ b/charts/DemoUI/inference/templates/NOTES.txt @@ -0,0 +1,22 @@ +Get the application URL by running these commands: +{{- if .Values.ingress.enabled }} +{{- range $host := .Values.ingress.hosts }} + {{- range .paths }} + http{{ if $.Values.ingress.tls }}s{{ end }}://{{ $host.host }}{{ .path }} + {{- end }} +{{- end }} +{{- else if contains "NodePort" .Values.service.type }} + export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "inference.fullname" . }}) + export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}") + echo http://$NODE_IP:$NODE_PORT +{{- else if contains "LoadBalancer" .Values.service.type }} + NOTE: It may take a few minutes for the LoadBalancer IP to be available. + You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "inference.fullname" . }}' + export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "inference.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}") + echo http://$SERVICE_IP:{{ .Values.service.port }} +{{- else if contains "ClusterIP" .Values.service.type }} + export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "inference.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}") + export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}") + echo "Visit http://127.0.0.1:8080 to use your application" + kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8080:$CONTAINER_PORT +{{- end }} diff --git a/charts/DemoUI/inference/templates/_helpers.tpl b/charts/DemoUI/inference/templates/_helpers.tpl new file mode 100644 index 000000000..72fa79999 --- /dev/null +++ b/charts/DemoUI/inference/templates/_helpers.tpl @@ -0,0 +1,62 @@ +{{/* +Expand the name of the chart. +*/}} +{{- define "inference.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "inference.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "inference.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "inference.labels" -}} +helm.sh/chart: {{ include "inference.chart" . }} +{{ include "inference.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "inference.selectorLabels" -}} +app.kubernetes.io/name: {{ include "inference.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} + +{{/* +Create the name of the service account to use +*/}} +{{- define "inference.serviceAccountName" -}} +{{- if .Values.serviceAccount.create }} +{{- default (include "inference.fullname" .) .Values.serviceAccount.name }} +{{- else }} +{{- default "default" .Values.serviceAccount.name }} +{{- end }} +{{- end }} diff --git a/charts/DemoUI/inference/templates/deployment.yaml b/charts/DemoUI/inference/templates/deployment.yaml new file mode 100644 index 000000000..143f2a887 --- /dev/null +++ b/charts/DemoUI/inference/templates/deployment.yaml @@ -0,0 +1,77 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "inference.fullname" . }} + labels: + {{- include "inference.labels" . | nindent 4 }} +spec: + selector: + matchLabels: + {{- include "inference.selectorLabels" . | nindent 6 }} + template: + metadata: + {{- with .Values.podAnnotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + labels: + {{- include "inference.labels" . | nindent 8 }} + {{- with .Values.podLabels }} + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + serviceAccountName: {{ include "inference.serviceAccountName" . }} + securityContext: + {{- toYaml .Values.podSecurityContext | nindent 8 }} + containers: + - name: {{ .Chart.Name }} + securityContext: + {{- toYaml .Values.securityContext | nindent 12 }} + image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + command: ["/bin/sh"] + args: + - -c + - | + mkdir -p /app/frontend && \ + pip install chainlit requests && \ + wget -O /app/frontend/inference.py https://raw.githubusercontent.com/Azure/kaito/main/demo/inferenceUI/chainlit.py && \ + chainlit run frontend/inference.py -w + env: + - name: WORKSPACE_SERVICE_URL + value: "{{ .Values.env.workspaceServiceURL }}" + workingDir: /app + ports: + - name: http + containerPort: {{ .Values.service.port }} + protocol: TCP + livenessProbe: + {{- toYaml .Values.livenessProbe | nindent 12 }} + readinessProbe: + {{- toYaml .Values.readinessProbe | nindent 12 }} + resources: + {{- toYaml .Values.resources | nindent 12 }} + {{- with .Values.volumeMounts }} + volumeMounts: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- with .Values.volumes }} + volumes: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/charts/DemoUI/inference/templates/service.yaml b/charts/DemoUI/inference/templates/service.yaml new file mode 100644 index 000000000..f25596c31 --- /dev/null +++ b/charts/DemoUI/inference/templates/service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "inference.fullname" . }} + labels: + {{- include "inference.labels" . | nindent 4 }} +spec: + type: {{ .Values.service.type }} + ports: + - port: {{ .Values.service.port }} + targetPort: http + protocol: TCP + name: http + selector: + {{- include "inference.selectorLabels" . | nindent 4 }} diff --git a/charts/DemoUI/inference/templates/serviceaccount.yaml b/charts/DemoUI/inference/templates/serviceaccount.yaml new file mode 100644 index 000000000..bfb0cd8ee --- /dev/null +++ b/charts/DemoUI/inference/templates/serviceaccount.yaml @@ -0,0 +1,13 @@ +{{- if .Values.serviceAccount.create -}} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ include "inference.serviceAccountName" . }} + labels: + {{- include "inference.labels" . | nindent 4 }} + {{- with .Values.serviceAccount.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +automountServiceAccountToken: {{ .Values.serviceAccount.automount }} +{{- end }} diff --git a/charts/DemoUI/inference/values.yaml b/charts/DemoUI/inference/values.yaml new file mode 100644 index 000000000..518b43aa4 --- /dev/null +++ b/charts/DemoUI/inference/values.yaml @@ -0,0 +1,47 @@ +# values.yaml for Chainlit Front-end + +replicaCount: 1 +image: + repository: python + pullPolicy: IfNotPresent + tag: "3.8" +imagePullSecrets: [] +podAnnotations: {} +serviceAccount: + create: false + name: "" +service: + type: ClusterIP + port: 8000 + # env: + # Workspace Service URL + # Specify the URL for the Workspace Service inference endpoint. Use the DNS name within the cluster for reliability. + # + # Examples: + # Cluster IP: "http://:80/chat" + # DNS name: "http://..svc.cluster.local:80/chat" + # e.g., "http://workspace-falcon-7b.default.svc.cluster.local:80/chat" + # + # workspaceServiceURL: "" +resources: + limits: + cpu: 500m + memory: 256Mi + requests: + cpu: 10m + memory: 128Mi +livenessProbe: + exec: + command: + - pgrep + - chainlit +readinessProbe: + exec: + command: + - pgrep + - chainlit +nodeSelector: {} +tolerations: [] +affinity: {} +ingress: + enabled: false \ No newline at end of file diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 000000000..f76da1a24 --- /dev/null +++ b/demo/README.md @@ -0,0 +1,7 @@ +## Kaito Demos Overview + +Welcome to the KAITO demos directory! Here you'll find a collection of demonstration +applications designed to showcase various functionalities and +integrations with the KAITO Workspace. Feel free to explore! + +For specific instructions and details, please refer to the README.md file within each demo's directory. \ No newline at end of file diff --git a/demo/inferenceUI/README.md b/demo/inferenceUI/README.md new file mode 100644 index 000000000..f803dd72e --- /dev/null +++ b/demo/inferenceUI/README.md @@ -0,0 +1,61 @@ +## KAITO InferenceUI Demo + +The KAITO InferenceUI Demo provides a sample front-end application that demonstrates +how to interface with the KAITO Workspace for inference tasks. +This guide covers deploying the front-end as a Helm chart in a Kubernetes environment +as well as how to run the Python application independently. + +### Prerequisites + +- A Kubernetes cluster with Helm installed +- Access to the KAITO Workspace Service endpoint + +## Deployment with Helm +Deploy the KAITO InferenceUI Demo by setting the +workspaceServiceURL environment variable to your +Workspace Service endpoint. + + +### Configuring the Workspace Service URL +- Using the --set flag: + + ``` + helm install inference-frontend ./charts/DemoUI/inference --set env.workspaceServiceURL="http://..svc.cluster.local:80/chat" + ``` + - Using a custom `values.override.yaml` file: + ``` + env: + workspaceServiceURL: "http://..svc.cluster.local:80/chat" + ``` + Then deploy with custom values file: + ``` + helm install inference-frontend ./charts/DemoUI/inference -f ./charts/DemoUI/inference/values.override.yaml + ``` + +Replace `` and `` with your service's name and Kubernetes namespace. +This DNS naming convention ensures reliable service resolution within your cluster. + +## Accessing the Application +After deploying, access the KAITO InferenceUI based on your service type: +- NodePort + ``` + export NODE_PORT=$(kubectl get --namespace default -o jsonpath="{.spec.ports[0].nodePort}" services inference-frontend) + export NODE_IP=$(kubectl get nodes --namespace default -o jsonpath="{.items[0].status.addresses[0].address}") + echo "Access your application at http://$NODE_IP:$NODE_PORT" + ``` +- LoadBalancer (It may take a few minutes for the LoadBalancer IP to be available): + ``` + export SERVICE_IP=$(kubectl get svc --namespace default inference-frontend --template "{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}") + echo "Access your application at http://$SERVICE_IP:8000" + ``` +- ClusterIP (Use port-forwarding to access your application locally): + ``` + export POD_NAME=$(kubectl get pods --namespace default -l "app.kubernetes.io/name=inference" -o jsonpath="{.items[0].metadata.name}") + kubectl --namespace default port-forward $POD_NAME 8080:8000 + echo "Visit http://127.0.0.1:8080 to use your application" + ``` + +--- + +For additional support or to report issues, please contact the development +team at . diff --git a/demo/inferenceUI/chainlit.py b/demo/inferenceUI/chainlit.py new file mode 100644 index 000000000..5e192db2b --- /dev/null +++ b/demo/inferenceUI/chainlit.py @@ -0,0 +1,53 @@ +import os + +import chainlit as cl +import requests + +URL = os.environ.get('WORKSPACE_SERVICE_URL') +@cl.step +def inference(prompt): + # Endpoint URL + data = { + "prompt": prompt, + "return_full_text": False, + "clean_up_tokenization_spaces": True, + "generate_kwargs": { + "max_length": 1024, + "min_length": 0, + "do_sample": True, + "top_k": 10, + "early_stopping": False, + "num_beams": 1, + "temperature": 1.0, + "top_p": 1, + "typical_p": 1, + "repetition_penalty": 1 + } + } + + response = requests.post(URL, json=data) + + if response.status_code == 200: + response_data = response.json() + return response_data.get("Result", "No result found") + else: + return f"Error: Received response code {response.status_code}" + +@cl.on_message +async def main(message: cl.Message): + """ + This function is called every time a user inputs a message in the UI. + It sends back an intermediate response from inference, followed by the final answer. + + Args: + message: The user's message. + + Returns: + None. + """ + + # Call inference + response = inference(message.content) + + # Send the final answer + await cl.Message(content=response).send() \ No newline at end of file From cf7cf94d9cdb61a66ff034429e4363ba2bcb8842 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Mon, 1 Apr 2024 16:40:52 -0700 Subject: [PATCH 21/23] feat: Add API Docs and Improve Inference Readability (#331) **Reason for Change**: This change adds plenty more documentation and OpenAPI spec for our inference. It also enables the use of preferred non-NVIDIA nodes without crashing. **Requirements** - [x] added unit tests and e2e tests (if applicable). **Issue Fixed**: Fixes #321 **Notes for Reviewers**: --- .../inference/text-generation/api_spec.json | 599 ++++++++++++++++++ .../text-generation/inference_api.py | 284 ++++++++- .../text-generation/requirements.txt | 1 + .../tests/test_inference_api.py | 86 ++- 4 files changed, 919 insertions(+), 51 deletions(-) create mode 100644 presets/inference/text-generation/api_spec.json diff --git a/presets/inference/text-generation/api_spec.json b/presets/inference/text-generation/api_spec.json new file mode 100644 index 000000000..480fa97e4 --- /dev/null +++ b/presets/inference/text-generation/api_spec.json @@ -0,0 +1,599 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "FastAPI", + "version": "0.1.0" + }, + "paths": { + "/": { + "get": { + "summary": "Home Endpoint", + "description": "A simple endpoint that indicates the server is running.\nNo parameters are required. Returns a message indicating the server status.", + "operationId": "home__get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HomeResponse" + } + } + } + } + } + } + }, + "/healthz": { + "get": { + "summary": "Health Check Endpoint", + "operationId": "health_check_healthz_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HealthStatus" + }, + "example": { + "status": "Healthy" + } + } + } + }, + "500": { + "description": "Error Response", + "content": { + "application/json": { + "examples": { + "model_uninitialized": { + "summary": "Model not initialized", + "value": { + "detail": "Model not initialized" + } + }, + "pipeline_uninitialized": { + "summary": "Pipeline not initialized", + "value": { + "detail": "Pipeline not initialized" + } + } + } + } + } + } + } + } + }, + "/chat": { + "post": { + "summary": "Chat Endpoint", + "description": "Processes chat requests, generating text based on the specified pipeline (text generation or conversational).\nValidates required parameters based on the pipeline and returns the generated text.", + "operationId": "generate_text_chat_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/UnifiedRequestModel" + }, + "examples": { + "text_generation_example": { + "summary": "Text Generation Example", + "description": "An example of a text generation request.", + "value": { + "prompt": "Tell me a joke", + "return_full_text": true, + "clean_up_tokenization_spaces": false, + "generate_kwargs": { + "max_length": 200, + "min_length": 0, + "do_sample": true, + "early_stopping": false, + "num_beams": 1, + "temperature": 1, + "top_k": 10, + "top_p": 1, + "typical_p": 1, + "repetition_penalty": 1, + "eos_token_id": 11 + } + } + }, + "conversation_example": { + "summary": "Conversation Example", + "description": "An example of a conversational request.", + "value": { + "messages": [ + { + "role": "user", + "content": "What is your favourite condiment?" + }, + { + "role": "assistant", + "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!" + }, + { + "role": "user", + "content": "Do you have mayonnaise recipes?" + } + ], + "return_full_text": true, + "clean_up_tokenization_spaces": false, + "generate_kwargs": { + "max_length": 200, + "min_length": 0, + "do_sample": true, + "early_stopping": false, + "num_beams": 1, + "temperature": 1, + "top_k": 10, + "top_p": 1, + "typical_p": 1, + "repetition_penalty": 1, + "eos_token_id": 11 + } + } + } + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {}, + "examples": { + "text_generation": { + "summary": "Text Generation Response", + "value": { + "Result": "Generated text based on the prompt." + } + }, + "conversation": { + "summary": "Conversation Response", + "value": { + "Result": "Response to the last message in the conversation." + } + } + } + } + } + }, + "400": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "examples": { + "missing_prompt": { + "summary": "Missing Prompt", + "value": { + "detail": "Text generation parameter prompt required" + } + }, + "missing_messages": { + "summary": "Missing Messages", + "value": { + "detail": "Conversational parameter messages required" + } + } + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } + }, + "/metrics": { + "get": { + "summary": "Metrics Endpoint", + "description": "Provides system metrics, including GPU details if available, or CPU and memory usage otherwise.\nUseful for monitoring the resource utilization of the server running the ML models.", + "operationId": "get_metrics_metrics_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MetricsResponse" + }, + "examples": { + "gpu_metrics": { + "summary": "Example when GPUs are available", + "value": { + "gpu_info": [ + { + "id": "GPU-1234", + "name": "GeForce GTX 950", + "load": "25.00%", + "temperature": "55 C", + "memory": { + "used": "1.00 GB", + "total": "2.00 GB" + } + } + ] + } + }, + "cpu_metrics": { + "summary": "Example when only CPU is available", + "value": { + "cpu_info": { + "load_percentage": 20, + "physical_cores": 4, + "total_cores": 8, + "memory": { + "used": "4.00 GB", + "total": "16.00 GB" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "CPUInfo": { + "properties": { + "load_percentage": { + "type": "number", + "title": "Load Percentage" + }, + "physical_cores": { + "type": "integer", + "title": "Physical Cores" + }, + "total_cores": { + "type": "integer", + "title": "Total Cores" + }, + "memory": { + "$ref": "#/components/schemas/MemoryInfo" + } + }, + "type": "object", + "required": [ + "load_percentage", + "physical_cores", + "total_cores", + "memory" + ], + "title": "CPUInfo" + }, + "ErrorResponse": { + "properties": { + "detail": { + "type": "string", + "title": "Detail" + } + }, + "type": "object", + "required": [ + "detail" + ], + "title": "ErrorResponse" + }, + "GPUInfo": { + "properties": { + "id": { + "type": "string", + "title": "Id" + }, + "name": { + "type": "string", + "title": "Name" + }, + "load": { + "type": "string", + "title": "Load" + }, + "temperature": { + "type": "string", + "title": "Temperature" + }, + "memory": { + "$ref": "#/components/schemas/MemoryInfo" + } + }, + "type": "object", + "required": [ + "id", + "name", + "load", + "temperature", + "memory" + ], + "title": "GPUInfo" + }, + "GenerateKwargs": { + "properties": { + "max_length": { + "type": "integer", + "title": "Max Length", + "default": 200 + }, + "min_length": { + "type": "integer", + "title": "Min Length", + "default": 0 + }, + "do_sample": { + "type": "boolean", + "title": "Do Sample", + "default": true + }, + "early_stopping": { + "type": "boolean", + "title": "Early Stopping", + "default": false + }, + "num_beams": { + "type": "integer", + "title": "Num Beams", + "default": 1 + }, + "temperature": { + "type": "number", + "title": "Temperature", + "default": 1 + }, + "top_k": { + "type": "integer", + "title": "Top K", + "default": 10 + }, + "top_p": { + "type": "number", + "title": "Top P", + "default": 1 + }, + "typical_p": { + "type": "number", + "title": "Typical P", + "default": 1 + }, + "repetition_penalty": { + "type": "number", + "title": "Repetition Penalty", + "default": 1 + }, + "pad_token_id": { + "type": "integer", + "title": "Pad Token Id" + }, + "eos_token_id": { + "type": "integer", + "title": "Eos Token Id", + "default": 11 + } + }, + "type": "object", + "title": "GenerateKwargs", + "example": { + "max_length": 200, + "temperature": 0.7, + "top_p": 0.9, + "additional_param": "Example value" + } + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "type": "array", + "title": "Detail" + } + }, + "type": "object", + "title": "HTTPValidationError" + }, + "HealthStatus": { + "properties": { + "status": { + "type": "string", + "title": "Status", + "example": "Healthy" + } + }, + "type": "object", + "required": [ + "status" + ], + "title": "HealthStatus" + }, + "HomeResponse": { + "properties": { + "message": { + "type": "string", + "title": "Message", + "example": "Server is running" + } + }, + "type": "object", + "required": [ + "message" + ], + "title": "HomeResponse" + }, + "MemoryInfo": { + "properties": { + "used": { + "type": "string", + "title": "Used" + }, + "total": { + "type": "string", + "title": "Total" + } + }, + "type": "object", + "required": [ + "used", + "total" + ], + "title": "MemoryInfo" + }, + "Message": { + "properties": { + "role": { + "type": "string", + "title": "Role" + }, + "content": { + "type": "string", + "title": "Content" + } + }, + "type": "object", + "required": [ + "role", + "content" + ], + "title": "Message" + }, + "MetricsResponse": { + "properties": { + "gpu_info": { + "items": { + "$ref": "#/components/schemas/GPUInfo" + }, + "type": "array", + "title": "Gpu Info" + }, + "cpu_info": { + "$ref": "#/components/schemas/CPUInfo" + } + }, + "type": "object", + "title": "MetricsResponse" + }, + "UnifiedRequestModel": { + "properties": { + "prompt": { + "type": "string", + "title": "Prompt", + "description": "Prompt for text generation. Required for text-generation pipeline. Do not use with 'messages'." + }, + "return_full_text": { + "type": "boolean", + "title": "Return Full Text", + "description": "Return full text if True, else only added text", + "default": true + }, + "clean_up_tokenization_spaces": { + "type": "boolean", + "title": "Clean Up Tokenization Spaces", + "description": "Clean up extra spaces in text output", + "default": false + }, + "prefix": { + "type": "string", + "title": "Prefix", + "description": "Prefix added to prompt" + }, + "handle_long_generation": { + "type": "string", + "title": "Handle Long Generation", + "description": "Strategy to handle long generation" + }, + "generate_kwargs": { + "allOf": [ + { + "$ref": "#/components/schemas/GenerateKwargs" + } + ], + "title": "Generate Kwargs", + "description": "Additional kwargs for generate method" + }, + "messages": { + "items": { + "$ref": "#/components/schemas/Message" + }, + "type": "array", + "title": "Messages", + "description": "Messages for conversational model. Required for conversational pipeline. Do not use with 'prompt'." + } + }, + "type": "object", + "title": "UnifiedRequestModel" + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "type": "array", + "title": "Location" + }, + "msg": { + "type": "string", + "title": "Message" + }, + "type": { + "type": "string", + "title": "Error Type" + } + }, + "type": "object", + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError" + } + } + } +} \ No newline at end of file diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index f6c604a54..c23a15c6b 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -2,13 +2,15 @@ # Licensed under the MIT license. import os from dataclasses import asdict, dataclass, field -from typing import Any, Dict, List, Optional +from typing import Annotated, Any, Dict, List, Optional import GPUtil +import psutil import torch import transformers import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import Body, FastAPI, HTTPException +from fastapi.responses import Response from pydantic import BaseModel, Extra, Field from transformers import (AutoModelForCausalLM, AutoTokenizer, GenerationConfig, HfArgumentParser) @@ -35,7 +37,7 @@ class ModelConfig: load_in_8bit: bool = field(default=False, metadata={"help": "Load model in 8-bit mode"}) torch_dtype: Optional[str] = field(default=None, metadata={"help": "The torch dtype for the pre-trained model"}) device_map: str = field(default="auto", metadata={"help": "The device map for the pre-trained model"}) - + # Method to process additional arguments def process_additional_args(self, addt_args: List[str]): """ @@ -51,7 +53,7 @@ def process_additional_args(self, addt_args: List[str]): else: value = True # Assign a True value for standalone flags i += 1 # Move to the next item - + addt_args_dict[key] = value # Update the ModelConfig instance with the additional args @@ -102,20 +104,57 @@ def __post_init__(self): try: # Attempt to load the generation configuration default_generate_config = GenerationConfig.from_pretrained( - args.pretrained_model_name_or_path, + args.pretrained_model_name_or_path, local_files_only=args.local_files_only ).to_dict() except Exception as e: default_generate_config = {} -@app.get('/') +class HomeResponse(BaseModel): + message: str = Field(..., example="Server is running") +@app.get('/', response_model=HomeResponse, summary="Home Endpoint") def home(): - return "Server is running", 200 + """ + A simple endpoint that indicates the server is running. + No parameters are required. Returns a message indicating the server status. + """ + return {"message": "Server is running"} -@app.get("/healthz") +class HealthStatus(BaseModel): + status: str = Field(..., example="Healthy") +@app.get( + "/healthz", + response_model=HealthStatus, + summary="Health Check Endpoint", + responses={ + 200: { + "description": "Successful Response", + "content": { + "application/json": { + "example": {"status": "Healthy"} + } + } + }, + 500: { + "description": "Error Response", + "content": { + "application/json": { + "examples": { + "model_uninitialized": { + "summary": "Model not initialized", + "value": {"detail": "Model not initialized"} + }, + "pipeline_uninitialized": { + "summary": "Pipeline not initialized", + "value": {"detail": "Pipeline not initialized"} + } + } + } + } + } + } +) def health_check(): - if not torch.cuda.is_available(): - raise HTTPException(status_code=500, detail="No GPU available") if not model: raise HTTPException(status_code=500, detail="Model not initialized") if not pipeline: @@ -137,10 +176,22 @@ class GenerateKwargs(BaseModel): eos_token_id: Optional[int] = tokenizer.eos_token_id class Config: extra = Extra.allow # Allows for additional fields not explicitly defined + schema_extra = { + "example": { + "max_length": 200, + "temperature": 0.7, + "top_p": 0.9, + "additional_param": "Example value" + } + } + +class Message(BaseModel): + role: str + content: str class UnifiedRequestModel(BaseModel): # Fields for text generation - prompt: Optional[str] = Field(None, description="Prompt for text generation") + prompt: Optional[str] = Field(None, description="Prompt for text generation. Required for text-generation pipeline. Do not use with 'messages'.") return_full_text: Optional[bool] = Field(True, description="Return full text if True, else only added text") clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") prefix: Optional[str] = Field(None, description="Prefix added to prompt") @@ -148,10 +199,103 @@ class UnifiedRequestModel(BaseModel): generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method") # Field for conversational model - messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model") + messages: Optional[List[Message]] = Field(None, description="Messages for conversational model. Required for conversational pipeline. Do not use with 'prompt'.") + def messages_to_dict_list(self): + return [message.dict() for message in self.messages] if self.messages else [] + +class ErrorResponse(BaseModel): + detail: str -@app.post("/chat") -def generate_text(request_model: UnifiedRequestModel): +@app.post( + "/chat", + summary="Chat Endpoint", + responses={ + 200: { + "description": "Successful Response", + "content": { + "application/json": { + "examples": { + "text_generation": { + "summary": "Text Generation Response", + "value": { + "Result": "Generated text based on the prompt." + } + }, + "conversation": { + "summary": "Conversation Response", + "value": { + "Result": "Response to the last message in the conversation." + } + } + } + } + } + }, + 400: { + "model": ErrorResponse, + "description": "Validation Error", + "content": { + "application/json": { + "examples": { + "missing_prompt": { + "summary": "Missing Prompt", + "value": {"detail": "Text generation parameter prompt required"} + }, + "missing_messages": { + "summary": "Missing Messages", + "value": {"detail": "Conversational parameter messages required"} + } + } + } + } + }, + 500: { + "model": ErrorResponse, + "description": "Internal Server Error" + } + } +) +def generate_text( + request_model: Annotated[ + UnifiedRequestModel, + Body( + openapi_examples={ + "text_generation_example": { + "summary": "Text Generation Example", + "description": "An example of a text generation request.", + "value": { + "prompt": "Tell me a joke", + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "prefix": None, + "handle_long_generation": None, + "generate_kwargs": GenerateKwargs().dict(), + }, + }, + "conversation_example": { + "summary": "Conversation Example", + "description": "An example of a conversational request.", + "value": { + "messages": [ + {"role": "user", "content": "What is your favourite condiment?"}, + {"role": "assistant", "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!"}, + {"role": "user", "content": "Do you have mayonnaise recipes?"} + ], + "return_full_text": True, + "clean_up_tokenization_spaces": False, + "prefix": None, + "handle_long_generation": None, + "generate_kwargs": GenerateKwargs().dict(), + }, + }, + }, + ), + ], +): + """ + Processes chat requests, generating text based on the specified pipeline (text generation or conversational). + Validates required parameters based on the pipeline and returns the generated text. + """ user_generate_kwargs = request_model.generate_kwargs.dict() if request_model.generate_kwargs else {} generate_kwargs = {**default_generate_config, **user_generate_kwargs} @@ -176,12 +320,12 @@ def generate_text(request_model: UnifiedRequestModel): return {"Result": result} - elif args.pipeline == "conversational": + elif args.pipeline == "conversational": if not request_model.messages: raise HTTPException(status_code=400, detail="Conversational parameter messages required") response = pipeline( - request_model.messages, + request_model.messages_to_dict_list(), clean_up_tokenization_spaces=request_model.clean_up_tokenization_spaces, **generate_kwargs ) @@ -190,27 +334,101 @@ def generate_text(request_model: UnifiedRequestModel): else: raise HTTPException(status_code=400, detail="Invalid pipeline type") -@app.get("/metrics") +class MemoryInfo(BaseModel): + used: str + total: str + +class CPUInfo(BaseModel): + load_percentage: float + physical_cores: int + total_cores: int + memory: MemoryInfo + +class GPUInfo(BaseModel): + id: str + name: str + load: str + temperature: str + memory: MemoryInfo + +class MetricsResponse(BaseModel): + gpu_info: Optional[List[GPUInfo]] = None + cpu_info: Optional[CPUInfo] = None + +@app.get( + "/metrics", + response_model=MetricsResponse, + summary="Metrics Endpoint", + responses={ + 200: { + "description": "Successful Response", + "content": { + "application/json": { + "examples": { + "gpu_metrics": { + "summary": "Example when GPUs are available", + "value": { + "gpu_info": [{"id": "GPU-1234", "name": "GeForce GTX 950", "load": "25.00%", "temperature": "55 C", "memory": {"used": "1.00 GB", "total": "2.00 GB"}}], + "cpu_info": None # Indicates CPUs info might not be present when GPUs are available + } + }, + "cpu_metrics": { + "summary": "Example when only CPU is available", + "value": { + "gpu_info": None, # Indicates GPU info might not be present when only CPU is available + "cpu_info": {"load_percentage": 20.0, "physical_cores": 4, "total_cores": 8, "memory": {"used": "4.00 GB", "total": "16.00 GB"}} + } + } + } + } + } + }, + 500: { + "description": "Internal Server Error", + "model": ErrorResponse, + } + } +) def get_metrics(): + """ + Provides system metrics, including GPU details if available, or CPU and memory usage otherwise. + Useful for monitoring the resource utilization of the server running the ML models. + """ try: - gpus = GPUtil.getGPUs() - gpu_info = [] - for gpu in gpus: - gpu_info.append({ - "id": gpu.id, - "name": gpu.name, - "load": f"{gpu.load * 100:.2f}%", # Format as percentage - "temperature": f"{gpu.temperature} C", - "memory": { - "used": f"{gpu.memoryUsed / 1024:.2f} GB", - "total": f"{gpu.memoryTotal / 1024:.2f} GB" - } - }) - return {"gpu_info": gpu_info} + if torch.cuda.is_available(): + gpus = GPUtil.getGPUs() + gpu_info = [GPUInfo( + id=gpu.id, + name=gpu.name, + load=f"{gpu.load * 100:.2f}%", + temperature=f"{gpu.temperature} C", + memory=MemoryInfo( + used=f"{gpu.memoryUsed / (1024 ** 3):.2f} GB", + total=f"{gpu.memoryTotal / (1024 ** 3):.2f} GB" + ) + ) for gpu in gpus] + return MetricsResponse(gpu_info=gpu_info) + else: + # Gather CPU metrics + cpu_usage = psutil.cpu_percent(interval=1, percpu=False) + physical_cores = psutil.cpu_count(logical=False) + total_cores = psutil.cpu_count(logical=True) + virtual_memory = psutil.virtual_memory() + memory = MemoryInfo( + used=f"{virtual_memory.used / (1024 ** 3):.2f} GB", + total=f"{virtual_memory.total / (1024 ** 3):.2f} GB" + ) + cpu_info = CPUInfo( + load_percentage=cpu_usage, + physical_cores=physical_cores, + total_cores=total_cores, + memory=memory + ) + return MetricsResponse(cpu_info=cpu_info) except Exception as e: - return {"error": str(e)} + raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": local_rank = int(os.environ.get("LOCAL_RANK", 0)) # Default to 0 if not set port = 5000 + local_rank # Adjust port based on local rank - uvicorn.run(app=app, host='0.0.0.0', port=port) + uvicorn.run(app=app, host='0.0.0.0', port=port) \ No newline at end of file diff --git a/presets/inference/text-generation/requirements.txt b/presets/inference/text-generation/requirements.txt index 8a7c50dbe..1d1c845a5 100644 --- a/presets/inference/text-generation/requirements.txt +++ b/presets/inference/text-generation/requirements.txt @@ -8,6 +8,7 @@ uvicorn[standard]==0.23.2 bitsandbytes==0.42.0 deepspeed==0.11.1 gputil==1.4.0 +psutil==5.9.8 # For UTs pytest==8.0.0 httpx==0.26.0 \ No newline at end of file diff --git a/presets/inference/text-generation/tests/test_inference_api.py b/presets/inference/text-generation/tests/test_inference_api.py index c15b0f38f..535d0f4e2 100644 --- a/presets/inference/text-generation/tests/test_inference_api.py +++ b/presets/inference/text-generation/tests/test_inference_api.py @@ -4,7 +4,6 @@ from unittest.mock import patch import pytest -import torch from fastapi.testclient import TestClient from transformers import AutoTokenizer @@ -44,7 +43,7 @@ def test_conversational(configured_app): client = TestClient(configured_app) messages = [ {"role": "user", "content": "What is your favourite condiment?"}, - {"role": "assistant", "content": "Well, Im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever Im cooking up in the kitchen!"}, + {"role": "assistant", "content": "Well, im quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever im cooking up in the kitchen!"}, {"role": "user", "content": "Do you have mayonnaise recipes?"} ] request_data = { @@ -102,17 +101,12 @@ def test_missing_prompt(configured_app): def test_read_main(configured_app): client = TestClient(configured_app) response = client.get("/") - server_msg, status_code = response.json() - assert server_msg == "Server is running" - assert status_code == 200 + assert response.status_code == 200 + assert response.json() == {"message": "Server is running"} def test_health_check(configured_app): - device = "GPU" if torch.cuda.is_available() else "CPU" - if device != "GPU": - pytest.skip("Skipping healthz endpoint check - running on CPU") client = TestClient(configured_app) response = client.get("/healthz") - # Assuming we have a GPU available assert response.status_code == 200 assert response.json() == {"status": "Healthy"} @@ -122,17 +116,73 @@ def test_get_metrics(configured_app): assert response.status_code == 200 assert "gpu_info" in response.json() +def test_get_metrics_with_gpus(configured_app): + client = TestClient(configured_app) + # Define a simple mock GPU object with the necessary attributes + class MockGPU: + def __init__(self, id, name, load, temperature, memoryUsed, memoryTotal): + self.id = id + self.name = name + self.load = load + self.temperature = temperature + self.memoryUsed = memoryUsed + self.memoryTotal = memoryTotal + + # Create a mock GPU object with the desired attributes + mock_gpu = MockGPU( + id="GPU-1234", + name="GeForce GTX 950", + load=0.25, # 25% + temperature=55, # 55 C + memoryUsed=1 * (1024 ** 3), # 1 GB + memoryTotal=2 * (1024 ** 3) # 2 GB + ) + + # Mock torch.cuda.is_available to simulate an environment with GPUs + # Mock GPUtil.getGPUs to return a list containing the mock GPU object + with patch('torch.cuda.is_available', return_value=True), \ + patch('GPUtil.getGPUs', return_value=[mock_gpu]): + response = client.get("/metrics") + assert response.status_code == 200 + data = response.json() + + # Assertions to verify that the GPU info is correctly returned in the response + assert data["gpu_info"] != [] + assert len(data["gpu_info"]) == 1 + gpu_data = data["gpu_info"][0] + + assert gpu_data["id"] == "GPU-1234" + assert gpu_data["name"] == "GeForce GTX 950" + assert gpu_data["load"] == "25.00%" + assert gpu_data["temperature"] == "55 C" + assert gpu_data["memory"]["used"] == "1.00 GB" + assert gpu_data["memory"]["total"] == "2.00 GB" + assert data["cpu_info"] is None # Assuming CPU info is not present when GPUs are available + def test_get_metrics_no_gpus(configured_app): client = TestClient(configured_app) - with patch('GPUtil.getGPUs', return_value=[]) as mock_getGPUs: + # Mock GPUtil.getGPUs to simulate an environment without GPUs + with patch('torch.cuda.is_available', return_value=False), \ + patch('psutil.cpu_percent', return_value=20.0), \ + patch('psutil.cpu_count', side_effect=[4, 8]), \ + patch('psutil.virtual_memory') as mock_virtual_memory: + mock_virtual_memory.return_value.used = 4 * (1024 ** 3) # 4 GB + mock_virtual_memory.return_value.total = 16 * (1024 ** 3) # 16 GB response = client.get("/metrics") assert response.status_code == 200 - assert response.json()["gpu_info"] == [] + data = response.json() + assert data["gpu_info"] is None # No GPUs available + assert data["cpu_info"] is not None # CPU info should be present + assert data["cpu_info"]["load_percentage"] == 20.0 + assert data["cpu_info"]["physical_cores"] == 4 + assert data["cpu_info"]["total_cores"] == 8 + assert data["cpu_info"]["memory"]["used"] == "4.00 GB" + assert data["cpu_info"]["memory"]["total"] == "16.00 GB" def test_default_generation_params(configured_app): if configured_app.test_config['pipeline'] != 'text-generation': pytest.skip("Skipping non-text-generation tests") - + client = TestClient(configured_app) request_data = { @@ -144,14 +194,14 @@ def test_default_generation_params(configured_app): with patch('inference_api.pipeline') as mock_pipeline: mock_pipeline.return_value = [{"generated_text": "Mocked response"}] # Mock the response of the pipeline function - + response = client.post("/chat", json=request_data) - + assert response.status_code == 200 data = response.json() assert "Result" in data assert data["Result"] == "Mocked response", "The response content doesn't match the expected mock response" - + # Check the default args _, kwargs = mock_pipeline.call_args assert kwargs['max_length'] == 200 @@ -187,7 +237,7 @@ def test_generation_with_max_length(configured_app): data = response.json() print("Response: ", data["Result"]) assert "Result" in data, "The response should contain a 'Result' key" - + tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path']) prompt_tokens = tokenizer.tokenize(prompt) total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt @@ -207,7 +257,7 @@ def test_generation_with_min_length(configured_app): client = TestClient(configured_app) prompt = "This prompt requests a response of a certain minimum length to test the functionality." min_length = 30 - max_length = 40 + max_length = 40 request_data = { "prompt": prompt, @@ -221,7 +271,7 @@ def test_generation_with_min_length(configured_app): assert response.status_code == 200 data = response.json() assert "Result" in data, "The response should contain a 'Result' key" - + tokenizer = AutoTokenizer.from_pretrained(configured_app.test_config['model_path']) prompt_tokens = tokenizer.tokenize(prompt) total_tokens = tokenizer.tokenize(data["Result"]) # data["Result"] includes the input prompt From 77aa95b5af48ce5da97b32c32d9fc2c06b77faa9 Mon Sep 17 00:00:00 2001 From: Ishaan Sehgal Date: Mon, 1 Apr 2024 22:18:38 -0700 Subject: [PATCH 22/23] chore: Factoring out reusable presets logic - Part 4 (#332) --- pkg/inference/preset-inferences.go | 92 ++++++++++-------------------- pkg/utils/common-preset.go | 72 +++++++++++++++++++++++ 2 files changed, 101 insertions(+), 63 deletions(-) create mode 100644 pkg/utils/common-preset.go diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 4c4792b54..92069e146 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -5,6 +5,7 @@ package inference import ( "context" "fmt" + "github.com/azure/kaito/pkg/utils" "os" "strconv" @@ -19,10 +20,9 @@ import ( ) const ( - ProbePath = "/healthz" - Port5000 = int32(5000) - InferenceFile = "inference_api.py" - DefaultVolumeMountPath = "/dev/shm" + ProbePath = "/healthz" + Port5000 = int32(5000) + InferenceFile = "inference_api.py" ) var ( @@ -92,21 +92,21 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl return nil } -func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) { - imageName := string(workspaceObj.Inference.Preset.Name) - imageTag := inferenceObj.Tag +func GetInferenceImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, presetObj *model.PresetParam) (string, []corev1.LocalObjectReference) { imagePullSecretRefs := []corev1.LocalObjectReference{} - if inferenceObj.ImageAccessMode == "private" { - imageName = string(workspaceObj.Inference.Preset.PresetOptions.Image) + if presetObj.ImageAccessMode == "private" { + imageName := workspaceObj.Inference.Preset.PresetOptions.Image for _, secretName := range workspaceObj.Inference.Preset.PresetOptions.ImagePullSecrets { imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName}) } return imageName, imagePullSecretRefs + } else { + imageName := string(workspaceObj.Inference.Preset.Name) + imageTag := presetObj.Tag + registryName := os.Getenv("PRESET_REGISTRY_NAME") + imageName = fmt.Sprintf("%s/kaito-%s:%s", registryName, imageName, imageTag) + return imageName, imagePullSecretRefs } - - registryName := os.Getenv("PRESET_REGISTRY_NAME") - imageName = registryName + fmt.Sprintf("/kaito-%s:%s", imageName, imageTag) - return imageName, imagePullSecretRefs } func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, @@ -118,17 +118,25 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work } } - volume, volumeMount := configVolume(workspaceObj, inferenceObj) + var volumes []corev1.Volume + var volumeMounts []corev1.VolumeMount + volume, volumeMount := utils.ConfigSHMVolume(workspaceObj) + if volume.Name != "" { + volumes = append(volumes, volume) + } + if volumeMount.Name != "" { + volumeMounts = append(volumeMounts, volumeMount) + } commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj) - image, imagePullSecrets := GetImageInfo(ctx, workspaceObj, inferenceObj) + image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceObj) var depObj client.Object if supportDistributedInference { depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands, - containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) + containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts) } else { depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands, - containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) + containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts) } err := resources.CreateResource(ctx, depObj, kubeClient) if client.IgnoreAlreadyExists(err) != nil { @@ -142,10 +150,10 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work // and sets the GPU resources required for inference. // Returns the command and resource configuration. func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) { - torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams) - torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams) - modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams) - commands := shellCommand(torchCommand + " " + modelCommand) + torchCommand := utils.BuildCmdStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams) + torchCommand = utils.BuildCmdStr(torchCommand, inferenceObj.TorchRunRdzvParams) + modelCommand := utils.BuildCmdStr(InferenceFile, inferenceObj.ModelRunParams) + commands := utils.ShellCmd(torchCommand + " " + modelCommand) resourceRequirements := corev1.ResourceRequirements{ Requests: corev1.ResourceList{ @@ -158,45 +166,3 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetP return commands, resourceRequirements } - -func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) { - volume := []corev1.Volume{} - volumeMount := []corev1.VolumeMount{} - - // Signifies multinode inference requirement - if *wObj.Resource.Count > 1 { - // Append share memory volume to any existing volumes - volume = append(volume, corev1.Volume{ - Name: "dshm", - VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{ - Medium: "Memory", - }, - }, - }) - - volumeMount = append(volumeMount, corev1.VolumeMount{ - Name: volume[0].Name, - MountPath: DefaultVolumeMountPath, - }) - } - - return volume, volumeMount -} - -func shellCommand(command string) []string { - return []string{ - "/bin/sh", - "-c", - command, - } -} - -func buildCommandStr(baseCommand string, torchRunParams map[string]string) string { - updatedBaseCommand := baseCommand - for key, value := range torchRunParams { - updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value) - } - - return updatedBaseCommand -} diff --git a/pkg/utils/common-preset.go b/pkg/utils/common-preset.go new file mode 100644 index 000000000..91cbe2e92 --- /dev/null +++ b/pkg/utils/common-preset.go @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +package utils + +import ( + "fmt" + kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" + corev1 "k8s.io/api/core/v1" +) + +const ( + DefaultVolumeMountPath = "/dev/shm" +) + +func ConfigSHMVolume(wObj *kaitov1alpha1.Workspace) (corev1.Volume, corev1.VolumeMount) { + volume := corev1.Volume{} + volumeMount := corev1.VolumeMount{} + + // Signifies multinode inference requirement + if *wObj.Resource.Count > 1 { + // Append share memory volume to any existing volumes + volume = corev1.Volume{ + Name: "dshm", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{ + Medium: "Memory", + }, + }, + } + + volumeMount = corev1.VolumeMount{ + Name: volume.Name, + MountPath: DefaultVolumeMountPath, + } + } + + return volume, volumeMount +} + +func ConfigDataVolume() ([]corev1.Volume, []corev1.VolumeMount) { + var volumes []corev1.Volume + var volumeMounts []corev1.VolumeMount + volumes = append(volumes, corev1.Volume{ + Name: "data-volume", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }) + + volumeMounts = append(volumeMounts, corev1.VolumeMount{ + Name: "data-volume", + MountPath: "/data", + }) + return volumes, volumeMounts +} + +func ShellCmd(command string) []string { + return []string{ + "/bin/sh", + "-c", + command, + } +} + +func BuildCmdStr(baseCommand string, torchRunParams map[string]string) string { + updatedBaseCommand := baseCommand + for key, value := range torchRunParams { + updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value) + } + + return updatedBaseCommand +} From fde6369d760f92c31531e55ea79684084c4dab34 Mon Sep 17 00:00:00 2001 From: Heba <31887807+helayoty@users.noreply.github.com> Date: Wed, 3 Apr 2024 17:37:48 -0700 Subject: [PATCH 23/23] fix: Update namespace in the helm chart (#337) **Reason for Change**: - Update the namespace to use the release namespace. **Requirements** - [ ] added unit tests and e2e tests (if applicable). **Issue Fixed**: Fixes #336 **Notes for Reviewers**: --------- Signed-off-by: Heba <31887807+helayoty@users.noreply.github.com> Co-authored-by: ishaansehgal99 --- charts/kaito/workspace/templates/clusterrole_binding.yaml | 2 +- charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml | 2 +- charts/kaito/workspace/templates/role_binding.yaml | 2 +- charts/kaito/workspace/templates/webhooks.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/charts/kaito/workspace/templates/clusterrole_binding.yaml b/charts/kaito/workspace/templates/clusterrole_binding.yaml index 4451b3f5e..3076a4bfb 100644 --- a/charts/kaito/workspace/templates/clusterrole_binding.yaml +++ b/charts/kaito/workspace/templates/clusterrole_binding.yaml @@ -11,4 +11,4 @@ roleRef: subjects: - kind: ServiceAccount name: {{ include "kaito.fullname" . }}-sa - namespace: {{ include "kaito.fullname" . }} + namespace: {{ .Release.Namespace }} diff --git a/charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml b/charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml index 21974de27..850eb4562 100644 --- a/charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml +++ b/charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml @@ -63,7 +63,7 @@ apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata: name: high-priority-nonpreempting - namespace: {{ include "kaito.fullname" . }} + namespace: {{ .Release.Namespace }} labels: {{- include "kaito.labels" . | nindent 4 }} value: 1000000 diff --git a/charts/kaito/workspace/templates/role_binding.yaml b/charts/kaito/workspace/templates/role_binding.yaml index 708b6b173..c67e6bc13 100644 --- a/charts/kaito/workspace/templates/role_binding.yaml +++ b/charts/kaito/workspace/templates/role_binding.yaml @@ -12,4 +12,4 @@ roleRef: subjects: - kind: ServiceAccount name: {{ include "kaito.fullname" . }}-sa - namespace: {{ include "kaito.fullname" . }} + namespace: {{ .Release.Namespace }} diff --git a/charts/kaito/workspace/templates/webhooks.yaml b/charts/kaito/workspace/templates/webhooks.yaml index 440804a72..9b501304b 100644 --- a/charts/kaito/workspace/templates/webhooks.yaml +++ b/charts/kaito/workspace/templates/webhooks.yaml @@ -10,7 +10,7 @@ webhooks: clientConfig: service: name: {{ include "kaito.fullname" . }} - namespace: {{ include "kaito.fullname" . }} + namespace: {{ .Release.Namespace }} port: {{ .Values.webhook.port }} failurePolicy: Fail sideEffects: None