Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: Display rendered System and User prompts #12

Merged
merged 1 commit into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ using **comma** as the separator, and enclosing values (where needed) in **doubl
### Using input data in prompts

The user/system prompt textboxes support [Jinja](https://jinja.palletsprojects.com/) templates.
Don't worry if you're new to Jinja -- Prompterator can show you a real-time "compiled" preview of
your prompts to help you write the templates.

Given a column named `text` in your uploaded CSV data, you can include values from this column by
writing the simple `{{text}}` template in your prompt.

Expand Down
1 change: 1 addition & 0 deletions prompterator/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,6 @@ def call(self, input, **kwargs):
DEFAULT_ROW_NO = 0
DATA_POINT_TEXT_AREA_HEIGHT = 180
PROMPT_TEXT_AREA_HEIGHT = 300
PROMPT_PREVIEW_TEXT_AREA_HEIGHT = 200

DATAFILE_FILTER_ALL = "all"
88 changes: 72 additions & 16 deletions prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def initialise_session_from_uploaded_file(df):

st.session_state["df"] = df
st.session_state[c.COLS_TO_SHOW_KEY] = [c.TEXT_ORIG_COL]
st.session_state.row = st.session_state.df.iloc[c.DEFAULT_ROW_NO]

if st.session_state.responses_generated_externally:
st.session_state.enable_labelling = True
Expand Down Expand Up @@ -432,17 +433,36 @@ def set_up_ui_generation():
height=c.PROMPT_TEXT_AREA_HEIGHT,
disabled=not model_supports_user_prompt,
)
col1, col2 = st.columns([1, 2])

if "df" in st.session_state:
cols_for_interpolation = set(st.session_state.df.columns).difference(
c.COLS_NOT_FOR_PROMPT_INTERPOLATION
prompt_parsing_error_message_area = st.empty()
col1, col2 = st.columns([3, 1])
cols_for_interpolation = list(
set(st.session_state.df.columns).difference(c.COLS_NOT_FOR_PROMPT_INTERPOLATION)
)
col1.write(
f"These are the columns available in the data, feel free to include them in "
f"your prompt: {cols_for_interpolation}"
f"your prompt(s): `{'`, `'.join(cols_for_interpolation)}`."
)
with col2:
st.toggle(label="show prompt preview", value=False, key="show_prompt_preview")

if st.session_state.show_prompt_preview:
col1, col2 = st.columns([3, 2])
set_up_prompt_preview(
col1,
st.session_state.system_prompt,
prompt_parsing_error_message_area,
prompt_kind="system",
)
set_up_prompt_preview(
col2,
st.session_state.user_prompt,
prompt_parsing_error_message_area,
prompt_kind="user",
)

col2.button(
st.button(
label="Run prompt",
on_click=run_prompt,
kwargs={"progress_ui_area": progress_ui_area},
Expand Down Expand Up @@ -488,29 +508,64 @@ def _get_coloured_patch(patch):
)


def set_up_prompt_attrs_area(st_container):
env = u.jinja_env()
parsed_content = env.parse(st.session_state.system_prompt)
vars = meta.find_undeclared_variables(parsed_content)
def set_up_prompt_vars_area(st_container, error_container):
try:
parsed_content = u.jinja_env().parse(
st.session_state.system_prompt + st.session_state.user_prompt
)
used_vars = meta.find_undeclared_variables(parsed_content)
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
error_container.error(
f"Couldn't parse the Jinja templates in the prompt(s). Ensure the "
f"templates are valid. Short error message: {e}\n\n"
f"Full error message:\n\n{traceback}"
)
used_vars = set()

if c.TEXT_ORIG_COL in vars:
vars.remove(c.TEXT_ORIG_COL)
if c.TEXT_ORIG_COL in used_vars:
used_vars.remove(c.TEXT_ORIG_COL)

if len(vars) > 0:
if len(used_vars) > 0:
# create text of used prompt's variables and their values
vars_values = ""
for var in vars:
for var in used_vars:
vars_values += var + ":\n " + st.session_state.row.get(var, "none") + "\n"

st_container.text_area(
label=f"Attributes used in a prompt",
label=f"Attributes other than `{c.TEXT_ORIG_COL}` used in the prompt(s)",
key="attributes",
value=vars_values,
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)


def set_up_prompt_preview(st_container, prompt, error_container, prompt_kind="system"):
if "row" in st.session_state:
try:
prompt = u.jinja_env().from_string(prompt).render(**st.session_state.row.to_dict())
except Exception as e:
traceback = u.format_traceback_for_markdown(tb.format_exc())
error_container.error(
f"Couldn't show {prompt_kind} prompt preview due to an error "
f"when parsing and rendering the Jinja templates. Ensure that "
f"your templates are valid. Short error message: {e}\n\n"
f"Full error message:\n\n{traceback}"
)
prompt = "ERROR"
else:
prompt = ""

st_container.text_area(
label=f"{prompt_kind.title()} prompt preview",
key=f"{prompt_kind}_prompt_jinja2",
value=prompt,
disabled=True,
height=c.PROMPT_PREVIEW_TEXT_AREA_HEIGHT,
)


def display_image(st_container, base64_str):
st_container.markdown(
f"""
Expand Down Expand Up @@ -592,15 +647,16 @@ def _handle_skip_past_label_rows_toggle():


def set_up_ui_labelling():
prompt_parsing_error_message_area = st.empty()
col1_orig, col2_orig = st.columns([1, 1])
text_orig_length = len(st.session_state.get("text_orig", ""))
col1_orig.text_area(
label=f"Original text ({text_orig_length} chars)",
label=f"Original text (`{c.TEXT_ORIG_COL}` column) ({text_orig_length} chars)",
key="text_orig",
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)
set_up_prompt_attrs_area(col2_orig)
set_up_prompt_vars_area(col2_orig, prompt_parsing_error_message_area)

if "image" in st.session_state.row:
display_image(col2_orig, st.session_state.row["image"])
Expand Down
Loading