Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for struct. output #19

Merged
merged 11 commits into from
Nov 5, 2024
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,35 @@ which would lead to this in your prompt:
- C
```

### Using structured output (function calling and structured output)

*Structured Outputs is a feature that ensures the model will always generate responses that adhere to your supplied JSON Schema, so you don't need to worry about the model omitting a required key, or hallucinating an invalid enum value.* - OpenAI website

#### How to get it to work

1. Check with the model developer/provider whether the model supports some kind of structured output.
2. Toggle structured output switch
3. Select one of the supported structured output methods (a model might support all of them but also none of them):
- `None` - no structured output is used (equals to toggle being in off state)
- `Function calling` - hacky way of implementing structured outputs before `Response format` was implemented into API
- `Response format` - new way of implementing structured outputs
4. Provide JSON schema in `json schema` text input (can be generated from `pydantic` model or `zod` if you use `nodejs`) where `title` must satisfy `'^[a-zA-Z0-9_-]+$'`:
```json
{
"title": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"type": "object",
"properties": {
"order_id": {
"type": "string"
}
},
"required": ["order_id"],
"additionalProperties": false
}
```
example from OpenAI website (slightly modified). For more examples see [https://json-schema.org/learn/miscellaneous-examples](https://json-schema.org/learn/miscellaneous-examples).

gkaretka marked this conversation as resolved.
Show resolved Hide resolved
### Postprocessing the model outputs

When working with LLMs, you would often postprocess the raw generated text. Prompterator
Expand Down
47 changes: 46 additions & 1 deletion prompterator/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import os
from enum import Enum
from typing import Any, Dict, Generic, Optional, TypeVar

from pydantic import BaseModel
Expand All @@ -16,6 +18,25 @@ class ConfigurableModelParameter(GenericModel, Generic[DataT]):
step: DataT


class StructuredOutputImplementation(Enum):
NONE = "None"
FUNCTION_CALLING = "Function calling"
RESPONSE_FORMAT = "Response format"


@dataclasses.dataclass
class StructuredOutputData:
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
enabled: bool
schema: str
method: StructuredOutputImplementation


@dataclasses.dataclass
class ModelInputs:
inputs: Dict[str, Any]
structured_output_data: StructuredOutputData


class ModelProperties(BaseModel):
name: str
is_chat_model: bool = False
Expand All @@ -26,7 +47,9 @@ class ModelProperties(BaseModel):

configurable_params: Dict[str, ConfigurableModelParameter] = {}
non_configurable_params: Dict[str, Any] = {}

supports_structured_output: Optional[list[StructuredOutputImplementation]] = [
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
StructuredOutputImplementation.NONE
]
# By default, models are sorted by their position index, which is used to order them in the UI.
# The 1e6 default value is used to ensure that models without a position index are sorted last.
position_index: int = int(1e6)
Expand Down Expand Up @@ -115,3 +138,25 @@ def call(self, input, **kwargs):
PROMPT_PREVIEW_TEXT_AREA_HEIGHT = 200

DATAFILE_FILTER_ALL = "all"

DEFAULT_JSON_SCHEMA = """{
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
"title": "translate_json_schema",
"description": "Translation schema",
"type": "object",
"properties": {
"originalText": {
"type": "string",
"additionalProperties": false
},
"translatedText": {
"type": "string",
"additionalProperties": false
}
},
"required": [
"originalText",
"translatedText"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}"""
62 changes: 58 additions & 4 deletions prompterator/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import traceback as tb
Expand Down Expand Up @@ -113,10 +114,22 @@ def set_up_dynamic_session_state_vars():

def run_prompt(progress_ui_area):
progress_bar = progress_ui_area.progress(0, text="generating texts")

gkaretka marked this conversation as resolved.
Show resolved Hide resolved
system_prompt_template = st.session_state.system_prompt
user_prompt_template = st.session_state.user_prompt

structured_output_enabled = st.session_state.structured_output_enabled
prompt_json_schema = None
selected_structured_output_method = None
if structured_output_enabled:
prompt_json_schema = st.session_state.prompt_json_schema
selected_structured_output_method = c.StructuredOutputImplementation(
st.session_state.selected_structured_output_method
)

structured_output_params = c.StructuredOutputData(
structured_output_enabled, prompt_json_schema, selected_structured_output_method
)

if not st.session_state.system_prompt.strip() and not st.session_state.user_prompt.strip():
st.error("Both prompts are empty, not running the text generation any further.")
return
Expand All @@ -127,12 +140,15 @@ def run_prompt(progress_ui_area):
df_old = st.session_state.df.copy()

try:
model_inputs = {
inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row
)
for i, row in df_old.iterrows()
}
model_inputs = c.ModelInputs(
inputs=inputs, structured_output_data=structured_output_params
)
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
st.error(
Expand All @@ -141,7 +157,7 @@ def run_prompt(progress_ui_area):
)
return

if len(model_inputs) == 0:
if len(model_inputs.inputs) == 0:
st.error("No input data to generate texts from!")
return

Expand Down Expand Up @@ -398,7 +414,7 @@ def set_up_ui_saved_datafiles():


def set_up_ui_generation():
col1, col2 = st.columns([1, 2])
col1, col2, col3 = st.columns([3, 3, 1])
col1.text_input(
placeholder="name your prompt version",
label="Prompt name",
Expand All @@ -411,6 +427,23 @@ def set_up_ui_generation():
label_visibility="collapsed",
key=c.PROMPT_COMMENT_KEY,
)
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
selected_model: c.ModelProperties = st.session_state.model
available_structured_output_settings = selected_model.supports_structured_output

# Allow structured outputs only if the model allows other implementation
# than NONE, other implementations currently include FUNCTION_CALLING
# and RESPONSE_FORMAT. Models by default do not require this parameter to be set.
structured_output_available = (
len(set(available_structured_output_settings) - {c.StructuredOutputImplementation.NONE})
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
> 0
)
samsucik marked this conversation as resolved.
Show resolved Hide resolved

structured_output_enabled = col3.toggle(
label="Structured output",
value=False,
key="structured_output_enabled",
disabled=(not structured_output_available),
)

progress_ui_area = st.empty()

Expand All @@ -434,6 +467,27 @@ def set_up_ui_generation():
disabled=not model_supports_user_prompt,
)

if structured_output_available and structured_output_enabled:
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
json_input = st.container()
json_input.text_area(
label="JSON Schema",
placeholder="Your JSON schema goes here",
value=c.DEFAULT_JSON_SCHEMA,
key="prompt_json_schema",
height=c.PROMPT_TEXT_AREA_HEIGHT,
)
struct_options_1, struct_options_2 = st.columns([3, 2])
struct_options_1.select_slider(
"Select which method to use for structured output",
options=[setting.value for setting in available_structured_output_settings],
key="selected_structured_output_method",
)

if u.validate_json(st.session_state.prompt_json_schema):
struct_options_2.success("JSON is valid", icon="🟢")
else:
struct_options_2.error("JSON is invalid", icon="🔴")
gkaretka marked this conversation as resolved.
Show resolved Hide resolved

if "df" in st.session_state:
prompt_parsing_error_message_area = st.empty()
col1, col2 = st.columns([3, 1])
Expand Down
69 changes: 68 additions & 1 deletion prompterator/models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import openai
from openai import AzureOpenAI, OpenAI

from prompterator.utils import build_function_calling_tooling, build_response_format

logger = logging.getLogger(__name__)

from prompterator.constants import ( # isort:skip
CONFIGURABLE_MODEL_PARAMETER_PROPERTIES,
ModelProperties,
PrompteratorLLM,
StructuredOutputImplementation as soi,
StructuredOutputData,
)


Expand Down Expand Up @@ -86,12 +90,39 @@ def __init__(self):
super().__init__()

def call(self, idx, input, **kwargs):
structured_output_data: StructuredOutputData = kwargs["structured_output"]
model_params = kwargs["model_params"]
try:
samsucik marked this conversation as resolved.
Show resolved Hide resolved
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
model_params["tools"], function_name = build_function_calling_tooling(
samsucik marked this conversation as resolved.
Show resolved Hide resolved
structured_output_data.schema
)
model_params["tool_choice"] = {
"type": "function",
"function": {"name": function_name},
}
if structured_output_data.method == soi.RESPONSE_FORMAT:
model_params["response_format"] = build_response_format(
structured_output_data.schema
)

response_data = self.client.chat.completions.create(
model=self.specific_model_name or self.name, messages=input, **model_params
)
response_text = response_data.choices[0].message.content

response_text = None
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
response_text = (
response_data.choices[0].message.tool_calls[0].function.arguments
)
elif structured_output_data.method == soi.RESPONSE_FORMAT:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content

return {"response": response_text, "data": response_data, "idx": idx}
except openai.RateLimitError as e:
Expand Down Expand Up @@ -129,6 +160,11 @@ class GPT4o(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=1,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
],
)


Expand All @@ -140,6 +176,11 @@ class GPT4oAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=6,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
],
)
openai_variant = "azure"
specific_model_name = "gpt-4o"
Expand All @@ -153,6 +194,11 @@ class GPT4oMini(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=2,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
],
)


Expand All @@ -164,6 +210,11 @@ class GPT4oMiniAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=7,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
],
)
openai_variant = "azure"
specific_model_name = "gpt-4o-mini"
Expand All @@ -177,6 +228,10 @@ class GPT35Turbo(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=3,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
],
)


Expand All @@ -188,6 +243,10 @@ class GPT35TurboAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=8,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
],
)
openai_variant = "azure"
specific_model_name = "gpt-35-turbo"
Expand All @@ -201,6 +260,10 @@ class GPT4(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=4,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
],
)


Expand All @@ -212,6 +275,10 @@ class GPT4Azure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=9,
supports_structured_output=[
soi.NONE,
soi.FUNCTION_CALLING,
],
)
openai_variant = "azure"
specific_model_name = "gpt-4"
Expand Down
Loading
Loading