diff --git a/pyproject.toml b/pyproject.toml index 73914157..b78058c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,7 +185,6 @@ ignore = [ "EM102", # Exception f-strings "G004", # Logging f-strings "T201", # print() used for user output - "TRY003", # Raise with inline message strings # Backwards-compatibility suppressions for existing code "A001", # Variable shadows built-in diff --git a/src/seclab_taskflow_agent/_stream.py b/src/seclab_taskflow_agent/_stream.py index 9befae4f..413e0de1 100644 --- a/src/seclab_taskflow_agent/_stream.py +++ b/src/seclab_taskflow_agent/_stream.py @@ -27,6 +27,7 @@ from .render_utils import render_model_output from .sdk import TextDelta, ToolEnd from .sdk.errors import BackendRateLimitError, BackendTimeoutError +from seclab_taskflow_agent.error_utils import error_with_message # Application-level backstop: if the backend's event stream goes silent # for this long, surface a BackendTimeoutError so the retry loop can @@ -100,9 +101,7 @@ async def drive_backend_stream( except StopAsyncIteration: break except asyncio.TimeoutError as exc: - raise BackendTimeoutError( - f"Backend stream idle for {STREAM_IDLE_TIMEOUT}s" - ) from exc + raise error_with_message(BackendTimeoutError, f"Backend stream idle for {STREAM_IDLE_TIMEOUT}s") from exc watchdog_ping() if isinstance(event, TextDelta): await render_model_output( @@ -130,7 +129,7 @@ async def drive_backend_stream( except BackendRateLimitError as exc: last_rate_limit_exc = exc if rate_limit_backoff == max_rate_limit_backoff: - raise BackendTimeoutError("Max rate limit backoff reached") from exc + raise error_with_message(BackendTimeoutError, "Max rate limit backoff reached") from exc if rate_limit_backoff > max_rate_limit_backoff: rate_limit_backoff = max_rate_limit_backoff else: @@ -139,4 +138,4 @@ async def drive_backend_stream( await asyncio.sleep(rate_limit_backoff) if last_rate_limit_exc is not None: # pragma: no cover - loop always returns/raises above - raise BackendTimeoutError("Rate limit backoff exhausted") from last_rate_limit_exc + raise error_with_message(BackendTimeoutError, "Rate limit backoff exhausted") from last_rate_limit_exc diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 394395dc..3737e0ce 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -27,6 +27,7 @@ import httpx from .capi import get_AI_endpoint, get_AI_token, get_provider +from seclab_taskflow_agent.error_utils import error_with_message __all__ = [ "DEFAULT_MODEL", @@ -173,7 +174,7 @@ def __init__( if token: resolved_token = os.getenv(token, "") if not resolved_token: - raise RuntimeError(f"Token env var {token!r} is not set") + raise error_with_message(RuntimeError, f"Token env var {token!r} is not set") else: resolved_token = get_AI_token() diff --git a/src/seclab_taskflow_agent/available_tools.py b/src/seclab_taskflow_agent/available_tools.py index 577ae067..88f05978 100644 --- a/src/seclab_taskflow_agent/available_tools.py +++ b/src/seclab_taskflow_agent/available_tools.py @@ -17,6 +17,7 @@ import yaml from pydantic import ValidationError +from seclab_taskflow_agent.error_utils import error_with_message from .models import ( DOCUMENT_MODELS, @@ -108,18 +109,14 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: # Resolve package and filename from dotted path components = toolname.rsplit(".", 1) if len(components) != 2: - raise BadToolNameError( - f'Not a valid toolname: "{toolname}". ' - f'Expected format: "packagename.filename"' - ) + raise error_with_message(BadToolNameError, f'Not a valid toolname: "{toolname}". ' + f'Expected format: "packagename.filename"') package, filename = components try: pkg_dir = importlib.resources.files(package) if not pkg_dir.is_dir(): - raise BadToolNameError( - f"Cannot load {toolname} because {pkg_dir} is not a valid directory." - ) + raise error_with_message(BadToolNameError, f"Cannot load {toolname} because {pkg_dir} is not a valid directory.") filepath = pkg_dir.joinpath(filename + ".yaml") with filepath.open() as fh: raw = yaml.safe_load(fh) @@ -128,17 +125,13 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: header = raw.get("seclab-taskflow-agent", {}) filetype = header.get("filetype", "") if filetype != tooltype.value: - raise FileTypeException( - f"Error in {filepath}: expected filetype {tooltype.value!r}, " - f"got {filetype!r}." - ) + raise error_with_message(FileTypeException, f"Error in {filepath}: expected filetype {tooltype.value!r}, " + f"got {filetype!r}.") # Parse into the appropriate Pydantic model model_cls = DOCUMENT_MODELS.get(filetype) if model_cls is None: - raise BadToolNameError( - f"Unknown filetype {filetype!r} in {toolname}" - ) + raise error_with_message(BadToolNameError, f"Unknown filetype {filetype!r} in {toolname}") try: doc = model_cls(**raw) @@ -147,9 +140,7 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: for err in exc.errors(): if "Unsupported version" in str(err.get("msg", "")): raise VersionException(str(err["msg"])) from exc - raise BadToolNameError( - f"Validation error loading {toolname}: {exc}" - ) from exc + raise error_with_message(BadToolNameError, f"Validation error loading {toolname}: {exc}") from exc # Cache and return if tooltype not in self._cache: @@ -158,10 +149,8 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel: return doc except ModuleNotFoundError as exc: - raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc + raise error_with_message(BadToolNameError, f"Cannot load {toolname}: {exc}") from exc except FileNotFoundError: - raise BadToolNameError( - f"Cannot load {toolname} because {filepath} is not a valid file." - ) + raise error_with_message(BadToolNameError, f"Cannot load {toolname} because {filepath} is not a valid file.") except ValueError as exc: - raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc + raise error_with_message(BadToolNameError, f"Cannot load {toolname}: {exc}") from exc diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index a605258f..9af219ee 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -22,6 +22,7 @@ from urllib.parse import urlparse import httpx +from seclab_taskflow_agent.error_utils import error_with_message __all__ = [ "COPILOT_INTEGRATION_ID", @@ -193,7 +194,7 @@ def get_AI_token() -> str: token = os.getenv("COPILOT_TOKEN") if token: return token - raise RuntimeError("AI_API_TOKEN environment variable is not set.") + raise error_with_message(RuntimeError, "AI_API_TOKEN environment variable is not set.") # --------------------------------------------------------------------------- diff --git a/src/seclab_taskflow_agent/cli.py b/src/seclab_taskflow_agent/cli.py index 19e7790c..33e01d5a 100644 --- a/src/seclab_taskflow_agent/cli.py +++ b/src/seclab_taskflow_agent/cli.py @@ -20,6 +20,7 @@ from typing import Annotated import typer +from seclab_taskflow_agent.error_utils import error_with_message from .available_tools import AvailableTools from .banner import get_banner @@ -37,7 +38,7 @@ def _parse_global(value: str) -> tuple[str, str]: """Parse a ``KEY=VALUE`` string into a (key, value) pair.""" if "=" not in value: - raise typer.BadParameter(f"Invalid global variable format: {value!r}. Expected KEY=VALUE.") + raise error_with_message(typer.BadParameter, f"Invalid global variable format: {value!r}. Expected KEY=VALUE.") key, _, val = value.partition("=") return key.strip(), val.strip() diff --git a/src/seclab_taskflow_agent/env_utils.py b/src/seclab_taskflow_agent/env_utils.py index 6d756962..bf64823e 100644 --- a/src/seclab_taskflow_agent/env_utils.py +++ b/src/seclab_taskflow_agent/env_utils.py @@ -7,6 +7,7 @@ from typing import Any import jinja2 +from seclab_taskflow_agent.error_utils import error_with_message __all__ = ["TmpEnv", "swap_env"] @@ -49,7 +50,7 @@ def swap_env(s: str, context: dict[str, Any] | None = None) -> str: except jinja2.UndefinedError as e: raise LookupError(str(e)) except jinja2.TemplateError as e: - raise LookupError(f"Template rendering failed for: {s!r}: {e}") + raise error_with_message(LookupError, f"Template rendering failed for: {s!r}: {e}") class TmpEnv: diff --git a/src/seclab_taskflow_agent/error_utils.py b/src/seclab_taskflow_agent/error_utils.py new file mode 100644 index 00000000..c87635b3 --- /dev/null +++ b/src/seclab_taskflow_agent/error_utils.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Helpers for constructing exceptions without inline raise messages.""" + +from __future__ import annotations + +__all__ = ["error_with_message"] + +from typing import TypeVar + +ExcT = TypeVar("ExcT", bound=BaseException) + + +def error_with_message(exc_type: type[ExcT], message: str, /) -> ExcT: + """Return *exc_type* initialised with *message*.""" + return exc_type(message) diff --git a/src/seclab_taskflow_agent/mcp_lifecycle.py b/src/seclab_taskflow_agent/mcp_lifecycle.py index 117f52a8..7d5056fa 100644 --- a/src/seclab_taskflow_agent/mcp_lifecycle.py +++ b/src/seclab_taskflow_agent/mcp_lifecycle.py @@ -23,6 +23,7 @@ MCPNamespaceWrap, mcp_client_params, ) +from seclab_taskflow_agent.error_utils import error_with_message if TYPE_CHECKING: from .available_tools import AvailableTools @@ -116,7 +117,7 @@ def _print_err(line: str) -> None: client_session_timeout_seconds=client_session_timeout, ) case _: - raise ValueError(f"Unsupported MCP transport: {params['kind']}") + raise error_with_message(ValueError, f"Unsupported MCP transport: {params['kind']}") entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb)) diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py index af7d03d5..29875702 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/client.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/client.py @@ -16,6 +16,7 @@ import yaml from seclab_taskflow_agent.path_utils import log_file_name +from seclab_taskflow_agent.error_utils import error_with_message # this is a local fork of https://github.com/riga/jsonrpyc modified for our purposes from . import jsonrpyc @@ -194,10 +195,10 @@ def _server_request_run( template_values: dict | None = None, ): if not self.active_database: - raise RuntimeError("No Active Database") + raise error_with_message(RuntimeError, "No Active Database") if not self.active_connection: - raise RuntimeError("No Active Connection") + raise error_with_message(RuntimeError, "No Active Connection") if isinstance(quick_eval_pos, dict): # A quick eval position contains: @@ -302,7 +303,7 @@ def _format(self, query): def _resolve_query_server(self): help_msg = shell_command_to_string(self.codeql_cli + ["excute", "--help"]) if not re.search("query-server2", help_msg): - raise RuntimeError("Legacy server not supported!") + raise error_with_message(RuntimeError, "Legacy server not supported!") return "query-server2" def _resolve_library_paths(self, query_path): @@ -463,11 +464,11 @@ def _file_uri_to_path(uri): # internally the codeql client will resolve both relative and full paths # regardless of root directory differences if not uri.startswith("file:///"): - raise ValueError("URI path should be formatted as absolute") + raise error_with_message(ValueError, "URI path should be formatted as absolute") # note: don't try to parse paths like "file://a/b" because that returns "/b", should be "file:///a/b" parsed = urlparse(uri) if parsed.scheme != "file": - raise ValueError(f"Not a file:// uri: {uri}") + raise error_with_message(ValueError, f"Not a file:// uri: {uri}") path = unquote(parsed.path) region = None if ":" in path: @@ -605,7 +606,7 @@ def run_query( if target: target_pos = get_query_position(query_path, target) if not target_pos: - raise ValueError(f"Could not resolve quick eval target for {target}") + raise error_with_message(ValueError, f"Could not resolve quick eval target for {target}") try: with ( QueryServer(database, keep_alive=keep_alive, log_stderr=log_stderr) as server, @@ -633,7 +634,7 @@ def run_query( case "sarif": result = server._bqrs_to_sarif(bqrs_path, server._query_info(query_path)) case _: - raise ValueError("Unsupported output format {fmt}") + raise error_with_message(ValueError, "Unsupported output format {fmt}") except Exception as e: - raise RuntimeError(f"Error in run_query: {e}") from e + raise error_with_message(RuntimeError, f"Error in run_query: {e}") from e return result diff --git a/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py b/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py index d245666a..1e22dcd6 100644 --- a/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py +++ b/src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py @@ -12,6 +12,7 @@ from pydantic import Field from seclab_taskflow_agent.path_utils import log_file_name, mcp_data_dir +from seclab_taskflow_agent.error_utils import error_with_message from .client import _debug_log, file_from_uri, list_src_files, run_query, search_in_src_archive @@ -53,10 +54,10 @@ def _resolve_query_path(language: str, query: str) -> Path: global TEMPLATED_QUERY_PATHS if language not in TEMPLATED_QUERY_PATHS: - raise RuntimeError(f"Error: Language `{language}` not supported!") + raise error_with_message(RuntimeError, f"Error: Language `{language}` not supported!") query_path = TEMPLATED_QUERY_PATHS[language].get(query) if not query_path: - raise RuntimeError(f"Error: query `{query}` not supported for `{language}`!") + raise error_with_message(RuntimeError, f"Error: query `{query}` not supported for `{language}`!") return Path(query_path) @@ -69,7 +70,7 @@ def _resolve_db_path(relative_db_path: str | Path): absolute_path = CODEQL_DBS_BASE_PATH / relative_db_path if not absolute_path.is_dir(): _debug_log(f"Database path not found: {absolute_path}") - raise RuntimeError(f"Error: Database not found at {absolute_path}!") + raise error_with_message(RuntimeError, f"Error: Database not found at {absolute_path}!") return absolute_path diff --git a/src/seclab_taskflow_agent/mcp_transport.py b/src/seclab_taskflow_agent/mcp_transport.py index 8632fd8d..6188f2a3 100644 --- a/src/seclab_taskflow_agent/mcp_transport.py +++ b/src/seclab_taskflow_agent/mcp_transport.py @@ -34,6 +34,7 @@ from urllib.parse import urlparse from agents.mcp import MCPServerStdio +from seclab_taskflow_agent.error_utils import error_with_message # Exit codes that are considered normal termination. _EXPECTED_EXIT_CODES: frozenset[int] = frozenset({0, -signal.SIGTERM}) @@ -109,7 +110,7 @@ async def async_wait_for_connection( host = parsed.hostname port = parsed.port if host is None or port is None: - raise ValueError(f"URL must include a host and port: {self.url}") + raise error_with_message(ValueError, f"URL must include a host and port: {self.url}") deadline = asyncio.get_event_loop().time() + timeout while True: try: @@ -119,7 +120,7 @@ async def async_wait_for_connection( return except (OSError, ConnectionRefusedError): if asyncio.get_event_loop().time() > deadline: - raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds") + raise error_with_message(TimeoutError, f"Could not connect to {host}:{port} after {timeout} seconds") await asyncio.sleep(poll_interval) def wait_for_connection( @@ -139,7 +140,7 @@ def wait_for_connection( host = parsed.hostname port = parsed.port if host is None or port is None: - raise ValueError(f"URL must include a host and port: {self.url}") + raise error_with_message(ValueError, f"URL must include a host and port: {self.url}") deadline = time.time() + timeout while True: try: @@ -147,7 +148,7 @@ def wait_for_connection( return except OSError: if time.time() > deadline: - raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds") + raise error_with_message(TimeoutError, f"Could not connect to {host}:{port} after {timeout} seconds") time.sleep(poll_interval) def run(self) -> None: @@ -216,7 +217,7 @@ def join_and_raise(self, timeout: float | None = None) -> None: """ self.join(timeout) if self.is_alive(): - raise RuntimeError("Process thread did not exit within timeout.") + raise error_with_message(RuntimeError, "Process thread did not exit within timeout.") if self.exception is not None: raise self.exception diff --git a/src/seclab_taskflow_agent/mcp_utils.py b/src/seclab_taskflow_agent/mcp_utils.py index d20446db..88a904dc 100644 --- a/src/seclab_taskflow_agent/mcp_utils.py +++ b/src/seclab_taskflow_agent/mcp_utils.py @@ -28,6 +28,7 @@ from .available_tools import AvailableTools from .env_utils import swap_env +from seclab_taskflow_agent.error_utils import error_with_message # Re-export transport classes and prompt builder so that existing # ``from .mcp_utils import …`` statements continue to work. @@ -209,7 +210,7 @@ def mcp_client_params( logging.debug(f"Initializing streamable toolbox: {tb}\nargs:\n{args}\nenv:\n{env}\n") exe = shutil.which(sp.command) if exe is None: - raise FileNotFoundError(f"Could not resolve path to {sp.command}") + raise error_with_message(FileNotFoundError, f"Could not resolve path to {sp.command}") start_cmd = [exe] if args: for i, v in enumerate(args): @@ -227,7 +228,7 @@ def mcp_client_params( server_params["env"] = env case _: - raise ValueError(f"Unsupported MCP transport {kind}") + raise error_with_message(ValueError, f"Unsupported MCP transport {kind}") client_params[tb] = ( server_params, diff --git a/src/seclab_taskflow_agent/models.py b/src/seclab_taskflow_agent/models.py index eff05ee6..e1efb6fc 100644 --- a/src/seclab_taskflow_agent/models.py +++ b/src/seclab_taskflow_agent/models.py @@ -29,6 +29,7 @@ from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from seclab_taskflow_agent.error_utils import error_with_message # Valid API type values for model configuration. ApiType = Literal["chat_completions", "responses"] @@ -66,9 +67,7 @@ def _normalise_version(cls, v: Any) -> str: @classmethod def _validate_version(cls, v: str) -> str: if v != SUPPORTED_VERSION: - raise ValueError( - f"Unsupported version: {v}. Only version {SUPPORTED_VERSION} is supported." - ) + raise error_with_message(ValueError, f"Unsupported version: {v}. Only version {SUPPORTED_VERSION} is supported.") return v @@ -110,7 +109,7 @@ class TaskDefinition(BaseModel): @model_validator(mode="after") def _run_xor_prompt(self) -> TaskDefinition: if self.run and self.user_prompt: - raise ValueError("shell task ('run') and prompt task ('user_prompt') are mutually exclusive") + raise error_with_message(ValueError, "shell task ('run') and prompt task ('user_prompt') are mutually exclusive") return self diff --git a/src/seclab_taskflow_agent/runner.py b/src/seclab_taskflow_agent/runner.py index b05cc6bf..eacda470 100644 --- a/src/seclab_taskflow_agent/runner.py +++ b/src/seclab_taskflow_agent/runner.py @@ -38,6 +38,7 @@ from .models import ModelConfigDocument, PersonalityDocument, TaskDefinition from .render_utils import flush_async_output, render_model_output from .sdk import AgentSpec, MCPServerSpec, get_backend, resolve_backend_name +from seclab_taskflow_agent.error_utils import error_with_message from .sdk.errors import ( BackendBadRequestError, BackendMaxTurnsError, @@ -85,9 +86,7 @@ def _resolve_model_config( models_params: dict[str, dict[str, Any]] = m_config.model_settings or {} unknown = set(models_params) - set(model_keys) if unknown: - raise ValueError( - f"Settings section of model_config file {model_config_ref} contains models not in the model section: {unknown}" - ) + raise error_with_message(ValueError, f"Settings section of model_config file {model_config_ref} contains models not in the model section: {unknown}") return model_keys, model_dict, models_params, m_config.api_type, m_config.backend @@ -110,9 +109,9 @@ def _merge_reusable_task( """ reusable_doc = available_tools.get_taskflow(task.uses) if reusable_doc is None: - raise ValueError(f"No such reusable taskflow: {task.uses}") + raise error_with_message(ValueError, f"No such reusable taskflow: {task.uses}") if len(reusable_doc.taskflow) > 1: - raise ValueError("Reusable taskflows can only contain 1 task") + raise error_with_message(ValueError, "Reusable taskflows can only contain 1 task") parent_task = reusable_doc.taskflow[0].task merged: dict[str, Any] = parent_task.model_dump(by_alias=True, exclude_defaults=True) current: dict[str, Any] = task.model_dump(by_alias=True, exclude_defaults=True) @@ -154,7 +153,7 @@ def _resolve_task_model( task_model_settings: dict[str, Any] | Any = task.model_settings or {} if not isinstance(task_model_settings, dict): - raise ValueError(f"model_settings in task {task.name or ''} needs to be a dictionary") + raise error_with_message(ValueError, f"model_settings in task {task.name or ''} needs to be a dictionary") # Task-level overrides can also set engine keys task_settings = dict(task_model_settings) @@ -205,14 +204,14 @@ async def _build_prompts_to_run( raise except json.JSONDecodeError as exc: logging.critical(f"Could not parse tool result as JSON: {last_mcp_tool_results[-1][:200]}") - raise ValueError("Tool result is not valid JSON") from exc + raise error_with_message(ValueError, "Tool result is not valid JSON") from exc text = last_result.get("text", "") try: iterable_result = json.loads(text) except json.JSONDecodeError as exc: logging.critical(f"Could not parse result text: {text}") - raise ValueError("Result text is not valid JSON") from exc + raise error_with_message(ValueError, "Result text is not valid JSON") from exc try: iter(iterable_result) except TypeError: @@ -235,7 +234,7 @@ async def _build_prompts_to_run( prompts_to_run.append(rendered_prompt) except jinja2.TemplateError as e: logging.error(f"Error rendering template for result {value}: {e}") - raise ValueError(f"Template rendering failed: {e}") + raise error_with_message(ValueError, f"Template rendering failed: {e}") # Consume only after all prompts rendered successfully so that # the result remains available for retry/resume on failure. @@ -605,7 +604,7 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo inputs = task.inputs or {} task_prompt = task.user_prompt or "" if run and task_prompt: - raise ValueError("shell task and prompt task are mutually exclusive!") + raise error_with_message(ValueError, "shell task and prompt task are mutually exclusive!") must_complete = task.must_complete max_turns = task.max_steps or DEFAULT_MAX_TURNS toolboxes_override = task.toolboxes or [] @@ -626,7 +625,7 @@ async def on_handoff_hook(context: RunContextWrapper[TContext], agent: Agent[TCo ) except jinja2.TemplateError as e: logging.error(f"Template rendering error: {e}") - raise ValueError(f"Failed to render prompt template: {e}") from e + raise error_with_message(ValueError, f"Failed to render prompt template: {e}") from e with TmpEnv(env, context={"globals": global_variables}): prompts_to_run: list[str] = await _build_prompts_to_run( @@ -660,14 +659,12 @@ async def run_prompts(async_task: bool = False, max_concurrent_tasks: int = 5) - for agent_name in current_agents: personality = available_tools.get_personality(agent_name) if personality is None: - raise ValueError(f"No such personality: {agent_name}") + raise error_with_message(ValueError, f"No such personality: {agent_name}") resolved_agents[agent_name] = personality if not resolved_agents: - raise ValueError( - "No agents resolved for this task. " - "Specify a personality with -p or provide an agents list." - ) + raise error_with_message(ValueError, "No agents resolved for this task. " + "Specify a personality with -p or provide an agents list.") async def _deploy(ra: dict, pp: str) -> bool: async with semaphore: diff --git a/src/seclab_taskflow_agent/sdk/__init__.py b/src/seclab_taskflow_agent/sdk/__init__.py index 15086922..cfc4e74f 100644 --- a/src/seclab_taskflow_agent/sdk/__init__.py +++ b/src/seclab_taskflow_agent/sdk/__init__.py @@ -31,6 +31,7 @@ TextDelta, ToolEnd, ) +from seclab_taskflow_agent.error_utils import error_with_message _ENV_VAR = "SECLAB_TASKFLOW_BACKEND" _KNOWN = ("openai_agents", "copilot_sdk") @@ -40,7 +41,7 @@ def get_backend(name: str) -> AgentBackend: """Return the backend adapter instance for *name*, importing it lazily.""" if name not in _KNOWN: - raise ValueError(f"Unknown backend {name!r}. Known: {_KNOWN}") + raise error_with_message(ValueError, f"Unknown backend {name!r}. Known: {_KNOWN}") if name not in _BACKENDS: if name == "openai_agents": from .openai_agents.backend import OpenAIAgentsBackend @@ -73,5 +74,5 @@ def resolve_backend_name( del endpoint # reserved for forward compat; not used for selection name = explicit or os.getenv(_ENV_VAR) or "openai_agents" if name not in _KNOWN: - raise ValueError(f"Unknown backend {name!r}. Known: {_KNOWN}") + raise error_with_message(ValueError, f"Unknown backend {name!r}. Known: {_KNOWN}") return name diff --git a/src/seclab_taskflow_agent/sdk/copilot_sdk/backend.py b/src/seclab_taskflow_agent/sdk/copilot_sdk/backend.py index ec08294b..bc8922dc 100644 --- a/src/seclab_taskflow_agent/sdk/copilot_sdk/backend.py +++ b/src/seclab_taskflow_agent/sdk/copilot_sdk/backend.py @@ -28,6 +28,7 @@ from ..errors import BackendBadRequestError, BackendCapabilityError, BackendUnexpectedError from .mcp import build_mcp_config from .permissions import build_permission_handler +from seclab_taskflow_agent.error_utils import error_with_message _VALID_REASONING = ("low", "medium", "high", "xhigh") @@ -66,9 +67,7 @@ def _normalize_model(model: str) -> str: about the model under test. """ if not model: - raise BackendBadRequestError( - "copilot_sdk: model is required (the SDK would otherwise pick a default)" - ) + raise error_with_message(BackendBadRequestError, "copilot_sdk: model is required (the SDK would otherwise pick a default)") return model.split("/", 1)[1] if "/" in model else model @@ -96,10 +95,8 @@ def _reasoning_effort(model_settings: dict[str, Any]) -> str | None: if raw is None: return None if raw not in _VALID_REASONING: - raise BackendBadRequestError( - f"copilot_sdk: invalid reasoning_effort {raw!r} " - f"(expected one of {_VALID_REASONING})" - ) + raise error_with_message(BackendBadRequestError, f"copilot_sdk: invalid reasoning_effort {raw!r} " + f"(expected one of {_VALID_REASONING})") return raw @@ -132,14 +129,10 @@ def validate(self, spec: AgentSpec) -> None: wire protocol per model. """ if spec.handoffs or spec.in_handoff_graph: - raise BackendCapabilityError( - "copilot_sdk: agent handoffs are not supported" - ) + raise error_with_message(BackendCapabilityError, "copilot_sdk: agent handoffs are not supported") for unsupported in ("temperature", "parallel_tool_calls"): if unsupported in spec.model_settings: - raise BackendCapabilityError( - f"copilot_sdk: model_settings.{unsupported} is not supported" - ) + raise error_with_message(BackendCapabilityError, f"copilot_sdk: model_settings.{unsupported} is not supported") async def build( self, diff --git a/src/seclab_taskflow_agent/session.py b/src/seclab_taskflow_agent/session.py index 9b771511..4b364acd 100644 --- a/src/seclab_taskflow_agent/session.py +++ b/src/seclab_taskflow_agent/session.py @@ -24,6 +24,7 @@ from pydantic import BaseModel, Field from .path_utils import _data_dir +from seclab_taskflow_agent.error_utils import error_with_message def session_dir() -> Path: @@ -121,7 +122,7 @@ def load(cls, session_id: str) -> TaskflowSession: """ path = session_dir() / f"{session_id}.json" if not path.exists(): - raise FileNotFoundError(f"No session checkpoint found: {session_id}") + raise error_with_message(FileNotFoundError, f"No session checkpoint found: {session_id}") return cls.model_validate_json(path.read_text()) @classmethod diff --git a/src/seclab_taskflow_agent/shell_utils.py b/src/seclab_taskflow_agent/shell_utils.py index 75175eca..a4b1b80a 100644 --- a/src/seclab_taskflow_agent/shell_utils.py +++ b/src/seclab_taskflow_agent/shell_utils.py @@ -8,6 +8,7 @@ import tempfile from mcp.types import CallToolResult, TextContent +from seclab_taskflow_agent.error_utils import error_with_message __all__ = ["shell_command_to_string", "shell_exec_with_temporary_file", "shell_tool_call"] @@ -23,7 +24,7 @@ def shell_command_to_string(cmd: list[str]) -> str: stdout, stderr = p.communicate() p.wait() if p.returncode: - raise RuntimeError(f"Command {cmd} failed: {stderr}") + raise error_with_message(RuntimeError, f"Command {cmd} failed: {stderr}") return stdout diff --git a/src/seclab_taskflow_agent/template_utils.py b/src/seclab_taskflow_agent/template_utils.py index 2f21d4a6..503e43f6 100644 --- a/src/seclab_taskflow_agent/template_utils.py +++ b/src/seclab_taskflow_agent/template_utils.py @@ -14,6 +14,7 @@ from .available_tools import AvailableTools from .available_tools import BadToolNameError +from seclab_taskflow_agent.error_utils import error_with_message class PromptLoader(jinja2.BaseLoader): @@ -77,7 +78,7 @@ def env_function(var_name: str, default: Optional[str] = None, required: bool = """ value = os.getenv(var_name, default) if value is None and required: - raise LookupError(f"Required environment variable {var_name} not found!") + raise error_with_message(LookupError, f"Required environment variable {var_name} not found!") return value or "" diff --git a/tests/test_error_utils.py b/tests/test_error_utils.py new file mode 100644 index 00000000..e8356f6a --- /dev/null +++ b/tests/test_error_utils.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: GitHub, Inc. +# SPDX-License-Identifier: MIT + +"""Tests for exception-construction helpers.""" + +from seclab_taskflow_agent.error_utils import error_with_message +from seclab_taskflow_agent.sdk.errors import BackendTimeoutError + + +def test_error_with_message_preserves_type_and_message(): + exc = error_with_message(BackendTimeoutError, "timed out") + + assert isinstance(exc, BackendTimeoutError) + assert str(exc) == "timed out"