diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 1d49f50d79..1fe1046fee 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio from datetime import datetime from pathlib import Path from typing import Optional @@ -69,16 +70,16 @@ async def run_input_file( memory_service=memory_service, credential_service=credential_service, ) - with open(input_path, 'r', encoding='utf-8') as f: + with open(input_path, "r", encoding="utf-8") as f: input_file = InputFile.model_validate_json(f.read()) - input_file.state['_time'] = datetime.now().isoformat() + input_file.state["_time"] = datetime.now().isoformat() session = await session_service.create_session( app_name=app_name, user_id=user_id, state=input_file.state ) for query in input_file.queries: - click.echo(f'[user]: {query}') - content = types.Content(role='user', parts=[types.Part(text=query)]) + click.echo(f"[user]: {query}") + content = types.Content(role="user", parts=[types.Part(text=query)]) async with Aclosing( runner.run_async( user_id=session.user_id, session_id=session.id, new_message=content @@ -86,8 +87,8 @@ async def run_input_file( ) as agen: async for event in agen: if event.content and event.content.parts: - if text := ''.join(part.text or '' for part in event.content.parts): - click.echo(f'[{event.author}]: {text}') + if text := "".join(part.text or "" for part in event.content.parts): + click.echo(f"[{event.author}]: {text}") return session @@ -98,6 +99,9 @@ async def run_interactively( session_service: BaseSessionService, credential_service: BaseCredentialService, memory_service: Optional[BaseMemoryService] = None, + save_session_on_runtime: bool = False, + interval: int = 60, + agent_root: Optional[Path] = None, ) -> None: app = ( root_agent_or_app @@ -111,25 +115,74 @@ async def run_interactively( memory_service=memory_service, credential_service=credential_service, ) - while True: - query = input('[user]: ') - if not query or not query.strip(): - continue - if query == 'exit': - break - async with Aclosing( - runner.run_async( - user_id=session.user_id, - session_id=session.id, - new_message=types.Content( - role='user', parts=[types.Part(text=query)] - ), - ) - ) as agen: - async for event in agen: - if event.content and event.content.parts: - if text := ''.join(part.text or '' for part in event.content.parts): - click.echo(f'[{event.author}]: {text}') + + # Background task for periodic session saving + save_task: Optional[asyncio.Task[None]] = None + + async def _periodic_save_session() -> None: + """Periodically save the session to disk every interval seconds.""" + nonlocal session, save_task + try: + while True: + await asyncio.sleep(interval) + click.echo(f"interval : {interval}") + + if save_session_on_runtime and agent_root: + try: + current_session = await session_service.get_session( + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + ) + + if current_session: + runtime_session_path = ( + agent_root / ".adk" / "runtime_session.json" + ) + runtime_session_path.parent.mkdir(parents=True, exist_ok=True) + runtime_session_path.write_text( + current_session.model_dump_json( + indent=2, exclude_none=True, by_alias=True + ), + encoding="utf-8", + ) + except Exception: + pass + except asyncio.CancelledError: + pass + + # Start the periodic save task if enabled + if save_session_on_runtime and agent_root: + save_task = asyncio.create_task(_periodic_save_session()) + + try: + while True: + query = input("[user]: ") + if not query or not query.strip(): + continue + if query == "exit": + break + async with Aclosing( + runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=types.Content( + role="user", parts=[types.Part(text=query)] + ), + ) + ) as agen: + async for event in agen: + if event.content and event.content.parts: + if text := "".join(part.text or "" for part in event.content.parts): + click.echo(f"[{event.author}]: {text}") + finally: + if save_task: + save_task.cancel() + try: + await save_task + except asyncio.CancelledError: + pass + await runner.close() @@ -140,6 +193,8 @@ async def run_cli( input_file: Optional[str] = None, saved_session_file: Optional[str] = None, save_session: bool, + save_session_on_runtime: bool, + interval: int, session_id: Optional[str] = None, session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, @@ -167,7 +222,7 @@ async def run_cli( agent_parent_path = Path(agent_parent_dir).resolve() agent_root = agent_parent_path / agent_folder_name load_services_module(str(agent_root)) - user_id = 'test_user' + user_id = "test_user" agents_dir = str(agent_parent_path) agent_loader = AgentLoader(agents_dir=agents_dir) @@ -179,7 +234,7 @@ async def run_cli( if isinstance(agent_or_app, App) and agent_or_app.name != agent_folder_name: app_name_to_dir = {agent_or_app.name: agent_folder_name} - if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'): + if not is_env_enabled("ADK_DISABLE_LOAD_DOTENV"): envs.load_dotenv_for_agent(agent_folder_name, agents_dir) # Create session and artifact services using factory functions. @@ -211,8 +266,8 @@ def _print_event(event) -> None: text_parts = [part.text for part in content.parts if part.text] if not text_parts: return - author = event.author or 'system' - click.echo(f'[{author}]: {"".join(text_parts)}') + author = event.author or "system" + click.echo(f"[{author}]: {''.join(text_parts)}") if input_file: session = await run_input_file( @@ -227,7 +282,7 @@ def _print_event(event) -> None: ) elif saved_session_file: # Load the saved session from file - with open(saved_session_file, 'r', encoding='utf-8') as f: + with open(saved_session_file, "r", encoding="utf-8") as f: loaded_session = Session.model_validate_json(f.read()) # Create a new session in the service, copying state from the file @@ -252,10 +307,39 @@ def _print_event(event) -> None: memory_service=memory_service, ) else: - session = await session_service.create_session( - app_name=session_app_name, user_id=user_id - ) - click.echo(f'Running agent {agent_or_app.name}, type exit to exit.') + # Check for runtime saved session + runtime_session_path = agent_root / ".adk" / "runtime_session.json" + if runtime_session_path.exists(): + try: + with open(runtime_session_path, "r", encoding="utf-8") as f: + loaded_session = Session.model_validate_json(f.read()) + + # Create a new session in the service, copying state from the file + session = await session_service.create_session( + app_name=session_app_name, + user_id=user_id, + state=loaded_session.state if loaded_session else None, + ) + + # Append events from the file to the new session and display them + if loaded_session: + for event in loaded_session.events: + await session_service.append_event(session, event) + _print_event(event) + + click.echo(f"Loaded runtime session from {runtime_session_path}") + except Exception as e: + click.echo(f"Warning: Failed to load runtime session: {e}") + # Fall back to creating a new session + session = await session_service.create_session( + app_name=session_app_name, user_id=user_id + ) + else: + session = await session_service.create_session( + app_name=session_app_name, user_id=user_id + ) + + click.echo(f"Running agent {agent_or_app.name}, type exit to exit.") await run_interactively( agent_or_app, artifact_service, @@ -263,11 +347,14 @@ def _print_event(event) -> None: session_service, credential_service, memory_service=memory_service, + save_session_on_runtime=True, + interval=interval, + agent_root=agent_root, ) if save_session: - session_id = session_id or input('Session ID to save: ') - session_path = agent_root / f'{session_id}.session.json' + session_id = session_id or input("Session ID to save: ") + session_path = agent_root / f"{session_id}.session.json" # Fetch the session again to get all the details. session = await session_service.get_session( @@ -277,7 +364,7 @@ def _print_event(event) -> None: ) session_path.write_text( session.model_dump_json(indent=2, exclude_none=True, by_alias=True), - encoding='utf-8', + encoding="utf-8", ) - print('Session saved to', session_path) + print("Session saved to", session_path) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 07ccc15892..3b5db0339c 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -605,6 +605,24 @@ def wrapper(*args, **kwargs): @main.command("run", cls=HelpfulCommand) @feature_options() @adk_services_options(default_use_local_storage=True) +@click.option( + "--interval", + type=int, + default=60, + show_default=True, + help=( + "Autosave interval in seconds (only used if --save_session_on_runtime" + " is set)." + ), +) +@click.option( + "--save_session_on_runtime", + type=bool, + is_flag=True, + show_default=True, + default=False, + help="Optional. Whether to save the session to a json file on runtime.", +) @click.option( "--save_session", type=bool, @@ -653,9 +671,11 @@ def wrapper(*args, **kwargs): exists=True, dir_okay=True, file_okay=False, resolve_path=True ), ) -def cli_run( +def cli_run( # type: ignore[misc] agent: str, save_session: bool, + save_session_on_runtime: bool, + interval: int, session_id: Optional[str], replay: Optional[str], resume: Optional[str], @@ -663,7 +683,7 @@ def cli_run( artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = True, -): +) -> None: """Runs an interactive CLI for a certain agent. AGENT: The path to the agent source code folder. @@ -684,6 +704,8 @@ def cli_run( input_file=replay, saved_session_file=resume, save_session=save_session, + save_session_on_runtime=save_session_on_runtime, + interval=interval, session_id=session_id, session_service_uri=session_service_uri, artifact_service_uri=artifact_service_uri, diff --git a/tests/unittests/cli/test_save_session_on_runtime.py b/tests/unittests/cli/test_save_session_on_runtime.py new file mode 100644 index 0000000000..3e25d9fcd2 --- /dev/null +++ b/tests/unittests/cli/test_save_session_on_runtime.py @@ -0,0 +1,114 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the save_session_on_runtime feature.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from unittest import mock + +from click.testing import CliRunner +from google.adk.cli.cli import run_cli +from google.adk.cli.cli_tools_click import main +from google.adk.sessions.session import Session + + +def test_cli_run_has_save_session_on_runtime_option(): + """Test that the run command has the --save_session_on_runtime option.""" + runner = CliRunner() + result = runner.invoke(main, ["run", "--help"]) + assert "--save_session_on_runtime" in result.output + + +@mock.patch("google.adk.cli.cli_tools_click.run_cli") # patch where it's USED +def test_cli_run_passes_save_session_on_runtime_flag(mock_run_cli): + """Test that the run command passes the save_session_on_runtime flag to run_cli.""" + runner = CliRunner() + result = runner.invoke( + main, + ["run", "contributing/samples/hello_world", "--save_session_on_runtime"], + input="exit\n", + ) + + print(f"Exit code: {result.exit_code}") + print(f"Output: {result.output}") + if result.exception: + print(f"Exception: {result.exception}") + import traceback + + traceback.print_exception( + type(result.exception), result.exception, result.exception.__traceback__ + ) + + assert mock_run_cli.called, ( + f"run_cli was not called. Exit code: {result.exit_code}, Output:" + f" {result.output}" + ) + call_args = mock_run_cli.call_args + assert call_args is not None + assert call_args.kwargs.get("save_session_on_runtime") is True + + +@mock.patch("google.adk.cli.cli.create_session_service_from_options") +@mock.patch("google.adk.cli.cli.create_artifact_service_from_options") +@mock.patch("google.adk.cli.cli.create_memory_service_from_options") +@mock.patch("google.adk.cli.cli.InMemoryCredentialService") +@mock.patch("google.adk.cli.cli.AgentLoader") +@mock.patch("google.adk.cli.cli.load_services_module") +@mock.patch("google.adk.cli.cli.envs.load_dotenv_for_agent") +@mock.patch("google.adk.cli.cli.run_interactively") +def test_run_cli_saves_session_periodically( + mock_run_interactively, + mock_load_dotenv, + mock_load_services_module, + mock_agent_loader, + mock_credential_service, + mock_memory_service, + mock_artifact_service, + mock_session_service, +): + """Test that run_cli calls run_interactively with save_session_on_runtime=True when flag is set.""" + mock_session_service_instance = mock_session_service.return_value + mock_session_service_instance.create_session = mock.AsyncMock( + return_value=mock.Mock(spec=Session) + ) + mock_session_service_instance.get_session = mock.AsyncMock( + return_value=mock.Mock(spec=Session) + ) + + mock_agent_loader_instance = mock_agent_loader.return_value + mock_agent_loader_instance.load_agent = mock.Mock(return_value=mock.Mock()) + + asyncio.run( + run_cli( + agent_parent_dir="/fake/parent", + agent_folder_name="fake_agent", + save_session=False, + save_session_on_runtime=True, + interval=60, + session_id=None, + session_service_uri=None, + artifact_service_uri=None, + memory_service_uri=None, + use_local_storage=True, + ) + ) + + mock_run_interactively.assert_called_once() + call_args = mock_run_interactively.call_args + assert call_args is not None + assert call_args.kwargs.get("save_session_on_runtime") is True + assert call_args.kwargs.get("agent_root") is not None diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index f7df1bf17f..add8a22638 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -191,6 +191,8 @@ async def test_run_cli_with_input_file(fake_agent, tmp_path: Path) -> None: input_file=str(input_path), saved_session_file=None, save_session=False, + save_session_on_runtime=False, + interval=60, ) @@ -217,6 +219,8 @@ async def test_run_cli_loads_services_module( input_file=str(input_path), saved_session_file=None, save_session=False, + save_session_on_runtime=False, + interval=60, ) assert loaded_dirs == [str(agent_root.resolve())] @@ -255,6 +259,8 @@ def _session_factory(**_: Any) -> InMemorySessionService: input_file=str(input_path), saved_session_file=None, save_session=False, + save_session_on_runtime=False, + interval=60, ) assert created_app_names @@ -277,13 +283,15 @@ async def test_run_cli_save_session( if session_file.exists(): session_file.unlink() - await cli.run_cli( - agent_parent_dir=str(parent_dir), - agent_folder_name=folder_name, - input_file=None, - saved_session_file=None, - save_session=True, - ) + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=None, + saved_session_file=None, + save_session=False, + save_session_on_runtime=False, + interval=60, + ) assert session_file.exists() data = json.loads(session_file.read_text()) @@ -355,6 +363,8 @@ async def test_run_cli_accepts_memory_scheme( session_service_uri="memory://", artifact_service_uri="memory://", memory_service_uri="memory://", + save_session_on_runtime=False, + interval=60, ) @@ -388,6 +398,8 @@ def _raise_invalid_memory_uri( saved_session_file=None, save_session=False, memory_service_uri="unknown://x", + save_session_on_runtime=False, + interval=60, ) @@ -441,6 +453,8 @@ async def _run_input_file( saved_session_file=None, save_session=False, memory_service_uri="memory://", + save_session_on_runtime=False, + interval=60, ) assert Path(captured_factory_args["base_dir"]) == parent_dir.resolve() @@ -486,6 +500,8 @@ def _memory_factory( saved_session_file=None, save_session=False, memory_service_uri="memory://", + save_session_on_runtime=False, + interval=60, ) assert "create_memory" in call_order