diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index 5479a38c1..6498bc772 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -31,6 +31,7 @@ from a2a.types.a2a_pb2 import ( Message, Task, + TaskArtifactUpdateEvent, TaskState, TaskStatus, TaskStatusUpdateEvent, @@ -68,6 +69,210 @@ def __init__(self, request_id: uuid.UUID): self.request_id = request_id +class EventConsumer: + """Consumes events from the agent and updates system state.""" + + def __init__(self, active_task: ActiveTask) -> None: + self.active_task = active_task + self.task_mode: bool | None = None + self.message_to_save: Message | None = None + + async def run(self) -> None: + """Consumes events from the agent and updates system state.""" + logger.debug('Consumer[%s]: Started', self.active_task._task_id) + try: + while True: + logger.debug( + 'Consumer[%s]: Waiting for event', + self.active_task._task_id, + ) + event = ( + await self.active_task._event_queue_agent.dequeue_event() + ) + logger.debug( + 'Consumer[%s]: Dequeued event %s', + self.active_task._task_id, + type(event).__name__, + ) + + await self._process_event(event) + + except QueueShutDown: + logger.debug( + 'Consumer[%s]: Event queue shut down', + self.active_task._task_id, + ) + except Exception as e: + logger.exception('Consumer[%s]: Failed', self.active_task._task_id) + async with self.active_task._lock: + await self.active_task._mark_task_as_failed(e) + + async def _process_event(self, event: Event) -> None: + updated_task = None + + try: + if isinstance(event, _RequestCompleted): + logger.debug( + 'Consumer[%s]: Request completed', self.active_task._task_id + ) + self.active_task._request_lock.release() + elif isinstance(event, _RequestStarted): + logger.debug( + 'Consumer[%s]: Request started', self.active_task._task_id + ) + self.message_to_save = event.request_context.message + elif isinstance(event, Message): + self._handle_message_event(event) + else: + updated_task = await self._handle_task_event(event) + if isinstance(event, Task): + event = updated_task + + if updated_task is not None: + await self._update_task_state(updated_task, event) + self.active_task._task_created.set() + + finally: + await self._enqueue_to_subscribers(event, updated_task) + + def _handle_message_event(self, event: Message) -> None: + if self.task_mode is True: + raise InvalidAgentResponseError( + 'Received Message object in task mode. Use TaskStatusUpdateEvent or TaskArtifactUpdateEvent instead.' + ) + if self.task_mode is False: + raise InvalidAgentResponseError( + 'Multiple Message objects received.' + ) + self.task_mode = False + + async def _handle_task_event( + self, + event: Task + | TaskStatusUpdateEvent + | TaskArtifactUpdateEvent + | PushNotificationEvent, + ) -> Task: + if self.task_mode is False: + raise InvalidAgentResponseError( + f'Received {type(event).__name__} in message mode. Use Task with TaskStatusUpdateEvent and TaskArtifactUpdateEvent instead.' + ) + + if isinstance(event, Task): + await self._handle_initial_task(event) + else: + await self._handle_task_modification_event(event) + + self.task_mode = True + task = await self.active_task._task_manager.get_task() + if task is None: + raise RuntimeError(f'Task {self.active_task.task_id} not found') + return task + + async def _handle_initial_task(self, event: Task) -> None: + existing_task = await self.active_task._task_manager.get_task() + if existing_task: + logger.error( + 'Task %s already exists. Ignoring task replacement.', + self.active_task._task_id, + ) + else: + await self.active_task._task_manager.save_task_event(event) + self.message_to_save = None + + async def _handle_task_modification_event( + self, + event: TaskStatusUpdateEvent + | TaskArtifactUpdateEvent + | PushNotificationEvent, + ) -> None: + if ( + isinstance(event, TaskStatusUpdateEvent) + and not self.active_task._task_created.is_set() + ): + task = await self.active_task._task_manager.get_task() + if task is None: + raise InvalidAgentResponseError( + f'Agent should enqueue Task before {type(event).__name__} event' + ) + + updated_task = await self.active_task._task_manager.ensure_task_id( + self.active_task._task_id, + event.context_id, + ) + + if self.message_to_save is not None: + updated_task = self.active_task._task_manager.update_with_message( + self.message_to_save, + updated_task, + ) + await self.active_task._task_manager.save_task_event(updated_task) + self.message_to_save = None + + self.active_task._task_manager.context_id = event.context_id + await self.active_task._task_manager.process(event) + + async def _update_task_state( + self, + updated_task: Task, + event: Task + | TaskStatusUpdateEvent + | TaskArtifactUpdateEvent + | PushNotificationEvent, + ) -> None: + is_terminal = updated_task.status.state in TERMINAL_TASK_STATES + + if is_terminal: + await self._handle_terminal_state(updated_task) + + if ( + self.active_task._push_sender + and self.active_task._task_id + and isinstance(event, PushNotificationEvent) + ): + logger.debug( + 'Consumer[%s]: Sending push notification', + self.active_task._task_id, + ) + await self.active_task._push_sender.send_notification( + self.active_task._task_id, event + ) + + async def _handle_terminal_state(self, updated_task: Task) -> None: + logger.debug( + 'Consumer[%s]: Reached terminal state %s', + self.active_task._task_id, + updated_task.status.state, + ) + if not self.active_task._is_finished.is_set(): + async with self.active_task._lock: + self.active_task._reference_count -= 1 + + self.active_task._is_finished.set() + self.active_task._request_queue.shutdown(immediate=True) + + async def _enqueue_to_subscribers( + self, event: Event, updated_task: Task | None + ) -> None: + if updated_task is not None: + updated_task_copy = Task() + updated_task_copy.CopyFrom(updated_task) + if event is updated_task: + event = updated_task_copy + updated_task = updated_task_copy + + logger.debug( + 'Consumer[%s]: Enqueuing\nEvent: %s\nUpdated Task: %s\n', + self.active_task._task_id, + event, + updated_task, + ) + await self.active_task._event_queue_subscribers.enqueue_event( + cast('Any', (event, updated_task)) + ) + self.active_task._event_queue_agent.task_done() + + class ActiveTask: """Manages the lifecycle and execution of an active A2A task. @@ -320,233 +525,17 @@ async def _run_producer(self) -> None: await self._event_queue_subscribers.close(immediate=False) logger.debug('Producer[%s]: Completed', self._task_id) - async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 - """Consumes events from the agent and updates system state. - - This continuous loop dequeues events emitted by the producer, updates the - database via `TaskManager`, and intercepts critical task states (e.g., - INPUT_REQUIRED, COMPLETED, FAILED) to cache the final result. - - Concurrency Guarantee: - Runs as a detached asyncio.Task. The loop ends gracefully when the producer - closes the queue (raising `QueueShutDown`). Upon termination, it formally sets - `_is_finished`, unblocking all global subscribers and wait() calls. - """ - logger.debug('Consumer[%s]: Started', self._task_id) - task_mode = None - message_to_save = None - # TODO: Make helper methods - # TODO: Support Task enqueue + async def _run_consumer(self) -> None: + """Consumes events from the agent and updates system state.""" try: - try: - try: - while True: - # Dequeue event. This raises QueueShutDown when finished. - logger.debug( - 'Consumer[%s]: Waiting for event', - self._task_id, - ) - new_task = None - event = await self._event_queue_agent.dequeue_event() - logger.debug( - 'Consumer[%s]: Dequeued event %s', - self._task_id, - type(event).__name__, - ) - - try: - if isinstance(event, _RequestCompleted): - logger.debug( - 'Consumer[%s]: Request completed', - self._task_id, - ) - self._request_lock.release() - elif isinstance(event, _RequestStarted): - logger.debug( - 'Consumer[%s]: Request started', - self._task_id, - ) - message_to_save = event.request_context.message - - elif isinstance(event, Message): - if task_mode is not None: - if task_mode: - raise InvalidAgentResponseError( - 'Received Message object in task mode. Use TaskStatusUpdateEvent or TaskArtifactUpdateEvent instead.' - ) - raise InvalidAgentResponseError( - 'Multiple Message objects received.' - ) - task_mode = False - logger.debug( - 'Consumer[%s]: Setting result to Message: %s', - self._task_id, - event, - ) - else: - if task_mode is False: - raise InvalidAgentResponseError( - f'Received {type(event).__name__} in message mode. Use Task with TaskStatusUpdateEvent and TaskArtifactUpdateEvent instead.' - ) - - if isinstance(event, Task): - existing_task = ( - await self._task_manager.get_task() - ) - if existing_task: - logger.error( - 'Task %s already exists. Ignoring task replacement.', - self._task_id, - ) - else: - await ( - self._task_manager.save_task_event( - event - ) - ) - # Initial task should already contain the message. - message_to_save = None - else: - if ( - isinstance(event, TaskStatusUpdateEvent) - and not self._task_created.is_set() - ): - task = ( - await self._task_manager.get_task() - ) - if task is None: - raise InvalidAgentResponseError( - f'Agent should enqueue Task before {type(event).__name__} event' - ) - - new_task = ( - await self._task_manager.ensure_task_id( - self._task_id, - event.context_id, - ) - ) - - if message_to_save is not None: - new_task = self._task_manager.update_with_message( - message_to_save, - new_task, - ) - await ( - self._task_manager.save_task_event( - new_task - ) - ) - message_to_save = None - - task_mode = True - # Save structural events (like TaskStatusUpdate) to DB. - - self._task_manager.context_id = event.context_id - if not isinstance(event, Task): - await self._task_manager.process(event) - - # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states - new_task = await self._task_manager.get_task() - if new_task is None: - raise RuntimeError( - f'Task {self.task_id} not found' - ) - if isinstance(event, Task): - event = new_task - is_interrupted = ( - new_task.status.state - in INTERRUPTED_TASK_STATES - ) - is_terminal = ( - new_task.status.state - in TERMINAL_TASK_STATES - ) - - # If we hit a breakpoint or terminal state, lock in the result. - if is_interrupted or is_terminal: - logger.debug( - 'Consumer[%s]: Setting first result as Task (state=%s)', - self._task_id, - new_task.status.state, - ) - - if is_terminal: - logger.debug( - 'Consumer[%s]: Reached terminal state %s', - self._task_id, - new_task.status.state, - ) - if not self._is_finished.is_set(): - async with self._lock: - # TODO: what about _reference_count when task is failing? - self._reference_count -= 1 - # _maybe_cleanup() is called in finally block. - - # Terminate the ActiveTask globally. - self._is_finished.set() - self._request_queue.shutdown(immediate=True) - - if is_interrupted: - logger.debug( - 'Consumer[%s]: Interrupted with state %s', - self._task_id, - new_task.status.state, - ) - - if ( - self._push_sender - and self._task_id - and isinstance(event, PushNotificationEvent) - ): - logger.debug( - 'Consumer[%s]: Sending push notification', - self._task_id, - ) - await self._push_sender.send_notification( - self._task_id, event - ) - - self._task_created.set() - - finally: - if new_task is not None: - new_task_copy = Task() - new_task_copy.CopyFrom(new_task) - new_task = new_task_copy - if isinstance(event, Task): - new_task_copy = Task() - new_task_copy.CopyFrom(event) - event = new_task_copy - - logger.debug( - 'Consumer[%s]: Enqueuing\nEvent: %s\nNew Task: %s\n', - self._task_id, - event, - new_task, - ) - await self._event_queue_subscribers.enqueue_event( - cast('Any', (event, new_task)) - ) - self._event_queue_agent.task_done() - except QueueShutDown: - logger.debug( - 'Consumer[%s]: Event queue shut down', self._task_id - ) - except Exception as e: - logger.exception('Consumer[%s]: Failed', self._task_id) - # TODO: Make the task in database as failed. - async with self._lock: - await self._mark_task_as_failed(e) - finally: - # The consumer is dead. The ActiveTask is permanently finished. - self._is_finished.set() - self._request_queue.shutdown(immediate=True) - await self._event_queue_agent.close(immediate=True) - - logger.debug('Consumer[%s]: Finishing', self._task_id) - await self._maybe_cleanup() + await EventConsumer(self).run() finally: - logger.debug('Consumer[%s]: Completed', self._task_id) + self._is_finished.set() + self._request_queue.shutdown(immediate=True) + await self._event_queue_agent.close(immediate=True) + + logger.debug('Consumer[%s]: Finishing', self._task_id) + await self._maybe_cleanup() async def subscribe( # noqa: PLR0912, PLR0915 self,