Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from __future__ import annotations

import asyncio
import contextlib
import copy
from functools import cached_property
Expand Down Expand Up @@ -97,19 +98,14 @@ class Gemini(BaseLlm):
directly (location, project, credentials, http_options, etc.),
subclass ``Gemini`` and override the ``api_client`` property::

from functools import cached_property
from google.adk.models import Gemini
from google.genai import Client

class GlobalGemini(Gemini):
@cached_property
def api_client(self) -> Client:
return Client(vertexai=True, location="global")

agent = Agent(model=GlobalGemini(model="gemini-3-pro-preview"))

Use ``@property`` instead of ``@cached_property`` if you hit asyncio
lock contention in multithreaded code.
"""

model: str = 'gemini-2.5-flash'
Expand Down Expand Up @@ -321,15 +317,32 @@ async def _generate_content_via_interactions(
):
yield llm_response

@cached_property
@property
def api_client(self) -> Client:
"""Provides the api client.

The client is cached per asyncio event loop so that each OS thread running
its own event loop receives a fresh ``google.genai.Client`` (and therefore
a fresh underlying HTTP session). Reusing a client whose event loop has
been closed causes ``RuntimeError: Event loop is closed`` in multi-threaded
deployments such as Vertex AI Agent Engine.

Returns:
The api client.
"""
from google.genai import Client

try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

cached_loop, cached_client = self.__dict__.get(
'_api_client_cache', (None, None)
)
if cached_client is not None and cached_loop is loop:
return cached_client

base_url, api_version = self._base_url_and_api_version
kwargs_for_http_options: dict[str, Any] = {
'headers': self._tracking_headers(),
Expand All @@ -345,7 +358,9 @@ def api_client(self) -> Client:
if self.model.startswith('projects/'):
kwargs['vertexai'] = True

return Client(**kwargs)
client = Client(**kwargs)
self.__dict__['_api_client_cache'] = (loop, client)
return client

@cached_property
def _api_backend(self) -> GoogleLLMVariant:
Expand Down Expand Up @@ -374,10 +389,21 @@ def _live_api_version(self) -> str:
# use v1alpha for using API KEY from Google AI Studio
return 'v1alpha'

@cached_property
@property
def _live_api_client(self) -> Client:
from google.genai import Client

try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

cached_loop, cached_client = self.__dict__.get(
'_live_api_client_cache', (None, None)
)
if cached_client is not None and cached_loop is loop:
return cached_client

base_url, _ = self._base_url_and_api_version

kwargs: dict[str, Any] = {
Expand All @@ -390,7 +416,9 @@ def _live_api_client(self) -> Client:
if self.model.startswith('projects/'):
kwargs['vertexai'] = True

return Client(**kwargs)
client = Client(**kwargs)
self.__dict__['_live_api_client_cache'] = (loop, client)
return client

@contextlib.asynccontextmanager
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
Expand Down
80 changes: 80 additions & 0 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import sys
import threading
from typing import Optional
from unittest import mock
from unittest.mock import AsyncMock
Expand Down Expand Up @@ -155,6 +157,84 @@ def test_supported_models():
)


def test_api_client_cached_per_event_loop():
"""Verify api_client returns a fresh client for each distinct event loop.

Regression test for https://github.com/google/adk-python/issues/5538:
RuntimeError: Event loop is closed caused by reusing a cached api_client
whose underlying HTTP session was bound to a previous (now-closed) event
loop. Each new OS thread in Vertex AI Agent Engine gets a fresh event loop,
so the client must not be shared across loops.
"""
model = Gemini(model='gemini-2.5-flash')

with mock.patch('google.genai.Client', autospec=True) as mock_client:
mock_client.side_effect = lambda **kw: mock.MagicMock()

# Collect clients returned in each thread's event loop.
results: list[object] = []

def run_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
client = loop.run_until_complete(asyncio.coroutine(lambda: model.api_client)())
results.append(client)
finally:
loop.close()
asyncio.set_event_loop(None)

async def get_client():
return model.api_client

# Each thread runs its own event loop.
t1 = threading.Thread(target=lambda: _run_and_collect(model, results))
t2 = threading.Thread(target=lambda: _run_and_collect(model, results))
t1.start()
t1.join()
t2.start()
t2.join()

# Two distinct event loops → two distinct Client instances.
assert len(results) == 2
assert results[0] is not results[1], (
'api_client must not be shared across different event loops'
)
assert mock_client.call_count == 2


def _run_and_collect(model: Gemini, results: list) -> None:
"""Helper: run a fresh event loop in the current thread and capture api_client."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:

async def _get():
return model.api_client

client = loop.run_until_complete(_get())
results.append(client)
finally:
loop.close()
asyncio.set_event_loop(None)


def test_api_client_cached_within_same_event_loop():
"""Verify api_client is cached (not recreated) within the same event loop."""
model = Gemini(model='gemini-2.5-flash')

with mock.patch('google.genai.Client', autospec=True) as mock_client:
mock_client.side_effect = lambda **kw: mock.MagicMock()

async def _get_twice():
c1 = model.api_client
c2 = model.api_client
assert c1 is c2, 'api_client must be cached within the same event loop'

asyncio.run(_get_twice())
assert mock_client.call_count == 1


def test_gemini_api_client_creation_with_projects_prefix():
model = Gemini(
model="projects/test-project/locations/test-location/publishers/google/models/gemini-2.5-pro"
Expand Down
Loading