From 15880536ba03db6921f675bcd6d959b8cc4bd272 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 6 Jun 2026 22:51:59 +0000 Subject: [PATCH] Add Financial-Guardrail & Telemetry Harness with Arize Phoenix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wraps Bedrock AgentCore execution to audit token I/O at every discrete reasoning step, acting as a financial circuit breaker that automatically halts agents exceeding session budgets or stuck in repetitive loops. New modules: - src/factor/harness/ — budget tracker, loop detector, circuit breaker, Phoenix OTLP telemetry, and GuardedBedrockModel proxy - docker-compose.yml — self-hosted Arize Phoenix (UI :6006, OTLP :4317) - tests/test_harness.py — 25 tests covering all harness components Wiring: - All 5 agent creation functions accept optional circuit breaker - SSE stream emits guardrail_halt events on budget/loop violations - New endpoints: /api/v1/sessions/{id}/budget, /api/v1/guardrail/status - Config: phoenix_*, guardrail_* settings with Sonnet pricing defaults https://claude.ai/code/session_016EaoZrc2jxt1SNHipXbCpM --- docker-compose.yml | 22 +++ pyproject.toml | 2 + requirements.txt | 2 + src/factor/agents/analysis.py | 13 +- src/factor/agents/coordinator.py | 15 +- src/factor/agents/ingestion.py | 13 +- src/factor/agents/knowledge.py | 13 +- src/factor/agents/reporting.py | 13 +- src/factor/app.py | 64 +++++++ src/factor/aws/observability.py | 20 +- src/factor/config.py | 11 ++ src/factor/harness/__init__.py | 33 ++++ src/factor/harness/budget.py | 88 +++++++++ src/factor/harness/circuit_breaker.py | 108 +++++++++++ src/factor/harness/exceptions.py | 29 +++ src/factor/harness/guardrail.py | 100 ++++++++++ src/factor/harness/loop_detector.py | 73 ++++++++ src/factor/harness/model_wrapper.py | 78 ++++++++ src/factor/harness/telemetry.py | 115 ++++++++++++ tests/test_harness.py | 255 ++++++++++++++++++++++++++ 20 files changed, 1055 insertions(+), 12 deletions(-) create mode 100644 docker-compose.yml create mode 100644 src/factor/harness/__init__.py create mode 100644 src/factor/harness/budget.py create mode 100644 src/factor/harness/circuit_breaker.py create mode 100644 src/factor/harness/exceptions.py create mode 100644 src/factor/harness/guardrail.py create mode 100644 src/factor/harness/loop_detector.py create mode 100644 src/factor/harness/model_wrapper.py create mode 100644 src/factor/harness/telemetry.py create mode 100644 tests/test_harness.py diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..b28a8aa --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,22 @@ +version: "3.9" + +services: + phoenix: + image: arizephoenix/phoenix:latest + container_name: factor-phoenix + ports: + - "6006:6006" # UI + HTTP collector + - "4317:4317" # OTLP gRPC collector + environment: + - PHOENIX_WORKING_DIR=/phoenix_data + volumes: + - phoenix_data:/phoenix_data + restart: unless-stopped + healthcheck: + test: ["CMD", "wget", "--spider", "-q", "http://localhost:6006/healthz"] + interval: 10s + timeout: 5s + retries: 3 + +volumes: + phoenix_data: diff --git a/pyproject.toml b/pyproject.toml index fccc52b..62a30cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,8 @@ dependencies = [ "httpx>=0.27.0", "opentelemetry-api>=1.27.0", "opentelemetry-sdk>=1.27.0", + "opentelemetry-exporter-otlp-proto-http>=1.27.0", + "arize-phoenix-otel>=0.6.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 7e55222..279928d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,8 @@ python-multipart>=0.0.9 httpx>=0.27.0 opentelemetry-api>=1.27.0 opentelemetry-sdk>=1.27.0 +opentelemetry-exporter-otlp-proto-http>=1.27.0 +arize-phoenix-otel>=0.6.0 pytest>=8.3.0 pytest-asyncio>=0.24.0 moto>=5.0.0 diff --git a/src/factor/agents/analysis.py b/src/factor/agents/analysis.py index 7625071..4ef2b32 100644 --- a/src/factor/agents/analysis.py +++ b/src/factor/agents/analysis.py @@ -9,6 +9,8 @@ from factor.agents.prompts import ANALYSIS_PROMPT from factor.config import settings +from factor.harness.circuit_breaker import CircuitBreaker +from factor.harness.model_wrapper import GuardedBedrockModel from factor.tools.detection import detect_provision_type from factor.tools.scoring import score_risk from factor.tools.gaps import find_gaps @@ -17,22 +19,29 @@ logger = logging.getLogger(__name__) -def create_analysis_agent() -> Agent: +def create_analysis_agent(breaker: CircuitBreaker | None = None) -> Agent: """Create and return the Analysis Agent. The Analysis Agent handles provision classification, risk scoring, gap analysis, and cross-document comparison. + + Args: + breaker: Optional circuit breaker for financial guardrails. """ model = BedrockModel( model_id=settings.bedrock_model_id, region_name=settings.aws_region, ) + if breaker is not None: + model = GuardedBedrockModel(model, breaker) + agent = Agent( model=model, system_prompt=ANALYSIS_PROMPT, tools=[detect_provision_type, score_risk, find_gaps, compare_across_documents], ) - logger.info("Created Analysis Agent with model=%s", settings.bedrock_model_id) + logger.info("Created Analysis Agent with model=%s guarded=%s", + settings.bedrock_model_id, breaker is not None) return agent diff --git a/src/factor/agents/coordinator.py b/src/factor/agents/coordinator.py index d39b75b..7ca2251 100644 --- a/src/factor/agents/coordinator.py +++ b/src/factor/agents/coordinator.py @@ -11,6 +11,8 @@ from factor.agents.prompts import COORDINATOR_PROMPT from factor.config import settings +from factor.harness.circuit_breaker import CircuitBreaker +from factor.harness.model_wrapper import GuardedBedrockModel from factor.tools.parsing import parse_pdf, parse_docx from factor.tools.chunking import chunk_provisions from factor.tools.detection import detect_provision_type @@ -185,17 +187,23 @@ def generate_report(analysis_results: dict, output_dir: str = "./reports") -> di return report -def create_coordinator_agent() -> Agent: +def create_coordinator_agent(breaker: CircuitBreaker | None = None) -> Agent: """Create and return the Coordinator Agent. The Coordinator orchestrates the full due diligence pipeline: - ingest → analyze → search knowledge → report. + ingest -> analyze -> search knowledge -> report. + + Args: + breaker: Optional circuit breaker for financial guardrails. """ model = BedrockModel( model_id=settings.bedrock_model_id, region_name=settings.aws_region, ) + if breaker is not None: + model = GuardedBedrockModel(model, breaker) + agent = Agent( model=model, system_prompt=COORDINATOR_PROMPT, @@ -207,5 +215,6 @@ def create_coordinator_agent() -> Agent: ], ) - logger.info("Created Coordinator Agent with model=%s", settings.bedrock_model_id) + logger.info("Created Coordinator Agent with model=%s guarded=%s", + settings.bedrock_model_id, breaker is not None) return agent diff --git a/src/factor/agents/ingestion.py b/src/factor/agents/ingestion.py index 7ec0ae6..9cf7bc6 100644 --- a/src/factor/agents/ingestion.py +++ b/src/factor/agents/ingestion.py @@ -9,28 +9,37 @@ from factor.agents.prompts import INGESTION_PROMPT from factor.config import settings +from factor.harness.circuit_breaker import CircuitBreaker +from factor.harness.model_wrapper import GuardedBedrockModel from factor.tools.parsing import parse_pdf, parse_docx from factor.tools.chunking import chunk_provisions logger = logging.getLogger(__name__) -def create_ingestion_agent() -> Agent: +def create_ingestion_agent(breaker: CircuitBreaker | None = None) -> Agent: """Create and return the Ingestion Agent. The Ingestion Agent handles document parsing (PDF, DOCX) and provision chunking using anchor patterns. + + Args: + breaker: Optional circuit breaker for financial guardrails. """ model = BedrockModel( model_id=settings.bedrock_model_id, region_name=settings.aws_region, ) + if breaker is not None: + model = GuardedBedrockModel(model, breaker) + agent = Agent( model=model, system_prompt=INGESTION_PROMPT, tools=[parse_pdf, parse_docx, chunk_provisions], ) - logger.info("Created Ingestion Agent with model=%s", settings.bedrock_model_id) + logger.info("Created Ingestion Agent with model=%s guarded=%s", + settings.bedrock_model_id, breaker is not None) return agent diff --git a/src/factor/agents/knowledge.py b/src/factor/agents/knowledge.py index b753b7e..b1d7fbe 100644 --- a/src/factor/agents/knowledge.py +++ b/src/factor/agents/knowledge.py @@ -9,6 +9,8 @@ from factor.agents.prompts import KNOWLEDGE_PROMPT from factor.config import settings +from factor.harness.circuit_breaker import CircuitBreaker +from factor.harness.model_wrapper import GuardedBedrockModel from factor.tools.rag import search_synthetic_knowledge from factor.tools.classification import classify_domain from factor.tools.citations import extract_citations @@ -16,22 +18,29 @@ logger = logging.getLogger(__name__) -def create_knowledge_agent() -> Agent: +def create_knowledge_agent(breaker: CircuitBreaker | None = None) -> Agent: """Create and return the Knowledge Agent. The Knowledge Agent searches the synthetic legal knowledge base, classifies provisions by domain, and extracts/labels citations. + + Args: + breaker: Optional circuit breaker for financial guardrails. """ model = BedrockModel( model_id=settings.bedrock_model_id, region_name=settings.aws_region, ) + if breaker is not None: + model = GuardedBedrockModel(model, breaker) + agent = Agent( model=model, system_prompt=KNOWLEDGE_PROMPT, tools=[search_synthetic_knowledge, classify_domain, extract_citations], ) - logger.info("Created Knowledge Agent with model=%s", settings.bedrock_model_id) + logger.info("Created Knowledge Agent with model=%s guarded=%s", + settings.bedrock_model_id, breaker is not None) return agent diff --git a/src/factor/agents/reporting.py b/src/factor/agents/reporting.py index 033f94a..61e9584 100644 --- a/src/factor/agents/reporting.py +++ b/src/factor/agents/reporting.py @@ -9,27 +9,36 @@ from factor.agents.prompts import REPORTING_PROMPT from factor.config import settings +from factor.harness.circuit_breaker import CircuitBreaker +from factor.harness.model_wrapper import GuardedBedrockModel from factor.tools.export import build_risk_report, export_excel, export_html logger = logging.getLogger(__name__) -def create_reporting_agent() -> Agent: +def create_reporting_agent(breaker: CircuitBreaker | None = None) -> Agent: """Create and return the Reporting Agent. The Reporting Agent assembles structured risk reports and exports them in multiple formats (JSON, Excel, HTML). + + Args: + breaker: Optional circuit breaker for financial guardrails. """ model = BedrockModel( model_id=settings.bedrock_model_id, region_name=settings.aws_region, ) + if breaker is not None: + model = GuardedBedrockModel(model, breaker) + agent = Agent( model=model, system_prompt=REPORTING_PROMPT, tools=[build_risk_report, export_excel, export_html], ) - logger.info("Created Reporting Agent with model=%s", settings.bedrock_model_id) + logger.info("Created Reporting Agent with model=%s guarded=%s", + settings.bedrock_model_id, breaker is not None) return agent diff --git a/src/factor/app.py b/src/factor/app.py index 984f324..0b2f129 100644 --- a/src/factor/app.py +++ b/src/factor/app.py @@ -16,6 +16,8 @@ from factor import DISCLAIMER, __version__ from factor.config import settings +from factor.harness.guardrail import get_guardrail +from factor.harness.exceptions import CircuitBreakerTripped from factor.tools.chunking import chunk_provisions from factor.tools.detection import detect_provision_type from factor.tools.scoring import score_risk @@ -66,6 +68,11 @@ async def configure_logging(): logging.getLogger("httpx").setLevel(logging.WARNING) logger.info("Logging configured: level=%s", settings.factor_log_level) + if settings.phoenix_enabled: + from factor.aws.observability import init_tracing + init_tracing("factor") + logger.info("Phoenix telemetry initialized") + @app.get("/api/v1/health") async def health_check(): @@ -120,10 +127,20 @@ async def analyze_documents(files: list[UploadFile] = File(...)): session_store.create_session(session_id, [f.filename or "" for f in files]) + guardrail = get_guardrail() + breaker = guardrail.register_session(session_id) if settings.guardrail_enabled else None + async def event_stream() -> AsyncGenerator[dict, None]: try: yield {"event": "session", "data": json.dumps({"session_id": session_id, "disclaimer": DISCLAIMER})} + if breaker: + yield {"event": "guardrail", "data": json.dumps({ + "stage": "initialized", + "budget_usd": settings.guardrail_session_budget_usd, + "max_steps": settings.guardrail_max_steps, + })} + yield {"event": "status", "data": json.dumps({"stage": "ingestion", "message": "Parsing documents..."})} all_provisions = {} @@ -164,6 +181,9 @@ async def event_stream() -> AsyncGenerator[dict, None]: risk["document_id"] = doc_id all_risk_scores.append(risk) + if breaker: + breaker.record_step(action="score_risk", meta={"doc_id": doc_id}) + gaps = find_gaps(detected_provisions=detected_types, doc_type="unknown") for gap in gaps: gap["document_id"] = doc_id @@ -186,9 +206,28 @@ async def event_stream() -> AsyncGenerator[dict, None]: session_store.store_result(session_id, report) + if breaker: + yield {"event": "guardrail", "data": json.dumps({ + "stage": "completed", + **breaker.status(), + })} + yield {"event": "report", "data": json.dumps(report)} yield {"event": "done", "data": json.dumps({"session_id": session_id, "disclaimer": DISCLAIMER})} + + except CircuitBreakerTripped as exc: + logger.warning("Circuit breaker halted session %s: %s", session_id, exc) + session_store.update_status(session_id, "halted") + yield {"event": "guardrail_halt", "data": json.dumps({ + "halted": True, + **exc.status, + "message": str(exc), + "disclaimer": DISCLAIMER, + })} + finally: + if settings.guardrail_enabled: + guardrail.remove_session(session_id) shutil.rmtree(upload_dir, ignore_errors=True) return EventSourceResponse(event_stream()) @@ -259,6 +298,31 @@ async def export_report( return {"path": path, "format": format, "disclaimer": DISCLAIMER} +@app.get("/api/v1/sessions/{session_id}/budget") +async def get_session_budget(session_id: str): + """Get real-time budget and guardrail status for a session.""" + guardrail = get_guardrail() + status = guardrail.session_status(session_id) + if status is None: + raise HTTPException(status_code=404, detail="No active guardrail for this session") + status["disclaimer"] = DISCLAIMER + return status + + +@app.get("/api/v1/guardrail/status") +async def guardrail_overview(): + """Get guardrail status across all active sessions.""" + guardrail = get_guardrail() + return { + "enabled": settings.guardrail_enabled, + "phoenix_enabled": settings.phoenix_enabled, + "phoenix_endpoint": settings.phoenix_otlp_endpoint, + "default_budget_usd": settings.guardrail_session_budget_usd, + "active_sessions": guardrail.all_sessions(), + "disclaimer": DISCLAIMER, + } + + @app.get("/api/v1/knowledge/search") async def search_knowledge( q: str = Query(..., min_length=1), diff --git a/src/factor/aws/observability.py b/src/factor/aws/observability.py index 74a81ee..4f627b6 100644 --- a/src/factor/aws/observability.py +++ b/src/factor/aws/observability.py @@ -1,4 +1,4 @@ -"""AgentCore Observability — OpenTelemetry tracing and CloudWatch integration.""" +"""AgentCore Observability — OpenTelemetry tracing with Phoenix export.""" from __future__ import annotations @@ -8,6 +8,8 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor, ConsoleSpanExporter +from factor.config import settings + logger = logging.getLogger(__name__) _initialized = False @@ -16,6 +18,10 @@ def init_tracing(service_name: str = "factor") -> trace.Tracer: """Initialize OpenTelemetry tracing. + When Phoenix is enabled, delegates to the harness telemetry module + which sets up the OTLP exporter and guardrail span processor. + Falls back to console exporter otherwise. + Args: service_name: Name of the service for trace attribution. @@ -25,6 +31,18 @@ def init_tracing(service_name: str = "factor") -> trace.Tracer: global _initialized if not _initialized: + if settings.phoenix_enabled: + try: + from factor.harness.telemetry import init_phoenix_tracing + tracer, processor = init_phoenix_tracing(service_name) + if processor is not None: + from factor.harness.guardrail import get_guardrail + processor.set_guardrail(get_guardrail()) + _initialized = True + return tracer + except Exception: + logger.warning("Failed to init Phoenix tracing, falling back to console", exc_info=True) + provider = TracerProvider() processor = SimpleSpanProcessor(ConsoleSpanExporter()) provider.add_span_processor(processor) diff --git a/src/factor/config.py b/src/factor/config.py index a916614..568b068 100644 --- a/src/factor/config.py +++ b/src/factor/config.py @@ -32,6 +32,17 @@ class Settings(BaseSettings): factor_s3_bucket: str = "factor-documents" factor_allowed_origins: str = "*" + # Phoenix / Guardrail Harness + phoenix_enabled: bool = True + phoenix_otlp_endpoint: str = "http://localhost:6006/v1/traces" + guardrail_enabled: bool = True + guardrail_session_budget_usd: float = 5.0 + guardrail_max_steps: int = 200 + guardrail_loop_window: int = 10 + guardrail_loop_threshold: int = 5 + guardrail_input_cost_per_1m: float = 3.0 + guardrail_output_cost_per_1m: float = 15.0 + # Cognito factor_cognito_user_pool_id: str = "" factor_cognito_client_id: str = "" diff --git a/src/factor/harness/__init__.py b/src/factor/harness/__init__.py new file mode 100644 index 0000000..135a7e3 --- /dev/null +++ b/src/factor/harness/__init__.py @@ -0,0 +1,33 @@ +"""Financial-Guardrail & Telemetry Harness for Factor agents. + +Wraps Bedrock AgentCore execution to audit token I/O at every reasoning +step, enforce per-session budgets, and halt agents stuck in repetitive +high-cost reasoning loops. Traces are exported to Arize Phoenix via OTLP. +""" + +from factor.harness.exceptions import ( + CircuitBreakerTripped, + BudgetExceededError, + ReasoningLoopError, +) +from factor.harness.budget import SessionBudget +from factor.harness.loop_detector import LoopDetector +from factor.harness.circuit_breaker import CircuitBreaker +from factor.harness.guardrail import FinancialGuardrail, get_guardrail + +__all__ = [ + "CircuitBreakerTripped", + "BudgetExceededError", + "ReasoningLoopError", + "SessionBudget", + "LoopDetector", + "CircuitBreaker", + "FinancialGuardrail", + "get_guardrail", +] + + +def guarded_model(breaker): + """Lazy import to avoid requiring strands at import time.""" + from factor.harness.model_wrapper import guarded_model as _guarded_model + return _guarded_model(breaker) diff --git a/src/factor/harness/budget.py b/src/factor/harness/budget.py new file mode 100644 index 0000000..9a0da8d --- /dev/null +++ b/src/factor/harness/budget.py @@ -0,0 +1,88 @@ +"""Per-session token cost accounting.""" + +from __future__ import annotations + +import threading +import time + + +class SessionBudget: + """Accumulates token costs for a single agent session. + + Pricing defaults match Anthropic Claude Sonnet on Bedrock + ($3 / 1M input tokens, $15 / 1M output tokens). + """ + + def __init__( + self, + max_budget_usd: float, + input_cost_per_1m: float = 3.0, + output_cost_per_1m: float = 15.0, + max_steps: int = 200, + ): + self.max_budget_usd = max_budget_usd + self.input_cost_per_1m = input_cost_per_1m + self.output_cost_per_1m = output_cost_per_1m + self.max_steps = max_steps + + self._lock = threading.Lock() + self._input_tokens = 0 + self._output_tokens = 0 + self._steps = 0 + self._history: list[dict] = [] + self._created_at = time.time() + + def record(self, input_tokens: int, output_tokens: int, meta: dict | None = None) -> None: + with self._lock: + self._input_tokens += input_tokens + self._output_tokens += output_tokens + self._steps += 1 + self._history.append({ + "step": self._steps, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cumulative_cost_usd": self._cost_unlocked(), + "timestamp": time.time(), + "meta": meta or {}, + }) + + @property + def total_cost_usd(self) -> float: + with self._lock: + return self._cost_unlocked() + + def _cost_unlocked(self) -> float: + input_cost = (self._input_tokens / 1_000_000) * self.input_cost_per_1m + output_cost = (self._output_tokens / 1_000_000) * self.output_cost_per_1m + return input_cost + output_cost + + @property + def is_over_budget(self) -> bool: + with self._lock: + return self._cost_unlocked() >= self.max_budget_usd + + @property + def is_over_step_limit(self) -> bool: + with self._lock: + return self._steps >= self.max_steps + + def status(self) -> dict: + with self._lock: + cost = self._cost_unlocked() + return { + "input_tokens": self._input_tokens, + "output_tokens": self._output_tokens, + "total_tokens": self._input_tokens + self._output_tokens, + "total_cost_usd": round(cost, 6), + "budget_limit_usd": self.max_budget_usd, + "budget_remaining_usd": round(max(0, self.max_budget_usd - cost), 6), + "budget_utilization_pct": round((cost / self.max_budget_usd) * 100, 2) if self.max_budget_usd > 0 else 0, + "steps": self._steps, + "max_steps": self.max_steps, + "elapsed_seconds": round(time.time() - self._created_at, 2), + } + + @property + def history(self) -> list[dict]: + with self._lock: + return list(self._history) diff --git a/src/factor/harness/circuit_breaker.py b/src/factor/harness/circuit_breaker.py new file mode 100644 index 0000000..ed4f7eb --- /dev/null +++ b/src/factor/harness/circuit_breaker.py @@ -0,0 +1,108 @@ +"""Circuit breaker — combines budget enforcement and loop detection.""" + +from __future__ import annotations + +import logging + +from factor.harness.budget import SessionBudget +from factor.harness.loop_detector import LoopDetector +from factor.harness.exceptions import BudgetExceededError, ReasoningLoopError + +logger = logging.getLogger(__name__) + + +class CircuitBreaker: + """Financial circuit breaker for a single agent session. + + Enforces two independent trip conditions: + 1. Token spend exceeds the session budget. + 2. Agent enters a repetitive reasoning loop. + + When either condition fires, `check()` raises the corresponding + exception, hard-halting agent execution. + """ + + def __init__( + self, + session_id: str, + max_budget_usd: float = 5.0, + input_cost_per_1m: float = 3.0, + output_cost_per_1m: float = 15.0, + max_steps: int = 200, + loop_window: int = 10, + loop_threshold: int = 5, + ): + self.session_id = session_id + self.budget = SessionBudget( + max_budget_usd=max_budget_usd, + input_cost_per_1m=input_cost_per_1m, + output_cost_per_1m=output_cost_per_1m, + max_steps=max_steps, + ) + self.loop_detector = LoopDetector( + window_size=loop_window, + threshold=loop_threshold, + ) + self._tripped = False + self._trip_reason: str | None = None + + def record_step( + self, + input_tokens: int = 0, + output_tokens: int = 0, + action: str = "model_call", + meta: dict | None = None, + ) -> None: + """Record a reasoning step and check trip conditions.""" + self.budget.record(input_tokens, output_tokens, meta) + self.loop_detector.record(action) + + logger.debug( + "Session %s step: action=%s in=%d out=%d cost=$%.4f", + self.session_id, action, input_tokens, output_tokens, + self.budget.total_cost_usd, + ) + + self.check() + + def check(self) -> None: + """Raise if any trip condition is met.""" + if self.budget.is_over_budget: + self._tripped = True + self._trip_reason = "budget_exceeded" + status = self.status() + logger.warning("CIRCUIT BREAKER: budget exceeded for session %s — $%.4f / $%.2f", + self.session_id, status["total_cost_usd"], status["budget_limit_usd"]) + raise BudgetExceededError(status) + + if self.budget.is_over_step_limit: + self._tripped = True + self._trip_reason = "step_limit_exceeded" + status = self.status() + logger.warning("CIRCUIT BREAKER: step limit exceeded for session %s — %d / %d", + self.session_id, status["steps"], status["max_steps"]) + raise BudgetExceededError(status) + + if self.loop_detector.is_looping: + self._tripped = True + self._trip_reason = "reasoning_loop" + status = self.status() + logger.warning("CIRCUIT BREAKER: reasoning loop detected for session %s — %s", + self.session_id, status.get("loop_signature")) + raise ReasoningLoopError(status) + + @property + def is_tripped(self) -> bool: + return self._tripped + + def status(self) -> dict: + budget_status = self.budget.status() + loop_status = self.loop_detector.status() + return { + "session_id": self.session_id, + "tripped": self._tripped, + "reason": self._trip_reason, + **budget_status, + "loop_signature": loop_status["loop_signature"], + "is_looping": loop_status["is_looping"], + } diff --git a/src/factor/harness/exceptions.py b/src/factor/harness/exceptions.py new file mode 100644 index 0000000..31d480b --- /dev/null +++ b/src/factor/harness/exceptions.py @@ -0,0 +1,29 @@ +"""Circuit breaker exceptions — raised to hard-halt agent execution.""" + +from __future__ import annotations + + +class CircuitBreakerTripped(Exception): + """Base exception for all circuit breaker halts.""" + + def __init__(self, status: dict): + self.status = status + super().__init__(self._format(status)) + + @staticmethod + def _format(status: dict) -> str: + reason = status.get("reason", "unknown") + cost = status.get("total_cost_usd", 0) + steps = status.get("total_steps", 0) + return ( + f"Circuit breaker tripped: {reason} " + f"(cost=${cost:.4f}, steps={steps})" + ) + + +class BudgetExceededError(CircuitBreakerTripped): + """Session token spend exceeded the configured budget.""" + + +class ReasoningLoopError(CircuitBreakerTripped): + """Agent is stuck in a repetitive high-cost reasoning loop.""" diff --git a/src/factor/harness/guardrail.py b/src/factor/harness/guardrail.py new file mode 100644 index 0000000..6c7982b --- /dev/null +++ b/src/factor/harness/guardrail.py @@ -0,0 +1,100 @@ +"""FinancialGuardrail — manages per-session circuit breakers.""" + +from __future__ import annotations + +import logging +import threading + +from factor.config import settings +from factor.harness.circuit_breaker import CircuitBreaker +from factor.harness.exceptions import CircuitBreakerTripped + +logger = logging.getLogger(__name__) + +_instance: FinancialGuardrail | None = None +_instance_lock = threading.Lock() + + +class FinancialGuardrail: + """Central registry of per-session circuit breakers. + + One instance is shared across the application. Each analysis session + gets its own `CircuitBreaker` with independent budget and loop state. + """ + + def __init__(self): + self._breakers: dict[str, CircuitBreaker] = {} + self._lock = threading.Lock() + + def register_session(self, session_id: str) -> CircuitBreaker: + """Create and register a circuit breaker for a new session.""" + breaker = CircuitBreaker( + session_id=session_id, + max_budget_usd=settings.guardrail_session_budget_usd, + input_cost_per_1m=settings.guardrail_input_cost_per_1m, + output_cost_per_1m=settings.guardrail_output_cost_per_1m, + max_steps=settings.guardrail_max_steps, + loop_window=settings.guardrail_loop_window, + loop_threshold=settings.guardrail_loop_threshold, + ) + with self._lock: + self._breakers[session_id] = breaker + logger.info( + "Registered circuit breaker for session %s (budget=$%.2f, max_steps=%d)", + session_id, settings.guardrail_session_budget_usd, settings.guardrail_max_steps, + ) + return breaker + + def get_breaker(self, session_id: str) -> CircuitBreaker | None: + with self._lock: + return self._breakers.get(session_id) + + def remove_session(self, session_id: str) -> None: + with self._lock: + self._breakers.pop(session_id, None) + + def record_from_span( + self, + session_id: str | None, + input_tokens: int, + output_tokens: int, + action: str, + ) -> None: + """Called by `GuardrailSpanProcessor` when a span with token data ends. + + Silently ignores unknown session IDs (spans from background work). + """ + if session_id is None: + return + breaker = self.get_breaker(session_id) + if breaker is None: + return + try: + breaker.record_step( + input_tokens=input_tokens, + output_tokens=output_tokens, + action=action, + ) + except CircuitBreakerTripped: + # Let it propagate up the call stack to halt the agent + raise + + def session_status(self, session_id: str) -> dict | None: + breaker = self.get_breaker(session_id) + if breaker is None: + return None + return breaker.status() + + def all_sessions(self) -> list[dict]: + with self._lock: + return [b.status() for b in self._breakers.values()] + + +def get_guardrail() -> FinancialGuardrail: + """Return the singleton FinancialGuardrail instance.""" + global _instance + if _instance is None: + with _instance_lock: + if _instance is None: + _instance = FinancialGuardrail() + return _instance diff --git a/src/factor/harness/loop_detector.py b/src/factor/harness/loop_detector.py new file mode 100644 index 0000000..439f8f0 --- /dev/null +++ b/src/factor/harness/loop_detector.py @@ -0,0 +1,73 @@ +"""Repetitive reasoning loop detection.""" + +from __future__ import annotations + +import threading +from collections import deque + + +class LoopDetector: + """Detects when an agent is stuck repeating the same reasoning pattern. + + Maintains a sliding window of recent actions and flags a loop when the + same action signature appears `threshold` times within the window. + """ + + def __init__(self, window_size: int = 10, threshold: int = 5): + self.window_size = window_size + self.threshold = threshold + self._lock = threading.Lock() + self._window: deque[str] = deque(maxlen=window_size) + self._loop_detected = False + self._loop_signature: str | None = None + + def record(self, action_signature: str) -> None: + """Record an action and check for loops. + + Args: + action_signature: A string identifying the action (e.g. + tool name, or tool_name + hashed args). + """ + with self._lock: + self._window.append(action_signature) + self._check_unlocked() + + def _check_unlocked(self) -> None: + if len(self._window) < self.threshold: + return + + counts: dict[str, int] = {} + for sig in self._window: + counts[sig] = counts.get(sig, 0) + 1 + + for sig, count in counts.items(): + if count >= self.threshold: + self._loop_detected = True + self._loop_signature = sig + return + + tail = list(self._window)[-self.threshold:] + if len(set(tail)) == 1: + self._loop_detected = True + self._loop_signature = tail[0] + + @property + def is_looping(self) -> bool: + with self._lock: + return self._loop_detected + + def status(self) -> dict: + with self._lock: + return { + "is_looping": self._loop_detected, + "loop_signature": self._loop_signature, + "window": list(self._window), + "window_size": self.window_size, + "threshold": self.threshold, + } + + def reset(self) -> None: + with self._lock: + self._window.clear() + self._loop_detected = False + self._loop_signature = None diff --git a/src/factor/harness/model_wrapper.py b/src/factor/harness/model_wrapper.py new file mode 100644 index 0000000..56b2a1f --- /dev/null +++ b/src/factor/harness/model_wrapper.py @@ -0,0 +1,78 @@ +"""GuardedBedrockModel — wraps Strands BedrockModel with circuit breaker checks.""" + +from __future__ import annotations + +import logging + +from strands.models.bedrock import BedrockModel + +from factor.config import settings +from factor.harness.circuit_breaker import CircuitBreaker + +logger = logging.getLogger(__name__) + + +class GuardedBedrockModel: + """Proxy around `BedrockModel` that enforces financial guardrails. + + Before every model invocation the circuit breaker is checked. If it + has already tripped (budget exceeded or reasoning loop) the call is + blocked immediately without consuming additional tokens. + + All other attribute access is delegated transparently to the + underlying `BedrockModel` so Strands sees a duck-type-compatible + model object. + """ + + def __init__(self, model: BedrockModel, breaker: CircuitBreaker): + # Store on the instance dict directly to avoid __getattr__ recursion + object.__setattr__(self, "_model", model) + object.__setattr__(self, "_breaker", breaker) + + def __getattr__(self, name: str): + return getattr(self._model, name) + + def __setattr__(self, name: str, value): + if name.startswith("_"): + object.__setattr__(self, name, value) + else: + setattr(self._model, name, value) + + def __call__(self, *args, **kwargs): + """Intercept the model call to enforce guardrails.""" + self._breaker.check() + return self._model(*args, **kwargs) + + def converse(self, *args, **kwargs): + self._breaker.check() + return self._model.converse(*args, **kwargs) + + def invoke(self, *args, **kwargs): + self._breaker.check() + return self._model.invoke(*args, **kwargs) + + def update_config(self, *args, **kwargs): + self._breaker.check() + return self._model.update_config(*args, **kwargs) + + def format_request(self, *args, **kwargs): + return self._model.format_request(*args, **kwargs) + + def format_chunk(self, *args, **kwargs): + return self._model.format_chunk(*args, **kwargs) + + +def guarded_model(breaker: CircuitBreaker) -> GuardedBedrockModel: + """Create a GuardedBedrockModel with the standard config. + + Args: + breaker: The circuit breaker instance for this session. + + Returns: + A model proxy that halts on budget / loop violations. + """ + model = BedrockModel( + model_id=settings.bedrock_model_id, + region_name=settings.aws_region, + ) + return GuardedBedrockModel(model, breaker) diff --git a/src/factor/harness/telemetry.py b/src/factor/harness/telemetry.py new file mode 100644 index 0000000..5b88fab --- /dev/null +++ b/src/factor/harness/telemetry.py @@ -0,0 +1,115 @@ +"""Phoenix OTLP telemetry — exports OpenTelemetry traces to Arize Phoenix.""" + +from __future__ import annotations + +import logging + +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider, ReadableSpan +from opentelemetry.sdk.trace.export import ( + SimpleSpanProcessor, + SpanExporter, + SpanExportResult, + ConsoleSpanExporter, +) + +from factor.config import settings + +logger = logging.getLogger(__name__) + +_initialized = False + + +class GuardrailSpanProcessor(SimpleSpanProcessor): + """SpanProcessor that extracts token usage from OTel spans and feeds + it to the active circuit breaker for the session. + + Works with both OpenInference and OpenTelemetry GenAI semantic + conventions for token count attributes. + """ + + # Attribute names used by different instrumentation libraries + TOKEN_ATTR_MAPS = [ + # OpenTelemetry GenAI semantic conventions + ("gen_ai.usage.input_tokens", "gen_ai.usage.output_tokens"), + # OpenInference / Phoenix conventions + ("llm.token_count.prompt", "llm.token_count.completion"), + # Strands-specific (if present) + ("strands.input_tokens", "strands.output_tokens"), + ] + + def __init__(self, exporter: SpanExporter): + super().__init__(exporter) + self._guardrail = None + + def set_guardrail(self, guardrail) -> None: + self._guardrail = guardrail + + def on_end(self, span: ReadableSpan) -> None: + super().on_end(span) + + if self._guardrail is None: + return + + attrs = span.attributes or {} + input_tokens = 0 + output_tokens = 0 + + for input_key, output_key in self.TOKEN_ATTR_MAPS: + inp = attrs.get(input_key, 0) + out = attrs.get(output_key, 0) + if inp or out: + input_tokens = int(inp) + output_tokens = int(out) + break + + if input_tokens or output_tokens: + session_id = attrs.get("session.id") or attrs.get("factor.session_id") + action = span.name or "unknown" + self._guardrail.record_from_span( + session_id=session_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + action=action, + ) + + +def init_phoenix_tracing(service_name: str = "factor") -> tuple[trace.Tracer, GuardrailSpanProcessor | None]: + """Initialize OpenTelemetry with Phoenix OTLP exporter. + + Falls back to ConsoleSpanExporter when Phoenix deps are unavailable + or the endpoint is not configured. + + Returns: + Tuple of (tracer, guardrail_processor). The processor is None + when Phoenix is not available. + """ + global _initialized + + if _initialized: + return trace.get_tracer(service_name), None + + provider = TracerProvider() + guardrail_processor = None + + if settings.phoenix_enabled and settings.phoenix_otlp_endpoint: + try: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + phoenix_exporter = OTLPSpanExporter( + endpoint=settings.phoenix_otlp_endpoint, + ) + guardrail_processor = GuardrailSpanProcessor(phoenix_exporter) + provider.add_span_processor(guardrail_processor) + logger.info("Phoenix OTLP exporter configured: %s", settings.phoenix_otlp_endpoint) + except ImportError: + logger.warning("opentelemetry-exporter-otlp not installed, falling back to console") + provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) + else: + provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) + logger.info("Phoenix disabled or endpoint not set, using console exporter") + + trace.set_tracer_provider(provider) + _initialized = True + + return trace.get_tracer(service_name), guardrail_processor diff --git a/tests/test_harness.py b/tests/test_harness.py new file mode 100644 index 0000000..c31ec24 --- /dev/null +++ b/tests/test_harness.py @@ -0,0 +1,255 @@ +"""Tests for the Financial-Guardrail & Telemetry Harness.""" + +from __future__ import annotations + +import pytest + +from factor.harness.budget import SessionBudget +from factor.harness.loop_detector import LoopDetector +from factor.harness.circuit_breaker import CircuitBreaker +from factor.harness.exceptions import ( + BudgetExceededError, + ReasoningLoopError, + CircuitBreakerTripped, +) +from factor.harness.guardrail import FinancialGuardrail + + +# --------------------------------------------------------------------------- +# SessionBudget +# --------------------------------------------------------------------------- + +class TestSessionBudget: + def test_initial_state(self): + budget = SessionBudget(max_budget_usd=5.0) + status = budget.status() + assert status["input_tokens"] == 0 + assert status["output_tokens"] == 0 + assert status["total_cost_usd"] == 0 + assert status["budget_limit_usd"] == 5.0 + assert not budget.is_over_budget + + def test_cost_calculation(self): + budget = SessionBudget( + max_budget_usd=10.0, + input_cost_per_1m=3.0, + output_cost_per_1m=15.0, + ) + budget.record(1_000_000, 0) + assert budget.total_cost_usd == pytest.approx(3.0) + + budget.record(0, 1_000_000) + assert budget.total_cost_usd == pytest.approx(18.0) + + def test_over_budget(self): + budget = SessionBudget(max_budget_usd=0.01, input_cost_per_1m=3.0) + budget.record(100_000, 0) + assert budget.is_over_budget + + def test_step_limit(self): + budget = SessionBudget(max_budget_usd=100.0, max_steps=3) + budget.record(10, 10) + budget.record(10, 10) + assert not budget.is_over_step_limit + budget.record(10, 10) + assert budget.is_over_step_limit + + def test_history(self): + budget = SessionBudget(max_budget_usd=10.0) + budget.record(100, 200, {"tool": "detect"}) + budget.record(300, 400, {"tool": "score"}) + assert len(budget.history) == 2 + assert budget.history[0]["input_tokens"] == 100 + assert budget.history[1]["meta"]["tool"] == "score" + + def test_utilization_percentage(self): + budget = SessionBudget(max_budget_usd=10.0, input_cost_per_1m=10.0) + budget.record(500_000, 0) + status = budget.status() + assert status["budget_utilization_pct"] == pytest.approx(50.0) + + +# --------------------------------------------------------------------------- +# LoopDetector +# --------------------------------------------------------------------------- + +class TestLoopDetector: + def test_no_loop_initially(self): + detector = LoopDetector(window_size=10, threshold=5) + assert not detector.is_looping + + def test_detects_repetitive_actions(self): + detector = LoopDetector(window_size=10, threshold=3) + detector.record("tool_a") + detector.record("tool_a") + assert not detector.is_looping + detector.record("tool_a") + assert detector.is_looping + assert detector.status()["loop_signature"] == "tool_a" + + def test_mixed_actions_no_loop(self): + detector = LoopDetector(window_size=10, threshold=5) + for i in range(10): + detector.record(f"tool_{i}") + assert not detector.is_looping + + def test_reset_clears_state(self): + detector = LoopDetector(window_size=5, threshold=3) + for _ in range(3): + detector.record("stuck") + assert detector.is_looping + detector.reset() + assert not detector.is_looping + assert detector.status()["loop_signature"] is None + + def test_window_rotation(self): + detector = LoopDetector(window_size=5, threshold=4) + for _ in range(3): + detector.record("tool_a") + for _ in range(5): + detector.record("tool_b") + status = detector.status() + assert len(status["window"]) == 5 + assert all(a == "tool_b" for a in status["window"]) + + +# --------------------------------------------------------------------------- +# CircuitBreaker +# --------------------------------------------------------------------------- + +class TestCircuitBreaker: + def test_normal_operation(self): + breaker = CircuitBreaker( + session_id="test-session", + max_budget_usd=10.0, + max_steps=100, + ) + breaker.record_step(input_tokens=100, output_tokens=50, action="analyze") + assert not breaker.is_tripped + + def test_budget_trip(self): + breaker = CircuitBreaker( + session_id="test-session", + max_budget_usd=0.001, + input_cost_per_1m=3.0, + ) + with pytest.raises(BudgetExceededError) as exc_info: + breaker.record_step(input_tokens=1_000_000, output_tokens=0, action="big_call") + assert breaker.is_tripped + assert exc_info.value.status["reason"] == "budget_exceeded" + + def test_step_limit_trip(self): + breaker = CircuitBreaker( + session_id="test-session", + max_budget_usd=1000.0, + max_steps=3, + ) + breaker.record_step(input_tokens=10, output_tokens=10, action="a") + breaker.record_step(input_tokens=10, output_tokens=10, action="b") + with pytest.raises(BudgetExceededError) as exc_info: + breaker.record_step(input_tokens=10, output_tokens=10, action="c") + assert exc_info.value.status["reason"] == "step_limit_exceeded" + + def test_loop_trip(self): + breaker = CircuitBreaker( + session_id="test-session", + max_budget_usd=1000.0, + max_steps=1000, + loop_window=10, + loop_threshold=3, + ) + with pytest.raises(ReasoningLoopError) as exc_info: + for _ in range(5): + breaker.record_step(input_tokens=10, output_tokens=10, action="stuck_tool") + assert breaker.is_tripped + assert exc_info.value.status["reason"] == "reasoning_loop" + assert exc_info.value.status["loop_signature"] == "stuck_tool" + + def test_status_report(self): + breaker = CircuitBreaker(session_id="test-session", max_budget_usd=5.0) + breaker.record_step(input_tokens=1000, output_tokens=500, action="analyze") + status = breaker.status() + assert status["session_id"] == "test-session" + assert status["tripped"] is False + assert status["input_tokens"] == 1000 + assert status["output_tokens"] == 500 + assert status["steps"] == 1 + assert "total_cost_usd" in status + + +# --------------------------------------------------------------------------- +# FinancialGuardrail +# --------------------------------------------------------------------------- + +class TestFinancialGuardrail: + def test_register_and_retrieve(self): + guardrail = FinancialGuardrail() + breaker = guardrail.register_session("session-1") + assert breaker is not None + assert guardrail.get_breaker("session-1") is breaker + + def test_session_status(self): + guardrail = FinancialGuardrail() + guardrail.register_session("session-1") + status = guardrail.session_status("session-1") + assert status is not None + assert status["session_id"] == "session-1" + + def test_unknown_session_returns_none(self): + guardrail = FinancialGuardrail() + assert guardrail.session_status("nonexistent") is None + assert guardrail.get_breaker("nonexistent") is None + + def test_remove_session(self): + guardrail = FinancialGuardrail() + guardrail.register_session("session-1") + guardrail.remove_session("session-1") + assert guardrail.get_breaker("session-1") is None + + def test_record_from_span(self): + guardrail = FinancialGuardrail() + guardrail.register_session("session-1") + guardrail.record_from_span( + session_id="session-1", + input_tokens=100, + output_tokens=50, + action="model_call", + ) + status = guardrail.session_status("session-1") + assert status["input_tokens"] == 100 + assert status["steps"] == 1 + + def test_record_from_span_unknown_session_is_silent(self): + guardrail = FinancialGuardrail() + guardrail.record_from_span( + session_id="unknown", + input_tokens=100, + output_tokens=50, + action="model_call", + ) + + def test_all_sessions(self): + guardrail = FinancialGuardrail() + guardrail.register_session("s1") + guardrail.register_session("s2") + sessions = guardrail.all_sessions() + assert len(sessions) == 2 + ids = {s["session_id"] for s in sessions} + assert ids == {"s1", "s2"} + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + +class TestExceptions: + def test_circuit_breaker_tripped_hierarchy(self): + assert issubclass(BudgetExceededError, CircuitBreakerTripped) + assert issubclass(ReasoningLoopError, CircuitBreakerTripped) + + def test_exception_carries_status(self): + status = {"reason": "budget_exceeded", "total_cost_usd": 5.01, "total_steps": 42} + exc = BudgetExceededError(status) + assert exc.status is status + assert "budget_exceeded" in str(exc) + assert "5.01" in str(exc)