Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 125 additions & 38 deletions src/google/adk/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import asyncio
from datetime import datetime
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -69,25 +70,25 @@ async def run_input_file(
memory_service=memory_service,
credential_service=credential_service,
)
with open(input_path, 'r', encoding='utf-8') as f:
with open(input_path, "r", encoding="utf-8") as f:
input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now().isoformat()
input_file.state["_time"] = datetime.now().isoformat()

session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries:
click.echo(f'[user]: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)])
click.echo(f"[user]: {query}")
content = types.Content(role="user", parts=[types.Part(text=query)])
async with Aclosing(
runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
)
) as agen:
async for event in agen:
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
if text := "".join(part.text or "" for part in event.content.parts):
click.echo(f"[{event.author}]: {text}")
return session


Expand All @@ -98,6 +99,9 @@ async def run_interactively(
session_service: BaseSessionService,
credential_service: BaseCredentialService,
memory_service: Optional[BaseMemoryService] = None,
save_session_on_runtime: bool = False,
interval: int = 60,
agent_root: Optional[Path] = None,
) -> None:
app = (
root_agent_or_app
Expand All @@ -111,25 +115,74 @@ async def run_interactively(
memory_service=memory_service,
credential_service=credential_service,
)
while True:
query = input('[user]: ')
if not query or not query.strip():
continue
if query == 'exit':
break
async with Aclosing(
runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=types.Content(
role='user', parts=[types.Part(text=query)]
),
)
) as agen:
async for event in agen:
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')

# Background task for periodic session saving
save_task: Optional[asyncio.Task[None]] = None

async def _periodic_save_session() -> None:
"""Periodically save the session to disk every interval seconds."""
nonlocal session, save_task
try:
while True:
await asyncio.sleep(interval)
click.echo(f"interval : {interval}")

if save_session_on_runtime and agent_root:
try:
current_session = await session_service.get_session(
app_name=session.app_name,
user_id=session.user_id,
session_id=session.id,
)

if current_session:
runtime_session_path = (
agent_root / ".adk" / "runtime_session.json"
)
runtime_session_path.parent.mkdir(parents=True, exist_ok=True)
runtime_session_path.write_text(
current_session.model_dump_json(
indent=2, exclude_none=True, by_alias=True
),
encoding="utf-8",
)
except Exception:
pass
except asyncio.CancelledError:
pass

# Start the periodic save task if enabled
if save_session_on_runtime and agent_root:
save_task = asyncio.create_task(_periodic_save_session())

try:
while True:
query = input("[user]: ")
if not query or not query.strip():
continue
if query == "exit":
break
async with Aclosing(
runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=types.Content(
role="user", parts=[types.Part(text=query)]
),
)
) as agen:
async for event in agen:
if event.content and event.content.parts:
if text := "".join(part.text or "" for part in event.content.parts):
click.echo(f"[{event.author}]: {text}")
finally:
if save_task:
save_task.cancel()
try:
await save_task
except asyncio.CancelledError:
pass

await runner.close()


