Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add model version #341

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
624f479
Initial Dockerfile and fastapi implementation
etredal Mar 12, 2024
fcb9dc6
Merge branch 'Azure:main' into etredal/AddModelVersion
etredal Mar 13, 2024
646aba4
Merge branch 'Azure:main' into etredal/AddModelVersion
etredal Mar 13, 2024
484c0ad
Rename, add constants, template update
etredal Mar 13, 2024
6941ea5
Merge branch 'etredal/AddModelVersion' of https://github.com/etredal/…
etredal Mar 13, 2024
a8dcafd
Merge branch 'main' into etredal/AddModelVersion
etredal Mar 13, 2024
031735d
Fix formatting
etredal Mar 14, 2024
bf79780
Merge branch 'etredal/AddModelVersion' of https://github.com/etredal/…
etredal Mar 14, 2024
62edd96
Merge branch 'Azure:main' into etredal/AddModelVersion
etredal Mar 14, 2024
afc7030
Version and endpoint
etredal Mar 14, 2024
199d699
Version updates
etredal Mar 14, 2024
e85ee06
Resync
etredal Mar 14, 2024
f298777
Phi2
etredal Mar 14, 2024
478f68b
Version
etredal Mar 20, 2024
60bada8
Versioning fixes
etredal Mar 27, 2024
29f14e0
Adding MODEL_VERSION into .txt file
etredal Mar 31, 2024
fcc19a1
MODEL VERSION HASH
etredal Apr 3, 2024
7582ee5
Get Hash
etredal Apr 4, 2024
3cff917
Merge branch 'etredal-etredal/AddModelVersion'
etredal Apr 4, 2024
cf897d2
Merge branch 'main' of https://github.com/Azure/kaito
etredal Apr 4, 2024
9db1a04
Version comments
etredal Apr 4, 2024
bf015c0
fix: Checkout Evans awesome fork
ishaansehgal99 Apr 5, 2024
6b88a93
fix: Checkout Evans awesome fork
ishaansehgal99 Apr 5, 2024
81ce9c9
feat: Document version endpoint
ishaansehgal99 Apr 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/e2e-preset-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
run: |
PR_BRANCH=${{ env.BRANCH_NAME }} \
FORCE_RUN_ALL=${{ env.FORCE_RUN_ALL }} \
PR_REPO_URL=${{ github.event.pull_request.head.repo.clone_url }} \
python3 .github/workflows/kind-cluster/determine_models.py

- name: Print Determined Models
Expand Down Expand Up @@ -274,6 +275,11 @@ jobs:
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && (steps.check_prod_image.outputs.IMAGE_EXISTS == 'false' || env.FORCE_RUN_ALL == 'true')
run: |
curl http://${{ steps.get_ip.outputs.SERVICE_IP }}:80/healthz

- name: Test version endpoint
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && (steps.check_prod_image.outputs.IMAGE_EXISTS == 'false' || env.FORCE_RUN_ALL == 'true')
run: |
curl http://${{ steps.get_ip.outputs.SERVICE_IP }}:80/version

- name: Test inference endpoint
if: steps.check_test_image.outputs.IMAGE_EXISTS == 'true' && (steps.check_prod_image.outputs.IMAGE_EXISTS == 'false' || env.FORCE_RUN_ALL == 'true')
Expand Down
14 changes: 11 additions & 3 deletions .github/workflows/kind-cluster/determine_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def models_to_build(files_changed):
seen_model_types.add(model_info["type"])
return list(models)

def check_modified_models(pr_branch):
def check_modified_models(pr_branch, pr_repo_url):
"""Check for modified models in the repository."""
repo_dir = Path.cwd() / "repo"

Expand All @@ -102,7 +102,14 @@ def check_modified_models(pr_branch):

run_command("git checkout --detach")
run_command("git fetch origin main:main")
run_command(f"git fetch origin {pr_branch}:{pr_branch}")

fetch_command = f"git fetch origin {pr_branch}:{pr_branch}"
if pr_repo_url != KAITO_REPO_URL:
# Add the PR's repo as a new remote only if it's different from the main repo
run_command("git remote add pr_repo {}".format(pr_repo_url))
fetch_command = f"git fetch pr_repo {pr_branch}"

run_command(fetch_command)
run_command(f"git checkout {pr_branch}")

files = run_command("git diff --name-only origin/main") # Returns each file on newline
Expand All @@ -118,14 +125,15 @@ def check_modified_models(pr_branch):
def main():
pr_branch = os.environ.get("PR_BRANCH", "main") # If not specified default to 'main'
force_run_all = os.environ.get("FORCE_RUN_ALL", "false") # If not specified default to False
pr_repo_url = os.environ.get("PR_REPO_URL", KAITO_REPO_URL)

affected_models = []
if force_run_all != "false":
affected_models = [model['name'] for model in YAML_PR['models']]
else:
# Logic to determine affected models
# Example: affected_models = ['model1', 'model2', 'model3']
affected_models = check_modified_models(pr_branch)
affected_models = check_modified_models(pr_branch, pr_repo_url)

# Convert the list of models into JSON matrix format
matrix = create_matrix(affected_models)
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/kind-cluster/docker-job-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ spec:
--build-arg WEIGHTS_PATH=/weights \
--build-arg VERSION={{VERSION}} \
--build-arg MODEL_TYPE={{MODEL_TYPE}} \
--build-arg IMAGE_NAME={{IMAGE_NAME}} \
--build-arg MODEL_VERSION={{MODEL_VERSION}} \
-f $DOCKERFILE_PATH /
docker push $ACR_NAME.azurecr.io/{{IMAGE_NAME}}:$VERSION
env:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/preset-image-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ jobs:
run: |
PR_BRANCH=${{ env.BRANCH_NAME }} \
FORCE_RUN_ALL=${{ env.FORCE_RUN_ALL }} \
PR_REPO_URL=${{ github.event.pull_request.head.repo.clone_url }} \
python3 .github/workflows/kind-cluster/determine_models.py

