From 1ec8472a9de2f80e8eea1e18d6635ee0ae324518 Mon Sep 17 00:00:00 2001 From: Sam Sucik Date: Wed, 27 Mar 2024 17:10:31 +0100 Subject: [PATCH] Fix and improve experience when using Jinja to iterate over objects DATANG-3679 (#14) * Add another Jinja function for parsing AST strings that aren't valid JSON strings * Show meaningful error message instead of breaking Prompterator when there are issues with usage of fromjson or fromAstString in Jinja templates * Format code * Add brief documentation in README for using Jinja templates in prompts * Improve the Jinja example in README * Further improve the Jinja templating documentation * Fix reading of backslash-escaped stuff from CSV files * Improve Jinja docs in README * Improve naming and remove useless arg * Simplify Jinja-related docs and add some general usage tips --- README.md | 40 ++++++++++++++++++++++++++++++++++++++++ prompterator/main.py | 22 ++++++++++++++++------ prompterator/utils.py | 33 ++++++++++++++++++++++++++++++--- 3 files changed, 86 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index c9290bb..3f3d9ad 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,46 @@ The image will be rendered inside the displayed dataframe and next to the "gener (*Note: you also need an `OPENAI_API_KEY` environment variable to use `gpt-4-vision-preview`*) +## Usage guide + +### Input format + +Prompterator accepts CSV files as input. Additionally, the CSV data should follow these rules: +- be parseable using a +[`pd.read_csv`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html) +call with the default argument values. This means e.g. having **column names** in the first row, +using **comma** as the separator, and enclosing values (where needed) in **double quotes** (`"`) +- have a column named `text` + +### Using input data in prompts + +The user/system prompt textboxes support [Jinja](https://jinja.palletsprojects.com/) templates. +Given a column named `text` in your uploaded CSV data, you can include values from this column by +writing the simple `{{text}}` template in your prompt. + +If the values in your column represent more complex objects, you can still work with them but make +sure they are either valid JSON strings or valid Python expressions accepted by +[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval). + +To parse string representations of objects, use: +- `fromjson`: for valid JSON strings, e.g. `'["A", "B"]'` +- `fromAstString`: for Python expressions such as dicts/lists/tuples/... (see the accepted types of + [`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval)), e.g. `"{'key': 'value'}"` + +For example, given a CSV column `texts` with a value `"[""A"", ""B"", ""C""]"`, you can utilise this template to enumerate the individual list items +in your prompt: +```jinja +{% for item in fromjson(texts) -%} +- {{ item }} +{% endfor %} +``` +which would lead to this in your prompt: +``` +- A +- B +- C +``` + ## Paper You can find more information on Prompterator in the associated paper: https://aclanthology.org/2023.emnlp-demo.43/ diff --git a/prompterator/main.py b/prompterator/main.py index 476ae29..ce0dded 100644 --- a/prompterator/main.py +++ b/prompterator/main.py @@ -1,5 +1,6 @@ import logging import os +import traceback as tb from collections import OrderedDict from datetime import datetime @@ -113,12 +114,21 @@ def run_prompt(progress_ui_area): model_instance = m.MODEL_INSTANCES[model.name] model_params = {param: st.session_state[param] for param in model.configurable_params} df_old = st.session_state.df.copy() - model_inputs = { - i: u.create_model_input( - model, model_instance, user_prompt_template, system_prompt_template, row + + try: + model_inputs = { + i: u.create_model_input( + model, model_instance, user_prompt_template, system_prompt_template, row + ) + for i, row in df_old.iterrows() + } + except Exception as e: + traceback = u.format_traceback_for_markdown(tb.format_exc()) + st.error( + f"Couldn't prepare model inputs due to this error: {e}\n\nFull error " + f"message:\n\n{traceback}" ) - for i, row in df_old.iterrows() - } + return if len(model_inputs) == 0: st.error("No input data to generate texts from!") @@ -604,7 +614,7 @@ def show_dataframe(): def process_uploaded_file(): if st.session_state.uploaded_file is not None: - df = pd.read_csv(st.session_state.uploaded_file, header=0) + df = pd.read_csv(st.session_state.uploaded_file) assert c.TEXT_ORIG_COL in df.columns st.session_state.responses_generated_externally = c.TEXT_GENERATED_COL in df.columns initialise_session_from_uploaded_file(df) diff --git a/prompterator/utils.py b/prompterator/utils.py index 6122068..218f53f 100644 --- a/prompterator/utils.py +++ b/prompterator/utils.py @@ -1,8 +1,10 @@ +import ast import concurrent.futures import itertools import json import logging import os +import re import socket import time from collections import Counter @@ -10,6 +12,7 @@ from concurrent.futures.process import BrokenProcessPool from datetime import datetime from functools import partial +from typing import Any import jinja2 import openai @@ -233,11 +236,30 @@ def create_model_input( @st.cache_resource def jinja_env() -> jinja2.Environment: - def from_json(text: str): - return json.loads(text) + def fromjson(text: str) -> Any: + try: + return json.loads(text) + except json.decoder.JSONDecodeError as e: + raise ValueError( + f"The string you passed into `fromjson` is not a valid JSON string: " f"`{text}`" + ) from e + + def fromAstString(text: str) -> Any: + try: + return ast.literal_eval(text) + except Exception as e: + raise ValueError( + f"The string you passed into `fromAstString` is not a valid " + f"input: `{text}`. Generally, try passing a valid string " + f"representation of a " + f"Python dictionary/list/set/tuple or other simple types. For more " + f"details, refer to " + f"[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval)." + ) from e env = jinja2.Environment() - env.globals["fromjson"] = from_json + env.globals["fromjson"] = fromjson + env.globals["fromAstString"] = fromAstString return env @@ -311,3 +333,8 @@ def insert_hidden_html_marker(helper_element_id, target_streamlit_element=None): """, unsafe_allow_html=True, ) + + +def format_traceback_for_markdown(text): + text = re.sub(r" ", " ", text) + return re.sub(r"\n", "\n\n", text)