Skip to content

Commit

Permalink
use new wrapper/prompt api (#1346)
Browse files Browse the repository at this point in the history
* use new wrapper/prompt api

* wip

* mypy

* mypy

* oops

* test run report with judge
  • Loading branch information
mike0sv authored Oct 18, 2024
1 parent 0d02c62 commit 5f87168
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 223 deletions.
210 changes: 34 additions & 176 deletions src/evidently/features/llm_judge.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,42 @@
import json
from abc import ABC
from abc import abstractmethod
from enum import Enum
from typing import Callable
from typing import ClassVar
from typing import Dict
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import Union

import pandas as pd

from evidently import ColumnType
from evidently._pydantic_compat import Field
from evidently._pydantic_compat import PrivateAttr
from evidently._pydantic_compat import SecretStr
from evidently.base_metric import ColumnName
from evidently.errors import EvidentlyError
from evidently.features.generated_features import GeneratedFeatures
from evidently.options.base import Options
from evidently.options.option import Option
from evidently.pydantic_utils import EnumValueMixin
from evidently.pydantic_utils import EvidentlyBaseModel
from evidently.pydantic_utils import autoregister
from evidently.utils.data_preprocessing import DataDefinition
from evidently.utils.llm.base import LLMMessage
from evidently.utils.llm.prompts import PromptBlock
from evidently.utils.llm.prompts import PromptTemplate
from evidently.utils.llm.wrapper import LLMRequest
from evidently.utils.llm.wrapper import LLMWrapper
from evidently.utils.llm.wrapper import get_llm_wrapper

LLMMessage = Tuple[str, str]
LLMResponse = Dict[str, Union[str, float]]


class EvidentlyLLMError(EvidentlyError):
pass


class LLMResponseParseError(EvidentlyLLMError):
pass


class LLMRequestError(EvidentlyLLMError):
pass


class LLMWrapper(ABC):
__used_options__: ClassVar[List[Type[Option]]] = []

@abstractmethod
def complete(self, messages: List[LLMMessage]) -> str:
raise NotImplementedError

def get_used_options(self) -> List[Type[Option]]:
return self.__used_options__


LLMProvider = str
LLMModel = str
LLMWrapperProvider = Callable[[LLMModel, Options], LLMWrapper]
_wrappers: Dict[Tuple[LLMProvider, Optional[LLMModel]], LLMWrapperProvider] = {}


def llm_provider(name: LLMProvider, model: Optional[LLMModel]):
def dec(f: LLMWrapperProvider):
_wrappers[(name, model)] = f
return f

return dec


def get_llm_wrapper(provider: LLMProvider, model: LLMModel, options: Options) -> LLMWrapper:
key: Tuple[str, Optional[str]] = (provider, model)
if key in _wrappers:
return _wrappers[key](model, options)
key = (provider, None)
if key in _wrappers:
return _wrappers[key](model, options)
raise ValueError(f"LLM wrapper for provider {provider} model {model} not found")


class BaseLLMPromptTemplate(EvidentlyBaseModel, ABC):
class BaseLLMPromptTemplate(PromptTemplate):
class Config:
is_base_type = True

@abstractmethod
def iterate_messages(self, data: pd.DataFrame, input_columns: Dict[str, str]) -> Iterator[LLMMessage]:
raise NotImplementedError

@abstractmethod
def get_system_prompts(self) -> List[LLMMessage]:
raise NotImplementedError

@abstractmethod
def parse_response(self, response: str) -> LLMResponse:
raise NotImplementedError
def iterate_messages(self, data: pd.DataFrame, input_columns: Dict[str, str]) -> Iterator[LLMRequest[dict]]:
template = self.get_template()
for _, column_values in data[list(input_columns)].rename(columns=input_columns).iterrows():
yield LLMRequest(
messages=self.get_messages(column_values, template), response_parser=self.parse, response_type=dict
)

@abstractmethod
def list_output_columns(self) -> List[str]:
Expand All @@ -103,10 +46,6 @@ def list_output_columns(self) -> List[str]:
def get_type(self, subcolumn: Optional[str]) -> ColumnType:
raise NotImplementedError

@abstractmethod
def get_prompt_template(self) -> str:
raise NotImplementedError


class Uncertainty(str, Enum):
UNKNOWN = "unknown"
Expand All @@ -119,9 +58,6 @@ class BinaryClassificationPromptTemplate(BaseLLMPromptTemplate, EnumValueMixin):
class Config:
type_alias = "evidently:prompt_template:BinaryClassificationPromptTemplate"

template: str = (
"""{__criteria__}\n{__task__}\n\n{__as__}\n{{input}}\n{__ae__}\n\n{__instructions__}\n\n{__output_format__}"""
)
criteria: str = ""
instructions_template: str = (
"Use the following categories for classification:\n{__categories__}\n{__scoring__}\nThink step by step."
Expand All @@ -146,32 +82,6 @@ class Config:

pre_messages: List[LLMMessage] = Field(default_factory=list)

def iterate_messages(self, data: pd.DataFrame, input_columns: Dict[str, str]) -> Iterator[LLMMessage]:
prompt_template = self.get_prompt_template()
for _, column_values in data[list(input_columns)].rename(columns=input_columns).iterrows():
yield "user", prompt_template.format(**dict(column_values))

def get_prompt_template(self) -> str:
values = {
"__criteria__": self._criteria(),
"__task__": self._task(),
"__instructions__": self._instructions(),
"__output_format__": self._output_format(),
"__as__": self.anchor_start,
"__ae__": self.anchor_end,
**self.placeholders,
}
return self.template.format(**values)

def _task(self):
return (
f"Classify text between {self.anchor_start} and {self.anchor_end} "
f"into two categories: {self.target_category} and {self.non_target_category}."
)

def _criteria(self):
return self.criteria

def _instructions(self):
categories = (
(
Expand Down Expand Up @@ -203,30 +113,30 @@ def _uncertainty_class(self):
return self.target_category
raise ValueError(f"Unknown uncertainty value: {self.uncertainty}")

def _output_format(self):
values = []
columns = {}
def get_blocks(self) -> Sequence[PromptBlock]:
fields = {}
if self.include_category:
cat = f"{self.target_category} or {self.non_target_category}"
if self.uncertainty == Uncertainty.UNKNOWN:
cat += " or UNKNOWN"
columns[self.output_column] = f'"{cat}"'
values.append("category")
fields["category"] = (cat, self.output_column)
if self.include_score:
columns[self.output_score_column] = "<score here>"
values.append("score")
fields["score"] = ("<score here>", self.output_score_column)
if self.include_reasoning:
columns[self.output_reasoning_column] = '"<reasoning here>"'
values.append("reasoning")

keys = "\n".join(f'"{k}": {v}' for k, v in columns.items())
return f"Return {', '.join(values)} formatted as json without formatting as follows:\n{{{{\n{keys}\n}}}}"
fields["reasoning"] = ('"<reasoning here>"', self.output_reasoning_column)
return [
PromptBlock.simple(self.criteria),
PromptBlock.simple(
f"Classify text between {self.anchor_start} and {self.anchor_end} "
f"into two categories: {self.target_category} and {self.non_target_category}."
),
PromptBlock.input().anchored(self.anchor_start, self.anchor_end),
PromptBlock.simple(self._instructions()),
PromptBlock.json_output(**fields),
]

def parse_response(self, response: str) -> LLMResponse:
try:
return json.loads(response)
except json.JSONDecodeError as e:
raise LLMResponseParseError(f"Failed to parse response '{response}' as json") from e
def get_messages(self, values, template: Optional[str] = None) -> List[LLMMessage]:
return [*self.pre_messages, *super().get_messages(values)]

def list_output_columns(self) -> List[str]:
result = []
Expand All @@ -247,9 +157,6 @@ def get_type(self, subcolumn: Optional[str]) -> ColumnType:
return ColumnType.Categorical
raise ValueError(f"Unknown subcolumn {subcolumn}")

def get_system_prompts(self) -> List[LLMMessage]:
return self.pre_messages


class LLMJudge(GeneratedFeatures):
class Config:
Expand Down Expand Up @@ -281,12 +188,10 @@ def get_input_columns(self):
return {self.input_column: self.DEFAULT_INPUT_COLUMN}

def generate_features(self, data: pd.DataFrame, data_definition: DataDefinition, options: Options) -> pd.DataFrame:
result: List[Dict[str, Union[str, float]]] = []
result = self.get_llm_wrapper(options).run_batch_sync(
requests=self.template.iterate_messages(data, self.get_input_columns())
)

for message in self.template.iterate_messages(data, self.get_input_columns()):
messages: List[LLMMessage] = [*self.template.get_system_prompts(), message]
response = self.get_llm_wrapper(options).complete(messages)
result.append(self.template.parse_response(response))
return pd.DataFrame(result)

def list_columns(self) -> List["ColumnName"]:
Expand All @@ -300,50 +205,3 @@ def get_type(self, subcolumn: Optional[str] = None) -> ColumnType:
subcolumn = self._extract_subcolumn_name(subcolumn)

return self.template.get_type(subcolumn)


class OpenAIKey(Option):
api_key: Optional[SecretStr] = None

def __init__(self, api_key: Optional[str] = None):
self.api_key = SecretStr(api_key) if api_key is not None else None
super().__init__()

def get_value(self) -> Optional[str]:
if self.api_key is None:
return None
return self.api_key.get_secret_value()


@llm_provider("openai", None)
class OpenAIWrapper(LLMWrapper):
__used_options__: ClassVar = [OpenAIKey]

def __init__(self, model: str, options: Options):
import openai

self.model = model
self.client = openai.OpenAI(api_key=options.get(OpenAIKey).get_value())

def complete(self, messages: List[LLMMessage]) -> str:
import openai

messages = [{"role": user, "content": msg} for user, msg in messages]
try:
response = self.client.chat.completions.create(model=self.model, messages=messages) # type: ignore[arg-type]
except openai.OpenAIError as e:
raise LLMRequestError("Failed to call OpenAI complete API") from e
content = response.choices[0].message.content
assert content is not None # todo: better error
return content


@llm_provider("litellm", None)
class LiteLLMWrapper(LLMWrapper):
def __init__(self, model: str):
self.model = model

def complete(self, messages: List[LLMMessage]) -> str:
from litellm import completion

return completion(model=self.model, messages=messages).choices[0].message.content
38 changes: 1 addition & 37 deletions src/evidently/ui/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import asyncio
import contextlib
import datetime
import json
import threading
from abc import ABC
from abc import abstractmethod
from enum import Enum
from functools import wraps
from typing import IO
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import ClassVar
from typing import Dict
from typing import Iterator
Expand Down Expand Up @@ -59,30 +54,7 @@
from evidently.utils import NumpyEncoder
from evidently.utils.dashboard import TemplateParams
from evidently.utils.dashboard import inline_iframe_html_template

_loop = asyncio.new_event_loop()

_thr = threading.Thread(target=_loop.run_forever, name="Async Runner", daemon=True)


TA = TypeVar("TA")


def async_to_sync(awaitable: Awaitable[TA]) -> TA:
try:
asyncio.get_running_loop()
# we are in sync context but inside a running loop
if not _thr.is_alive():
_thr.start()
future = asyncio.run_coroutine_threadsafe(awaitable, _loop)
return future.result()
except RuntimeError:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(awaitable)
finally:
new_loop.close()
from evidently.utils.sync import sync_api


class BlobMetadata(BaseModel):
Expand Down Expand Up @@ -188,14 +160,6 @@ def _default_dashboard():
return DashboardConfig(name="", panels=[])


def sync_api(f: Callable[..., Awaitable[TA]]) -> Callable[..., TA]:
@wraps(f)
def sync_call(*args, **kwargs):
return async_to_sync(f(*args, **kwargs))

return sync_call


class Project(Entity):
entity_type: ClassVar[EntityType] = EntityType.Project

Expand Down
2 changes: 1 addition & 1 deletion src/evidently/ui/workspace/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from evidently.suite.base_suite import Snapshot
from evidently.ui.base import Project
from evidently.ui.base import ProjectManager
from evidently.ui.base import async_to_sync
from evidently.ui.type_aliases import STR_UUID
from evidently.ui.type_aliases import ZERO_UUID
from evidently.ui.type_aliases import DatasetID
from evidently.ui.type_aliases import OrgID
from evidently.ui.type_aliases import TeamID
from evidently.ui.type_aliases import UserID
from evidently.ui.workspace.base import WorkspaceBase
from evidently.utils.sync import async_to_sync


class WorkspaceView(WorkspaceBase):
Expand Down
2 changes: 1 addition & 1 deletion src/evidently/utils/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict


@dataclasses.dataclass
@dataclasses.dataclass(unsafe_hash=True, frozen=True)
class LLMMessage:
role: str
content: str
Expand Down
2 changes: 1 addition & 1 deletion src/evidently/utils/llm/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from evidently._pydantic_compat import SecretStr
from evidently.options.base import Options
from evidently.options.option import Option
from evidently.ui.base import sync_api
from evidently.utils.llm.base import LLMMessage
from evidently.utils.llm.errors import LLMRequestError
from evidently.utils.sync import sync_api

TResult = TypeVar("TResult")

Expand Down
Loading

0 comments on commit 5f87168

Please sign in to comment.