Skip to content

Commit

Permalink
feat: Add support for struct. output
Browse files Browse the repository at this point in the history
- add support for structured outputs
- with function calling
- with structured outputs

Co-authored-by: Marek Šuppa <mrshu@users.noreply.github.com>
  • Loading branch information
gkaretka and mrshu committed Oct 25, 2024
1 parent 5de48e1 commit 896999e
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 15 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,35 @@ which would lead to this in your prompt:
- C
```

### Using structured output (function calling and structured output)

*Structured Outputs is a feature that ensures the model will always generate responses that adhere to your supplied JSON Schema, so you don't need to worry about the model omitting a required key, or hallucinating an invalid enum value.* - OpenAI website

#### How to get it to work

1. Check with the model developer/provider whether the model supports some kind of structured output.
2. Toggle structured output switch
3. Select one of the supported structured output methods (a model might support all of them but also none of them):
- `None` - no structured output is used (equals to toggle being in off state)
- `Function calling` - hacky way of implementing structured outputs before `Response format` was implemented into API
- `Response format` - new way of implementing structured outputs
4. Provide JSON schema in `json schema` text input (can be generated from `pydantic` model or `zod` if you use `nodejs`):
```json
{
"title": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"type": "object",
"properties": {
"order_id": {
"type": "string"
}
},
"required": ["order_id"],
"additionalProperties": false
}
```
example from OpenAI website (slightly modified). For more examples see [https://json-schema.org/learn/miscellaneous-examples](https://json-schema.org/learn/miscellaneous-examples).

### Postprocessing the model outputs

When working with LLMs, you would often postprocess the raw generated text. Prompterator
Expand Down
47 changes: 46 additions & 1 deletion prompterator/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import os
from enum import Enum
from typing import Any, Dict, Generic, Optional, TypeVar

from pydantic import BaseModel
Expand All @@ -16,6 +18,25 @@ class ConfigurableModelParameter(GenericModel, Generic[DataT]):
step: DataT


class StructuredOutputImplementation(Enum):
NONE = "None"
FUNCTION_CALLING = "Function calling"
RESPONSE_FORMAT = "Response format"


@dataclasses.dataclass
class StructuredOutputData:
enabled: bool
schema: str
method: StructuredOutputImplementation


@dataclasses.dataclass
class ModelInputs:
inputs: Dict[str, Any]
structured_output_data: StructuredOutputData


class ModelProperties(BaseModel):
name: str
is_chat_model: bool = False
Expand All @@ -26,7 +47,9 @@ class ModelProperties(BaseModel):

configurable_params: Dict[str, ConfigurableModelParameter] = {}
non_configurable_params: Dict[str, Any] = {}

supports_structured_output: Optional[list[StructuredOutputImplementation]] = [
StructuredOutputImplementation.NONE
]
# By default, models are sorted by their position index, which is used to order them in the UI.
# The 1e6 default value is used to ensure that models without a position index are sorted last.
position_index: int = int(1e6)
Expand Down Expand Up @@ -115,3 +138,25 @@ def call(self, input, **kwargs):
PROMPT_PREVIEW_TEXT_AREA_HEIGHT = 200

DATAFILE_FILTER_ALL = "all"

DEFAULT_JSON_SCHEMA = """{
"title": "translate_json_schema",
"description": "Translation schema",
"type": "object",
"properties": {
"originalText": {
"type": "string",
"additionalProperties": false
},
"translatedText": {
"type": "string",
"additionalProperties": false
}
},
"required": [
"originalText",
"translatedText"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}"""
59 changes: 55 additions & 4 deletions prompterator/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import traceback as tb
Expand Down Expand Up @@ -113,10 +114,22 @@ def set_up_dynamic_session_state_vars():

def run_prompt(progress_ui_area):
progress_bar = progress_ui_area.progress(0, text="generating texts")

system_prompt_template = st.session_state.system_prompt
user_prompt_template = st.session_state.user_prompt

structured_output_enabled = st.session_state.structured_output_enabled
prompt_json_schema = None
selected_structured_output_method = None
if structured_output_enabled:
prompt_json_schema = st.session_state.prompt_json_schema
selected_structured_output_method = c.StructuredOutputImplementation(
st.session_state.selected_structured_output_method
)

structured_output_params = c.StructuredOutputData(
structured_output_enabled, prompt_json_schema, selected_structured_output_method
)

if not st.session_state.system_prompt.strip() and not st.session_state.user_prompt.strip():
st.error("Both prompts are empty, not running the text generation any further.")
return
Expand All @@ -127,12 +140,15 @@ def run_prompt(progress_ui_area):
df_old = st.session_state.df.copy()

try:
model_inputs = {
inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row
)
for i, row in df_old.iterrows()
}
model_inputs = c.ModelInputs(
inputs=inputs, structured_output_data=structured_output_params
)
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
st.error(
Expand All @@ -141,7 +157,7 @@ def run_prompt(progress_ui_area):
)
return

if len(model_inputs) == 0:
if len(model_inputs.inputs) == 0:
st.error("No input data to generate texts from!")
return

Expand Down Expand Up @@ -398,7 +414,7 @@ def set_up_ui_saved_datafiles():


def set_up_ui_generation():
col1, col2 = st.columns([1, 2])
col1, col2, col3 = st.columns([3, 3, 1])
col1.text_input(
placeholder="name your prompt version",
label="Prompt name",
Expand All @@ -411,6 +427,20 @@ def set_up_ui_generation():
label_visibility="collapsed",
key=c.PROMPT_COMMENT_KEY,
)
selected_model: c.ModelProperties = st.session_state.model
available_structured_output_settings = selected_model.supports_structured_output

structured_output_available = (
len(set(available_structured_output_settings) - {c.StructuredOutputImplementation.NONE})
> 0
)

structured_output_enabled = col3.toggle(
label="Structured output",
value=False,
key="structured_output_enabled",
disabled=(not structured_output_available),
)

progress_ui_area = st.empty()

Expand All @@ -434,6 +464,27 @@ def set_up_ui_generation():
disabled=not model_supports_user_prompt,
)

if structured_output_available and structured_output_enabled:
json_input = st.container()
json_input.text_area(
label="JSON Schema",
placeholder="Your JSON schema goes here",
value=c.DEFAULT_JSON_SCHEMA,
key="prompt_json_schema",
height=c.PROMPT_TEXT_AREA_HEIGHT,
)
struct_options_1, struct_options_2 = st.columns([3, 2])
struct_options_1.select_slider(
"Select which method to use for structured output",
options=[setting.value for setting in available_structured_output_settings],
key="selected_structured_output_method",
)

if u.validate_json(st.session_state.prompt_json_schema):
struct_options_2.success("JSON is valid", icon="🟢")
else:
struct_options_2.error("JSON is invalid", icon="🔴")

if "df" in st.session_state:
prompt_parsing_error_message_area = st.empty()
col1, col2 = st.columns([3, 1])
Expand Down
69 changes: 68 additions & 1 deletion prompterator/models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import openai
from openai import AzureOpenAI, OpenAI

from prompterator.utils import build_function_calling_tooling, build_response_format

logger = logging.getLogger(__name__)

from prompterator.constants import ( # isort:skip
CONFIGURABLE_MODEL_PARAMETER_PROPERTIES,
ModelProperties,
PrompteratorLLM,
StructuredOutputImplementation as soi,
StructuredOutputData,
)


Expand Down Expand Up @@ -86,12 +90,39 @@ def __init__(self):
super().__init__()

def call(self, idx, input, **kwargs):
structured_output_data: StructuredOutputData = kwargs["structured_output"]
model_params = kwargs["model_params"]
try:
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
model_params["tools"], function_name = build_function_calling_tooling(
structured_output_data.schema
)
model_params["tool_choice"] = {
"type": "function",
"function": {"name": function_name},
}
if structured_output_data.method == soi.RESPONSE_FORMAT:
model_params["response_format"] = build_response_format(
structured_output_data.schema
)

response_data = self.client.chat.completions.create(
model=self.specific_model_name or self.name, messages=input, **model_params
)
response_text = response_data.choices[0].message.content

response_text = None
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
response_text = (
response_data.choices[0].message.tool_calls[0].function.arguments
)
elif structured_output_data.method == soi.RESPONSE_FORMAT:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content

return {"response": response_text, "data": response_data, "idx": idx}
except openai.RateLimitError as e:
Expand Down Expand Up @@ -129,6 +160,11 @@ class GPT4o(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=1,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
],
)


Expand All @@ -140,6 +176,11 @@ class GPT4oAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=6,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
],
)
openai_variant = "azure"
specific_model_name = "gpt-4o"
Expand All @@ -153,6 +194,11 @@ class GPT4oMini(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=2,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
],
)


Expand All @@ -164,6 +210,11 @@ class GPT4oMiniAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=7,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
],
)
openai_variant = "azure"
specific_model_name = "gpt-4o-mini"
Expand All @@ -177,6 +228,10 @@ class GPT35Turbo(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=3,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
],
)


Expand All @@ -188,6 +243,10 @@ class GPT35TurboAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=8,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
],
)
openai_variant = "azure"
specific_model_name = "gpt-35-turbo"
Expand All @@ -201,6 +260,10 @@ class GPT4(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=4,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
],
)


Expand All @@ -212,6 +275,10 @@ class GPT4Azure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=9,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
],
)
openai_variant = "azure"
specific_model_name = "gpt-4"
Expand Down
Loading

0 comments on commit 896999e

Please sign in to comment.