From 9a102d2031789cb9923a167a1f0a2f10a8d47fd0 Mon Sep 17 00:00:00 2001 From: Gregor Karetka Date: Mon, 4 Nov 2024 20:03:19 +0100 Subject: [PATCH] address PR comments --- prompterator/models/openai_models.py | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/prompterator/models/openai_models.py b/prompterator/models/openai_models.py index 469fcce..6621fc8 100644 --- a/prompterator/models/openai_models.py +++ b/prompterator/models/openai_models.py @@ -128,9 +128,9 @@ def build_response_format(json_schema): @staticmethod def enrich_model_params_of_function_calling(structured_output_config, model_params): - if structured_output_data.enabled: - if structured_output_data.method == soi.FUNCTION_CALLING: - schema = json.loads(structured_output_data.schema) + if structured_output_config.enabled: + if structured_output_config.method == soi.FUNCTION_CALLING: + schema = json.loads(structured_output_config.schema) function_name = ChatGPTMixin.get_function_calling_tooling_name(schema) model_params["tools"] = ChatGPTMixin.build_function_calling_tooling( @@ -140,17 +140,17 @@ def enrich_model_params_of_function_calling(structured_output_config, model_para "type": "function", "function": {"name": function_name}, } - if structured_output_data.method == soi.RESPONSE_FORMAT: - schema = json.loads(structured_output_data.schema) + if structured_output_config.method == soi.RESPONSE_FORMAT: + schema = json.loads(structured_output_config.schema) model_params["response_format"] = ChatGPTMixin.build_response_format(schema) return model_params @staticmethod def process_response(structured_output_config, response_data): - if structured_output_data.enabled: - if structured_output_data.method == soi.FUNCTION_CALLING: + if structured_output_config.enabled: + if structured_output_config.method == soi.FUNCTION_CALLING: response_text = response_data.choices[0].message.tool_calls[0].function.arguments - elif structured_output_data.method == soi.RESPONSE_FORMAT: + elif structured_output_config.method == soi.RESPONSE_FORMAT: response_text = response_data.choices[0].message.content else: response_text = response_data.choices[0].message.content @@ -159,16 +159,17 @@ def process_response(structured_output_config, response_data): return response_text def call(self, idx, input, **kwargs): - structured_output_data: StructuredOutputConfig = kwargs["structured_output"] + structured_output_config: StructuredOutputConfig = kwargs["structured_output"] model_params = kwargs["model_params"] try: model_params = ChatGPTMixin.enrich_model_params_of_function_calling( - structured_output_data, model_params + structured_output_config, model_params ) except json.JSONDecodeError as e: logger.error( - "Error occurred while loading provided json schema" + "Error occurred while loading provided json schema. " + f"Provided schema {structured_output_config.schema}" "%d. Returning an empty response.", idx, exc_info=e, @@ -198,11 +199,13 @@ def call(self, idx, input, **kwargs): return {"idx": idx} try: - response_text = ChatGPTMixin.process_response(structured_output_data, response_data) + response_text = ChatGPTMixin.process_response(structured_output_config, response_data) return {"response": response_text, "data": response_data, "idx": idx} except KeyError as e: logger.error( - "Error occurred while processing response, response does not follow expected format" + "Error occurred while processing response," + "response does not follow expected format" + f"Response: {response_data}" "%d. Returning an empty response.", idx, exc_info=e,