Skip to content

Commit

Permalink
fix: Better handling of vllm vs non-vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Jan 14, 2025
1 parent 45ba8cb commit fc6f0ba
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions presets/ragengine/inference/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Any
from dataclasses import field
from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen
Expand All @@ -10,12 +11,17 @@
from urllib.parse import urlparse, urljoin
from ragengine.config import LLM_INFERENCE_URL, LLM_ACCESS_SECRET #, LLM_RESPONSE_FIELD

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

OPENAI_URL_PREFIX = "https://api.openai.com"
HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co"

class Inference(CustomLLM):
params: dict = field(default_factory=dict)
_default_model: str = None
_custom_api_endpoint_type: str = None # "vllm", "non-vllm", or None

def set_params(self, params: dict) -> None:
self.params = params
Expand Down Expand Up @@ -57,11 +63,24 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse

# DEBUG: Call the debugging function
# self._debug_curl_command(data)
try:
return self._post_request(data, headers={
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
})
except Exception as e:
# Check for vLLM-specific missing model error
if "missing" in str(e) and "model" in str(e):
logger.warning("Detected missing 'model' parameter. Fetching default model and retrying...")
self._default_model = self._fetch_default_model() # Fetch default model dynamically
if self._default_model:
data["model"] = self._default_model
return self._post_request(data, headers={
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
})
raise # Re-raise the exception if not recoverable

return self._post_request(
data,
headers={"Authorization": f"Bearer {LLM_ACCESS_SECRET}", "Content-Type": "application/json"}
)

def _get_models_endpoint(self) -> str:
"""
Expand All @@ -82,17 +101,19 @@ def _fetch_default_model(self) -> str:
}

response = requests.get(models_url, headers=headers)
response.raise_for_status() # Raise an exception for HTTP errors
response.raise_for_status() # Raise an exception for HTTP errors (includes 404)

models = response.json().get("data", [])
self._custom_api_endpoint_type = "vllm"
return models[0].get("id") if models else None
except Exception as e:
print(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.")
logger.error(f"Error fetching models from {models_url}: {e}. \"model\" parameter will not be included with inference call.")
self._custom_api_endpoint_type = "non-vllm"
return None

def _get_default_model(self) -> str:
"""
Returns the cached default model if available; otherwise fetches and caches it.
Returns the cached default model if available, otherwise fetches and caches it.
"""
if not self._default_model:
self._default_model = self._fetch_default_model()
Expand All @@ -105,7 +126,7 @@ def _post_request(self, data: dict, headers: dict) -> CompletionResponse:
response_data = response.json()
return CompletionResponse(text=str(response_data))
except requests.RequestException as e:
print(f"Error during POST request to {LLM_INFERENCE_URL}: {e}")
logger.error(f"Error during POST request to {LLM_INFERENCE_URL}: {e}")
raise

def _debug_curl_command(self, data: dict) -> None:
Expand All @@ -122,8 +143,8 @@ def _debug_curl_command(self, data: dict) -> None:
}.items()])
+ f" -d '{json.dumps(data)}'"
)
print("Equivalent curl command:")
print(curl_command)
logger.info("Equivalent curl command:")
logger.info(curl_command)

@property
def metadata(self) -> LLMMetadata:
Expand Down

0 comments on commit fc6f0ba

Please sign in to comment.