Skip to content

Commit

Permalink
feat: Custom params for llm
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Oct 7, 2024
1 parent be9d6ed commit cf24953
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 11 deletions.
19 changes: 15 additions & 4 deletions ragengine/inference/custom_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,28 @@
from config import INFERENCE_URL, INFERENCE_ACCESS_SECRET, RESPONSE_FIELD

class CustomInference(CustomLLM):
params: dict = {}

def set_params(self, params: dict) -> None:
self.params = params

def get_param(self, key, default=None):
return self.params.get(key, default)

@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
pass

@llm_completion_callback()
def complete(self, prompt: str, **kwargs) -> CompletionResponse:
if "openai" in INFERENCE_URL:
return self._openai_complete(prompt, **kwargs)
else:
return self._custom_api_complete(prompt, **kwargs)
try:
if "openai" in INFERENCE_URL:
return self._openai_complete(prompt, **kwargs, **self.params)
else:
return self._custom_api_complete(prompt, **kwargs, **self.params)
finally:
# Clear params after the completion is done
self.params = {}

def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
llm = OpenAI(
Expand Down
1 change: 1 addition & 0 deletions ragengine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class IndexRequest(BaseModel):
class QueryRequest(BaseModel):
query: str
top_k: int = 10
params: Optional[Dict] = None # Accept a dictionary for parameters

class UpdateRequest(BaseModel):
documents: List[Document]
Expand Down
6 changes: 3 additions & 3 deletions ragengine/tests/test_faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def vector_store_manager(init_embed_manager):
os.environ['PERSIST_DIR'] = temp_dir
yield FaissVectorStoreHandler(init_embed_manager)


def test_index_documents(vector_store_manager):
documents = [
Document(doc_id="1", text="First document", metadata={"type": "text"}),
Expand Down Expand Up @@ -73,16 +72,17 @@ def test_query_documents(mock_post, vector_store_manager):
]
vector_store_manager.index_documents(documents, index_name="test_index")

params = {"temperature": 0.7}
# Mock query and results
query_result = vector_store_manager.query("First", top_k=1, index_name="test_index")
query_result = vector_store_manager.query("First", top_k=1, index_name="test_index", params=params)

assert query_result is not None
assert query_result.response == "This is the completion from the API"

mock_post.assert_called_once_with(
INFERENCE_URL,
# Auto-Generated by LlamaIndex
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True},
json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7},
headers={"Authorization": f"Bearer {INFERENCE_ACCESS_SECRET}"}
)

Expand Down
2 changes: 1 addition & 1 deletion ragengine/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def index_documents(self, documents: List[Document], index_name: str) -> List[st
pass

@abstractmethod
def query(self, query: str, top_k: int, index_name: str):
def query(self, query: str, top_k: int, index_name: str, params: dict):
pass

@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion ragengine/vector_store/faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def add_document(self, document: Document, index_name: str):
self.index_map[index_name].insert(llama_doc)
self._persist(index_name)

def query(self, query: str, top_k: int, index_name: str):
def query(self, query: str, top_k: int, index_name: str, params: dict):
"""Queries the FAISS vector store."""
if index_name not in self.index_map:
raise ValueError(f"No such index: '{index_name}' exists.")
self.llm.set_params(params)

query_engine = self.index_map[index_name].as_query_engine(llm=self.llm, similarity_top_k=top_k)
return query_engine.query(query)

Expand Down
4 changes: 2 additions & 2 deletions ragengine/vector_store_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def create(self, documents: List[Document]) -> List[str]:
"""Index new documents."""
return self.vector_store.index_documents(documents)

def read(self, query: str, top_k: int):
def read(self, query: str, top_k: int, params: dict):
"""Query the indexed documents."""
return self.vector_store.query(query, top_k)
return self.vector_store.query(query, top_k, params)

"""
def update(self, documents: List[Document]) -> Dict[str, List[str]]:
Expand Down

0 comments on commit cf24953

Please sign in to comment.