Skip to content

Commit

Permalink
Merge pull request #124 from awslabs/call-agent-with-chat-history
Browse files Browse the repository at this point in the history
Split processRequest into 2 method to classify and process the agent's response separatetly
  • Loading branch information
brnaba-aws authored Dec 2, 2024
2 parents 97b5e90 + 42e84a7 commit 2160060
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 100 deletions.
130 changes: 70 additions & 60 deletions python/src/multi_agent_orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,53 +97,39 @@ async def dispatch_to_agent(self,

return response

async def route_request(self,
user_input: str,
user_id: str,
session_id: str,
additional_params: Dict[str, str] = {}) -> AgentResponse:
self.execution_times.clear()

async def classify_request(self,
user_input: str,
user_id: str,
session_id: str) -> ClassifierResult:
"""Classify user request with conversation history."""
try:
chat_history = await self.storage.fetch_all_chats(user_id, session_id) or []
classifier_result:ClassifierResult = await self.measure_execution_time(
classifier_result = await self.measure_execution_time(
"Classifying user intent",
lambda: self.classifier.classify(user_input, chat_history)
)

if self.config.LOG_CLASSIFIER_OUTPUT:
self.print_intent(user_input, classifier_result)

except Exception as error:
self.logger.error(f"Error during intent classification: {str(error)}")
return AgentResponse(
metadata=self.create_metadata(None,
user_input,
user_id,
session_id,
additional_params),
output=self.config.CLASSIFICATION_ERROR_MESSAGE
if self.config.CLASSIFICATION_ERROR_MESSAGE else
str(error),
streaming=False
)
if not classifier_result.selected_agent:
if self.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED and self.default_agent:
classifier_result = self.get_fallback_result()
self.logger.info("Using default agent as no agent was selected")

if not classifier_result.selected_agent:
if self.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED and self.default_agent:
classifier_result = self.get_fallback_result()
self.logger.info("Using default agent as no agent was selected")
else:
return AgentResponse(
metadata= self.create_metadata(classifier_result,
user_input,
user_id,
session_id,
additional_params),
output= ConversationMessage(role=ParticipantRole.ASSISTANT.value,
content=[{'text': self.config.NO_SELECTED_AGENT_MESSAGE}]),
streaming=False
)
return classifier_result

except Exception as error:
self.logger.error(f"Error during intent classification: {str(error)}")
raise error

async def agent_process_request(self,
user_input: str,
user_id: str,
session_id: str,
classifier_result: ClassifierResult,
additional_params: Dict[str, str] = {}) -> AgentResponse:
"""Process agent response and handle chat storage."""
try:
agent_response = await self.dispatch_to_agent({
"user_input": user_input,
Expand All @@ -154,49 +140,73 @@ async def route_request(self,
})

metadata = self.create_metadata(classifier_result,
user_input,
user_id,
session_id,
additional_params)
user_input,
user_id,
session_id,
additional_params)

# save question
await self.save_message(
ConversationMessage(
role=ParticipantRole.USER.value,
content=[{'text':user_input}]
content=[{'text': user_input}]
),
user_id,
session_id,
classifier_result.selected_agent
)

if isinstance(agent_response, ConversationMessage):
# save the response
await self.save_message(agent_response,
user_id,
session_id,
classifier_result.selected_agent)

user_id,
session_id,
classifier_result.selected_agent)

return AgentResponse(
metadata=metadata,
output=agent_response,
streaming=classifier_result.selected_agent.is_streaming_enabled()
)
metadata=metadata,
output=agent_response,
streaming=classifier_result.selected_agent.is_streaming_enabled()
)

except Exception as error:
self.logger.error(f"Error during agent dispatch or processing:{str(error)}")
return AgentResponse(
metadata= self.create_metadata(classifier_result,
user_input,
user_id,
session_id,
additional_params),
output = self.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE \
if self.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE else str(error),
self.logger.error(f"Error during agent processing: {str(error)}")
raise error

async def route_request(self,
user_input: str,
user_id: str,
session_id: str,
additional_params: Dict[str, str] = {}) -> AgentResponse:
"""Route user request to appropriate agent."""
self.execution_times.clear()

try:
classifier_result = await self.classify_request(user_input, user_id, session_id)

if not classifier_result.selected_agent:
return AgentResponse(
metadata=self.create_metadata(classifier_result, user_input, user_id, session_id, additional_params),
output=ConversationMessage(
role=ParticipantRole.ASSISTANT.value,
content=[{'text': self.config.NO_SELECTED_AGENT_MESSAGE}]
),
streaming=False
)

return await self.agent_process_request(
user_input,
user_id,
session_id,
classifier_result,
additional_params
)

except Exception as error:
return AgentResponse(
metadata=self.create_metadata(None, user_input, user_id, session_id, additional_params),
output=self.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE or str(error),
streaming=False
)

finally:
self.logger.print_execution_times(self.execution_times)

Expand Down
97 changes: 57 additions & 40 deletions typescript/src/orchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,57 +333,51 @@ export class MultiAgentOrchestrator {
}
}

