From e8dfa4cf449f01d63fc987ee4638a23e89cfc712 Mon Sep 17 00:00:00 2001 From: hellovai Date: Wed, 29 Nov 2023 11:55:51 -0800 Subject: [PATCH] Make anthropic a chat client (#149) --- .../providers/anthropic_provider.py | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/clients/python/baml_core/registrations/providers/anthropic_provider.py b/clients/python/baml_core/registrations/providers/anthropic_provider.py index 8b9f0d3ba..3ea5a0d85 100644 --- a/clients/python/baml_core/registrations/providers/anthropic_provider.py +++ b/clients/python/baml_core/registrations/providers/anthropic_provider.py @@ -2,7 +2,12 @@ import typing -from baml_core.provider_manager import LLMProvider, register_llm_provider, LLMResponse +from baml_core.provider_manager import ( + LLMChatProvider, + register_llm_provider, + LLMResponse, + LLMChatMessage, +) def _hydrate_anthropic_tokenizer() -> None: @@ -15,7 +20,7 @@ def _hydrate_anthropic_tokenizer() -> None: @register_llm_provider("baml-anthropic") @typing.final -class AnthropicProvider(LLMProvider): +class AnthropicProvider(LLMChatProvider): __kwargs: typing.Dict[str, typing.Any] def _to_error_code(self, e: Exception) -> typing.Optional[int]: @@ -34,13 +39,7 @@ def __init__( ), "Either use max_retries with Anthropic via options or retry via BAML, not both" super().__init__( - chat_to_prompt=lambda chat: "".join( - map( - lambda c: f'{anthropic.HUMAN_PROMPT if c["role"] != "system" else anthropic.AI_PROMPT} {c["content"]}', - chat, - ) - ) - + anthropic.AI_PROMPT, + prompt_to_chat=lambda chat: {"role": "human", "content": chat}, **kwargs, ) @@ -79,7 +78,16 @@ def __init__( def _validate(self) -> None: pass - async def _run(self, prompt: str) -> LLMResponse: + async def _run_chat(self, messages: typing.List[LLMChatMessage]) -> LLMResponse: + prompt = ( + "".join( + map( + lambda c: f'{anthropic.HUMAN_PROMPT if c["role"] != "system" else anthropic.AI_PROMPT} {c["content"]}', + messages, + ) + ) + + anthropic.AI_PROMPT + ) prompt_tokens = await self.__client.count_tokens(prompt) response = typing.cast( anthropic.types.Completion,