Expand All @@ -140,6 +193,8 @@ async def run_cli(
input_file: Optional[str] = None,
saved_session_file: Optional[str] = None,
save_session: bool,
save_session_on_runtime: bool,
interval: int,
session_id: Optional[str] = None,
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
Expand Down Expand Up @@ -167,7 +222,7 @@ async def run_cli(
agent_parent_path = Path(agent_parent_dir).resolve()
agent_root = agent_parent_path / agent_folder_name
load_services_module(str(agent_root))
user_id = 'test_user'
user_id = "test_user"

agents_dir = str(agent_parent_path)
agent_loader = AgentLoader(agents_dir=agents_dir)
Expand All @@ -179,7 +234,7 @@ async def run_cli(
if isinstance(agent_or_app, App) and agent_or_app.name != agent_folder_name:
app_name_to_dir = {agent_or_app.name: agent_folder_name}

if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'):
if not is_env_enabled("ADK_DISABLE_LOAD_DOTENV"):
envs.load_dotenv_for_agent(agent_folder_name, agents_dir)

# Create session and artifact services using factory functions.
Expand Down Expand Up @@ -211,8 +266,8 @@ def _print_event(event) -> None:
text_parts = [part.text for part in content.parts if part.text]
if not text_parts:
return
author = event.author or 'system'
click.echo(f'[{author}]: {"".join(text_parts)}')
author = event.author or "system"
click.echo(f"[{author}]: {''.join(text_parts)}")

if input_file:
session = await run_input_file(
Expand All @@ -227,7 +282,7 @@ def _print_event(event) -> None:
)
elif saved_session_file:
# Load the saved session from file
with open(saved_session_file, 'r', encoding='utf-8') as f:
with open(saved_session_file, "r", encoding="utf-8") as f:
loaded_session = Session.model_validate_json(f.read())

# Create a new session in the service, copying state from the file
Expand All @@ -252,22 +307,54 @@ def _print_event(event) -> None:
memory_service=memory_service,
)
else:
session = await session_service.create_session(
app_name=session_app_name, user_id=user_id
)
click.echo(f'Running agent {agent_or_app.name}, type exit to exit.')
# Check for runtime saved session
runtime_session_path = agent_root / ".adk" / "runtime_session.json"
if runtime_session_path.exists():
try:
with open(runtime_session_path, "r", encoding="utf-8") as f:
loaded_session = Session.model_validate_json(f.read())

# Create a new session in the service, copying state from the file
session = await session_service.create_session(
app_name=session_app_name,
user_id=user_id,
state=loaded_session.state if loaded_session else None,
)

# Append events from the file to the new session and display them
if loaded_session:
for event in loaded_session.events:
await session_service.append_event(session, event)
_print_event(event)

click.echo(f"Loaded runtime session from {runtime_session_path}")
except Exception as e:
click.echo(f"Warning: Failed to load runtime session: {e}")
# Fall back to creating a new session
session = await session_service.create_session(
app_name=session_app_name, user_id=user_id
)
else:
session = await session_service.create_session(
app_name=session_app_name, user_id=user_id
)

click.echo(f"Running agent {agent_or_app.name}, type exit to exit.")
await run_interactively(
agent_or_app,
artifact_service,
session,
session_service,
credential_service,
memory_service=memory_service,
save_session_on_runtime=True,
interval=interval,
agent_root=agent_root,
)

if save_session:
session_id = session_id or input('Session ID to save: ')
session_path = agent_root / f'{session_id}.session.json'
session_id = session_id or input("Session ID to save: ")
session_path = agent_root / f"{session_id}.session.json"

# Fetch the session again to get all the details.
session = await session_service.get_session(
Expand All @@ -277,7 +364,7 @@ def _print_event(event) -> None:
)
session_path.write_text(
session.model_dump_json(indent=2, exclude_none=True, by_alias=True),
encoding='utf-8',
encoding="utf-8",
)

print('Session saved to', session_path)
print("Session saved to", session_path)
26 changes: 24 additions & 2 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,24 @@ def wrapper(*args, **kwargs):
@main.command("run", cls=HelpfulCommand)
@feature_options()
@adk_services_options(default_use_local_storage=True)
@click.option(
"--interval",
type=int,
default=60,
show_default=True,
help=(
"Autosave interval in seconds (only used if --save_session_on_runtime"
" is set)."
),
)
@click.option(
"--save_session_on_runtime",
type=bool,
is_flag=True,
show_default=True,
default=False,
help="Optional. Whether to save the session to a json file on runtime.",
)
@click.option(
"--save_session",
type=bool,
Expand Down Expand Up @@ -653,17 +671,19 @@ def wrapper(*args, **kwargs):
exists=True, dir_okay=True, file_okay=False, resolve_path=True
),
)
def cli_run(
def cli_run( # type: ignore[misc]
agent: str,
save_session: bool,
save_session_on_runtime: bool,
interval: int,
session_id: Optional[str],
replay: Optional[str],
resume: Optional[str],
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
use_local_storage: bool = True,
):
) -> None:
"""Runs an interactive CLI for a certain agent.

AGENT: The path to the agent source code folder.
Expand All @@ -684,6 +704,8 @@ def cli_run(
input_file=replay,
saved_session_file=resume,
save_session=save_session,
save_session_on_runtime=save_session_on_runtime,
interval=interval,
session_id=session_id,
session_service_uri=session_service_uri,
artifact_service_uri=artifact_service_uri,
Expand Down
Loading
Loading