diff --git a/pkg_api/nl_to_pkg/llm/llm_connector.py b/pkg_api/nl_to_pkg/llm/llm_connector.py index d76c3ae..6c545b2 100644 --- a/pkg_api/nl_to_pkg/llm/llm_connector.py +++ b/pkg_api/nl_to_pkg/llm/llm_connector.py @@ -1,12 +1,10 @@ """Module for querying LLM.""" -import yaml - +import os from typing import Any, Dict +import yaml from ollama import Client, Options -import os -_DEFAULT_ENDPOINT = "http://badne4.ux.uis.no:11434" _DEFAULT_CONFIG_PATH = "pkg_api/nl_to_pkg/llm/configs/llm_config_llama2.yaml" @@ -19,21 +17,33 @@ def __init__( Args: config_path: Path to the config file. + + Raises: + ValueError: If no model is specified in the config. + ValueError: If no host is specified in the config. + FileNotFoundError: If the config file is not found. """ - self._config = self._load_config(config_path) - self._model = self._config.get("model", "llama2") + self._config_path = config_path + self._config = self._load_config() + if "model" not in self._config: + raise ValueError( + "No model specified in the config. For example llama2" + ) + if "host" not in self._config: + raise ValueError("No host specified in the config.") + self._client = Client(host=self._config.get("host")) + self._model = self._config.get("model") self._stream = self._config.get("stream", False) - self._client = Client(host=self._config.get("host", _DEFAULT_ENDPOINT)) self._llm_options = self._get_llm_config() - def _generate(self, prompt: str) -> str: + def _generate(self, prompt: str) -> Dict[str, Any]: """Generates a response from LLM. Args: prompt: The prompt to be sent to LLM. Returns: - The response with metadata from LLM. + The dict with response and metadata from LLM. """ return self._client.generate( self._model, prompt, options=self._llm_options, stream=self._stream @@ -43,29 +53,25 @@ def get_response(self, prompt: str) -> str: """Returns the response from LLM. Args: - prompt: The prompt to be sent to LLM. + prompt: The prompt to be sent to LLM. Returns: The response from LLM, if it was successful. """ return self._generate(prompt).get("response", "") # type: ignore - @staticmethod - def _load_config(config_path: str) -> Dict[str, Any]: + def _load_config(self) -> Dict[str, Any]: """Loads the config from the given path. - Args: - config_path: Path to the config file. - Raises: FileNotFoundError: If the file is not found. Returns: A dictionary containing the config keys and values. """ - if not os.path.isfile(config_path): - raise FileNotFoundError(f"File {config_path} not found.") - with open(config_path, "r") as file: + if not os.path.isfile(self._config_path): + raise FileNotFoundError(f"File {self._config_path} not found.") + with open(self._config_path, "r") as file: yaml_data = yaml.safe_load(file) return yaml_data diff --git a/tests/nl_to_pkg/test_llm_connector.py b/tests/nl_to_pkg/test_llm_connector.py index 36099d6..e112b93 100644 --- a/tests/nl_to_pkg/test_llm_connector.py +++ b/tests/nl_to_pkg/test_llm_connector.py @@ -1,9 +1,11 @@ """Tests for LLM connector.""" +from unittest.mock import mock_open + import pytest +from pytest_mock import MockerFixture from pkg_api.nl_to_pkg.llm.llm_connector import LLMConnector -from unittest.mock import mock_open @pytest.fixture @@ -27,7 +29,11 @@ def llm_connector_mistral() -> LLMConnector: def test_get_response_request_params_default( llm_connector_default: LLMConnector, ) -> None: - """Tests that get_response sends the correct request to LLM.""" + """Tests that get_response sends the correct request to LLM. + + Args: + llm_connector_default: LLMConnector instance. + """ response = llm_connector_default._generate("Test prompt") assert response assert response["model"] == "llama2" @@ -40,7 +46,11 @@ def test_get_response_request_params_default( def test_get_response_request_params_llama( llm_connector_llama2: LLMConnector, ) -> None: - """Tests that get_response sends the correct request to LLM.""" + """Tests that get_response sends the correct request to LLM. + + Args: + llm_connector_llama2: LLMConnector instance with Llama2 instance. + """ response = llm_connector_llama2._generate("Test prompt") assert response["model"] == "llama2" assert response["response"] == llm_connector_llama2.get_response( @@ -52,7 +62,11 @@ def test_get_response_request_params_llama( def test_get_response_request_params_mistral( llm_connector_mistral: LLMConnector, ) -> None: - """Tests that get_response sends the correct request to LLM.""" + """Tests that get_response sends the correct request to LLM. + + Args: + llm_connector_mistral: LLMConnector instance for Mistral model config. + """ response = llm_connector_mistral._generate("Test prompt") assert response["model"] == "mistral" assert response["response"] == llm_connector_mistral.get_response( @@ -61,19 +75,7 @@ def test_get_response_request_params_mistral( assert response -def test_load_config_success(mocker): - """Tests that _load_config returns the correct config. - - Args: - mocker: Mocking object. - """ - mocker.patch("builtins.open", mock_open(read_data="key: value")) - mocker.patch("os.path.isfile", return_value=True) - result = LLMConnector._load_config("fake_path") - assert result == {"key": "value"} - - -def test_load_config_file_not_found(mocker): +def test_load_config_file_not_found(mocker: MockerFixture) -> None: """Tests that _load_config raises FileNotFoundError. Args: @@ -81,18 +83,18 @@ def test_load_config_file_not_found(mocker): """ mocker.patch("os.path.isfile", return_value=False) with pytest.raises(FileNotFoundError): - LLMConnector._load_config("fake_path") + LLMConnector("fake_path")._load_config() -def test_load_config_content(mocker): +def test_load_config_content(mocker: MockerFixture) -> None: """Tests that _load_config returns the valid data. Args: mocker: Mocking object. """ mocker.patch( - "builtins.open", mock_open(read_data="key1: value1\nkey2: value2") + "builtins.open", mock_open(read_data="model: value1\nhost: value2") ) mocker.patch("os.path.isfile", return_value=True) - result = LLMConnector._load_config("fake_path") - assert result == {"key1": "value1", "key2": "value2"} + result = LLMConnector(config_path="fake_path")._load_config() + assert result == {"model": "value1", "host": "value2"}