Skip to content

Commit

Permalink
Fix review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
vinaysetty committed Jan 28, 2024
1 parent 413b1d3 commit 9212173
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 40 deletions.
42 changes: 24 additions & 18 deletions pkg_api/nl_to_pkg/llm/llm_connector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Module for querying LLM."""
import yaml

import os
from typing import Any, Dict

import yaml
from ollama import Client, Options
import os

_DEFAULT_ENDPOINT = "http://badne4.ux.uis.no:11434"
_DEFAULT_CONFIG_PATH = "pkg_api/nl_to_pkg/llm/configs/llm_config_llama2.yaml"


Expand All @@ -19,21 +17,33 @@ def __init__(
Args:
config_path: Path to the config file.
Raises:
ValueError: If no model is specified in the config.
ValueError: If no host is specified in the config.
FileNotFoundError: If the config file is not found.
"""
self._config = self._load_config(config_path)
self._model = self._config.get("model", "llama2")
self._config_path = config_path
self._config = self._load_config()
if "model" not in self._config:
raise ValueError(
"No model specified in the config. For example llama2"
)
if "host" not in self._config:
raise ValueError("No host specified in the config.")
self._client = Client(host=self._config.get("host"))
self._model = self._config.get("model")
self._stream = self._config.get("stream", False)
self._client = Client(host=self._config.get("host", _DEFAULT_ENDPOINT))
self._llm_options = self._get_llm_config()

def _generate(self, prompt: str) -> str:
def _generate(self, prompt: str) -> Dict[str, Any]:
"""Generates a response from LLM.
Args:
prompt: The prompt to be sent to LLM.
Returns:
The response with metadata from LLM.
The dict with response and metadata from LLM.
"""
return self._client.generate(
self._model, prompt, options=self._llm_options, stream=self._stream
Expand All @@ -43,29 +53,25 @@ def get_response(self, prompt: str) -> str:
"""Returns the response from LLM.
Args:
prompt: The prompt to be sent to LLM.
prompt: The prompt to be sent to LLM.
Returns:
The response from LLM, if it was successful.
"""
return self._generate(prompt).get("response", "") # type: ignore

@staticmethod
def _load_config(config_path: str) -> Dict[str, Any]:
def _load_config(self) -> Dict[str, Any]:
"""Loads the config from the given path.
Args:
config_path: Path to the config file.
Raises:
FileNotFoundError: If the file is not found.
Returns:
A dictionary containing the config keys and values.
"""
if not os.path.isfile(config_path):
raise FileNotFoundError(f"File {config_path} not found.")
with open(config_path, "r") as file:
if not os.path.isfile(self._config_path):
raise FileNotFoundError(f"File {self._config_path} not found.")
with open(self._config_path, "r") as file:
yaml_data = yaml.safe_load(file)
return yaml_data

Expand Down
46 changes: 24 additions & 22 deletions tests/nl_to_pkg/test_llm_connector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for LLM connector."""

from unittest.mock import mock_open

import pytest
from pytest_mock import MockerFixture

from pkg_api.nl_to_pkg.llm.llm_connector import LLMConnector
from unittest.mock import mock_open


@pytest.fixture
Expand All @@ -27,7 +29,11 @@ def llm_connector_mistral() -> LLMConnector:
def test_get_response_request_params_default(
llm_connector_default: LLMConnector,
) -> None:
"""Tests that get_response sends the correct request to LLM."""
"""Tests that get_response sends the correct request to LLM.
Args:
llm_connector_default: LLMConnector instance.
"""
response = llm_connector_default._generate("Test prompt")
assert response
assert response["model"] == "llama2"
Expand All @@ -40,7 +46,11 @@ def test_get_response_request_params_default(
def test_get_response_request_params_llama(
llm_connector_llama2: LLMConnector,
) -> None:
"""Tests that get_response sends the correct request to LLM."""
"""Tests that get_response sends the correct request to LLM.
Args:
llm_connector_llama2: LLMConnector instance with Llama2 instance.
"""
response = llm_connector_llama2._generate("Test prompt")
assert response["model"] == "llama2"
assert response["response"] == llm_connector_llama2.get_response(
Expand All @@ -52,7 +62,11 @@ def test_get_response_request_params_llama(
def test_get_response_request_params_mistral(
llm_connector_mistral: LLMConnector,
) -> None:
"""Tests that get_response sends the correct request to LLM."""
"""Tests that get_response sends the correct request to LLM.
Args:
llm_connector_mistral: LLMConnector instance for Mistral model config.
"""
response = llm_connector_mistral._generate("Test prompt")
assert response["model"] == "mistral"
assert response["response"] == llm_connector_mistral.get_response(
Expand All @@ -61,38 +75,26 @@ def test_get_response_request_params_mistral(
assert response


def test_load_config_success(mocker):
"""Tests that _load_config returns the correct config.
Args:
mocker: Mocking object.
"""
mocker.patch("builtins.open", mock_open(read_data="key: value"))
mocker.patch("os.path.isfile", return_value=True)
result = LLMConnector._load_config("fake_path")
assert result == {"key": "value"}


def test_load_config_file_not_found(mocker):
def test_load_config_file_not_found(mocker: MockerFixture) -> None:
"""Tests that _load_config raises FileNotFoundError.
Args:
mocker: Mocking object.
"""
mocker.patch("os.path.isfile", return_value=False)
with pytest.raises(FileNotFoundError):
LLMConnector._load_config("fake_path")
LLMConnector("fake_path")._load_config()


def test_load_config_content(mocker):
def test_load_config_content(mocker: MockerFixture) -> None:
"""Tests that _load_config returns the valid data.
Args:
mocker: Mocking object.
"""
mocker.patch(
"builtins.open", mock_open(read_data="key1: value1\nkey2: value2")
"builtins.open", mock_open(read_data="model: value1\nhost: value2")
)
mocker.patch("os.path.isfile", return_value=True)
result = LLMConnector._load_config("fake_path")
assert result == {"key1": "value1", "key2": "value2"}
result = LLMConnector(config_path="fake_path")._load_config()
assert result == {"model": "value1", "host": "value2"}

0 comments on commit 9212173

Please sign in to comment.