Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@ def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
def to_google_genai_finish_reason(
anthropic_stop_reason: Optional[str],
) -> types.FinishReason:
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use", "pause_turn"]:
return "STOP"
if anthropic_stop_reason == "max_tokens":
return "MAX_TOKENS"
if anthropic_stop_reason == "refusal":
return "SAFETY"
return "FINISH_REASON_UNSPECIFIED"


Expand Down Expand Up @@ -343,8 +345,7 @@ def message_to_generate_content_response(
message.usage.input_tokens + message.usage.output_tokens
),
),
# TODO: Deal with these later.
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
finish_reason=to_google_genai_finish_reason(message.stop_reason),
)


Expand Down Expand Up @@ -547,6 +548,7 @@ async def _generate_content_streaming(
redacted_thinking_blocks: dict[int, str] = {}
input_tokens = 0
output_tokens = 0
stop_reason: Optional[str] = None

async for event in raw_stream:
if event.type == "message_start":
Expand Down Expand Up @@ -603,6 +605,7 @@ async def _generate_content_streaming(

elif event.type == "message_delta":
output_tokens = event.usage.output_tokens
stop_reason = event.delta.stop_reason

# Build the final aggregated response with all content.
all_parts: list[types.Part] = []
Expand Down Expand Up @@ -644,6 +647,7 @@ async def _generate_content_streaming(
candidates_token_count=output_tokens,
total_token_count=input_tokens + output_tokens,
),
finish_reason=to_google_genai_finish_reason(stop_reason),
partial=False,
)

Expand Down
150 changes: 150 additions & 0 deletions tests/unittests/models/test_anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,3 +1905,153 @@ async def test_streaming_redacted_thinking_block_preserved_in_final():

text_part = final.content.parts[1]
assert text_part.text == "Done."




# --- Tests for finish_reason ---


@pytest.mark.parametrize(
"stop_reason,expected_finish_reason",
[
("end_turn", types.FinishReason.STOP),
("stop_sequence", types.FinishReason.STOP),
("tool_use", types.FinishReason.STOP),
("max_tokens", types.FinishReason.MAX_TOKENS),
("pause_turn", types.FinishReason.STOP),
("refusal", types.FinishReason.SAFETY),
(None, types.FinishReason.FINISH_REASON_UNSPECIFIED),
],
)
def test_to_google_genai_finish_reason(stop_reason, expected_finish_reason):
"""All Anthropic stop_reason values map to the correct ADK FinishReason."""
from google.adk.models.anthropic_llm import to_google_genai_finish_reason

assert to_google_genai_finish_reason(stop_reason) == expected_finish_reason


@pytest.mark.asyncio
async def test_non_streaming_sets_finish_reason():
"""finish_reason is populated on non-streaming LlmResponse."""
llm = AnthropicLlm(model="claude-sonnet-4-20250514")
mock_message = anthropic_types.Message(
id="msg_test",
content=[anthropic_types.TextBlock(text="Hi", type="text", citations=None)],
model="claude-sonnet-4-20250514",
role="assistant",
stop_reason="end_turn",
stop_sequence=None,
type="message",
usage=anthropic_types.Usage(
input_tokens=5,
output_tokens=2,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
server_tool_use=None,
service_tier=None,
),
)
mock_client = MagicMock()
mock_client.messages.create = AsyncMock(return_value=mock_message)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(system_instruction="Test"),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
responses = [
r async for r in llm.generate_content_async(llm_request, stream=False)
]

assert len(responses) == 1
assert responses[0].finish_reason == types.FinishReason.STOP


@pytest.mark.asyncio
async def test_non_streaming_finish_reason_max_tokens():
"""finish_reason MAX_TOKENS is set when stop_reason is max_tokens."""
llm = AnthropicLlm(model="claude-sonnet-4-20250514")
mock_message = anthropic_types.Message(
id="msg_test",
content=[anthropic_types.TextBlock(text="Hi", type="text", citations=None)],
model="claude-sonnet-4-20250514",
role="assistant",
stop_reason="max_tokens",
stop_sequence=None,
type="message",
usage=anthropic_types.Usage(
input_tokens=5,
output_tokens=2,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
server_tool_use=None,
service_tier=None,
),
)
mock_client = MagicMock()
mock_client.messages.create = AsyncMock(return_value=mock_message)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(system_instruction="Test"),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
responses = [
r async for r in llm.generate_content_async(llm_request, stream=False)
]

assert responses[0].finish_reason == types.FinishReason.MAX_TOKENS


@pytest.mark.asyncio
async def test_streaming_sets_finish_reason():
"""finish_reason is populated on the final streaming LlmResponse."""
llm = AnthropicLlm(model="claude-sonnet-4-20250514")

events = [
MagicMock(
type="message_start",
message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)),
),
MagicMock(
type="content_block_start",
index=0,
content_block=anthropic_types.TextBlock(text="", type="text"),
),
MagicMock(
type="content_block_delta",
index=0,
delta=anthropic_types.TextDelta(text="Hi", type="text_delta"),
),
MagicMock(type="content_block_stop", index=0),
MagicMock(
type="message_delta",
delta=MagicMock(stop_reason="max_tokens"),
usage=MagicMock(output_tokens=1),
),
MagicMock(type="message_stop"),
]

mock_client = MagicMock()
mock_client.messages.create = AsyncMock(
return_value=_make_mock_stream_events(events)
)

llm_request = LlmRequest(
model="claude-sonnet-4-20250514",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
config=types.GenerateContentConfig(system_instruction="Test"),
)

with mock.patch.object(llm, "_anthropic_client", mock_client):
responses = [
r async for r in llm.generate_content_async(llm_request, stream=True)
]

final = responses[-1]
assert final.finish_reason == types.FinishReason.MAX_TOKENS