Skip to content

Commit

Permalink
Merge pull request #78 from pauldotyu/main
Browse files Browse the repository at this point in the history
feat: add support for local llms running in cluster
  • Loading branch information
chzbrgr71 authored Nov 19, 2023
2 parents 845c893 + be9b671 commit fa0cb75
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 49 deletions.
3 changes: 2 additions & 1 deletion src/ai-service/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ pytest==7.3.1
httpx
pyyaml
semantic-kernel==0.3.1.dev0
azure.identity==1.14.0
azure.identity==1.14.0
requests==2.31.0
149 changes: 102 additions & 47 deletions src/ai-service/routers/description_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,53 +6,73 @@
from dotenv import load_dotenv
from typing import Any, List, Dict
import os
import dotenv
import requests
import json

# Load environment variables from .env file
load_dotenv()
# Set the useLocalLLM and useAzureOpenAI variables based on environment variables
useLocalLLM: bool = False
useAzureOpenAI: bool = False

# Initialize the semantic kernel
kernel: sk.Kernel = sk.Kernel()
if os.environ.get("USE_LOCAL_LLM"):
useLocalLLM = os.environ.get("USE_LOCAL_LLM").lower() == "true"

kernel = sk.Kernel()
if os.environ.get("USE_AZURE_OPENAI"):
useAzureOpenAI = os.environ.get("USE_AZURE_OPENAI").lower() == "true"

# Get the Azure OpenAI deployment name, API key, and endpoint or OpenAI org id from environment variables
useAzureOpenAI: str = os.environ.get("USE_AZURE_OPENAI")
api_key: str = os.environ.get("OPENAI_API_KEY")
useAzureAD: str = os.environ.get("USE_AZURE_AD")
# if useLocalLLM and useAzureOpenAI are both set to true, raise an exception
if useLocalLLM and useAzureOpenAI:
raise Exception("USE_LOCAL_LLM and USE_AZURE_OPENAI environment variables cannot both be set to true")

if (isinstance(api_key, str) == False or api_key == "") and (isinstance(useAzureAD, str) == False or useAzureAD == ""):
raise Exception("OPENAI_API_KEY environment variable must be set")
if isinstance(useAzureOpenAI, str) == False or (useAzureOpenAI.lower() != "true" and useAzureOpenAI.lower() != "false"):
raise Exception("USE_AZURE_OPENAI environment variable must be set to 'True' or 'False' string not boolean")
# if useLocalLLM or useAzureOpenAI are set to true, get the endpoint from the environment variables
if useLocalLLM or useAzureOpenAI:
endpoint: str = os.environ.get("AI_ENDPOINT") or os.environ.get("AZURE_OPENAI_ENDPOINT")

if isinstance(endpoint, str) == False or endpoint == "":
raise Exception("AI_ENDPOINT or AZURE_OPENAI_ENDPOINT environment variable must be set when USE_LOCAL_LLM or USE_AZURE_OPENAI is set to true")

# if not using local LLM, set up the semantic kernel
if useLocalLLM:
print("Using Local LLM")
else:
print("Using OpenAI and setting up Semantic Kernel")
# Load environment variables from .env file
load_dotenv()

if useAzureOpenAI.lower() == "false":
org_id = os.environ.get("OPENAI_ORG_ID")
if isinstance(org_id, str) == False or org_id == "":
raise Exception("OPENAI_ORG_ID environment variable must be set when USE_AZURE_OPENAI is set to False")
# Add the OpenAI text completion service to the kernel
kernel.add_chat_service("dv", OpenAIChatCompletion("gpt-3.5-turbo", api_key, org_id))
# Initialize the semantic kernel
kernel: sk.Kernel = sk.Kernel()

kernel = sk.Kernel()

# Get the Azure OpenAI deployment name, API key, and endpoint or OpenAI org id from environment variables
api_key: str = os.environ.get("OPENAI_API_KEY")
useAzureAD: str = os.environ.get("USE_AZURE_AD")

if (isinstance(api_key, str) == False or api_key == "") and (isinstance(useAzureAD, str) == False or useAzureAD == ""):
raise Exception("OPENAI_API_KEY environment variable must be set")

if not useAzureOpenAI:
org_id = os.environ.get("OPENAI_ORG_ID")
if isinstance(org_id, str) == False or org_id == "":
raise Exception("OPENAI_ORG_ID environment variable must be set when USE_AZURE_OPENAI is set to False")
# Add the OpenAI text completion service to the kernel
kernel.add_chat_service("dv", OpenAIChatCompletion("gpt-3.5-turbo", api_key, org_id))

