Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ ignore = [
"EM102", # Exception f-strings
"G004", # Logging f-strings
"T201", # print() used for user output
"TRY003", # Raise with inline message strings

# Backwards-compatibility suppressions for existing code
"A001", # Variable shadows built-in
Expand Down
9 changes: 4 additions & 5 deletions src/seclab_taskflow_agent/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .render_utils import render_model_output
from .sdk import TextDelta, ToolEnd
from .sdk.errors import BackendRateLimitError, BackendTimeoutError
from seclab_taskflow_agent.error_utils import error_with_message

# Application-level backstop: if the backend's event stream goes silent
# for this long, surface a BackendTimeoutError so the retry loop can
Expand Down Expand Up @@ -100,9 +101,7 @@ async def drive_backend_stream(
except StopAsyncIteration:
break
except asyncio.TimeoutError as exc:
raise BackendTimeoutError(
f"Backend stream idle for {STREAM_IDLE_TIMEOUT}s"
) from exc
raise error_with_message(BackendTimeoutError, f"Backend stream idle for {STREAM_IDLE_TIMEOUT}s") from exc
watchdog_ping()
if isinstance(event, TextDelta):
await render_model_output(
Expand Down Expand Up @@ -130,7 +129,7 @@ async def drive_backend_stream(
except BackendRateLimitError as exc:
last_rate_limit_exc = exc
if rate_limit_backoff == max_rate_limit_backoff:
raise BackendTimeoutError("Max rate limit backoff reached") from exc
raise error_with_message(BackendTimeoutError, "Max rate limit backoff reached") from exc
if rate_limit_backoff > max_rate_limit_backoff:
rate_limit_backoff = max_rate_limit_backoff
else:
Expand All @@ -139,4 +138,4 @@ async def drive_backend_stream(
await asyncio.sleep(rate_limit_backoff)

if last_rate_limit_exc is not None: # pragma: no cover - loop always returns/raises above
raise BackendTimeoutError("Rate limit backoff exhausted") from last_rate_limit_exc
raise error_with_message(BackendTimeoutError, "Rate limit backoff exhausted") from last_rate_limit_exc
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import httpx

from .capi import get_AI_endpoint, get_AI_token, get_provider
from seclab_taskflow_agent.error_utils import error_with_message

__all__ = [
"DEFAULT_MODEL",
Expand Down Expand Up @@ -173,7 +174,7 @@ def __init__(
if token:
resolved_token = os.getenv(token, "")
if not resolved_token:
raise RuntimeError(f"Token env var {token!r} is not set")
raise error_with_message(RuntimeError, f"Token env var {token!r} is not set")
else:
resolved_token = get_AI_token()

Expand Down
33 changes: 11 additions & 22 deletions src/seclab_taskflow_agent/available_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import yaml
from pydantic import ValidationError
from seclab_taskflow_agent.error_utils import error_with_message

from .models import (
DOCUMENT_MODELS,
Expand Down Expand Up @@ -108,18 +109,14 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel:
# Resolve package and filename from dotted path
components = toolname.rsplit(".", 1)
if len(components) != 2:
raise BadToolNameError(
f'Not a valid toolname: "{toolname}". '
f'Expected format: "packagename.filename"'
)
raise error_with_message(BadToolNameError, f'Not a valid toolname: "{toolname}". '
f'Expected format: "packagename.filename"')
package, filename = components

try:
pkg_dir = importlib.resources.files(package)
if not pkg_dir.is_dir():
raise BadToolNameError(
f"Cannot load {toolname} because {pkg_dir} is not a valid directory."
)
raise error_with_message(BadToolNameError, f"Cannot load {toolname} because {pkg_dir} is not a valid directory.")
filepath = pkg_dir.joinpath(filename + ".yaml")
with filepath.open() as fh:
raw = yaml.safe_load(fh)
Expand All @@ -128,17 +125,13 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel:
header = raw.get("seclab-taskflow-agent", {})
filetype = header.get("filetype", "")
if filetype != tooltype.value:
raise FileTypeException(
f"Error in {filepath}: expected filetype {tooltype.value!r}, "
f"got {filetype!r}."
)
raise error_with_message(FileTypeException, f"Error in {filepath}: expected filetype {tooltype.value!r}, "
f"got {filetype!r}.")

# Parse into the appropriate Pydantic model
model_cls = DOCUMENT_MODELS.get(filetype)
if model_cls is None:
raise BadToolNameError(
f"Unknown filetype {filetype!r} in {toolname}"
)
raise error_with_message(BadToolNameError, f"Unknown filetype {filetype!r} in {toolname}")

try:
doc = model_cls(**raw)
Expand All @@ -147,9 +140,7 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel:
for err in exc.errors():
if "Unsupported version" in str(err.get("msg", "")):
raise VersionException(str(err["msg"])) from exc
raise BadToolNameError(
f"Validation error loading {toolname}: {exc}"
) from exc
raise error_with_message(BadToolNameError, f"Validation error loading {toolname}: {exc}") from exc

# Cache and return
if tooltype not in self._cache:
Expand All @@ -158,10 +149,8 @@ def _load(self, tooltype: AvailableToolType, toolname: str) -> DocumentModel:
return doc

except ModuleNotFoundError as exc:
raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc
raise error_with_message(BadToolNameError, f"Cannot load {toolname}: {exc}") from exc
except FileNotFoundError:
raise BadToolNameError(
f"Cannot load {toolname} because {filepath} is not a valid file."
)
raise error_with_message(BadToolNameError, f"Cannot load {toolname} because {filepath} is not a valid file.")
except ValueError as exc:
raise BadToolNameError(f"Cannot load {toolname}: {exc}") from exc
raise error_with_message(BadToolNameError, f"Cannot load {toolname}: {exc}") from exc
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from urllib.parse import urlparse

import httpx
from seclab_taskflow_agent.error_utils import error_with_message

__all__ = [
"COPILOT_INTEGRATION_ID",
Expand Down Expand Up @@ -193,7 +194,7 @@ def get_AI_token() -> str:
token = os.getenv("COPILOT_TOKEN")
if token:
return token
raise RuntimeError("AI_API_TOKEN environment variable is not set.")
raise error_with_message(RuntimeError, "AI_API_TOKEN environment variable is not set.")


# ---------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Annotated

import typer
from seclab_taskflow_agent.error_utils import error_with_message

from .available_tools import AvailableTools
from .banner import get_banner
Expand All @@ -37,7 +38,7 @@
def _parse_global(value: str) -> tuple[str, str]:
"""Parse a ``KEY=VALUE`` string into a (key, value) pair."""
if "=" not in value:
raise typer.BadParameter(f"Invalid global variable format: {value!r}. Expected KEY=VALUE.")
raise error_with_message(typer.BadParameter, f"Invalid global variable format: {value!r}. Expected KEY=VALUE.")
key, _, val = value.partition("=")
return key.strip(), val.strip()

Expand Down
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

import jinja2
from seclab_taskflow_agent.error_utils import error_with_message

__all__ = ["TmpEnv", "swap_env"]

Expand Down Expand Up @@ -49,7 +50,7 @@ def swap_env(s: str, context: dict[str, Any] | None = None) -> str:
except jinja2.UndefinedError as e:
raise LookupError(str(e))
except jinja2.TemplateError as e:
raise LookupError(f"Template rendering failed for: {s!r}: {e}")
raise error_with_message(LookupError, f"Template rendering failed for: {s!r}: {e}")


class TmpEnv:
Expand Down
17 changes: 17 additions & 0 deletions src/seclab_taskflow_agent/error_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: GitHub, Inc.
# SPDX-License-Identifier: MIT

"""Helpers for constructing exceptions without inline raise messages."""

from __future__ import annotations

__all__ = ["error_with_message"]

from typing import TypeVar

ExcT = TypeVar("ExcT", bound=BaseException)


def error_with_message(exc_type: type[ExcT], message: str, /) -> ExcT:
"""Return *exc_type* initialised with *message*."""
return exc_type(message)
3 changes: 2 additions & 1 deletion src/seclab_taskflow_agent/mcp_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MCPNamespaceWrap,
mcp_client_params,
)
from seclab_taskflow_agent.error_utils import error_with_message

if TYPE_CHECKING:
from .available_tools import AvailableTools
Expand Down Expand Up @@ -116,7 +117,7 @@ def _print_err(line: str) -> None:
client_session_timeout_seconds=client_session_timeout,
)
case _:
raise ValueError(f"Unsupported MCP transport: {params['kind']}")
raise error_with_message(ValueError, f"Unsupported MCP transport: {params['kind']}")

entries.append(MCPServerEntry(MCPNamespaceWrap(confirms, mcp_server), server_proc, name=tb))

Expand Down
17 changes: 9 additions & 8 deletions src/seclab_taskflow_agent/mcp_servers/codeql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import yaml

from seclab_taskflow_agent.path_utils import log_file_name
from seclab_taskflow_agent.error_utils import error_with_message

# this is a local fork of https://github.com/riga/jsonrpyc modified for our purposes
from . import jsonrpyc
Expand Down Expand Up @@ -194,10 +195,10 @@ def _server_request_run(
template_values: dict | None = None,
):
if not self.active_database:
raise RuntimeError("No Active Database")
raise error_with_message(RuntimeError, "No Active Database")

if not self.active_connection:
raise RuntimeError("No Active Connection")
raise error_with_message(RuntimeError, "No Active Connection")

if isinstance(quick_eval_pos, dict):
# A quick eval position contains:
Expand Down Expand Up @@ -302,7 +303,7 @@ def _format(self, query):
def _resolve_query_server(self):
help_msg = shell_command_to_string(self.codeql_cli + ["excute", "--help"])
if not re.search("query-server2", help_msg):
raise RuntimeError("Legacy server not supported!")
raise error_with_message(RuntimeError, "Legacy server not supported!")
return "query-server2"

def _resolve_library_paths(self, query_path):
Expand Down Expand Up @@ -463,11 +464,11 @@ def _file_uri_to_path(uri):
# internally the codeql client will resolve both relative and full paths
# regardless of root directory differences
if not uri.startswith("file:///"):
raise ValueError("URI path should be formatted as absolute")
raise error_with_message(ValueError, "URI path should be formatted as absolute")
# note: don't try to parse paths like "file://a/b" because that returns "/b", should be "file:///a/b"
parsed = urlparse(uri)
if parsed.scheme != "file":
raise ValueError(f"Not a file:// uri: {uri}")
raise error_with_message(ValueError, f"Not a file:// uri: {uri}")
path = unquote(parsed.path)
region = None
if ":" in path:
Expand Down Expand Up @@ -605,7 +606,7 @@ def run_query(
if target:
target_pos = get_query_position(query_path, target)
if not target_pos:
raise ValueError(f"Could not resolve quick eval target for {target}")
raise error_with_message(ValueError, f"Could not resolve quick eval target for {target}")
try:
with (
QueryServer(database, keep_alive=keep_alive, log_stderr=log_stderr) as server,
Expand Down Expand Up @@ -633,7 +634,7 @@ def run_query(
case "sarif":
result = server._bqrs_to_sarif(bqrs_path, server._query_info(query_path))
case _:
raise ValueError("Unsupported output format {fmt}")
raise error_with_message(ValueError, "Unsupported output format {fmt}")
except Exception as e:
raise RuntimeError(f"Error in run_query: {e}") from e
raise error_with_message(RuntimeError, f"Error in run_query: {e}") from e
return result
7 changes: 4 additions & 3 deletions src/seclab_taskflow_agent/mcp_servers/codeql/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pydantic import Field

from seclab_taskflow_agent.path_utils import log_file_name, mcp_data_dir
from seclab_taskflow_agent.error_utils import error_with_message

from .client import _debug_log, file_from_uri, list_src_files, run_query, search_in_src_archive

Expand Down Expand Up @@ -53,10 +54,10 @@
def _resolve_query_path(language: str, query: str) -> Path:
global TEMPLATED_QUERY_PATHS
if language not in TEMPLATED_QUERY_PATHS:
raise RuntimeError(f"Error: Language `{language}` not supported!")
raise error_with_message(RuntimeError, f"Error: Language `{language}` not supported!")
query_path = TEMPLATED_QUERY_PATHS[language].get(query)
if not query_path:
raise RuntimeError(f"Error: query `{query}` not supported for `{language}`!")
raise error_with_message(RuntimeError, f"Error: query `{query}` not supported for `{language}`!")
return Path(query_path)


Expand All @@ -69,7 +70,7 @@ def _resolve_db_path(relative_db_path: str | Path):
absolute_path = CODEQL_DBS_BASE_PATH / relative_db_path
if not absolute_path.is_dir():
_debug_log(f"Database path not found: {absolute_path}")
raise RuntimeError(f"Error: Database not found at {absolute_path}!")
raise error_with_message(RuntimeError, f"Error: Database not found at {absolute_path}!")
return absolute_path


Expand Down
11 changes: 6 additions & 5 deletions src/seclab_taskflow_agent/mcp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from urllib.parse import urlparse

from agents.mcp import MCPServerStdio
from seclab_taskflow_agent.error_utils import error_with_message

# Exit codes that are considered normal termination.
_EXPECTED_EXIT_CODES: frozenset[int] = frozenset({0, -signal.SIGTERM})
Expand Down Expand Up @@ -109,7 +110,7 @@ async def async_wait_for_connection(
host = parsed.hostname
port = parsed.port
if host is None or port is None:
raise ValueError(f"URL must include a host and port: {self.url}")
raise error_with_message(ValueError, f"URL must include a host and port: {self.url}")
deadline = asyncio.get_event_loop().time() + timeout
while True:
try:
Expand All @@ -119,7 +120,7 @@ async def async_wait_for_connection(
return
except (OSError, ConnectionRefusedError):
if asyncio.get_event_loop().time() > deadline:
raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds")
raise error_with_message(TimeoutError, f"Could not connect to {host}:{port} after {timeout} seconds")
await asyncio.sleep(poll_interval)

def wait_for_connection(
Expand All @@ -139,15 +140,15 @@ def wait_for_connection(
host = parsed.hostname
port = parsed.port
if host is None or port is None:
raise ValueError(f"URL must include a host and port: {self.url}")
raise error_with_message(ValueError, f"URL must include a host and port: {self.url}")
deadline = time.time() + timeout
while True:
try:
with socket.create_connection((host, port), timeout=2):
return
except OSError:
if time.time() > deadline:
raise TimeoutError(f"Could not connect to {host}:{port} after {timeout} seconds")
raise error_with_message(TimeoutError, f"Could not connect to {host}:{port} after {timeout} seconds")
time.sleep(poll_interval)

def run(self) -> None:
Expand Down Expand Up @@ -216,7 +217,7 @@ def join_and_raise(self, timeout: float | None = None) -> None:
"""
self.join(timeout)
if self.is_alive():
raise RuntimeError("Process thread did not exit within timeout.")
raise error_with_message(RuntimeError, "Process thread did not exit within timeout.")
if self.exception is not None:
raise self.exception

Expand Down
5 changes: 3 additions & 2 deletions src/seclab_taskflow_agent/mcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from .available_tools import AvailableTools
from .env_utils import swap_env
from seclab_taskflow_agent.error_utils import error_with_message

# Re-export transport classes and prompt builder so that existing
# ``from .mcp_utils import …`` statements continue to work.
Expand Down Expand Up @@ -209,7 +210,7 @@ def mcp_client_params(
logging.debug(f"Initializing streamable toolbox: {tb}\nargs:\n{args}\nenv:\n{env}\n")
exe = shutil.which(sp.command)
if exe is None:
raise FileNotFoundError(f"Could not resolve path to {sp.command}")
raise error_with_message(FileNotFoundError, f"Could not resolve path to {sp.command}")
start_cmd = [exe]
if args:
for i, v in enumerate(args):
Expand All @@ -227,7 +228,7 @@ def mcp_client_params(
server_params["env"] = env

case _:
raise ValueError(f"Unsupported MCP transport {kind}")
raise error_with_message(ValueError, f"Unsupported MCP transport {kind}")

client_params[tb] = (
server_params,
Expand Down
Loading