From 9c7547ad6e8ac26aeeb1e35a364f2e7bbf39f298 Mon Sep 17 00:00:00 2001 From: bwmac Date: Fri, 10 Jan 2025 15:45:24 -0500 Subject: [PATCH] adds syncronous interface --- synapseclient/models/__init__.py | 4 + synapseclient/models/agent.py | 26 +++- .../models/protocols/agent_protocol.py | 144 ++++++++++++++++++ 3 files changed, 170 insertions(+), 4 deletions(-) create mode 100644 synapseclient/models/protocols/agent_protocol.py diff --git a/synapseclient/models/__init__.py b/synapseclient/models/__init__.py index a487a3827..09aa2ac32 100644 --- a/synapseclient/models/__init__.py +++ b/synapseclient/models/__init__.py @@ -1,5 +1,6 @@ # These are all of the models that are used by the Synapse client. from synapseclient.models.activity import Activity, UsedEntity, UsedURL +from synapseclient.models.agent import Agent, AgentSession, AgentSessionAccessLevel from synapseclient.models.annotations import Annotations from synapseclient.models.file import File, FileHandle from synapseclient.models.folder import Folder @@ -38,4 +39,7 @@ "TeamMember", "UserProfile", "UserPreference", + "Agent", + "AgentSession", + "AgentSessionAccessLevel", ] diff --git a/synapseclient/models/agent.py b/synapseclient/models/agent.py index 8cb270bd3..903a88f98 100644 --- a/synapseclient/models/agent.py +++ b/synapseclient/models/agent.py @@ -14,7 +14,11 @@ start_session, update_session, ) -from synapseclient.core.async_utils import otel_trace_method +from synapseclient.core.async_utils import async_to_sync, otel_trace_method +from synapseclient.models.protocols.agent_protocol import ( + AgentSessionSynchronousProtocol, + AgentSynchronousProtocol, +) class AgentType(str, Enum): @@ -71,7 +75,8 @@ class AgentPrompt: # TODO Add example usage to the docstring @dataclass -class AgentSession: +@async_to_sync +class AgentSession(AgentSessionSynchronousProtocol): """Represents a [Synapse Agent Session](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentSession.html) Attributes: @@ -259,7 +264,8 @@ async def prompt_async( # TODO Add example usage to the docstring @dataclass -class Agent: +@async_to_sync +class Agent(AgentSynchronousProtocol): """Represents a [Synapse Agent Registration](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistration.html) Attributes: @@ -364,6 +370,7 @@ async def start_session_async( synapse_client: Optional[Synapse] = None, ) -> "AgentSession": """Starts an agent session. + Adds the session to the Agent's sessions dictionary and sets it as the current session. Arguments: access_level: The access level of the agent session. @@ -389,6 +396,17 @@ async def start_session_async( async def get_session_async( self, session_id: str, *, synapse_client: Optional[Synapse] = None ) -> "AgentSession": + """Gets an existing agent session. + Adds the session to the Agent's sessions dictionary and sets it as the current session. + + Arguments: + session_id: The ID of the session to get. + synapse_client: The Synapse client to use for the request. + If None, the default client will be used. + + Returns: + The existing AgentSession object. + """ session = await AgentSession(id=session_id).get_async( synapse_client=synapse_client ) @@ -400,7 +418,7 @@ async def get_session_async( @otel_trace_method( method_to_trace_name=lambda self, **kwargs: f"Prompt_Agent_Session: {self.registration_id}" ) - async def prompt( + async def prompt_async( self, prompt: str, enable_trace: bool = False, diff --git a/synapseclient/models/protocols/agent_protocol.py b/synapseclient/models/protocols/agent_protocol.py new file mode 100644 index 000000000..5a0ca138c --- /dev/null +++ b/synapseclient/models/protocols/agent_protocol.py @@ -0,0 +1,144 @@ +"""Protocol for the methods of the Agent and AgentSession classes that have +synchronous counterparts generated at runtime.""" + +from typing import TYPE_CHECKING, Optional, Protocol + +from synapseclient import Synapse + +if TYPE_CHECKING: + from synapseclient.models import Agent, AgentSession, AgentSessionAccessLevel + + +class AgentSessionSynchronousProtocol(Protocol): + """Protocol for the methods of the AgentSession class that have synchronous counterparts + generated at runtime.""" + + def start(self, *, synapse_client: Optional[Synapse] = None) -> "AgentSession": + """Starts an agent session. + + Arguments: + synapse_client: The Synapse client to use for the request. If None, the default client will be used. + + Returns: + The new AgentSession object. + """ + return self + + def get(self, *, synapse_client: Optional[Synapse] = None) -> "AgentSession": + """Gets an existing agent session. + + Arguments: + synapse_client: The Synapse client to use for the request. If None, the default client will be used. + + Returns: + The existing AgentSession object. + """ + return self + + def update(self, *, synapse_client: Optional[Synapse] = None) -> "AgentSession": + """Updates an existing agent session. + + Arguments: + synapse_client: The Synapse client to use for the request. If None, the default client will be used. + + Returns: + The updated AgentSession object. + """ + return self + + def prompt(self, *, synapse_client: Optional[Synapse] = None) -> None: + """Sends a prompt to the agent and adds the response to the AgentSession's chat history. + + Arguments: + prompt: The prompt to send to the agent. + enable_trace: Whether to enable trace for the prompt. + print_response: Whether to print the response to the console. + newer_than: The timestamp to get trace results newer than. Defaults to None (all results). + synapse_client: The Synapse client to use for the request. If None, the default client will be used. + """ + return None + + +class AgentSynchronousProtocol(Protocol): + """Protocol for the methods of the Agent class that have synchronous counterparts + generated at runtime.""" + + def register(self, *, synapse_client: Optional[Synapse] = None) -> "Agent": + """Registers an agent with the Synapse API. If agent exists, it will be retrieved. + + Arguments: + synapse_client: The Synapse client to use for the request. If None, the default client will be used. + + Returns: + The registered or existing Agent object. + """ + return self + + def get(self, *, synapse_client: Optional[Synapse] = None) -> "Agent": + """Gets an existing agent. + + Arguments: + synapse_client: The Synapse client to use for the request. If None, the default client will be used. + + Returns: + The existing Agent object. + """ + return self + + def start_session( + self, + access_level: Optional["AgentSessionAccessLevel"] = "PUBLICLY_ACCESSIBLE", + *, + synapse_client: Optional[Synapse] = None, + ) -> "AgentSession": + """Starts an agent session. + Adds the session to the Agent's sessions dictionary and sets it as the current session. + Arguments: + access_level: The access level of the agent session. + Must be one of PUBLICLY_ACCESSIBLE, READ_YOUR_PRIVATE_DATA, or WRITE_YOUR_PRIVATE_DATA. + Defaults to PUBLICLY_ACCESSIBLE. + synapse_client: The Synapse client to use for the request. If None, the default client will be used. + + Returns: + The new AgentSession object. + """ + return AgentSession() + + def get_session( + self, session_id: str, *, synapse_client: Optional[Synapse] = None + ) -> "AgentSession": + """Gets an existing agent session. + Adds the session to the Agent's sessions dictionary and sets it as the current session. + + Arguments: + session_id: The ID of the session to get. + synapse_client: The Synapse client to use for the request. + If None, the default client will be used. + + Returns: + The existing AgentSession object. + """ + return AgentSession() + + def prompt( + self, + prompt: str, + enable_trace: bool = False, + print_response: bool = False, + session: Optional["AgentSession"] = None, + newer_than: Optional[int] = None, + *, + synapse_client: Optional[Synapse] = None, + ) -> None: + """Sends a prompt to the agent for the current session. + If no session is currently active, a new session will be started. + + Arguments: + prompt: The prompt to send to the agent. + enable_trace: Whether to enable trace for the prompt. + print_response: Whether to print the response to the console. + session_id: The ID of the session to send the prompt to. If None, the current session will be used. + newer_than: The timestamp to get trace results newer than. Defaults to None (all results). + synapse_client: The Synapse client to use for the request. If None, the default client will be used. + """ + return None