Skip to content

Commit

Permalink
Display values of variables used in a prompt
Browse files Browse the repository at this point in the history
- extract jinja variables from a prompt,
- new text area with values of variables from a prompt.
  • Loading branch information
Daniela Ovadova committed Oct 31, 2023
1 parent 2abd830 commit 24852b5
Showing 1 changed file with 29 additions and 4 deletions.
33 changes: 29 additions & 4 deletions prompterator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import prompterator.models as m
import prompterator.utils as u

from jinja2 import meta

# needed to use the simple custom component
# from apps.scripts.components_callbacks import register_callback
# from components.rate_buttons import rate_buttons
Expand Down Expand Up @@ -450,13 +452,34 @@ def _get_coloured_patch(patch):
def set_up_ui_labelling():
col1, col2 = st.columns([1, 1])
text_orig_length = len(st.session_state.get("text_orig", ""))
col1.text_area(
orig_text_container = col1.container()
orig_text_container.text_area(
label=f"Original text ({text_orig_length} chars)",
key="text_orig",
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)
labelling_container = col2.container()

env = u.jinja_env()
parsed_content = env.parse(st.session_state.system_prompt)
vars = meta.find_undeclared_variables(parsed_content)

if len(vars) > 1:
# create text of used prompt's variables and their values
vars_values = ""
for var in vars:
if var != c.TEXT_ORIG_COL:
vars_values = vars_values + var + ":\n" + " " + st.session_state.row[var] + "\n"

orig_text_container.text_area(
label=f"Attributes used in a prompt",
key="attributes",
value=vars_values,
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)

generated_text_area = col2.container()
text_generated_length = len(st.session_state.get("text_generated", ""))
length_change_percentage = (text_generated_length - text_orig_length) / text_orig_length * 100
length_change_percentage_str = (
Expand All @@ -467,18 +490,20 @@ def set_up_ui_labelling():
)

if not st.session_state.get("show_diff", False):
labelling_container.text_area(
generated_text_area.text_area(
label=generated_text_label,
key="text_generated",
value=st.session_state.get("text_generated", ""),
disabled=True,
height=c.DATA_POINT_TEXT_AREA_HEIGHT,
)
else:
labelling_container.markdown(
generated_text_area.markdown(
create_diff_viewer(generated_text_label), unsafe_allow_html=True
)

empty_col, labeling_col = st.columns([1, 1])
labelling_container = labeling_col.container()
col1, col2, col3, col4, col5, col6, col7 = labelling_container.columns([1, 1, 5, 1, 1, 1, 2])
col1.button(
"👍",
Expand Down

0 comments on commit 24852b5

Please sign in to comment.