diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 58df3fd9a5..31e2be4145 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -32,6 +32,7 @@ from typing import TYPE_CHECKING from typing import Union +import anthropic from anthropic import AsyncAnthropic from anthropic import AsyncAnthropicVertex from anthropic import NOT_GIVEN @@ -52,6 +53,15 @@ logger = logging.getLogger("google_adk." + __name__) +_RATE_LIMIT_POSSIBLE_FIX_MESSAGE = """ +To mitigate rate limit errors, consider using a different model, reducing +request frequency, or upgrading your Anthropic API plan. +""" + + +class AnthropicRateLimitError(Exception): + """Raised when the Anthropic API returns a rate limit error.""" + @dataclasses.dataclass class _ToolUseAccumulator: @@ -494,16 +504,21 @@ async def generate_content_async( thinking = _build_anthropic_thinking_param(llm_request.config) if not stream: - message = await self._anthropic_client.messages.create( - model=model_to_use, - system=llm_request.config.system_instruction, - messages=messages, - tools=tools, - tool_choice=tool_choice, - max_tokens=self.max_tokens, - thinking=thinking, - ) - yield message_to_generate_content_response(message) + try: + message = await self._anthropic_client.messages.create( + model=model_to_use, + system=llm_request.config.system_instruction, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_tokens=self.max_tokens, + thinking=thinking, + ) + yield message_to_generate_content_response(message) + except anthropic.RateLimitError as e: + raise AnthropicRateLimitError( + f"{_RATE_LIMIT_POSSIBLE_FIX_MESSAGE}\n\n{e}" + ) from e else: async for response in self._generate_content_streaming( llm_request, messages, tools, tool_choice, thinking @@ -528,16 +543,21 @@ async def _generate_content_streaming( a final aggregated LlmResponse with all content. """ model_to_use = self._resolve_model_name(llm_request.model) - raw_stream = await self._anthropic_client.messages.create( - model=model_to_use, - system=llm_request.config.system_instruction, - messages=messages, - tools=tools, - tool_choice=tool_choice, - max_tokens=self.max_tokens, - stream=True, - thinking=thinking, - ) + try: + raw_stream = await self._anthropic_client.messages.create( + model=model_to_use, + system=llm_request.config.system_instruction, + messages=messages, + tools=tools, + tool_choice=tool_choice, + max_tokens=self.max_tokens, + stream=True, + thinking=thinking, + ) + except anthropic.RateLimitError as e: + raise AnthropicRateLimitError( + f"{_RATE_LIMIT_POSSIBLE_FIX_MESSAGE}\n\n{e}" + ) from e # Track content blocks being built during streaming. # Each entry maps a block index to its accumulated state. diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index aa18d6f072..bde79897fc 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -1905,3 +1905,116 @@ async def test_streaming_redacted_thinking_block_preserved_in_final(): text_part = final.content.parts[1] assert text_part.text == "Done." + + +# --- Tests for Anthropic API error handling --- + + +def _make_non_streaming_request(): + return LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig(system_instruction="Test"), + ) + + +def _make_rate_limit_error(): + import anthropic + + mock_response = MagicMock() + mock_response.status_code = 429 + mock_response.headers = {} + return anthropic.RateLimitError( + message="rate limit exceeded", + response=mock_response, + body={"error": {"message": "rate limit exceeded"}}, + ) + + +def _make_auth_error(): + import anthropic + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.headers = {} + return anthropic.AuthenticationError( + message="invalid api key", + response=mock_response, + body={"error": {"message": "invalid api key"}}, + ) + + +@pytest.mark.asyncio +async def test_non_streaming_rate_limit_raises_anthropic_rate_limit_error(): + """RateLimitError is re-raised as AnthropicRateLimitError with helpful message.""" + from google.adk.models.anthropic_llm import AnthropicRateLimitError + + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(side_effect=_make_rate_limit_error()) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + with pytest.raises(AnthropicRateLimitError): + _ = [ + r + async for r in llm.generate_content_async( + _make_non_streaming_request(), stream=False + ) + ] + + +@pytest.mark.asyncio +async def test_streaming_rate_limit_raises_anthropic_rate_limit_error(): + """RateLimitError is re-raised as AnthropicRateLimitError in streaming path.""" + from google.adk.models.anthropic_llm import AnthropicRateLimitError + + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(side_effect=_make_rate_limit_error()) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + with pytest.raises(AnthropicRateLimitError): + _ = [ + r + async for r in llm.generate_content_async( + _make_non_streaming_request(), stream=True + ) + ] + + +@pytest.mark.asyncio +async def test_non_streaming_other_errors_propagate(): + """Non-rate-limit errors propagate unchanged.""" + import anthropic + + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(side_effect=_make_auth_error()) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + with pytest.raises(anthropic.AuthenticationError): + _ = [ + r + async for r in llm.generate_content_async( + _make_non_streaming_request(), stream=False + ) + ] + + +@pytest.mark.asyncio +async def test_streaming_other_errors_propagate(): + """Non-rate-limit errors propagate unchanged in streaming path.""" + import anthropic + + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(side_effect=_make_auth_error()) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + with pytest.raises(anthropic.AuthenticationError): + _ = [ + r + async for r in llm.generate_content_async( + _make_non_streaming_request(), stream=True + ) + ]