Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Azure/kaito into Ishaan/RAG
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Oct 21, 2024
2 parents d32169f + 941170b commit 5520950
Show file tree
Hide file tree
Showing 21 changed files with 640 additions and 290 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ The above figure presents the Kaito architecture overview. Its major components
## Installation

Please check the installation guidance [here](./docs/installation.md).
Please check the installation guidance [here](./docs/installation.md) for deployment using Azure CLI and [here](./terraform/README.md) for deployment using Terraform.

## Quick start

Expand Down
2 changes: 1 addition & 1 deletion presets/test/falcon-benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Ensure your `accelerate` configuration aligns with the values provided during be
- If you haven't already, you can use the Azure CLI or the Azure Portal to create and configure a GPU node pool in your AKS cluster.
<!-- markdown-link-check-disable -->
2. Building and Pushing the Docker Image:
- First, you need to build a Docker image from the provided [Dockerfile](https://github.com/Azure/kaito/blob/main/docker/presets/inference/tfs/Dockerfile) and push it to a container registry accessible by your AKS cluster
- First, you need to build a Docker image from the provided [Dockerfile](https://github.com/Azure/kaito/blob/main/docker/presets/models/tfs/Dockerfile) and push it to a container registry accessible by your AKS cluster
<!-- markdown-link-check-enable -->
- Example:
```
Expand Down
20 changes: 20 additions & 0 deletions ragengine/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# config.py

# Variables are set via environment variables from the RAGEngine CR
# and exposed to the pod. For example, InferenceURL is specified in the CR and
# passed to the pod via env variables.

import os

EMBEDDING_TYPE = os.getenv("EMBEDDING_TYPE", "local")
EMBEDDING_URL = os.getenv("EMBEDDING_URL")

INFERENCE_URL = os.getenv("INFERENCE_URL", "http://localhost:5000/chat")
INFERENCE_ACCESS_SECRET = os.getenv("AccessSecret", "default-inference-secret")
# RESPONSE_FIELD = os.getenv("RESPONSE_FIELD", "result")

MODEL_ID = os.getenv("MODEL_ID", "BAAI/bge-small-en-v1.5")
VECTOR_DB_TYPE = os.getenv("VECTOR_DB_TYPE", "faiss")
INDEX_SERVICE_NAME = os.getenv("INDEX_SERVICE_NAME", "default-index-service")
ACCESS_SECRET = os.getenv("ACCESS_SECRET", "default-access-secret")
PERSIST_DIR = "storage"
53 changes: 53 additions & 0 deletions ragengine/inference/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any
from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen
from llama_index.llms.openai import OpenAI
from llama_index.core.llms.callbacks import llm_completion_callback
import requests
from ragengine.config import INFERENCE_URL, INFERENCE_ACCESS_SECRET #, RESPONSE_FIELD

class Inference(CustomLLM):
params: dict = {}

def set_params(self, params: dict) -> None:
self.params = params

def get_param(self, key, default=None):
return self.params.get(key, default)

@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
pass

@llm_completion_callback()
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
try:
if "openai" in INFERENCE_URL:
return self._openai_complete(prompt, **kwargs, **self.params)
else:
return self._custom_api_complete(prompt, **kwargs, **self.params)
finally:
# Clear params after the completion is done
self.params = {}

def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
llm = OpenAI(
api_key=INFERENCE_ACCESS_SECRET,
**kwargs # Pass all kwargs directly; kwargs may include model, temperature, max_tokens, etc.
)
return llm.complete(prompt)

def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
headers = {"Authorization": f"Bearer {INFERENCE_ACCESS_SECRET}"}
data = {"prompt": prompt, **kwargs}

response = requests.post(INFERENCE_URL, json=data, headers=headers)
response_data = response.json()

# Dynamically extract the field from the response based on the specified response_field
# completion_text = response_data.get(RESPONSE_FIELD, "No response field found") # not necessary for now
return CompletionResponse(text=str(response_data))

@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata()
1 change: 0 additions & 1 deletion ragengine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
class Document(BaseModel):
text: str
metadata: Optional[dict] = {}
doc_id: Optional[str] = None

class IndexRequest(BaseModel):
index_name: str
Expand Down
1 change: 1 addition & 0 deletions ragengine/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# RAG Library Requirements
llama-index
llama-index-embeddings-huggingface
fastapi
Expand Down
97 changes: 7 additions & 90 deletions ragengine/tests/vector_store/test_faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from unittest.mock import patch

import pytest
from vector_store.faiss_store import FaissVectorStoreHandler
from models import Document
from embedding.huggingface_local import LocalHuggingFaceEmbedding
from config import MODEL_ID, INFERENCE_URL, INFERENCE_ACCESS_SECRET
from ragengine.vector_store.faiss_store import FaissVectorStoreHandler
from ragengine.models import Document
from ragengine.embedding.huggingface_local import LocalHuggingFaceEmbedding
from ragengine.config import MODEL_ID, INFERENCE_URL, INFERENCE_ACCESS_SECRET

@pytest.fixture(scope='session')
def init_embed_manager():
Expand All @@ -15,7 +15,7 @@ def init_embed_manager():
@pytest.fixture
def vector_store_manager(init_embed_manager):
with TemporaryDirectory() as temp_dir:
print(f"Saving Temporary Test Storage at: {temp_dir}")
print(f"Saving temporary test storage at: {temp_dir}")
# Mock the persistence directory
os.environ['PERSIST_DIR'] = temp_dir
yield FaissVectorStoreHandler(init_embed_manager)
Expand Down Expand Up @@ -86,100 +86,17 @@ def test_query_documents(mock_post, vector_store_manager):
headers={"Authorization": f"Bearer {INFERENCE_ACCESS_SECRET}"}
)

"""
Commented because Refresh, Update, and Delete functionality are commented
def test_add_and_delete_document(vector_store_manager, capsys):
def test_add_document(vector_store_manager, capsys):
documents = [Document(doc_id="3", text="Third document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)

# Add a document to the existing index
new_document = Document(doc_id="4", text="Fourth document", metadata={"type": "text"})
vector_store_manager.add_document("test_index", new_document)
vector_store_manager.index_documents("test_index", new_document)

# Assert that the document exists
assert vector_store_manager.document_exists("test_index", "4")

# Delete the document - it should handle the NotImplementedError and not raise an exception
vector_store_manager.delete_document("test_index", "4")
# Capture the printed output (if any)
captured = capsys.readouterr()
# Check if the expected message about NotImplementedError was printed
assert "Delete not yet implemented for Faiss index. Skipping document 4." in captured.out
# Assert that the document still exists (since deletion wasn't implemented)
assert vector_store_manager.document_exists("test_index", "4")
def test_update_document_not_implemented(vector_store_manager, capsys):
# Test that updating a document raises a NotImplementedError and is handled properly.
# Add a document to the index
documents = [Document(doc_id="1", text="First document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
# Attempt to update the existing document
updated_document = Document(doc_id="1", text="Updated first document", metadata={"type": "text"})
vector_store_manager.update_document("test_index", updated_document)
# Capture the printed output (if any)
captured = capsys.readouterr()
# Check if the NotImplementedError message was printed
assert "Update is equivalent to deleting the document and then inserting it again." in captured.out
assert f"Update not yet implemented for Faiss index. Skipping document {updated_document.doc_id}." in captured.out
# Ensure the document remains unchanged
original_doc = vector_store_manager.get_document("test_index", "1")
assert original_doc is not None
def test_refresh_unchanged_documents(vector_store_manager, capsys):
# Test that refreshing documents does nothing on unchanged documents.
# Add documents to the index
documents = [Document(doc_id="1", text="First document", metadata={"type": "text"}),
Document(doc_id="2", text="Second document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", documents)
refresh_results = vector_store_manager.refresh_documents("test_index", documents)
# Capture the printed output (if any)
captured = capsys.readouterr()
assert captured.out == ""
assert refresh_results == [False, False]
def test_refresh_new_documents(vector_store_manager):
# Test that refreshing new documents creates them.
vector_store_manager.index_documents("test_index", [])
# Add a document to the index
documents = [Document(doc_id="1", text="First document", metadata={"type": "text"}),
Document(doc_id="2", text="Second document", metadata={"type": "text"})]
refresh_results = vector_store_manager.refresh_documents("test_index", documents)
inserted_documents = vector_store_manager.list_all_documents("test_index")
assert len(inserted_documents) == len(documents)
assert inserted_documents.keys() == {"1", "2"}
assert refresh_results == [True, True]
def test_refresh_existing_documents(vector_store_manager, capsys):
# Test that refreshing existing documents prints error.
original_documents = [Document(doc_id="1", text="First document", metadata={"type": "text"})]
vector_store_manager.index_documents("test_index", original_documents)
new_documents = [Document(doc_id="1", text="Updated document", metadata={"type": "text"}),
Document(doc_id="2", text="Second document", metadata={"type": "text"})]
refresh_results = vector_store_manager.refresh_documents("test_index", new_documents)
captured = capsys.readouterr()
# Check if the NotImplementedError message was printed
assert "Refresh not yet fully implemented for index" in captured.out
assert not refresh_results
"""

def test_persist_and_load_index_store(vector_store_manager):
"""Test that the index store is persisted and loaded correctly."""
# Add a document and persist the index
Expand Down
25 changes: 6 additions & 19 deletions ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from abc import ABC, abstractmethod
from typing import Dict, List

from models import Document
from ragengine.models import Document
from llama_index.core import VectorStoreIndex
import hashlib


class BaseVectorStore(ABC):
def generate_doc_id(text: str) -> str:
"""Generates a unique document ID based on the hash of the document text."""
return hashlib.sha256(text.encode('utf-8')).hexdigest()

@abstractmethod
def index_documents(self, index_name: str, documents: List[Document]) -> List[str]:
pass
Expand All @@ -18,24 +23,6 @@ def query(self, index_name: str, query: str, top_k: int, params: dict):
def add_document(self, index_name: str, document: Document):
pass

"""
@abstractmethod
def delete_document(self, doc_id: str, index_name: str):
pass
@abstractmethod
def update_document(self, document: Document, index_name: str) -> str:
pass
@abstractmethod
def refresh_documents(self, documents: List[Document], index_name: str) -> List[bool]:
pass
"""

@abstractmethod
def get_document(self, index_name: str, doc_id: str) -> Document:
pass

@abstractmethod
def list_all_indexed_documents(self) -> Dict[str, VectorStoreIndex]:
pass
Expand Down
Loading

0 comments on commit 5520950

Please sign in to comment.