else:
deployment: str = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
endpoint: str = os.environ.get("AZURE_OPENAI_ENDPOINT")
if isinstance(deployment, str) == False or isinstance(endpoint, str) == False or deployment == "" or endpoint == "":
raise Exception("AZURE_OPENAI_DEPLOYMENT_NAME and AZURE_OPENAI_ENDPOINT environment variables must be set when USE_AZURE_OPENAI is set to true")
# Add the Azure OpenAI text completion service to the kernel
if isinstance(useAzureAD, str) == True and useAzureAD.lower() == "true":
print("Authenticating to Azure OpenAI with Azure AD Workload Identity")
credential = DefaultAzureCredential()
access_token = credential.get_token("https://cognitiveservices.azure.com/.default")
kernel.add_chat_service("dv", AzureChatCompletion(deployment_name=deployment, endpoint=endpoint, api_key=access_token.token, ad_auth=True))
else:
print("Authenticating to Azure OpenAI with OpenAI API key")
kernel.add_chat_service("dv", AzureChatCompletion(deployment, endpoint, api_key))
deployment: str = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
# Add the Azure OpenAI text completion service to the kernel
if isinstance(useAzureAD, str) == True and useAzureAD.lower() == "true":
print("Authenticating to Azure OpenAI with Azure AD Workload Identity")
credential = DefaultAzureCredential()
access_token = credential.get_token("https://cognitiveservices.azure.com/.default")
kernel.add_chat_service("dv", AzureChatCompletion(deployment_name=deployment, endpoint=endpoint, api_key=access_token.token, ad_auth=True))
else:
print("Authenticating to Azure OpenAI with OpenAI API key")
kernel.add_chat_service("dv", AzureChatCompletion(deployment, endpoint, api_key))

# Import semantic skills from the "skills" directory
skills_directory: str = "skills"
productFunctions: dict = kernel.import_semantic_skill_from_directory(skills_directory, "ProductSkill")
descriptionFunction: Any = productFunctions["Description"]
# Import semantic skills from the "skills" directory
skills_directory: str = "skills"
productFunctions: dict = kernel.import_semantic_skill_from_directory(skills_directory, "ProductSkill")
descriptionFunction: Any = productFunctions["Description"]

# Define the description API router
description: APIRouter = APIRouter(prefix="/generate", tags=["generate"])
Expand All @@ -62,7 +82,7 @@ class Product:
def __init__(self, product: Dict[str, List]) -> None:
self.name: str = product["name"]
self.tags: List[str] = product["tags"]

# Define the post_description endpoint
@description.post("/description", summary="Get description for a product", operation_id="getDescription")
async def post_description(request: Request) -> JSONResponse:
Expand All @@ -73,15 +93,50 @@ async def post_description(request: Request) -> JSONResponse:
name: str = product.name
tags: List = ",".join(product.tags)

# Create a new context and invoke the description function
context: Any = kernel.create_new_context()
context["name"] = name
context["tags"] = tags
result: str = await descriptionFunction.invoke_async(context=context)
if "error" in str(result).lower():
return Response(content=str(result), status_code=status.HTTP_401_UNAUTHORIZED)
print(result)
result = str(result).replace("\n", "")
if useLocalLLM:
print("Calling local LLM")

prompt = f"Describe this pet store product using joyful, playful, and enticing language.\nProduct name: {name}\ntags: {tags}\ndescription:\""
temperature = 0.5
top_p = 0.0

url = endpoint
payload = {
"prompt": prompt,
"temperature": temperature,
"top_p": top_p
}
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, json=payload)

# convert response.text to json
result = json.loads(response.text)
result = result["Result"]
result = result.split("description:")[1]

# remove all double quotes
if "\"" in result:
result = result.replace("\"", "")

# # if first character is a double quote, remove it
# if result[0] == "\"":
# result = result[1:]
# # if last character is a double quote, remove it
# if result[-1] == "\"":
# result = result[:-1]

print(result)
else:
print("Calling OpenAI")
# Create a new context and invoke the description function
context: Any = kernel.create_new_context()
context["name"] = name
context["tags"] = tags
result: str = await descriptionFunction.invoke_async(context=context)
if "error" in str(result).lower():
return Response(content=str(result), status_code=status.HTTP_401_UNAUTHORIZED)
print(result)
result = str(result).replace("\n", "")

# Return the description as a JSON response
return JSONResponse(content={"description": result}, status_code=status.HTTP_200_OK)
Expand Down
2 changes: 1 addition & 1 deletion src/store-admin/src/components/ProductForm.vue
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
<div class="form-row">
<label for="product-description">Description</label>
<textarea id="product-description" placeholder="Product Description" v-model="product.description" />
<button @click="generateDescription" class="ai-button">Ask OpenAI</button>
<button @click="generateDescription" class="ai-button">Ask AI Assistant</button>
<input type="hidden" id="product-id" placeholder="Product ID" v-model="product.id" />
</div>

Expand Down

0 comments on commit fa0cb75

Please sign in to comment.