Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYNPY-1544] potential changes to mixin #1153

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 87 additions & 44 deletions synapseclient/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from synapseclient.core.async_utils import async_to_sync, otel_trace_method
from synapseclient.core.constants.concrete_types import AGENT_CHAT_REQUEST
from synapseclient.models.mixins.asynchronous_job import AsynchronousJob
from synapseclient.models.mixins.asynchronous_job import AsynchronousCommunicator
from synapseclient.models.protocols.agent_protocol import (
AgentSessionSynchronousProtocol,
AgentSynchronousProtocol,
Expand Down Expand Up @@ -50,13 +50,15 @@ class AgentSessionAccessLevel(str, Enum):


@dataclass
class AgentPrompt:
class AgentPrompt(AsynchronousCommunicator):
"""Represents a prompt, response, and metadata within an AgentSession.

Attributes:
id: The unique ID of the agent prompt.
session_id: The ID of the session that the prompt is associated with.
prompt: The prompt to send to the agent.
response: The response from the agent.
enable_trace: Whether tracing is enabled for the prompt.
trace: The trace of the agent session.
"""

Expand All @@ -65,20 +67,68 @@ class AgentPrompt:
id: Optional[str] = None
"""The unique ID of the agent prompt."""

session_id: Optional[str] = None
"""The ID of the session that the prompt is associated with."""

prompt: Optional[str] = None
"""The prompt sent to the agent."""

response: Optional[str] = None
"""The response from the agent."""

enable_trace: Optional[bool] = False
"""Whether tracing is enabled for the prompt."""

trace: Optional[str] = None
"""The trace or "thought process" of the agent when responding to the prompt."""

def to_synapse_request(self):
"""Converts the request to a request expected of the Synapse REST API."""
return {
"concreteType": self.concrete_type,
"sessionId": self.session_id,
"chatText": self.prompt,
"enableTrace": self.enable_trace,
}

def fill_from_dict(self, synapse_response: Dict[str, str]) -> "AgentPrompt":
"""
Converts a response from the REST API into this dataclass.

Arguments:
agent_prompt: The response from the REST API.

Returns:
The AgentPrompt object.
"""
self.id = synapse_response.get("jobId", None)
self.session_id = synapse_response.get("sessionId", None)
self.response = synapse_response.get("responseText", None)
return self

async def _post_exchange_async(
self, *, synapse_client: Optional[Synapse] = None, **kwargs
) -> None:
"""Retrieves information about the trace of this prompt with the agent.

Arguments:
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
if self.enable_trace:
trace_response = await get_trace(
prompt_id=self.id,
newer_than=kwargs.get("newer_than", None),
synapse_client=synapse_client,
)
self.trace = trace_response["page"][0]["message"]


# TODO Add example usage to the docstring
@dataclass
@async_to_sync
class AgentSession(AgentSessionSynchronousProtocol, AsynchronousJob):
class AgentSession(AgentSessionSynchronousProtocol):
"""Represents a [Synapse Agent Session](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentSession.html)

Attributes:
Expand Down Expand Up @@ -150,7 +200,9 @@ async def start_async(
"""Starts an agent session.

Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.

Returns:
The new AgentSession object.
Expand All @@ -171,7 +223,9 @@ async def get_async(
"""Gets an agent session.

Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.

Returns:
The retrieved AgentSession object.
Expand All @@ -194,7 +248,9 @@ async def update_async(
Only updates to the access level are currently supported.

Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.

Returns:
The updated AgentSession object.
Expand Down Expand Up @@ -223,45 +279,24 @@ async def prompt_async(
enable_trace: Whether to enable trace for the prompt.
print_response: Whether to print the response to the console.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
prompt_id = await self.send_job_async(
request_type=AGENT_CHAT_REQUEST,
session_id=self.id,
prompt=prompt,
enable_trace=enable_trace,
synapse_client=synapse_client,
)

answer_response = await self.get_job_async(
job_id=prompt_id,
request_type=AGENT_CHAT_REQUEST,
synapse_client=synapse_client,
agent_prompt = AgentPrompt(
prompt=prompt, session_id=self.id, enable_trace=enable_trace
)
response = answer_response["responseText"]

if enable_trace:
trace_response = await get_trace(
prompt_id=prompt_id,
newer_than=newer_than,
synapse_client=synapse_client,
)
trace = trace_response["page"][0]["message"]

self.chat_history.append(
AgentPrompt(
id=prompt_id,
prompt=prompt,
response=response,
trace=trace,
)
await agent_prompt.send_job_and_wait_async(
synapse_client=synapse_client, post_exchange_args={"newer_than": newer_than}
)
self.chat_history.append(agent_prompt)

if print_response:
print(f"PROMPT:\n{prompt}\n")
print(f"RESPONSE:\n{response}\n")
print(f"RESPONSE:\n{agent_prompt.response}\n")
if enable_trace:
print(f"TRACE:\n{trace}")
print(f"TRACE:\n{agent_prompt.trace}")


# TODO Add example usage to the docstring
Expand Down Expand Up @@ -328,7 +363,9 @@ async def register_async(
"""Registers an agent with the Synapse API. If agent exists, it will be retrieved.

Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.

Returns:
The registered or existing Agent object.
Expand All @@ -348,7 +385,9 @@ async def get_async(self, *, synapse_client: Optional[Synapse] = None) -> "Agent
"""Gets an existing agent.

Arguments:
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.

Returns:
The existing Agent object.
Expand Down Expand Up @@ -378,8 +417,9 @@ async def start_session_async(
access_level: The access level of the agent session.
Must be one of PUBLICLY_ACCESSIBLE, READ_YOUR_PRIVATE_DATA, or WRITE_YOUR_PRIVATE_DATA.
Defaults to PUBLICLY_ACCESSIBLE.
synapse_client: The Synapse client to use for the request.
If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.

Returns:
The new AgentSession object.
Expand All @@ -403,8 +443,9 @@ async def get_session_async(

Arguments:
session_id: The ID of the session to get.
synapse_client: The Synapse client to use for the request.
If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.

Returns:
The existing AgentSession object.
Expand Down Expand Up @@ -439,7 +480,9 @@ async def prompt_async(
print_response: Whether to print the response to the console.
session_id: The ID of the session to send the prompt to. If None, the current session will be used.
newer_than: The timestamp to get trace results newer than. Defaults to None (all results).
synapse_client: The Synapse client to use for the request. If None, the default client will be used.
synapse_client: If not passed in and caching was not disabled by
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
# TODO: Iron this out. Make sure we cover all cases.
if session:
Expand Down
Loading
Loading