From 000a8244f89af3d7f21e0072ff919e0bdee7b4c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20=C5=A0uppa?= Date: Tue, 27 Aug 2024 17:28:20 +0200 Subject: [PATCH] add: Display rendered System and User prompts (#12) * Ensure that when jinja2 is used, the rendered versions of the System and User prompts are also shown in the UI. Signed-off-by: mrshu --- README.md | 3 ++ prompterator/constants.py | 1 + prompterator/main.py | 88 ++++++++++++++++++++++++++++++++------- 3 files changed, 76 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index d304fd5..8e1bca5 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,9 @@ using **comma** as the separator, and enclosing values (where needed) in **doubl ### Using input data in prompts The user/system prompt textboxes support [Jinja](https://jinja.palletsprojects.com/) templates. +Don't worry if you're new to Jinja -- Prompterator can show you a real-time "compiled" preview of +your prompts to help you write the templates. + Given a column named `text` in your uploaded CSV data, you can include values from this column by writing the simple `{{text}}` template in your prompt. diff --git a/prompterator/constants.py b/prompterator/constants.py index 984e4bb..3d2f247 100644 --- a/prompterator/constants.py +++ b/prompterator/constants.py @@ -112,5 +112,6 @@ def call(self, input, **kwargs): DEFAULT_ROW_NO = 0 DATA_POINT_TEXT_AREA_HEIGHT = 180 PROMPT_TEXT_AREA_HEIGHT = 300 +PROMPT_PREVIEW_TEXT_AREA_HEIGHT = 200 DATAFILE_FILTER_ALL = "all" diff --git a/prompterator/main.py b/prompterator/main.py index af4e9fb..92f7351 100644 --- a/prompterator/main.py +++ b/prompterator/main.py @@ -68,6 +68,7 @@ def initialise_session_from_uploaded_file(df): st.session_state["df"] = df st.session_state[c.COLS_TO_SHOW_KEY] = [c.TEXT_ORIG_COL] + st.session_state.row = st.session_state.df.iloc[c.DEFAULT_ROW_NO] if st.session_state.responses_generated_externally: st.session_state.enable_labelling = True @@ -432,17 +433,36 @@ def set_up_ui_generation(): height=c.PROMPT_TEXT_AREA_HEIGHT, disabled=not model_supports_user_prompt, ) - col1, col2 = st.columns([1, 2]) + if "df" in st.session_state: - cols_for_interpolation = set(st.session_state.df.columns).difference( - c.COLS_NOT_FOR_PROMPT_INTERPOLATION + prompt_parsing_error_message_area = st.empty() + col1, col2 = st.columns([3, 1]) + cols_for_interpolation = list( + set(st.session_state.df.columns).difference(c.COLS_NOT_FOR_PROMPT_INTERPOLATION) ) col1.write( f"These are the columns available in the data, feel free to include them in " - f"your prompt: {cols_for_interpolation}" + f"your prompt(s): `{'`, `'.join(cols_for_interpolation)}`." ) + with col2: + st.toggle(label="show prompt preview", value=False, key="show_prompt_preview") + + if st.session_state.show_prompt_preview: + col1, col2 = st.columns([3, 2]) + set_up_prompt_preview( + col1, + st.session_state.system_prompt, + prompt_parsing_error_message_area, + prompt_kind="system", + ) + set_up_prompt_preview( + col2, + st.session_state.user_prompt, + prompt_parsing_error_message_area, + prompt_kind="user", + ) - col2.button( + st.button( label="Run prompt", on_click=run_prompt, kwargs={"progress_ui_area": progress_ui_area}, @@ -488,22 +508,32 @@ def _get_coloured_patch(patch): ) -def set_up_prompt_attrs_area(st_container): - env = u.jinja_env() - parsed_content = env.parse(st.session_state.system_prompt) - vars = meta.find_undeclared_variables(parsed_content) +def set_up_prompt_vars_area(st_container, error_container): + try: + parsed_content = u.jinja_env().parse( + st.session_state.system_prompt + st.session_state.user_prompt + ) + used_vars = meta.find_undeclared_variables(parsed_content) + except Exception as e: + traceback = u.format_traceback_for_markdown(tb.format_exc()) + error_container.error( + f"Couldn't parse the Jinja templates in the prompt(s). Ensure the " + f"templates are valid. Short error message: {e}\n\n" + f"Full error message:\n\n{traceback}" + ) + used_vars = set() - if c.TEXT_ORIG_COL in vars: - vars.remove(c.TEXT_ORIG_COL) + if c.TEXT_ORIG_COL in used_vars: + used_vars.remove(c.TEXT_ORIG_COL) - if len(vars) > 0: + if len(used_vars) > 0: # create text of used prompt's variables and their values vars_values = "" - for var in vars: + for var in used_vars: vars_values += var + ":\n " + st.session_state.row.get(var, "none") + "\n" st_container.text_area( - label=f"Attributes used in a prompt", + label=f"Attributes other than `{c.TEXT_ORIG_COL}` used in the prompt(s)", key="attributes", value=vars_values, disabled=True, @@ -511,6 +541,31 @@ def set_up_prompt_attrs_area(st_container): ) +def set_up_prompt_preview(st_container, prompt, error_container, prompt_kind="system"): + if "row" in st.session_state: + try: + prompt = u.jinja_env().from_string(prompt).render(**st.session_state.row.to_dict()) + except Exception as e: + traceback = u.format_traceback_for_markdown(tb.format_exc()) + error_container.error( + f"Couldn't show {prompt_kind} prompt preview due to an error " + f"when parsing and rendering the Jinja templates. Ensure that " + f"your templates are valid. Short error message: {e}\n\n" + f"Full error message:\n\n{traceback}" + ) + prompt = "ERROR" + else: + prompt = "" + + st_container.text_area( + label=f"{prompt_kind.title()} prompt preview", + key=f"{prompt_kind}_prompt_jinja2", + value=prompt, + disabled=True, + height=c.PROMPT_PREVIEW_TEXT_AREA_HEIGHT, + ) + + def display_image(st_container, base64_str): st_container.markdown( f""" @@ -592,15 +647,16 @@ def _handle_skip_past_label_rows_toggle(): def set_up_ui_labelling(): + prompt_parsing_error_message_area = st.empty() col1_orig, col2_orig = st.columns([1, 1]) text_orig_length = len(st.session_state.get("text_orig", "")) col1_orig.text_area( - label=f"Original text ({text_orig_length} chars)", + label=f"Original text (`{c.TEXT_ORIG_COL}` column) ({text_orig_length} chars)", key="text_orig", disabled=True, height=c.DATA_POINT_TEXT_AREA_HEIGHT, ) - set_up_prompt_attrs_area(col2_orig) + set_up_prompt_vars_area(col2_orig, prompt_parsing_error_message_area) if "image" in st.session_state.row: display_image(col2_orig, st.session_state.row["image"])