diff --git a/synapseclient/models/agent.py b/synapseclient/models/agent.py index fdd389c36..117f61a87 100644 --- a/synapseclient/models/agent.py +++ b/synapseclient/models/agent.py @@ -14,7 +14,7 @@ ) from synapseclient.core.async_utils import async_to_sync, otel_trace_method from synapseclient.core.constants.concrete_types import AGENT_CHAT_REQUEST -from synapseclient.models.mixins.asynchronous_job import AsynchronousJob +from synapseclient.models.mixins.asynchronous_job import AsynchronousCommunicator from synapseclient.models.protocols.agent_protocol import ( AgentSessionSynchronousProtocol, AgentSynchronousProtocol, @@ -50,13 +50,15 @@ class AgentSessionAccessLevel(str, Enum): @dataclass -class AgentPrompt: +class AgentPrompt(AsynchronousCommunicator): """Represents a prompt, response, and metadata within an AgentSession. Attributes: id: The unique ID of the agent prompt. + session_id: The ID of the session that the prompt is associated with. prompt: The prompt to send to the agent. response: The response from the agent. + enable_trace: Whether tracing is enabled for the prompt. trace: The trace of the agent session. """ @@ -65,20 +67,68 @@ class AgentPrompt: id: Optional[str] = None """The unique ID of the agent prompt.""" + session_id: Optional[str] = None + """The ID of the session that the prompt is associated with.""" + prompt: Optional[str] = None """The prompt sent to the agent.""" response: Optional[str] = None """The response from the agent.""" + enable_trace: Optional[bool] = False + """Whether tracing is enabled for the prompt.""" + trace: Optional[str] = None """The trace or "thought process" of the agent when responding to the prompt.""" + def to_synapse_request(self): + """Converts the request to a request expected of the Synapse REST API.""" + return { + "concreteType": self.concrete_type, + "sessionId": self.session_id, + "chatText": self.prompt, + "enableTrace": self.enable_trace, + } + + def fill_from_dict(self, synapse_response: Dict[str, str]) -> "AgentPrompt": + """ + Converts a response from the REST API into this dataclass. + + Arguments: + agent_prompt: The response from the REST API. + + Returns: + The AgentPrompt object. + """ + self.id = synapse_response.get("jobId", None) + self.session_id = synapse_response.get("sessionId", None) + self.response = synapse_response.get("responseText", None) + return self + + async def _post_exchange_async( + self, *, synapse_client: Optional[Synapse] = None, **kwargs + ) -> None: + """Retrieves information about the trace of this prompt with the agent. + + Arguments: + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. + """ + if self.enable_trace: + trace_response = await get_trace( + prompt_id=self.id, + newer_than=kwargs.get("newer_than", None), + synapse_client=synapse_client, + ) + self.trace = trace_response["page"][0]["message"] + # TODO Add example usage to the docstring @dataclass @async_to_sync -class AgentSession(AgentSessionSynchronousProtocol, AsynchronousJob): +class AgentSession(AgentSessionSynchronousProtocol): """Represents a [Synapse Agent Session](https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentSession.html) Attributes: @@ -150,7 +200,9 @@ async def start_async( """Starts an agent session. Arguments: - synapse_client: The Synapse client to use for the request. If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. Returns: The new AgentSession object. @@ -171,7 +223,9 @@ async def get_async( """Gets an agent session. Arguments: - synapse_client: The Synapse client to use for the request. If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. Returns: The retrieved AgentSession object. @@ -194,7 +248,9 @@ async def update_async( Only updates to the access level are currently supported. Arguments: - synapse_client: The Synapse client to use for the request. If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. Returns: The updated AgentSession object. @@ -223,45 +279,24 @@ async def prompt_async( enable_trace: Whether to enable trace for the prompt. print_response: Whether to print the response to the console. newer_than: The timestamp to get trace results newer than. Defaults to None (all results). - synapse_client: The Synapse client to use for the request. If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. """ - prompt_id = await self.send_job_async( - request_type=AGENT_CHAT_REQUEST, - session_id=self.id, - prompt=prompt, - enable_trace=enable_trace, - synapse_client=synapse_client, - ) - answer_response = await self.get_job_async( - job_id=prompt_id, - request_type=AGENT_CHAT_REQUEST, - synapse_client=synapse_client, + agent_prompt = AgentPrompt( + prompt=prompt, session_id=self.id, enable_trace=enable_trace ) - response = answer_response["responseText"] - - if enable_trace: - trace_response = await get_trace( - prompt_id=prompt_id, - newer_than=newer_than, - synapse_client=synapse_client, - ) - trace = trace_response["page"][0]["message"] - - self.chat_history.append( - AgentPrompt( - id=prompt_id, - prompt=prompt, - response=response, - trace=trace, - ) + await agent_prompt.send_job_and_wait_async( + synapse_client=synapse_client, post_exchange_args={"newer_than": newer_than} ) + self.chat_history.append(agent_prompt) if print_response: print(f"PROMPT:\n{prompt}\n") - print(f"RESPONSE:\n{response}\n") + print(f"RESPONSE:\n{agent_prompt.response}\n") if enable_trace: - print(f"TRACE:\n{trace}") + print(f"TRACE:\n{agent_prompt.trace}") # TODO Add example usage to the docstring @@ -328,7 +363,9 @@ async def register_async( """Registers an agent with the Synapse API. If agent exists, it will be retrieved. Arguments: - synapse_client: The Synapse client to use for the request. If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. Returns: The registered or existing Agent object. @@ -348,7 +385,9 @@ async def get_async(self, *, synapse_client: Optional[Synapse] = None) -> "Agent """Gets an existing agent. Arguments: - synapse_client: The Synapse client to use for the request. If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. Returns: The existing Agent object. @@ -378,8 +417,9 @@ async def start_session_async( access_level: The access level of the agent session. Must be one of PUBLICLY_ACCESSIBLE, READ_YOUR_PRIVATE_DATA, or WRITE_YOUR_PRIVATE_DATA. Defaults to PUBLICLY_ACCESSIBLE. - synapse_client: The Synapse client to use for the request. - If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. Returns: The new AgentSession object. @@ -403,8 +443,9 @@ async def get_session_async( Arguments: session_id: The ID of the session to get. - synapse_client: The Synapse client to use for the request. - If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. Returns: The existing AgentSession object. @@ -439,7 +480,9 @@ async def prompt_async( print_response: Whether to print the response to the console. session_id: The ID of the session to send the prompt to. If None, the current session will be used. newer_than: The timestamp to get trace results newer than. Defaults to None (all results). - synapse_client: The Synapse client to use for the request. If None, the default client will be used. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. """ # TODO: Iron this out. Make sure we cover all cases. if session: diff --git a/synapseclient/models/mixins/asynchronous_job.py b/synapseclient/models/mixins/asynchronous_job.py index 1712eab8c..098994f2d 100644 --- a/synapseclient/models/mixins/asynchronous_job.py +++ b/synapseclient/models/mixins/asynchronous_job.py @@ -3,13 +3,87 @@ import time from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import Any, Dict, Optional +from synapseclient import Synapse from synapseclient.core.constants.concrete_types import AGENT_CHAT_REQUEST from synapseclient.core.exceptions import SynapseError, SynapseTimeoutError -if TYPE_CHECKING: - from synapseclient import Synapse +ASYNC_JOB_URIS = { + AGENT_CHAT_REQUEST: "/agent/chat/async", +} + + +class AsynchronousCommunicator: + """Mixin to handle communication with the Synapse Asynchronous Job service.""" + + def to_synapse_request(self) -> None: + """Converts the request to a request expected of the Synapse REST API. + + This is a placeholder for any additional logic that needs to be run before the exchange with Synapse. + It must be overridden by subclasses if needed. + """ + raise NotImplementedError("to_synapse_request must be implemented.") + + def fill_from_dict( + self, synapse_response: Dict[str, str] + ) -> "AsynchronousCommunicator": + """ + Converts a response from the REST API into this dataclass. + + This is a placeholder for any additional logic that needs to be run after the exchange with Synapse. + It must be overridden by subclasses if needed. + + Arguments: + synapse_response: The response from the REST API. + + Returns: + An instance of this class. + """ + raise NotImplementedError("fill_from_dict must be implemented.") + + async def _post_exchange_async( + self, synapse_client: Optional[Synapse] = None, **kwargs + ) -> None: + """Any additional logic to run after the exchange with Synapse. + + This is a placeholder for any additional logic that needs to be run after the exchange with Synapse. + It must be overridden by subclasses if needed. + + Arguments: + synapse_client: The Synapse client to use for the request. + **kwargs: Additional arguments to pass to the request. + """ + pass + + async def send_job_and_wait_async( + self, + post_exchange_args: Optional[Dict[str, Any]] = None, + *, + synapse_client: Optional[Synapse] = None, + ) -> "AsynchronousCommunicator": + """Send the job to the Asynchronous Job service and wait for it to complete. + + This is a placeholder for any additional logic that needs to be run after the exchange with Synapse. + It must be overridden by subclasses if needed. + + Arguments: + post_exchange_args: Additional arguments to pass to the request. + synapse_client: The Synapse client to use for the request. + + Returns: + An instance of this class. + """ + result = await send_job_and_wait_async( + request=self.to_synapse_request(), + request_type=self.concrete_type, + synapse_client=synapse_client, + ) + self.fill_from_dict(synapse_response=result) + await self._post_exchange_async( + **post_exchange_args, synapse_client=synapse_client + ) + return self class AsynchronousJobState(str, Enum): @@ -151,129 +225,157 @@ def fill_from_dict(self, async_job_status: dict) -> "AsynchronousJobStatus": return self -class AsynchronousJob: +async def send_job_and_wait_async( + request: Dict[str, Any], + request_type: str, + endpoint: str = None, + *, + synapse_client: Optional["Synapse"] = None, +) -> Dict[str, Any]: """ - Mixin for objects that can have Asynchronous Jobs. + Sends the job to the Synapse API and waits for the response. Request body matches: + + + Arguments: + request: A request matching . + endpoint: The endpoint to use for the request. Defaults to None. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. + + Returns: + The response body matching + + + Raises: + SynapseError: If the job fails. + SynapseTimeoutError: If the job does not complete within the timeout. """ - - ASYNC_JOB_URIS = { - AGENT_CHAT_REQUEST: "/agent/chat/async", + job_id = await send_job_async(request=request, synapse_client=synapse_client) + return { + "jobId": job_id, + **await get_job_async( + job_id=job_id, + request_type=request_type, + synapse_client=synapse_client, + endpoint=endpoint, + ), } - async def send_job_async( - self, - request_type: str, - session_id: str, - prompt: str, - enable_trace: bool, - synapse_client: Optional["Synapse"] = None, - ) -> str: - """ - Sends the job to the Synapse API. Request body matches: - - Returns the job ID. - Arguments: - request_type: The type of the job. - session_id: The ID of the session to send the prompt to. - prompt: The prompt to send to the agent. - enable_trace: Whether to enable trace for the prompt. Defaults to False. - synapse_client: If not passed in and caching was not disabled by - `Synapse.allow_client_caching(False)` this will use the last created - instance from the Synapse class constructor. +async def send_job_async( + request: Dict[str, Any], + *, + synapse_client: Optional["Synapse"] = None, +) -> str: + """ + Sends the job to the Synapse API. Request body matches: + + Returns the job ID. + + Arguments: + request: A request matching . + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. + + Returns: + The job ID retrieved from the response. + + """ + if not request: + raise ValueError("request must be provided.") - Returns: - The job ID retrieved from the response. - - """ - request = { - "concreteType": request_type, - "sessionId": session_id, - "chatText": prompt, - "enableTrace": enable_trace, - } - response = await synapse_client.rest_post_async( - uri=f"{self.ASYNC_JOB_URIS[request_type]}/start", body=json.dumps(request) - ) - return response["token"] + request_type = request.get("concreteType") - async def get_job_async( - self, - job_id: str, - request_type: str, - synapse_client: "Synapse", - endpoint: str = None, - ) -> Dict[str, Any]: - """ - Gets the job from the server using its ID. Handles progress tracking, failures and timeouts. + if not request_type or request_type not in ASYNC_JOB_URIS: + raise ValueError(f"Unsupported request type: {request_type}") - Arguments: - job_id: The ID of the job to get. - request_type: The type of the job. - synapse_client: If not passed in and caching was not disabled by - `Synapse.allow_client_caching(False)` this will use the last created - instance from the Synapse class constructor. - endpoint: The endpoint to use for the request. Defaults to None. + client = Synapse.get_client(synapse_client=synapse_client) + response = await client.rest_post_async( + uri=f"{ASYNC_JOB_URIS[request_type]}/start", body=json.dumps(request) + ) + return response["token"] - Returns: - The response body matching - - Raises: - SynapseError: If the job fails. - SynapseTimeoutError: If the job does not complete within the timeout. - """ - start_time = asyncio.get_event_loop().time() - SLEEP = 1 - TIMEOUT = 60 - - last_message = "" - last_progress = 0 - last_total = 1 - progressed = False - - while asyncio.get_event_loop().time() - start_time < TIMEOUT: - result = await synapse_client.rest_get_async( - uri=f"{self.ASYNC_JOB_URIS[request_type]}/get/{job_id}", - endpoint=endpoint, +async def get_job_async( + job_id: str, + request_type: str, + endpoint: str = None, + *, + synapse_client: Optional["Synapse"] = None, +) -> Dict[str, Any]: + """ + Gets the job from the server using its ID. Handles progress tracking, failures and timeouts. + + Arguments: + job_id: The ID of the job to get. + request_type: The type of the job. + endpoint: The endpoint to use for the request. Defaults to None. + synapse_client: If not passed in and caching was not disabled by + `Synapse.allow_client_caching(False)` this will use the last created + instance from the Synapse class constructor. + + Returns: + The response body matching + + + Raises: + SynapseError: If the job fails. + SynapseTimeoutError: If the job does not complete within the timeout. + """ + client = Synapse.get_client(synapse_client=synapse_client) + start_time = asyncio.get_event_loop().time() + SLEEP = 1 + TIMEOUT = 60 + + last_message = "" + last_progress = 0 + last_total = 1 + progressed = False + + while asyncio.get_event_loop().time() - start_time < TIMEOUT: + result = await client.rest_get_async( + uri=f"{ASYNC_JOB_URIS[request_type]}/get/{job_id}", + endpoint=endpoint, + ) + job_status = AsynchronousJobStatus().fill_from_dict(async_job_status=result) + if job_status.state == AsynchronousJobState.PROCESSING: + # TODO: Is this adequate to determine if the endpoint tracks progress? + progress_tracking = any( + [ + job_status.progress_message, + job_status.progress_current, + job_status.progress_total, + ] ) - job_status = AsynchronousJobStatus().fill_from_dict(async_job_status=result) - if job_status.state == AsynchronousJobState.PROCESSING: - # TODO: Is this adequate to determine if the endpoint tracks progress? - progress_tracking = any( - [ - job_status.progress_message, - job_status.progress_current, - job_status.progress_total, - ] - ) - progressed = ( - job_status.progress_message != last_message - or last_progress != job_status.progress_current - ) - if progress_tracking and progressed: - last_message = job_status.progress_message - last_progress = job_status.progress_current - last_total = job_status.progress_total - - synapse_client._print_transfer_progress( - last_progress, - last_total, - prefix=last_message, - isBytes=False, - ) - start_time = asyncio.get_event_loop().time() - await asyncio.sleep(SLEEP) - elif job_status.state == AsynchronousJobState.FAILED: - raise SynapseError( - f"{job_status.error_message}\n{job_status.error_details}", - async_job_status=job_status.id, + progressed = ( + job_status.progress_message != last_message + or last_progress != job_status.progress_current + ) + if progress_tracking and progressed: + last_message = job_status.progress_message + last_progress = job_status.progress_current + last_total = job_status.progress_total + + client._print_transfer_progress( + last_progress, + last_total, + prefix=last_message, + isBytes=False, ) - else: - break - else: - raise SynapseTimeoutError( - f"Timeout waiting for query results: {time.time() - start_time} seconds" + start_time = asyncio.get_event_loop().time() + await asyncio.sleep(SLEEP) + elif job_status.state == AsynchronousJobState.FAILED: + raise SynapseError( + f"{job_status.error_message}\n{job_status.error_details}", + async_job_status=job_status.id, ) + else: + break + else: + raise SynapseTimeoutError( + f"Timeout waiting for query results: {time.time() - start_time} seconds" + ) - return result + return result