Skip to content

Commit

Permalink
feat: add more logging
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaansehgal99 committed Jan 14, 2025
1 parent fc6f0ba commit dc3a0be
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions presets/ragengine/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

OPENAI_URL_PREFIX = "https://api.openai.com"
HUGGINGFACE_URL_PREFIX = "https://api-inference.huggingface.co"
DEFAULT_HEADERS = {
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
}

class Inference(CustomLLM):
params: dict = field(default_factory=dict)
Expand Down Expand Up @@ -64,24 +68,24 @@ def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse
# DEBUG: Call the debugging function
# self._debug_curl_command(data)
try:
return self._post_request(data, headers={
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
})
return self._post_request(data, headers=DEFAULT_HEADERS)
except Exception as e:
err_msg = str(e)
# Check for vLLM-specific missing model error
if "missing" in str(e) and "model" in str(e):
logger.warning("Detected missing 'model' parameter. Fetching default model and retrying...")
if "missing" in err_msg and "model" in err_msg and "Field required" in err_msg:
logger.warning(
f"Detected missing 'model' parameter in API response. "
f"Response: {err_msg}. Attempting to fetch the default model..."
)
self._default_model = self._fetch_default_model() # Fetch default model dynamically
if self._default_model:
logger.info(f"Default model '{self._default_model}' fetched successfully. Retrying request...")
data["model"] = self._default_model
return self._post_request(data, headers={
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
})
return self._post_request(data, headers=DEFAULT_HEADERS)
else:
logger.error("Failed to fetch a default model. Aborting retry.")
raise # Re-raise the exception if not recoverable


def _get_models_endpoint(self) -> str:
"""
Constructs the URL for the /v1/models endpoint based on LLM_INFERENCE_URL.
Expand All @@ -95,12 +99,7 @@ def _fetch_default_model(self) -> str:
"""
try:
models_url = self._get_models_endpoint()
headers = {
"Authorization": f"Bearer {LLM_ACCESS_SECRET}",
"Content-Type": "application/json"
}

response = requests.get(models_url, headers=headers)
response = requests.get(models_url, headers=DEFAULT_HEADERS)
response.raise_for_status() # Raise an exception for HTTP errors (includes 404)

models = response.json().get("data", [])
Expand Down

0 comments on commit dc3a0be

Please sign in to comment.