diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 73818db59..1701f14ef 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -20,6 +20,7 @@ amannn aproject ARequest ARun +ARoute AServer AServers AService @@ -94,6 +95,8 @@ notif npx oauthoidc oidc +oneof +oneofs Oneof OpenAPI openapiv @@ -112,6 +115,7 @@ proto protobuf Protobuf protoc +protojson pydantic pyi pypistats @@ -125,6 +129,7 @@ rmi RS256 RUF SECP256R1 +SFIXED SLF socio sse diff --git a/docs/migrations/v1_0/README.md b/docs/migrations/v1_0/README.md index da3d6ba79..0de221aca 100644 --- a/docs/migrations/v1_0/README.md +++ b/docs/migrations/v1_0/README.md @@ -465,6 +465,30 @@ app = FastAPI(routes=routes) uvicorn.run(app, host=host, port=port) ``` +`FastAPI(routes=routes)` mounts the A2A endpoints correctly, but FastAPI's OpenAPI generator only enumerates routes that are `fastapi.routing.APIRoute` instances, so the A2A endpoints will not appear in `/docs` or `/openapi.json`. To make them visible in the auto-generated OpenAPI schema — grouped into Agent Card, JSON-RPC, and REST sections — use the `add_a2a_routes_to_fastapi` helper: + +```python +from fastapi import FastAPI +import uvicorn + +from a2a.server.routes import ( + add_a2a_routes_to_fastapi, + create_agent_card_routes, + create_jsonrpc_routes, + create_rest_routes, +) + +app = FastAPI() +add_a2a_routes_to_fastapi( + app, + agent_card_routes=create_agent_card_routes(agent_card), + jsonrpc_routes=create_jsonrpc_routes(request_handler, rpc_url='/'), + rest_routes=create_rest_routes(request_handler), +) + +uvicorn.run(app, host=host, port=port) +``` + > **Example**: [`a2a-mcp-without-framework/server/__main__.py` in PR #509](https://github.com/a2aproject/a2a-samples/pull/509/files#diff-d15d39ae64c3d4e3a36cc6fb442302caf4e32a6dbd858792e7a4bed180a625ac) --- diff --git a/src/a2a/server/routes/__init__.py b/src/a2a/server/routes/__init__.py index 007e2722f..0767fd853 100644 --- a/src/a2a/server/routes/__init__.py +++ b/src/a2a/server/routes/__init__.py @@ -5,6 +5,7 @@ DefaultServerCallContextBuilder, ServerCallContextBuilder, ) +from a2a.server.routes.helpers import add_a2a_routes_to_fastapi from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes from a2a.server.routes.rest_routes import create_rest_routes @@ -12,6 +13,7 @@ __all__ = [ 'DefaultServerCallContextBuilder', 'ServerCallContextBuilder', + 'add_a2a_routes_to_fastapi', 'create_agent_card_routes', 'create_jsonrpc_routes', 'create_rest_routes', diff --git a/src/a2a/server/routes/agent_card_routes.py b/src/a2a/server/routes/agent_card_routes.py index 924a3d9dc..68e29c05f 100644 --- a/src/a2a/server/routes/agent_card_routes.py +++ b/src/a2a/server/routes/agent_card_routes.py @@ -41,6 +41,7 @@ def create_agent_card_routes( ) async def _get_agent_card(request: Request) -> Response: + """Returns the public AgentCard describing this agent's capabilities, supported transports, and skills.""" card_to_serve = agent_card if card_modifier: card_to_serve = await card_modifier(card_to_serve) diff --git a/src/a2a/server/routes/helpers/__init__.py b/src/a2a/server/routes/helpers/__init__.py new file mode 100644 index 000000000..9fef86ad7 --- /dev/null +++ b/src/a2a/server/routes/helpers/__init__.py @@ -0,0 +1,6 @@ +from a2a.server.routes.helpers.fastapi import add_a2a_routes_to_fastapi + + +__all__ = [ + 'add_a2a_routes_to_fastapi', +] diff --git a/src/a2a/server/routes/helpers/_proto_schema.py b/src/a2a/server/routes/helpers/_proto_schema.py new file mode 100644 index 000000000..7c82094d1 --- /dev/null +++ b/src/a2a/server/routes/helpers/_proto_schema.py @@ -0,0 +1,118 @@ +"""Proto → JSON Schema helpers shared across transport helpers.""" + +from typing import Any + +from google.protobuf.descriptor import Descriptor, FieldDescriptor +from google.protobuf.message import Message + +from a2a.types.a2a_pb2 import SendMessageRequest, TaskPushNotificationConfig + + +REST_BODY_TYPES: dict[tuple[str, str], type[Message]] = { + ('/message:send', 'POST'): SendMessageRequest, + ('/message:stream', 'POST'): SendMessageRequest, + ('/tasks/{id}/pushNotificationConfigs', 'POST'): TaskPushNotificationConfig, +} + +# 64-bit integer types serialize as strings in protojson. +_PROTO_SCALAR_SCHEMAS: dict[int, dict[str, Any]] = { + FieldDescriptor.TYPE_DOUBLE: {'type': 'number'}, + FieldDescriptor.TYPE_FLOAT: {'type': 'number'}, + FieldDescriptor.TYPE_INT64: {'type': 'string', 'format': 'int64'}, + FieldDescriptor.TYPE_UINT64: {'type': 'string', 'format': 'uint64'}, + FieldDescriptor.TYPE_INT32: {'type': 'integer', 'format': 'int32'}, + FieldDescriptor.TYPE_FIXED64: {'type': 'string', 'format': 'fixed64'}, + FieldDescriptor.TYPE_FIXED32: {'type': 'integer', 'format': 'fixed32'}, + FieldDescriptor.TYPE_BOOL: {'type': 'boolean'}, + FieldDescriptor.TYPE_STRING: {'type': 'string'}, + FieldDescriptor.TYPE_BYTES: {'type': 'string', 'format': 'byte'}, + FieldDescriptor.TYPE_UINT32: {'type': 'integer', 'format': 'uint32'}, + FieldDescriptor.TYPE_SFIXED32: {'type': 'integer'}, + FieldDescriptor.TYPE_SFIXED64: {'type': 'string'}, + FieldDescriptor.TYPE_SINT32: {'type': 'integer'}, + FieldDescriptor.TYPE_SINT64: {'type': 'string'}, +} + +_WELL_KNOWN_SCHEMAS: dict[str, dict[str, Any]] = { + 'google.protobuf.Timestamp': {'type': 'string', 'format': 'date-time'}, + 'google.protobuf.Duration': {'type': 'string'}, + 'google.protobuf.Struct': {'type': 'object'}, + 'google.protobuf.Value': {}, + 'google.protobuf.ListValue': {'type': 'array', 'items': {}}, + 'google.protobuf.Empty': {'type': 'object'}, + 'google.protobuf.Any': {'type': 'object'}, + 'google.protobuf.FieldMask': {'type': 'string'}, +} + + +def field_schema( + field: FieldDescriptor, components: dict[str, Any] +) -> dict[str, Any]: + if field.message_type and field.message_type.GetOptions().map_entry: + value_field = field.message_type.fields_by_name['value'] + return { + 'type': 'object', + 'additionalProperties': field_schema(value_field, components), + } + + if field.type == FieldDescriptor.TYPE_MESSAGE: + item = message_schema(field.message_type, components) + elif field.type == FieldDescriptor.TYPE_ENUM: + item = { + 'type': 'string', + 'enum': [v.name for v in field.enum_type.values], + } + else: + item = dict(_PROTO_SCALAR_SCHEMAS.get(field.type, {'type': 'string'})) + + if field.is_repeated: + return {'type': 'array', 'items': item} + return item + + +def message_schema( + descriptor: Descriptor | Any, components: dict[str, Any] +) -> dict[str, Any]: + """Returns a $ref to descriptor's schema, registering it in components if needed.""" + if descriptor.full_name in _WELL_KNOWN_SCHEMAS: + return dict(_WELL_KNOWN_SCHEMAS[descriptor.full_name]) + + name = descriptor.name + ref = {'$ref': f'#/components/schemas/{name}'} + if name in components: + return ref + + # Reserve the slot before recursing so cyclic types terminate. + components[name] = {} + + real_oneofs = [o for o in descriptor.oneofs if len(o.fields) > 1] + oneof_field_names = {f.name for o in real_oneofs for f in o.fields} + base_properties = { + f.name: field_schema(f, components) + for f in descriptor.fields + if f.name not in oneof_field_names + } + + if not real_oneofs: + components[name] = {'type': 'object', 'properties': base_properties} + return ref + + oneof_constraints = [ + { + 'oneOf': [ + { + 'type': 'object', + 'properties': {f.name: field_schema(f, components)}, + 'required': [f.name], + } + for f in oneof.fields + ] + } + for oneof in real_oneofs + ] + parts: list[dict[str, Any]] = [] + if base_properties: + parts.append({'type': 'object', 'properties': base_properties}) + parts.extend(oneof_constraints) + components[name] = parts[0] if len(parts) == 1 else {'allOf': parts} + return ref diff --git a/src/a2a/server/routes/helpers/fastapi.py b/src/a2a/server/routes/helpers/fastapi.py new file mode 100644 index 000000000..7239bdb82 --- /dev/null +++ b/src/a2a/server/routes/helpers/fastapi.py @@ -0,0 +1,196 @@ +from typing import TYPE_CHECKING, Any + +from a2a.server.routes.helpers._proto_schema import ( + REST_BODY_TYPES, + message_schema, +) +from a2a.server.routes.helpers.jsonrpc import ( + DESCRIPTION as _JSONRPC_DESCRIPTION, +) +from a2a.server.routes.helpers.jsonrpc import ( + envelope_schema as _jsonrpc_envelope_schema, +) +from a2a.utils.constants import PROTOCOL_VERSION_1_0, VERSION_HEADER + + +if TYPE_CHECKING: + from fastapi import FastAPI + from fastapi.routing import APIRoute as _A2ARoute + from starlette.routing import BaseRoute, Route + + _package_fastapi_installed = True +else: + try: + from fastapi.routing import APIRoute + from starlette.routing import Route, request_response + + class _A2ARoute(APIRoute): + """APIRoute that uses Starlette's request_response to bypass FastAPI middleware scope requirements.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.app = request_response(self.endpoint) + + _package_fastapi_installed = True + except ImportError: + Route = Any + _A2ARoute = Any + + _package_fastapi_installed = False + + +_AGENT_CARD_TAG = 'A2A: Agent Card' +_JSONRPC_TAG = 'A2A: JSON-RPC' +_REST_TAG = 'A2A: REST' + +_A2A_VERSION_HEADER = { + 'in': 'header', + 'name': VERSION_HEADER, + 'required': True, + 'schema': {'type': 'string', 'enum': [PROTOCOL_VERSION_1_0]}, + 'example': PROTOCOL_VERSION_1_0, +} + + +def _request_body_extra( + ref: dict[str, Any], description: str +) -> dict[str, Any]: + return { + 'requestBody': { + 'description': description, + 'required': True, + 'content': {'application/json': {'schema': ref}}, + }, + } + + +def _rest_body_extra( + route: 'Route', rest_bodies: dict[tuple[str, str], dict[str, Any]] +) -> dict[str, Any] | None: + methods = route.methods or set() + for (suffix, method), extra in rest_bodies.items(): + if method in methods and route.path.endswith(suffix): + return extra + return None + + +def _attach_route( + app: 'FastAPI', + route: 'BaseRoute', + tag: str, + openapi_extra: dict[str, Any] | None, + require_version_header: bool = False, +) -> None: + if not (isinstance(route, Route) and route.methods): + app.routes.append(route) + return + # Drop HEAD: Starlette adds it alongside GET, but FastAPI registers duplicate operation IDs. + methods = sorted(m for m in route.methods if m != 'HEAD') + if require_version_header: + extra = dict(openapi_extra or {}) + extra.setdefault('parameters', [_A2A_VERSION_HEADER]) + openapi_extra = extra + app.routes.append( + _A2ARoute( + path=route.path, + endpoint=route.endpoint, + methods=methods, + tags=[tag], + openapi_extra=openapi_extra, + ) + ) + + +def add_a2a_routes_to_fastapi( + app: 'FastAPI', + *, + agent_card_routes: 'list[BaseRoute] | None' = None, + jsonrpc_routes: 'list[BaseRoute] | None' = None, + rest_routes: 'list[BaseRoute] | None' = None, +) -> None: + """Mounts A2A routes on a FastAPI app and enriches them for ``/docs``. + + Re-registers Starlette routes as ``APIRoute`` instances so they appear in + the auto-generated OpenAPI schema, tagged and annotated with proto-derived + request-body schemas. + + Usage:: + + app = FastAPI() + add_a2a_routes_to_fastapi( + app, + agent_card_routes=create_agent_card_routes(agent_card), + jsonrpc_routes=create_jsonrpc_routes(request_handler, rpc_url='/'), + rest_routes=create_rest_routes(request_handler), + ) + + Args: + app: The FastAPI application to mount the routes on. + agent_card_routes: Routes returned by ``create_agent_card_routes``. + jsonrpc_routes: Routes returned by ``create_jsonrpc_routes``. + rest_routes: Routes returned by ``create_rest_routes``. + """ + if not _package_fastapi_installed: + raise ImportError( + 'The `fastapi` package is required to use ' + '`add_a2a_routes_to_fastapi`. Install it alongside ' + '`a2a-sdk[http-server]`.' + ) + + components: dict[str, Any] = {} + jsonrpc_extra = { + 'summary': 'A2A JSON-RPC endpoint', + 'description': _JSONRPC_DESCRIPTION, + **_request_body_extra( + _jsonrpc_envelope_schema(components), 'A2A JSON-RPC 2.0 request' + ), + } + rest_extras = { + key: _request_body_extra( + message_schema(cls.DESCRIPTOR, components), + f'A2A {cls.__name__}', + ) + for key, cls in REST_BODY_TYPES.items() + } + + for route in agent_card_routes or (): + _attach_route(app, route, _AGENT_CARD_TAG, openapi_extra=None) + + for route in jsonrpc_routes or (): + extra = jsonrpc_extra if isinstance(route, Route) else None + _attach_route( + app, + route, + _JSONRPC_TAG, + openapi_extra=extra, + require_version_header=True, + ) + + for route in rest_routes or (): + extra = ( + _rest_body_extra(route, rest_extras) + if isinstance(route, Route) + else None + ) + _attach_route( + app, + route, + _REST_TAG, + openapi_extra=extra, + require_version_header=True, + ) + + original_openapi = app.openapi + + def _openapi() -> dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + schema = original_openapi() + component_schemas = schema.setdefault('components', {}).setdefault( + 'schemas', {} + ) + for name, sub_schema in components.items(): + component_schemas.setdefault(name, sub_schema) + return schema + + app.openapi = _openapi # type: ignore[method-assign] diff --git a/src/a2a/server/routes/helpers/jsonrpc.py b/src/a2a/server/routes/helpers/jsonrpc.py new file mode 100644 index 000000000..a85b772fe --- /dev/null +++ b/src/a2a/server/routes/helpers/jsonrpc.py @@ -0,0 +1,86 @@ +"""JSON-RPC specific helpers for A2A server routes.""" + +from typing import Any + +from google.protobuf.message import Message + +from a2a.server.routes.helpers._proto_schema import message_schema +from a2a.types.a2a_pb2 import ( + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + SendMessageRequest, + SubscribeToTaskRequest, + TaskPushNotificationConfig, +) + + +METHOD_TYPES: dict[str, type[Message]] = { + 'SendMessage': SendMessageRequest, + 'SendStreamingMessage': SendMessageRequest, + 'GetTask': GetTaskRequest, + 'ListTasks': ListTasksRequest, + 'CancelTask': CancelTaskRequest, + 'CreateTaskPushNotificationConfig': TaskPushNotificationConfig, + 'GetTaskPushNotificationConfig': GetTaskPushNotificationConfigRequest, + 'ListTaskPushNotificationConfigs': ListTaskPushNotificationConfigsRequest, + 'DeleteTaskPushNotificationConfig': DeleteTaskPushNotificationConfigRequest, + 'SubscribeToTask': SubscribeToTaskRequest, + 'GetExtendedAgentCard': GetExtendedAgentCardRequest, +} + +DESCRIPTION = """\ +A2A JSON-RPC 2.0 endpoint. The `method` field selects the operation; +`params` must match that method's schema (see the `oneOf` below). + +**Supported methods:** + +- `SendMessage` — Send a message to the agent (returns a Task or response Message). +- `SendStreamingMessage` — Send a message and receive a Server-Sent Events stream. +- `GetTask` — Fetch a task by ID. +- `ListTasks` — List tasks with pagination and filtering. +- `CancelTask` — Cancel an in-progress task. +- `CreateTaskPushNotificationConfig` — Register a push-notification config on a task. +- `GetTaskPushNotificationConfig` — Read a single push-notification config. +- `ListTaskPushNotificationConfigs` — List all push-notification configs for a task. +- `DeleteTaskPushNotificationConfig` — Delete a push-notification config. +- `SubscribeToTask` — Subscribe to task events via Server-Sent Events. +- `GetExtendedAgentCard` — Fetch the authenticated extended agent card. +""" + + +def envelope_schema(components: dict[str, Any]) -> dict[str, Any]: + """Builds the A2ARequest JSON-RPC envelope schema with a oneOf over all method params.""" + seen_refs: set[str] = set() + params_refs: list[dict[str, Any]] = [] + for cls in METHOD_TYPES.values(): + ref = message_schema(cls.DESCRIPTOR, components) + key = ref.get('$ref', '') + if key and key not in seen_refs: + seen_refs.add(key) + params_refs.append(ref) + + components['A2ARequest'] = { + 'type': 'object', + 'required': ['jsonrpc', 'method'], + 'properties': { + 'jsonrpc': {'type': 'string', 'enum': ['2.0']}, + 'id': { + 'oneOf': [ + {'type': 'string'}, + {'type': 'integer'}, + {'type': 'null'}, + ], + }, + 'method': { + 'type': 'string', + 'enum': list(METHOD_TYPES), + }, + 'params': {'oneOf': params_refs}, + }, + } + return {'$ref': '#/components/schemas/A2ARequest'} diff --git a/tests/server/routes/helpers/__init__.py b/tests/server/routes/helpers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/routes/helpers/test_fastapi.py b/tests/server/routes/helpers/test_fastapi.py new file mode 100644 index 000000000..745bdff25 --- /dev/null +++ b/tests/server/routes/helpers/test_fastapi.py @@ -0,0 +1,178 @@ +from unittest.mock import AsyncMock + +import pytest + +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes import ( + add_a2a_routes_to_fastapi, + create_agent_card_routes, + create_jsonrpc_routes, + create_rest_routes, +) +from a2a.server.routes.helpers.fastapi import ( + _AGENT_CARD_TAG, + _JSONRPC_TAG, + _REST_TAG, +) +from a2a.types.a2a_pb2 import AgentCard, Task +from a2a.utils.constants import ( + AGENT_CARD_WELL_KNOWN_PATH, + PROTOCOL_VERSION_1_0, + VERSION_HEADER, +) + + +fastapi = pytest.importorskip('fastapi') +from fastapi import FastAPI # noqa: E402 +from fastapi.testclient import TestClient # noqa: E402 + + +@pytest.fixture +def agent_card() -> AgentCard: + return AgentCard(name='Test Agent', version='1.0.0') + + +@pytest.fixture +def mock_handler() -> AsyncMock: + return AsyncMock(spec=RequestHandler) + + +def _build_app(agent_card: AgentCard, mock_handler: AsyncMock) -> FastAPI: + app = FastAPI() + add_a2a_routes_to_fastapi( + app, + agent_card_routes=create_agent_card_routes(agent_card), + jsonrpc_routes=create_jsonrpc_routes(mock_handler, rpc_url='/'), + rest_routes=create_rest_routes(mock_handler), + ) + return app + + +def test_routes_appear_in_openapi_with_tags( + agent_card: AgentCard, mock_handler: AsyncMock +) -> None: + """Each group is documented and tagged for the Swagger UI.""" + app = _build_app(agent_card, mock_handler) + paths = app.openapi()['paths'] + + assert paths[AGENT_CARD_WELL_KNOWN_PATH]['get']['tags'] == [_AGENT_CARD_TAG] + assert paths['/']['post']['tags'] == [_JSONRPC_TAG] + assert paths['/message:send']['post']['tags'] == [_REST_TAG] + assert paths['/tasks']['get']['tags'] == [_REST_TAG] + + +def test_routes_dispatch_under_fastapi( + agent_card: AgentCard, mock_handler: AsyncMock +) -> None: + """Re-registered routes still dispatch correctly under a FastAPI app.""" + mock_handler.on_message_send.return_value = Task(id='task-123') + + app = _build_app(agent_card, mock_handler) + client = TestClient(app) + + assert client.get(AGENT_CARD_WELL_KNOWN_PATH).json()['name'] == 'Test Agent' + rpc_response = client.post( + '/', json={'jsonrpc': '2.0', 'id': '1', 'method': 'NoSuchMethod'} + ).json() + assert rpc_response['error']['code'] == -32601 + + rest_response = client.post( + '/message:send', + json={}, + headers={VERSION_HEADER: PROTOCOL_VERSION_1_0}, + ) + assert rest_response.status_code == 200 + assert rest_response.json()['task']['id'] == 'task-123' + + +def test_tenant_mount_still_dispatches(mock_handler: AsyncMock) -> None: + """`Mount` entries (tenant routing) keep dispatching after registration.""" + mock_handler.on_message_send.return_value = Task(id='tenant-task') + + app = FastAPI() + add_a2a_routes_to_fastapi(app, rest_routes=create_rest_routes(mock_handler)) + client = TestClient(app) + + response = client.post( + '/my-tenant/message:send', + json={}, + headers={VERSION_HEADER: PROTOCOL_VERSION_1_0}, + ) + assert response.status_code == 200 + context = mock_handler.on_message_send.call_args[0][1] + assert context.tenant == 'my-tenant' + + +def test_partial_groups(agent_card: AgentCard, mock_handler: AsyncMock) -> None: + """Calling with only a subset of groups works and tags only those.""" + app = FastAPI() + add_a2a_routes_to_fastapi( + app, + agent_card_routes=create_agent_card_routes(agent_card), + ) + paths = app.openapi()['paths'] + assert list(paths.keys()) == [AGENT_CARD_WELL_KNOWN_PATH] + + +def test_request_body_schemas_are_attached( + agent_card: AgentCard, mock_handler: AsyncMock +) -> None: + """JSON-RPC and REST POST bodies expose schemas derived from proto types.""" + app = _build_app(agent_card, mock_handler) + schema = app.openapi() + components = schema['components']['schemas'] + + assert 'A2ARequest' in components + assert components['A2ARequest']['properties']['method']['enum'] + rpc_body = schema['paths']['/']['post']['requestBody'] + assert rpc_body['content']['application/json']['schema'] == { + '$ref': '#/components/schemas/A2ARequest' + } + + send_body = schema['paths']['/message:send']['post']['requestBody'] + assert send_body['content']['application/json']['schema'] == { + '$ref': '#/components/schemas/SendMessageRequest' + } + + assert 'Message' in components + assert 'Part' in components + assert components['Message']['properties']['role']['enum'] == [ + 'ROLE_UNSPECIFIED', + 'ROLE_USER', + 'ROLE_AGENT', + ] + + +def test_routes_without_body_have_no_request_body( + agent_card: AgentCard, mock_handler: AsyncMock +) -> None: + """GET/DELETE/parameterless POST routes don't get a fabricated body.""" + app = _build_app(agent_card, mock_handler) + paths = app.openapi()['paths'] + + assert 'requestBody' not in paths[AGENT_CARD_WELL_KNOWN_PATH]['get'] + assert 'requestBody' not in paths['/tasks']['get'] + assert 'requestBody' not in paths['/tasks/{id}:cancel']['post'] + assert ( + 'requestBody' + not in paths['/tasks/{id}/pushNotificationConfigs/{push_id}']['delete'] + ) + + +def test_a2a_version_header_on_dispatcher_routes( + agent_card: AgentCard, mock_handler: AsyncMock +) -> None: + """JSON-RPC and REST routes declare the version header so Swagger pre-fills it.""" + app = _build_app(agent_card, mock_handler) + paths = app.openapi()['paths'] + + def _has_version_header(op: dict) -> bool: + return any( + p.get('name') == VERSION_HEADER for p in op.get('parameters', []) + ) + + assert _has_version_header(paths['/']['post']) + assert _has_version_header(paths['/message:send']['post']) + assert _has_version_header(paths['/tasks']['get']) + assert _has_version_header(paths['/tasks/{id}:cancel']['post']) + assert not _has_version_header(paths[AGENT_CARD_WELL_KNOWN_PATH]['get']) diff --git a/tests/server/routes/helpers/test_jsonrpc.py b/tests/server/routes/helpers/test_jsonrpc.py new file mode 100644 index 000000000..665047220 --- /dev/null +++ b/tests/server/routes/helpers/test_jsonrpc.py @@ -0,0 +1,71 @@ +from a2a.server.routes.helpers.jsonrpc import ( + DESCRIPTION, + METHOD_TYPES, + envelope_schema, +) + + +def test_envelope_schema_ref(): + components = {} + ref = envelope_schema(components) + assert ref == {'$ref': '#/components/schemas/A2ARequest'} + + +def test_envelope_schema_required_fields(): + components = {} + envelope_schema(components) + assert components['A2ARequest']['required'] == ['jsonrpc', 'method'] + + +def test_envelope_schema_method_enum_matches_method_types(): + components = {} + envelope_schema(components) + enum = components['A2ARequest']['properties']['method']['enum'] + assert set(enum) == set(METHOD_TYPES) + + +def test_envelope_schema_params_is_one_of(): + components = {} + envelope_schema(components) + params = components['A2ARequest']['properties']['params'] + assert 'oneOf' in params + assert len(params['oneOf']) > 0 + + +def test_envelope_schema_deduplicates_shared_param_types(): + # SendMessage and SendStreamingMessage share SendMessageRequest. + components = {} + envelope_schema(components) + refs = [ + r['$ref'] + for r in components['A2ARequest']['properties']['params']['oneOf'] + ] + assert len(refs) == len(set(refs)) + + +def test_envelope_schema_jsonrpc_version(): + components = {} + envelope_schema(components) + assert components['A2ARequest']['properties']['jsonrpc']['enum'] == ['2.0'] + + +def test_method_types_contains_all_a2a_methods(): + expected = { + 'SendMessage', + 'SendStreamingMessage', + 'GetTask', + 'ListTasks', + 'CancelTask', + 'CreateTaskPushNotificationConfig', + 'GetTaskPushNotificationConfig', + 'ListTaskPushNotificationConfigs', + 'DeleteTaskPushNotificationConfig', + 'SubscribeToTask', + 'GetExtendedAgentCard', + } + assert set(METHOD_TYPES) == expected + + +def test_description_lists_all_methods(): + for method in METHOD_TYPES: + assert method in DESCRIPTION diff --git a/tests/server/routes/helpers/test_proto_schema.py b/tests/server/routes/helpers/test_proto_schema.py new file mode 100644 index 000000000..7ae2dca73 --- /dev/null +++ b/tests/server/routes/helpers/test_proto_schema.py @@ -0,0 +1,132 @@ +from a2a.server.routes.helpers._proto_schema import ( + REST_BODY_TYPES, + field_schema, + message_schema, +) +from a2a.types.a2a_pb2 import Message, Part, SendMessageRequest + + +def test_message_schema_registers_ref(): + components = {} + ref = message_schema(SendMessageRequest.DESCRIPTOR, components) + assert ref == {'$ref': '#/components/schemas/SendMessageRequest'} + assert 'SendMessageRequest' in components + + +def test_message_schema_returns_cached_ref(): + components = {} + ref1 = message_schema(SendMessageRequest.DESCRIPTOR, components) + ref2 = message_schema(SendMessageRequest.DESCRIPTOR, components) + assert ref1 == ref2 + + +def test_message_schema_recurses_into_nested_types(): + components = {} + message_schema(SendMessageRequest.DESCRIPTOR, components) + assert 'Message' in components + assert 'Part' in components + + +def test_message_schema_well_known_type_inline(): + from google.protobuf.descriptor_pool import Default + + struct_descriptor = Default().FindMessageTypeByName( + 'google.protobuf.Struct' + ) + components = {} + schema = message_schema(struct_descriptor, components) + assert schema == {'type': 'object'} + assert 'Struct' not in components + + +def test_message_schema_oneof_becomes_allof_with_one_of_constraint(): + components = {} + message_schema(Part.DESCRIPTOR, components) + schema = components['Part'] + assert 'allOf' in schema + one_of_constraint = next(p for p in schema['allOf'] if 'oneOf' in p) + oneof_keys = {list(v['properties'])[0] for v in one_of_constraint['oneOf']} + assert {'text', 'raw', 'url', 'data'} <= oneof_keys + + +def test_message_schema_oneof_variants_have_required(): + components = {} + message_schema(Part.DESCRIPTOR, components) + one_of_constraint = next( + p for p in components['Part']['allOf'] if 'oneOf' in p + ) + for variant in one_of_constraint['oneOf']: + assert len(variant['required']) == 1 + + +def test_message_schema_multiple_oneofs_use_allof_not_cartesian_product(): + # Simulate a descriptor with two oneofs: verify allOf has one constraint + # per oneof rather than a flat list of cross-product variants. + from unittest.mock import MagicMock + + def _make_field(name): + f = MagicMock() + f.name = name + f.message_type = None + f.type = 9 # TYPE_STRING + f.is_repeated = False + return f + + def _make_oneof(fields): + o = MagicMock() + o.fields = fields + return o + + f_a, f_b = _make_field('a'), _make_field('b') + f_x, f_y = _make_field('x'), _make_field('y') + oneof1 = _make_oneof([f_a, f_b]) + oneof2 = _make_oneof([f_x, f_y]) + + descriptor = MagicMock() + descriptor.full_name = 'test.MultiOneof' + descriptor.name = 'MultiOneof' + descriptor.oneofs = [oneof1, oneof2] + descriptor.fields = [f_a, f_b, f_x, f_y] + + components = {} + message_schema(descriptor, components) + schema = components['MultiOneof'] + + # Should be allOf with two oneOf constraints (one per oneof group), + # NOT a flat oneOf with 2*2=4 Cartesian-product variants. + assert 'allOf' in schema + one_of_constraints = [p for p in schema['allOf'] if 'oneOf' in p] + assert len(one_of_constraints) == 2 + assert len(one_of_constraints[0]['oneOf']) == 2 + assert len(one_of_constraints[1]['oneOf']) == 2 + + +def test_field_schema_repeated_wraps_in_array(): + components = {} + msg_descriptor = SendMessageRequest.DESCRIPTOR.fields_by_name[ + 'message' + ].message_type + parts_field = msg_descriptor.fields_by_name['parts'] + schema = field_schema(parts_field, components) + assert schema['type'] == 'array' + assert 'items' in schema + + +def test_field_schema_enum(): + role_field = Message.DESCRIPTOR.fields_by_name['role'] + schema = field_schema(role_field, {}) + assert schema['type'] == 'string' + assert 'ROLE_USER' in schema['enum'] + assert 'ROLE_AGENT' in schema['enum'] + + +def test_field_schema_map_entry(): + metadata_field = SendMessageRequest.DESCRIPTOR.fields_by_name['metadata'] + schema = field_schema(metadata_field, {}) + assert schema == {'type': 'object'} + + +def test_rest_body_types_coverage(): + assert ('/message:send', 'POST') in REST_BODY_TYPES + assert ('/message:stream', 'POST') in REST_BODY_TYPES + assert ('/tasks/{id}/pushNotificationConfigs', 'POST') in REST_BODY_TYPES