Skip to content

Commit

Permalink
cleans up agent logic
Browse files Browse the repository at this point in the history
  • Loading branch information
BWMac committed Jan 15, 2025
1 parent 165570c commit 8adf5e5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 28 deletions.
22 changes: 9 additions & 13 deletions synapseclient/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class AgentType(str, Enum):
"""
Enum representing the type of agent as defined in
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentType.html>
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentType.html>
- BASELINE is a default agent provided by Synapse.
- CUSTOM is a custom agent that has been registered by a user.
Expand All @@ -37,7 +37,7 @@ class AgentType(str, Enum):
class AgentSessionAccessLevel(str, Enum):
"""
Enum representing the access level of the agent session as defined in
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentAccessLevel.html>
<https://rest-docs.synapse.org/rest/org/sagebionetworks/repo/model/agent/AgentAccessLevel.html>
- PUBLICLY_ACCESSIBLE: The agent can only access publicly accessible data.
- READ_YOUR_PRIVATE_DATA: The agent can read the user's private data.
Expand Down Expand Up @@ -96,7 +96,7 @@ def fill_from_dict(self, synapse_response: Dict[str, str]) -> "AgentPrompt":
Converts a response from the REST API into this dataclass.
Arguments:
agent_prompt: The response from the REST API.
synapse_response: The response from the REST API.
Returns:
The AgentPrompt object.
Expand Down Expand Up @@ -146,9 +146,9 @@ class AgentSession(AgentSessionSynchronousProtocol):
id: Optional[str] = None
"""The unique ID of the agent session. Can only be used by the user that created it."""

access_level: Optional[
AgentSessionAccessLevel
] = AgentSessionAccessLevel.PUBLICLY_ACCESSIBLE
access_level: Optional[AgentSessionAccessLevel] = (
AgentSessionAccessLevel.PUBLICLY_ACCESSIBLE
)
"""The access level of the agent session.
One of PUBLICLY_ACCESSIBLE, READ_YOUR_PRIVATE_DATA, or WRITE_YOUR_PRIVATE_DATA.
Defaults to PUBLICLY_ACCESSIBLE.
Expand Down Expand Up @@ -484,14 +484,10 @@ async def prompt_async(
`Synapse.allow_client_caching(False)` this will use the last created
instance from the Synapse class constructor.
"""
# TODO: Iron this out. Make sure we cover all cases.
if session:
if session.id not in self.sessions:
await self.get_session_async(
session_id=session.id, synapse_client=synapse_client
)
else:
self.current_session = session
await self.get_session_async(
session_id=session.id, synapse_client=synapse_client
)
else:
if not self.current_session:
await self.start_session_async(synapse_client=synapse_client)
Expand Down
16 changes: 1 addition & 15 deletions synapseclient/models/mixins/asynchronous_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ class AsynchronousCommunicator:
"""Mixin to handle communication with the Synapse Asynchronous Job service."""

def to_synapse_request(self) -> None:
"""Converts the request to a request expected of the Synapse REST API.
This is a placeholder for any additional logic that needs to be run before the exchange with Synapse.
It must be overridden by subclasses if needed.
"""
"""Converts the request to a request expected of the Synapse REST API."""
raise NotImplementedError("to_synapse_request must be implemented.")

def fill_from_dict(
Expand All @@ -31,9 +27,6 @@ def fill_from_dict(
"""
Converts a response from the REST API into this dataclass.
This is a placeholder for any additional logic that needs to be run after the exchange with Synapse.
It must be overridden by subclasses if needed.
Arguments:
synapse_response: The response from the REST API.
Expand All @@ -47,9 +40,6 @@ async def _post_exchange_async(
) -> None:
"""Any additional logic to run after the exchange with Synapse.
This is a placeholder for any additional logic that needs to be run after the exchange with Synapse.
It must be overridden by subclasses if needed.
Arguments:
synapse_client: The Synapse client to use for the request.
**kwargs: Additional arguments to pass to the request.
Expand All @@ -64,9 +54,6 @@ async def send_job_and_wait_async(
) -> "AsynchronousCommunicator":
"""Send the job to the Asynchronous Job service and wait for it to complete.
This is a placeholder for any additional logic that needs to be run after the exchange with Synapse.
It must be overridden by subclasses if needed.
Arguments:
post_exchange_args: Additional arguments to pass to the request.
synapse_client: The Synapse client to use for the request.
Expand Down Expand Up @@ -341,7 +328,6 @@ async def get_job_async(
)
job_status = AsynchronousJobStatus().fill_from_dict(async_job_status=result)
if job_status.state == AsynchronousJobState.PROCESSING:
# TODO: Is this adequate to determine if the endpoint tracks progress?
progress_tracking = any(
[
job_status.progress_message,
Expand Down

0 comments on commit 8adf5e5

Please sign in to comment.