diff --git a/README.md b/README.md index 8e1bca5..712280b 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,32 @@ which would lead to this in your prompt: - C ``` +### Using structured output (function calling and structured output) + +*Structured Outputs is a feature that ensures the model will always generate responses that adhere to your supplied JSON Schema, so you don't need to worry about the model omitting a required key, or hallucinating an invalid enum value.* - OpenAI website + +#### How to get it to work + +1. Check with the model developer/provider whether the model supports some kind of structured output. +2. Toggle structured output switch +3. Select one of the supported structured output methods (a model might support all of them but also none of them): + - `None` - no structured output is used (equals to toggle being in off state) + - `Function calling` - hacky way of implementing structured outputs before `Response format` was implemented into API + - `Response format` - new way of implementing structured outputs +4. Provide JSON schema in `json schema` text input (can be generated from `pydantic` model or `zod` if you use `nodejs`) where `title` must satisfy `'^[a-zA-Z0-9_-]+$'`: + ```json + { + "title": "get_delivery_date", + "description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'", + "type": "object", + "properties": { + "order_id": { + "type": "string" + } + }, + "required": ["order_id"], + "additionalProperties": false + } ### Postprocessing the model outputs When working with LLMs, you would often postprocess the raw generated text. Prompterator diff --git a/prompterator/constants.py b/prompterator/constants.py index 3d2f247..2d01ee1 100644 --- a/prompterator/constants.py +++ b/prompterator/constants.py @@ -1,4 +1,6 @@ +import dataclasses import os +from enum import Enum from typing import Any, Dict, Generic, Optional, TypeVar from pydantic import BaseModel @@ -16,6 +18,25 @@ class ConfigurableModelParameter(GenericModel, Generic[DataT]): step: DataT +class StructuredOutputImplementation(Enum): + NONE = "None" + FUNCTION_CALLING = "Function calling" + RESPONSE_FORMAT = "Response format" + + +@dataclasses.dataclass +class StructuredOutputConfig: + enabled: bool + schema: str + method: StructuredOutputImplementation + + +@dataclasses.dataclass +class ModelInputs: + inputs: Dict[str, Any] + structured_output_data: StructuredOutputConfig + + class ModelProperties(BaseModel): name: str is_chat_model: bool = False @@ -26,7 +47,9 @@ class ModelProperties(BaseModel): configurable_params: Dict[str, ConfigurableModelParameter] = {} non_configurable_params: Dict[str, Any] = {} - + supported_structured_output_implementations: Optional[list[StructuredOutputImplementation]] = [ + StructuredOutputImplementation.NONE + ] # By default, models are sorted by their position index, which is used to order them in the UI. # The 1e6 default value is used to ensure that models without a position index are sorted last. position_index: int = int(1e6) @@ -115,3 +138,25 @@ def call(self, input, **kwargs): PROMPT_PREVIEW_TEXT_AREA_HEIGHT = 200 DATAFILE_FILTER_ALL = "all" + +DEFAULT_JSON_SCHEMA = """{ + "title": "translate_json_schema", + "description": "Translation schema", + "type": "object", + "properties": { + "originalText": { + "type": "string", + "additionalProperties": false + }, + "translatedText": { + "type": "string", + "additionalProperties": false + } + }, + "required": [ + "originalText", + "translatedText" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" +}""" diff --git a/prompterator/main.py b/prompterator/main.py index 2b4af5d..3a0cd59 100644 --- a/prompterator/main.py +++ b/prompterator/main.py @@ -1,3 +1,4 @@ +import json import logging import os import traceback as tb @@ -117,6 +118,19 @@ def run_prompt(progress_ui_area): system_prompt_template = st.session_state.system_prompt user_prompt_template = st.session_state.user_prompt + structured_output_enabled = st.session_state.structured_output_enabled + prompt_json_schema = None + selected_structured_output_method = None + if structured_output_enabled: + prompt_json_schema = st.session_state.prompt_json_schema + selected_structured_output_method = c.StructuredOutputImplementation( + st.session_state.selected_structured_output_method + ) + + structured_output_params = c.StructuredOutputConfig( + structured_output_enabled, prompt_json_schema, selected_structured_output_method + ) + if not st.session_state.system_prompt.strip() and not st.session_state.user_prompt.strip(): st.error("Both prompts are empty, not running the text generation any further.") return @@ -127,12 +141,15 @@ def run_prompt(progress_ui_area): df_old = st.session_state.df.copy() try: - model_inputs = { + inputs = { i: u.create_model_input( model, model_instance, user_prompt_template, system_prompt_template, row ) for i, row in df_old.iterrows() } + model_inputs = c.ModelInputs( + inputs=inputs, structured_output_data=structured_output_params + ) except Exception as e: traceback = u.format_traceback_for_markdown(tb.format_exc()) st.error( @@ -141,7 +158,7 @@ def run_prompt(progress_ui_area): ) return - if len(model_inputs) == 0: + if len(model_inputs.inputs) == 0: st.error("No input data to generate texts from!") return @@ -398,7 +415,7 @@ def set_up_ui_saved_datafiles(): def set_up_ui_generation(): - col1, col2 = st.columns([1, 2]) + col1, col2, col3 = st.columns([3, 3, 1]) col1.text_input( placeholder="name your prompt version", label="Prompt name", @@ -412,6 +429,23 @@ def set_up_ui_generation(): key=c.PROMPT_COMMENT_KEY, ) + selected_model: c.ModelProperties = st.session_state.model + available_structured_output_settings = ( + selected_model.supported_structured_output_implementations + ) + + structured_output_available = ( + c.StructuredOutputImplementation.FUNCTION_CALLING in available_structured_output_settings + or c.StructuredOutputImplementation.RESPONSE_FORMAT in available_structured_output_settings + ) + + structured_output_enabled = col3.toggle( + label="Structured output", + value=False, + key="structured_output_enabled", + disabled=(not structured_output_available), + ) + progress_ui_area = st.empty() col1, col2 = st.columns([3, 2]) @@ -434,6 +468,26 @@ def set_up_ui_generation(): disabled=not model_supports_user_prompt, ) + if structured_output_enabled and structured_output_available: + json_input = st.container() + json_input.text_area( + label="JSON Schema", + placeholder=c.DEFAULT_JSON_SCHEMA, + key="prompt_json_schema", + height=c.PROMPT_TEXT_AREA_HEIGHT, + ) + struct_options_1, struct_options_2 = st.columns([3, 2]) + struct_options_1.select_slider( + "Select which method to use for structured output", + options=[setting.value for setting in available_structured_output_settings], + key="selected_structured_output_method", + ) + + if u.validate_json(st.session_state.prompt_json_schema): + struct_options_2.success("JSON is valid", icon="🟢") + else: + struct_options_2.error("JSON is invalid", icon="🔴") + if "df" in st.session_state: prompt_parsing_error_message_area = st.empty() col1, col2 = st.columns([3, 1]) diff --git a/prompterator/models/openai_models.py b/prompterator/models/openai_models.py index 449c7bf..1a72656 100644 --- a/prompterator/models/openai_models.py +++ b/prompterator/models/openai_models.py @@ -1,3 +1,4 @@ +import json import logging import os import time @@ -11,6 +12,8 @@ CONFIGURABLE_MODEL_PARAMETER_PROPERTIES, ModelProperties, PrompteratorLLM, + StructuredOutputImplementation as soi, + StructuredOutputConfig, ) @@ -85,15 +88,98 @@ def __init__(self): super().__init__() + @staticmethod + def get_function_calling_tooling_name(json_schema): + return json_schema["title"] + + @staticmethod + def build_function_calling_tooling(json_schema, function_name): + """ + @param function_name: name for the openai tool + @param json_schema: contains desired output schema in proper Json Schema format + @return: list[tools] is (a single function in this case) callable by OpenAI model + in function calling mode. + """ + function = json_schema.copy() + description = function.pop("description", function_name) + tools = [ + { + "type": "function", + "function": { + "name": function_name, + "description": description, + "parameters": function, + }, + } + ] + + return tools + + @staticmethod + def build_response_format(json_schema): + """ + @param json_schema: contains desired output schema in proper Json Schema format + @return: dict with desired response format directly usable with OpenAI API + """ + schema = {"name": json_schema.pop("title"), "schema": json_schema, "strict": True} + response_format = {"type": "json_schema", "json_schema": schema} + + return response_format + + @staticmethod + def enrich_model_params_of_function_calling(structured_output_config, model_params): + if structured_output_config.enabled: + if structured_output_config.method == soi.FUNCTION_CALLING: + schema = json.loads(structured_output_config.schema) + function_name = ChatGPTMixin.get_function_calling_tooling_name(schema) + + model_params["tools"] = ChatGPTMixin.build_function_calling_tooling( + schema, function_name + ) + model_params["tool_choice"] = { + "type": "function", + "function": {"name": function_name}, + } + if structured_output_config.method == soi.RESPONSE_FORMAT: + schema = json.loads(structured_output_config.schema) + model_params["response_format"] = ChatGPTMixin.build_response_format(schema) + return model_params + + @staticmethod + def process_response(structured_output_config, response_data): + if structured_output_config.enabled: + if structured_output_config.method == soi.FUNCTION_CALLING: + response_text = response_data.choices[0].message.tool_calls[0].function.arguments + elif structured_output_config.method == soi.RESPONSE_FORMAT: + response_text = response_data.choices[0].message.content + else: + response_text = response_data.choices[0].message.content + else: + response_text = response_data.choices[0].message.content + return response_text + def call(self, idx, input, **kwargs): + structured_output_config: StructuredOutputConfig = kwargs["structured_output"] model_params = kwargs["model_params"] + + try: + model_params = ChatGPTMixin.enrich_model_params_of_function_calling( + structured_output_config, model_params + ) + except json.JSONDecodeError as e: + logger.error( + "Error occurred while loading provided json schema. " + f"Provided schema {structured_output_config.schema}" + "%d. Returning an empty response.", + idx, + exc_info=e, + ) + return {"idx": idx} + try: response_data = self.client.chat.completions.create( model=self.specific_model_name or self.name, messages=input, **model_params ) - response_text = response_data.choices[0].message.content - - return {"response": response_text, "data": response_data, "idx": idx} except openai.RateLimitError as e: logger.error( "OpenAI API rate limit reached when generating a response for text with index " @@ -112,6 +198,20 @@ def call(self, idx, input, **kwargs): ) return {"idx": idx} + try: + response_text = ChatGPTMixin.process_response(structured_output_config, response_data) + return {"response": response_text, "data": response_data, "idx": idx} + except KeyError as e: + logger.error( + "Error occurred while processing response," + "response does not follow expected format" + f"Response: {response_data}" + "%d. Returning an empty response.", + idx, + exc_info=e, + ) + return {"idx": idx} + def format_prompt(self, system_prompt, user_prompt, **kwargs): messages = [] if len(system_prompt.strip()) > 0: @@ -129,6 +229,11 @@ class GPT4o(ChatGPTMixin): handles_batches_of_inputs=False, configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(), position_index=1, + supported_structured_output_implementations=[ + soi.NONE, + soi.FUNCTION_CALLING, + soi.RESPONSE_FORMAT, + ], ) @@ -140,6 +245,11 @@ class GPT4oAzure(ChatGPTMixin): handles_batches_of_inputs=False, configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(), position_index=6, + supported_structured_output_implementations=[ + soi.NONE, + soi.FUNCTION_CALLING, + soi.RESPONSE_FORMAT, + ], ) openai_variant = "azure" specific_model_name = "gpt-4o" @@ -153,6 +263,11 @@ class GPT4oMini(ChatGPTMixin): handles_batches_of_inputs=False, configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(), position_index=2, + supported_structured_output_implementations=[ + soi.NONE, + soi.FUNCTION_CALLING, + soi.RESPONSE_FORMAT, + ], ) @@ -164,6 +279,11 @@ class GPT4oMiniAzure(ChatGPTMixin): handles_batches_of_inputs=False, configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(), position_index=7, + supported_structured_output_implementations=[ + soi.NONE, + soi.FUNCTION_CALLING, + soi.RESPONSE_FORMAT, + ], ) openai_variant = "azure" specific_model_name = "gpt-4o-mini" @@ -177,6 +297,10 @@ class GPT35Turbo(ChatGPTMixin): handles_batches_of_inputs=False, configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(), position_index=3, + supported_structured_output_implementations=[ + soi.NONE, + soi.FUNCTION_CALLING, + ], ) @@ -188,6 +312,10 @@ class GPT35TurboAzure(ChatGPTMixin): handles_batches_of_inputs=False, configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(), position_index=8, + supported_structured_output_implementations=[ + soi.NONE, + soi.FUNCTION_CALLING, + ], ) openai_variant = "azure" specific_model_name = "gpt-35-turbo" @@ -201,6 +329,10 @@ class GPT4(ChatGPTMixin): handles_batches_of_inputs=False, configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(), position_index=4, + supported_structured_output_implementations=[ + soi.NONE, + soi.FUNCTION_CALLING, + ], ) @@ -212,6 +344,10 @@ class GPT4Azure(ChatGPTMixin): handles_batches_of_inputs=False, configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(), position_index=9, + supported_structured_output_implementations=[ + soi.NONE, + soi.FUNCTION_CALLING, + ], ) openai_variant = "azure" specific_model_name = "gpt-4" diff --git a/prompterator/utils.py b/prompterator/utils.py index f22fe04..344574c 100644 --- a/prompterator/utils.py +++ b/prompterator/utils.py @@ -77,7 +77,7 @@ def categorical_conditional_highlight(row, cond_column_name, palette): def generate_responses( model_properties: ModelProperties, model_instance: PrompteratorLLM, - inputs, + model_inputs: c.ModelInputs, model_params, progress_bar, ): @@ -88,11 +88,19 @@ def generate_responses( } if model_properties.handles_batches_of_inputs: results = generate_responses_using_batching( - model_properties, model_instance, inputs, model_params, progress_bar + model_properties, + model_instance, + model_inputs, + model_params, + progress_bar, ) else: results = generate_responses_using_parallelism( - model_properties, model_instance, inputs, model_params, progress_bar + model_properties, + model_instance, + model_inputs, + model_params, + progress_bar, ) return results @@ -112,11 +120,11 @@ def update_generation_progress_bar(bar, current, total): def generate_responses_using_batching( model_properties: ModelProperties, model_instance: PrompteratorLLM, - inputs, + model_inputs: c.ModelInputs, model_params, progress_bar, ): - inputs = list(inputs.values()) + inputs = list(model_inputs.inputs.values()) if model_properties.max_batch_size is not None: input_batches = split_inputs_into_batches(inputs, model_properties.max_batch_size) else: @@ -127,7 +135,12 @@ def generate_responses_using_batching( n_attempts = 0 while n_attempts < model_properties.max_retries: try: - result_batch = model_instance.call(n_attempts, batch, model_params=model_params) + result_batch = model_instance.call( + n_attempts, + batch, + model_params=model_params, + structured_output=model_inputs.structured_output_data, + ) result_batches.append(result_batch) break except Exception as e: @@ -153,7 +166,7 @@ def generate_responses_using_batching( def generate_responses_using_parallelism( model_properties: ModelProperties, model_instance: PrompteratorLLM, - inputs, + model_inputs: c.ModelInputs, model_params, progress_bar, ): @@ -167,11 +180,16 @@ def generate_responses_using_parallelism( model_properties=model_properties, model_params=model_params, ) - + inputs = model_inputs.inputs with ThreadPoolExecutor(max_workers=len(inputs)) as executor: # start all jobs for i, input in inputs.items(): - pj = executor.submit(generate_func, idx=i, input=input) + pj = executor.submit( + generate_func, + idx=i, + input=input, + structured_output=model_inputs.structured_output_data, + ) processed_jobs.append(pj) # retrieve results and show progress @@ -362,3 +380,11 @@ def insert_hidden_html_marker(helper_element_id, target_streamlit_element=None): def format_traceback_for_markdown(text): text = re.sub(r" ", " ", text) return re.sub(r"\n", "\n\n", text) + + +def validate_json(text): + try: + json.loads(text) + return True + except json.decoder.JSONDecodeError as e: + return False