diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index e7e5337..360c18f 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -16,7 +16,7 @@ from x402.mechanisms.evm.exact.register import register_exact_evm_client from x402.mechanisms.evm.upto.register import register_upto_evm_client -from ..types import TEE_LLM, ResponseFormat, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode +from ..types import TEE_LLM, ResponseFormat, StreamChunk, TextGenerationOutput, x402SettlementMode from .opg_token import Permit2ApprovalResult, ensure_opg_approval from .tee_connection import RegistryTEEConnection, StaticTEEConnection, TEEConnectionInterface from .tee_registry import TEERegistry @@ -366,11 +366,6 @@ async def chat( if not stream: return await self._chat_request(params, messages) - # The TEE streaming endpoint omits tool call content from SSE events. - # Fall back to non-streaming and emit a single final StreamChunk. - if tools: - return self._chat_tools_as_stream(params, messages) - return self._chat_stream(params, messages) # ── Chat internals ────────────────────────────────────────────────── @@ -424,31 +419,6 @@ async def _request() -> TextGenerationOutput: except Exception as e: raise RuntimeError(f"TEE LLM chat failed: {e}") from e - async def _chat_tools_as_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: - """Non-streaming fallback for tool-call requests wrapped as a single StreamChunk.""" - result = await self._chat_request(params, messages) - chat_output = result.chat_output or {} - yield StreamChunk( - choices=[ - StreamChoice( - delta=StreamDelta( - role=chat_output.get("role"), - content=chat_output.get("content"), - tool_calls=chat_output.get("tool_calls"), - ), - index=0, - finish_reason=result.finish_reason, - ) - ], - model=params.model, - is_final=True, - tee_signature=result.tee_signature, - tee_timestamp=result.tee_timestamp, - tee_id=result.tee_id, - tee_endpoint=result.tee_endpoint, - tee_payment_address=result.tee_payment_address, - ) - async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> AsyncGenerator[StreamChunk, None]: """Async SSE streaming implementation.""" self._tee.ensure_refresh_loop() diff --git a/tests/llm_test.py b/tests/llm_test.py index 5309f28..9bb8e86 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -461,19 +461,19 @@ async def test_stream_error_raises(self, fake_http): with pytest.raises(RuntimeError, match="streaming request failed"): _ = [chunk async for chunk in gen] - async def test_tools_with_stream_falls_back_to_single_chunk(self, fake_http): - """When tools + stream=True, LLM falls back to non-streaming and yields one chunk.""" + async def test_tools_with_stream_uses_sse_chunks(self, fake_http): + """When tools + stream=True, tool call deltas are streamed through SSE.""" tools = [{"type": "function", "function": {"name": "f"}}] - fake_http.set_response( + fake_http.set_stream_response( 200, - { - "choices": [ - { - "message": {"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]}, - "finish_reason": "tool_calls", - } - ], - }, + [ + ( + b'data: {"model":"gpt-5","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":' + b'[{"id":"tc1","index":0,"function":{"name":"f","arguments":"{\\"city\\":\\"NYC\\"}"}}]},' + b'"finish_reason":"tool_calls"}]}\n\n' + ), + b"data: [DONE]\n\n", + ], ) llm = _make_llm() @@ -487,7 +487,13 @@ async def test_tools_with_stream_falls_back_to_single_chunk(self, fake_http): assert len(chunks) == 1 assert chunks[0].is_final - assert chunks[0].choices[0].delta.tool_calls == [{"id": "tc1"}] + assert chunks[0].choices[0].delta.tool_calls == [ + { + "id": "tc1", + "index": 0, + "function": {"name": "f", "arguments": '{"city":"NYC"}'}, + } + ] assert chunks[0].choices[0].finish_reason == "tool_calls"