async routeRequest(
async classifyRequest(
userInput: string,
userId: string,
sessionId: string,
additionalParams: Record<any, any> = {}
): Promise<AgentResponse> {
this.executionTimes = new Map();
let classifierResult: ClassifierResult;
const chatHistory = (await this.storage.fetchAllChats(userId, sessionId)) || [];

sessionId: string
): Promise<ClassifierResult> {
try {
classifierResult = await this.measureExecutionTime(
const chatHistory = await this.storage.fetchAllChats(userId, sessionId) || [];
const classifierResult = await this.measureExecutionTime(
"Classifying user intent",
() => this.classifier.classify(userInput, chatHistory)
);

this.logger.printIntent(userInput, classifierResult);

if (!classifierResult.selectedAgent && this.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED && this.defaultAgent) {
const fallbackResult = this.getFallbackResult();
this.logger.info("Using default agent as no agent was selected");
return fallbackResult;
}

return classifierResult;
} catch (error) {
this.logger.error("Error during intent classification:", error);
return {
metadata: this.createMetadata(null, userInput, userId, sessionId, additionalParams),
output: this.config.CLASSIFICATION_ERROR_MESSAGE ? this.config.CLASSIFICATION_ERROR_MESSAGE: String(error),
streaming: false,
};
throw error;
}

}

async agentProcessRequest(
userInput: string,
userId: string,
sessionId: string,
classifierResult: ClassifierResult,
additionalParams: Record<any, any> = {}
): Promise<AgentResponse> {
try {
// Handle case where no agent was selected
if (!classifierResult.selectedAgent) {
if (this.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED && this.defaultAgent) {
classifierResult = this.getFallbackResult();
this.logger.info("Using default agent as no agent was selected");
} else {
return {
metadata: this.createMetadata(classifierResult, userInput, userId, sessionId, additionalParams),
output: this.config.NO_SELECTED_AGENT_MESSAGE!,
streaming: false,
};
}
}

const agentResponse = await this.dispatchToAgent({
userInput,
userId,
sessionId,
classifierResult,
additionalParams,
});

const metadata = this.createMetadata(classifierResult, userInput, userId, sessionId, additionalParams);

if (this.isAsyncIterable(agentResponse)) {
const accumulatorTransform = new AccumulatorTransform();
this.processStreamInBackground(
Expand All @@ -400,8 +394,7 @@ export class MultiAgentOrchestrator {
streaming: true,
};
}

// Check if we should save the conversation

if (classifierResult?.selectedAgent.saveChat) {
await saveConversationExchange(
userInput,
Expand All @@ -413,25 +406,49 @@ export class MultiAgentOrchestrator {
this.config.MAX_MESSAGE_PAIRS_PER_AGENT
);
}



return {
metadata,
output: agentResponse,
streaming: false,
};
} catch (error) {
this.logger.error("Error during agent dispatch or processing:", error);

this.logger.error("Error during agent processing:", error);
throw error;
}
}

async routeRequest(
userInput: string,
userId: string,
sessionId: string,
additionalParams: Record<any, any> = {}
): Promise<AgentResponse> {
this.executionTimes = new Map();

try {
const classifierResult = await this.classifyRequest(userInput, userId, sessionId);

if (!classifierResult.selectedAgent) {
return {
metadata: this.createMetadata(classifierResult, userInput, userId, sessionId, additionalParams),
output: this.config.NO_SELECTED_AGENT_MESSAGE!,
streaming: false,
};
}

return await this.agentProcessRequest(userInput, userId, sessionId, classifierResult, additionalParams);
} catch (error) {
return {
metadata: this.createMetadata(classifierResult, userInput, userId, sessionId, additionalParams),
output: this.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE ? this.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE: String(error),
metadata: this.createMetadata(null, userInput, userId, sessionId, additionalParams),
output: this.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE || String(error),
streaming: false,
};
} finally {
this.logger.printExecutionTimes(this.executionTimes);
}
}


private async processStreamInBackground(
agentResponse: AsyncIterable<any>,
Expand Down

0 comments on commit 2160060

Please sign in to comment.