Skip to content

Commit

Permalink
feat: Add support for structured output and function calling (#19)
Browse files Browse the repository at this point in the history
* feat: Add support for struct. output

- add support for structured outputs
- with function calling
- with structured outputs
  • Loading branch information
gkaretka authored Nov 5, 2024
1 parent 5de48e1 commit 99a7eaf
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 16 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,32 @@ which would lead to this in your prompt:
- C
```

### Using structured output (function calling and structured output)

*Structured Outputs is a feature that ensures the model will always generate responses that adhere to your supplied JSON Schema, so you don't need to worry about the model omitting a required key, or hallucinating an invalid enum value.* - OpenAI website

#### How to get it to work

1. Check with the model developer/provider whether the model supports some kind of structured output.
2. Toggle structured output switch
3. Select one of the supported structured output methods (a model might support all of them but also none of them):
- `None` - no structured output is used (equals to toggle being in off state)
- `Function calling` - hacky way of implementing structured outputs before `Response format` was implemented into API
- `Response format` - new way of implementing structured outputs
4. Provide JSON schema in `json schema` text input (can be generated from `pydantic` model or `zod` if you use `nodejs`) where `title` must satisfy `'^[a-zA-Z0-9_-]+$'`:
```json
{
"title": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"type": "object",
"properties": {
"order_id": {
"type": "string"
}
},
"required": ["order_id"],
"additionalProperties": false
}
### Postprocessing the model outputs

When working with LLMs, you would often postprocess the raw generated text. Prompterator
Expand Down
47 changes: 46 additions & 1 deletion prompterator/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import os
from enum import Enum
from typing import Any, Dict, Generic, Optional, TypeVar

from pydantic import BaseModel
Expand All @@ -16,6 +18,25 @@ class ConfigurableModelParameter(GenericModel, Generic[DataT]):
step: DataT


class StructuredOutputImplementation(Enum):
NONE = "None"
FUNCTION_CALLING = "Function calling"
RESPONSE_FORMAT = "Response format"


@dataclasses.dataclass
class StructuredOutputConfig:
enabled: bool
schema: str
method: StructuredOutputImplementation


@dataclasses.dataclass
class ModelInputs:
inputs: Dict[str, Any]
structured_output_data: StructuredOutputConfig


class ModelProperties(BaseModel):
name: str
is_chat_model: bool = False
Expand All @@ -26,7 +47,9 @@ class ModelProperties(BaseModel):

configurable_params: Dict[str, ConfigurableModelParameter] = {}
non_configurable_params: Dict[str, Any] = {}

supported_structured_output_implementations: Optional[list[StructuredOutputImplementation]] = [
StructuredOutputImplementation.NONE
]
# By default, models are sorted by their position index, which is used to order them in the UI.
# The 1e6 default value is used to ensure that models without a position index are sorted last.
position_index: int = int(1e6)
Expand Down Expand Up @@ -115,3 +138,25 @@ def call(self, input, **kwargs):
PROMPT_PREVIEW_TEXT_AREA_HEIGHT = 200

DATAFILE_FILTER_ALL = "all"

DEFAULT_JSON_SCHEMA = """{
"title": "translate_json_schema",
"description": "Translation schema",
"type": "object",
"properties": {
"originalText": {
"type": "string",
"additionalProperties": false
},
"translatedText": {
"type": "string",
"additionalProperties": false
}
},
"required": [
"originalText",
"translatedText"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}"""
60 changes: 57 additions & 3 deletions prompterator/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import traceback as tb
Expand Down Expand Up @@ -117,6 +118,19 @@ def run_prompt(progress_ui_area):
system_prompt_template = st.session_state.system_prompt
user_prompt_template = st.session_state.user_prompt

structured_output_enabled = st.session_state.structured_output_enabled
prompt_json_schema = None
selected_structured_output_method = None
if structured_output_enabled:
prompt_json_schema = st.session_state.prompt_json_schema
selected_structured_output_method = c.StructuredOutputImplementation(
st.session_state.selected_structured_output_method
)

structured_output_params = c.StructuredOutputConfig(
structured_output_enabled, prompt_json_schema, selected_structured_output_method
)

if not st.session_state.system_prompt.strip() and not st.session_state.user_prompt.strip():
st.error("Both prompts are empty, not running the text generation any further.")
return
Expand All @@ -127,12 +141,15 @@ def run_prompt(progress_ui_area):
df_old = st.session_state.df.copy()

try:
model_inputs = {
inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row
)
for i, row in df_old.iterrows()
}
model_inputs = c.ModelInputs(
inputs=inputs, structured_output_data=structured_output_params
)
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
st.error(
Expand All @@ -141,7 +158,7 @@ def run_prompt(progress_ui_area):
)
return

if len(model_inputs) == 0:
if len(model_inputs.inputs) == 0:
st.error("No input data to generate texts from!")
return

Expand Down Expand Up @@ -398,7 +415,7 @@ def set_up_ui_saved_datafiles():


def set_up_ui_generation():
col1, col2 = st.columns([1, 2])
col1, col2, col3 = st.columns([3, 3, 1])
col1.text_input(
placeholder="name your prompt version",
label="Prompt name",
Expand All @@ -412,6 +429,23 @@ def set_up_ui_generation():
key=c.PROMPT_COMMENT_KEY,
)

selected_model: c.ModelProperties = st.session_state.model
available_structured_output_settings = (
selected_model.supported_structured_output_implementations
)

structured_output_available = (
c.StructuredOutputImplementation.FUNCTION_CALLING in available_structured_output_settings
or c.StructuredOutputImplementation.RESPONSE_FORMAT in available_structured_output_settings
)

structured_output_enabled = col3.toggle(
label="Structured output",
value=False,
key="structured_output_enabled",
disabled=(not structured_output_available),
)

progress_ui_area = st.empty()

col1, col2 = st.columns([3, 2])
Expand All @@ -434,6 +468,26 @@ def set_up_ui_generation():
disabled=not model_supports_user_prompt,
)

if structured_output_enabled and structured_output_available:
json_input = st.container()
json_input.text_area(
label="JSON Schema",
placeholder=c.DEFAULT_JSON_SCHEMA,
key="prompt_json_schema",
height=c.PROMPT_TEXT_AREA_HEIGHT,
)
struct_options_1, struct_options_2 = st.columns([3, 2])
struct_options_1.select_slider(
"Select which method to use for structured output",
options=[setting.value for setting in available_structured_output_settings],
key="selected_structured_output_method",
)

if u.validate_json(st.session_state.prompt_json_schema):
struct_options_2.success("JSON is valid", icon="🟢")
else:
struct_options_2.error("JSON is invalid", icon="🔴")

if "df" in st.session_state:
prompt_parsing_error_message_area = st.empty()
col1, col2 = st.columns([3, 1])
Expand Down
Loading

0 comments on commit 99a7eaf

Please sign in to comment.