From d25d26427b7462e395ab5840573ffcab4c06b36c Mon Sep 17 00:00:00 2001 From: Bangqi Zhu Date: Thu, 9 Jan 2025 19:36:36 -0800 Subject: [PATCH] RAG server patch for consistency Signed-off-by: Bangqi Zhu --- presets/ragengine/inference/inference.py | 40 +++++++++++++++++- presets/ragengine/main.py | 4 +- presets/ragengine/models.py | 4 +- presets/ragengine/tests/api/test_main.py | 41 ++++++++++++++++--- .../tests/vector_store/test_base_store.py | 24 +++++++++-- presets/ragengine/vector_store/base.py | 3 ++ 6 files changed, 102 insertions(+), 14 deletions(-) diff --git a/presets/ragengine/inference/inference.py b/presets/ragengine/inference/inference.py index f48248463..8576cb667 100644 --- a/presets/ragengine/inference/inference.py +++ b/presets/ragengine/inference/inference.py @@ -6,6 +6,7 @@ from llama_index.llms.openai import OpenAI from llama_index.core.llms.callbacks import llm_completion_callback import requests +from urllib.parse import urlparse, urljoin from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD OPENAI_URL_PREFIX = "https://api.openai.com" @@ -13,12 +14,41 @@ class Inference(CustomLLM): params: dict = {} + model: str = "" def set_params(self, params: dict) -> None: self.params = params def get_param(self, key, default=None): return self.params.get(key, default) + # Get base URL + def _get_base_url(self) -> str: + parsed = urlparse(LLM_INFERENCE_URL) + base_url = f"{parsed.scheme}://{parsed.netloc}" + return urljoin(base_url, "/v1/models") + + #Fetch and set the model from the inference endpoint + def set_model(self) -> None: + + try: + models_url = self._get_base_url() + headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} + response = requests.get(models_url, headers=headers) + + if response.status_code == 404: + self.model = None + return + + response.raise_for_status() + + data = response.json() + if data.get("data") and len(data["data"]) > 0: + self.model = data["data"][0]["id"] + else: + raise ValueError("No model found in response") + + except requests.RequestException as e: + raise Exception(f"Failed to fetch model information: {str(e)}") @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: @@ -53,8 +83,14 @@ def _huggingface_remote_complete(self, prompt: str, **kwargs: Any) -> Completion def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: headers = {"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} - data = {"prompt": prompt, **kwargs} - + if self.model != None: + data = {"prompt": prompt, "model":self.model} + else: + data = {"prompt": prompt} + + for param in self.params: + data[param] = self.params[param] + response = requests.post(LLM_INFERENCE_URL, json=data, headers=headers) response_data = response.json() diff --git a/presets/ragengine/main.py b/presets/ragengine/main.py index 56f891178..f3e63794a 100644 --- a/presets/ragengine/main.py +++ b/presets/ragengine/main.py @@ -60,7 +60,9 @@ async def index_documents(request: IndexRequest): # TODO: Research async/sync wh @app.post("/query", response_model=QueryResponse) async def query_index(request: QueryRequest): try: - llm_params = request.llm_params or {} # Default to empty dict if no params provided + llm_params = {} + for key, value in request.model_extra.items(): + llm_params[key] = value rerank_params = request.rerank_params or {} # Default to empty dict if no params provided return rag_ops.query(request.index_name, request.query, request.top_k, llm_params, rerank_params) except Exception as e: diff --git a/presets/ragengine/models.py b/presets/ragengine/models.py index a1b2ff529..1c5ba2616 100644 --- a/presets/ragengine/models.py +++ b/presets/ragengine/models.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class Document(BaseModel): text: str @@ -22,7 +22,7 @@ class QueryRequest(BaseModel): index_name: str query: str top_k: int = 10 - llm_params: Optional[Dict] = None # Accept a dictionary for parameters + model_config = ConfigDict(extra='allow') rerank_params: Optional[Dict] = None # Accept a dictionary for parameters class ListDocumentsResponse(BaseModel): diff --git a/presets/ragengine/tests/api/test_main.py b/presets/ragengine/tests/api/test_main.py index fee67dd7b..4e3947ddc 100644 --- a/presets/ragengine/tests/api/test_main.py +++ b/presets/ragengine/tests/api/test_main.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import patch +from unittest.mock import patch, MagicMock from llama_index.core.storage.index_store import SimpleIndexStore @@ -39,7 +39,15 @@ def test_index_documents_success(): assert not doc2["metadata"] @patch('requests.post') -def test_query_index_success(mock_post): +@patch('requests.get') +def test_query_index_success(mock_get, mock_post): + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + "data": [{"id": "test-model"}] + } + ) + # Define Mock Response for Custom Inference API mock_response = { "result": "This is the completion from the API" @@ -62,7 +70,7 @@ def test_query_index_success(mock_post): "index_name": "test_index", "query": "test query", "top_k": 1, - "llm_params": {"temperature": 0.7} + "temperature": 0.7 } response = client.post("/query", json=request_data) @@ -73,6 +81,28 @@ def test_query_index_success(mock_post): assert response.json()["source_nodes"][0]["score"] == pytest.approx(0.5354418754577637, rel=1e-6) assert response.json()["source_nodes"][0]["metadata"] == {} assert mock_post.call_count == 1 + assert mock_get.call_count == 1 + +@patch('requests.get') +def test_query_index_model_not_found(mock_get): + mock_get.return_value = MagicMock(status_code=404) + + request_data = { + "index_name": "test_index", + "query": "test query", + "top_k": 1, + "temperature": 0.7 + } + + index_data = { + "index_name": "test_index", + "documents": [{"text": "Test document"}] + } + response = client.post("/index", json=index_data) + assert response.status_code == 200 + + response = client.post("/query", json=request_data) + assert response.status_code == 200 @patch('requests.post') @@ -135,7 +165,7 @@ def test_reranker_and_query_with_index(mock_post): "index_name": "test_index", "query": "what is the capital of france?", "top_k": 5, - "llm_params": {"temperature": 0.7}, + "temperature": 0.7, "rerank_params": {"top_n": top_n} } @@ -171,14 +201,13 @@ def test_query_index_failure(): "index_name": "non_existent_index", # Use an index name that doesn't exist "query": "test query", "top_k": 1, - "llm_params": {"temperature": 0.7} + "temperature": 0.7 } response = client.post("/query", json=request_data) assert response.status_code == 500 assert response.json()["detail"] == "No such index: 'non_existent_index' exists." - def test_list_all_indexed_documents_success(): response = client.get("/indexed-documents") assert response.status_code == 200 diff --git a/presets/ragengine/tests/vector_store/test_base_store.py b/presets/ragengine/tests/vector_store/test_base_store.py index d3f49848f..5af58c764 100644 --- a/presets/ragengine/tests/vector_store/test_base_store.py +++ b/presets/ragengine/tests/vector_store/test_base_store.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import os -from unittest.mock import patch +from unittest.mock import patch, MagicMock import pytest from abc import ABC, abstractmethod @@ -66,7 +66,15 @@ def check_indexed_documents(self, vector_store_manager): pass @patch('requests.post') - def test_query_documents(self, mock_post, vector_store_manager): + @patch('requests.get') + def test_query_documents(self, mock_get, mock_post, vector_store_manager): + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + "data": [{"id": "test-model"}] + } + ) + mock_response = { "result": "This is the completion from the API" } @@ -87,13 +95,23 @@ def test_query_documents(self, mock_post, vector_store_manager): assert query_result["source_nodes"][0]["text"] == "First document" assert query_result["source_nodes"][0]["score"] == pytest.approx(self.expected_query_score, rel=1e-6) + mock_get.assert_called_once() + mock_post.assert_called_once_with( LLM_INFERENCE_URL, json={"prompt": "Context information is below.\n---------------------\ntype: text\n\nFirst document\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: First\nAnswer: ", "formatted": True, 'temperature': 0.7}, headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}"} ) - def test_add_document(self, vector_store_manager): + @patch('requests.get') + def test_add_document(self, mock_get, vector_store_manager): + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + "data": [{"id": "test-model"}] + } + ) + documents = [Document(text="Third document", metadata={"type": "text"})] vector_store_manager.index_documents("test_index", documents) diff --git a/presets/ragengine/vector_store/base.py b/presets/ragengine/vector_store/base.py index fd45b9c38..9b8d4a2d4 100644 --- a/presets/ragengine/vector_store/base.py +++ b/presets/ragengine/vector_store/base.py @@ -111,6 +111,9 @@ def query(self, """ if index_name not in self.index_map: raise ValueError(f"No such index: '{index_name}' exists.") + if self.llm.model == "": + self.llm.set_model() + self.llm.set_params(llm_params) node_postprocessors = []