Skip to content

Commit

Permalink
adds syncronous interface
Browse files Browse the repository at this point in the history
  • Loading branch information
BWMac committed Jan 10, 2025
1 parent dadbac3 commit 9c7547a
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 4 deletions.
4 changes: 4 additions & 0 deletions synapseclient/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# These are all of the models that are used by the Synapse client.
from synapseclient.models.activity import Activity, UsedEntity, UsedURL
from synapseclient.models.agent import Agent, AgentSession, AgentSessionAccessLevel
from synapseclient.models.annotations import Annotations
from synapseclient.models.file import File, FileHandle
from synapseclient.models.folder import Folder
Expand Down Expand Up @@ -38,4 +39,7 @@
"TeamMember",
"UserProfile",
"UserPreference",
"Agent",
"AgentSession",
"AgentSessionAccessLevel",
]
26 changes: 22 additions & 4 deletions synapseclient/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
start_session,
update_session,
)
from synapseclient.core.async_utils import otel_trace_method
from synapseclient.core.async_utils import async_to_sync, otel_trace_method
from synapseclient.models.protocols.agent_protocol import (
AgentSessionSynchronousProtocol,
AgentSynchronousProtocol,
)


class AgentType(str, Enum):
Expand Down Expand Up @@ -71,7 +75,8 @@ class AgentPrompt:

# TODO Add example usage to the docstring
@dataclass
class AgentSession:
@async_to_sync
class AgentSession(AgentSessionSynchronousProtocol):
"""Represents a [Synapse Agent Session](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentSession.html)
Attributes:
Expand Down Expand Up @@ -259,7 +264,8 @@ async def prompt_async(

# TODO Add example usage to the docstring
@dataclass
class Agent:
@async_to_sync
class Agent(AgentSynchronousProtocol):
"""Represents a [Synapse Agent Registration](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentRegistration.html)
Attributes:
Expand Down Expand Up @@ -364,6 +370,7 @@ async def start_session_async(
synapse_client: Optional[Synapse] = None,
) -> "AgentSession":
"""Starts an agent session.
Adds the session to the Agent's sessions dictionary and sets it as the current session.
Arguments:
access_level: The access level of the agent session.
Expand All @@ -389,6 +396,17 @@ async def start_session_async(
async def get_session_async(
self, session_id: str, *, synapse_client: Optional[Synapse] = None
) -> "AgentSession":
"""Gets an existing agent session.
Adds the session to the Agent's sessions dictionary and sets it as the current session.
Arguments:
session_id: The ID of the session to get.
synapse_client: The Synapse client to use for the request.
If None, the default client will be used.
Returns:
The existing AgentSession object.
"""
session = await AgentSession(id=session_id).get_async(
synapse_client=synapse_client
)
Expand All @@ -400,7 +418,7 @@ async def get_session_async(
@otel_trace_method(
method_to_trace_name=lambda self, **kwargs: f"Prompt_Agent_Session: {self.registration_id}"
)
async def prompt(
async def prompt_async(
self,
prompt: str,
enable_trace: bool = False,
Expand Down
144 changes: 144 additions & 0 deletions synapseclient/models/protocols/agent_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Protocol for the methods of the Agent and AgentSession classes that have
synchronous counterparts generated at runtime."""

from typing import TYPE_CHECKING, Optional, Protocol

from synapseclient import Synapse

if TYPE_CHECKING:
from synapseclient.models import Agent, AgentSession, AgentSessionAccessLevel


class AgentSessionSynchronousProtocol(Protocol):
"""Protocol for the methods of the AgentSession class that have synchronous counterparts
generated at runtime."""

def start(self, *, synapse_client: Optional[Synapse] = None) -> "AgentSession":
"""Starts an agent session.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
Returns:
The new AgentSession object.
"""
return self

def get(self, *, synapse_client: Optional[Synapse] = None) -> "AgentSession":
"""Gets an existing agent session.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
Returns:
The existing AgentSession object.
"""
return self

def update(self, *, synapse_client: Optional[Synapse] = None) -> "AgentSession":
"""Updates an existing agent session.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
Returns:
The updated AgentSession object.
"""
return self

def prompt(self, *, synapse_client: Optional[Synapse] = None) -> None:
"""Sends a prompt to the agent and adds the response to the AgentSession's chat history.
Arguments:
prompt: The prompt to send to the agent.
enable_trace: Whether to enable trace for the prompt.
print_response: Whether to print the response to the console.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
"""
return None


class AgentSynchronousProtocol(Protocol):
"""Protocol for the methods of the Agent class that have synchronous counterparts
generated at runtime."""

def register(self, *, synapse_client: Optional[Synapse] = None) -> "Agent":
"""Registers an agent with the Synapse API. If agent exists, it will be retrieved.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
Returns:
The registered or existing Agent object.
"""
return self

def get(self, *, synapse_client: Optional[Synapse] = None) -> "Agent":
"""Gets an existing agent.
Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
Returns:
The existing Agent object.
"""
return self

def start_session(
self,
access_level: Optional["AgentSessionAccessLevel"] = "PUBLICLY_ACCESSIBLE",
*,
synapse_client: Optional[Synapse] = None,
) -> "AgentSession":
"""Starts an agent session.
Adds the session to the Agent's sessions dictionary and sets it as the current session.
Arguments:
access_level: The access level of the agent session.
Must be one of PUBLICLY_ACCESSIBLE, READ_YOUR_PRIVATE_DATA, or WRITE_YOUR_PRIVATE_DATA.
Defaults to PUBLICLY_ACCESSIBLE.
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
Returns:
The new AgentSession object.
"""
return AgentSession()

def get_session(
self, session_id: str, *, synapse_client: Optional[Synapse] = None
) -> "AgentSession":
"""Gets an existing agent session.
Adds the session to the Agent's sessions dictionary and sets it as the current session.
Arguments:
session_id: The ID of the session to get.
synapse_client: The Synapse client to use for the request.
If None, the default client will be used.
Returns:
The existing AgentSession object.
"""
return AgentSession()

def prompt(
self,
prompt: str,
enable_trace: bool = False,
print_response: bool = False,
session: Optional["AgentSession"] = None,
newer_than: Optional[int] = None,
*,
synapse_client: Optional[Synapse] = None,
) -> None:
"""Sends a prompt to the agent for the current session.
If no session is currently active, a new session will be started.
Arguments:
prompt: The prompt to send to the agent.
enable_trace: Whether to enable trace for the prompt.
print_response: Whether to print the response to the console.
session_id: The ID of the session to send the prompt to. If None, the current session will be used.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
"""
return None

0 comments on commit 9c7547a

Please sign in to comment.