Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for struct. output #19

Merged
merged 11 commits into from
Nov 5, 2024
29 changes: 13 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,22 +166,19 @@ which would lead to this in your prompt:
- `Function calling` - hacky way of implementing structured outputs before `Response format` was implemented into API
- `Response format` - new way of implementing structured outputs
4. Provide JSON schema in `json schema` text input (can be generated from `pydantic` model or `zod` if you use `nodejs`) where `title` must satisfy `'^[a-zA-Z0-9_-]+$'`:
```json
{
"title": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"type": "object",
"properties": {
"order_id": {
"type": "string"
}
},
"required": ["order_id"],
"additionalProperties": false
}
```
example from OpenAI website (slightly modified). For more examples see [https://json-schema.org/learn/miscellaneous-examples](https://json-schema.org/learn/miscellaneous-examples).

```json
{
"title": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"type": "object",
"properties": {
"order_id": {
"type": "string"
}
},
"required": ["order_id"],
"additionalProperties": false
}
### Postprocessing the model outputs

When working with LLMs, you would often postprocess the raw generated text. Prompterator
Expand Down
6 changes: 3 additions & 3 deletions prompterator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class StructuredOutputImplementation(Enum):


@dataclasses.dataclass
class StructuredOutputData:
class StructuredOutputConfig:
enabled: bool
schema: str
method: StructuredOutputImplementation
Expand All @@ -34,7 +34,7 @@ class StructuredOutputData:
@dataclasses.dataclass
class ModelInputs:
inputs: Dict[str, Any]
structured_output_data: StructuredOutputData
structured_output_data: StructuredOutputConfig


class ModelProperties(BaseModel):
Expand All @@ -47,7 +47,7 @@ class ModelProperties(BaseModel):

configurable_params: Dict[str, ConfigurableModelParameter] = {}
non_configurable_params: Dict[str, Any] = {}
supports_structured_output: Optional[list[StructuredOutputImplementation]] = [
supported_structured_output_implementations: Optional[list[StructuredOutputImplementation]] = [
StructuredOutputImplementation.NONE
]
# By default, models are sorted by their position index, which is used to order them in the UI.
Expand Down
20 changes: 10 additions & 10 deletions prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def set_up_dynamic_session_state_vars():

def run_prompt(progress_ui_area):
progress_bar = progress_ui_area.progress(0, text="generating texts")

system_prompt_template = st.session_state.system_prompt
user_prompt_template = st.session_state.user_prompt

Expand All @@ -126,7 +127,7 @@ def run_prompt(progress_ui_area):
st.session_state.selected_structured_output_method
)

structured_output_params = c.StructuredOutputData(
structured_output_params = c.StructuredOutputConfig(
structured_output_enabled, prompt_json_schema, selected_structured_output_method
)

Expand Down Expand Up @@ -427,15 +428,15 @@ def set_up_ui_generation():
label_visibility="collapsed",
key=c.PROMPT_COMMENT_KEY,
)
gkaretka marked this conversation as resolved.
Show resolved Hide resolved

selected_model: c.ModelProperties = st.session_state.model
available_structured_output_settings = selected_model.supports_structured_output
available_structured_output_settings = (
selected_model.supported_structured_output_implementations
)

# Allow structured outputs only if the model allows other implementation
# than NONE, other implementations currently include FUNCTION_CALLING
# and RESPONSE_FORMAT. Models by default do not require this parameter to be set.
structured_output_available = (
len(set(available_structured_output_settings) - {c.StructuredOutputImplementation.NONE})
> 0
c.StructuredOutputImplementation.FUNCTION_CALLING in available_structured_output_settings
or c.StructuredOutputImplementation.RESPONSE_FORMAT in available_structured_output_settings
)

structured_output_enabled = col3.toggle(
Expand Down Expand Up @@ -467,12 +468,11 @@ def set_up_ui_generation():
disabled=not model_supports_user_prompt,
)

if structured_output_available and structured_output_enabled:
if structured_output_enabled and structured_output_available:
json_input = st.container()
json_input.text_area(
label="JSON Schema",
placeholder="Your JSON schema goes here",
value=c.DEFAULT_JSON_SCHEMA,
placeholder=c.DEFAULT_JSON_SCHEMA,
key="prompt_json_schema",
height=c.PROMPT_TEXT_AREA_HEIGHT,
)
Expand Down
151 changes: 111 additions & 40 deletions prompterator/models/openai_models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import logging
import os
import time
import json

import openai
from openai import AzureOpenAI, OpenAI

from prompterator.utils import build_function_calling_tooling, build_response_format

logger = logging.getLogger(__name__)

from prompterator.constants import ( # isort:skip
CONFIGURABLE_MODEL_PARAMETER_PROPERTIES,
ModelProperties,
PrompteratorLLM,
StructuredOutputImplementation as soi,
StructuredOutputData,
StructuredOutputConfig,
)


Expand Down Expand Up @@ -89,42 +88,102 @@ def __init__(self):

super().__init__()

@staticmethod
def get_function_calling_tooling_name(json_schema):
function = json_schema.copy()
return function.pop("title")
gkaretka marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def build_function_calling_tooling(json_schema, function_name):
"""
@param function_name: name for the openai tool
@param json_schema: contains desired output schema in proper Json Schema format
@return: list[tools] is (a single function in this case) callable by OpenAI model
in function calling mode.
"""
function = json_schema.copy()
description = (
function.pop("description")
if function.get("description", None) is not None
else function_name
)
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
tools = [
{
"type": "function",
"function": {
"name": function_name,
"description": description,
"parameters": function,
},
}
]

return tools

@staticmethod
def build_response_format(json_schema):
"""
@param json_schema: contains desired output schema in proper Json Schema format
@return: dict with desired response format directly usable with OpenAI API
"""
schema = {"name": json_schema.pop("title"), "schema": json_schema, "strict": True}
samsucik marked this conversation as resolved.
Show resolved Hide resolved
response_format = {"type": "json_schema", "json_schema": schema}

return response_format

@staticmethod
def enrich_model_params_of_function_calling(structured_output_data, model_params):
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
schema = json.loads(structured_output_data.schema)
function_name = ChatGPTMixin.get_function_calling_tooling_name(schema)

model_params["tools"] = ChatGPTMixin.build_function_calling_tooling(
schema, function_name
)
model_params["tool_choice"] = {
"type": "function",
"function": {"name": function_name},
}
if structured_output_data.method == soi.RESPONSE_FORMAT:
schema = json.loads(structured_output_data.schema)
model_params["response_format"] = ChatGPTMixin.build_response_format(schema)
return model_params

@staticmethod
def process_response(structured_output_data, response_data):
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
response_text = response_data.choices[0].message.tool_calls[0].function.arguments
elif structured_output_data.method == soi.RESPONSE_FORMAT:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
return response_text

def call(self, idx, input, **kwargs):
structured_output_data: StructuredOutputData = kwargs["structured_output"]
structured_output_data: StructuredOutputConfig = kwargs["structured_output"]
model_params = kwargs["model_params"]

try:
samsucik marked this conversation as resolved.
Show resolved Hide resolved
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
model_params["tools"], function_name = build_function_calling_tooling(
structured_output_data.schema
)
model_params["tool_choice"] = {
"type": "function",
"function": {"name": function_name},
}
if structured_output_data.method == soi.RESPONSE_FORMAT:
model_params["response_format"] = build_response_format(
structured_output_data.schema
)
model_params = ChatGPTMixin.enrich_model_params_of_function_calling(
structured_output_data, model_params
)
except json.JSONDecodeError as e:
logger.error(
"Error occurred while loading provided json schema"
samsucik marked this conversation as resolved.
Show resolved Hide resolved
"%d. Returning an empty response.",
idx,
exc_info=e,
)
return {"idx": idx}

try:
response_data = self.client.chat.completions.create(
model=self.specific_model_name or self.name, messages=input, **model_params
)

response_text = None
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
response_text = (
response_data.choices[0].message.tool_calls[0].function.arguments
)
elif structured_output_data.method == soi.RESPONSE_FORMAT:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content

return {"response": response_text, "data": response_data, "idx": idx}
except openai.RateLimitError as e:
logger.error(
"OpenAI API rate limit reached when generating a response for text with index "
Expand All @@ -143,6 +202,18 @@ def call(self, idx, input, **kwargs):
)
return {"idx": idx}

try:
response_text = ChatGPTMixin.process_response(structured_output_data, response_data)
return {"response": response_text, "data": response_data, "idx": idx}
except KeyError as e:
logger.error(
"Error occurred while processing response, response does not follow expected format"
samsucik marked this conversation as resolved.
Show resolved Hide resolved
"%d. Returning an empty response.",
idx,
exc_info=e,
)
return {"idx": idx}

def format_prompt(self, system_prompt, user_prompt, **kwargs):
messages = []
if len(system_prompt.strip()) > 0:
Expand All @@ -160,7 +231,7 @@ class GPT4o(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=1,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
Expand All @@ -176,7 +247,7 @@ class GPT4oAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=6,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
Expand All @@ -194,7 +265,7 @@ class GPT4oMini(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=2,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
Expand All @@ -210,7 +281,7 @@ class GPT4oMiniAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=7,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
soi.RESPONSE_FORMAT,
Expand All @@ -228,7 +299,7 @@ class GPT35Turbo(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=3,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
],
Expand All @@ -243,7 +314,7 @@ class GPT35TurboAzure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=8,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
],
Expand All @@ -260,7 +331,7 @@ class GPT4(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=4,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
],
Expand All @@ -275,7 +346,7 @@ class GPT4Azure(ChatGPTMixin):
handles_batches_of_inputs=False,
configurable_params=CONFIGURABLE_MODEL_PARAMETER_PROPERTIES.copy(),
position_index=9,
supports_structured_output=[
supported_structured_output_implementations=[
soi.NONE,
soi.FUNCTION_CALLING,
],
Expand Down
Loading