From ed25b8bf54c3ed8715dc279345ac361dcecd1def Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Tue, 21 Apr 2026 13:31:18 -0400 Subject: [PATCH 1/4] rework command framework Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com> --- csp_bot/__init__.py | 15 +- csp_bot/bot.py | 99 +++- csp_bot/commands/__init__.py | 25 +- csp_bot/commands/context.py | 174 ++++++ csp_bot/commands/executor.py | 212 +++++++ csp_bot/commands/framework.py | 163 ++++++ csp_bot/commands/legacy.py | 82 +++ csp_bot/gateway/gateway.py | 12 +- csp_bot/tests/test_bot_integration.py | 108 ++++ csp_bot/tests/test_command_framework.py | 724 ++++++++++++++++++++++++ 10 files changed, 1593 insertions(+), 21 deletions(-) create mode 100644 csp_bot/commands/context.py create mode 100644 csp_bot/commands/executor.py create mode 100644 csp_bot/commands/framework.py create mode 100644 csp_bot/commands/legacy.py create mode 100644 csp_bot/tests/test_command_framework.py diff --git a/csp_bot/__init__.py b/csp_bot/__init__.py index ef8d02e..f17659d 100644 --- a/csp_bot/__init__.py +++ b/csp_bot/__init__.py @@ -21,8 +21,13 @@ from .commands import ( BaseCommand, BaseCommandModel, + BotInfo, + Command, + CommandContext, + CommandModel, EchoCommand, HelpCommand, + LegacyCommandAdapter, NoResponseCommand, ReplyCommand, ReplyToAllCommand, @@ -30,6 +35,7 @@ ReplyToOtherCommand, ScheduleCommand, StatusCommand, + command, mention_user, ) from .gateway import CspBotGateway, Gateway, GatewayChannels, GatewayModule, GatewaySettings @@ -53,7 +59,14 @@ "DiscordConfig", "SlackConfig", "SymphonyConfig", - # Commands + # Commands — new framework + "Command", + "CommandContext", + "CommandModel", + "BotInfo", + "LegacyCommandAdapter", + "command", + # Commands — legacy "BaseCommand", "BaseCommandModel", "EchoCommand", diff --git a/csp_bot/bot.py b/csp_bot/bot.py index c641897..3333987 100644 --- a/csp_bot/bot.py +++ b/csp_bot/bot.py @@ -38,16 +38,21 @@ from .bot_config import BotConfig from .commands import ( BaseCommand, - BaseCommandModel, + BotInfo, + Command, + CommandContext, HelpCommand, ScheduleCommand, StatusCommand, + execute_command_func, + get_registered_commands, ) from .gateway import GatewayChannels, GatewayModule from .structs import ( Backend, BotCommand, BotMessage, + CommandVariant, ) log = getLogger(__name__) @@ -70,8 +75,8 @@ class Bot(GatewayModule): config: BotConfig - _command_models: List[BaseCommandModel] = PrivateAttr(default_factory=list) - _commands: Dict[str, BaseCommand] = PrivateAttr(default_factory=dict) + _command_models: List[Any] = PrivateAttr(default_factory=list) + _commands: Dict[str, Any] = PrivateAttr(default_factory=dict) _configs: Dict[Backend, Any] = PrivateAttr(default_factory=dict) _adapters: Dict[Backend, Any] = PrivateAttr(default_factory=dict) _connected_backends: Dict[Backend, Tuple[Any, asyncio.AbstractEventLoop]] = PrivateAttr(default_factory=dict) @@ -79,9 +84,14 @@ class Bot(GatewayModule): _authorized_users: Dict[Backend, Set[str]] = PrivateAttr(default_factory=dict) _bot_user_ids: Dict[Backend, str] = PrivateAttr(default_factory=dict) _bot_names: Dict[Backend, str] = PrivateAttr(default_factory=dict) + _deps: Any = PrivateAttr(default=None) _thread: Optional[threading.Thread] = PrivateAttr(None) _lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + def set_deps(self, deps: Any) -> None: + """Set shared dependency object for new command framework contexts.""" + self._deps = deps + def connect(self, channels: GatewayChannels) -> None: """Connect to configured backends and set up message processing. @@ -318,8 +328,11 @@ async def _fetch() -> Optional[Channel]: log.exception(f"Error resolving channel: {channel_identifier}") return None - def load_commands(self, command_models: List[BaseCommandModel]) -> None: - """Load command handlers from command models.""" + def load_commands(self, command_models: List[Any]) -> None: + """Load command handlers from command models and decorator registry. + + Supports both legacy BaseCommandModel and the new CommandModel. + """ log.info(f"Loading {len(command_models)} commands...") for model in command_models: try: @@ -328,14 +341,59 @@ def load_commands(self, command_models: List[BaseCommandModel]) -> None: log.critical(f"Incomplete command type - implement all abstract methods: {model.command}") raise e - command_str = command.command() + if isinstance(command, BaseCommand): + command_str = command.command() + runner: Any = command + elif isinstance(command, Command): + command_str = command.name + runner = command + else: + raise TypeError(f"Unsupported command type from model {type(model).__name__}: {type(command).__name__}") + log.info(f"Registered command: /{command_str}") if command_str in self._commands: raise Exception(f"Command already registered: {command_str}\n\t{command}\n\t{self._commands[command_str]}") - self._commands[command_str] = command + self._commands[command_str] = runner self._command_models.append(model) + # Decorator-registered commands are available globally and can be mixed + # with model-based commands. Explicit model definitions win on conflicts. + for command_name, entry in get_registered_commands().items(): + if command_name in self._commands: + continue + log.info(f"Registered decorated command: /{command_name}") + self._commands[command_name] = entry + + def _command_backends(self, command_runner: Any) -> List[str]: + """Return supported backends for either legacy or new command types.""" + if isinstance(command_runner, BaseCommand): + return command_runner.backends() + if isinstance(command_runner, Command): + return command_runner.backends + return list(getattr(command_runner, "backends", []) or []) + + def _build_command_context(self, cmd: BotCommand) -> CommandContext: + """Build CommandContext from legacy BotCommand for new framework execution.""" + bot_info = BotInfo( + id=self._get_bot_id(cmd.backend) or "", + name=self._get_bot_name(cmd.backend) or "", + version="", + ) + channel = Channel(id=cmd.channel_id, name=cmd.channel_name) + return CommandContext( + command_name=cmd.command, + source=cmd.source, + targets=list(cmd.targets), + channel=channel, + message=cmd.message, + args=list(cmd.args), + args_text=" ".join(cmd.args), + backend=cmd.backend, + bot=bot_info, + deps=self._deps, + ) + # ========================================================================= # Message Processing Nodes # ========================================================================= @@ -733,7 +791,8 @@ def _extract_commands( command_runner = self._commands[command_name] # Check backend support - if command_runner.backends() and backend not in command_runner.backends(): + command_backends = self._command_backends(command_runner) + if command_backends and backend not in command_backends: log.warning(f"Command {command_name} not supported on {backend}") return None @@ -764,7 +823,7 @@ def _extract_commands( channel_id=target_channel, channel_name=channel_name, backend=backend, - variant=command_runner.kind(), + variant=command_runner.kind() if isinstance(command_runner, BaseCommand) else CommandVariant.REPLY, message=msg, delay=None, schedule="", @@ -874,7 +933,7 @@ def _create_help_command(self, msg: Message, backend: str, channel_id: str) -> B channel_id=channel_id, channel_name=channel_name, backend=backend, - variant=command_runner.kind(), + variant=command_runner.kind() if isinstance(command_runner, BaseCommand) else CommandVariant.REPLY, message=msg, delay=None, schedule="", @@ -888,12 +947,22 @@ def _execute_command(self, cmd: BotCommand) -> Optional[Union[Message, List[Mess return None try: - if isinstance(command_runner, HelpCommand): - responses = command_runner.execute(cmd, MappingProxyType(self._commands)) - elif isinstance(command_runner, ScheduleCommand): - responses = command_runner.execute(cmd, self._scheduled) + if isinstance(command_runner, BaseCommand): + if isinstance(command_runner, HelpCommand): + responses = command_runner.execute(cmd, MappingProxyType(self._commands)) + elif isinstance(command_runner, ScheduleCommand): + responses = command_runner.execute(cmd, self._scheduled) + else: + responses = command_runner.execute(cmd) + elif isinstance(command_runner, Command): + ctx = self._build_command_context(cmd) + responses = [r for r in execute_command_func(command_runner.execute, ctx) if r is not None] + elif hasattr(command_runner, "handler"): + ctx = self._build_command_context(cmd) + responses = [r for r in execute_command_func(command_runner.handler, ctx) if r is not None] else: - responses = command_runner.execute(cmd) + log.error(f"Unsupported command runner type for {cmd.command}: {type(command_runner).__name__}") + return None except Exception: log.exception(f"Error executing command: {cmd.command}") return None diff --git a/csp_bot/commands/__init__.py b/csp_bot/commands/__init__.py index 4e98162..4419b89 100644 --- a/csp_bot/commands/__init__.py +++ b/csp_bot/commands/__init__.py @@ -2,6 +2,14 @@ Commands can leverage chatom's cross-platform features for mentions, formatting, and entity recognition. + +Two command APIs are available: + +1. ``@command`` decorator for stateless function-based commands +2. ``Command`` BaseModel subclass for stateful class-based commands + +The legacy ``BaseCommand`` hierarchy is still supported via +``LegacyCommandAdapter``. """ from csp_bot.utils import mention_user @@ -15,13 +23,28 @@ ReplyToAuthorCommand, ReplyToOtherCommand, ) +from .context import BotInfo, CommandContext from .echo import EchoCommand, EchoCommandModel +from .executor import execute_command_func +from .framework import Command, CommandEntry, CommandModel, clear_registry, command, get_registered_commands from .help import HelpCommand, HelpCommandModel +from .legacy import LegacyCommandAdapter from .schedule import ScheduleCommand, ScheduleCommandModel from .status import StatusCommand, StatusCommandModel __all__ = ( - # Base classes + # New framework + "Command", + "CommandContext", + "CommandEntry", + "CommandModel", + "BotInfo", + "LegacyCommandAdapter", + "command", + "clear_registry", + "get_registered_commands", + "execute_command_func", + # Legacy base classes "BaseCommand", "BaseCommandModel", "NoResponseCommand", diff --git a/csp_bot/commands/context.py b/csp_bot/commands/context.py new file mode 100644 index 0000000..a3007a9 --- /dev/null +++ b/csp_bot/commands/context.py @@ -0,0 +1,174 @@ +"""Command context for the new command framework. + +Provides a typed, read-only view of a command invocation that both +decorator-based and class-based commands receive. +""" + +from __future__ import annotations + +from typing import Any, Generic, List, Optional, TypeVar, Union + +from chatom import Channel, Message, User +from chatom.format import ( + FormattedAttachment, + FormattedImage, + FormattedMessage, + Table, + Text, + TextNode, + UserMention, +) + +Deps = TypeVar("Deps") + + +class BotInfo: + """Metadata about the bot instance.""" + + __slots__ = ("id", "name", "version") + + def __init__(self, id: str = "", name: str = "", version: str = ""): + self.id = id + self.name = name + self.version = version + + +class CommandContext(Generic[Deps]): + """Typed, read-only view of a command invocation. + + Both ``@command`` decorated functions and ``Command`` subclasses + receive this as their primary interface to the invocation. + + Attributes: + command_name: The name of the command being executed. + source: The user who invoked the command. + targets: Users mentioned in the command arguments. + channel: The channel where the command was issued. + message: The original chatom Message. + args: Parsed argument tokens (mentions stripped). + args_text: Raw argument string. + backend: Backend identifier ("slack", "symphony", etc.). + bot: Bot metadata. + deps: Injected dependencies. + """ + + __slots__ = ( + "command_name", + "source", + "targets", + "channel", + "message", + "args", + "args_text", + "backend", + "bot", + "deps", + ) + + def __init__( + self, + *, + command_name: str, + source: User, + targets: List[User], + channel: Channel, + message: Message, + args: List[str], + args_text: str, + backend: str, + bot: BotInfo, + deps: Any = None, + ): + self.command_name = command_name + self.source = source + self.targets = targets + self.channel = channel + self.message = message + self.args = args + self.args_text = args_text + self.backend = backend + self.bot = bot + self.deps = deps + + @property + def target(self) -> Optional[User]: + """First mentioned user, or None.""" + return self.targets[0] if self.targets else None + + def mention(self, user: Optional[User]) -> UserMention: + """Create a mention node for a user. + + Args: + user: The user to mention. Returns empty Text if None. + + Returns: + A UserMention TextNode that renders correctly per backend. + """ + if user is None: + return UserMention(user_id="", display_name="") + return UserMention( + user_id=user.id, + display_name=getattr(user, "display_name", "") or getattr(user, "name", "") or "", + ) + + def reply(self, *content: Union[TextNode, str, Table, FormattedImage, FormattedAttachment]) -> FormattedMessage: + """Build a FormattedMessage from content nodes. + + Args: + *content: Text nodes, strings, tables, images, or attachments. + + Returns: + A FormattedMessage ready for rendering. + + Example: + >>> ctx.reply(Bold(child=Text(content="Hello")), " world") + """ + msg = FormattedMessage(metadata={"backend": self.backend}) + for item in content: + if isinstance(item, str): + msg.content.append(Text(content=item)) + else: + msg.content.append(item) + return msg + + def table( + self, + data: Any, + headers: Optional[List[str]] = None, + alignment: Optional[Union[str, List[str]]] = None, + ) -> Table: + """Build a Table node from data. + + Args: + data: List of dicts, list of lists, or a pandas DataFrame. + headers: Column headers (inferred from dicts/DataFrame if omitted). + alignment: Column alignment ("left", "right", "center") or list per column. + + Returns: + A Table node. + """ + # Handle pandas DataFrame + try: + import pandas as pd + + if isinstance(data, pd.DataFrame): + headers = headers or list(data.columns) + data = data.values.tolist() + return Table.from_data(data, headers=headers) + except ImportError: + pass + + # Handle list of dicts + if data and isinstance(data[0], dict): + return Table.from_dict_list(data, columns=headers) + + # Handle list of lists + return Table.from_data(data, headers=headers) + + def image(self, url: str, alt: str = "", title: str = "") -> FormattedImage: + """Create an image node.""" + return FormattedImage(url=url, alt_text=alt, title=title) + + def attachment(self, url: str, filename: str, content_type: str = "") -> FormattedAttachment: + """Create an attachment node.""" + return FormattedAttachment(url=url, filename=filename, content_type=content_type) diff --git a/csp_bot/commands/executor.py b/csp_bot/commands/executor.py new file mode 100644 index 0000000..47043c1 --- /dev/null +++ b/csp_bot/commands/executor.py @@ -0,0 +1,212 @@ +"""Command execution engine supporting four function signatures. + +Handles sync functions, async functions, sync generators, and async +generators uniformly. The caller gets back a list of response items +(Message, FormattedMessage, str, or None) regardless of which +signature the command used. +""" + +from __future__ import annotations + +import asyncio +import inspect +import logging +import threading +from typing import Any, List, Optional, Union + +from chatom import Message +from chatom.format import FormattedMessage + +from csp_bot.structs import BotCommand + +log = logging.getLogger(__name__) + +# Module-level async event loop running in a background thread. +# Lazily initialised on first use. +_loop: Optional[asyncio.AbstractEventLoop] = None +_loop_thread: Optional[threading.Thread] = None +_loop_lock = threading.Lock() + + +def _get_event_loop() -> asyncio.AbstractEventLoop: + """Get or create the shared background event loop.""" + global _loop, _loop_thread + with _loop_lock: + if _loop is None or _loop.is_closed(): + _loop = asyncio.new_event_loop() + _loop_thread = threading.Thread( + target=_loop.run_forever, + name="csp-bot-async-loop", + daemon=True, + ) + _loop_thread.start() + return _loop + + +def _coerce_response(item: Any, backend: str) -> Optional[Union[Message, BotCommand]]: + """Coerce a command return value into a chatom Message. + + Accepts: + - None → None + - BotCommand → pass-through + - str → Message(content=str) + - Message → pass-through (ensure metadata.backend set) + - FormattedMessage → Message with rendered content + formatted_content + + Returns: + A chatom Message, BotCommand, or None. + """ + if item is None: + return None + + if isinstance(item, BotCommand): + return item + + if isinstance(item, Message): + if item.metadata is None: + item.metadata = {} + if not item.metadata.get("backend"): + item.metadata["backend"] = backend + return item + + if isinstance(item, FormattedMessage): + rendered = item.render_for(backend) + msg = Message( + content=rendered, + metadata={"backend": backend, "formatted": item}, + ) + return msg + + if isinstance(item, str): + return Message( + content=item, + metadata={"backend": backend}, + ) + + # Unknown type — try str() + log.warning(f"Command returned unexpected type {type(item)}, converting to str") + return Message( + content=str(item), + metadata={"backend": backend}, + ) + + +def execute_command_func( + fn: Any, + ctx: Any, + timeout: float = 60.0, +) -> List[Optional[Union[Message, BotCommand]]]: + """Execute a command callable and return a list of Messages. + + Detects the function signature and dispatches accordingly: + - sync function → call directly, wrap single result + - async function → run in background loop, wrap single result + - sync generator → drain, collect all yielded items + - async generator → drain in background loop, collect all + + Args: + fn: The callable (function, bound method, generator, etc.) + ctx: The CommandContext to pass. + timeout: Timeout in seconds for async operations and generators. + + Returns: + List of Messages/BotCommands. + """ + backend = getattr(ctx, "backend", "") + + if inspect.isasyncgenfunction(fn): + return _run_async_generator(fn, ctx, backend, timeout) + elif inspect.isgeneratorfunction(fn): + return _run_sync_generator(fn, ctx, backend, timeout) + elif inspect.iscoroutinefunction(fn): + return _run_async_function(fn, ctx, backend, timeout) + else: + return _run_sync_function(fn, ctx, backend) + + +def _run_sync_function(fn: Any, ctx: Any, backend: str) -> List[Optional[Union[Message, BotCommand]]]: + """Execute a plain sync function.""" + try: + result = fn(ctx) + return [_coerce_response(result, backend)] + except Exception: + log.exception("Error executing sync command") + raise + + +def _run_async_function( + fn: Any, + ctx: Any, + backend: str, + timeout: float, +) -> List[Optional[Union[Message, BotCommand]]]: + """Execute an async function in the background event loop.""" + loop = _get_event_loop() + try: + future = asyncio.run_coroutine_threadsafe(fn(ctx), loop) + result = future.result(timeout=timeout) + return [_coerce_response(result, backend)] + except asyncio.TimeoutError: + log.error(f"Async command timed out after {timeout}s") + raise + except Exception: + log.exception("Error executing async command") + raise + + +def _run_sync_generator( + fn: Any, + ctx: Any, + backend: str, + timeout: float, +) -> List[Optional[Union[Message, BotCommand]]]: + """Drain a sync generator until it yields None sentinel.""" + results: List[Optional[Union[Message, BotCommand]]] = [] + try: + gen = fn(ctx) + for item in gen: + if item is None: + break + results.append(_coerce_response(item, backend)) + except GeneratorExit: + pass + except Exception: + log.exception("Error in generator command") + raise + return results + + +async def _drain_async_gen( + fn: Any, + ctx: Any, + backend: str, +) -> List[Optional[Union[Message, BotCommand]]]: + """Async helper to drain an async generator until None sentinel.""" + results: List[Optional[Union[Message, BotCommand]]] = [] + async for item in fn(ctx): + if item is None: + break + results.append(_coerce_response(item, backend)) + return results + + +def _run_async_generator( + fn: Any, + ctx: Any, + backend: str, + timeout: float, +) -> List[Optional[Union[Message, BotCommand]]]: + """Drain an async generator in the background event loop.""" + loop = _get_event_loop() + try: + future = asyncio.run_coroutine_threadsafe( + _drain_async_gen(fn, ctx, backend), + loop, + ) + return future.result(timeout=timeout) + except asyncio.TimeoutError: + log.error(f"Async generator command timed out after {timeout}s") + raise + except Exception: + log.exception("Error in async generator command") + raise diff --git a/csp_bot/commands/framework.py b/csp_bot/commands/framework.py new file mode 100644 index 0000000..4686d25 --- /dev/null +++ b/csp_bot/commands/framework.py @@ -0,0 +1,163 @@ +"""New command framework for csp-bot. + +Provides two APIs for defining commands: + +1. ``@command`` decorator for stateless function-based commands +2. ``Command`` BaseModel subclass for stateful class-based commands + +Both support four execution signatures: sync, async, generator, async generator. +The framework detects the signature automatically and handles async bridging, +generator draining, and error handling transparently. +""" + +from __future__ import annotations + +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Type, +) + +from ccflow import BaseModel + +from csp_bot.commands.context import CommandContext + +log = logging.getLogger(__name__) + +_COMMAND_REGISTRY: Dict[str, "CommandEntry"] = {} + + +class CommandEntry: + """Internal registry entry for a command.""" + + __slots__ = ("name", "help", "backends", "handler", "is_class") + + def __init__( + self, + name: str, + help: str, + handler: Any, + backends: Optional[List[str]] = None, + is_class: bool = False, + ): + self.name = name + self.help = help + self.handler = handler + self.backends = backends or [] + self.is_class = is_class + + +def command( + name: str, + help: str = "", + backends: Optional[List[str]] = None, +) -> Callable: + """Decorator to register a function as a bot command. + + The decorated function receives a ``CommandContext`` and returns + a response. Four signatures are supported: + + - ``def f(ctx) -> T``: sync, single response + - ``async def f(ctx) -> T``: async, single response + - ``def f(ctx)``: sync generator, yields multiple responses + - ``async def f(ctx)``: async generator, yields multiple responses + + Args: + name: The command name (e.g. "echo" for /echo). + help: Help text shown by the /help command. + backends: List of backends this command supports. Empty = all. + + Returns: + The original function, registered in the global command registry. + + Example:: + + @command(name="echo", help="Echoes your message") + def echo(ctx: CommandContext) -> str: + return f"{ctx.mention(ctx.target)}: {ctx.args_text}" + """ + + def decorator(fn: Callable) -> Callable: + entry = CommandEntry( + name=name, + help=help, + handler=fn, + backends=backends, + is_class=False, + ) + _COMMAND_REGISTRY[name] = entry + # Stash metadata on the function for introspection + fn._command_name = name + fn._command_help = help + fn._command_backends = backends or [] + return fn + + return decorator + + +def get_registered_commands() -> Dict[str, CommandEntry]: + """Return a copy of the global command registry.""" + return dict(_COMMAND_REGISTRY) + + +def clear_registry() -> None: + """Clear the global command registry. Intended for testing.""" + _COMMAND_REGISTRY.clear() + + +class Command(BaseModel): + """Base class for stateful commands. + + Subclass this and implement ``execute`` to create a command that + retains state across invocations. Fields are Pydantic-validated + and composable with Hydra's ``_target_`` instantiation. + + ``execute`` supports all four signatures: sync, async, generator, + async generator. + + Example:: + + class MetsCommand(Command): + name: str = "mets" + help: str = "Show MLB standings" + api_url: str = "https://..." + + def execute(self, ctx: CommandContext) -> str: + return f"Standings from {self.api_url}" + """ + + name: str = "" + """The command name (e.g. "mets" for /mets).""" + + help: str = "" + """Help text shown by the /help command.""" + + backends: List[str] = [] + """Backends this command supports. Empty = all.""" + + def execute(self, ctx: CommandContext) -> Any: + """Execute the command. Override in subclasses. + + Supports sync return, async return, sync generator (yield), + and async generator (async yield). + """ + raise NotImplementedError(f"{type(self).__name__} must implement execute()") + + +class CommandModel(BaseModel): + """Hydra model for registering a Command subclass via config. + + Example YAML:: + + mets: + _target_: mypackage.MetsCommandModel + command: + _target_: mypackage.MetsCommand + api_url: "https://..." + """ + + command: Type[Command] = Command diff --git a/csp_bot/commands/legacy.py b/csp_bot/commands/legacy.py new file mode 100644 index 0000000..6442a16 --- /dev/null +++ b/csp_bot/commands/legacy.py @@ -0,0 +1,82 @@ +"""Legacy adapter for old BaseCommand subclasses. + +Wraps existing BaseCommand instances so they can be used in the new +command execution pipeline alongside @command functions and Command +subclasses. This allows incremental migration — old commands keep +working without any changes. +""" + +from __future__ import annotations + +import logging +from typing import Any, List + +from csp_bot.commands.base import BaseCommand +from csp_bot.commands.context import CommandContext +from csp_bot.structs import BotCommand + +log = logging.getLogger(__name__) + + +class LegacyCommandAdapter: + """Wraps a BaseCommand so it can be driven by the new framework. + + The adapter: + - Accepts a CommandContext + - Converts it to a BotCommand for the legacy execute() call + - Delegates to the original command's execute() method + - Passes through extra kwargs for HelpCommand/ScheduleCommand + """ + + __slots__ = ("_command",) + + def __init__(self, command: BaseCommand): + self._command = command + + @property + def wrapped(self) -> BaseCommand: + """The underlying BaseCommand instance.""" + return self._command + + @property + def name(self) -> str: + return self._command.command() + + @property + def help(self) -> str: + return self._command.help() + + @property + def backends(self) -> List[str]: + return self._command.backends() + + def context_to_bot_command(self, ctx: CommandContext) -> BotCommand: + """Convert a CommandContext back to a legacy BotCommand.""" + return BotCommand( + command=ctx.command_name, + args=tuple(ctx.args), + source=ctx.source, + targets=tuple(ctx.targets), + channel_id=ctx.channel.id if ctx.channel else "", + channel_name=ctx.channel.name if ctx.channel else "", + backend=ctx.backend, + variant=self._command.kind(), + message=ctx.message, + delay=None, + schedule="", + times_run=0, + ) + + def execute(self, ctx: CommandContext, **extra_kwargs) -> Any: + """Execute the legacy command with a CommandContext. + + Converts ctx → BotCommand, calls preexecute, then execute. + Returns whatever the legacy command returns (Message, BotMessage, etc.) + """ + bot_cmd = self.context_to_bot_command(ctx) + + # Legacy preexecute + bot_cmd = self._command.preexecute(bot_cmd) + + # Legacy execute — pass through extra kwargs + return self._command.execute(bot_cmd, **extra_kwargs) diff --git a/csp_bot/gateway/gateway.py b/csp_bot/gateway/gateway.py index 366ada4..e44ad2d 100644 --- a/csp_bot/gateway/gateway.py +++ b/csp_bot/gateway/gateway.py @@ -5,7 +5,7 @@ from functools import wraps from logging import getLogger -from typing import List +from typing import Any, List, Union from chatom import Message from csp import ts @@ -19,7 +19,7 @@ from pydantic import Field, model_validator from csp_bot import __version__ -from csp_bot.commands import BaseCommandModel +from csp_bot.commands import BaseCommandModel, CommandModel from csp_bot.structs import BotCommand log = getLogger(__name__) @@ -67,7 +67,8 @@ class CspBotGateway(BaseGateway): """CSP Bot Gateway with chatom integration.""" settings: GatewaySettings = Field(default_factory=GatewaySettings) - commands: List[BaseCommandModel] = [] + commands: List[Union[BaseCommandModel, CommandModel]] = [] + deps: Any = None @model_validator(mode="before") @classmethod @@ -85,7 +86,8 @@ def __init__( self, modules: List[GatewayModule] = None, channels: GatewayChannels = None, - commands: List[BaseCommandModel] = None, + commands: List[Union[BaseCommandModel, CommandModel]] = None, + deps: Any = None, *args, **kwargs, ): @@ -94,6 +96,7 @@ def __init__( modules=modules, channels=channels, commands=commands, + deps=deps, *args, **kwargs, ) @@ -105,6 +108,7 @@ def __init__( for module in self.modules: log.info(f"Checking module: {type(module).__name__} - is Bot: {isinstance(module, Bot)}") if isinstance(module, Bot): + module.set_deps(self.deps) module.load_commands(self.commands) @wraps(BaseGateway.start) diff --git a/csp_bot/tests/test_bot_integration.py b/csp_bot/tests/test_bot_integration.py index e7f4b30..6127792 100644 --- a/csp_bot/tests/test_bot_integration.py +++ b/csp_bot/tests/test_bot_integration.py @@ -17,6 +17,7 @@ from csp_bot import Bot, BotCommand, BotConfig, BotMessage from csp_bot.bot_config import SymphonyConfig from csp_bot.commands import HelpCommand, ReplyToOtherCommand +from csp_bot.commands.framework import Command, clear_registry, command from csp_bot.structs import CommandVariant # Test Fixtures @@ -263,6 +264,113 @@ def test_bot_message_to_chatom_includes_metadata(self, bot_with_symphony): assert result.metadata.get("backend") == "slack" +class TestNewFrameworkIntegration: + """Integration tests for new command framework execution in Bot.""" + + def setup_method(self): + clear_registry() + + def teardown_method(self): + clear_registry() + + def test_load_commands_discovers_decorated_command(self, bot_with_symphony): + """Decorated commands should be auto-registered by Bot.load_commands.""" + + @command(name="newecho", help="Echo via new framework") + def newecho(ctx): + return f"echo: {ctx.args_text}" + + bot_with_symphony.load_commands([]) + + assert "newecho" in bot_with_symphony._commands + entry = bot_with_symphony._commands["newecho"] + assert hasattr(entry, "handler") + + def test_execute_command_supports_new_class_command(self, bot_with_symphony): + """Class-based new framework commands should execute via _execute_command.""" + + class NewPing(Command): + name: str = "newping" + help: str = "Ping via new framework" + + def execute(self, ctx): + return ctx.reply("pong") + + bot_with_symphony._commands["newping"] = NewPing() + bot_with_symphony._bot_user_ids["symphony"] = "bot123" + bot_with_symphony._bot_names["symphony"] = "TestBot" + + cmd = BotCommand( + command="newping", + args=(), + source=User(id="user123", name="Test User"), + targets=(), + channel_id="channel789", + channel_name="test-channel", + backend="symphony", + variant=CommandVariant.REPLY, + message=Message( + id="msg123", + content="/newping", + author=User(id="user123"), + channel=Channel(id="channel789"), + ), + delay=None, + schedule="", + times_run=0, + ) + + results = bot_with_symphony._execute_command(cmd) + + assert results is not None + assert len(results) == 1 + assert isinstance(results[0], Message) + assert results[0].content == "pong" + assert results[0].metadata is not None + assert results[0].metadata.get("backend") == "symphony" + + def test_execute_command_injects_deps_into_context(self, bot_with_symphony): + """New framework commands should receive shared deps via ctx.deps.""" + + class NeedsDeps(Command): + name: str = "needsdeps" + help: str = "Check deps wiring" + + def execute(self, ctx): + return f"token={ctx.deps['token']}" + + bot_with_symphony._commands["needsdeps"] = NeedsDeps() + bot_with_symphony.set_deps({"token": "abc123"}) + bot_with_symphony._bot_user_ids["symphony"] = "bot123" + bot_with_symphony._bot_names["symphony"] = "TestBot" + + cmd = BotCommand( + command="needsdeps", + args=(), + source=User(id="user123", name="Test User"), + targets=(), + channel_id="channel789", + channel_name="test-channel", + backend="symphony", + variant=CommandVariant.REPLY, + message=Message( + id="msg123", + content="/needsdeps", + author=User(id="user123"), + channel=Channel(id="channel789"), + ), + delay=None, + schedule="", + times_run=0, + ) + + results = bot_with_symphony._execute_command(cmd) + + assert results is not None + assert len(results) == 1 + assert results[0].content == "token=abc123" + + # Command Argument Parsing Tests diff --git a/csp_bot/tests/test_command_framework.py b/csp_bot/tests/test_command_framework.py new file mode 100644 index 0000000..02ad0d6 --- /dev/null +++ b/csp_bot/tests/test_command_framework.py @@ -0,0 +1,724 @@ +"""Tests for the new command framework (Phase 1). + +Tests the @command decorator, Command base class, CommandContext, +four execution signatures (sync, async, generator, async generator), +the executor, and the legacy adapter. +""" + +import asyncio + +import pytest +from chatom import Channel, Message, User +from chatom.format import Bold, FormattedMessage, Text, UserMention + +from csp_bot.commands.base import ReplyToOtherCommand +from csp_bot.commands.context import BotInfo, CommandContext +from csp_bot.commands.executor import _coerce_response, execute_command_func +from csp_bot.commands.framework import Command, clear_registry, command, get_registered_commands +from csp_bot.commands.legacy import LegacyCommandAdapter +from csp_bot.structs import BotCommand, CommandVariant + + +def _make_ctx(**overrides) -> CommandContext: + """Build a CommandContext with sensible defaults.""" + defaults = dict( + command_name="test", + source=User(id="U1", name="alice"), + targets=[User(id="U2", name="bob")], + channel=Channel(id="C1", name="general"), + message=Message(content="/test hello", channel_id="C1"), + args=["hello"], + args_text="hello", + backend="slack", + bot=BotInfo(id="B1", name="testbot", version="0.0.1"), + deps=None, + ) + defaults.update(overrides) + return CommandContext(**defaults) + + +# =========================================================================== +# CommandContext tests +# =========================================================================== + + +class TestCommandContext: + def test_basic_attributes(self): + ctx = _make_ctx() + assert ctx.command_name == "test" + assert ctx.source.id == "U1" + assert ctx.backend == "slack" + assert ctx.args == ["hello"] + assert ctx.args_text == "hello" + + def test_target_property(self): + ctx = _make_ctx() + assert ctx.target.id == "U2" + + def test_target_none_when_empty(self): + ctx = _make_ctx(targets=[]) + assert ctx.target is None + + def test_mention_returns_user_mention(self): + ctx = _make_ctx() + node = ctx.mention(ctx.source) + assert isinstance(node, UserMention) + assert node.user_id == "U1" + + def test_mention_none_returns_empty(self): + ctx = _make_ctx() + node = ctx.mention(None) + assert isinstance(node, UserMention) + assert node.user_id == "" + + def test_reply_builds_formatted_message(self): + ctx = _make_ctx() + msg = ctx.reply("hello ", Bold(child=Text(content="world"))) + assert isinstance(msg, FormattedMessage) + assert len(msg.content) == 2 + assert isinstance(msg.content[0], Text) + assert isinstance(msg.content[1], Bold) + + def test_reply_sets_backend_metadata(self): + ctx = _make_ctx(backend="symphony") + msg = ctx.reply("test") + assert msg.metadata["backend"] == "symphony" + + def test_reply_renders_for_backend(self): + ctx = _make_ctx(backend="slack") + msg = ctx.reply(Bold(child=Text(content="hi"))) + rendered = msg.render_for("slack") + assert "*hi*" in rendered + + def test_table_from_dicts(self): + ctx = _make_ctx() + tbl = ctx.table([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) + # Should produce a valid Table node + rendered = tbl.render("markdown") + assert "a" in rendered + assert "1" in rendered + + def test_image(self): + ctx = _make_ctx() + img = ctx.image("https://example.com/img.png", alt="pic") + assert img.url == "https://example.com/img.png" + assert img.alt_text == "pic" + + def test_deps_accessible(self): + ctx = _make_ctx(deps={"api_key": "abc"}) + assert ctx.deps["api_key"] == "abc" + + +# =========================================================================== +# @command decorator tests +# =========================================================================== + + +class TestCommandDecorator: + def setup_method(self): + clear_registry() + + def teardown_method(self): + clear_registry() + + def test_registers_function(self): + @command(name="ping", help="Pong!") + def ping(ctx): + return "pong" + + reg = get_registered_commands() + assert "ping" in reg + assert reg["ping"].help == "Pong!" + assert reg["ping"].handler is ping + + def test_function_metadata(self): + @command(name="greet", help="Say hello", backends=["slack"]) + def greet(ctx): + return "hello" + + assert greet._command_name == "greet" + assert greet._command_help == "Say hello" + assert greet._command_backends == ["slack"] + + def test_multiple_commands(self): + @command(name="a", help="A") + def cmd_a(ctx): + return "a" + + @command(name="b", help="B") + def cmd_b(ctx): + return "b" + + reg = get_registered_commands() + assert "a" in reg + assert "b" in reg + + def test_decorator_returns_original_function(self): + @command(name="echo", help="Echo") + def echo(ctx): + return ctx.args_text + + # Should be callable directly + ctx = _make_ctx(args_text="hello world") + assert echo(ctx) == "hello world" + + +# =========================================================================== +# Command class tests +# =========================================================================== + + +class TestCommandClass: + def test_subclass_with_fields(self): + class MyCmd(Command): + name: str = "mycmd" + help: str = "My command" + multiplier: int = 2 + + def execute(self, ctx): + return str(int(ctx.args[0]) * self.multiplier) + + cmd = MyCmd() + assert cmd.name == "mycmd" + assert cmd.multiplier == 2 + + ctx = _make_ctx(args=["5"]) + assert cmd.execute(ctx) == "10" + + def test_subclass_with_custom_fields(self): + class Greeter(Command): + name: str = "greet" + help: str = "Greet someone" + greeting: str = "Hello" + + def execute(self, ctx): + return f"{self.greeting}, {ctx.source.name}!" + + cmd = Greeter(greeting="Howdy") + ctx = _make_ctx() + assert cmd.execute(ctx) == "Howdy, alice!" + + def test_base_command_raises_not_implemented(self): + cmd = Command(name="noop", help="noop") + ctx = _make_ctx() + with pytest.raises(NotImplementedError): + cmd.execute(ctx) + + +# =========================================================================== +# Executor tests — four signatures +# =========================================================================== + + +class TestExecutorSync: + def test_sync_returns_str(self): + def echo(ctx): + return ctx.args_text + + ctx = _make_ctx(args_text="hello") + results = execute_command_func(echo, ctx) + assert len(results) == 1 + assert results[0].content == "hello" + + def test_sync_returns_message(self): + def echo(ctx): + return Message(content="hi", metadata={"backend": "slack"}) + + ctx = _make_ctx() + results = execute_command_func(echo, ctx) + assert results[0].content == "hi" + + def test_sync_returns_formatted_message(self): + def echo(ctx): + return ctx.reply(Bold(child=Text(content="bold"))) + + ctx = _make_ctx(backend="slack") + results = execute_command_func(echo, ctx) + assert results[0].content == "*bold*" + + def test_sync_returns_none(self): + def noop(ctx): + return None + + ctx = _make_ctx() + results = execute_command_func(noop, ctx) + assert results == [None] + + def test_sync_returns_bot_command(self): + def next_cmd(ctx): + return BotCommand( + command="followup", + args=("x",), + source=ctx.source, + targets=tuple(ctx.targets), + channel_id=ctx.channel.id, + channel_name=ctx.channel.name, + backend=ctx.backend, + variant=CommandVariant.REPLY_TO_OTHER, + message=ctx.message, + delay=None, + schedule="", + times_run=0, + ) + + ctx = _make_ctx() + results = execute_command_func(next_cmd, ctx) + assert len(results) == 1 + assert isinstance(results[0], BotCommand) + assert results[0].command == "followup" + + def test_sync_raises(self): + def bad(ctx): + raise ValueError("boom") + + ctx = _make_ctx() + with pytest.raises(ValueError, match="boom"): + execute_command_func(bad, ctx) + + +class TestExecutorAsync: + def test_async_returns_str(self): + async def echo(ctx): + return ctx.args_text + + ctx = _make_ctx(args_text="async hello") + results = execute_command_func(echo, ctx) + assert len(results) == 1 + assert results[0].content == "async hello" + + def test_async_returns_formatted(self): + async def echo(ctx): + return ctx.reply("async world") + + ctx = _make_ctx(backend="discord") + results = execute_command_func(echo, ctx) + assert results[0].content == "async world" + + def test_async_raises(self): + async def bad(ctx): + raise RuntimeError("async boom") + + ctx = _make_ctx() + with pytest.raises(RuntimeError, match="async boom"): + execute_command_func(bad, ctx) + + def test_async_returns_bot_command(self): + async def next_cmd(ctx): + return BotCommand( + command="afollowup", + args=(), + source=ctx.source, + targets=tuple(ctx.targets), + channel_id=ctx.channel.id, + channel_name=ctx.channel.name, + backend=ctx.backend, + variant=CommandVariant.REPLY_TO_OTHER, + message=ctx.message, + delay=None, + schedule="", + times_run=0, + ) + + ctx = _make_ctx() + results = execute_command_func(next_cmd, ctx) + assert len(results) == 1 + assert isinstance(results[0], BotCommand) + assert results[0].command == "afollowup" + + +class TestExecutorGenerator: + def test_generator_yields_multiple(self): + def multi(ctx): + yield "first" + yield "second" + yield "third" + + ctx = _make_ctx() + results = execute_command_func(multi, ctx) + assert len(results) == 3 + assert results[0].content == "first" + assert results[1].content == "second" + assert results[2].content == "third" + + def test_generator_stops_on_none_sentinel(self): + def sparse(ctx): + yield "a" + yield None + yield "b" + + ctx = _make_ctx() + results = execute_command_func(sparse, ctx) + assert len(results) == 1 + assert results[0].content == "a" + + def test_generator_stops_on_stopiteration(self): + def finite(ctx): + yield "x" + yield "y" + # Natural exhaustion -> StopIteration + + ctx = _make_ctx() + results = execute_command_func(finite, ctx) + assert len(results) == 2 + assert results[0].content == "x" + assert results[1].content == "y" + + def test_generator_yields_bot_command(self): + def next_cmd(ctx): + yield "working" + yield BotCommand( + command="gfollowup", + args=("1",), + source=ctx.source, + targets=tuple(ctx.targets), + channel_id=ctx.channel.id, + channel_name=ctx.channel.name, + backend=ctx.backend, + variant=CommandVariant.REPLY_TO_OTHER, + message=ctx.message, + delay=None, + schedule="", + times_run=0, + ) + yield None + + ctx = _make_ctx() + results = execute_command_func(next_cmd, ctx) + assert len(results) == 2 + assert results[0].content == "working" + assert isinstance(results[1], BotCommand) + assert results[1].command == "gfollowup" + + def test_generator_yields_formatted(self): + def rich(ctx): + yield ctx.reply(Bold(child=Text(content="step 1"))) + yield ctx.reply(Bold(child=Text(content="step 2"))) + + ctx = _make_ctx(backend="slack") + results = execute_command_func(rich, ctx) + assert len(results) == 2 + assert "*step 1*" in results[0].content + assert "*step 2*" in results[1].content + + def test_generator_raises(self): + def bad_gen(ctx): + yield "ok" + raise ValueError("gen boom") + + ctx = _make_ctx() + with pytest.raises(ValueError, match="gen boom"): + execute_command_func(bad_gen, ctx) + + +class TestExecutorAsyncGenerator: + def test_async_generator_yields_multiple(self): + async def multi(ctx): + yield "first" + yield "second" + + ctx = _make_ctx() + results = execute_command_func(multi, ctx) + assert len(results) == 2 + assert results[0].content == "first" + assert results[1].content == "second" + + def test_async_generator_with_await(self): + async def fetcher(ctx): + await asyncio.sleep(0.01) + yield "fetched" + + ctx = _make_ctx() + results = execute_command_func(fetcher, ctx) + assert len(results) == 1 + assert results[0].content == "fetched" + + def test_async_generator_stops_on_none_sentinel(self): + async def sparse(ctx): + yield "start" + yield None + yield "after" + + ctx = _make_ctx() + results = execute_command_func(sparse, ctx) + assert len(results) == 1 + assert results[0].content == "start" + + def test_async_generator_stops_on_exhaustion(self): + async def finite(ctx): + yield "x" + yield "y" + # Natural async generator exhaustion + + ctx = _make_ctx() + results = execute_command_func(finite, ctx) + assert len(results) == 2 + assert results[0].content == "x" + assert results[1].content == "y" + + def test_async_generator_yields_bot_command(self): + async def next_cmd(ctx): + yield "working" + yield BotCommand( + command="agfollowup", + args=(), + source=ctx.source, + targets=tuple(ctx.targets), + channel_id=ctx.channel.id, + channel_name=ctx.channel.name, + backend=ctx.backend, + variant=CommandVariant.REPLY_TO_OTHER, + message=ctx.message, + delay=None, + schedule="", + times_run=0, + ) + yield None + + ctx = _make_ctx() + results = execute_command_func(next_cmd, ctx) + assert len(results) == 2 + assert results[0].content == "working" + assert isinstance(results[1], BotCommand) + assert results[1].command == "agfollowup" + + def test_async_generator_raises(self): + async def bad_agen(ctx): + yield "ok" + raise RuntimeError("agen boom") + + ctx = _make_ctx() + with pytest.raises(RuntimeError, match="agen boom"): + execute_command_func(bad_agen, ctx) + + +# =========================================================================== +# Executor — Command class integration +# =========================================================================== + + +class TestExecutorWithCommandClass: + def test_sync_command_class(self): + class Echo(Command): + name: str = "echo" + help: str = "Echo" + + def execute(self, ctx): + return ctx.args_text + + cmd = Echo() + ctx = _make_ctx(args_text="class hello") + results = execute_command_func(cmd.execute, ctx) + assert results[0].content == "class hello" + + def test_async_command_class(self): + class AsyncEcho(Command): + name: str = "aecho" + help: str = "Async echo" + + async def execute(self, ctx): + return ctx.args_text + + cmd = AsyncEcho() + ctx = _make_ctx(args_text="async class") + results = execute_command_func(cmd.execute, ctx) + assert results[0].content == "async class" + + def test_generator_command_class(self): + class MultiStep(Command): + name: str = "multi" + help: str = "Multi-step" + + def execute(self, ctx): + yield "step 1" + yield "step 2" + + cmd = MultiStep() + ctx = _make_ctx() + results = execute_command_func(cmd.execute, ctx) + assert len(results) == 2 + + def test_async_generator_command_class(self): + class Streamer(Command): + name: str = "stream" + help: str = "Stream" + + async def execute(self, ctx): + yield "chunk 1" + yield "chunk 2" + + cmd = Streamer() + ctx = _make_ctx() + results = execute_command_func(cmd.execute, ctx) + assert len(results) == 2 + + +# =========================================================================== +# _coerce_response tests +# =========================================================================== + + +class TestCoerceResponse: + def test_none(self): + assert _coerce_response(None, "slack") is None + + def test_str(self): + msg = _coerce_response("hello", "slack") + assert isinstance(msg, Message) + assert msg.content == "hello" + assert msg.metadata["backend"] == "slack" + + def test_message_passthrough(self): + m = Message(content="hi", metadata={"backend": "discord"}) + result = _coerce_response(m, "discord") + assert result is m + + def test_message_sets_backend_if_missing(self): + m = Message(content="hi") + result = _coerce_response(m, "symphony") + assert result.metadata["backend"] == "symphony" + + def test_formatted_message(self): + fm = FormattedMessage(content=[Bold(child=Text(content="bold"))]) + result = _coerce_response(fm, "slack") + assert isinstance(result, Message) + assert "*bold*" in result.content + assert result.metadata["backend"] == "slack" + + def test_unknown_type(self): + result = _coerce_response(42, "slack") + assert isinstance(result, Message) + assert result.content == "42" + + +# =========================================================================== +# Legacy adapter tests +# =========================================================================== + + +class _FakeEchoCommand(ReplyToOtherCommand): + """Minimal legacy command for testing.""" + + def command(self): + return "echo" + + def name(self): + return "Echo" + + def help(self): + return "Echo a message" + + def execute(self, cmd): + content = " ".join(cmd.args) if cmd.args else "" + return Message(content=content, metadata={"backend": cmd.backend}) + + +class TestLegacyAdapter: + def test_wraps_command(self): + legacy = _FakeEchoCommand() + adapter = LegacyCommandAdapter(legacy) + assert adapter.name == "echo" + assert adapter.help == "Echo a message" + assert adapter.wrapped is legacy + + def test_execute_via_adapter(self): + legacy = _FakeEchoCommand() + adapter = LegacyCommandAdapter(legacy) + ctx = _make_ctx(command_name="echo", args=["hello", "world"]) + result = adapter.execute(ctx) + assert isinstance(result, Message) + assert result.content == "hello world" + + def test_context_to_bot_command(self): + legacy = _FakeEchoCommand() + adapter = LegacyCommandAdapter(legacy) + ctx = _make_ctx( + command_name="echo", + args=["x"], + backend="symphony", + ) + bot_cmd = adapter.context_to_bot_command(ctx) + assert isinstance(bot_cmd, BotCommand) + assert bot_cmd.command == "echo" + assert bot_cmd.args == ("x",) + assert bot_cmd.backend == "symphony" + assert bot_cmd.variant == CommandVariant.REPLY_TO_OTHER + + +# =========================================================================== +# End-to-end: decorated command through executor +# =========================================================================== + + +class TestEndToEnd: + def setup_method(self): + clear_registry() + + def teardown_method(self): + clear_registry() + + def test_decorated_sync_through_executor(self): + @command(name="greet", help="Greet") + def greet(ctx): + return f"Hello, {ctx.source.name}!" + + entry = get_registered_commands()["greet"] + ctx = _make_ctx() + results = execute_command_func(entry.handler, ctx) + assert results[0].content == "Hello, alice!" + + def test_decorated_async_through_executor(self): + @command(name="agreet", help="Async greet") + async def agreet(ctx): + return f"Hi, {ctx.source.name}!" + + entry = get_registered_commands()["agreet"] + ctx = _make_ctx() + results = execute_command_func(entry.handler, ctx) + assert results[0].content == "Hi, alice!" + + def test_decorated_generator_through_executor(self): + @command(name="multi", help="Multi") + def multi(ctx): + yield "one" + yield "two" + + entry = get_registered_commands()["multi"] + ctx = _make_ctx() + results = execute_command_func(entry.handler, ctx) + assert len(results) == 2 + + def test_class_command_through_executor(self): + class Thanks(Command): + name: str = "thanks" + help: str = "Thank someone" + gifts: list = ["cookie", "cake"] + + def execute(self, ctx): + return f"{ctx.mention(ctx.target)} gets a {self.gifts[0]}" + + cmd = Thanks() + ctx = _make_ctx(backend="slack") + results = execute_command_func(cmd.execute, ctx) + # UserMention rendered as string when coerced + assert "gets a cookie" in results[0].content + + def test_formatted_message_renders_per_backend(self): + @command(name="rich", help="Rich") + def rich(ctx): + return ctx.reply( + Bold(child=Text(content="Title")), + " - details", + ) + + entry = get_registered_commands()["rich"] + + # Slack + ctx_slack = _make_ctx(backend="slack") + results = execute_command_func(entry.handler, ctx_slack) + assert "*Title*" in results[0].content + assert " - details" in results[0].content + + # Discord + ctx_discord = _make_ctx(backend="discord") + results = execute_command_func(entry.handler, ctx_discord) + assert "**Title**" in results[0].content From 9200826474d74dc973df876b2463d41c30855368 Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Tue, 21 Apr 2026 15:02:41 -0400 Subject: [PATCH 2/4] Migrate commands to concrete chatom types Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com> --- csp_bot/commands/echo.py | 3 ++- csp_bot/structs.py | 8 ++++---- csp_bot/tests/test_echo.py | 38 +++++++++----------------------------- 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/csp_bot/commands/echo.py b/csp_bot/commands/echo.py index f6dd991..184b814 100644 --- a/csp_bot/commands/echo.py +++ b/csp_bot/commands/echo.py @@ -36,7 +36,8 @@ def execute(self, command: BotCommand) -> Optional[Message]: # Add mentions for any tagged users if command.targets: - mentions = mention_users([t.to_chatom_user() for t in command.targets], command.backend) + users = [t.to_chatom_user() if hasattr(t, "to_chatom_user") else t for t in command.targets] + mentions = mention_users(users, command.backend) if mentions: content = f"{content} {mentions}".strip() diff --git a/csp_bot/structs.py b/csp_bot/structs.py index 0513875..4cc7017 100644 --- a/csp_bot/structs.py +++ b/csp_bot/structs.py @@ -8,7 +8,7 @@ from enum import Enum from typing import Tuple -from chatom import Channel, Message as ChatomMessage +from chatom import Channel, Message as ChatomMessage, User from csp_gateway.utils.struct import GatewayStruct __all__ = ( @@ -124,10 +124,10 @@ class BotCommand(GatewayStruct): args: Tuple[str] """Command arguments as parsed tokens.""" - source: object # chatom.User - stored as object for Struct compatibility + source: User """The user who issued the command.""" - targets: Tuple[object] # Tuple[chatom.User] - stored as object for Struct compatibility + targets: Tuple[User] """Users mentioned/tagged in the command.""" channel_id: str @@ -142,7 +142,7 @@ class BotCommand(GatewayStruct): variant: CommandVariant """The command response variant.""" - message: object # chatom.Message - stored as object for Struct compatibility + message: ChatomMessage """The original chatom Message.""" delay: datetime diff --git a/csp_bot/tests/test_echo.py b/csp_bot/tests/test_echo.py index 0b6a662..aa978e4 100644 --- a/csp_bot/tests/test_echo.py +++ b/csp_bot/tests/test_echo.py @@ -1,6 +1,6 @@ """Tests for the echo command.""" -from chatom import Channel, User +from chatom import Channel, Message, User from csp_bot.commands.echo import EchoCommand, EchoCommandModel from csp_bot.structs import BotCommand, CommandVariant @@ -39,7 +39,7 @@ def test_execute_with_args(self): source=User(id="u1", name="sender"), targets=(), variant=CommandVariant.REPLY, - message=None, + message=Message(content="test"), ) result = cmd.execute(bot_cmd) @@ -62,7 +62,7 @@ def test_execute_with_no_args_returns_none(self): source=User(id="u1", name="sender"), targets=(), variant=CommandVariant.REPLY, - message=None, + message=Message(content="test"), ) result = cmd.execute(bot_cmd) @@ -70,20 +70,10 @@ def test_execute_with_no_args_returns_none(self): assert result is None def test_execute_with_targets(self): - """Test executing echo with targets. - - Note: The echo command calls to_chatom_user() on targets, but the - targets are already chatom.User objects. For this test we use a mock. - """ - from unittest.mock import Mock - + """Test executing echo with targets.""" cmd = EchoCommand() channel = Channel(id="ch1", name="test-channel") - # Create a mock target that has the to_chatom_user method - mock_target = Mock() - mock_target.to_chatom_user.return_value = User(id="u123", name="testuser") - bot_cmd = BotCommand( backend="slack", command="echo", @@ -91,9 +81,9 @@ def test_execute_with_targets(self): channel_id=channel.id, channel_name=channel.name, source=User(id="u1", name="sender"), - targets=(mock_target,), + targets=(User(id="u123", name="testuser"),), variant=CommandVariant.REPLY_TO_OTHER, - message=None, + message=Message(content="test"), ) result = cmd.execute(bot_cmd) @@ -104,20 +94,10 @@ def test_execute_with_targets(self): assert "<@u123>" in result.content or "testuser" in result.content def test_execute_with_only_targets(self): - """Test executing echo with only targets (no args). - - Note: The echo command calls to_chatom_user() on targets, but the - targets are already chatom.User objects. For this test we use a mock. - """ - from unittest.mock import Mock - + """Test executing echo with only targets (no args).""" cmd = EchoCommand() channel = Channel(id="ch1", name="test-channel") - # Create a mock target that has the to_chatom_user method - mock_target = Mock() - mock_target.to_chatom_user.return_value = User(id="u123", name="testuser") - bot_cmd = BotCommand( backend="slack", command="echo", @@ -125,9 +105,9 @@ def test_execute_with_only_targets(self): channel_id=channel.id, channel_name=channel.name, source=User(id="u1", name="sender"), - targets=(mock_target,), + targets=(User(id="u123", name="testuser"),), variant=CommandVariant.REPLY_TO_OTHER, - message=None, + message=Message(content="test"), ) result = cmd.execute(bot_cmd) From f320900b3281e4f12ae3fe1f0b4f22b400a2dd2c Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Tue, 21 Apr 2026 15:10:20 -0400 Subject: [PATCH 3/4] Load commands via entrypoints, enforce backend compatibility Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com> --- csp_bot/bot.py | 97 +++++++++++++++++--- csp_bot/commands/echo.py | 3 +- csp_bot/tests/test_bot_integration.py | 112 +++++++++++++++++++++++- csp_bot/tests/test_command_framework.py | 40 --------- csp_bot/utils.py | 5 -- 5 files changed, 197 insertions(+), 60 deletions(-) diff --git a/csp_bot/bot.py b/csp_bot/bot.py index 3333987..058a4db 100644 --- a/csp_bot/bot.py +++ b/csp_bot/bot.py @@ -12,6 +12,7 @@ import asyncio import html +import importlib.metadata as importlib_metadata import re import threading import time @@ -88,6 +89,8 @@ class Bot(GatewayModule): _thread: Optional[threading.Thread] = PrivateAttr(None) _lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + _KNOWN_BACKENDS: Set[str] = {"discord", "slack", "symphony", "telegram"} + def set_deps(self, deps: Any) -> None: """Set shared dependency object for new command framework contexts.""" self._deps = deps @@ -334,6 +337,8 @@ def load_commands(self, command_models: List[Any]) -> None: Supports both legacy BaseCommandModel and the new CommandModel. """ log.info(f"Loading {len(command_models)} commands...") + self._load_entrypoint_commands() + active_backends = self._active_backends() for model in command_models: try: command = model.command() @@ -350,6 +355,9 @@ def load_commands(self, command_models: List[Any]) -> None: else: raise TypeError(f"Unsupported command type from model {type(model).__name__}: {type(command).__name__}") + if not self._is_command_backend_compatible(command_str, runner, active_backends): + continue + log.info(f"Registered command: /{command_str}") if command_str in self._commands: raise Exception(f"Command already registered: {command_str}\n\t{command}\n\t{self._commands[command_str]}") @@ -362,9 +370,86 @@ def load_commands(self, command_models: List[Any]) -> None: for command_name, entry in get_registered_commands().items(): if command_name in self._commands: continue + if not self._is_command_backend_compatible(command_name, entry, active_backends): + continue log.info(f"Registered decorated command: /{command_name}") self._commands[command_name] = entry + def _load_entrypoint_commands(self) -> None: + """Load command plugins from Python entry points. + + Entry points in the ``csp_bot.commands`` group are imported so they can + register commands through decorators or module import side effects. + If the loaded object is callable, it is invoked with no arguments. + """ + try: + try: + entry_points = importlib_metadata.entry_points(group="csp_bot.commands") + except TypeError: + all_entry_points = importlib_metadata.entry_points() + entry_points = all_entry_points.get("csp_bot.commands", []) + except Exception: + log.exception("Failed to discover csp_bot.commands entry points") + return + + for entry_point in entry_points: + try: + loaded = entry_point.load() + except Exception: + log.exception("Failed to load command entry point: %s", getattr(entry_point, "name", "")) + continue + + if callable(loaded): + try: + loaded() + except Exception: + log.exception("Failed to initialize command entry point: %s", getattr(entry_point, "name", "")) + continue + + log.info("Loaded command entry point: %s", getattr(entry_point, "name", "")) + + def _active_backends(self) -> Set[str]: + """Return configured backends for this bot instance.""" + active: Set[str] = set() + if self.config.discord: + active.add("discord") + if self.config.slack: + active.add("slack") + if self.config.symphony: + active.add("symphony") + return active + + def _normalize_command_backends(self, command_name: str, backends: List[str]) -> List[str]: + """Normalize and validate declared command backends.""" + normalized = [b.lower() for b in backends] + unknown = sorted({b for b in normalized if b not in self._KNOWN_BACKENDS}) + if unknown: + raise ValueError(f"Command '{command_name}' declared unknown backends: {', '.join(unknown)}") + return normalized + + def _is_command_backend_compatible(self, command_name: str, command_runner: Any, active_backends: Set[str]) -> bool: + """Check registration-time backend compatibility for a command.""" + declared_backends = self._command_backends(command_runner) + if not declared_backends: + return True + + normalized = self._normalize_command_backends(command_name, declared_backends) + + # If no backends are configured yet, keep command registration permissive. + if not active_backends: + return True + + if active_backends.intersection(normalized): + return True + + log.info( + "Skipping command /%s: declared backends %s do not match active backends %s", + command_name, + normalized, + sorted(active_backends), + ) + return False + def _command_backends(self, command_runner: Any) -> List[str]: """Return supported backends for either legacy or new command types.""" if isinstance(command_runner, BaseCommand): @@ -394,10 +479,6 @@ def _build_command_context(self, cmd: BotCommand) -> CommandContext: deps=self._deps, ) - # ========================================================================= - # Message Processing Nodes - # ========================================================================= - @csp.node def _process_incoming_messages(self, msg: ts[Message]) -> Outputs(bot_commands=ts[[BotCommand]], unauthorized_message=ts[Message]): """Process incoming messages to extract bot commands. @@ -521,10 +602,6 @@ def _handle_commands(self, cmd: ts[BotCommand]) -> Outputs(messages=ts[[Message] csp.schedule_alarm(a_ratelimit, timedelta(seconds=self.config.ratelimit_seconds), True) - # ========================================================================= - # Message Analysis using chatom - # ========================================================================= - def _is_message_to_bot(self, msg: Message, backend: str) -> Tuple[bool, str, str, List[User]]: """Check if a message is directed at the bot. @@ -713,10 +790,6 @@ def _is_authorized(self, msg: Message, backend: str) -> bool: authorized = self._authorized_users.get(backend, set()) return author_id in authorized - # ========================================================================= - # Command Extraction and Execution - # ========================================================================= - def _extract_commands( self, msg: Message, diff --git a/csp_bot/commands/echo.py b/csp_bot/commands/echo.py index 184b814..9f7b2d4 100644 --- a/csp_bot/commands/echo.py +++ b/csp_bot/commands/echo.py @@ -36,8 +36,7 @@ def execute(self, command: BotCommand) -> Optional[Message]: # Add mentions for any tagged users if command.targets: - users = [t.to_chatom_user() if hasattr(t, "to_chatom_user") else t for t in command.targets] - mentions = mention_users(users, command.backend) + mentions = mention_users(list(command.targets), command.backend) if mentions: content = f"{content} {mentions}".strip() diff --git a/csp_bot/tests/test_bot_integration.py b/csp_bot/tests/test_bot_integration.py index 6127792..c5f49fd 100644 --- a/csp_bot/tests/test_bot_integration.py +++ b/csp_bot/tests/test_bot_integration.py @@ -17,7 +17,7 @@ from csp_bot import Bot, BotCommand, BotConfig, BotMessage from csp_bot.bot_config import SymphonyConfig from csp_bot.commands import HelpCommand, ReplyToOtherCommand -from csp_bot.commands.framework import Command, clear_registry, command +from csp_bot.commands.framework import Command, CommandModel, clear_registry, command from csp_bot.structs import CommandVariant # Test Fixtures @@ -371,6 +371,116 @@ def execute(self, ctx): assert results[0].content == "token=abc123" +class TestRegistrationTimeBackendPolicy: + """Tests for registration-time backend compatibility checks.""" + + def setup_method(self): + clear_registry() + + def teardown_method(self): + clear_registry() + + def test_decorated_command_skipped_when_backend_not_active(self, bot_with_symphony): + """Commands limited to inactive backends should not be registered.""" + + @command(name="slack_only", help="Slack only", backends=["slack"]) + def slack_only(ctx): + return "nope" + + bot_with_symphony.load_commands([]) + + assert "slack_only" not in bot_with_symphony._commands + + def test_decorated_command_registered_when_backend_active(self, bot_with_symphony): + """Commands limited to active backends should be registered.""" + + @command(name="symphony_only", help="Symphony only", backends=["symphony"]) + def symphony_only(ctx): + return "ok" + + bot_with_symphony.load_commands([]) + + assert "symphony_only" in bot_with_symphony._commands + + def test_invalid_backend_name_raises_on_registration(self, bot_with_symphony): + """Unknown backend names should fail fast during registration.""" + + @command(name="bad_backend", help="Bad backend", backends=["not-a-backend"]) + def bad_backend(ctx): + return "nope" + + with pytest.raises(ValueError, match="unknown backends"): + bot_with_symphony.load_commands([]) + + def test_model_command_skipped_when_backend_not_active(self, bot_with_symphony): + """Model-loaded commands should obey registration-time backend filtering.""" + + class SlackOnlyCommand(Command): + name: str = "model_slack_only" + help: str = "Slack-only model command" + backends: list[str] = ["slack"] + + def execute(self, ctx): + return "nope" + + model = CommandModel(command=SlackOnlyCommand) + bot_with_symphony.load_commands([model]) + + assert "model_slack_only" not in bot_with_symphony._commands + + +class TestEntryPointCommandDiscovery: + """Tests for plugin command discovery through Python entry points.""" + + def setup_method(self): + clear_registry() + + def teardown_method(self): + clear_registry() + + def test_load_commands_discovers_entrypoint_registered_command(self, bot_with_symphony): + """Entry-point loader should import plugin and register its decorated command.""" + + def register_plugin_command(): + @command(name="from_ep", help="Registered from entry point") + def from_ep(ctx): + return "ok" + + entry_point = MagicMock() + entry_point.name = "plugin.from_ep" + entry_point.load.return_value = register_plugin_command + + with patch("csp_bot.bot.importlib_metadata.entry_points", return_value=[entry_point]): + bot_with_symphony.load_commands([]) + + assert "from_ep" in bot_with_symphony._commands + + def test_model_command_precedence_over_entrypoint_command(self, bot_with_symphony): + """Explicit model registration should win when entry point uses same command name.""" + + def register_plugin_command(): + @command(name="clash", help="Plugin command") + def clash(ctx): + return "plugin" + + class ClashModelCommand(Command): + name: str = "clash" + help: str = "Model command" + + def execute(self, ctx): + return "model" + + entry_point = MagicMock() + entry_point.name = "plugin.clash" + entry_point.load.return_value = register_plugin_command + + with patch("csp_bot.bot.importlib_metadata.entry_points", return_value=[entry_point]): + bot_with_symphony.load_commands([CommandModel(command=ClashModelCommand)]) + + assert "clash" in bot_with_symphony._commands + assert isinstance(bot_with_symphony._commands["clash"], ClashModelCommand) + + # Command Argument Parsing Tests diff --git a/csp_bot/tests/test_command_framework.py b/csp_bot/tests/test_command_framework.py index 02ad0d6..afe2ce9 100644 --- a/csp_bot/tests/test_command_framework.py +++ b/csp_bot/tests/test_command_framework.py @@ -37,11 +37,6 @@ def _make_ctx(**overrides) -> CommandContext: return CommandContext(**defaults) -# =========================================================================== -# CommandContext tests -# =========================================================================== - - class TestCommandContext: def test_basic_attributes(self): ctx = _make_ctx() @@ -109,11 +104,6 @@ def test_deps_accessible(self): assert ctx.deps["api_key"] == "abc" -# =========================================================================== -# @command decorator tests -# =========================================================================== - - class TestCommandDecorator: def setup_method(self): clear_registry() @@ -163,11 +153,6 @@ def echo(ctx): assert echo(ctx) == "hello world" -# =========================================================================== -# Command class tests -# =========================================================================== - - class TestCommandClass: def test_subclass_with_fields(self): class MyCmd(Command): @@ -205,11 +190,6 @@ def test_base_command_raises_not_implemented(self): cmd.execute(ctx) -# =========================================================================== -# Executor tests — four signatures -# =========================================================================== - - class TestExecutorSync: def test_sync_returns_str(self): def echo(ctx): @@ -491,11 +471,6 @@ async def bad_agen(ctx): execute_command_func(bad_agen, ctx) -# =========================================================================== -# Executor — Command class integration -# =========================================================================== - - class TestExecutorWithCommandClass: def test_sync_command_class(self): class Echo(Command): @@ -552,11 +527,6 @@ async def execute(self, ctx): assert len(results) == 2 -# =========================================================================== -# _coerce_response tests -# =========================================================================== - - class TestCoerceResponse: def test_none(self): assert _coerce_response(None, "slack") is None @@ -590,11 +560,6 @@ def test_unknown_type(self): assert result.content == "42" -# =========================================================================== -# Legacy adapter tests -# =========================================================================== - - class _FakeEchoCommand(ReplyToOtherCommand): """Minimal legacy command for testing.""" @@ -644,11 +609,6 @@ def test_context_to_bot_command(self): assert bot_cmd.variant == CommandVariant.REPLY_TO_OTHER -# =========================================================================== -# End-to-end: decorated command through executor -# =========================================================================== - - class TestEndToEnd: def setup_method(self): clear_registry() diff --git a/csp_bot/utils.py b/csp_bot/utils.py index 346fc15..8feb5db 100644 --- a/csp_bot/utils.py +++ b/csp_bot/utils.py @@ -120,11 +120,6 @@ def get_backend_format(backend: Backend) -> Format: return get_format_for_backend(backend) -# ============================================================================ -# Symphony MessageML formatting utilities -# ============================================================================ - - def format_with_message_ml(text: str, to_message_ml: bool = True) -> str: """Convert text to/from Symphony MessageML format. From 96ad4b83f746d56441d74b67a416e567cf306d39 Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Tue, 21 Apr 2026 21:07:55 -0400 Subject: [PATCH 4/4] Migrate built-in commands from hand-crafted per-backend strings to FormattedMessage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit executor.py — New _extract_attachments() helper converts FormattedImage/FormattedAttachment from a FormattedMessage into base Attachment objects; _coerce_response() now populates Message.attachments when returning a FormattedMessage echo.py — Replaced manual mention_users() calls with FormattedMessage + UserMention nodes help.py — Replaced per-backend if/elif rendering (Symphony , Slack mrkdwn, Discord markdown table) with FormattedMessage + Heading + Table.from_dict_list() that auto-renders for each backend status.py — Same migration: builds a Table.from_dict_list() of metrics instead of manual **bold** markdown lines schedule.py — Same migration for the schedule listing test_command_framework.py — 3 new tests for _coerce_response() covering images in content, file attachments, and mixed content+attachments Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com> --- csp_bot/__init__.py | 8 + csp_bot/bot.py | 193 +++++++- csp_bot/commands/__init__.py | 6 + csp_bot/commands/agent.py | 573 ++++++++++++++++++++++++ csp_bot/commands/context.py | 44 +- csp_bot/commands/echo.py | 32 +- csp_bot/commands/executor.py | 54 +++ csp_bot/commands/help.py | 87 ++-- csp_bot/commands/schedule.py | 32 +- csp_bot/commands/status.py | 39 +- csp_bot/persistence.py | 352 +++++++++++++++ csp_bot/structs.py | 3 + csp_bot/tests/test_agent_command.py | 496 ++++++++++++++++++++ csp_bot/tests/test_command_framework.py | 94 ++++ csp_bot/tests/test_persistence.py | 213 +++++++++ csp_bot/tests/test_schedule.py | 107 +++++ pyproject.toml | 7 + 17 files changed, 2227 insertions(+), 113 deletions(-) create mode 100644 csp_bot/commands/agent.py create mode 100644 csp_bot/persistence.py create mode 100644 csp_bot/tests/test_agent_command.py create mode 100644 csp_bot/tests/test_persistence.py create mode 100644 csp_bot/tests/test_schedule.py diff --git a/csp_bot/__init__.py b/csp_bot/__init__.py index f17659d..2d67d52 100644 --- a/csp_bot/__init__.py +++ b/csp_bot/__init__.py @@ -39,6 +39,7 @@ mention_user, ) from .gateway import CspBotGateway, Gateway, GatewayChannels, GatewayModule, GatewaySettings +from .persistence import FsspecStateStore, InMemoryStateStore, ScheduledCommandRecord, ScheduleStore, StateStore, StoredRecord from .structs import Backend, BotCommand, BotMessage, CommandVariant from .utils import format_message, get_backend_format, is_valid_url, mention_users @@ -85,6 +86,13 @@ "GatewayChannels", "GatewayModule", "GatewaySettings", + # Persistence + "FsspecStateStore", + "InMemoryStateStore", + "ScheduleStore", + "ScheduledCommandRecord", + "StateStore", + "StoredRecord", # Structs "Backend", "BotCommand", diff --git a/csp_bot/bot.py b/csp_bot/bot.py index 058a4db..a336696 100644 --- a/csp_bot/bot.py +++ b/csp_bot/bot.py @@ -49,6 +49,7 @@ get_registered_commands, ) from .gateway import GatewayChannels, GatewayModule +from .persistence import InMemoryStateStore, ScheduledCommandRecord, ScheduleStore, StateStore from .structs import ( Backend, BotCommand, @@ -81,7 +82,7 @@ class Bot(GatewayModule): _configs: Dict[Backend, Any] = PrivateAttr(default_factory=dict) _adapters: Dict[Backend, Any] = PrivateAttr(default_factory=dict) _connected_backends: Dict[Backend, Tuple[Any, asyncio.AbstractEventLoop]] = PrivateAttr(default_factory=dict) - _scheduled: Dict[str, BotCommand] = PrivateAttr(default_factory=dict) + _schedule_store: ScheduleStore = PrivateAttr(default_factory=lambda: ScheduleStore(InMemoryStateStore())) _authorized_users: Dict[Backend, Set[str]] = PrivateAttr(default_factory=dict) _bot_user_ids: Dict[Backend, str] = PrivateAttr(default_factory=dict) _bot_names: Dict[Backend, str] = PrivateAttr(default_factory=dict) @@ -91,6 +92,41 @@ class Bot(GatewayModule): _KNOWN_BACKENDS: Set[str] = {"discord", "slack", "symphony", "telegram"} + @staticmethod + def _datetime_for_now(value: Optional[datetime], now: datetime) -> Optional[datetime]: + # Persistence records use aware UTC datetimes. csp.now() is naive in + # current runtime tests, so normalize only at the csp scheduling edge. + if value is None: + return None + if value.tzinfo is not None and now.tzinfo is None: + return value.replace(tzinfo=None) + if value.tzinfo is None and now.tzinfo is not None: + return value.replace(tzinfo=now.tzinfo) + return value + + def set_state_store(self, state_store: StateStore) -> None: + """Inject the state store used for bot runtime persistence.""" + self.set_schedule_store(ScheduleStore(state_store)) + + def set_schedule_store(self, schedule_store: ScheduleStore) -> None: + """Inject a schedule store for delayed and recurring commands.""" + self._schedule_store = schedule_store + + def _restore_scheduled_commands(self, now: datetime) -> List[ScheduledCommandRecord]: + """Return future scheduled commands that should be re-armed.""" + restored = [] + for record in self._schedule_store.records(): + next_run_at = self._datetime_for_now(record.next_run_at, now) + if next_run_at and next_run_at >= now: + restored.append(record) + return restored + + def _store_scheduled_command(self, cmd: BotCommand, next_run_at: datetime) -> ScheduledCommandRecord: + return self._schedule_store.put(cmd, schedule_id=cmd.schedule_id or None, next_run_at=next_run_at) + + def _remove_scheduled_command(self, schedule_id: str) -> bool: + return self._schedule_store.remove(schedule_id) + def set_deps(self, deps: Any) -> None: """Set shared dependency object for new command framework contexts.""" self._deps = deps @@ -126,6 +162,9 @@ def connect(self, channels: GatewayChannels) -> None: log.info(f"Fetching bot info for {backend}...") self._fetch_bot_info(backend) + # Inject backends into AgentCommand subclasses + self._inject_backends_into_agent_commands() + # Subscribe to messages from all adapters # chatom provides unified Message type across all backends messages_in = csp.null_ts(Message) @@ -248,6 +287,53 @@ def _update_user_access_loop(self, backend: str) -> None: except Exception: log.exception(f"Error updating user access for {backend}") + def _inject_backends_into_agent_commands(self) -> None: + """Inject connected BackendBase instances into AgentCommand subclasses.""" + try: + from csp_bot.commands.agent import AgentCommand + except ImportError: + # pydantic-ai / chatom[agent] not installed — skip silently + return + + backends = {} + loops = {} + for name in self._adapters: + result = self._ensure_backend_connected(name) + if result: + connected_backend, loop = result + backends[name] = connected_backend + loops[name] = loop + if backends: + AgentCommand.set_backends(backends, loops=loops) + log.info(f"Injected {len(backends)} connected backends into AgentCommand: {list(backends.keys())}") + + def _track_agent_session_response(self, response: Message, command: BotCommand) -> None: + """Associate a sent response with an agent session for reply tracking. + + Uses the original command's message ID as the key that future replies + will reference (e.g., thread_ts in Slack, or message reference in Discord). + """ + metadata = response.metadata or {} + session_key = metadata.get("agent_session_key") + if not session_key: + return + + try: + from csp_bot.commands.agent import AgentCommand + except ImportError: + return + + # Use the original command's message ID — in Slack threads, replies + # reference this as thread_ts; in Discord, as message_reference. + orig_msg_id = command.message.id if command.message else None + if orig_msg_id: + AgentCommand._sessions.update_response_id(session_key, orig_msg_id) + log.debug(f"Tracked agent session response: session={session_key}, msg_id={orig_msg_id}") + + # Also track by the response message ID if it has one + if response.id and response.id != orig_msg_id: + AgentCommand._sessions.update_response_id(session_key, response.id) + def _ensure_backend_connected(self, backend: str) -> Optional[Tuple[Any, asyncio.AbstractEventLoop]]: """Ensure a connected backend exists for the given platform. @@ -531,10 +617,17 @@ def _handle_commands(self, cmd: ts[BotCommand]) -> Outputs(messages=ts[[Message] with csp.start(): csp.schedule_alarm(a_ratelimit, timedelta(seconds=self.config.ratelimit_seconds), True) + now = csp.now() + for record in self._restore_scheduled_commands(now): + next_run_at = self._datetime_for_now(record.next_run_at, now) + if next_run_at: + csp.schedule_alarm(a_scheduled, next_run_at, record.command) # Handle scheduled command triggers if csp.ticked(a_scheduled): - if a_scheduled.command in self._scheduled: + # Removed schedules may still have an outstanding CSP alarm; the + # store is the source of truth and acts as the tombstone check. + if self._schedule_store.get(a_scheduled.schedule_id) is not None: s_to_process.append(a_scheduled) # Reschedule recurring commands @@ -542,23 +635,25 @@ def _handle_commands(self, cmd: ts[BotCommand]) -> Outputs(messages=ts[[Message] now = csp.now() next_time = croniter(a_scheduled.schedule, now).get_next(datetime) if next_time >= now: + self._store_scheduled_command(a_scheduled, next_time) csp.schedule_alarm(a_scheduled, next_time, a_scheduled) else: - self._scheduled.pop(a_scheduled.command, None) + self._remove_scheduled_command(a_scheduled.schedule_id) # Handle new commands if csp.ticked(cmd): now = csp.now() + delay = self._datetime_for_now(cmd.delay, now) # Check for delayed execution - if cmd.delay and cmd.delay >= now: - self._scheduled[cmd.command] = cmd - csp.schedule_alarm(a_scheduled, cmd.delay, cmd) + if delay and delay >= now: + self._store_scheduled_command(cmd, delay) + csp.schedule_alarm(a_scheduled, delay, cmd) # Check for scheduled execution elif cmd.schedule: next_time = croniter(cmd.schedule, now).get_next(datetime) if next_time >= now: - self._scheduled[cmd.command] = cmd + self._store_scheduled_command(cmd, next_time) csp.schedule_alarm(a_scheduled, next_time, cmd) else: s_to_process.append(cmd) @@ -579,6 +674,8 @@ def _handle_commands(self, cmd: ts[BotCommand]) -> Outputs(messages=ts[[Message] if isinstance(item, Message): log.debug(f"Adding message to buffer: {item.content[:100] if item.content else 'empty'}...") s_buffer.append(item) + # Track agent session responses for reply continuity + self._track_agent_session_response(item, command) elif isinstance(item, BotCommand): next_cycle_commands.append(item) else: @@ -821,6 +918,11 @@ def _extract_commands( # Check for command syntax (supports both / and ! prefixes) if not content.startswith("/") and not content.startswith("!"): + # Check if this is a reply to an active agent session + session_cmd = self._check_agent_session_reply(msg, backend, channel_id) + if session_cmd: + return session_cmd + # If tagged but no command, show help log.info("No command prefix, showing help") return self._create_help_command(msg, backend, channel_id) @@ -905,7 +1007,7 @@ def _extract_commands( # Pre-execute hooks if isinstance(command_runner, ScheduleCommand): - return command_runner.preexecute(bot_cmd, self._scheduled, self) + return command_runner.preexecute(bot_cmd, self._schedule_store, self) elif isinstance(command_runner, StatusCommand): return command_runner.preexecute(bot_cmd, self) return command_runner.preexecute(bot_cmd) @@ -914,6 +1016,79 @@ def _extract_commands( log.exception("Error extracting command") return None + def _check_agent_session_reply(self, msg: Message, backend: str, channel_id: str) -> Optional[BotCommand]: + """Check if the message is a reply to a bot response with an active agent session. + + If so, constructs a BotCommand to continue the conversation. + """ + try: + from csp_bot.commands.agent import AgentCommand + except ImportError: + return None + + # Get the referenced message ID + ref_id = None + if msg.reference and msg.reference.message_id: + ref_id = msg.reference.message_id + elif msg.reply_to and msg.reply_to.id: + ref_id = msg.reply_to.id + # Check thread metadata (Slack thread_ts) + if not ref_id and msg.thread and msg.thread.id: + ref_id = msg.thread.id + + if not ref_id: + return None + + # Look up session by the bot response ID + session = AgentCommand._sessions.get_by_response_id(ref_id) + if session is None: + return None + + # Found an active session — route the reply to the same command + command_name = session.command_name + if command_name not in self._commands: + log.warning(f"Agent session references unknown command: {command_name}") + return None + + command_runner = self._commands[command_name] + content = msg.content or "" + # Strip bot mention if present + bot_name = self._get_bot_name(backend) + bot_id = self._get_bot_id(backend) + if bot_id: + content = re.sub(rf"<@!?{re.escape(bot_id)}>", "", content).strip() + if bot_name and content.startswith(f"@{bot_name}"): + content = content[len(f"@{bot_name}") :].strip() + + source = User( + id=msg.author.id if msg.author else msg.author_id or "", + name=msg.author.name if msg.author else "", + email=getattr(msg.author, "email", "") if msg.author else "", + handle=getattr(msg.author, "handle", "") if msg.author else "", + ) + + channel_name = "" + if msg.channel and hasattr(msg.channel, "name"): + channel_name = msg.channel.name or "" + + bot_cmd = BotCommand( + command=command_name, + args=(content,), + source=source, + targets=(), + channel_id=channel_id, + channel_name=channel_name, + backend=backend, + variant=command_runner.kind() if isinstance(command_runner, BaseCommand) else CommandVariant.REPLY, + message=msg, + delay=None, + schedule="", + times_run=0, + ) + + log.info(f"Routing reply to active agent session: command={command_name}, user={source.id}") + return command_runner.preexecute(bot_cmd) + def _parse_command_args( self, tokens: List[str], @@ -1024,7 +1199,7 @@ def _execute_command(self, cmd: BotCommand) -> Optional[Union[Message, List[Mess if isinstance(command_runner, HelpCommand): responses = command_runner.execute(cmd, MappingProxyType(self._commands)) elif isinstance(command_runner, ScheduleCommand): - responses = command_runner.execute(cmd, self._scheduled) + responses = command_runner.execute(cmd, self._schedule_store) else: responses = command_runner.execute(cmd) elif isinstance(command_runner, Command): diff --git a/csp_bot/commands/__init__.py b/csp_bot/commands/__init__.py index 4419b89..1d62674 100644 --- a/csp_bot/commands/__init__.py +++ b/csp_bot/commands/__init__.py @@ -32,6 +32,11 @@ from .schedule import ScheduleCommand, ScheduleCommandModel from .status import StatusCommand, StatusCommandModel +try: + from .agent import AgentCommand +except ImportError: + pass + __all__ = ( # New framework "Command", @@ -45,6 +50,7 @@ "get_registered_commands", "execute_command_func", # Legacy base classes + "AgentCommand", "BaseCommand", "BaseCommandModel", "NoResponseCommand", diff --git a/csp_bot/commands/agent.py b/csp_bot/commands/agent.py new file mode 100644 index 0000000..ce8ca74 --- /dev/null +++ b/csp_bot/commands/agent.py @@ -0,0 +1,573 @@ +"""AgentCommand — base class for LLM-powered bot commands. + +Provides the boilerplate for running a pydantic-ai Agent from within +a csp-bot command: background execution via a thread pool, automatic +rescheduling until the LLM call completes, and formatted response output. + +Supports **stateful sessions**: when a user replies to a bot response (or +invokes the same command again within a time window), the conversation +history is resumed automatically. Sessions expire after a configurable TTL. + +Subclasses implement :meth:`build_agent` and :meth:`build_prompt`. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import threading +from abc import abstractmethod +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union + +from chatom import Channel, Message +from chatom.backend import BackendBase +from chatom.format import Format, convert_format + +from csp_bot.commands.base import BaseCommand, ReplyCommand +from csp_bot.structs import BotCommand + +try: + from chatom.agent import BackendToolset + from chatom.agent.toolset import AccessPolicy + from pydantic_ai import Agent + from pydantic_ai.messages import ModelMessage +except ImportError as e: + raise ImportError("AgentCommand requires the 'agent' extra. Install with: pip install csp-bot[agent]") from e + +log = logging.getLogger(__name__) + +__all__ = ("AgentCommand",) + +_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="agent-cmd") + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + +@dataclass +class AgentSession: + """Tracks a multi-turn conversation between a user and an agent command.""" + + user_id: str + channel_id: str + command_name: str + message_history: List[ModelMessage] = field(default_factory=list) + last_active: datetime = field(default_factory=_utc_now) + bot_response_id: Optional[str] = None # ID of last bot message (for reply matching) + + def touch(self) -> None: + self.last_active = _utc_now() + + def is_expired(self, ttl_seconds: float) -> bool: + return (_utc_now() - self.last_active).total_seconds() > ttl_seconds + + +class SessionStore: + """Thread-safe store for agent sessions with automatic expiry.""" + + def __init__(self, ttl_seconds: float = 900.0): + self._sessions: Dict[str, AgentSession] = {} + self._response_index: Dict[str, str] = {} # bot_response_id -> session_key + self._ttl = ttl_seconds + self._lock = threading.Lock() + + def get(self, key: str) -> Optional[AgentSession]: + with self._lock: + session = self._sessions.get(key) + if session and session.is_expired(self._ttl): + self._remove_session(key) + return None + return session + + def get_by_response_id(self, response_id: str) -> Optional[AgentSession]: + """Look up a session by the bot's response message ID (for replies).""" + with self._lock: + key = self._response_index.get(response_id) + if key is None: + return None + session = self._sessions.get(key) + if session and session.is_expired(self._ttl): + self._remove_session(key) + return None + return session + + def put(self, key: str, session: AgentSession) -> None: + with self._lock: + self._sessions[key] = session + if session.bot_response_id: + self._response_index[session.bot_response_id] = key + + def update_response_id(self, key: str, response_id: str) -> None: + """Associate a bot response message ID with a session.""" + with self._lock: + session = self._sessions.get(key) + if session: + # Remove old mapping + if session.bot_response_id and session.bot_response_id in self._response_index: + del self._response_index[session.bot_response_id] + session.bot_response_id = response_id + self._response_index[response_id] = key + + def _remove_session(self, key: str) -> None: + """Remove a session (caller must hold lock).""" + session = self._sessions.pop(key, None) + if session and session.bot_response_id: + self._response_index.pop(session.bot_response_id, None) + + def cleanup_expired(self) -> int: + """Remove all expired sessions. Returns count removed.""" + with self._lock: + expired = [k for k, s in self._sessions.items() if s.is_expired(self._ttl)] + for key in expired: + self._remove_session(key) + return len(expired) + + +def _run_agent( + agent: Agent, + prompt: Union[str, Sequence[Any]], + loop: Optional[asyncio.AbstractEventLoop] = None, + message_history: Optional[Sequence[ModelMessage]] = None, +) -> Any: + """Run an agent on an event loop (for use in thread pool). + + If *loop* is provided (i.e. the backend's event loop), it is reused so + that aiohttp sessions bound to it remain valid. Otherwise a fresh loop + is created. + + ``prompt`` may be a plain string or a sequence of pydantic-ai user + content parts (e.g. text plus :class:`~pydantic_ai.BinaryContent` + images), enabling multimodal input. + """ + coro = agent.run(prompt, message_history=message_history) + if loop is not None and loop.is_running(): + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result() + + owns_loop = loop is None + if owns_loop: + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + if owns_loop: + loop.close() + + +class AgentCommand(ReplyCommand): + """Base class for commands that run a pydantic-ai agent with session support. + + Subclasses implement :meth:`build_agent` and :meth:`build_prompt`, + and receive a :class:`~chatom.agent.BackendToolset` for free via + :meth:`build_toolset`. + + **Session continuity**: conversation history is preserved per + user+channel+command. If the user replies to the bot's response + (thread reply or message reference), the prior session is resumed. + Sessions expire after :attr:`session_ttl_seconds`. + + Example:: + + class AskCommand(AgentCommand): + def command(self): return "ask" + def name(self): return "Ask" + def help(self): return "/ask — Ask the AI (reply to continue)" + + def build_agent(self, command): + toolset = self.build_toolset(command) + return Agent( + "anthropic:claude-sonnet-4-6", + toolsets=[toolset] if toolset else [], + instructions="You are a helpful assistant.", + ) + + def build_prompt(self, command): + return " ".join(command.args) + """ + + _backends: ClassVar[Dict[str, BackendBase]] = {} + _backend_loops: ClassVar[Dict[str, asyncio.AbstractEventLoop]] = {} + _futures: ClassVar[Dict[str, Future]] = {} + _sessions: ClassVar[SessionStore] = SessionStore(ttl_seconds=900.0) + + # Configurable delay between polling checks (seconds) + poll_interval: int = 2 + # Maximum time to wait for agent completion (seconds) + timeout: int = 120 + # Session time-to-live (seconds). 0 disables sessions. + session_ttl_seconds: float = 900.0 + # Send a status message every N poll cycles (0 disables) + status_every_n_polls: int = 15 + # When True, image attachments on the incoming message are downloaded + # and passed to the model as multimodal input (so the agent can "see" + # images the user posted). + include_incoming_images: bool = True + # Maximum size (bytes) of an incoming image to download for the model. + max_incoming_image_bytes: int = 5_000_000 + # Status messages shown to the user while processing + status_messages: ClassVar[List[str]] = [ + "Thinking...", + "Still working on it...", + "Processing your request...", + "Almost there...", + "Crunching the data...", + ] + + def __init__(self, *args, **kwargs): + pass + + @classmethod + def set_backends( + cls, + backends: Dict[str, BackendBase], + loops: Optional[Dict[str, asyncio.AbstractEventLoop]] = None, + ) -> None: + """Inject backend instances. Called by Bot after adapter setup.""" + cls._backends = backends + cls._backend_loops = loops or {} + + @classmethod + def set_session_ttl(cls, ttl_seconds: float) -> None: + """Reconfigure the session TTL.""" + cls._sessions = SessionStore(ttl_seconds=ttl_seconds) + + @abstractmethod + def build_agent(self, command: BotCommand) -> Agent: + """Return the pydantic-ai Agent to run for this command.""" + ... + + @abstractmethod + def build_prompt(self, command: BotCommand) -> str: + """Return the user prompt string.""" + ... + + def build_toolset(self, command: BotCommand) -> Optional[BackendToolset]: + """Return a BackendToolset for the command's backend, or None. + + The toolset is configured with an AccessPolicy that enforces: + - The agent can only access the channel where the command was invoked + - The requesting user must be a member of any target channel + - DM reads are blocked by default + - Message count is capped per request + + Subclasses can override :meth:`build_access_policy` to customize. + """ + backend = self._backends.get(command.backend) + if backend is None: + return None + policy = self.build_access_policy(command) + return BackendToolset(backend=backend, access_policy=policy) + + def build_access_policy(self, command: BotCommand) -> "AccessPolicy": + """Build the access policy for this command invocation. + + Override in subclasses to customize access rules. The default + policy restricts access to the channel where the user typed the + command (not the /room redirect target). + """ + # Use the original channel where the user issued the command, + # not the /room-redirected destination. + origin_channel_id = command.message.channel_id if command.message and command.message.channel_id else command.channel_id + return AccessPolicy( + requesting_user=command.source, + invoking_channel_id=origin_channel_id, + restrict_to_invoking_channel=True, + require_membership=True, + block_dm_reads=True, + ) + + def get_model(self, model_name: str = "claude-sonnet-4-6") -> Any: + """Return a model instance configured from environment variables. + + Checks ANTHROPIC_AUTH_TOKEN / ANTHROPIC_API_KEY and ANTHROPIC_BASE_URL + to construct a properly-configured provider. + """ + from pydantic_ai.models.anthropic import AnthropicModel + from pydantic_ai.providers.anthropic import AnthropicProvider + + api_key = os.environ.get("ANTHROPIC_API_KEY") or os.environ.get("ANTHROPIC_AUTH_TOKEN") + base_url = os.environ.get("ANTHROPIC_BASE_URL") + + provider = AnthropicProvider(api_key=api_key, base_url=base_url) + return AnthropicModel(model_name, provider=provider) + + def wrap_symphony_output(self, messageml: str, command: BotCommand) -> str: + """Hook to post-process Symphony MessageML before wrapping in . + + Override in subclasses to wrap output in expandable cards, add + headers, etc. Default is identity (no wrapping). + """ + return messageml + + def _session_key(self, command: BotCommand) -> str: + """Key for session lookup: command:user:channel.""" + return f"{self.command()}:{command.source.id}:{command.channel_id}" + + def _get_session(self, command: BotCommand) -> Optional[AgentSession]: + """Find an existing session — by reply reference or by user+channel.""" + # First: check if this is a reply to a bot message + msg = command.message + reply_to_id = getattr(msg, "reply_to_id", None) or (msg.reference.message_id if getattr(msg, "reference", None) else None) + if reply_to_id: + session = self._sessions.get_by_response_id(reply_to_id) + if session: + session.touch() + return session + + # Second: check by user+channel (same command re-invoked) + session = self._sessions.get(self._session_key(command)) + if session: + session.touch() + return session + + return None + + def _create_session(self, command: BotCommand) -> AgentSession: + """Create a new session for this command invocation.""" + session = AgentSession( + user_id=command.source.id, + channel_id=command.channel_id, + command_name=self.command(), + ) + self._sessions.put(self._session_key(command), session) + return session + + def _command_key(self, command: BotCommand) -> str: + """Unique key for tracking in-flight futures.""" + # Use the message ID for stability across reschedules + msg_id = command.message.id if command.message else "" + return f"{self.command()}:{command.source.id}:{msg_id}" + + def _status_channel(self, command: BotCommand) -> Channel: + """Return the channel where status/progress messages should go. + + When /room redirect is used, status updates go to the origin channel + (where the user typed the command) rather than the redirect destination. + """ + if command.message and command.message.channel_id: + origin_id = command.message.channel_id + origin_name = command.message.channel.name if command.message.channel else "" + if origin_id != command.channel_id: + return Channel(id=origin_id, name=origin_name) + return command.channel + + def _incoming_image_attachments(self, command: BotCommand) -> List[Any]: + """Return image attachments on the incoming message, if any.""" + msg = command.message + if not msg or not getattr(msg, "attachments", None): + return [] + images = [] + for att in msg.attachments: + content_type = getattr(att, "content_type", "") or "" + att_type = getattr(getattr(att, "attachment_type", None), "value", "") + if content_type.startswith("image/") or att_type == "image": + images.append(att) + return images + + def _build_model_prompt(self, command: BotCommand, prompt: str) -> Union[str, List[Any]]: + """Assemble the prompt for the model, attaching incoming images. + + Downloads any image attachments on the incoming message (via the + backend, on its event loop) and returns a list of pydantic-ai + content parts so the model can see them. Falls back to the plain + text prompt when there are no images or download is unavailable. + """ + if not self.include_incoming_images: + return prompt + + images = self._incoming_image_attachments(command) + if not images: + return prompt + + backend = self._backends.get(command.backend) + if backend is None: + return prompt + + from pydantic_ai import BinaryContent + + backend_loop = self._backend_loops.get(command.backend) + parts: List[Any] = [prompt] + for att in images: + if getattr(att, "size", None) and att.size > self.max_incoming_image_bytes: + log.warning("Skipping incoming image %r: %s bytes exceeds limit", getattr(att, "filename", ""), att.size) + continue + try: + data = self._download_on_loop(backend, att, command.message, backend_loop) + except Exception: + log.exception("Failed to download incoming image %r", getattr(att, "filename", "")) + continue + if not data or len(data) > self.max_incoming_image_bytes: + continue + media_type = getattr(att, "content_type", "") or "image/png" + parts.append(BinaryContent(data=data, media_type=media_type)) + + # Only return a multimodal list if we actually attached an image. + return parts if len(parts) > 1 else prompt + + @staticmethod + def _download_on_loop( + backend: BackendBase, + attachment: Any, + message: Optional[Message], + loop: Optional[asyncio.AbstractEventLoop], + ) -> bytes: + """Download an attachment, reusing the backend's event loop if running.""" + coro = backend.download_attachment(attachment, message=message) + if loop is not None and loop.is_running(): + return asyncio.run_coroutine_threadsafe(coro, loop).result() + new_loop = asyncio.new_event_loop() + try: + return new_loop.run_until_complete(coro) + finally: + new_loop.close() + + def preexecute(self, command: BotCommand) -> BotCommand: + """Submit the LLM call to a thread pool and schedule a check.""" + removed_sessions = self._sessions.cleanup_expired() + if removed_sessions: + log.debug("Cleaned up %d expired agent sessions", removed_sessions) + + key = self._command_key(command) + if key not in self._futures: + try: + agent = self.build_agent(command) + prompt_text = self.build_prompt(command) + prompt = self._build_model_prompt(command, prompt_text) + except Exception: + log.exception("Error building agent/prompt for %s", self.command()) + command.args = ("ERROR: Failed to initialize the AI agent.",) + return command + + # Resolve session history + session = self._get_session(command) or self._create_session(command) + history = session.message_history or None + + # Use the backend's event loop so aiohttp sessions stay valid + backend_loop = self._backend_loops.get(command.backend) + future = _executor.submit(_run_agent, agent, prompt, backend_loop, history) + self._futures[key] = future + log.info( + "AgentCommand[%s] submitted for user %s (session history: %d msgs)", + self.command(), + command.source.name, + len(session.message_history), + ) + + command.delay = _utc_now() + timedelta(seconds=self.poll_interval) + return command + + def execute(self, command: BotCommand) -> Optional[Union[Message, List[Union[Message, "BaseCommand"]], "BaseCommand"]]: + """Return result when ready; reschedule if still running.""" + # Handle errors from preexecute + if command.args and len(command.args) == 1 and str(command.args[0]).startswith("ERROR:"): + return Message( + content=str(command.args[0]), + channel=command.channel, + metadata={"backend": command.backend}, + ) + + key = self._command_key(command) + future = self._futures.get(key) + + if future is None: + log.warning("AgentCommand[%s] no future found for key %s", self.command(), key) + return Message( + content="Sorry, something went wrong processing your request.", + channel=command.channel, + metadata={"backend": command.backend}, + ) + + if not future.done(): + # Check timeout + elapsed = command.times_run * self.poll_interval + if elapsed >= self.timeout: + self._futures.pop(key, None) + future.cancel() + return Message( + content="Sorry, the AI request timed out. Please try again.", + channel=command.channel, + metadata={"backend": command.backend}, + ) + # Reschedule — optionally with a status message + command.delay = _utc_now() + timedelta(seconds=self.poll_interval) + command.times_run += 1 + + result: List[Any] = [command] + if self.status_every_n_polls and self.status_messages: + # Status messages go to the origin channel (where the user + # typed the command), NOT the /room redirect destination. + origin_channel = self._status_channel(command) + if command.times_run == 1: + # First poll: always send initial status + status_text = self.status_messages[0] + result.append( + Message( + content=status_text, + channel=origin_channel, + metadata={"backend": command.backend}, + ) + ) + elif command.times_run % self.status_every_n_polls == 0: + idx = (command.times_run // self.status_every_n_polls) % len(self.status_messages) + status_text = self.status_messages[idx] + result.append( + Message( + content=status_text, + channel=origin_channel, + metadata={"backend": command.backend}, + ) + ) + return result + + # Future is done — get result + self._futures.pop(key, None) + try: + result = future.result() + output = str(result.output) if hasattr(result, "output") else str(result) + + # Persist conversation history in session + session = self._get_session(command) or self._create_session(command) + if hasattr(result, "all_messages"): + session.message_history = list(result.all_messages()) + session.touch() + + except Exception: + log.exception("AgentCommand[%s] agent execution failed", self.command()) + return Message( + content="Sorry, there was an error processing your request.", + channel=command.channel, + metadata={"backend": command.backend}, + ) + + # For Symphony, LLM output is markdown that must be converted to + # MessageML. Pre-wrapping here prevents the backend from treating + # the content as pre-formatted MessageML. + if command.backend == "symphony": + messageml = convert_format(output, Format.MARKDOWN, Format.SYMPHONY_MESSAGEML) + messageml = self.wrap_symphony_output(messageml, command) + output = f"{messageml}" + + # Build response — include session key in metadata so the bot can + # associate the response message ID back to the session later. + response = Message( + content=output, + channel=command.channel, + metadata={ + "backend": command.backend, + "agent_session_key": self._session_key(command), + }, + ) + return response + + def on_response_sent(self, session_key: str, response_message_id: str) -> None: + """Associate a sent message ID with the session for reply tracking. + + Called by the Bot after the response message is published, so that + future replies to that message can be routed back to this session. + """ + self._sessions.update_response_id(session_key, response_message_id) diff --git a/csp_bot/commands/context.py b/csp_bot/commands/context.py index a3007a9..13dd167 100644 --- a/csp_bot/commands/context.py +++ b/csp_bot/commands/context.py @@ -165,10 +165,42 @@ def table( # Handle list of lists return Table.from_data(data, headers=headers) - def image(self, url: str, alt: str = "", title: str = "") -> FormattedImage: - """Create an image node.""" - return FormattedImage(url=url, alt_text=alt, title=title) + def image( + self, + url: str = "", + alt: str = "", + title: str = "", + *, + data: Optional[bytes] = None, + filename: str = "", + content_type: str = "", + ) -> FormattedImage: + """Create an image node. + + Provide either a ``url`` (the image is linked/unfurled) or raw + ``data`` bytes (the image is uploaded to the chat via the backend's + file-upload API). + """ + return FormattedImage( + url=url, + alt_text=alt, + title=title, + data=data, + filename=filename, + content_type=content_type, + ) - def attachment(self, url: str, filename: str, content_type: str = "") -> FormattedAttachment: - """Create an attachment node.""" - return FormattedAttachment(url=url, filename=filename, content_type=content_type) + def attachment( + self, + url: str = "", + filename: str = "", + content_type: str = "", + *, + data: Optional[bytes] = None, + ) -> FormattedAttachment: + """Create an attachment node. + + Provide either a ``url`` or raw ``data`` bytes. When ``data`` is + given the file is uploaded to the chat rather than linked. + """ + return FormattedAttachment(url=url, filename=filename, content_type=content_type, data=data) diff --git a/csp_bot/commands/echo.py b/csp_bot/commands/echo.py index 9f7b2d4..b4291e0 100644 --- a/csp_bot/commands/echo.py +++ b/csp_bot/commands/echo.py @@ -1,15 +1,15 @@ """Echo command for csp-bot. -A simple command that echoes back messages. +A simple command that echoes back messages using FormattedMessage. """ from logging import getLogger from typing import Optional, Type from chatom import Message +from chatom.format import FormattedMessage, Text, UserMention from csp_bot.structs import BotCommand -from csp_bot.utils import mention_users from .base import BaseCommand, BaseCommandModel, ReplyToOtherCommand @@ -31,21 +31,27 @@ def help(self) -> str: def execute(self, command: BotCommand) -> Optional[Message]: log.info(f"Echo command: {command.command}") - # Build content from args - content = " ".join(command.args) if command.args else "" + text = " ".join(command.args) if command.args else "" + has_targets = bool(command.targets) - # Add mentions for any tagged users - if command.targets: - mentions = mention_users(list(command.targets), command.backend) - if mentions: - content = f"{content} {mentions}".strip() - - # If nothing to echo (no args and no targets), return None - if not content: + if not text and not has_targets: return None + msg = FormattedMessage(metadata={"backend": command.backend}) + if text: + msg.content.append(Text(content=text)) + for target in command.targets: + if msg.content: + msg.content.append(Text(content=" ")) + msg.content.append( + UserMention( + user_id=target.id, + display_name=getattr(target, "display_name", "") or getattr(target, "name", "") or "", + ) + ) + return Message( - content=content, + content=msg.render_for(command.backend), channel=command.channel, metadata={"backend": command.backend}, ) diff --git a/csp_bot/commands/executor.py b/csp_bot/commands/executor.py index 47043c1..486b012 100644 --- a/csp_bot/commands/executor.py +++ b/csp_bot/commands/executor.py @@ -15,7 +15,9 @@ from typing import Any, List, Optional, Union from chatom import Message +from chatom.base.attachment import Attachment, AttachmentType, Image as BaseImage from chatom.format import FormattedMessage +from chatom.format.attachment import FormattedAttachment, FormattedImage from csp_bot.structs import BotCommand @@ -43,6 +45,56 @@ def _get_event_loop() -> asyncio.AbstractEventLoop: return _loop +def _extract_attachments(fm: FormattedMessage) -> list: + """Extract base Attachment objects from a FormattedMessage. + + Converts FormattedAttachment/FormattedImage from both the content list + and the attachments list into chatom base Attachment/Image objects + suitable for Message.attachments. + """ + result = [] + for node in fm.content: + if isinstance(node, FormattedImage): + result.append( + BaseImage( + url=node.url, + data=node.data, + filename=node.filename, + alt_text=node.alt_text, + content_type=node.content_type or "image/png", + size=node.size if hasattr(node, "size") else None, + width=node.width, + height=node.height, + attachment_type=AttachmentType.IMAGE, + ) + ) + elif isinstance(node, FormattedAttachment): + att_type = Attachment.from_content_type(node.content_type) if node.content_type else AttachmentType.FILE + result.append( + Attachment( + filename=node.filename, + url=node.url, + data=node.data, + size=node.size, + content_type=node.content_type, + attachment_type=att_type, + ) + ) + for fa in fm.attachments: + att_type = Attachment.from_content_type(fa.content_type) if fa.content_type else AttachmentType.FILE + result.append( + Attachment( + filename=fa.filename, + url=fa.url, + data=fa.data, + size=fa.size, + content_type=fa.content_type, + attachment_type=att_type, + ) + ) + return result + + def _coerce_response(item: Any, backend: str) -> Optional[Union[Message, BotCommand]]: """Coerce a command return value into a chatom Message. @@ -71,8 +123,10 @@ def _coerce_response(item: Any, backend: str) -> Optional[Union[Message, BotComm if isinstance(item, FormattedMessage): rendered = item.render_for(backend) + attachments = _extract_attachments(item) msg = Message( content=rendered, + attachments=attachments, metadata={"backend": backend, "formatted": item}, ) return msg diff --git a/csp_bot/commands/help.py b/csp_bot/commands/help.py index 4e5536b..adf935b 100644 --- a/csp_bot/commands/help.py +++ b/csp_bot/commands/help.py @@ -1,13 +1,13 @@ """Help command for csp-bot. -Uses chatom's formatting utilities for cross-platform output. +Uses chatom's FormattedMessage for automatic cross-platform rendering. """ -from html import escape from logging import getLogger -from typing import Mapping, Optional, Type +from typing import Any, Mapping, Optional, Type from chatom import Message +from chatom.format import FormattedMessage, Heading, Table, Text from csp_bot.structs import BotCommand @@ -16,6 +16,25 @@ log = getLogger(__name__) +def _command_backends(runner: Any) -> list: + """Get backends list from either legacy or new command types.""" + if isinstance(runner, BaseCommand): + return runner.backends() + return getattr(runner, "backends", []) or [] + + +def _command_name(runner: Any) -> str: + if isinstance(runner, BaseCommand): + return runner.name() + return getattr(runner, "name", "") or "" + + +def _command_help(runner: Any) -> str: + if isinstance(runner, BaseCommand): + return runner.help() + return getattr(runner, "help", "") or "" + + class HelpCommand(ReplyCommand): """Display help for available commands.""" @@ -31,69 +50,39 @@ def help(self) -> str: def execute( self, command: BotCommand, - commands: Mapping[str, "BaseCommand"] = None, + commands: Mapping[str, Any] = None, ) -> Optional[Message]: log.info(f"Help command: {command.command}") # Collect help for each command - helps = [] + rows = [] for cmd_key, cmd_inst in commands.items(): - # Skip hidden commands if cmd_key.startswith("_"): continue - # Skip commands not supported on this backend - if cmd_inst.backends() and command.backend not in cmd_inst.backends(): + backends = _command_backends(cmd_inst) + if backends and command.backend not in backends: continue - # Filter by requested commands if args provided if command.args and cmd_key not in command.args: continue - helps.append( - ( - escape(cmd_key, quote=True), - escape(cmd_inst.name(), quote=True), - escape(cmd_inst.help(), quote=True), - ) + rows.append( + { + "Command": f"/{cmd_key}", + "Name": _command_name(cmd_inst), + "Info": _command_help(cmd_inst), + } ) - helps = sorted(helps, key=lambda x: x[0]) - - # Build response using chatom's FormattedMessage - if command.backend == "symphony": - # Symphony supports expandable cards - table = "" - for cmd_key, name, help_text in helps: - table += f"" - table += "
CommandNameInfo
/{cmd_key}{name}{help_text}
" - content = f'
Bot Commands Help
{table}
' - elif command.backend == "slack": - # Slack uses mrkdwn code blocks for tables - lines = ["*Bot Commands Help*", "```"] - lines.append(f"{'Command':<15} {'Name':<15} {'Info'}") - lines.append("-" * 60) - for cmd_key, name, help_text in helps: - lines.append(f"/{cmd_key:<14} {name:<15} {help_text}") - lines.append("```") - content = "\n".join(lines) - elif command.backend == "discord": - # Discord uses markdown tables - lines = ["# Bot Commands Help", ""] - lines.append("| Command | Name | Info |") - lines.append("|---------|------|------|") - for cmd_key, name, help_text in helps: - lines.append(f"| /{cmd_key} | {name} | {help_text} |") - content = "\n".join(lines) - else: - # Plain text fallback - lines = ["Bot Commands Help", ""] - for cmd_key, name, help_text in helps: - lines.append(f"/{cmd_key}: {name} - {help_text}") - content = "\n".join(lines) + rows.sort(key=lambda r: r["Command"]) + + msg = FormattedMessage(metadata={"backend": command.backend}) + msg.content.append(Heading(child=Text(content="Bot Commands Help"), level=3)) + msg.content.append(Table.from_dict_list(rows, columns=["Command", "Name", "Info"])) return Message( - content=content, + content=msg.render_for(command.backend), channel=command.channel, metadata={"backend": command.backend}, ) diff --git a/csp_bot/commands/schedule.py b/csp_bot/commands/schedule.py index 86d49b3..551e07e 100644 --- a/csp_bot/commands/schedule.py +++ b/csp_bot/commands/schedule.py @@ -3,11 +3,11 @@ Allows scheduling commands for delayed or recurring execution. """ -from html import escape from logging import getLogger -from typing import TYPE_CHECKING, List, Mapping, Optional, Type +from typing import TYPE_CHECKING, List, Optional, Type from chatom import Message +from chatom.format import Bold, FormattedMessage, Table, Text from croniter import CroniterBadCronError, croniter from dateparser import parse @@ -17,6 +17,7 @@ if TYPE_CHECKING: from csp_bot import Bot + from csp_bot.persistence import ScheduleStore log = getLogger(__name__) @@ -36,7 +37,7 @@ def help(self) -> str: def preexecute( self, command: BotCommand, - schedule: Mapping[str, BotCommand], + schedule: "ScheduleStore", bot_instance: "Bot", ) -> Optional[BotCommand]: log.info(f"Schedule command preexecute: {command.command}") @@ -58,7 +59,9 @@ def preexecute( if command.args[0] == "remove": for arg in command.args[1:]: - schedule.pop(arg, None) + removed = bot_instance._remove_scheduled_command(arg) + if not removed: + log.warning("No scheduled command found for id: %s", arg) return None if command.args[0] == "add": @@ -108,24 +111,17 @@ def preexecute( return None - def execute(self, command: BotCommand, schedule: Mapping[str, BotCommand]) -> Message: + def execute(self, command: BotCommand, schedule: "ScheduleStore") -> Message: log.info("Schedule list command") - # Build schedule listing - if command.backend == "symphony": - table = "" - for scheduled_cmd in schedule.values(): - table += f"" - table += "
IDCommand
{id(scheduled_cmd)}{escape(scheduled_cmd.command)}
" - content = f'
Bot Schedule
{table}
' - else: - lines = ["*Scheduled Commands*", ""] - for scheduled_cmd in schedule.values(): - lines.append(f"- {id(scheduled_cmd)}: /{scheduled_cmd.command}") - content = "\n".join(lines) + rows = [{"ID": record.schedule_id, "Command": f"/{record.command.command}"} for record in schedule.records()] + + msg = FormattedMessage(metadata={"backend": command.backend}) + msg.content.append(Bold(child=Text(content="Scheduled Commands"))) + msg.content.append(Table.from_dict_list(rows, columns=["ID", "Command"])) return Message( - content=content, + content=msg.render_for(command.backend), channel=command.channel, metadata={"backend": command.backend}, ) diff --git a/csp_bot/commands/status.py b/csp_bot/commands/status.py index bbe6a0a..d1ea983 100644 --- a/csp_bot/commands/status.py +++ b/csp_bot/commands/status.py @@ -1,6 +1,6 @@ """Status command for csp-bot. -Displays system and bot status information. +Displays system and bot status information using FormattedMessage. """ from datetime import datetime @@ -12,6 +12,7 @@ import psutil from chatom import Message +from chatom.format import Bold, FormattedMessage, Table, Text from csp_bot.structs import BotCommand @@ -47,25 +48,27 @@ def preexecute(self, command: BotCommand, bot_instance: "Bot") -> BotCommand: def execute(self, command: BotCommand) -> Optional[Message]: log.info("Status command") - # Build status information - lines = [] - - lines.append(f"**Now**: {datetime.utcnow()}") - lines.append(f"**Backends**: {', '.join(self._adapters)}") - lines.append(f"**CPU**: {psutil.cpu_percent()}%") - lines.append(f"**Memory**: {psutil.virtual_memory().percent}%") - lines.append(f"**Memory Available**: {round(psutil.virtual_memory().available * 100 / psutil.virtual_memory().total, 2)}%") - lines.append(f"**Host**: {_HOSTNAME}") - lines.append(f"**User**: {_USER}") - - current_process = psutil.Process() - lines.append(f"**PID**: {current_process.pid}") - lines.append(f"**Active Threads**: {active_count()}") - - content = "\n".join(lines) + mem = psutil.virtual_memory() + proc = psutil.Process() + + rows = [ + {"Metric": "Now", "Value": str(datetime.utcnow())}, + {"Metric": "Backends", "Value": ", ".join(self._adapters)}, + {"Metric": "CPU", "Value": f"{psutil.cpu_percent()}%"}, + {"Metric": "Memory", "Value": f"{mem.percent}%"}, + {"Metric": "Memory Available", "Value": f"{round(mem.available * 100 / mem.total, 2)}%"}, + {"Metric": "Host", "Value": _HOSTNAME}, + {"Metric": "User", "Value": _USER}, + {"Metric": "PID", "Value": str(proc.pid)}, + {"Metric": "Active Threads", "Value": str(active_count())}, + ] + + msg = FormattedMessage(metadata={"backend": command.backend}) + msg.content.append(Bold(child=Text(content="Bot Status"))) + msg.content.append(Table.from_dict_list(rows, columns=["Metric", "Value"])) return Message( - content=content, + content=msg.render_for(command.backend), channel=command.channel, metadata={"backend": command.backend}, ) diff --git a/csp_bot/persistence.py b/csp_bot/persistence.py new file mode 100644 index 0000000..e0ce6bf --- /dev/null +++ b/csp_bot/persistence.py @@ -0,0 +1,352 @@ +"""Persistence primitives for bot runtime state. + +The first implementation is in-memory so existing behavior stays simple, but the +protocol is shaped for durable backends such as SQLite or Redis. +""" + +from __future__ import annotations + +import threading +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from pickle import HIGHEST_PROTOCOL, dumps, loads +from typing import Any, Iterable, Optional, Protocol +from urllib.parse import quote, unquote + +from csp_bot.structs import BotCommand + +__all__ = ( + "FsspecStateStore", + "InMemoryStateStore", + "ScheduleStore", + "ScheduledCommandRecord", + "StateStore", + "StoredRecord", +) + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def _to_utc(value: Optional[datetime]) -> Optional[datetime]: + if value is None: + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + +def _sort_datetime(value: Optional[datetime]) -> datetime: + return _to_utc(value) or datetime.max.replace(tzinfo=timezone.utc) + + +@dataclass(frozen=True) +class StoredRecord: + """A single namespaced state record.""" + + namespace: str + key: str + value: Any + created_at: datetime + updated_at: datetime + expires_at: Optional[datetime] = None + + def is_expired(self, now: Optional[datetime] = None) -> bool: + if self.expires_at is None: + return False + return (_to_utc(now) or _utc_now()) >= _to_utc(self.expires_at) + + +class StateStore(Protocol): + """Namespace/key/value persistence interface for bot runtime state.""" + + def get(self, namespace: str, key: str, default: Any = None) -> Any: + """Return a value, or default if missing or expired.""" + ... + + def put(self, namespace: str, key: str, value: Any, ttl_seconds: Optional[float] = None) -> StoredRecord: + """Store a value with an optional TTL and return its record metadata. + + A TTL of ``None`` means no expiry. A TTL of ``0`` expires immediately. + """ + ... + + def delete(self, namespace: str, key: str) -> bool: + """Delete a value and return whether a record was removed.""" + ... + + def records(self, namespace: str, prefix: str = "") -> Iterable[StoredRecord]: + """Return unexpired records in a namespace, optionally filtered by key prefix.""" + ... + + def cleanup_expired(self, namespace: Optional[str] = None) -> int: + """Remove expired records and return the number removed.""" + ... + + def clear(self, namespace: Optional[str] = None) -> int: + """Remove records, optionally limited to one namespace.""" + ... + + +class InMemoryStateStore: + """Thread-safe in-memory StateStore implementation. + + Values are stored by reference. Durable implementations are expected to + serialize or copy values at the storage boundary. + """ + + def __init__(self) -> None: + self._records: dict[tuple[str, str], StoredRecord] = {} + self._lock = threading.Lock() + + def get(self, namespace: str, key: str, default: Any = None) -> Any: + with self._lock: + record_key = (namespace, key) + record = self._records.get(record_key) + if record is None: + return default + if record.is_expired(): + self._records.pop(record_key, None) + return default + return record.value + + def put(self, namespace: str, key: str, value: Any, ttl_seconds: Optional[float] = None) -> StoredRecord: + now = _utc_now() + record_key = (namespace, key) + with self._lock: + existing = self._records.get(record_key) + expires_at = now + timedelta(seconds=ttl_seconds) if ttl_seconds is not None else None + record = StoredRecord( + namespace=namespace, + key=key, + value=value, + created_at=existing.created_at if existing else now, + updated_at=now, + expires_at=_to_utc(expires_at), + ) + self._records[record_key] = record + return record + + def delete(self, namespace: str, key: str) -> bool: + with self._lock: + return self._records.pop((namespace, key), None) is not None + + def records(self, namespace: str, prefix: str = "") -> list[StoredRecord]: + # Cleanup and read are separate lock acquisitions; this keeps the + # StateStore protocol simple for backends with native expiry support. + self.cleanup_expired(namespace) + with self._lock: + return [ + record + for (record_namespace, record_key), record in sorted(self._records.items()) + if record_namespace == namespace and record_key.startswith(prefix) + ] + + def cleanup_expired(self, namespace: Optional[str] = None) -> int: + now = _utc_now() + with self._lock: + expired_keys = [ + record_key + for record_key, record in self._records.items() + if (namespace is None or record.namespace == namespace) and record.is_expired(now) + ] + for record_key in expired_keys: + self._records.pop(record_key, None) + return len(expired_keys) + + def clear(self, namespace: Optional[str] = None) -> int: + """Remove records, optionally limited to one namespace.""" + with self._lock: + if namespace is None: + removed = len(self._records) + self._records.clear() + return removed + removed_keys = [record_key for record_key, record in self._records.items() if record.namespace == namespace] + for record_key in removed_keys: + self._records.pop(record_key, None) + return len(removed_keys) + + +class FsspecStateStore: + """fsspec-backed StateStore implementation. + + Records are stored as one pickle payload per namespace/key below ``url``. + Pickle keeps this backend compatible with the current ``StateStore`` value + contract, which accepts arbitrary Python objects. Only use this store with + trusted storage locations and compatible csp-bot/chatom versions. + """ + + def __init__(self, url: str, **storage_options: Any) -> None: + import fsspec + + self._mapper = fsspec.get_mapper(url, create=True, **storage_options) + self._lock = threading.Lock() + + def get(self, namespace: str, key: str, default: Any = None) -> Any: + with self._lock: + map_key = self._map_key(namespace, key) + record = self._load_record(map_key) + if record is None: + return default + if record.is_expired(): + self._delete_map_key(map_key) + return default + return record.value + + def put(self, namespace: str, key: str, value: Any, ttl_seconds: Optional[float] = None) -> StoredRecord: + now = _utc_now() + map_key = self._map_key(namespace, key) + with self._lock: + existing = self._load_record(map_key) + expires_at = now + timedelta(seconds=ttl_seconds) if ttl_seconds is not None else None + record = StoredRecord( + namespace=namespace, + key=key, + value=value, + created_at=existing.created_at if existing else now, + updated_at=now, + expires_at=_to_utc(expires_at), + ) + self._mapper[map_key] = dumps(record, protocol=HIGHEST_PROTOCOL) + return record + + def delete(self, namespace: str, key: str) -> bool: + with self._lock: + return self._delete_map_key(self._map_key(namespace, key)) + + def records(self, namespace: str, prefix: str = "") -> list[StoredRecord]: + self.cleanup_expired(namespace) + encoded_namespace = self._encode(namespace) + with self._lock: + records = [] + for map_key in sorted(self._mapper.keys()): + if not map_key.startswith(f"{encoded_namespace}/"): + continue + record = self._load_record(map_key) + if record and record.key.startswith(prefix): + records.append(record) + return records + + def cleanup_expired(self, namespace: Optional[str] = None) -> int: + now = _utc_now() + encoded_namespace = self._encode(namespace) if namespace is not None else None + with self._lock: + expired_keys = [] + for map_key in list(self._mapper.keys()): + if encoded_namespace is not None and not map_key.startswith(f"{encoded_namespace}/"): + continue + record = self._load_record(map_key) + if record and record.is_expired(now): + expired_keys.append(map_key) + for map_key in expired_keys: + self._delete_map_key(map_key) + return len(expired_keys) + + def clear(self, namespace: Optional[str] = None) -> int: + with self._lock: + if namespace is None: + keys = list(self._mapper.keys()) + else: + encoded_namespace = self._encode(namespace) + keys = [map_key for map_key in self._mapper.keys() if map_key.startswith(f"{encoded_namespace}/")] + for map_key in keys: + self._delete_map_key(map_key) + return len(keys) + + @staticmethod + def _encode(value: str) -> str: + return quote(value, safe="") + + @staticmethod + def _decode(value: str) -> str: + return unquote(value) + + @classmethod + def _map_key(cls, namespace: str, key: str) -> str: + return f"{cls._encode(namespace)}/{cls._encode(key)}" + + def _load_record(self, map_key: str) -> Optional[StoredRecord]: + try: + data = self._mapper[map_key] + except KeyError: + return None + record = loads(bytes(data)) + if not isinstance(record, StoredRecord): + raise TypeError(f"Stored value is not a StoredRecord: {map_key}") + return record + + def _delete_map_key(self, map_key: str) -> bool: + try: + del self._mapper[map_key] + except KeyError: + return False + return True + + +@dataclass(frozen=True) +class ScheduledCommandRecord: + """Persistent representation of a scheduled bot command.""" + + schedule_id: str + command: BotCommand + next_run_at: Optional[datetime] + created_at: datetime + updated_at: datetime + + @property + def is_recurring(self) -> bool: + return bool(self.command.schedule) + + +class ScheduleStore: + """Typed repository for scheduled BotCommand state.""" + + namespace = "csp_bot.schedules" + + def __init__(self, store: StateStore) -> None: + self._store = store + + def put( + self, + command: BotCommand, + schedule_id: Optional[str] = None, + next_run_at: Optional[datetime] = None, + ttl_seconds: Optional[float] = None, + ) -> ScheduledCommandRecord: + """Store a scheduled command. + + If no schedule ID is provided, a new ID is generated and assigned back + to ``command.schedule_id`` so the command carried by CSP alarms can be + matched against the stored record when it fires. + """ + now = _utc_now() + resolved_schedule_id = schedule_id or getattr(command, "schedule_id", "") or uuid.uuid4().hex + command.schedule_id = resolved_schedule_id + existing = self.get(resolved_schedule_id) + record = ScheduledCommandRecord( + schedule_id=resolved_schedule_id, + command=command, + next_run_at=_to_utc(next_run_at if next_run_at is not None else command.delay), + created_at=_to_utc(existing.created_at) if existing else now, + updated_at=now, + ) + self._store.put(self.namespace, resolved_schedule_id, record, ttl_seconds=ttl_seconds) + return record + + def get(self, schedule_id: str) -> Optional[ScheduledCommandRecord]: + record = self._store.get(self.namespace, schedule_id) + if isinstance(record, ScheduledCommandRecord): + return record + return None + + def remove(self, schedule_id: str) -> bool: + return self._store.delete(self.namespace, schedule_id) + + def records(self) -> list[ScheduledCommandRecord]: + records = [record.value for record in self._store.records(self.namespace) if isinstance(record.value, ScheduledCommandRecord)] + return sorted(records, key=lambda record: (_sort_datetime(record.next_run_at), _sort_datetime(record.created_at), record.schedule_id)) + + def cleanup_expired(self) -> int: + return self._store.cleanup_expired(self.namespace) diff --git a/csp_bot/structs.py b/csp_bot/structs.py index 4cc7017..fc1502f 100644 --- a/csp_bot/structs.py +++ b/csp_bot/structs.py @@ -151,6 +151,9 @@ class BotCommand(GatewayStruct): schedule: str """Cron expression for recurring commands.""" + schedule_id: str = "" + """Stable ID for delayed or scheduled command records.""" + times_run: int """Number of times this command has run.""" diff --git a/csp_bot/tests/test_agent_command.py b/csp_bot/tests/test_agent_command.py new file mode 100644 index 0000000..8dbf5c1 --- /dev/null +++ b/csp_bot/tests/test_agent_command.py @@ -0,0 +1,496 @@ +"""Tests for AgentCommand base class.""" + +import asyncio +import threading +from concurrent.futures import Future +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import pytest +from chatom import Message, User +from chatom.backend import BackendBase + +from csp_bot.commands.agent import AgentCommand, AgentSession, SessionStore, _run_agent +from csp_bot.structs import BotCommand, CommandVariant + + +class ConcreteAgentCommand(AgentCommand): + """Minimal concrete subclass for testing.""" + + def command(self): + return "test-agent" + + def name(self): + return "Test Agent" + + def help(self): + return "/test-agent — a test agent command" + + def build_agent(self, command): + from pydantic_ai import Agent + + return Agent("test", system_prompt="You are a test agent.") + + def build_prompt(self, command): + return " ".join(command.args) if command.args else "Hello" + + +@pytest.fixture +def cmd(): + """Fresh AgentCommand instance with clean state.""" + AgentCommand._futures = {} + AgentCommand._backends = {} + AgentCommand._sessions = SessionStore(ttl_seconds=900.0) + return ConcreteAgentCommand() + + +@pytest.fixture +def bot_command(): + """Create a BotCommand for testing.""" + return BotCommand( + command="test-agent", + args=("summarize", "this"), + source=User(id="U123", name="Test User"), + targets=(), + channel_id="C456", + channel_name="general", + backend="slack", + variant=CommandVariant.REPLY, + message=Message(id="msg1", content="/test-agent summarize this"), + delay=datetime.now(timezone.utc), + schedule="", + times_run=0, + ) + + +@pytest.fixture +def mock_backend(): + """Create a mock BackendBase.""" + backend = MagicMock(spec=BackendBase) + backend.name = "slack" + backend.capabilities = MagicMock() + backend.capabilities.supports = MagicMock(return_value=True) + return backend + + +class TestSetBackends: + def test_set_backends(self, mock_backend): + AgentCommand.set_backends({"slack": mock_backend}) + assert AgentCommand._backends == {"slack": mock_backend} + # Cleanup + AgentCommand._backends = {} + + def test_build_toolset_returns_toolset_when_backend_available(self, cmd, bot_command, mock_backend): + AgentCommand.set_backends({"slack": mock_backend}) + toolset = cmd.build_toolset(bot_command) + assert toolset is not None + AgentCommand._backends = {} + + def test_build_toolset_returns_none_when_no_backend(self, cmd, bot_command): + AgentCommand.set_backends({}) + toolset = cmd.build_toolset(bot_command) + assert toolset is None + + +class TestPreexecute: + def test_submits_future_and_sets_delay(self, cmd, bot_command): + with patch("csp_bot.commands.agent._executor") as mock_executor: + mock_future = MagicMock(spec=Future) + mock_executor.submit.return_value = mock_future + + result = cmd.preexecute(bot_command) + + assert mock_executor.submit.called + assert result.delay.replace(tzinfo=timezone.utc) > datetime.now(timezone.utc) - timedelta(seconds=5) + # Future should be stored + assert len(AgentCommand._futures) == 1 + + def test_does_not_resubmit_existing_future(self, cmd, bot_command): + with patch("csp_bot.commands.agent._executor") as mock_executor: + mock_future = MagicMock(spec=Future) + mock_executor.submit.return_value = mock_future + + cmd.preexecute(bot_command) + cmd.preexecute(bot_command) + + # Should only submit once + assert mock_executor.submit.call_count == 1 + + def test_handles_build_agent_error(self, cmd, bot_command): + with patch.object(cmd, "build_agent", side_effect=RuntimeError("fail")): + result = cmd.preexecute(bot_command) + assert result.args[0].startswith("ERROR:") + + def test_cleans_up_expired_sessions(self, cmd, bot_command): + AgentCommand._sessions = SessionStore(ttl_seconds=0.01) + expired = AgentSession(user_id="U1", channel_id="C1", command_name="ask", bot_response_id="old-response") + expired.last_active = datetime.now(timezone.utc) - timedelta(seconds=1) + AgentCommand._sessions.put("old-key", expired) + + with patch("csp_bot.commands.agent._executor") as mock_executor: + mock_executor.submit.return_value = MagicMock(spec=Future) + cmd.preexecute(bot_command) + + assert "old-key" not in AgentCommand._sessions._sessions + assert "old-response" not in AgentCommand._sessions._response_index + + +class TestExecute: + def test_returns_error_message_from_preexecute(self, cmd, bot_command): + bot_command.args = ("ERROR: Failed to initialize",) + result = cmd.execute(bot_command) + assert isinstance(result, Message) + assert "ERROR:" in result.content + + def test_reschedules_when_future_not_done(self, cmd, bot_command): + key = cmd._command_key(bot_command) + mock_future = MagicMock(spec=Future) + mock_future.done.return_value = False + AgentCommand._futures[key] = mock_future + + result = cmd.execute(bot_command) + # Should return a list with the rescheduled command and a status message + assert isinstance(result, list) + commands = [r for r in result if isinstance(r, BotCommand)] + messages = [r for r in result if isinstance(r, Message)] + assert len(commands) == 1 + assert commands[0].times_run == 1 + # First poll sends initial "Thinking..." status + assert len(messages) == 1 + assert "Thinking" in messages[0].content + + def test_returns_result_when_future_done(self, cmd, bot_command): + key = cmd._command_key(bot_command) + mock_future = MagicMock(spec=Future) + mock_future.done.return_value = True + mock_result = MagicMock() + mock_result.output = "Here is your summary." + mock_result.all_messages.return_value = [{"role": "user", "content": "test"}] + mock_future.result.return_value = mock_result + AgentCommand._futures[key] = mock_future + + result = cmd.execute(bot_command) + assert isinstance(result, Message) + assert result.content == "Here is your summary." + assert key not in AgentCommand._futures + + def test_handles_future_exception(self, cmd, bot_command): + key = cmd._command_key(bot_command) + mock_future = MagicMock(spec=Future) + mock_future.done.return_value = True + mock_future.result.side_effect = RuntimeError("LLM failed") + AgentCommand._futures[key] = mock_future + + result = cmd.execute(bot_command) + assert isinstance(result, Message) + assert "error" in result.content.lower() + assert key not in AgentCommand._futures + + def test_timeout_cancels_future(self, cmd, bot_command): + key = cmd._command_key(bot_command) + mock_future = MagicMock(spec=Future) + mock_future.done.return_value = False + AgentCommand._futures[key] = mock_future + + # Simulate many poll cycles exceeding timeout + bot_command.times_run = cmd.timeout // cmd.poll_interval + 1 + result = cmd.execute(bot_command) + + assert isinstance(result, Message) + assert "timed out" in result.content.lower() + mock_future.cancel.assert_called_once() + assert key not in AgentCommand._futures + + def test_no_future_returns_error(self, cmd, bot_command): + bot_command.args = () # Clear any error args + result = cmd.execute(bot_command) + assert isinstance(result, Message) + assert "something went wrong" in result.content.lower() + + +class TestRunAgent: + def test_uses_threadsafe_submission_for_running_backend_loop(self): + class FakeAgent: + async def run(self, prompt, message_history=None): + await asyncio.sleep(0) + return {"prompt": prompt, "message_history": message_history} + + loop = asyncio.new_event_loop() + loop_ready = threading.Event() + + def run_loop(): + asyncio.set_event_loop(loop) + loop_ready.set() + loop.run_forever() + + loop_thread = threading.Thread(target=run_loop, daemon=True) + loop_thread.start() + loop_ready.wait(timeout=1) + + try: + result = _run_agent(FakeAgent(), "hello", loop, ["previous"]) + finally: + loop.call_soon_threadsafe(loop.stop) + loop_thread.join(timeout=1) + loop.close() + + assert result == {"prompt": "hello", "message_history": ["previous"]} + + +class TestSessionStore: + def test_put_and_get(self): + store = SessionStore(ttl_seconds=60.0) + session = AgentSession(user_id="U1", channel_id="C1", command_name="ask") + store.put("key1", session) + assert store.get("key1") is session + + def test_get_returns_none_for_missing_key(self): + store = SessionStore(ttl_seconds=60.0) + assert store.get("nonexistent") is None + + def test_expired_session_returns_none(self): + store = SessionStore(ttl_seconds=0.01) + session = AgentSession(user_id="U1", channel_id="C1", command_name="ask") + session.last_active = datetime.now(timezone.utc) - timedelta(seconds=1) + store.put("key1", session) + assert store.get("key1") is None + + def test_get_by_response_id(self): + store = SessionStore(ttl_seconds=60.0) + session = AgentSession(user_id="U1", channel_id="C1", command_name="ask", bot_response_id="resp1") + store.put("key1", session) + assert store.get_by_response_id("resp1") is session + + def test_get_by_response_id_returns_none_for_unknown(self): + store = SessionStore(ttl_seconds=60.0) + assert store.get_by_response_id("unknown") is None + + def test_update_response_id(self): + store = SessionStore(ttl_seconds=60.0) + session = AgentSession(user_id="U1", channel_id="C1", command_name="ask") + store.put("key1", session) + store.update_response_id("key1", "new-resp-id") + assert store.get_by_response_id("new-resp-id") is session + + def test_update_response_id_removes_old_mapping(self): + store = SessionStore(ttl_seconds=60.0) + session = AgentSession(user_id="U1", channel_id="C1", command_name="ask", bot_response_id="old-id") + store.put("key1", session) + store.update_response_id("key1", "new-id") + assert store.get_by_response_id("old-id") is None + assert store.get_by_response_id("new-id") is session + + def test_cleanup_expired(self): + store = SessionStore(ttl_seconds=0.01) + s1 = AgentSession(user_id="U1", channel_id="C1", command_name="ask") + s1.last_active = datetime.now(timezone.utc) - timedelta(seconds=1) + s2 = AgentSession(user_id="U2", channel_id="C2", command_name="ask") + store.put("key1", s1) + store.put("key2", s2) + removed = store.cleanup_expired() + assert removed == 1 + assert store.get("key1") is None + assert store.get("key2") is s2 + + def test_cleanup_expired_removes_response_index(self): + store = SessionStore(ttl_seconds=0.01) + expired = AgentSession(user_id="U1", channel_id="C1", command_name="ask", bot_response_id="old-response") + expired.last_active = datetime.now(timezone.utc) - timedelta(seconds=1) + active = AgentSession(user_id="U2", channel_id="C2", command_name="ask", bot_response_id="new-response") + store.put("old-key", expired) + store.put("new-key", active) + + removed = store.cleanup_expired() + + assert removed == 1 + assert "old-response" not in store._response_index + assert store.get_by_response_id("new-response") is active + + +class TestSessionIntegration: + """Test session creation and resumption through AgentCommand.""" + + def test_preexecute_creates_session(self, cmd, bot_command): + with patch("csp_bot.commands.agent._executor") as mock_executor: + mock_executor.submit.return_value = MagicMock(spec=Future) + cmd.preexecute(bot_command) + + key = cmd._session_key(bot_command) + session = AgentCommand._sessions.get(key) + assert session is not None + assert session.user_id == "U123" + assert session.channel_id == "C456" + assert session.command_name == "test-agent" + + def test_execute_stores_message_history(self, cmd, bot_command): + key = cmd._command_key(bot_command) + mock_future = MagicMock(spec=Future) + mock_future.done.return_value = True + mock_result = MagicMock() + mock_result.output = "Some answer" + mock_history = [MagicMock(), MagicMock()] + mock_result.all_messages.return_value = mock_history + AgentCommand._futures[key] = mock_future + mock_future.result.return_value = mock_result + + # Pre-create session as preexecute would + cmd._create_session(bot_command) + + result = cmd.execute(bot_command) + assert isinstance(result, Message) + + session = AgentCommand._sessions.get(cmd._session_key(bot_command)) + assert session is not None + assert session.message_history == mock_history + + def test_session_resumed_on_reply(self, cmd, bot_command): + """Simulate a reply to a bot message resuming an existing session.""" + from chatom.base.message import MessageReference + + # First: create a session and register a response ID + session = cmd._create_session(bot_command) + session.message_history = [MagicMock(), MagicMock()] + AgentCommand._sessions.update_response_id(cmd._session_key(bot_command), "bot-msg-123") + + # Second: create a reply message referencing the bot response + reply_msg = Message( + id="msg2", + content="Tell me more", + author=User(id="U123", name="Test User"), + reference=MessageReference(message_id="bot-msg-123"), + ) + reply_command = BotCommand( + command="test-agent", + args=("Tell me more",), + source=User(id="U123", name="Test User"), + targets=(), + channel_id="C456", + channel_name="general", + backend="slack", + variant=CommandVariant.REPLY, + message=reply_msg, + delay=datetime.now(timezone.utc), + schedule="", + times_run=0, + ) + + # The session should be found via the reference + found = cmd._get_session(reply_command) + assert found is session + assert len(found.message_history) == 2 + + def test_session_key_format(self, cmd, bot_command): + key = cmd._session_key(bot_command) + assert key == "test-agent:U123:C456" + + def test_response_metadata_includes_session_key(self, cmd, bot_command): + """Execute should include agent_session_key in response metadata.""" + key = cmd._command_key(bot_command) + mock_future = MagicMock(spec=Future) + mock_future.done.return_value = True + mock_result = MagicMock() + mock_result.output = "Answer" + mock_result.all_messages.return_value = [] + mock_future.result.return_value = mock_result + AgentCommand._futures[key] = mock_future + + cmd._create_session(bot_command) + result = cmd.execute(bot_command) + + assert isinstance(result, Message) + assert result.metadata["agent_session_key"] == "test-agent:U123:C456" + + +class TestMultimodalPrompt: + """Incoming image attachments are passed to the model as multimodal input.""" + + @staticmethod + def _image_command(backend_name="slack"): + from chatom.base import Image + + msg = Message( + id="msg-img", + content="/test-agent what is this", + author=User(id="U123", name="Test User"), + attachments=[Image(id="att1", filename="pic.png", content_type="image/png")], + ) + return BotCommand( + command="test-agent", + args=("what", "is", "this"), + source=User(id="U123", name="Test User"), + targets=(), + channel_id="C456", + channel_name="general", + backend=backend_name, + variant=CommandVariant.REPLY, + message=msg, + delay=datetime.now(timezone.utc), + schedule="", + times_run=0, + ) + + def test_incoming_image_becomes_binary_content(self, cmd): + from unittest.mock import AsyncMock + + from pydantic_ai import BinaryContent + + backend = MagicMock(spec=BackendBase) + backend.download_attachment = AsyncMock(return_value=b"PNGBYTES") + AgentCommand.set_backends({"slack": backend}) + try: + command = self._image_command() + parts = cmd._build_model_prompt(command, "what is this") + finally: + AgentCommand._backends = {} + AgentCommand._backend_loops = {} + + assert isinstance(parts, list) + assert parts[0] == "what is this" + binaries = [p for p in parts if isinstance(p, BinaryContent)] + assert len(binaries) == 1 + assert binaries[0].data == b"PNGBYTES" + assert binaries[0].media_type == "image/png" + backend.download_attachment.assert_awaited_once() + + def test_no_images_returns_plain_prompt(self, cmd, bot_command): + backend = MagicMock(spec=BackendBase) + AgentCommand.set_backends({"slack": backend}) + try: + # bot_command has no attachments + result = cmd._build_model_prompt(bot_command, "plain prompt") + finally: + AgentCommand._backends = {} + + assert result == "plain prompt" + + def test_disabled_flag_returns_plain_prompt(self, cmd): + from unittest.mock import AsyncMock + + backend = MagicMock(spec=BackendBase) + backend.download_attachment = AsyncMock(return_value=b"x") + AgentCommand.set_backends({"slack": backend}) + cmd.include_incoming_images = False + try: + command = self._image_command() + result = cmd._build_model_prompt(command, "prompt text") + finally: + AgentCommand._backends = {} + cmd.include_incoming_images = True + + assert result == "prompt text" + backend.download_attachment.assert_not_called() + + def test_download_failure_falls_back_to_text(self, cmd): + from unittest.mock import AsyncMock + + backend = MagicMock(spec=BackendBase) + backend.download_attachment = AsyncMock(side_effect=RuntimeError("boom")) + AgentCommand.set_backends({"slack": backend}) + try: + command = self._image_command() + result = cmd._build_model_prompt(command, "prompt text") + finally: + AgentCommand._backends = {} + + # No image could be attached → fall back to the plain text prompt. + assert result == "prompt text" diff --git a/csp_bot/tests/test_command_framework.py b/csp_bot/tests/test_command_framework.py index afe2ce9..6254bf4 100644 --- a/csp_bot/tests/test_command_framework.py +++ b/csp_bot/tests/test_command_framework.py @@ -554,6 +554,100 @@ def test_formatted_message(self): assert "*bold*" in result.content assert result.metadata["backend"] == "slack" + def test_formatted_message_with_image(self): + from chatom.format.attachment import FormattedImage + + fm = FormattedMessage( + content=[ + Text(content="Check this:"), + FormattedImage(url="https://example.com/chart.png", alt_text="chart"), + ] + ) + result = _coerce_response(fm, "discord") + assert isinstance(result, Message) + assert len(result.attachments) == 1 + assert result.attachments[0].url == "https://example.com/chart.png" + assert result.attachments[0].attachment_type.value == "image" + + def test_formatted_message_with_attachment(self): + from chatom.format.attachment import FormattedAttachment + + fm = FormattedMessage( + content=[Text(content="Here's the report:")], + attachments=[ + FormattedAttachment( + filename="report.pdf", + url="https://example.com/report.pdf", + content_type="application/pdf", + size=2048, + ) + ], + ) + result = _coerce_response(fm, "slack") + assert isinstance(result, Message) + assert len(result.attachments) == 1 + att = result.attachments[0] + assert att.filename == "report.pdf" + assert att.url == "https://example.com/report.pdf" + assert att.size == 2048 + + def test_formatted_message_with_mixed_content_and_attachments(self): + from chatom.format.attachment import FormattedAttachment, FormattedImage + + fm = FormattedMessage( + content=[ + Text(content="Results:"), + FormattedImage(url="https://example.com/img.png", alt_text="img"), + ], + attachments=[ + FormattedAttachment( + filename="data.csv", + url="https://example.com/data.csv", + content_type="text/csv", + ) + ], + ) + result = _coerce_response(fm, "symphony") + assert isinstance(result, Message) + assert len(result.attachments) == 2 # 1 image from content + 1 from attachments + + def test_formatted_message_preserves_binary_image_data(self): + """Binary image bytes must survive so the send path can upload them.""" + from chatom.format.attachment import FormattedImage + + fm = FormattedMessage( + content=[ + Text(content="Generated chart:"), + FormattedImage(data=b"PNGBYTES", filename="chart.png", content_type="image/png"), + ] + ) + result = _coerce_response(fm, "slack") + assert isinstance(result, Message) + assert len(result.attachments) == 1 + att = result.attachments[0] + assert att.has_data is True + assert att.data == b"PNGBYTES" + assert att.filename == "chart.png" + assert att.content_type == "image/png" + + def test_formatted_message_preserves_binary_attachment_data(self): + """Binary file bytes must survive for upload via the send path.""" + from chatom.format.attachment import FormattedAttachment + + fm = FormattedMessage( + content=[Text(content="Report:")], + attachments=[ + FormattedAttachment(filename="report.pdf", data=b"%PDF-1.7", content_type="application/pdf"), + ], + ) + result = _coerce_response(fm, "slack") + assert isinstance(result, Message) + assert len(result.attachments) == 1 + att = result.attachments[0] + assert att.has_data is True + assert att.data == b"%PDF-1.7" + assert att.attachment_type.value == "document" + def test_unknown_type(self): result = _coerce_response(42, "slack") assert isinstance(result, Message) diff --git a/csp_bot/tests/test_persistence.py b/csp_bot/tests/test_persistence.py new file mode 100644 index 0000000..591003f --- /dev/null +++ b/csp_bot/tests/test_persistence.py @@ -0,0 +1,213 @@ +"""Tests for bot runtime persistence helpers.""" + +from datetime import datetime, timedelta, timezone + +from chatom import Message, User + +from csp_bot.persistence import FsspecStateStore, InMemoryStateStore, ScheduledCommandRecord, ScheduleStore, StoredRecord +from csp_bot.structs import BotCommand, CommandVariant + + +def _make_command(command: str = "echo", message_id: str = "msg1") -> BotCommand: + return BotCommand( + command=command, + args=("hello",), + source=User(id="U123", name="Test User"), + targets=(), + channel_id="C456", + channel_name="general", + backend="slack", + variant=CommandVariant.REPLY, + message=Message(id=message_id, content=f"/{command} hello"), + delay=datetime.now(timezone.utc) + timedelta(minutes=5), + schedule="", + times_run=0, + ) + + +class TestInMemoryStateStore: + def test_stored_record_expiry_boundaries(self): + now = datetime.now(timezone.utc) + + assert StoredRecord("namespace", "never", "value", now, now).is_expired(now) is False + assert StoredRecord("namespace", "future", "value", now, now, now + timedelta(seconds=1)).is_expired(now) is False + assert StoredRecord("namespace", "now", "value", now, now, now).is_expired(now) is True + assert StoredRecord("namespace", "past", "value", now, now, now - timedelta(seconds=1)).is_expired(now) is True + + def test_put_and_get(self): + store = InMemoryStateStore() + + store.put("namespace", "key", {"value": 1}) + + assert store.get("namespace", "key") == {"value": 1} + + def test_get_returns_default_for_missing_key(self): + store = InMemoryStateStore() + + assert store.get("namespace", "missing", default="fallback") == "fallback" + + def test_ttl_expiry_removes_record_on_get(self): + store = InMemoryStateStore() + store.put("namespace", "key", "value", ttl_seconds=-1) + + assert store.get("namespace", "key") is None + assert store.records("namespace") == [] + + def test_ttl_zero_expires_immediately(self): + store = InMemoryStateStore() + store.put("namespace", "key", "value", ttl_seconds=0) + + assert store.get("namespace", "key") is None + + def test_records_filters_by_namespace_and_prefix(self): + store = InMemoryStateStore() + store.put("schedules", "slack:1", 1) + store.put("schedules", "slack:2", 2) + store.put("schedules", "discord:1", 3) + store.put("sessions", "slack:1", 4) + + records = store.records("schedules", prefix="slack:") + + assert [record.value for record in records] == [1, 2] + + def test_cleanup_expired_can_target_namespace(self): + store = InMemoryStateStore() + store.put("schedules", "expired", 1, ttl_seconds=-1) + store.put("sessions", "expired", 2, ttl_seconds=-1) + + removed = store.cleanup_expired("schedules") + + assert removed == 1 + assert store.records("schedules") == [] + assert store.get("sessions", "expired") is None + + def test_overwrite_preserves_created_at_and_updates_updated_at(self): + store = InMemoryStateStore() + first = store.put("namespace", "key", "first") + second = store.put("namespace", "key", "second") + + assert second.created_at == first.created_at + assert second.updated_at >= first.updated_at + assert store.get("namespace", "key") == "second" + + +class TestFsspecStateStore: + def test_persists_across_instances(self, tmp_path): + url = str(tmp_path / "state") + first = FsspecStateStore(url) + first.put("namespace", "key", {"value": 1}) + + second = FsspecStateStore(url) + + assert second.get("namespace", "key") == {"value": 1} + + def test_records_filters_by_namespace_and_prefix(self, tmp_path): + store = FsspecStateStore(str(tmp_path / "state")) + store.put("schedules", "slack:1", 1) + store.put("schedules", "slack:2", 2) + store.put("schedules", "discord:1", 3) + store.put("sessions", "slack:1", 4) + + records = store.records("schedules", prefix="slack:") + + assert [record.value for record in records] == [1, 2] + + def test_ttl_expiry_removes_record_on_get(self, tmp_path): + store = FsspecStateStore(str(tmp_path / "state")) + store.put("namespace", "key", "value", ttl_seconds=0) + + assert store.get("namespace", "key") is None + assert store.records("namespace") == [] + + def test_clear_namespace(self, tmp_path): + store = FsspecStateStore(str(tmp_path / "state")) + store.put("schedules", "one", 1) + store.put("sessions", "one", 2) + + assert store.clear("schedules") == 1 + assert store.records("schedules") == [] + assert store.get("sessions", "one") == 2 + + +class TestScheduleStore: + def test_put_and_get_schedule_record(self): + store = ScheduleStore(InMemoryStateStore()) + command = _make_command() + + record = store.put(command, schedule_id="schedule-1") + + assert isinstance(record, ScheduledCommandRecord) + assert store.get("schedule-1") == record + assert record.command.command == command.command + assert record.command.message.id == command.message.id + assert record.command.schedule_id == "schedule-1" + assert record.next_run_at == command.delay.replace(tzinfo=timezone.utc) + + def test_generated_ids_allow_same_command_name(self): + store = ScheduleStore(InMemoryStateStore()) + first = _make_command(command="echo", message_id="msg1") + second = _make_command(command="echo", message_id="msg2") + + first_record = store.put(first) + second_record = store.put(second) + + assert first_record.schedule_id != second_record.schedule_id + assert [record.command.message.id for record in store.records()] == [first.message.id, second.message.id] + + def test_remove_schedule_record(self): + store = ScheduleStore(InMemoryStateStore()) + store.put(_make_command(), schedule_id="schedule-1") + + assert store.remove("schedule-1") is True + assert store.get("schedule-1") is None + + def test_recurring_record_uses_command_schedule(self): + store = ScheduleStore(InMemoryStateStore()) + command = _make_command() + command.schedule = "*/5 * * * *" + + record = store.put(command, schedule_id="recurring") + + assert record.is_recurring is True + + def test_next_run_override(self): + store = ScheduleStore(InMemoryStateStore()) + command = _make_command() + next_run_at = datetime.now(timezone.utc) + timedelta(hours=1) + + record = store.put(command, schedule_id="schedule-1", next_run_at=next_run_at) + + assert record.next_run_at == next_run_at + + def test_created_at_preserved_across_updates(self): + store = ScheduleStore(InMemoryStateStore()) + command = _make_command() + first = store.put(command, schedule_id="schedule-1") + command.args = ("updated",) + second = store.put(command, schedule_id="schedule-1") + + assert second.created_at == first.created_at + assert second.updated_at >= first.updated_at + assert second.command.args == ("updated",) + + def test_cleanup_expired(self): + store = ScheduleStore(InMemoryStateStore()) + store.put(_make_command(), schedule_id="expired", ttl_seconds=-1) + + assert store.cleanup_expired() == 1 + assert store.records() == [] + + def test_fsspec_round_trips_scheduled_command(self, tmp_path): + url = str(tmp_path / "state") + first = ScheduleStore(FsspecStateStore(url)) + command = _make_command(command="echo", message_id="msg1") + first.put(command, schedule_id="schedule-1") + + second = ScheduleStore(FsspecStateStore(url)) + record = second.get("schedule-1") + + assert record is not None + assert record.schedule_id == "schedule-1" + assert record.command.command == "echo" + assert record.command.message.id == "msg1" + assert record.command.source.id == "U123" diff --git a/csp_bot/tests/test_schedule.py b/csp_bot/tests/test_schedule.py new file mode 100644 index 0000000..43d99c5 --- /dev/null +++ b/csp_bot/tests/test_schedule.py @@ -0,0 +1,107 @@ +"""Tests for schedule command persistence integration.""" + +from datetime import datetime, timedelta, timezone + +from chatom import Message, User + +from csp_bot import Bot, BotConfig +from csp_bot.commands.schedule import ScheduleCommand +from csp_bot.persistence import InMemoryStateStore, ScheduleStore +from csp_bot.structs import BotCommand, CommandVariant + + +def _make_command(command: str = "echo", message_id: str = "msg1") -> BotCommand: + return BotCommand( + command=command, + args=("hello",), + source=User(id="U123", name="Test User"), + targets=(), + channel_id="C456", + channel_name="general", + backend="slack", + variant=CommandVariant.REPLY, + message=Message(id=message_id, content=f"/{command} hello"), + delay=datetime.now(timezone.utc) + timedelta(minutes=5), + schedule="", + times_run=0, + ) + + +def test_schedule_list_uses_stable_schedule_id(): + schedule_store = ScheduleStore(InMemoryStateStore()) + schedule_store.put(_make_command(), schedule_id="schedule-1") + command = _make_command(command="schedule", message_id="list") + command.args = ("list",) + + message = ScheduleCommand().execute(command, schedule_store) + + assert "schedule-1" in message.content + assert "/echo" in message.content + + +def test_schedule_remove_uses_stable_schedule_id(): + schedule_store = ScheduleStore(InMemoryStateStore()) + schedule_store.put(_make_command(), schedule_id="schedule-1") + bot = Bot(config=BotConfig()) + bot.set_schedule_store(schedule_store) + command = _make_command(command="schedule", message_id="remove") + command.args = ("remove", "schedule-1") + + result = ScheduleCommand().preexecute(command, schedule_store, bot) + + assert result is None + assert schedule_store.get("schedule-1") is None + + +def test_bot_store_scheduled_command_allows_duplicate_command_names(): + bot = Bot(config=BotConfig()) + first = _make_command(command="echo", message_id="msg1") + second = _make_command(command="echo", message_id="msg2") + + first_record = bot._store_scheduled_command(first, first.delay) + second_record = bot._store_scheduled_command(second, second.delay) + + assert first_record.schedule_id != second_record.schedule_id + assert {record.command.message.id for record in bot._schedule_store.records()} == {"msg1", "msg2"} + + +def test_bot_set_state_store_injects_existing_schedule_records(): + state_store = InMemoryStateStore() + schedule_store = ScheduleStore(state_store) + schedule_store.put(_make_command(), schedule_id="schedule-1") + bot = Bot(config=BotConfig()) + + bot.set_state_store(state_store) + + assert bot._schedule_store.get("schedule-1") is not None + assert bot._schedule_store.get("schedule-1").command.message.id == "msg1" + + +def test_bot_restore_scheduled_commands_skips_past_records(): + bot = Bot(config=BotConfig()) + now = datetime.now(timezone.utc) + past = _make_command(message_id="past") + future = _make_command(message_id="future") + bot._schedule_store.put(past, schedule_id="past", next_run_at=now - timedelta(minutes=1)) + bot._schedule_store.put(future, schedule_id="future", next_run_at=now + timedelta(minutes=1)) + + restored = bot._restore_scheduled_commands(now) + + assert [record.schedule_id for record in restored] == ["future"] + + +def test_bot_reschedules_recurring_command_with_same_schedule_id(): + bot = Bot(config=BotConfig()) + command = _make_command(command="echo", message_id="recurring") + command.schedule = "*/5 * * * *" + command.schedule_id = "schedule-1" + first_time = datetime.now(timezone.utc) + timedelta(minutes=5) + second_time = first_time + timedelta(minutes=5) + + first_record = bot._store_scheduled_command(command, first_time) + second_record = bot._store_scheduled_command(command, second_time) + + assert first_record.schedule_id == second_record.schedule_id == "schedule-1" + assert second_record.created_at == first_record.created_at + assert second_record.next_run_at == second_time + assert [record.schedule_id for record in bot._schedule_store.records()] == ["schedule-1"] diff --git a/pyproject.toml b/pyproject.toml index 88e5f8e..bf18821 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,12 @@ dependencies = [ ] [project.optional-dependencies] +agent = [ + "chatom[agent]", +] +persistence = [ + "fsspec", +] develop = [ "build", "bump-my-version", @@ -59,6 +65,7 @@ develop = [ "csp-adapter-slack>=0.4,<0.5", "csp-adapter-symphony>=0.4,<0.5", "hatchling", + "fsspec", "mdformat", "mdformat-tables>=1", "pytest",