- name: Print Determined Models
Expand Down
10 changes: 8 additions & 2 deletions docker/presets/inference/llama-2/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# --build-arg WEIGHTS_PATH=/weights \
# --build-arg VERSION={{VERSION}} \
# --build-arg MODEL_TYPE={{MODEL_TYPE}} \
# --build-arg IMAGE_NAME={{IMAGE_NAME}} \
# --build-arg MODEL_VERSION={{MODEL_VERSION}} \

FROM python:3.8-slim
WORKDIR /workspace
Expand All @@ -26,8 +28,12 @@ RUN pip install 'uvicorn[standard]'
ARG WEIGHTS_PATH
ARG MODEL_TYPE
ARG VERSION
# Write the version to a file
RUN echo $VERSION > /workspace/llama/version.txt
ARG IMAGE_NAME
ARG MODEL_VERSION

# Write metadata to model_info.json file
RUN MODEL_VERSION_HASH="${MODEL_VERSION##*/}" && \
echo "{\"Model Type\": \"$MODEL_TYPE\", \"Version\": \"$VERSION\", \"Image Name\": \"$IMAGE_NAME\", \"Model Version URL\": \"$MODEL_VERSION\", \"REVISION_ID\": \"$MODEL_VERSION_HASH\"}" > /workspace/llama/model_info.json

ADD ${WEIGHTS_PATH} /workspace/llama/llama-2/weights
ADD kaito/presets/inference/${MODEL_TYPE} /workspace/llama/llama-2
7 changes: 5 additions & 2 deletions docker/presets/inference/tfs-onnx/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ FROM mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu118-py38-torch211
ARG WEIGHTS_PATH
ARG MODEL_TYPE
ARG VERSION
ARG IMAGE_NAME
ARG MODEL_VERSION

# Set the working directory
WORKDIR /workspace/tfs

# Write the version to a file
RUN echo $VERSION > /workspace/tfs/version.txt
# Write metadata to model_info.json file
RUN MODEL_VERSION_HASH="${MODEL_VERSION##*/}" && \
echo "{\"Model Type\": \"$MODEL_TYPE\", \"Version\": \"$VERSION\", \"Image Name\": \"$IMAGE_NAME\", \"Model Version URL\": \"$MODEL_VERSION\", \"REVISION_ID\": \"$MODEL_VERSION_HASH\"}" > /workspace/tfs/model_info.json

# First, copy just the requirements.txt file and install dependencies
# This is done before copying the code to utilize Docker's layer caching and
Expand Down
7 changes: 5 additions & 2 deletions docker/presets/inference/tfs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ FROM python:3.10-slim
ARG WEIGHTS_PATH
ARG MODEL_TYPE
ARG VERSION
ARG IMAGE_NAME
ARG MODEL_VERSION

# Set the working directory
WORKDIR /workspace/tfs

# Write the version to a file
RUN echo $VERSION > /workspace/tfs/version.txt
# Write metadata to model_info.json file
RUN MODEL_VERSION_HASH="${MODEL_VERSION##*/}" && \
echo "{\"Model Type\": \"$MODEL_TYPE\", \"Version\": \"$VERSION\", \"Image Name\": \"$IMAGE_NAME\", \"Model Version URL\": \"$MODEL_VERSION\", \"REVISION_ID\": \"$MODEL_VERSION_HASH\"}" > /workspace/tfs/model_info.json

# First, copy just the preset files and install dependencies
# This is done before copying the code to utilize Docker's layer caching and
Expand Down
11 changes: 11 additions & 0 deletions presets/inference/llama2-chat/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import signal
import sys
import threading
import json
from typing import Optional

import GPUtil
Expand All @@ -18,6 +19,9 @@
from llama import Llama
from pydantic import BaseModel

# Constants
MODEL_INFO = "model_info.json"

# Setup argparse
parser = argparse.ArgumentParser(description="Llama API server.")
parser.add_argument("--ckpt_dir", default="weights/", help="Checkpoint directory.")
Expand Down Expand Up @@ -191,6 +195,13 @@ def get_metrics():
except Exception as e:
return {"error": str(e)}

@app_main.get("/version")
def get_version():
with open(f"/workspace/llama/{MODEL_INFO}", "r") as f:
model_info = json.load(f)

return model_info

def setup_worker_routes():
@app_worker.get("/healthz")
def health_check():
Expand Down
11 changes: 11 additions & 0 deletions presets/inference/llama2-completion/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import signal
import sys
import threading
import json
from typing import Optional

import GPUtil
Expand All @@ -18,6 +19,9 @@
from llama import Llama
from pydantic import BaseModel

# Constants
MODEL_INFO = "model_info.json"

# Setup argparse
parser = argparse.ArgumentParser(description="Llama API server.")
parser.add_argument("--ckpt_dir", default="weights/", help="Checkpoint directory.")
Expand Down Expand Up @@ -180,6 +184,13 @@ def get_metrics():
except Exception as e:
return {"error": str(e)}

@app_main.get("/version")
def get_version():
with open(f"/workspace/tfs/{MODEL_INFO}", "r") as f:
model_info = json.load(f)

return model_info

def setup_worker_routes():
@app_worker.get("/healthz")
def health_check():
Expand Down
Loading
Loading