Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and improve experience when using Jinja to iterate over objects DATANG-3679 #14

Merged
merged 10 commits into from
Mar 27, 2024
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,35 @@ The image will be rendered inside the displayed dataframe and next to the "gener

(*Note: you also need an `OPENAI_API_KEY` environment variable to use `gpt-4-vision-preview`*)

## Using input data in prompts
samsucik marked this conversation as resolved.
Show resolved Hide resolved

The user/system prompt textboxes support [Jinja](https://jinja.palletsprojects.com/) templates.
Given a column named `text` in your uploaded CSV data, you can use values from this column by
writing the simple `{{text}}` template in your prompt.

If the values in your column represent more complex objects such as Python dictionaries or lists,
you can still work with them but make sure that these values are enclosed in double quotes (`"`)
in your CSV file. E.g. given a column `texts` with a value like `"[\"A\", \"B\", \"C\"]'`, you can
utilise this template to enumerate the individual list items in your prompt:
samsucik marked this conversation as resolved.
Show resolved Hide resolved
```jinja
{% for item in fromjson(texts) -%}
- {{ item }}
{% endfor %}
```
which would lead to this in your prompt:
```
- A
- B
- C
```

To parse objects from their string representation like in the above example, we provide two
functions you can use in your templates:
- `fromjson`: to be used in case of _valid JSON strings_
- `fromAstString`: to parse a wider range of string representations (it's based on
[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval))


## Paper

You can find more information on Prompterator in the associated paper: https://aclanthology.org/2023.emnlp-demo.43/
Expand Down
20 changes: 15 additions & 5 deletions prompterator/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import traceback as tb
from collections import OrderedDict
from datetime import datetime

Expand Down Expand Up @@ -113,12 +114,21 @@ def run_prompt(progress_ui_area):
model_instance = m.MODEL_INSTANCES[model.name]
model_params = {param: st.session_state[param] for param in model.configurable_params}
df_old = st.session_state.df.copy()
model_inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row

try:
model_inputs = {
i: u.create_model_input(
model, model_instance, user_prompt_template, system_prompt_template, row
)
for i, row in df_old.iterrows()
}
except Exception as e:
traceback = u.format_multiline_text_for_markdown(tb.format_exc())
st.error(
f"Couldn't prepare model inputs due to this error: {e}\n\nFull error "
f"message:\n\n{traceback}"
)
for i, row in df_old.iterrows()
}
return

if len(model_inputs) == 0:
st.error("No input data to generate texts from!")
Expand Down
33 changes: 30 additions & 3 deletions prompterator/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import ast
import concurrent.futures
import itertools
import json
import logging
import os
import re
import socket
import time
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from datetime import datetime
from functools import partial
from typing import Any

import jinja2
import openai
Expand Down Expand Up @@ -233,11 +236,30 @@ def create_model_input(

@st.cache_resource
def jinja_env() -> jinja2.Environment:
def from_json(text: str):
return json.loads(text)
def fromjson(text: str) -> Any:
try:
return json.loads(text)
except json.decoder.JSONDecodeError as e:
raise ValueError(
f"The string you passed into `fromjson` is not a valid JSON string: " f"`{text}`"
) from e

def fromAstString(text: str) -> Any:
try:
return ast.literal_eval(text)
except Exception as e:
raise ValueError(
f"The string you passed into `fromAstString` is not a valid "
f"input: `{text}`. Generally, try passing a valid string "
f"representation of a "
f"Python dictionary/list/set/tuple or other simple types. For more "
f"details, refer to "
f"[`ast.literal_eval`](https://docs.python.org/3/library/ast.html#ast.literal_eval)."
) from e

env = jinja2.Environment()
env.globals["fromjson"] = from_json
env.globals["fromjson"] = fromjson
env.globals["fromAstString"] = fromAstString
return env


Expand Down Expand Up @@ -311,3 +333,8 @@ def insert_hidden_html_marker(helper_element_id, target_streamlit_element=None):
""",
unsafe_allow_html=True,
)


def format_multiline_text_for_markdown(text):
text = re.sub(r" ", " ", text)
samsucik marked this conversation as resolved.
Show resolved Hide resolved
return re.sub(r"\n", "\n\n", text)
Loading