diff --git a/python/src/multi_agent_orchestrator/orchestrator.py b/python/src/multi_agent_orchestrator/orchestrator.py index a6b4fe82..6e239b1b 100644 --- a/python/src/multi_agent_orchestrator/orchestrator.py +++ b/python/src/multi_agent_orchestrator/orchestrator.py @@ -97,16 +97,14 @@ async def dispatch_to_agent(self, return response - async def route_request(self, - user_input: str, - user_id: str, - session_id: str, - additional_params: Dict[str, str] = {}) -> AgentResponse: - self.execution_times.clear() - + async def classify_request(self, + user_input: str, + user_id: str, + session_id: str) -> ClassifierResult: + """Classify user request with conversation history.""" try: chat_history = await self.storage.fetch_all_chats(user_id, session_id) or [] - classifier_result:ClassifierResult = await self.measure_execution_time( + classifier_result = await self.measure_execution_time( "Classifying user intent", lambda: self.classifier.classify(user_input, chat_history) ) @@ -114,36 +112,24 @@ async def route_request(self, if self.config.LOG_CLASSIFIER_OUTPUT: self.print_intent(user_input, classifier_result) - except Exception as error: - self.logger.error(f"Error during intent classification: {str(error)}") - return AgentResponse( - metadata=self.create_metadata(None, - user_input, - user_id, - session_id, - additional_params), - output=self.config.CLASSIFICATION_ERROR_MESSAGE - if self.config.CLASSIFICATION_ERROR_MESSAGE else - str(error), - streaming=False - ) + if not classifier_result.selected_agent: + if self.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED and self.default_agent: + classifier_result = self.get_fallback_result() + self.logger.info("Using default agent as no agent was selected") - if not classifier_result.selected_agent: - if self.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED and self.default_agent: - classifier_result = self.get_fallback_result() - self.logger.info("Using default agent as no agent was selected") - else: - return AgentResponse( - metadata= self.create_metadata(classifier_result, - user_input, - user_id, - session_id, - additional_params), - output= ConversationMessage(role=ParticipantRole.ASSISTANT.value, - content=[{'text': self.config.NO_SELECTED_AGENT_MESSAGE}]), - streaming=False - ) + return classifier_result + except Exception as error: + self.logger.error(f"Error during intent classification: {str(error)}") + raise error + + async def agent_process_request(self, + user_input: str, + user_id: str, + session_id: str, + classifier_result: ClassifierResult, + additional_params: Dict[str, str] = {}) -> AgentResponse: + """Process agent response and handle chat storage.""" try: agent_response = await self.dispatch_to_agent({ "user_input": user_input, @@ -154,16 +140,15 @@ async def route_request(self, }) metadata = self.create_metadata(classifier_result, - user_input, - user_id, - session_id, - additional_params) + user_input, + user_id, + session_id, + additional_params) - # save question await self.save_message( ConversationMessage( role=ParticipantRole.USER.value, - content=[{'text':user_input}] + content=[{'text': user_input}] ), user_id, session_id, @@ -171,32 +156,57 @@ async def route_request(self, ) if isinstance(agent_response, ConversationMessage): - # save the response await self.save_message(agent_response, - user_id, - session_id, - classifier_result.selected_agent) - + user_id, + session_id, + classifier_result.selected_agent) return AgentResponse( - metadata=metadata, - output=agent_response, - streaming=classifier_result.selected_agent.is_streaming_enabled() - ) + metadata=metadata, + output=agent_response, + streaming=classifier_result.selected_agent.is_streaming_enabled() + ) except Exception as error: - self.logger.error(f"Error during agent dispatch or processing:{str(error)}") - return AgentResponse( - metadata= self.create_metadata(classifier_result, - user_input, - user_id, - session_id, - additional_params), - output = self.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE \ - if self.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE else str(error), + self.logger.error(f"Error during agent processing: {str(error)}") + raise error + + async def route_request(self, + user_input: str, + user_id: str, + session_id: str, + additional_params: Dict[str, str] = {}) -> AgentResponse: + """Route user request to appropriate agent.""" + self.execution_times.clear() + + try: + classifier_result = await self.classify_request(user_input, user_id, session_id) + + if not classifier_result.selected_agent: + return AgentResponse( + metadata=self.create_metadata(classifier_result, user_input, user_id, session_id, additional_params), + output=ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=[{'text': self.config.NO_SELECTED_AGENT_MESSAGE}] + ), streaming=False ) + return await self.agent_process_request( + user_input, + user_id, + session_id, + classifier_result, + additional_params + ) + + except Exception as error: + return AgentResponse( + metadata=self.create_metadata(None, user_input, user_id, session_id, additional_params), + output=self.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE or str(error), + streaming=False + ) + finally: self.logger.print_execution_times(self.execution_times) diff --git a/typescript/src/orchestrator.ts b/typescript/src/orchestrator.ts index 77cde09e..622f65bd 100644 --- a/typescript/src/orchestrator.ts +++ b/typescript/src/orchestrator.ts @@ -333,47 +333,41 @@ export class MultiAgentOrchestrator { } } - async routeRequest( + async classifyRequest( userInput: string, userId: string, - sessionId: string, - additionalParams: Record = {} - ): Promise { - this.executionTimes = new Map(); - let classifierResult: ClassifierResult; - const chatHistory = (await this.storage.fetchAllChats(userId, sessionId)) || []; - + sessionId: string + ): Promise { try { - classifierResult = await this.measureExecutionTime( + const chatHistory = await this.storage.fetchAllChats(userId, sessionId) || []; + const classifierResult = await this.measureExecutionTime( "Classifying user intent", () => this.classifier.classify(userInput, chatHistory) ); - + this.logger.printIntent(userInput, classifierResult); + + if (!classifierResult.selectedAgent && this.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED && this.defaultAgent) { + const fallbackResult = this.getFallbackResult(); + this.logger.info("Using default agent as no agent was selected"); + return fallbackResult; + } + + return classifierResult; } catch (error) { this.logger.error("Error during intent classification:", error); - return { - metadata: this.createMetadata(null, userInput, userId, sessionId, additionalParams), - output: this.config.CLASSIFICATION_ERROR_MESSAGE ? this.config.CLASSIFICATION_ERROR_MESSAGE: String(error), - streaming: false, - }; + throw error; } - + } + + async agentProcessRequest( + userInput: string, + userId: string, + sessionId: string, + classifierResult: ClassifierResult, + additionalParams: Record = {} + ): Promise { try { - // Handle case where no agent was selected - if (!classifierResult.selectedAgent) { - if (this.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED && this.defaultAgent) { - classifierResult = this.getFallbackResult(); - this.logger.info("Using default agent as no agent was selected"); - } else { - return { - metadata: this.createMetadata(classifierResult, userInput, userId, sessionId, additionalParams), - output: this.config.NO_SELECTED_AGENT_MESSAGE!, - streaming: false, - }; - } - } - const agentResponse = await this.dispatchToAgent({ userInput, userId, @@ -381,9 +375,9 @@ export class MultiAgentOrchestrator { classifierResult, additionalParams, }); - + const metadata = this.createMetadata(classifierResult, userInput, userId, sessionId, additionalParams); - + if (this.isAsyncIterable(agentResponse)) { const accumulatorTransform = new AccumulatorTransform(); this.processStreamInBackground( @@ -400,8 +394,7 @@ export class MultiAgentOrchestrator { streaming: true, }; } - - // Check if we should save the conversation + if (classifierResult?.selectedAgent.saveChat) { await saveConversationExchange( userInput, @@ -413,25 +406,49 @@ export class MultiAgentOrchestrator { this.config.MAX_MESSAGE_PAIRS_PER_AGENT ); } - - + return { metadata, output: agentResponse, streaming: false, }; } catch (error) { - this.logger.error("Error during agent dispatch or processing:", error); - + this.logger.error("Error during agent processing:", error); + throw error; + } + } + + async routeRequest( + userInput: string, + userId: string, + sessionId: string, + additionalParams: Record = {} + ): Promise { + this.executionTimes = new Map(); + + try { + const classifierResult = await this.classifyRequest(userInput, userId, sessionId); + + if (!classifierResult.selectedAgent) { + return { + metadata: this.createMetadata(classifierResult, userInput, userId, sessionId, additionalParams), + output: this.config.NO_SELECTED_AGENT_MESSAGE!, + streaming: false, + }; + } + + return await this.agentProcessRequest(userInput, userId, sessionId, classifierResult, additionalParams); + } catch (error) { return { - metadata: this.createMetadata(classifierResult, userInput, userId, sessionId, additionalParams), - output: this.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE ? this.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE: String(error), + metadata: this.createMetadata(null, userInput, userId, sessionId, additionalParams), + output: this.config.GENERAL_ROUTING_ERROR_MSG_MESSAGE || String(error), streaming: false, }; } finally { this.logger.printExecutionTimes(this.executionTimes); } } + private async processStreamInBackground( agentResponse: AsyncIterable,