Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for struct. output #19

Merged
merged 11 commits into from
Nov 5, 2024
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,32 @@ which would lead to this in your prompt:
- C
```

### Using structured output (function calling and structured output)

*Structured Outputs is a feature that ensures the model will always generate responses that adhere to your supplied JSON Schema, so you don't need to worry about the model omitting a required key, or hallucinating an invalid enum value.* - OpenAI website

#### How to get it to work

1. Check with the model developer/provider whether the model supports some kind of structured output.
2. Toggle structured output switch
3. Select one of the supported structured output methods (a model might support all of them but also none of them):
- `None` - no structured output is used (equals to toggle being in off state)
- `Function calling` - hacky way of implementing structured outputs before `Response format` was implemented into API
- `Response format` - new way of implementing structured outputs
4. Provide JSON schema in `json schema` text input (can be generated from `pydantic` model or `zod` if you use `nodejs`) where `title` must satisfy `'^[a-zA-Z0-9_-]+$'`:
```json
{
"title": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"type": "object",
"properties": {
"order_id": {
"type": "string"
}
},
"required": ["order_id"],
"additionalProperties": false
}
### Postprocessing the model outputs

When working with LLMs, you would often postprocess the raw generated text. Prompterator
Expand Down
47 changes: 46 additions & 1 deletion prompterator/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import os
from enum import Enum
from typing import Any, Dict, Generic, Optional, TypeVar

from pydantic import BaseModel
Expand All @@ -16,6 +18,25 @@ class ConfigurableModelParameter(GenericModel, Generic[DataT]):
step: DataT


class StructuredOutputImplementation(Enum):
NONE = "None"
FUNCTION_CALLING = "Function calling"
RESPONSE_FORMAT = "Response format"


@dataclasses.dataclass
class StructuredOutputConfig:
enabled: bool
schema: str
method: StructuredOutputImplementation


@dataclasses.dataclass
class ModelInputs:
inputs: Dict[str, Any]
structured_output_data: StructuredOutputConfig


class ModelProperties(BaseModel):
name: str
is_chat_model: bool = False
Expand All @@ -26,7 +47,9 @@ class ModelProperties(BaseModel):

configurable_params: Dict[str, ConfigurableModelParameter] = {}
non_configurable_params: Dict[str, Any] = {}

supported_structured_output_implementations: Optional[list[StructuredOutputImplementation]] = [
StructuredOutputImplementation.NONE
]
# By default, models are sorted by their position index, which is used to order them in the UI.
# The 1e6 default value is used to ensure that models without a position index are sorted last.
position_index: int = int(1e6)
Expand Down Expand Up @@ -115,3 +138,25 @@ def call(self, input, **kwargs):
PROMPT_PREVIEW_TEXT_AREA_HEIGHT = 200

DATAFILE_FILTER_ALL = "all"

DEFAULT_JSON_SCHEMA = """{
gkaretka marked this conversation as resolved.
Show resolved Hide resolved
"title": "translate_json_schema",
"description": "Translation schema",
"type": "object",
"properties": {
"originalText": {
"type": "string",
"additionalProperties": false
},
"translatedText": {
"type": "string",
"additionalProperties": false
}
},
"required": [
"originalText",
"translatedText"
],
"additionalProperties": false,
"$schema": "http://json-schema.org/draft-07/schema#"
}"""
60 changes: 57 additions & 3 deletions prompterator/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
import traceback as tb
Expand Down Expand Up @@ -117,6 +118,19 @@ def run_prompt(progress_ui_area):
system_prompt_template = st.session_state.system_prompt
user_prompt_template = st.session_state.user_prompt

structured_output_enabled = st.session_state.structured_output_enabled
prompt_json_schema = None
selected_structured_output_method = None
if structured_output_enabled:
prompt_json_schema = st.session_state.prompt_json_schema
selected_structured_output_method = c.StructuredOutputImplementation(
st.session_state.selected_structured_output_method
)

structured_output_params = c.StructuredOutputConfig(
structured_output_enabled, prompt_json_schema, selected_structured_output_method
)

if not st.session_state.system_prompt.strip() and not st.session_state.user_prompt.strip():
st.error("Both prompts are empty, not running the text generation any further.")
return
Expand All @@ -127,12 +141,15 @@ def run_prompt(progress_ui_area):
df_old = st.session_state.df.copy()

try:
model_inputs = {
inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row
)
for i, row in df_old.iterrows()
}
model_inputs = c.ModelInputs(
inputs=inputs, structured_output_data=structured_output_params
)
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
st.error(
Expand All @@ -141,7 +158,7 @@ def run_prompt(progress_ui_area):
)
return

if len(model_inputs) == 0:
if len(model_inputs.inputs) == 0:
st.error("No input data to generate texts from!")
return

Expand Down Expand Up @@ -398,7 +415,7 @@ def set_up_ui_saved_datafiles():


def set_up_ui_generation():
col1, col2 = st.columns([1, 2])
col1, col2, col3 = st.columns([3, 3, 1])
col1.text_input(
placeholder="name your prompt version",
label="Prompt name",
Expand All @@ -412,6 +429,23 @@ def set_up_ui_generation():
key=c.PROMPT_COMMENT_KEY,
)
gkaretka marked this conversation as resolved.
Show resolved Hide resolved

selected_model: c.ModelProperties = st.session_state.model
available_structured_output_settings = (
selected_model.supported_structured_output_implementations
)

structured_output_available = (
c.StructuredOutputImplementation.FUNCTION_CALLING in available_structured_output_settings
or c.StructuredOutputImplementation.RESPONSE_FORMAT in available_structured_output_settings
)

structured_output_enabled = col3.toggle(
label="Structured output",
value=False,
key="structured_output_enabled",
disabled=(not structured_output_available),
)

progress_ui_area = st.empty()

col1, col2 = st.columns([3, 2])
Expand All @@ -434,6 +468,26 @@ def set_up_ui_generation():
disabled=not model_supports_user_prompt,
)

if structured_output_enabled and structured_output_available:
json_input = st.container()
json_input.text_area(
label="JSON Schema",
placeholder=c.DEFAULT_JSON_SCHEMA,
key="prompt_json_schema",
height=c.PROMPT_TEXT_AREA_HEIGHT,
)
struct_options_1, struct_options_2 = st.columns([3, 2])
struct_options_1.select_slider(
"Select which method to use for structured output",
options=[setting.value for setting in available_structured_output_settings],
key="selected_structured_output_method",
)

if u.validate_json(st.session_state.prompt_json_schema):
struct_options_2.success("JSON is valid", icon="🟢")
else:
struct_options_2.error("JSON is invalid", icon="🔴")
gkaretka marked this conversation as resolved.
Show resolved Hide resolved

if "df" in st.session_state:
prompt_parsing_error_message_area = st.empty()
col1, col2 = st.columns([3, 1])
Expand Down
Loading
Loading