diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 58df3fd9a5..c47cb311d9 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -133,10 +133,12 @@ def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]: def to_google_genai_finish_reason( anthropic_stop_reason: Optional[str], ) -> types.FinishReason: - if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]: + if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use", "pause_turn"]: return "STOP" if anthropic_stop_reason == "max_tokens": return "MAX_TOKENS" + if anthropic_stop_reason == "refusal": + return "SAFETY" return "FINISH_REASON_UNSPECIFIED" @@ -343,8 +345,7 @@ def message_to_generate_content_response( message.usage.input_tokens + message.usage.output_tokens ), ), - # TODO: Deal with these later. - # finish_reason=to_google_genai_finish_reason(message.stop_reason), + finish_reason=to_google_genai_finish_reason(message.stop_reason), ) @@ -547,6 +548,7 @@ async def _generate_content_streaming( redacted_thinking_blocks: dict[int, str] = {} input_tokens = 0 output_tokens = 0 + stop_reason: Optional[str] = None async for event in raw_stream: if event.type == "message_start": @@ -603,6 +605,7 @@ async def _generate_content_streaming( elif event.type == "message_delta": output_tokens = event.usage.output_tokens + stop_reason = event.delta.stop_reason # Build the final aggregated response with all content. all_parts: list[types.Part] = [] @@ -644,6 +647,7 @@ async def _generate_content_streaming( candidates_token_count=output_tokens, total_token_count=input_tokens + output_tokens, ), + finish_reason=to_google_genai_finish_reason(stop_reason), partial=False, ) diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index aa18d6f072..f4ec7d7ee6 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -1905,3 +1905,153 @@ async def test_streaming_redacted_thinking_block_preserved_in_final(): text_part = final.content.parts[1] assert text_part.text == "Done." + + + + +# --- Tests for finish_reason --- + + +@pytest.mark.parametrize( + "stop_reason,expected_finish_reason", + [ + ("end_turn", types.FinishReason.STOP), + ("stop_sequence", types.FinishReason.STOP), + ("tool_use", types.FinishReason.STOP), + ("max_tokens", types.FinishReason.MAX_TOKENS), + ("pause_turn", types.FinishReason.STOP), + ("refusal", types.FinishReason.SAFETY), + (None, types.FinishReason.FINISH_REASON_UNSPECIFIED), + ], +) +def test_to_google_genai_finish_reason(stop_reason, expected_finish_reason): + """All Anthropic stop_reason values map to the correct ADK FinishReason.""" + from google.adk.models.anthropic_llm import to_google_genai_finish_reason + + assert to_google_genai_finish_reason(stop_reason) == expected_finish_reason + + +@pytest.mark.asyncio +async def test_non_streaming_sets_finish_reason(): + """finish_reason is populated on non-streaming LlmResponse.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + mock_message = anthropic_types.Message( + id="msg_test", + content=[anthropic_types.TextBlock(text="Hi", type="text", citations=None)], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="end_turn", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=2, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(return_value=mock_message) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig(system_instruction="Test"), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + assert responses[0].finish_reason == types.FinishReason.STOP + + +@pytest.mark.asyncio +async def test_non_streaming_finish_reason_max_tokens(): + """finish_reason MAX_TOKENS is set when stop_reason is max_tokens.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + mock_message = anthropic_types.Message( + id="msg_test", + content=[anthropic_types.TextBlock(text="Hi", type="text", citations=None)], + model="claude-sonnet-4-20250514", + role="assistant", + stop_reason="max_tokens", + stop_sequence=None, + type="message", + usage=anthropic_types.Usage( + input_tokens=5, + output_tokens=2, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + server_tool_use=None, + service_tier=None, + ), + ) + mock_client = MagicMock() + mock_client.messages.create = AsyncMock(return_value=mock_message) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig(system_instruction="Test"), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=False) + ] + + assert responses[0].finish_reason == types.FinishReason.MAX_TOKENS + + +@pytest.mark.asyncio +async def test_streaming_sets_finish_reason(): + """finish_reason is populated on the final streaming LlmResponse.""" + llm = AnthropicLlm(model="claude-sonnet-4-20250514") + + events = [ + MagicMock( + type="message_start", + message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)), + ), + MagicMock( + type="content_block_start", + index=0, + content_block=anthropic_types.TextBlock(text="", type="text"), + ), + MagicMock( + type="content_block_delta", + index=0, + delta=anthropic_types.TextDelta(text="Hi", type="text_delta"), + ), + MagicMock(type="content_block_stop", index=0), + MagicMock( + type="message_delta", + delta=MagicMock(stop_reason="max_tokens"), + usage=MagicMock(output_tokens=1), + ), + MagicMock(type="message_stop"), + ] + + mock_client = MagicMock() + mock_client.messages.create = AsyncMock( + return_value=_make_mock_stream_events(events) + ) + + llm_request = LlmRequest( + model="claude-sonnet-4-20250514", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig(system_instruction="Test"), + ) + + with mock.patch.object(llm, "_anthropic_client", mock_client): + responses = [ + r async for r in llm.generate_content_async(llm_request, stream=True) + ] + + final = responses[-1] + assert final.finish_reason == types.FinishReason.MAX_TOKENS