diff --git a/docs/superpowers/plans/2026-06-09-producer-trace-validation.md b/docs/superpowers/plans/2026-06-09-producer-trace-validation.md new file mode 100644 index 0000000000..92e068d54c --- /dev/null +++ b/docs/superpowers/plans/2026-06-09-producer-trace-validation.md @@ -0,0 +1,467 @@ +# Producer Trace Validation Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add a focused validation suite for producer trace so we can merge with confidence that end-to-end trace capture, viewer artifacts, failure visibility, and default `trace_function` semantics all work on real training runs. + +**Architecture:** Split validation into two layers. Keep cheap semantic checks in fast/unit tests, and keep real LMDeploy + Ray validation in three opt-in smoke runs backed by a small artifact assertion helper. Avoid a separate training run per concern by making the trace-enabled smoke cover both the remote return-value chain and the multi-batch latest/all scope behavior. + +**Tech Stack:** `unittest`, existing trace/viewer/hotspot builders, real `run_rl.sh` training entry, LMDeploy-backed rollout, Ray actors, JSONL trace shards, static HTML generation. + +--- + +## Scope decision + +These checks are all worth keeping, but they should not all become separate expensive jobs: + +1. **Trace enabled real smoke:** required. +2. **Ray remote chain correctness:** required, but fold into the trace-enabled smoke by asserting `.end` status from real JSONL. +3. **Trace disabled regression:** required. +4. **One representative failure path:** required; choose **judger failure** first because it is easier to make deterministic than rollout timeout and less infrastructure-heavy than abort. +5. **Multi-batch latest/all scope:** required, but fold into the trace-enabled smoke by running **2 produce batches** instead of 1. +6. **Explicit `target="group"` path:** required, but this is a **fast/unit test**, not a real training smoke. +7. **Compile / unit / type gate:** required. + +So the final validation matrix should be: + +- **Fast:** unit tests for explicit `group` target and existing trace semantics +- **Smoke A:** trace enabled + 2 produce batches +- **Smoke B:** trace disabled +- **Smoke C:** judger failure +- **Gate:** compileall + `tests/rl/test_trace.py` + targeted mypy/lint + +## File structure + +**Modify** +- `tests/rl/test_trace.py` + - Add the missing explicit `target="group"` semantic test. +- `examples/v1/config/rl_grpo_gsm8k_judge.py` + - Do not modify this file for smoke execution; keep it as the main example. + +**Create** +- `examples/v1/config/testing/rl_trace_smoke_enabled.py` + - Tiny real-training config with `trace_config.enabled=True`, `total_train_steps=2`, small batch sizes, and trace viewer enabled. +- `examples/v1/config/testing/rl_trace_smoke_disabled.py` + - Same tiny config, but `trace_config.enabled=False`. +- `examples/v1/config/testing/rl_trace_smoke_judger_fail.py` + - Same tiny config, but with a deterministic failing judger to produce `.error` or terminal failed status. +- `tests/rl/trace_smoke_assertions.py` + - Pure Python artifact assertion helper / CLI for checking trace dir, JSONL, viewer payload, hotspot payload, latest/all scope, and stage/status expectations from a completed smoke run. +- `docs/zh_cn/developer_guide/trace_validation.md` or `docs/superpowers/specs/trace-validation.md` (optional follow-up) + - Only if we want to document the exact smoke commands for future maintainers. Not required for the first landing. + +## Validation targets + +The artifact assertion helper must verify these exact conditions: + +### Enabled smoke +- Rank0 log contains `Producer Trace Viewer: http://...` +- `work_dir/producer_trace/producer_trace_*.jsonl` exists and is non-empty +- `build_viewer_payload_from_events(...)` returns: + - `task_summary.total_tasks > 0` + - `task_summary.running_tasks >= 0` + - `task_summary.completed_tasks >= 0` + - `task_summary.current_stage_counts` non-empty during the run or `latest_stage_counts` non-empty after the run +- `build_hotspot_payload_from_events(..., scope="latest-produce-batch")` succeeds +- `build_hotspot_payload_from_events(..., scope="all")` succeeds +- At least one trace has all of these stages: + - `xtuner.agent_loop.generate_sample.start` + - `xtuner.agent_loop.generate_sample.end` + - `xtuner.rollout_controller.generate.start` + - `xtuner.rollout_controller.generate.end` + - `xtuner.rollout_worker.generate.start` + - `xtuner.rollout_worker.generate.end` +- For that trace, the `.end` event status is a terminal status such as `completed`, and is not stuck at the old start-only state +- At least **2 distinct `produce_batch_id`** values appear in the JSONL +- Viewer payload default scope shows only the latest batch +- Hotspot payload supports both `latest-produce-batch` and `all`, and `all.task_count >= latest.task_count` + +### Disabled smoke +- Rank0 log does **not** contain `Producer Trace Viewer:` +- `work_dir/producer_trace` does not exist, or exists but has no `producer_trace_*.jsonl` +- No offline hotspot HTML is generated by default + +### Judger failure smoke +- Trace JSONL exists +- At least one event stage ends with `.error`, preferably `xtuner.judger.judge.error` +- If the failure path surfaces as terminal failed status instead of `.error`, assert at least one trace ends in `status in {"failed", "aborted"}` with a non-empty `error_msg` +- Failure must be visible in artifacts; it must not silently disappear from JSONL + +### Fast semantic test +- A function decorated with `@trace_function(..., target="group")` and accepting `list[RolloutState]` records start/end for each task in the group +- This must stay a unit test, not a training smoke + +## Task 1: Add the missing fast semantic test for explicit `target="group"` + +**Files:** +- Modify: `tests/rl/test_trace.py` + +- [ ] **Step 1: Add the failing test** + +```python +async def test_trace_function_respects_explicit_group_target(self): + store = InMemoryTraceStore(TraceConfig(enabled=True, max_events=20, max_events_per_trace=20)) + + @trace_function("custom.group", target="group") + async def traced_group(group: list[RolloutState]) -> list[RolloutState]: + for state in group: + state.status = Status.COMPLETED + return group + + states = [make_state(uid=1), make_state(uid=2)] + with use_trace_recorder(TraceRecorder(store)): + await traced_group(states) + + self.assertEqual( + [event.stage for event in store.get_timeline("gsm8k:1")], + ["custom.group.start", "custom.group.end"], + ) + self.assertEqual( + [event.stage for event in store.get_timeline("gsm8k:2")], + ["custom.group.start", "custom.group.end"], + ) +``` + +- [ ] **Step 2: Run the single test and verify it fails first** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceCoreBehaviorTest.test_trace_function_respects_explicit_group_target -v +``` + +Expected: FAIL before implementation if the resolver is wrong. + +- [ ] **Step 3: Implement the minimal test-only changes** + +If the current resolver already supports this path, the implementation step is only adding the test. No production code should be changed unless the test reveals a real bug. + +- [ ] **Step 4: Run the focused test and full trace unit file** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceCoreBehaviorTest.test_trace_function_respects_explicit_group_target -v +python -m unittest discover -s tests/rl -p test_trace.py +``` + +Expected: PASS. + +## Task 2: Add a reusable artifact assertion helper for real smoke runs + +**Files:** +- Create: `tests/rl/trace_smoke_assertions.py` + +- [ ] **Step 1: Write a helper skeleton that can load JSONL and build both viewer payloads** + +```python +from __future__ import annotations + +import argparse +from pathlib import Path + +from xtuner.tools.producer_trace_analysis import load_trace_jsonl +from xtuner.tools.producer_trace_hotspots import build_hotspot_payload_from_events +from xtuner.tools.producer_trace_viewer import build_viewer_payload_from_events + + +def load_events(trace_dir: Path): + events = load_trace_jsonl(trace_dir) + if not events: + raise AssertionError(f"No trace events found under {trace_dir}") + return events +``` + +- [ ] **Step 2: Add focused assertion functions** + +```python +def assert_enabled_smoke(trace_dir: Path, trainer_log: Path) -> None: + text = trainer_log.read_text(encoding="utf-8") + assert "Producer Trace Viewer: http://" in text + events = load_events(trace_dir) + latest_payload = build_viewer_payload_from_events(events, trace_source=str(trace_dir)) + hotspot_latest = build_hotspot_payload_from_events(events, trace_source=str(trace_dir), scope="latest-produce-batch") + hotspot_all = build_hotspot_payload_from_events(events, trace_source=str(trace_dir), scope="all") + assert latest_payload["task_summary"]["total_tasks"] > 0 + assert hotspot_all["task_count"] >= hotspot_latest["task_count"] +``` + +```python +def assert_disabled_smoke(work_dir: Path, trainer_log: Path) -> None: + text = trainer_log.read_text(encoding="utf-8") + assert "Producer Trace Viewer:" not in text + trace_dir = work_dir / "producer_trace" + assert not any(trace_dir.glob("producer_trace_*.jsonl")) if trace_dir.exists() else True +``` + +```python +def assert_failure_smoke(trace_dir: Path) -> None: + events = load_events(trace_dir) + assert any(event.stage.endswith(".error") for event in events) or any( + event.status in {"failed", "aborted"} and event.error_msg for event in events + ) +``` + +- [ ] **Step 3: Add a CLI** + +```python +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--scenario", choices=["enabled", "disabled", "judger-fail"], required=True) + parser.add_argument("--work-dir", type=Path, required=True) + parser.add_argument("--trainer-log", type=Path, required=True) + args = parser.parse_args() +``` + +- [ ] **Step 4: Run a syntax-only check** + +Run: + +```bash +python -m compileall -q tests/rl/trace_smoke_assertions.py +``` + +Expected: PASS. + +## Task 3: Add a tiny trace-enabled smoke config that covers real trace, Ray return-value semantics, and multi-batch scope + +**Files:** +- Create: `examples/v1/config/testing/rl_trace_smoke_enabled.py` + +- [ ] **Step 1: Base the config on the smallest working GSM8K + judger path** + +Use the existing `examples/v1/config/rl_grpo_gsm8k_judge.py` as the reference, but reduce runtime: + +```python +train_optimizer_steps = 1 +train_batch_size = 2 +prompt_repeat_k = 1 +producer_trace_config = TraceConfig( + enabled=True, + output_dir=Path(work_dir) / "producer_trace", + viewer_enabled=True, +) +``` + +- [ ] **Step 2: Force at least 2 produce batches** + +```python +total_train_steps = 2 +``` + +If one train step can still consume the whole trace batch in the chosen config, lower the per-step sample count until two distinct `produce_batch_id` values are guaranteed. + +- [ ] **Step 3: Keep the real rollout stack** + +Do not stub out: +- rollout controller +- rollout worker +- judger +- Ray actor path + +This smoke exists precisely to verify the real remote chain. + +- [ ] **Step 4: Run the enabled smoke manually** + +Run: + +```bash +bash -x examples/v1/scripts/run_rl.sh \ + examples/v1/config/testing/rl_trace_smoke_enabled.py \ + "lmdeploy" \ + "$QWEN3_PATH" \ + "$GSM8K_DATA_PATH" \ + "$GSM8K_EVAL_DATA_PATH" +``` + +Expected: +- training completes +- rank0 log prints `Producer Trace Viewer: http://127.0.0.1:...` +- `work_dir/producer_trace/producer_trace_*.jsonl` exists + +- [ ] **Step 5: Run artifact assertions** + +Run: + +```bash +python tests/rl/trace_smoke_assertions.py \ + --scenario enabled \ + --work-dir "$WORK_DIR" \ + --trainer-log "$WORK_DIR/logs/train.log" +``` + +Expected: PASS. + +## Task 4: Add a trace-disabled smoke config + +**Files:** +- Create: `examples/v1/config/testing/rl_trace_smoke_disabled.py` + +- [ ] **Step 1: Reuse the tiny smoke config shape, but disable trace** + +```python +producer_trace_config = TraceConfig( + enabled=False, + output_dir=None, +) +``` + +- [ ] **Step 2: Keep everything else as close as possible to the enabled smoke** + +Only vary the trace switch. Do not silently shrink or change the rollout stack differently here, or the regression value drops. + +- [ ] **Step 3: Run the disabled smoke** + +Run: + +```bash +bash -x examples/v1/scripts/run_rl.sh \ + examples/v1/config/testing/rl_trace_smoke_disabled.py \ + "lmdeploy" \ + "$QWEN3_PATH" \ + "$GSM8K_DATA_PATH" \ + "$GSM8K_EVAL_DATA_PATH" +``` + +- [ ] **Step 4: Run artifact assertions** + +Run: + +```bash +python tests/rl/trace_smoke_assertions.py \ + --scenario disabled \ + --work-dir "$WORK_DIR" \ + --trainer-log "$WORK_DIR/logs/train.log" +``` + +Expected: PASS, with no viewer startup and no trace JSONL. + +## Task 5: Add a deterministic judger-failure smoke config + +**Files:** +- Create: `examples/v1/config/testing/rl_trace_smoke_judger_fail.py` + +- [ ] **Step 1: Introduce a deterministic failure path** + +Prefer failing judger construction or judger execution over rollout timeout. The failure should be explicit and deterministic. + +Example direction: + +```python +class AlwaysFailJudger(...): + async def judge(self, rollout_states): + raise RuntimeError("trace smoke judger failure") +``` + +Or point the tiny smoke config at a tiny judger implementation already in the repo if one exists. + +- [ ] **Step 2: Keep trace enabled** + +```python +producer_trace_config = TraceConfig( + enabled=True, + output_dir=Path(work_dir) / "producer_trace", + viewer_enabled=True, +) +``` + +- [ ] **Step 3: Run the failure smoke** + +Run: + +```bash +bash -x examples/v1/scripts/run_rl.sh \ + examples/v1/config/testing/rl_trace_smoke_judger_fail.py \ + "lmdeploy" \ + "$QWEN3_PATH" \ + "$GSM8K_DATA_PATH" \ + "$GSM8K_EVAL_DATA_PATH" +``` + +Expected: +- training fails or aborts predictably +- trace JSONL still exists +- at least one `.error` event or failed terminal status is visible + +- [ ] **Step 4: Run artifact assertions** + +Run: + +```bash +python tests/rl/trace_smoke_assertions.py \ + --scenario judger-fail \ + --work-dir "$WORK_DIR" \ + --trainer-log "$WORK_DIR/logs/train.log" +``` + +Expected: PASS. + +## Task 6: Add the merge gate command set + +**Files:** +- Modify: none required if commands live in PR description / team checklist +- Optional create: `docs/superpowers/specs/trace-validation.md` + +- [ ] **Step 1: Define the fast gate** + +Run: + +```bash +python -m compileall -q \ + xtuner/v1/rl/trace.py \ + xtuner/tools/producer_trace_analysis.py \ + xtuner/tools/producer_trace_viewer.py \ + xtuner/tools/producer_trace_hotspots.py \ + xtuner/v1/rl/agent_loop_manager/producer.py \ + tests/rl/test_trace.py \ + tests/rl/trace_smoke_assertions.py + +python -m unittest discover -s tests/rl -p test_trace.py +``` + +- [ ] **Step 2: Define the targeted type gate** + +Run the repo’s targeted type/lint command for these touched files only. If the local environment has `mypy` directly available: + +```bash +python -m mypy \ + xtuner/v1/rl/trace.py \ + xtuner/v1/rl/agent_loop_manager/producer.py \ + tests/rl/test_trace.py \ + tests/rl/trace_smoke_assertions.py +``` + +If the repo uses a wrapper command instead, use that wrapper instead of inventing a new type-check entrypoint. + +- [ ] **Step 3: Define the real pre-merge gate** + +Required before merging: + +```text +1. tests/rl/test_trace.py passes +2. enabled 2-batch smoke passes +3. disabled smoke passes +4. judger-fail smoke passes +``` + +## Self-review checklist + +- This plan keeps all required concerns, but avoids 6 separate training runs. +- The enabled smoke is intentionally doing extra work so it covers: + - viewer startup + - JSONL emission + - hotspot generation + - Ray remote `.end` status correctness + - latest vs all batch scope +- The explicit `group` target check stays a unit test, where it belongs. +- The failure path uses judger failure first because it is the lowest-flake representative error path. + +## Recommended execution order + +1. Task 1: add missing fast test +2. Task 2: add artifact assertion helper +3. Task 3: enabled 2-batch smoke +4. Task 4: disabled smoke +5. Task 5: judger-fail smoke +6. Task 6: final compile / unit / type gate + diff --git a/docs/superpowers/plans/2026-06-09-unified-trace-viewer.md b/docs/superpowers/plans/2026-06-09-unified-trace-viewer.md new file mode 100644 index 0000000000..5f6c81ae51 --- /dev/null +++ b/docs/superpowers/plans/2026-06-09-unified-trace-viewer.md @@ -0,0 +1,370 @@ +# Unified Trace Viewer Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Merge the current online producer trace viewer and offline hotspot viewer into one unified viewer that defaults to all tasks, shows overview + stage stats + task list + task detail, and uses one shared payload model for live and offline modes. + +**Architecture:** Move viewer analysis toward a single shared payload builder in `producer_trace_analysis.py`, then make both `producer_trace_viewer.py` and `producer_trace_hotspots.py` render the same page model. Keep live/offline differences in data loading only. Preserve `latest batch` as a client-visible filter option, but make `all tasks` the default semantic view. + +**Tech Stack:** Python dataclasses, existing `TraceEvent` JSONL shards, current producer trace analysis helpers, static HTML + inline JavaScript, `unittest`. + +--- + +## File Structure + +- `xtuner/tools/producer_trace_analysis.py` + - Shared analysis layer for unified task rows, stage stats, per-task chart rows, and per-scope payloads. +- `xtuner/tools/producer_trace_viewer.py` + - Live viewer server + offline snapshot entrypoint for the unified page. +- `xtuner/tools/producer_trace_hotspots.py` + - Becomes a thin compatibility wrapper around the unified offline page builder. +- `xtuner/v1/rl/trace.py` + - Update `TraceConfig.viewer_scope` default if unified viewer should default to all tasks. +- `tests/rl/test_trace.py` + - Unified viewer payload tests, default scope tests, and compatibility tests. +- `docs/superpowers/specs/2026-06-09-trace-next-phase-working-notes.md` + - Keep requirement decisions synchronized as implementation proceeds. + +## Task 1: Add Unified Analysis Payload + +**Files:** +- Modify: `xtuner/tools/producer_trace_analysis.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: Add the failing tests for unified summary semantics** + +Add tests that assert: + +- overview uses all tasks by default +- stage summary exposes: + - `running_tasks` + - `visited_tasks` + - `avg_s` + - `p95_s` + - `max_s` +- task detail data includes both text timeline events and graphical spans +- failed tasks are counted separately + +Sketch: + +```python +def test_unified_view_payload_reports_overview_stage_stats_and_task_detail(self): + payload = build_unified_trace_payload_from_events(events, trace_source="/tmp/trace") + self.assertEqual(payload["default_scope"], "all") + self.assertEqual(payload["views"]["all"]["overview"]["total_tasks"], 3) + self.assertEqual(payload["views"]["all"]["overview"]["failed_tasks"], 1) + stage = payload["views"]["all"]["stage_stats"][0] + self.assertIn("running_tasks", stage) + self.assertIn("visited_tasks", stage) + detail = payload["views"]["all"]["task_details"]["gsm8k:1"] + self.assertTrue(detail["timeline_events"]) + self.assertTrue(detail["timeline_spans"]) +``` + +- [ ] **Step 2: Run the targeted test to confirm it fails** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_unified_view_payload_reports_overview_stage_stats_and_task_detail +``` + +Expected: + +- FAIL because `build_unified_trace_payload_from_events` or equivalent fields do not exist yet. + +- [ ] **Step 3: Add shared dataclasses / payload builders in `producer_trace_analysis.py`** + +Implement shared analysis primitives instead of keeping viewer/hotspot summaries separate: + +- enhanced task row +- per-stage stats +- per-task detail payload +- scope-aware top-level payload + +The new payload should conceptually look like: + +```python +{ + "default_scope": "all", + "available_scopes": ["all", "latest-produce-batch"], + "views": { + "all": { + "overview": {...}, + "stage_stats": [...], + "task_rows": [...], + "task_details": {...}, + }, + "latest-produce-batch": {...}, + }, +} +``` + +- [ ] **Step 4: Re-run the targeted unified payload test** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_unified_view_payload_reports_overview_stage_stats_and_task_detail +``` + +Expected: + +- PASS + +## Task 2: Replace Separate Viewer/Hotspot Pages With One Unified Page + +**Files:** +- Modify: `xtuner/tools/producer_trace_viewer.py` +- Modify: `xtuner/tools/producer_trace_hotspots.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: Add failing tests for unified page structure** + +Add assertions that rendered HTML contains the new sections and no longer contains removed sections: + +```python +def test_unified_viewer_html_contains_new_sections(self): + html = render_unified_trace_html(payload, live=False) + self.assertIn("Total tasks", html) + self.assertIn("Failed", html) + self.assertIn("Stage", html) + self.assertIn("Task Timeline", html) + self.assertNotIn("Suspect Open Spans", html) + self.assertNotIn("Latest Stage Distribution", html) +``` + +- [ ] **Step 2: Run the targeted HTML test to confirm it fails** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_unified_viewer_html_contains_new_sections +``` + +Expected: + +- FAIL because the old HTML still renders the old layout. + +- [ ] **Step 3: Implement the unified HTML / JS page in `producer_trace_viewer.py`** + +Refactor page structure to: + +- header +- overview cards +- scope toggle +- stage summary table +- task list with filters +- task detail: + - text timeline + - chart timeline below + +The JS should: + +- switch between `all` and `latest-produce-batch` +- filter task rows by: + - state + - current stage + - search text +- render task detail for the selected row + +- [ ] **Step 4: Convert `producer_trace_hotspots.py` into a compatibility entrypoint** + +Make the offline hotspots script reuse the unified offline page builder instead of maintaining a separate page model. + +Compatibility behavior: + +- existing CLI entry still works +- output HTML is the unified viewer page +- offline mode loads static payload only + +- [ ] **Step 5: Re-run the targeted HTML test** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_unified_viewer_html_contains_new_sections +``` + +Expected: + +- PASS + +## Task 3: Flip Default Viewer Semantics to All Tasks + +**Files:** +- Modify: `xtuner/v1/rl/trace.py` +- Modify: `xtuner/tools/producer_trace_viewer.py` +- Modify: `xtuner/tools/producer_trace_hotspots.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: Add failing tests for default scope** + +Add tests that assert: + +- `TraceConfig.viewer_scope` defaults to `"all"` +- CLI default scope for unified viewer is `"all"` +- live payload chooses `all` as `default_scope` + +Sketch: + +```python +def test_trace_config_defaults_viewer_scope_to_all(self): + self.assertEqual(TraceConfig().viewer_scope, "all") +``` + +- [ ] **Step 2: Run the targeted default-scope tests** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_trace_config_defaults_viewer_scope_to_all +``` + +Expected: + +- FAIL because current default is `latest-produce-batch`. + +- [ ] **Step 3: Change default viewer scope to `all`** + +Update: + +- `TraceConfig.viewer_scope` +- CLI defaults for unified viewer/offline page entrypoints +- any tests or assumptions that still rely on `latest-produce-batch` as the default + +- [ ] **Step 4: Keep `latest-produce-batch` as an optional filter** + +Do not remove the capability. Keep it available in: + +- payload `available_scopes` +- UI scope selector +- offline CLI option + +- [ ] **Step 5: Re-run the default-scope tests** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_trace_config_defaults_viewer_scope_to_all +``` + +Expected: + +- PASS + +## Task 4: Add Viewer Tests for Failed Tasks and Task Detail Behavior + +**Files:** +- Modify: `tests/rl/test_trace.py` + +- [ ] **Step 1: Add failing tests for failed-task accounting** + +Add tests that assert: + +- `failed_tasks` is counted in overview +- failed tasks appear in `task_rows` +- `error_msg` appears in task detail only + +Sketch: + +```python +def test_unified_viewer_counts_failed_tasks_and_keeps_error_msg_in_task_detail(self): + payload = build_unified_trace_payload_from_events(events, trace_source="/tmp/trace") + overview = payload["views"]["all"]["overview"] + self.assertEqual(overview["failed_tasks"], 1) + row = next(row for row in payload["views"]["all"]["task_rows"] if row["trace_id"] == "gsm8k:9") + self.assertEqual(row["status"], "failed") + detail = payload["views"]["all"]["task_details"]["gsm8k:9"] + self.assertIn("trace smoke judger failure", json.dumps(detail, ensure_ascii=False)) + self.assertNotIn("error_msg", row) +``` + +- [ ] **Step 2: Run the targeted failed-task test** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_unified_viewer_counts_failed_tasks_and_keeps_error_msg_in_task_detail +``` + +Expected: + +- FAIL until failed-task handling and detail structure are correct. + +- [ ] **Step 3: Finish the analysis/payload wiring for failed tasks** + +Make sure: + +- overview counts failed tasks +- task rows expose status and current stage +- task detail contains full event records including `error_msg` +- task rows do not duplicate the full `error_msg` + +- [ ] **Step 4: Re-run the targeted failed-task test** + +Run: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_unified_viewer_counts_failed_tasks_and_keeps_error_msg_in_task_detail +``` + +Expected: + +- PASS + +## Task 5: Full Verification + +**Files:** +- Verify touched files only + +- [ ] **Step 1: Run unified trace tests** + +Run: + +```bash +python -m unittest discover -s tests/rl -p test_trace.py +``` + +Expected: + +- PASS + +- [ ] **Step 2: Run compile checks** + +Run: + +```bash +python -m compileall -q xtuner/tools/producer_trace_analysis.py xtuner/tools/producer_trace_viewer.py xtuner/tools/producer_trace_hotspots.py xtuner/v1/rl/trace.py tests/rl/test_trace.py +``` + +Expected: + +- PASS + +- [ ] **Step 3: Run diff sanity** + +Run: + +```bash +git diff --check +``` + +Expected: + +- no whitespace / merge-marker issues + +- [ ] **Step 4: Optional live smoke after unit verification** + +Run: + +```bash +bash -x examples/v1/scripts/run_rl.sh examples/v1/config/testing/rl_trace_smoke_enabled.py lmdeploy "$MODEL_PATH" "$DATA_PATH" "$EVAL_DATA_PATH" +``` + +Expected: + +- unified live viewer starts +- page defaults to all tasks +- scope selector can switch to latest batch + diff --git a/docs/superpowers/plans/2026-06-11-otel-trace-parity.md b/docs/superpowers/plans/2026-06-11-otel-trace-parity.md new file mode 100644 index 0000000000..0fda13882e --- /dev/null +++ b/docs/superpowers/plans/2026-06-11-otel-trace-parity.md @@ -0,0 +1,1455 @@ +# OTel Trace Parity Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** 在当前 trace 实现上补齐 `4f22de76b7efe336b9bb80a3a27b70d50117b18a` 的关键 OTel tracing 能力,同时尽可能少改 XTuner 业务代码。 + +**Architecture:** 继续以 `xtuner.v1.rl.trace` 作为 XTuner 内部唯一 tracing 门面,业务代码不直接 import OpenTelemetry。新增能力集中在 tracing helper、sampler 样本身份、agent loop task context 传播、session server HTTP context 传播和推理耗时 span 这几个边界点。Jaeger 启动脚本和说明文档放在 `recipe/otle`,不侵入训练主路径。 + +**Tech Stack:** Python, OpenTelemetry API/SDK, OTLP HTTP/gRPC exporter, Jaeger 2.x memory storage, aiohttp session server. + +--- + +## 设计原则 + +1. **最小侵入业务代码。** Producer、agent loop、rollout worker 已经有 `trace_function` / `trace_span` 的地方不重写;只在缺失链路传播的位置加少量调用。 +2. **不把 OpenTelemetry API 泄露到业务模块。** 除 `xtuner.v1.rl.trace` 外,其他模块只调用 XTuner trace helper。 +3. **显式新增 `RolloutState.trace_id`。** 当前阶段先让 trace identity 成为一等字段;如果后续确认它和 `uid` 语义完全重复,再把 `trace_id` 收敛到 `uid`。 +4. **默认保持当前行为。** 如果 trace 未开启,所有新增 helper 必须 no-op,不影响训练路径。 +5. **先补链路,再补展示。** 本计划只补 4f22 缺失的 OTel trace 能力;已有自研 viewer 不在本轮扩展。 +6. **不引入大范围异常兜底。** 只保留 tracing 自身失败不影响训练的 no-op 保护;业务异常仍按原路径抛出或记录。 + +## 需要覆盖的 4f22 功能 + +- OTel exporter 支持标准环境变量、OTLP HTTP、OTLP gRPC、console exporter。 +- sampler 为每个 task 生成稳定 `trace_id`,并写入 `RolloutState.trace_id`。 +- `trace.py` 统一把 `RolloutState.trace_id` 转成 OTel trace id 和查询 attributes;`case.id` 只作为兼容 4f22 查询习惯的 alias。 +- agent loop 把 `trace.py` 从 `RolloutState` 构造出的 trace attributes 写入 OTel task context 和 baggage。 +- session server 支持从 header/body 恢复 trace context,并向 worker 请求注入 context。 +- session server 记录 `forward_worker`、`stream_read`、`on_request`、`on_response` 等 span。 +- session server 记录首 chunk、首 output token、首 content、token 数、finish reason 等属性。 +- `recipe/otle` 提供 Jaeger memory 参考配置和说明文档;不提供启动脚本,避免把 tmux、curl、linux-amd64、外网下载等环境假设带进仓库。 + +## 明确不做 + +- 不把 example 里的 `tensor_parallel_size=2 -> 1` 当作 trace 功能补齐。 +- 除新增 `RolloutState.trace_id` 外,不改其他 `RolloutState` pydantic 字段。 +- 不在 producer / rollout controller / rollout worker 里直接调用 OTel API。 +- 不接入 metrics,不做 Prometheus/Grafana。 +- 不重写已有在线/offline viewer。 + +## 文件结构 + +### 修改 + +- `xtuner/v1/rl/trace.py` + - 增加 OTel exporter factory。 + - 增加标准 OTel 环境变量兼容。 + - 增加 context propagation helper。 + - 增加 task context / baggage helper。 + - 增加 `RolloutState.trace_id` 读取、稳定 OTel trace id 映射和 trace attributes 构造。 + - 增加无 task target 的低层 span helper,供 session server 内部使用。 + +- `xtuner/v1/data_proto/rl_data.py` + - 给 `RolloutState` 增加 `trace_id: str | None = None`。 + +- `xtuner/v1/rl/agent_loop_manager/sampler.py` + - 生成稳定 `trace_id`,写入 `RolloutState.trace_id`。 + +- `xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py` + - 从 `RolloutState` 构造 trace attributes。 + - 在 runner 执行期间通过 trace helper 写入 task context 和 baggage。 + +- `xtuner/v1/rl/rollout/session_server.py` + - 恢复和注入 trace context。 + - 增加 session server span 和推理时延属性。 + +- `pyproject.toml` + - 如果当前只依赖 gRPC exporter,补 `opentelemetry-exporter-otlp-proto-http` 或使用 `opentelemetry-exporter-otlp`。 + +### 新增 + +- `recipe/otle/README.md` + - 写清楚 Jaeger 端口、OTLP endpoint、环境变量、如何查 `case.id`。 + +- `recipe/otle/jaeger/jaeger-memory.yaml` + - Jaeger 2.x memory storage 配置。 + +### 测试 + +- `tests/rl/test_trace.py` + - 增加 trace id 生成、baggage no-op、context helper、session server helper 的轻量单测。 + +> 注意:用户已明确要求“没有允许不要跑单测”。执行阶段只在用户批准后运行测试命令。 + +--- + +## Task 1: 扩展 `TraceConfig` 和 OTel exporter factory + +**目的:** 当前分支只支持 OTLP gRPC。补齐 4f22 的 OTLP HTTP / gRPC / console exporter 能力,并兼容标准 OTel 环境变量。 + +**Files:** +- Modify: `xtuner/v1/rl/trace.py` +- Modify: `pyproject.toml` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 在 `TraceConfig` 增加最小配置字段** + +建议字段: + +```python +class TraceConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + enabled: bool = False + otel_endpoint: str = "http://127.0.0.1:4317" + otel_protocol: Literal["grpc", "http/protobuf"] = "grpc" + otel_exporter: Literal["otlp", "console"] = "otlp" + otel_service_name: str = "xtuner-rl" + jaeger_query_url: str | None = None +``` + +理由: + +- `otel_protocol` 是唯一新增的必要选择项,用于区分 Jaeger OTLP HTTP `14318/v1/traces` 和 gRPC `4317/14317`。 +- `otel_exporter="console"` 只用于本地调试,等价于 4f22 的 console exporter。 +- 不新增 `XTUNER_OTEL_ENABLED` 配置字段,避免配置项膨胀;启用仍由 `TraceConfig.enabled` 控制。 + +- [ ] **Step 2: 在 `OtelTraceSink.__init__` 中改成 exporter factory** + +实现方式: + +```python +def _build_otel_exporter(config: TraceConfig) -> Any: + if config.otel_exporter == "console": + from opentelemetry.sdk.trace.export import ConsoleSpanExporter + + return ConsoleSpanExporter() + + if config.otel_protocol == "http/protobuf": + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + return OTLPSpanExporter(endpoint=config.otel_endpoint) + + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + return OTLPSpanExporter(endpoint=config.otel_endpoint, insecure=True) +``` + +然后 `OtelTraceSink` 使用: + +```python +exporter = _build_otel_exporter(config) +``` + +- [ ] **Step 3: 扩展 import error 文案** + +当前文案只提示 `opentelemetry-exporter-otlp-proto-grpc`。改为: + +```python +"OpenTelemetry tracing requires opentelemetry-api, opentelemetry-sdk, " +"and an OTLP exporter package. Install xtuner[trace] or install " +"opentelemetry-exporter-otlp-proto-grpc / opentelemetry-exporter-otlp-proto-http." +``` + +- [ ] **Step 4: 扩展 `pyproject.toml` trace extra** + +如果当前是: + +```toml +trace = ["opentelemetry-api", "opentelemetry-sdk", "opentelemetry-exporter-otlp-proto-grpc"] +``` + +改为: + +```toml +trace = [ + "opentelemetry-api", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-exporter-otlp-proto-http", +] +``` + +- [ ] **Step 5: 单测设计** + +新增或修改 `tests/rl/test_trace.py`: + +```python +def test_trace_config_accepts_http_protocol(): + config = TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_service_name="xtuner-test", + ) + assert config.otel_protocol == "http/protobuf" +``` + +执行阶段命令,需要用户批准后运行: + +```bash +python -m unittest tests.rl.test_trace.TestTraceConfig.test_trace_config_accepts_http_protocol +``` + +--- + +## Task 2: 兼容标准 OTel 环境变量 + +**目的:** 让 4f22 的启动脚本风格可以在当前分支工作,同时保留现有 `XTUNER_TRACE_*` 传播。 + +**Files:** +- Modify: `xtuner/v1/rl/trace.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 增加标准环境变量常量** + +在 `trace.py` 顶部补充: + +```python +OTEL_TRACES_EXPORTER_ENV = "OTEL_TRACES_EXPORTER" +OTEL_EXPORTER_OTLP_PROTOCOL_ENV = "OTEL_EXPORTER_OTLP_PROTOCOL" +OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV = "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" +OTEL_SERVICE_NAME_ENV = "OTEL_SERVICE_NAME" +XTUNER_OTEL_ENABLED_ENV = "XTUNER_OTEL_ENABLED" +AGENT_OTEL_ENABLED_ENV = "AGENT_OTEL_ENABLED" +XTUNER_OTEL_SERVICE_NAME_ENV = "XTUNER_OTEL_SERVICE_NAME" +``` + +- [ ] **Step 2: `_export_env` 同时导出当前变量和标准 OTel 变量** + +最小行为: + +```python +os.environ[TRACE_ENV_ENABLED] = "1" +os.environ[TRACE_ENV_OTEL_ENDPOINT] = config.otel_endpoint +os.environ[TRACE_ENV_OTEL_SERVICE_NAME] = config.otel_service_name + +os.environ[XTUNER_OTEL_ENABLED_ENV] = "1" +os.environ.setdefault(AGENT_OTEL_ENABLED_ENV, "1") +os.environ.setdefault(OTEL_TRACES_EXPORTER_ENV, config.otel_exporter) +os.environ.setdefault(OTEL_EXPORTER_OTLP_PROTOCOL_ENV, config.otel_protocol) +os.environ.setdefault(OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV, config.otel_endpoint) +os.environ.setdefault(OTEL_SERVICE_NAME_ENV, config.otel_service_name) +os.environ.setdefault(XTUNER_OTEL_SERVICE_NAME_ENV, config.otel_service_name) +``` + +使用 `setdefault` 的原因: + +- 用户已经设置过标准 OTel 变量时,不覆盖用户显式配置。 +- Ray worker 和子进程仍能看到必要变量。 + +- [ ] **Step 3: `_load_config_from_env` 支持标准 OTel 环境变量** + +读取优先级: + +1. `XTUNER_TRACE_*`,兼容当前分支。 +2. 标准 `OTEL_*`,兼容 4f22。 + +建议实现: + +```python +enabled = os.environ.get(TRACE_ENV_ENABLED) == "1" or _env_truthy(os.environ.get(XTUNER_OTEL_ENABLED_ENV)) +if not enabled: + return TraceConfig() + +endpoint = ( + os.environ.get(TRACE_ENV_OTEL_ENDPOINT) + or os.environ.get(OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV) + or TraceConfig.model_fields["otel_endpoint"].default +) +protocol = ( + os.environ.get(OTEL_EXPORTER_OTLP_PROTOCOL_ENV) + or TraceConfig.model_fields["otel_protocol"].default +) +service_name = ( + os.environ.get(TRACE_ENV_OTEL_SERVICE_NAME) + or os.environ.get(XTUNER_OTEL_SERVICE_NAME_ENV) + or os.environ.get(OTEL_SERVICE_NAME_ENV) + or TraceConfig.model_fields["otel_service_name"].default +) +``` + +- [ ] **Step 4: 单测设计** + +```python +def test_load_trace_config_from_standard_otel_env(monkeypatch): + monkeypatch.setenv("XTUNER_OTEL_ENABLED", "1") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://127.0.0.1:14318/v1/traces") + monkeypatch.setenv("OTEL_SERVICE_NAME", "xtuner-test") + + config = TraceRuntimeManager._load_config_from_env() + + assert config.enabled is True + assert config.otel_protocol == "http/protobuf" + assert config.otel_endpoint == "http://127.0.0.1:14318/v1/traces" + assert config.otel_service_name == "xtuner-test" +``` + +如果当前测试框架不用 `monkeypatch`,改用 `unittest.mock.patch.dict(os.environ, ..., clear=True)`。 + +--- + +## Task 3: 在 `trace.py` 增加 context propagation 和 task context helper + +**目的:** 让业务模块不直接依赖 OpenTelemetry API,同时补齐 4f22 的跨进程上下文传播能力。 + +**Files:** +- Modify: `xtuner/v1/rl/trace.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 增加 `trace_enabled_from_env`** + +```python +def trace_enabled_from_env() -> bool: + exporter = (os.environ.get(OTEL_TRACES_EXPORTER_ENV) or "").strip().lower() + if exporter in {"none", "false", "off", "0"}: + return False + return ( + os.environ.get(TRACE_ENV_ENABLED) == "1" + or _env_truthy(os.environ.get(XTUNER_OTEL_ENABLED_ENV)) + or _env_truthy(os.environ.get(AGENT_OTEL_ENABLED_ENV)) + or bool( + os.environ.get(OTEL_TRACES_EXPORTER_ENV) + or os.environ.get(OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV) + ) + ) +``` + +- [ ] **Step 2: 增加 context helper** + +```python +def extract_trace_context(headers: Mapping[str, str] | None) -> Any: + if not trace_enabled_from_env(): + return None + try: + from opentelemetry import propagate + + return propagate.extract(headers or {}) + except Exception: + return None + + +def inject_trace_context(headers: MutableMapping[str, str]) -> None: + if not trace_enabled_from_env(): + return + try: + from opentelemetry import propagate + + propagate.inject(headers) + except Exception: + return +``` + +- [ ] **Step 3: 增加 `use_trace_context`** + +```python +@contextmanager +def use_trace_context(trace_context: Any) -> Iterator[None]: + if not trace_enabled_from_env() or trace_context is None: + yield + return + + token = None + try: + from opentelemetry import context as otel_context + + token = otel_context.attach(trace_context) + except Exception: + yield + return + + try: + yield + finally: + if token is not None: + try: + otel_context.detach(token) + except Exception: + return +``` + +- [ ] **Step 4: 增加 `trace_baggage` 和 `trace_task_context`** + +```python +@contextmanager +def trace_baggage(attrs: Mapping[str, Any] | None) -> Iterator[None]: + if not attrs or not trace_enabled_from_env(): + yield + return + + token = None + try: + from opentelemetry import baggage, context + + ctx = context.get_current() + for key, value in attrs.items(): + if value is None: + continue + ctx = baggage.set_baggage(str(key), str(value), context=ctx) + token = context.attach(ctx) + except Exception: + yield + return + + try: + yield + finally: + if token is not None: + try: + context.detach(token) + except Exception: + return +``` + +- [ ] **Step 5: 单测设计** + +```python +def test_trace_baggage_noops_when_disabled(): + with patch.dict(os.environ, {"XTUNER_TRACE_ENABLED": "0"}, clear=True): + with trace_baggage({"xtuner.trace_id": "gsm8k:abc"}): + assert True +``` + +```python +def test_inject_context_noops_when_disabled(): + headers = {} + with patch.dict(os.environ, {"XTUNER_TRACE_ENABLED": "0"}, clear=True): + inject_trace_context(headers) + assert headers == {} +``` + +--- + +## Task 4: 增加无 RolloutState target 的底层 span helper + +**目的:** session server 内部没有 `RolloutState`,不能自然使用当前面向 task 的 `trace_span(target, name)`。需要一个只负责 OTel span 的 helper,但仍放在 `trace.py`,避免业务模块 import OTel。 + +**Files:** +- Modify: `xtuner/v1/rl/trace.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 增加属性设置 helper** + +```python +def set_otel_span_attrs(span: Any, **attrs: Any) -> None: + if span is None: + return + for key, value in attrs.items(): + if value is None: + continue + try: + if isinstance(value, (str, bool, int, float)): + span.set_attribute(key, value) + else: + span.set_attribute(key, str(value)) + except Exception: + continue +``` + +- [ ] **Step 2: 增加 exception helper** + +```python +def record_otel_exception(span: Any, exc: BaseException) -> None: + if span is None: + return + try: + span.record_exception(exc) + span.set_attribute("error.type", type(exc).__name__) + span.set_attribute("error.message", str(exc)) + except Exception: + return +``` + +- [ ] **Step 3: 增加 `otel_span` context manager** + +```python +@contextmanager +def otel_span(name: str, **attrs: Any) -> Iterator[Any]: + if not trace_enabled_from_env(): + yield None + return + + try: + from opentelemetry import trace as otel_trace + + tracer = otel_trace.get_tracer("xtuner.v1.rl.trace") + with tracer.start_as_current_span(name) as span: + set_otel_span_attrs(span, **attrs) + try: + yield span + except Exception as exc: + record_otel_exception(span, exc) + raise + except Exception: + yield None + return +``` + +- [ ] **Step 4: 增加手动 begin/end helper** + +streaming 场景需要 span 跨 `async for` 生命周期: + +```python +def begin_otel_span(name: str, **attrs: Any) -> Any: + if not trace_enabled_from_env(): + return None + try: + from opentelemetry import trace as otel_trace + + span = otel_trace.get_tracer("xtuner.v1.rl.trace").start_span(name) + set_otel_span_attrs(span, **attrs) + return span + except Exception: + return None + + +def end_otel_span(span: Any, exc: BaseException | None = None, **attrs: Any) -> None: + if span is None: + return + if exc is not None: + record_otel_exception(span, exc) + set_otel_span_attrs(span, **attrs) + try: + span.end() + except Exception: + return +``` + +- [ ] **Step 5: 单测设计** + +```python +def test_otel_span_noops_when_disabled(): + with patch.dict(os.environ, {"XTUNER_TRACE_ENABLED": "0"}, clear=True): + with otel_span("xtuner.test.disabled") as span: + assert span is None +``` + +--- + +## Task 5: 增加 `RolloutState.trace_id` 和 trace identity helper + +**目的:** 让每条 rollout task 有一个显式、稳定、贯穿始终的 trace identity。当前阶段先新增 `RolloutState.trace_id`;如果后续确认它和 `uid` 语义完全重复,再把 `trace_id` 收敛为 `uid`。 + +**Files:** +- Modify: `xtuner/v1/data_proto/rl_data.py` +- Modify: `xtuner/v1/rl/trace.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 给 `RolloutState` 增加字段** + +在 `xtuner/v1/data_proto/rl_data.py` 的 `RolloutState` 状态字段附近增加: + +```python + # --- 状态 --- + uid: int | None = None + trace_id: str | None = None # 稳定 task trace identity,优先用于 tracing;未来可视情况和 uid 收敛 + task_name: str | None = None +``` + +说明: + +- `uid` 保持原有业务用途,不改变现有 rollout/session/replay 语义。 +- `trace_id` 只表达 tracing identity。 +- `trace_id` 是 `str`,方便稳定 hash、带 task 前缀、传入 OTel baggage 和 Jaeger tags。 + +- [ ] **Step 2: 在 `trace.py` 增加稳定序列化和 run id helper** + +在 `trace.py` 顶部补充 import: + +```python +import hashlib +import json +``` + +`os` 和 `Any` 如果文件中已经存在,就复用现有 import。然后增加: + +```python +def _json_dumps_stable(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str, separators=(",", ":")) + + +def get_trace_run_id() -> str | None: + return ( + os.environ.get("XTUNER_OTEL_RUN_ID") + or os.environ.get("RUN_ID") + or os.environ.get("MODEL_NAME") + or os.environ.get("WORK_DIR") + ) +``` + +- [ ] **Step 3: 在 `trace.py` 增加 `build_rollout_trace_id`** + +```python +def build_rollout_trace_id( + state: RolloutState, + *, + repeat_index: int | None = None, +) -> str: + if state.message_uid is not None: + payload = { + "task_name": state.task_name, + "data_source": state.data_source, + "message_uid": state.message_uid, + "repeat_index": repeat_index, + } + else: + payload = { + "task_name": state.task_name, + "data_source": state.data_source, + "message": state.message, + "repeat_index": repeat_index, + } + digest = hashlib.sha1(_json_dumps_stable(payload).encode("utf-8")).hexdigest()[:16] + prefix = state.task_name or "unknown" + return f"{prefix}:{digest}" +``` + +说明: + +- `repeat_index` 进入 hash,确保同一个 prompt 的多个 rollout sample 不共享 trace id。 +- `task_name` 放入 prefix,便于人眼识别;真正稳定部分是 hash。 + +- [ ] **Step 4: 在 `trace.py` 增加 trace id 读取 helper** + +```python +def get_rollout_trace_id(state: RolloutState) -> str | None: + if state.trace_id: + return state.trace_id + if state.uid is not None: + return f"{state.task_name or 'unknown'}:{state.uid}" + return None +``` + +说明: + +- 新数据优先用 `trace_id`。 +- 老数据 fallback 到 `uid`,保证已有 trace path 不断。 +- fallback 不保证跨实验稳定,只用于兼容。 + +- [ ] **Step 5: 在 `trace.py` 增加 attributes 构造** + +```python +def build_rollout_trace_attributes(state: RolloutState) -> dict[str, Any]: + trace_id = get_rollout_trace_id(state) + attrs: dict[str, Any] = {} + if trace_id is not None: + attrs["xtuner.trace_id"] = trace_id + attrs["case.id"] = trace_id # 兼容 4f22 的 Jaeger 查询习惯,不在 RolloutState 中单独保存 case.id + run_id = get_trace_run_id() + if run_id: + attrs["run.id"] = run_id + if state.task_name is not None: + attrs["task.name"] = state.task_name + if state.uid is not None: + attrs["xtuner.uid"] = state.uid + if state.message_uid is not None: + attrs["sample.message_uid"] = state.message_uid + if state.data_source is not None: + attrs["sample.data_source"] = ( + _json_dumps_stable(state.data_source) if isinstance(state.data_source, dict) else str(state.data_source) + ) + return attrs +``` + +说明: + +- `case.id` 作为兼容 alias,避免破坏 4f22 文档里按 `case.id` 查询的体验。 +- 主概念是 `xtuner.trace_id` / `RolloutState.trace_id`。 +- sample repeat 信息第一版不在 `RolloutState` 保存;如果后续需要,可作为 attribute 扩展,不影响 trace identity。 + +- [ ] **Step 6: 修改现有 TraceEventBuilder 优先使用 `trace_id`** + +凡是当前通过 `task_name + uid` 构造 trace id 的地方,改成优先读取 `RolloutState.trace_id`。如果 helper 当前拿不到 `RolloutState`,保持原逻辑,避免大范围改签名。 + +目标行为: + +```python +trace_id = get_rollout_trace_id(state) +``` + +fallback: + +```python +trace_id = TraceEventBuilder.trace_id(task_name, uid) +``` + +- [ ] **Step 7: 单测设计** + +```python +def test_build_rollout_trace_id_is_stable_for_same_sample(): + state = RolloutState( + task_name="gsm8k", + message=[{"role": "user", "content": "1+1?"}], + data_source={"dataset": "demo"}, + message_uid=123, + ) + + trace_id1 = build_rollout_trace_id(state, repeat_index=1) + trace_id2 = build_rollout_trace_id(state, repeat_index=1) + + assert trace_id1 == trace_id2 + assert trace_id1.startswith("gsm8k:") +``` + +```python +def test_build_rollout_trace_id_changes_with_repeat_index(): + state = RolloutState( + task_name="gsm8k", + message=[{"role": "user", "content": "1+1?"}], + data_source={"dataset": "demo"}, + message_uid=123, + ) + + assert build_rollout_trace_id(state, repeat_index=0) != build_rollout_trace_id(state, repeat_index=1) +``` + +```python +def test_build_rollout_trace_attributes_uses_trace_id_before_uid(): + state = RolloutState( + uid=99, + trace_id="gsm8k:stable", + task_name="gsm8k", + message=[{"role": "user", "content": "1+1?"}], + ) + + attrs = build_rollout_trace_attributes(state) + + assert attrs["xtuner.trace_id"] == "gsm8k:stable" + assert attrs["case.id"] == "gsm8k:stable" + assert attrs["xtuner.uid"] == 99 +``` + +--- + +## Task 6: sampler 写入稳定 `trace_id` + +**目的:** sampler 是 task 创建入口,只在这里给新 task 写一次 `RolloutState.trace_id`。sampler 不生成 OTel attrs,也不理解 `case.id` / baggage。 + +**Files:** +- Modify: `xtuner/v1/rl/agent_loop_manager/sampler.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 增加 import** + +```python +from xtuner.v1.rl.trace import build_rollout_trace_id +``` + +- [ ] **Step 2: 在 `Sampler.sample` 写入 trace_id** + +在给每个 `RolloutState` 写入 `task_name` 后写: + +```python +if state.trace_id is None: + state.trace_id = build_rollout_trace_id(state, repeat_index=repeat_index) +``` + +完整位置: + +```python +for task, dataset_sampler in self.task_to_sampler.items(): + for group in dataset_sampler.sample(num_samples): + for repeat_index, state in enumerate(group): + state.task_name = task + if state.trace_id is None: + state.trace_id = build_rollout_trace_id(state, repeat_index=repeat_index) + results.append(group) +``` + +说明: + +- `_DatasetSampler` 只负责 dataset 层采样,不知道最终 `task_name`。 +- `Sampler.sample` 是任务名被确定的最小入口,因此第一版把 `trace_id` 放在这里生成。 +- deterministic 和非 deterministic 模式下,同一数据同一 repeat 都能得到相同 `trace_id`。 +- replay buffer 返回旧样本时保留样本已有 `trace_id`;如果旧样本没有,后续 trace helper 会 fallback 到 `uid`。 + +- [ ] **Step 3: 单测设计** + +如果 `_DatasetSampler` 构造成本低,测 `sample_from_dataloader` 返回的 group: + +```python +def test_dataset_sampler_fills_trace_id_for_each_repeat(): + sampler = _DatasetSampler(dataloader=_FakeDataloader(...), prompt_repeat_k=2) + + group = sampler.sample_from_dataloader() + + assert group[0].trace_id is not None + assert group[1].trace_id is not None + assert group[0].trace_id != group[1].trace_id +``` + +如果构造 dataloader 成本高,保留 Task 5 的 `build_rollout_trace_id` 单测,sampler 行为由真实 smoke 验证。 + +--- + +## Task 7: agent item 透传 trace identity,agent loop 传播 OTel task context + +**目的:** 让 sandbox/localhost runner 内部的 `trace_span(item, ...)` 和外层 `RolloutState` 属于同一条 task trace;同时补齐 4f22 的 `_otel_baggage` 能力并额外设置当前 OTel trace context,但 agent loop 不直接 import OpenTelemetry,也不读取 `extra_fields["otel"]`。 + +**Files:** +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/schemas.py` +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py` +- Modify: `xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 给 `AgentRolloutItem` 增加 `trace_id`** + +```python +class AgentRolloutItem(BaseModel): + ... + uid: int | None = None + trace_id: str | None = None +``` + +- [ ] **Step 2: 两个 agent loop 从 `RolloutState` 复制 `trace_id`** + +localhost: + +```python +item.uid = rollout_state.uid +item.trace_id = rollout_state.trace_id +item.group_id = rollout_state.message_uid +``` + +sandbox: + +```python +rollout_item.uid = rollout_state.uid +rollout_item.trace_id = rollout_state.trace_id +rollout_item.group_id = rollout_state.message_uid +``` + +说明: + +- runner 内部已有 `trace_span(item, ...)`,只要 item 上有 `trace_id`,`TraceEventBuilder` 就会优先使用它。 +- 这比在 runner 的每个 span 显式传 `trace_id=...` 更少侵入。 +- `trace_task_context` 不替代 `trace_span`;它只负责在 runner 执行期间设置当前 OTel trace context 和 baggage,方便子 OTel span 或 HTTP 注入继续挂在同一条 task trace 下。 + +- [ ] **Step 3: agent loop 增加 import** + +```python +from xtuner.v1.rl.trace import build_rollout_trace_attributes, trace_task_context +``` + +- [ ] **Step 4: 修改 localhost `generate_sample`** + +把: + +```python +result = await self._run_item(item) +``` + +改成: + +```python +result = await self._run_item(item, trace_attrs=build_rollout_trace_attributes(rollout_state)) +``` + +- [ ] **Step 5: 修改 localhost `_run_item` 签名和 context** + +把: + +```python +async def _run_item(self, item: AgentRolloutItem) -> AgentRolloutItem: + runner = _resolve_runner(item.pipeline) + if runner is None: + raise ValueError("AgentRolloutItem.pipeline is required.") + with ctx_session_id.set(str(item.uid)): + return await runner.run(item) +``` + +改成: + +```python +async def _run_item(self, item: AgentRolloutItem, trace_attrs: dict[str, Any] | None = None) -> AgentRolloutItem: + runner = _resolve_runner(item.pipeline) + if runner is None: + raise ValueError("AgentRolloutItem.pipeline is required.") + with ctx_session_id.set(str(item.uid)), trace_task_context(trace_attrs): + return await runner.run(item) +``` + +- [ ] **Step 6: 单测设计** + +避免真的依赖 lagent/OTel backend。用 fake runner 验证 `_run_item` 能接受 `trace_attrs` 且正常返回: + +```python +class _FakeRunner: + async def run(self, item): + return item + + +async def test_run_item_accepts_trace_attrs(): + loop = AgentInLocalhostLoop(...) + item = AgentRolloutItem(uid=1, pipeline=_FakeRunner(), ...) + result = await loop._run_item(item, trace_attrs={"xtuner.trace_id": "gsm8k:abc"}) + assert result is item +``` + +如果构造 `AgentInLocalhostLoop` 成本太高,可以只测 `trace_task_context` no-op,真实链路放 smoke test 验证。 + +额外增加一个 schema/identity 层面的轻量测试,验证 `AgentRolloutItem.trace_id` 会被 `TraceEventBuilder` 优先使用: + +```python +async def test_trace_event_prefers_agent_rollout_item_trace_id(): + item = AgentRolloutItem(id="case-1", data_source="gsm8k", instruction="problem.txt", uid=123, trace_id="gsm8k:stable") + + await trace_event(item, "custom.agent_item") + + assert event.trace_id == "gsm8k:stable" +``` + +--- + +## Task 8: session server 恢复和注入 trace context + +**目的:** 让从 agent/lagent 发来的请求可以和 session server / worker span 连成一条 trace。 + +**Files:** +- Modify: `xtuner/v1/rl/rollout/session_server.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 增加 import** + +```python +from xtuner.v1.rl.trace import extract_trace_context, inject_trace_context, use_trace_context +``` + +- [ ] **Step 2: 扩展 `_SESSION_SERVER_ONLY_KEYS`** + +把: + +```python +_SESSION_SERVER_ONLY_KEYS = {"session_id"} +``` + +改成: + +```python +_SESSION_SERVER_ONLY_KEYS = {"session_id", "_otel_trace_context"} +``` + +- [ ] **Step 3: 修改 `_handle_request` 顶层逻辑** + +保留当前“一次性读取 request body”的逻辑,增加 body trace context 解析: + +```python +request_body = await request.read() +body_trace_context = None +if request_body: + try: + body_data = json.loads(request_body) + body_trace_context = body_data.get("_otel_trace_context") if isinstance(body_data, dict) else None + except json.JSONDecodeError: + body_trace_context = None + +traceparent_header = request.headers.get("traceparent") +traceparent_body = None +if isinstance(body_trace_context, dict): + traceparent_body = body_trace_context.get("traceparent") + +parent_context_source = "header" if traceparent_header else "none" +parent_context = extract_trace_context(request.headers) +if traceparent_body: + parent_context = extract_trace_context(body_trace_context) + parent_context_source = "body" + +with use_trace_context(parent_context): + return await self._handle_request_impl( + request, + request_body=request_body, + traceparent_header_present=bool(traceparent_header), + traceparent_body_present=bool(traceparent_body), + traceparent_context_source=parent_context_source, + ) +``` + +- [ ] **Step 4: 把原 `_handle_request` 主体移动到 `_handle_request_impl`** + +新增签名: + +```python +async def _handle_request_impl( + self, + request: web.Request, + *, + request_body: bytes, + traceparent_header_present: bool, + traceparent_body_present: bool, + traceparent_context_source: str, +) -> web.Response: + ... +``` + +这是局部重排,不改变业务逻辑。 + +- [ ] **Step 5: 转发前注入 context** + +在构造 `forward_headers` 后增加: + +```python +inject_trace_context(forward_headers) +``` + +- [ ] **Step 6: 确保 `_otel_trace_context` 不转发给 worker** + +已有 on_request 过滤 `_SESSION_SERVER_ONLY_KEYS` 后,该字段会被移除。需要检查非 trace store 路径也同样过滤。 + +- [ ] **Step 7: 单测设计** + +拆一个纯函数降低测试成本: + +```python +def _extract_body_trace_context(request_body: bytes) -> dict[str, Any] | None: + if not request_body: + return None + try: + body_data = json.loads(request_body) + except json.JSONDecodeError: + return None + if not isinstance(body_data, dict): + return None + context = body_data.get("_otel_trace_context") + return context if isinstance(context, dict) else None +``` + +测试: + +```python +def test_extract_body_trace_context(): + body = json.dumps({"_otel_trace_context": {"traceparent": "00-" + "1" * 32 + "-" + "2" * 16 + "-01"}}).encode() + assert _extract_body_trace_context(body)["traceparent"].startswith("00-") +``` + +--- + +## Task 9: session server 增加关键 span 和推理时延属性 + +**目的:** 补齐 4f22 中最有价值的诊断数据:请求转发耗时、stream 读取耗时、首 chunk、首 output token、首 content、output tokens、finish reason。 + +**Files:** +- Modify: `xtuner/v1/rl/rollout/session_server.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 增加 import** + +```python +import time +from xtuner.v1.rl.trace import begin_otel_span, end_otel_span, otel_span, set_otel_span_attrs +``` + +- [ ] **Step 2: 增加 token count helper** + +```python +def _list_len(value: Any) -> int | None: + return len(value) if isinstance(value, list) else None + + +def _choices_output_ids_len(data: dict) -> int: + total = 0 + for choice in data.get("choices") or []: + output_ids = choice.get("output_ids") + if isinstance(output_ids, list): + total += len(output_ids) + return total + + +def _response_output_ids_len(data: dict) -> int | None: + output_ids = data.get("output_ids") + if isinstance(output_ids, list): + return len(output_ids) + total = _choices_output_ids_len(data) + return total if total > 0 else None +``` + +- [ ] **Step 3: 给 `on_request` / `on_response` 加 span** + +保留原逻辑,把主体下沉到 `_on_request_impl` / `_on_response_impl`: + +```python +async def on_request(self, req_body: dict, *, trace_enabled: bool = True) -> dict: + with otel_span( + "xtuner.session_server.on_request", + session_id=req_body.get("session_id"), + trace_store_enabled=trace_enabled, + messages=len(req_body.get("messages") or []) if isinstance(req_body.get("messages"), list) else None, + tools=len(req_body.get("tools") or []) if isinstance(req_body.get("tools"), list) else None, + ) as span: + return await self._on_request_impl(req_body, trace_enabled=trace_enabled, span=span) +``` + +`on_response` 同理记录: + +```python +output_tokens +response_chars +finish_reason +``` + +- [ ] **Step 4: 给 tokenizer / trace store 操作加局部 span** + +只加关键节点,避免代码过度碎片化: + +```python +with otel_span("xtuner.session_server.apply_chat_template", session_id=session_id): + prompt_text = self.tokenizer.apply_chat_template(...) + +with otel_span("xtuner.session_server.trace_store.search_prompt", session_id=session_id): + prefix, nodes = await self.store.search.remote(session_id, prompt_text, filter_none=True) + +with otel_span("xtuner.session_server.trace_store.insert_response", session_id=session_id, output_tokens=len(output_token_ids)): + await self.store.insert.remote(...) +``` + +- [ ] **Step 5: 给 worker 转发加 `forward_worker` span** + +在 `ClientSession` 请求前: + +```python +forward_span = begin_otel_span( + "xtuner.session_server.forward_worker", + target_url=target_url, + stream=is_stream, + request_bytes=len(request_body) if request_body else 0, + timeout_s=self.request_timeout, + input_tokens=input_tokens, + max_tokens=max_tokens, + model=request_data.get("model") if request_data else None, + http_method=request.method, + http_path=request.path, + worker_base_url=self.worker_base_url, + traceparent_header_present=traceparent_header_present, + traceparent_body_present=traceparent_body_present, + traceparent_context_source=traceparent_context_source, +) +``` + +异常和正常结束: + +```python +try: + ... +except Exception as exc: + end_otel_span(forward_span, exc=exc) + raise +else: + end_otel_span(forward_span, response_bytes=len(raw_response) if raw_response is not None else None) +``` + +- [ ] **Step 6: stream 响应统计首 token 等属性** + +在 stream 分支中: + +```python +stream_span = begin_otel_span( + "xtuner.session_server.stream_read", + target_url=target_url, + input_tokens=input_tokens, + max_tokens=max_tokens, +) +stream_start = time.perf_counter() +first_chunk_ms = None +first_output_token_ms = None +first_content_ms = None +chunk_count = 0 +raw_response_bytes = 0 +output_tokens = 0 +finish_reason = None +``` + +在 `async for line in resp.content` 中: + +```python +chunk_count += 1 +raw_response_bytes += len(line) +if first_chunk_ms is None: + first_chunk_ms = (time.perf_counter() - stream_start) * 1000 + +if request_data is not None and line.startswith(b"data: ") and line.strip() != b"data: [DONE]": + try: + data = json.loads(line.decode("utf-8")[6:]) + event_output_tokens = _choices_output_ids_len(data) + if event_output_tokens > 0 and first_output_token_ms is None: + first_output_token_ms = (time.perf_counter() - stream_start) * 1000 + output_tokens += event_output_tokens + for choice in data.get("choices") or []: + delta = choice.get("delta") or {} + if delta.get("content") and first_content_ms is None: + first_content_ms = (time.perf_counter() - stream_start) * 1000 + if choice.get("finish_reason"): + finish_reason = choice.get("finish_reason") + except Exception: + pass +``` + +finally 中: + +```python +end_otel_span( + stream_span, + first_chunk_ms=first_chunk_ms, + first_output_token_ms=first_output_token_ms, + first_content_ms=first_content_ms, + chunks=chunk_count, + raw_response_bytes=raw_response_bytes, + output_tokens=output_tokens if output_tokens > 0 else None, + finish_reason=finish_reason, + client_alive=client_alive, +) +set_otel_span_attrs( + forward_span, + first_chunk_ms=first_chunk_ms, + first_output_token_ms=first_output_token_ms, + first_content_ms=first_content_ms, + output_tokens=output_tokens if output_tokens > 0 else None, + finish_reason=finish_reason, +) +``` + +- [ ] **Step 7: non-stream 响应记录 response/token attrs** + +```python +with otel_span("xtuner.session_server.read_response", target_url=target_url): + raw_response = await resp.read() + +set_otel_span_attrs(forward_span, response_bytes=len(raw_response)) +``` + +解析 JSON 后: + +```python +usage = clean_data.get("usage") if isinstance(clean_data, dict) else None +set_otel_span_attrs( + forward_span, + output_tokens=_response_output_ids_len(clean_data) if isinstance(clean_data, dict) else None, + prompt_tokens=usage.get("prompt_tokens") if isinstance(usage, dict) else None, + completion_tokens=usage.get("completion_tokens") if isinstance(usage, dict) else None, + total_tokens=usage.get("total_tokens") if isinstance(usage, dict) else None, +) +``` + +- [ ] **Step 8: 单测设计** + +纯函数测试即可,不启动 aiohttp: + +```python +def test_choices_output_ids_len(): + data = { + "choices": [ + {"output_ids": [1, 2]}, + {"output_ids": [3]}, + {"delta": {"content": "x"}}, + ] + } + assert _choices_output_ids_len(data) == 3 +``` + +```python +def test_response_output_ids_len_prefers_top_level_output_ids(): + assert _response_output_ids_len({"output_ids": [1, 2, 3]}) == 3 +``` + +--- + +## Task 10: 补 `recipe/otle` Jaeger 参考配置 + +**目的:** 给用户一个可复现的 Jaeger memory 参考配置和 XTuner trace 配置示例。XTuner 不提供启动脚本,用户自行启动 Jaeger/OTLP backend。 + +**Files:** +- Create: `recipe/otle/README.md` +- Create: `recipe/otle/jaeger/jaeger-memory.yaml` + +- [ ] **Step 1: 添加 Jaeger memory 配置** + +`recipe/otle/jaeger/jaeger-memory.yaml` 内容: + +```yaml +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:14317 + http: + endpoint: 0.0.0.0:14318 + +processors: + batch: + +exporters: + jaeger_storage_exporter: + trace_storage: memstore + +extensions: + jaeger_storage: + backends: + memstore: + memory: + max_traces: 100000 + jaeger_query: + storage: + traces: memstore + base_path: / + http: + endpoint: 0.0.0.0:16686 + grpc: + endpoint: 0.0.0.0:16685 + +service: + telemetry: + metrics: + level: none + extensions: [jaeger_storage, jaeger_query] + pipelines: + traces: + receivers: [otlp] + processors: [batch] + exporters: [jaeger_storage_exporter] +``` + +- [ ] **Step 2: 添加 README** + +README 必须写清: + +```text +Jaeger UI: http://127.0.0.1:16686 +OTLP HTTP endpoint: http://127.0.0.1:14318/v1/traces +OTLP gRPC endpoint: 127.0.0.1:14317 +``` + +训练配置示例: + +```python +trace_config = TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_service_name="xtuner-rl", + jaeger_query_url="http://127.0.0.1:16686", +) +``` + +查询说明: + +```text +按 service.name 搜 xtuner-rl。 +按 tag 搜 xtuner.trace_id=。 +兼容 4f22 查询习惯时,也可以搜 case.id=。 +按 run.id 区分 baseline / opt_v1。 +``` + +--- + +## Task 11: 集成验证计划 + +**目的:** 开发完成后验证功能真实可用。注意:这些命令只能在用户明确允许后执行。 + +- [ ] **Step 1: 静态检查** + +命令: + +```bash +python -m compileall -q xtuner/v1/rl/trace.py xtuner/v1/rl/agent_loop_manager/sampler.py xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py xtuner/v1/rl/rollout/session_server.py tests/rl/test_trace.py +``` + +预期: + +```text +no output +``` + +- [ ] **Step 2: trace 单测** + +命令: + +```bash +python -m unittest tests.rl.test_trace +``` + +预期: + +```text +OK +``` + +- [ ] **Step 3: Jaeger backend smoke** + +前提:用户已自行启动 Jaeger/OTLP backend,并使用 `recipe/otle/jaeger/jaeger-memory.yaml` 或等价配置。 + +命令: + +```bash +curl -fsS http://127.0.0.1:16686/api/services +``` + +预期: + +```json +{"data":[],"total":0,"limit":0,"offset":0,"errors":null} +``` + +- [ ] **Step 4: trace enabled 真实训练 smoke** + +配置要求: + +```python +trace_config = TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_service_name="xtuner-rl", + jaeger_query_url="http://127.0.0.1:16686", +) +``` + +需要确认: + +- rank0 打印 Jaeger URL。 +- Jaeger services 中能看到 `xtuner-rl`。 +- 单条 trace 中能看到 sampler 产生的 `trace_id`,并能在 attributes 里看到 `xtuner.trace_id` / `case.id` / `run.id`。 +- localhost agent span、session server span 能在同一条 trace 下。 +- `xtuner.session_server.forward_worker` 有 `first_chunk_ms` / `first_output_token_ms`。 + +- [ ] **Step 5: trace disabled 回归** + +配置: + +```python +trace_config = TraceConfig(enabled=False) +``` + +需要确认: + +- 不启动或不打印 viewer URL。 +- 不要求 OTel backend 可用。 +- 训练不因 OTel 依赖或 Jaeger 不存在失败。 + +- [ ] **Step 6: 异常路径 smoke** + +人为制造一个 judger 或 session server 异常,确认: + +- 对应 span status 是 error。 +- span attributes 包含 `error.type` / `error.message`。 +- 原业务错误路径不被 tracing 吃掉。 + +--- + +## 合入顺序建议 + +1. Task 1-4:先合 tracing 基础设施,确保 no-op 和 exporter 选择逻辑清楚。 +2. Task 5-7:再合 `trace_id`、sampler 写入和 baggage,影响范围小。 +3. Task 8-9:最后合 session server,风险最高,需要重点 review。 +4. Task 10:脚本和文档可以独立合入。 +5. Task 11:用户允许后做验证,不在未授权情况下跑测试。 + +## Review 重点 + +- `trace.py` 是否仍然是唯一 tracing 门面,业务模块是否没有直接 import `opentelemetry`。 +- `TraceConfig(enabled=False)` 是否完全 no-op。 +- 标准 `OTEL_*` 环境变量是否不覆盖用户显式设置。 +- `RolloutState.trace_id` 是否只新增 tracing identity,不改变 `uid`、`session_uid`、replay buffer 的原有语义。 +- session server 是否只读一次 request body。 +- `_otel_trace_context` 是否不会透传给 worker。 +- stream 分支是否在客户端断开时仍然能结束 span。 +- tracing helper 内部异常是否不会影响训练。 diff --git a/docs/superpowers/plans/2026-06-16-otel-task-trace-followup.md b/docs/superpowers/plans/2026-06-16-otel-task-trace-followup.md new file mode 100644 index 0000000000..f13b7ad0fa --- /dev/null +++ b/docs/superpowers/plans/2026-06-16-otel-task-trace-followup.md @@ -0,0 +1,1060 @@ +# XTuner OTel Task Trace Follow-up Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** 补齐当前 OTel / Jaeger trace 版本中仍缺失的 task-level observability 能力,让 sandbox / localhost agent loop、validate 失败、trajectory materialize、LLM/tool 阶段语义和跨独立进程链路追踪都能被稳定观测。 + +**Architecture:** 继续保持 `xtuner.v1.rl.trace` 是唯一 tracing 门面,业务模块只使用 `trace_function` / `trace_span` / `trace_task_context`,不直接 import OpenTelemetry。新的 span 只加在稳定业务边界上,Jaeger dashboard 继续从 Jaeger Query API 重建 task 状态,不重新引入训练路径 JSONL backend。 + +**Tech Stack:** Python, OpenTelemetry API/SDK, OTLP exporter, Jaeger Query API, existing XTuner trace helpers, `unittest`. + +--- + +## 背景 + +当前 trace 主线已经从早期 JSONL producer trace 迁移到 OpenTelemetry + Jaeger: + +- 训练侧通过 `TraceConfig(enabled=True, ...)` 初始化 OTel exporter。 +- 每条 task 通过 `RolloutState.trace_id` / `xtuner.trace_id` 串起 producer、agent loop、rollout、session server、sandbox 等阶段。 +- rank0 自动启动 `XTuner Task Trace Dashboard`。 +- 单条 task detail 内嵌 Jaeger Native Trace。 + +这份文档只覆盖 review 后确认还需要继续开发的 5 个功能点: + +1. sandbox / localhost agent loop 高层 span。 +2. sandbox validate 业务失败需要标红 span。 +3. trajectory materialize 独立 span。 +4. LLM call / tool call 阶段语义。 +5. lagent / 独立进程分布式链路追踪。 + +不覆盖: + +- 慢 task 自动保存调用栈。它是另一个独立功能,应单独写计划。 +- 推理引擎内部 metrics,例如 GPU、queue、prefill、decode、KV cache。 +- 重写 Jaeger dashboard UI。 +- 恢复 JSONL 作为训练路径主后端。 + +## 当前实现基线 + +核心文件: + +- `xtuner/v1/rl/trace.py` + - `TraceConfig` + - `trace_function` + - `trace_span` + - `trace_task_context` + - `inject_trace_context` + - `extract_trace_context` + - `otel_span` +- `xtuner/tools/jaeger_trace_dashboard.py` + - 从 Jaeger Query API 拉取 trace。 + - 转成 `TraceEvent`。 + - 复用 unified viewer payload。 +- `xtuner/tools/producer_trace_analysis.py` + - 负责 task overview、stage summary、timeline span records。 +- `xtuner/tools/producer_trace_viewer.py` + - 渲染 dashboard HTML。 + - 当前 task detail 已经包含 `Jaeger Native Trace` iframe。 + +已有默认插桩: + +- producer: + - `xtuner.producer.sample_group` + - `xtuner.producer.generate_group` + - `xtuner.producer.put_generated_group` +- generic agent loop: + - `xtuner.agent_loop.generate_group` + - `xtuner.agent_loop.generate_sample` + - `xtuner.judger.judge` +- rollout: + - `xtuner.rollout_controller.generate` + - `xtuner.rollout_worker.generate` + - `xtuner.rollout_engine.generate` +- sandbox runner: + - `xtuner.sandbox.run_total` + - `xtuner.sandbox.acquire` + - `xtuner.sandbox.infer` + - `xtuner.sandbox.validate` + - `xtuner.sandbox.entry:` +- localhost runner: + - `xtuner.localhost.run_total` + - `xtuner.localhost.infer` + - `xtuner.localhost.validate` + - `xtuner.localhost.judger` + - `xtuner.localhost.agent` +- session server: + - `xtuner.session_server.on_request` + - `xtuner.session_server.forward_worker` + - `xtuner.session_server.stream_read` + - `xtuner.session_server.read_response` + - `xtuner.session_server.on_response` + +## 设计原则 + +1. **最小侵入 XTuner 业务代码。** 只在稳定边界加 span,不把 trace 逻辑散落进每个内部 helper。 +2. **业务代码不直接调用 OTel API。** 继续使用 `trace_function` / `trace_span` / `trace_task_context`。 +3. **span name 表达稳定阶段,attributes 表达细节。** 例如 `xtuner.stage.kind="llm_call"`,不要把所有细节编码进 span name。 +4. **失败路径必须在 Jaeger 中明显可见。** 有业务失败时用 `span.mark_error(...)`,不是只写最终 `RolloutState.error_msg`。 +5. **不记录大对象。** prompt、response、tool result、图片、tensor 不进入默认 span attributes。 +6. **验证以 Jaeger / dashboard payload 为准。** 新功能必须能从 OTel span 被 dashboard 重建出来。 + +## 文件结构 + +### 修改 + +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py` + - 给 sandbox agent loop 高层入口加 `trace_function`。 + - 给 materialize trajectory 加 `trace_span`。 + - 在调用 runner 时继续保留 `trace_task_context`。 + +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py` + - 修复 validate 业务失败未标红 span 的问题。 + - 给 sandbox 高层阶段增加轻量 `xtuner.stage.kind` attributes。 + +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py` + - 给 entry span 增加 `xtuner.stage.kind="entry"` 和 entry 相关 attributes。 + - 如果后续 entry 能明确区分 tool call,再通过 attribute 标出 `xtuner.stage.kind="tool_call"`。 + +- `xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py` + - 给 localhost agent loop 高层入口加 `trace_function`。 + - 给 materialize trajectory 加 `trace_span`。 + +- `xtuner/v1/rl/agent_loop/localhost_agent_loop/runner.py` + - 保持现有 validate failed 标红语义。 + - 给 localhost 高层阶段增加轻量 `xtuner.stage.kind` attributes。 + +- `xtuner/v1/rl/rollout/session_trace.py` + - 给 session server LLM 请求相关 span 增加 `xtuner.stage.kind="llm_call"`。 + - 确认 forward header 注入 trace context。 + +- `xtuner/v1/rl/rollout/session_server.py` + - 确保 `_otel_trace_context` 不转发给实际 inference worker。 + - 如需要,给 request body trace context 恢复路径补测试覆盖。 + +- `xtuner/tools/producer_trace_analysis.py` + - 更新 stage label / grouping,让 dashboard 主视图能把 `agent_in_sandbox`、`materialize_trajectory`、`llm_call`、`entry/tool_call` 展示得更清楚。 + +- `xtuner/tools/jaeger_trace_dashboard.py` + - 确认 Jaeger tags 中的 `xtuner.stage.kind` 进入 `TraceEvent.attributes`。 + - 如当前 payload 已包含 attributes,则只补测试。 + +- `tests/rl/test_trace.py` + - 增加 trace API、Jaeger payload、dashboard payload 的单元测试。 + +### 可选跨仓修改 + +如果要做到 lagent 进程内部也产生 child span,需要在 lagent repo 中增加一个很薄的 OTel adapter。XTuner 本仓只负责把 context 传过去。 + +建议 lagent 侧独立文件: + +- `lagent/.../trace.py` + - 从环境变量、HTTP header 或任务 payload extract `traceparent` / `tracestate` / `baggage`。 + - 使用 OTel 创建 `lagent.llm_call`、`lagent.tool_call` 等 child span。 + +## Task 1: sandbox / localhost agent loop 高层 span + +**目的:** 让 dashboard 和 Jaeger 中能直接看到单条 sample 进入了 sandbox / localhost agent loop 的高层边界,而不是只能看到内部 runner 阶段。 + +**Files:** +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py` +- Modify: `xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 给 sandbox agent loop 增加 import** + +在 `xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py` 中把 import 改成: + +```python +from xtuner.v1.rl.trace import build_rollout_trace_attributes, trace_function, trace_task_context +``` + +- [ ] **Step 2: 给 sandbox `generate_group` 加 group span** + +把 sandbox agent loop 的 `generate_group` 改成: + +```python + @trace_function("xtuner.agent_in_sandbox.generate_group") + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + async def generate_one(state: RolloutState) -> RolloutState: + if self._sample_semaphore is None: + return await self.generate_sample(state, **kwargs) + async with self._sample_semaphore: + return await self.generate_sample(state, **kwargs) + + pending_tasks = [] + for state in rollout_state: + state.sample_params = self.sample_params + task = create_task(generate_one(state)) + pending_tasks.append(task) + generated_samples = asyncio.gather(*pending_tasks) + group_samples = await generated_samples + return group_samples +``` + +这里依赖 `trace_function` 默认解析参数名 `rollout_state`。`rollout_state` 是 `list[RolloutState]`,会给每条 task 记录 group start/end。 + +- [ ] **Step 3: 给 sandbox `generate_sample` 加 sample span** + +把 sandbox agent loop 的 `generate_sample` 改成: + +```python + @trace_function("xtuner.agent_in_sandbox.generate_sample") + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + try: + rollout_item = rollout_state.extra_fields["rollout_item"].model_copy(deep=True) + if rollout_state.uid is None: + rollout_state.uid = uuid.uuid4().int + rollout_item.uid = rollout_state.uid + rollout_item.trace_id = rollout_state.trace_id + rollout_item.group_id = rollout_state.message_uid + await self._throttle_sandbox_create() + result = await self._run_item(rollout_item, trace_attrs=build_rollout_trace_attributes(rollout_state)) + await self._fill_rollout_state(rollout_state, result) + return rollout_state + except Exception as exc: + rollout_state.status = Status.COMPLETED if self.mode == "eval" else Status.FAILED + rollout_state.finish_reason = "error" + if self.mode == "eval": + rollout_state.reward = {"score": 0.0} + rollout_state.response = "" + rollout_state.extra_fields["agent_status"] = "exception" + rollout_state.error_msg = f"{type(exc).__name__}: {exc}" + self.logger.error(f"[AgentInSandboxLoop] failed: {exc}\n{traceback.format_exc()}") + return rollout_state +``` + +注意:这里不改变异常吞掉的业务语义。因为函数返回的是更新后的 `rollout_state`,`trace_function` 的 `.end` 事件会拿到最新 failed/completed 状态。 + +- [ ] **Step 4: 给 localhost agent loop 增加 import** + +在 `xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py` 中把 import 改成: + +```python +from xtuner.v1.rl.trace import build_rollout_trace_attributes, trace_function, trace_task_context +``` + +- [ ] **Step 5: 给 localhost `generate_group` 加 group span** + +把 localhost agent loop 的 `generate_group` 改成: + +```python + @trace_function("xtuner.agent_in_localhost.generate_group") + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + async def generate_one(state: RolloutState) -> RolloutState: + if self._sample_semaphore is None: + return await self.generate_sample(state, **kwargs) + async with self._sample_semaphore: + return await self.generate_sample(state, **kwargs) + + tasks: list[asyncio.Task[RolloutState]] = [] + for state in rollout_state: + state.sample_params = self.sample_params + task = create_task(generate_one(state)) + tasks.append(task) + + return await asyncio.gather(*tasks) +``` + +- [ ] **Step 6: 给 localhost `generate_sample` 加 sample span** + +把 localhost agent loop 的 `generate_sample` 改成: + +```python + @trace_function("xtuner.agent_in_localhost.generate_sample") + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + try: + if self.sample_timeout_s is not None and self.sample_timeout_s > 0: + return await asyncio.wait_for( + self._generate_sample_impl(rollout_state), + timeout=self.sample_timeout_s, + ) + return await self._generate_sample_impl(rollout_state) + except asyncio.TimeoutError: + self.logger.warning( + f"[AgentInLocalhostLoop] sample timed out after {self.sample_timeout_s:.1f}s " + f"(uid={rollout_state.uid}, group_id={rollout_state.message_uid})." + ) + return self._fail_rollout_state( + rollout_state, + finish_reason="timeout", + error_msg=f"TimeoutError: localhost agent sample exceeded {self.sample_timeout_s:.1f}s", + agent_status="timeout", + ) + except Exception as exc: + if self.mode == "train" and _is_trace_key_mismatch(exc): + raise + self.logger.error(f"[AgentInLocalhostLoop] failed: {exc}\n{traceback.format_exc()}") + return self._fail_rollout_state( + rollout_state, + finish_reason="error", + error_msg=f"{type(exc).__name__}: {exc}", + agent_status="exception", + ) +``` + +- [ ] **Step 7: 补 unit test 覆盖高层 span 名称** + +在 `tests/rl/test_trace.py` 增加一个不依赖真实 sandbox 的 decorator 语义测试: + +```python + async def test_trace_function_records_group_and_sample_like_agent_loop_spans(self): + sink = RecordingTraceSink() + states = [make_state(uid=1), make_state(uid=2)] + + @trace_function("xtuner.agent_in_sandbox.generate_group") + async def traced_group(rollout_state: list[RolloutState]) -> list[RolloutState]: + return [ + state.model_copy(update={"status": Status.COMPLETED}, deep=True) + for state in rollout_state + ] + + with use_trace_recorder(TraceRecorder(sink)): + await traced_group(states) + + stages_by_trace = { + trace_id: [event.stage for event in timeline_from_events(sink.events, trace_id)] + for trace_id in ("gsm8k:1", "gsm8k:2") + } + self.assertEqual( + stages_by_trace["gsm8k:1"], + ["xtuner.agent_in_sandbox.generate_group.start", "xtuner.agent_in_sandbox.generate_group.end"], + ) + self.assertEqual( + stages_by_trace["gsm8k:2"], + ["xtuner.agent_in_sandbox.generate_group.start", "xtuner.agent_in_sandbox.generate_group.end"], + ) +``` + +- [ ] **Step 8: 用户允许后运行 focused test** + +命令: + +```bash +python -m unittest tests.rl.test_trace.TraceCoreBehaviorTest.test_trace_function_records_group_and_sample_like_agent_loop_spans -v +``` + +期望: + +```text +OK +``` + +## Task 2: sandbox validate 业务失败标红 span + +**目的:** 当 sandbox validate 返回 `failed=True` 但没有抛 Python 异常时,`xtuner.sandbox.validate` 和 `xtuner.sandbox.run_total` 也要在 Jaeger 中显示为 error span,并且带上 `error_msg`。 + +**Files:** +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 修改 validate span 作用域** + +把当前逻辑: + +```python + t2 = time.monotonic() + async with trace_span(item, "xtuner.sandbox.validate", **trace_kwargs): + score, failed = await self.validate.run(item, pool) + t_validate = time.monotonic() - t2 + item.reward = score + if failed: + return self._fail( + item, + _first_judger_error(item) + or RolloutError( + stage="validate", + category="validate_failed", + type="JudgerValidator", + message="all judgers failed" if not item.judgers else "validate failed", + ), + ) +``` + +改成: + +```python + t2 = time.monotonic() + async with trace_span(item, "xtuner.sandbox.validate", **trace_kwargs) as validate_span: + score, failed = await self.validate.run(item, pool) + t_validate = time.monotonic() - t2 + item.reward = score + if failed: + error = _first_judger_error(item) or RolloutError( + stage="validate", + category="validate_failed", + type="JudgerValidator", + message="all judgers failed" if not item.judgers else "validate failed", + ) + error_msg = _format_error(error) + validate_span.mark_error(error_msg) + total_span.mark_error(error_msg) + return self._fail(item, error) +``` + +这样 `trace_span` 退出时会写 `xtuner.sandbox.validate.error`,`run_total` 退出时也会写 `.error`。 + +- [ ] **Step 2: 保持异常路径不变** + +不要改 `except Exception as exc:` 的业务逻辑。真实 Python 异常仍由 `trace_span` 自动记录 `error_type` 和 `error_stacktrace`。 + +- [ ] **Step 3: 补 `trace_span.mark_error` 语义测试** + +如果 `tests/rl/test_trace.py` 里已经有同类测试,只补一个更贴近业务失败的 case: + +```python + async def test_trace_span_mark_error_records_error_event_without_exception(self): + sink = RecordingTraceSink() + state = make_state(uid=77) + + with use_trace_recorder(TraceRecorder(sink)): + async with trace_span(state, "xtuner.sandbox.validate") as span: + span.mark_error("validate/validate_failed: validate failed") + + timeline = timeline_from_events(sink.events, "gsm8k:77") + self.assertEqual( + [event.stage for event in timeline], + ["xtuner.sandbox.validate.start", "xtuner.sandbox.validate.error"], + ) + self.assertEqual(timeline[-1].error_msg, "validate/validate_failed: validate failed") + self.assertIsNone(timeline[-1].error_stacktrace) +``` + +- [ ] **Step 4: 用户允许后运行 focused test** + +命令: + +```bash +python -m unittest tests.rl.test_trace.TraceCoreBehaviorTest.test_trace_span_mark_error_records_error_event_without_exception -v +``` + +期望: + +```text +OK +``` + +## Task 3: trajectory materialize 独立 span + +**目的:** 如果 sample 已经完成 sandbox infer/validate,但卡在训练 trajectory 导出、chat template、tokenizer 或 trace store export 阶段,dashboard 需要能直接显示 `materialize_trajectory`。 + +**Files:** +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py` +- Modify: `xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: sandbox import 增加 `trace_span`** + +在 `xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py` 中把 import 改成: + +```python +from xtuner.v1.rl.trace import build_rollout_trace_attributes, trace_function, trace_span, trace_task_context +``` + +- [ ] **Step 2: 包住 sandbox `_fill_rollout_state` 中的训练 trajectory 生成** + +把 train 模式下从 `_load_latest_trace_segment(...)` 到 `rollout_state.labels = ...` 的逻辑包起来: + +```python + async with trace_span( + rollout_state, + "xtuner.agent_in_sandbox.materialize_trajectory", + agent_status=item.status.value, + ) as span: + messages, tools = _load_latest_trace_segment(item.artifacts, require_tools=True) + span.annotate( + agent_message_count=len(messages), + agent_has_tools=tools is not None, + ) + if not messages: + raise ValueError("Agent artifacts must contain at least one trainable messages trace.") + session_id = rollout_state.uid + + trace_store = get_store() + text = self.tokenizer.apply_chat_template( + canonicalize_messages_for_chat_template(messages), + tools=tools, + tokenize=False, + add_generation_prompt=False, + ) + prompt_text = text[:-1] if text.endswith("\n") else text + data = await trace_store.export_training_trace.remote(str(session_id), prompt_text) + + rollout_state.input_ids = data["input_ids"] + rollout_state.labels = data["labels"] + span.annotate( + input_tokens=len(rollout_state.input_ids or []), + label_tokens=len(rollout_state.labels or []), + ) +``` + +不要把 `prompt_text`、messages 内容、tools 内容写入 span attribute,只写数量和 token 长度。 + +- [ ] **Step 3: localhost import 增加 `trace_span`** + +在 `xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py` 中把 import 改成: + +```python +from xtuner.v1.rl.trace import build_rollout_trace_attributes, trace_function, trace_span, trace_task_context +``` + +- [ ] **Step 4: 包住 localhost train trajectory 生成** + +在 localhost agent loop 中找到调用: + +```python +data = await get_store().export_training_trace.remote(str(rollout_state.uid), prompt_text) +``` + +把对应 train 模式 materialize 逻辑包成: + +```python + async with trace_span( + rollout_state, + "xtuner.agent_in_localhost.materialize_trajectory", + agent_status=item.status.value, + ) as span: + messages, tools = _load_latest_trace_segment(item.artifacts, require_tools=True) + span.annotate( + agent_message_count=len(messages), + agent_has_tools=tools is not None, + ) + if not messages: + raise ValueError("Agent artifacts must contain at least one trainable messages trace.") + segment = { + "messages": messages, + "tools": tools, + } + text = self.tokenizer.apply_chat_template( + canonicalize_messages_for_chat_template(segment["messages"]), + tools=segment["tools"], + tokenize=False, + add_generation_prompt=False, + ) + prompt_text = text[:-1] if text.endswith("\n") else text + data = await get_store().export_training_trace.remote(str(rollout_state.uid), prompt_text) + + rollout_state.input_ids = data["input_ids"] + rollout_state.labels = data["labels"] + span.annotate( + input_tokens=len(rollout_state.input_ids or []), + label_tokens=len(rollout_state.labels or []), + ) +``` + +如果当前 localhost 文件已有相同变量名但结构略有不同,保留现有业务逻辑,只把同一段边界包进 span。 + +- [ ] **Step 5: 补 attribute 保留测试** + +在 `tests/rl/test_trace.py` 增加: + +```python + async def test_trace_span_records_materialize_attributes(self): + sink = RecordingTraceSink() + state = make_state(uid=88) + + with use_trace_recorder(TraceRecorder(sink)): + async with trace_span( + state, + "xtuner.agent_in_sandbox.materialize_trajectory", + agent_status="completed", + ) as span: + span.annotate(agent_message_count=3, agent_has_tools=True, input_tokens=11, label_tokens=7) + + timeline = timeline_from_events(sink.events, "gsm8k:88") + self.assertEqual( + [event.stage for event in timeline], + [ + "xtuner.agent_in_sandbox.materialize_trajectory.start", + "xtuner.agent_in_sandbox.materialize_trajectory.end", + ], + ) + self.assertEqual(timeline[-1].attributes["agent_message_count"], 3) + self.assertEqual(timeline[-1].attributes["agent_has_tools"], True) + self.assertEqual(timeline[-1].attributes["input_tokens"], 11) + self.assertEqual(timeline[-1].attributes["label_tokens"], 7) +``` + +- [ ] **Step 6: 用户允许后运行 focused test** + +命令: + +```bash +python -m unittest tests.rl.test_trace.TraceCoreBehaviorTest.test_trace_span_records_materialize_attributes -v +``` + +期望: + +```text +OK +``` + +## Task 4: LLM call / tool call 阶段语义 + +**目的:** dashboard 不只展示技术函数名,还要能把阶段按用户语义理解为 agent loop、sandbox acquire、LLM call、tool/entry、judge、materialize 等类别。第一阶段先用 attributes 和 viewer label 完成,不强行改成 OpenInference。 + +**Files:** +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py` +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py` +- Modify: `xtuner/v1/rl/agent_loop/localhost_agent_loop/runner.py` +- Modify: `xtuner/v1/rl/agent_loop/localhost_agent_loop/stage.py` +- Modify: `xtuner/v1/rl/agent_loop/localhost_agent_loop/judger.py` +- Modify: `xtuner/v1/rl/rollout/session_trace.py` +- Modify: `xtuner/tools/producer_trace_analysis.py` +- Test: `tests/rl/test_trace.py` + +- [ ] **Step 1: 定义第一版 stage kind 字符串** + +第一版不用新 enum,直接使用稳定 attribute key: + +```text +xtuner.stage.kind +``` + +允许值: + +```text +agent_loop +sandbox +agent_run +entry +llm_call +judge +materialize +``` + +说明: + +- `entry` 是 sandbox shell entry 的真实执行边界。 +- 只有能确认 entry 是工具调用时,才写 `tool_call`;否则不要把所有 shell entry 都强行叫 tool call。 +- session server 的 `forward_worker` / `stream_read` 属于 `llm_call`。 + +- [ ] **Step 2: 给 sandbox runner span 增加 stage kind** + +在 `xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py` 中改成: + +```python + async with trace_span( + item, + "xtuner.sandbox.run_total", + **trace_kwargs, + **{"xtuner.stage.kind": "agent_loop"}, + ) as total_span: +``` + +`acquire`: + +```python + async with trace_span( + item, + "xtuner.sandbox.acquire", + **trace_kwargs, + **{"xtuner.stage.kind": "sandbox"}, + ) as acquire_span: +``` + +`infer`: + +```python + async with trace_span( + item, + "xtuner.sandbox.infer", + **trace_kwargs, + **{"xtuner.stage.kind": "agent_run"}, + ) as infer_span: +``` + +`validate`: + +```python + async with trace_span( + item, + "xtuner.sandbox.validate", + **trace_kwargs, + **{"xtuner.stage.kind": "judge"}, + ) as validate_span: +``` + +- [ ] **Step 3: 给 sandbox entry 增加 stage kind** + +在 `xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py` 的 `ShellEntry.run` 中改成: + +```python + async with trace_span( + item, + f"xtuner.sandbox.entry:{self.name}", + task_name=item.data_source, + uid=item.uid if item.uid is not None else item.id, + task_id=item.id, + entry_kind="ShellEntry", + entry_name=self.name, + **{"xtuner.stage.kind": "entry"}, + ): +``` + +在 `DetachedShellEntry.run` 中改成: + +```python + async with trace_span( + item, + f"xtuner.sandbox.entry:{self.name}", + task_name=item.data_source, + uid=item.uid if item.uid is not None else item.id, + task_id=item.id, + entry_kind="DetachedShellEntry", + entry_name=self.name, + **{"xtuner.stage.kind": "entry"}, + ): +``` + +- [ ] **Step 4: 给 localhost runner / stage / judger 增加 stage kind** + +`xtuner/v1/rl/agent_loop/localhost_agent_loop/runner.py`: + +```python +async with trace_span(item, "xtuner.localhost.run_total", **trace_kwargs, **{"xtuner.stage.kind": "agent_loop"}) as total_span: +``` + +```python +async with trace_span(item, "xtuner.localhost.infer", **trace_kwargs, **{"xtuner.stage.kind": "agent_run"}) as infer_span: +``` + +```python +async with trace_span(item, "xtuner.localhost.validate", **trace_kwargs, **{"xtuner.stage.kind": "judge"}) as validate_span: +``` + +`xtuner/v1/rl/agent_loop/localhost_agent_loop/stage.py`: + +```python +async with trace_span( + item, + "xtuner.localhost.agent", + agent_name=agent.name, + **{"xtuner.stage.kind": "agent_run"}, +) as span: +``` + +`xtuner/v1/rl/agent_loop/localhost_agent_loop/judger.py`: + +```python +async with trace_span( + item, + "xtuner.localhost.judger", + judger_name=self.name, + **{"xtuner.stage.kind": "judge"}, +) as span: +``` + +如果当前代码的参数名略有差异,保留现有 attributes,只追加 `xtuner.stage.kind`。 + +- [ ] **Step 5: 给 session server LLM span 增加 stage kind** + +在 `xtuner/v1/rl/rollout/session_trace.py` 的 `ForwardRequestTrace.start` 中给 `begin_otel_span("xtuner.session_server.forward_worker", ...)` 增加: + +```python + **{"xtuner.stage.kind": "llm_call"}, +``` + +在 `StreamResponseTrace.start` 中给 `begin_otel_span("xtuner.session_server.stream_read", ...)` 增加: + +```python + **{"xtuner.stage.kind": "llm_call"}, +``` + +- [ ] **Step 6: 更新 dashboard stage labels** + +在 `xtuner/tools/producer_trace_analysis.py` 的 `TRACE_STAGE_LABELS` 增加: + +```python + "xtuner.agent_in_sandbox.generate_group": "sandbox.generate_group", + "xtuner.agent_in_sandbox.generate_sample": "sandbox.generate_sample", + "xtuner.agent_in_sandbox.materialize_trajectory": "sandbox.materialize", + "xtuner.agent_in_localhost.generate_group": "localhost.generate_group", + "xtuner.agent_in_localhost.generate_sample": "localhost.generate_sample", + "xtuner.agent_in_localhost.materialize_trajectory": "localhost.materialize", + "xtuner.sandbox.run_total": "sandbox.run_total", + "xtuner.sandbox.acquire": "sandbox.acquire", + "xtuner.sandbox.infer": "sandbox.infer", + "xtuner.sandbox.validate": "sandbox.validate", + "xtuner.localhost.run_total": "localhost.run_total", + "xtuner.localhost.infer": "localhost.infer", + "xtuner.localhost.validate": "localhost.validate", + "xtuner.localhost.judger": "localhost.judger", + "xtuner.localhost.agent": "localhost.agent", + "xtuner.session_server.forward_worker": "llm.forward", + "xtuner.session_server.stream_read": "llm.stream", +``` + +保留 `display_trace_stage()` 的 fallback 行为,这样新 span 没配 label 也能正常展示。 + +- [ ] **Step 7: 补 Jaeger payload attribute 测试** + +在 `tests/rl/test_trace.py` 增加: + +```python + def test_jaeger_dashboard_preserves_stage_kind_attribute(self): + trace_id = "b" * 32 + raw_trace = make_jaeger_trace( + [ + make_jaeger_span( + trace_id=trace_id, + span_id="1", + operation="xtuner.session_server.forward_worker", + process_id="p-xtuner", + start_us=1_000_000, + duration_us=200_000, + tags={ + "xtuner.trace_id": "gsm8k:llm", + "xtuner.task_name": "gsm8k", + "xtuner.uid": "llm", + "xtuner.stage.kind": "llm_call", + }, + ) + ] + ) + + payload = build_dashboard_payload_from_jaeger_traces( + [raw_trace], + service_name="xtuner-rl", + now_s=2.0, + ) + detail = payload["views"]["all"]["task_details"]["gsm8k:llm"] + timeline = detail["timeline"] + self.assertTrue( + any( + event["stage"] == "xtuner.session_server.forward_worker.end" + and event["attributes"].get("xtuner.stage.kind") == "llm_call" + for event in timeline + ) + ) +``` + +- [ ] **Step 8: 用户允许后运行 focused test** + +命令: + +```bash +python -m unittest tests.rl.test_trace.TraceStoreAndViewerTest.test_jaeger_dashboard_preserves_stage_kind_attribute -v +``` + +期望: + +```text +OK +``` + +## Task 5: lagent / 独立进程分布式链路追踪 + +**目的:** 一条 XTuner task 的部分阶段如果运行在 lagent 或其他独立进程中,也能通过同一个 OTel trace id 接起来。XTuner 侧负责传递 context;独立进程侧负责 extract context 并创建 child span。 + +**Files:** +- Modify: `xtuner/v1/rl/trace.py` +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/hooks.py` +- Modify: `xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py` +- Modify: `xtuner/v1/rl/rollout/session_trace.py` +- Test: `tests/rl/test_trace.py` +- Optional external repo: lagent trace adapter. + +- [ ] **Step 1: 在 `trace.py` 增加 carrier helper** + +当前已有 `inject_trace_context(headers: MutableMapping[str, str]) -> None`。为了减少业务代码的重复字典创建,新增: + +```python +def make_trace_context_carrier() -> dict[str, str]: + carrier: dict[str, str] = {} + inject_trace_context(carrier) + return carrier +``` + +这个 helper 返回的 dict 可能为空。空 dict 表示 trace 未开启或当前没有可注入 context。 + +- [ ] **Step 2: 给 carrier helper 增加测试** + +在 `tests/rl/test_trace.py` 增加: + +```python + def test_make_trace_context_carrier_noops_when_trace_disabled(self): + reset_trace_for_test() + from xtuner.v1.rl.trace import make_trace_context_carrier + + self.assertEqual(make_trace_context_carrier(), {}) +``` + +这个测试只覆盖 disabled no-op。启用 OTel 后 carrier 具体 header 由 OTel propagator 决定,不在单测里硬编码。 + +- [ ] **Step 3: sandbox entry 环境变量注入 trace context** + +在 `xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py` 中,给 `ShellEntry._execute(...)` 和 `DetachedShellEntry._run_detached(...)` 传入的 env 增加 trace context。 + +建议增加一个局部 helper: + +```python +def _merge_trace_env(env: dict[str, str]) -> dict[str, str]: + carrier = make_trace_context_carrier() + if not carrier: + return env + merged = dict(env) + for key, value in carrier.items(): + merged[f"OTEL_PROPAGATOR_{key.upper().replace('-', '_')}"] = value + return merged +``` + +然后: + +```python +outcome = await self._execute(client, _merge_trace_env(self.env)) +``` + +对于 lagent 或子进程,推荐读取: + +```text +OTEL_PROPAGATOR_TRACEPARENT +OTEL_PROPAGATOR_TRACESTATE +OTEL_PROPAGATOR_BAGGAGE +``` + +说明:使用 `OTEL_PROPAGATOR_*` 前缀是为了避免和标准 OTel SDK 环境变量混淆。标准 `OTEL_*` 仍用于 exporter 配置。 + +- [ ] **Step 4: HTTP 请求路径继续使用标准 headers** + +session server 已经通过 `SessionTraceContext.inject_forward_headers(headers)` 注入标准 header。保留标准 header 名: + +```text +traceparent +tracestate +baggage +``` + +不要改成 `OTEL_PROPAGATOR_*`。环境变量 carrier 只用于 shell / lagent 子进程。 + +- [ ] **Step 5: 给 lagent 侧 adapter 定义最小协议** + +lagent 侧读取环境变量并还原 carrier: + +```python +def load_xtuner_trace_carrier_from_env() -> dict[str, str]: + carrier = {} + mapping = { + "traceparent": "OTEL_PROPAGATOR_TRACEPARENT", + "tracestate": "OTEL_PROPAGATOR_TRACESTATE", + "baggage": "OTEL_PROPAGATOR_BAGGAGE", + } + for header, env_name in mapping.items(): + value = os.environ.get(env_name) + if value: + carrier[header] = value + return carrier +``` + +lagent 侧创建 child span 的最小形态: + +```python +from opentelemetry import propagate, trace +from opentelemetry.context import attach, detach + + +def run_with_xtuner_parent(stage_name: str, fn, *args, **kwargs): + carrier = load_xtuner_trace_carrier_from_env() + ctx = propagate.extract(carrier) + token = attach(ctx) + try: + tracer = trace.get_tracer("lagent") + with tracer.start_as_current_span(stage_name) as span: + span.set_attribute("xtuner.stage.kind", "tool_call") + return fn(*args, **kwargs) + finally: + detach(token) +``` + +这个 adapter 不依赖 XTuner 内部模块,只依赖 OpenTelemetry。 + +- [ ] **Step 6: XTuner 侧补 env merge 单测** + +在 `tests/rl/test_trace.py` 中不直接依赖真实 OTel header,使用 patch: + +```python + def test_sandbox_trace_env_uses_prefixed_propagator_keys(self): + from xtuner.v1.rl.agent_loop.sandbox_agent_loop import sandbox as sandbox_module + + with patch( + "xtuner.v1.rl.agent_loop.sandbox_agent_loop.sandbox.make_trace_context_carrier", + return_value={"traceparent": "00-abc-def-01", "baggage": "xtuner.trace_id=gsm8k%3A1"}, + ): + merged = sandbox_module._merge_trace_env({"A": "1"}) + + self.assertEqual(merged["A"], "1") + self.assertEqual(merged["OTEL_PROPAGATOR_TRACEPARENT"], "00-abc-def-01") + self.assertEqual(merged["OTEL_PROPAGATOR_BAGGAGE"], "xtuner.trace_id=gsm8k%3A1") +``` + +如果 `_merge_trace_env` 放在 `sandbox.py` 文件内部,请不要以下划线 helper 膨胀太多;只保留这一个 helper 即可。 + +- [ ] **Step 7: 端到端 smoke 验收** + +用户允许后跑一轮包含 lagent / sandbox 子进程的真实训练,确认 Jaeger 中同一条 `xtuner.trace_id` 下出现两个 service 或至少两个进程来源: + +```text +service = exp-1 +tag = xtuner.trace_id=<某条 task trace id> +``` + +期望: + +- XTuner span 仍然存在。 +- sandbox / lagent child span 使用同一个 OTel trace id。 +- child span 的 parent-child 关系在 Jaeger Native Trace 中能展开。 +- 如果 lagent 侧暂时还没接 OTel SDK,XTuner 侧至少能看到带 context 的 shell / entry span,不影响训练。 + +## 最终验收 + +用户允许后执行这些验证,不在未授权情况下运行测试: + +```bash +python -m unittest discover -s tests/rl -p test_trace.py +``` + +期望: + +```text +OK +``` + +真实 smoke 需要覆盖: + +1. sandbox agent loop 正常路径: + - Jaeger / dashboard 中能看到 `xtuner.agent_in_sandbox.generate_sample`。 + - 能看到 `xtuner.sandbox.run_total/acquire/infer/validate`。 + - 能看到 `xtuner.agent_in_sandbox.materialize_trajectory`。 +2. sandbox validate failed 路径: + - `xtuner.sandbox.validate` 是 error span。 + - task detail 能看到 `error_msg`。 +3. session server LLM 调用: + - `xtuner.session_server.forward_worker` 有 `xtuner.stage.kind=llm_call`。 + - 首 token latency 字段仍然存在。 +4. localhost agent loop: + - Jaeger / dashboard 中能看到 `xtuner.agent_in_localhost.generate_sample`。 +5. 独立进程传播: + - 如果 lagent 侧 adapter 已接入,同一条 task trace 下能看到 lagent child span。 + +## 文档同步 + +开发完成后更新: + +- `docs/superpowers/specs/2026-06-12-current-trace-capabilities.md` + - 把新增 span 和 stage kind 写入当前能力说明。 + - 明确 validate 业务失败会标红。 + - 明确 materialize trajectory 已可观测。 +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + - 把“第一版必须覆盖 agent in sandbox”的目标状态更新为已实现项。 + - 把 lagent 独立进程传播的 adapter 协议写清楚。 +- `recipe/otle/README.md` + - 示例 service name 改成短名,例如 `exp-1`。 + - 补充 dashboard 查看方式。 + +## Open Questions + +这些问题不阻塞 Task 1-3,但会影响 Task 4-5 的最终形态: + +1. sandbox `entry` 是否都应该算 tool call?当前计划不这么做,只把它标为 `entry`。 +2. lagent child span 的 service name 是否使用 `lagent`,还是跟 XTuner 共用 `exp-1`?当前建议 lagent 用独立 service,方便 Jaeger 区分进程来源。 +3. 是否要在 dashboard stage summary 中按 `xtuner.stage.kind` 聚合一列?当前计划先保留按 span name 聚合,后续如页面信息过多再增加 kind 聚合。 + diff --git a/docs/superpowers/specs/2026-06-05-producer-task-trace-design.md b/docs/superpowers/specs/2026-06-05-producer-task-trace-design.md new file mode 100644 index 0000000000..9888d4500c --- /dev/null +++ b/docs/superpowers/specs/2026-06-05-producer-task-trace-design.md @@ -0,0 +1,920 @@ +# Producer Task Trace Minimal Design + +日期: 2026-06-05 + +## 背景 + +当前 RL rollout 生产链路中,producer 发出的每个 sample 会经过 producer、agent loop、rollout controller、rollout worker、judger、replay buffer 等多个异步边界。现有日志可以看到部分局部信息,但很难稳定回答: + +- 某个 sample 当前执行到了哪一步? +- 某个 sample 是否已经开始或结束 judger? +- 某个 sample 是否卡在等待 rollout worker 返回? +- agentic RL 场景里,某个 sample 是否执行到了用户关心的工具调用或环境交互阶段? + +第一版 trace 的目标不是做完整观测平台,而是给 producer 发出的每个 sample 留下一条有界的阶段时间线,并提供一个最小可视化界面让用户直接查看状态。 + +## 功能总览 + +第一版 trace 对用户提供五个核心功能。 + +### 1. 记录每个 sample 的执行路径 + +trace 会为每个 producer sample 记录一条 timeline。timeline 由一系列 event 构成,每条 event 表示 sample 到达了某个阶段,例如: + +```text +xtuner.producer.sampled +xtuner.rollout.controller.start +xtuner.judger.start +xtuner.judger.end +xtuner.final.completed +``` + +这个功能解决的问题是:用户可以知道某个 sample 从 producer 发出后,具体经过了哪些关键阶段。 + +实现方式: + +- 内置链路在 producer、agent loop、rollout controller、judger、replay buffer 等关键边界插桩。 +- 每条 event 绑定到一个 `trace_id`,通常是 `{task_name}:{uid}`。 +- 同一个 `trace_id` 下的 event 按时间组成 timeline。 + +### 2. 查看每个 sample 当前执行状态 + +每个 sample 的当前状态就是它 timeline 的最后一条 event。第一版不维护单独的 latest 状态表,也不维护 flags。 + +例如某个 sample 的最后一条 event 是: + +```text +xtuner.rollout.controller.start +``` + +那么 viewer 中这个 sample 当前就显示为正在等待 rollout controller / worker 返回。如果最后一条 event 是: + +```text +xtuner.judger.start +``` + +那么它当前就显示为进入了 judger 阶段。 + +这个功能解决的问题是:用户不用翻日志,打开 viewer 就能看到每条 task 当前卡在哪。 + +实现方式: + +- viewer 从 JSONL 重建 timeline。 +- 对每个 `trace_id` 取最后一条 event 作为 latest。 +- 在 task 表格和状态分布图中展示 latest stage。 + +### 3. 定位最可疑的卡住阶段 + +只看 latest stage 的数量分布不够判断谁导致训练卡住。例如: + +```text +xtuner.rollout.controller.start 80 +xtuner.judger.start 20 +``` + +这只能说明当前有 80 个 sample 在 rollout controller,20 个 sample 在 judger。它不能直接说明是 80 个 rollout controller 卡住了,还是 20 个 judger 卡住了。 + +viewer 因此需要展示 open span 的时长分布。open span 指的是有 `{stage}.start`,但后续还没有对应 `{stage}.end` 或 `{stage}.error` 的 span。 + +示例: + +```text +stage open oldest p50 p95 oldest_trace +xtuner.rollout.controller 80 42m 39m 41m gsm8k:17 +xtuner.judger 20 2m 2m 2m gsm8k:91 +``` + +这个功能解决的问题是:用户可以判断哪个阶段存在长时间未闭合的 span,从而定位最可疑的卡住阶段。 + +实现方式: + +- viewer 从每个 trace 的 timeline 中匹配 `.start/.end/.error`。 +- 对仍未闭合的 span 计算 `open_age = now - start_timestamp_s`。 +- 按 span name 聚合 `open_count`、`oldest_open_age`、`p50_open_age`、`p95_open_age`、`oldest_trace_id`。 +- viewer 顶部优先展示 `Suspect Open Spans`。 + +### 4. 查看单个 sample 的完整 timeline + +用户在 viewer 的 task 表格中点击某一行后,可以看到该 sample 的完整 event 序列: + +```text +xtuner.producer.sampled +xtuner.agent_loop.sample.start +xtuner.rollout.start +xtuner.rollout.controller.start +xtuner.rollout.worker_selected +xtuner.rollout.controller.end +xtuner.rollout.end +xtuner.judger.start +xtuner.judger.end +xtuner.agent_loop.sample.end +xtuner.replay_buffer.put.start +xtuner.final.completed +xtuner.replay_buffer.put.end +``` + +这个功能解决的问题是:当某个 sample 失败、卡住、超时或行为异常时,用户可以看到它具体走到了哪一步,以及前面发生过什么。 + +实现方式: + +- viewer 按 `trace_id` 对 JSONL event 分组。 +- 选中某个 task 后展示该组 event。 +- event 中的 `elapsed_s` 和 `error_msg` 会显示在详情中。 + +### 5. 支持用户自定义阶段插桩 + +内置阶段只能覆盖 XTuner 已知路径。agentic RL 中的工具调用、环境交互、用户自定义 agent loop 需要用户自己声明关心的阶段。 + +第一版提供三个插桩 API: + +- `trace_event(target, name)`: 记录一个瞬时事件。 +- `trace_span(target, name)`: 记录一段代码块的 start/end/error。 +- `trace_function(name, target=...)`: 用 decorator 把整个函数调用包成一个 span。 + +示例: + +```python +async def call_tool(state, tool_name, payload): + stage = f"user.tool.{tool_name}" + async with trace_span(state, stage): + return await tool_registry[tool_name](payload) +``` + +生成的 event: + +```text +user.tool.calculator.start +user.tool.calculator.end +``` + +这个功能解决的问题是:后续做 agentic RL 时,用户可以把自己关心的阶段记录进同一条 sample timeline,并在 viewer 里按 `user.*` 阶段过滤。 + +## 第一版目标 + +1. 为 producer 发出的每个 rollout sample 记录 bounded timeline。 +2. 通过 viewer 展示每个 sample 当前执行阶段。 +3. 通过 viewer 展示可疑 open span,辅助判断训练最可能卡在哪个阶段。 +4. 通过 viewer 展示单个 sample 的完整 timeline。 +5. 支持用户在自定义 agent loop、tool wrapper、environment step 中主动插桩。 +6. trace 开启时写 JSONL,viewer 从 JSONL 重建状态。 +7. trace 默认关闭;关闭时 `trace_*` API 都是 no-op。 + +## 第一版非目标 + +1. 不支持每个 event 携带任意附加字段。 +2. 不保存完整 prompt、response、tool result、图片、tensor 或 routed experts。 +3. 不做复杂在线 dashboard、权限、多用户协作、远程训练集群观测。 +4. 不接 OpenTelemetry。 +5. 不做回调式 hook 系统。 +6. 不维护独立 watchpoint flags 索引。 +7. 不做采集侧 stage 前缀过滤、flush 模式、按 event 类型开关等额外配置。 +8. 不做跨 Ray actor 的集中式 trace store。 +9. 不要求所有用户 agent loop 自动插桩;用户自定义逻辑通过 `trace_event/trace_span/trace_function` 主动记录。 + +## 设计原则 + +1. 最小功能优先:先让用户能看清 producer task 当前状态,再考虑复杂分析能力。 +2. 单一事实来源:timeline 是唯一事实来源;latest 和 reached-stage 都从 timeline 派生。 +3. 用户侧以 viewer 为主:第一版不把 manager 查询接口作为用户 API。 +4. 插桩接口干净:业务代码优先使用 `trace_event`、`trace_span`、`trace_function`,不传播 `self.tracer`。 +5. 阶段名表达语义:第一版没有 event 附加字段,工具名、环境名等先编码进 stage 字符串。 +6. 分布式能力后置:第一版只记录 producer/manager 侧可见阶段,不引入 Ray trace actor。 + +## 核心术语 + +### Task + +`task_name` 维度的任务类型,例如 `gsm8k`、`geo3k`、`tool_agent`。它不是一次具体 rollout 执行实例。 + +### Sample + +一次具体 rollout 执行实例,对应一个 `RolloutState`。这是 trace 的最小追踪对象。 + +### Trace ID + +用于唯一标识一个 sample。推荐格式: + +```text +{task_name}:{uid} +``` + +如果 `task_name` 缺失,用 `unknown`。如果 `uid` 缺失,第一版不记录这条 event。这里不生成临时 id,因为临时 id 无法被用户按 sample 查询,也会污染 trace store。 + +### Stage + +开放字符串,表示 sample 到达的阶段。 + +内置 XTuner 阶段使用 `xtuner.*` 前缀: + +```text +xtuner.producer.sampled +xtuner.rollout.controller.start +xtuner.judger.start +xtuner.final.completed +``` + +用户自定义阶段使用 `user.*` 前缀: + +```text +user.tool.calculator.start +user.tool.calculator.end +user.env.browser.step.start +user.env.browser.step.end +``` + +第一版没有任意附加字段,所以用户关心的信息应尽量放进 stage 名里。例如不要记录 `stage="user.tool.end", tool_name="calculator"`,而是记录 `stage="user.tool.calculator.end"`。 + +### Timeline + +每个 trace id 的 event 序列。它是第一版唯一的事实来源。 + +### Latest + +某个 trace id 的 timeline 最后一条 event。它不是单独维护的状态。 + +### Reached Stage + +某个 trace id 的 timeline 中是否出现过指定 stage 或 stage prefix。它通过扫描 timeline 得到,不单独维护 flags。 + +## 数据结构 + +### TraceConfig + +```python +class TraceConfig(BaseModel): + # 总开关。False 时使用 NoopTraceRecorder,trace_* API 都会成为 no-op。 + enabled: bool = False + + # TraceStore 在内存中保留的全局 event 数上限。 + # 超过后按最老 event 驱逐,只影响进程内临时查询。 + # 已经写入 JSONL 的历史不受这个限制。 + max_events: int = 100_000 + + # 单个 trace_id 在内存中保留的 event 数上限。 + # 用于限制单个 agentic sample 长时间工具调用或环境交互导致的内存增长。 + max_events_per_trace: int = 256 + + # JSONL 分片输出目录。None 时由 manager 解析为 worker_log_dir / "producer_traces"。 + # 第一版 trace 开启后固定启用 buffered JSONL writer,不提供单独 dump 开关。 + output_dir: Path | None = None +``` + +配置语义: + +- `enabled=False` 时所有 `trace_event/trace_span/trace_function` 都是 no-op。 +- `enabled=True` 时启用内存 timeline 和 buffered JSONL writer。 +- `max_events` 和 `max_events_per_trace` 只限制内存保留,不限制已经写出的 JSONL 分片。 +- 第一版不提供 latest/timeline/flags 的独立开关,因为 timeline 已经能派生 latest 和 reached-stage。 +- 第一版不提供采集侧 stage allowlist。如果某些内置阶段后续确认噪声过大,再单独增加过滤能力。 + +### TraceEvent + +```python +@dataclass +class TraceEvent: + trace_id: str + timestamp_s: float + event: Literal["enter", "exit", "error", "instant"] + stage: str + + status: str | None = None + task_name: str | None = None + uid: int | str | None = None + session_uid: int | str | None = None + + train_step: int | None = None + model_step: int | None = None + producer_future_step: int | None = None + produce_batch_id: str | None = None + + worker_rank: int | None = None + + duration_ms: int | None = None + ok: bool | None = None + err: str | None = None + error_msg: str | None = None + + extra: dict[str, Any] = field(default_factory=dict) +``` + +字段说明: + +- `trace_id`: 主键,通常是 `{task_name}:{uid}`。 +- `timestamp_s`: wall clock 时间,使用 `time.time()`。 +- `event`: 事件类型。 + - `enter`: 进入一个 span。 + - `exit`: span 正常结束。 + - `error`: span 异常或业务失败结束。 + - `instant`: 一个瞬时事件,不表示可闭合 span。 +- `stage`: 阶段名。标准语义下不再把 `.start/.end/.error` 编码进 `stage`,例如 `run_total`、`infer`、`validate`、`xtuner.rollout_worker.generate`。 +- `status`: 当前 sample 状态,例如 `completed`、`aborted`、`failed`。 +- `task_name`: producer task 名。 +- `uid`: sample uid。 +- `session_uid`: rollout routing 使用的 session id。 +- `train_step`: 当前训练步。 +- `model_step`: rollout sample 使用的模型版本。 +- `producer_future_step`: disagg producer 正在生产的 future step。 +- `produce_batch_id`: producer batch 标识,通常由 `train_step/model_step/producer_future_step` 派生。 +- `worker_rank`: rollout worker rank。 +- `duration_ms`: `exit` 或 `error` 事件的耗时,单位毫秒。 +- `ok`: `exit` 或 `error` 事件的结果标记。正常结束为 `True`,异常或业务失败为 `False`。 +- `err`: 简短错误字符串,兼容 sandbox 旧 `span(...).mark_error()` 和异常路径。 +- `error_msg`: 错误信息摘要。 +- `extra`: 用户传入的轻量级附加字段,例如 `task_id`、`entry_kind`、`sandbox_name`、`sandbox_env_id`。这里不写 prompt、response、tool result、图片、tensor 等大对象。 + +设计原则: + +- `event` 和 `stage` 必须拆开。`stage` 表达“什么阶段”,`event` 表达“进入 / 退出 / 出错 / 瞬时记录”。 +- viewer 和 analysis 优先使用 `event` 判断 span 生命周期。 +- 为兼容当前已经存在的 JSONL,reader 可以把旧 `stage` 后缀解释成事件类型: + - `{name}.start` 等价于 `event="enter", stage="{name}"`。 + - `{name}.end` 等价于 `event="exit", stage="{name}"`。 + - `{name}.error` 等价于 `event="error", stage="{name}"`。 +- 新写出的事件应使用标准字段,不再依赖从 stage 字符串后缀解析生命周期。 + +## 插桩 API + +第一版对外暴露的接口参考 slime 的形状,但只保留 task 阶段追踪需要的最小能力。底层仍然有 `TraceRecorder` 和 `TraceStore`,但业务代码不应该到处传 `self.tracer`。manager 在生产循环开始前把当前 recorder 绑定到 trace context;`trace_*` API 从当前 context 里取 recorder。 + +第一版不实现 slime 文档里的 attrs/attrs_getter、debug dump、`bind_trace(sample)`。viewer 只做最小 JSONL 可视化,不复刻 slime 的完整 timeline viewer 能力。当前 target 固定是 `RolloutState` 或 `list[RolloutState]`。 + +### trace_event + +`trace_event()` 记录一个瞬时阶段。 + +```python +await trace_event(state, "xtuner.judger.start") +await trace_event(states, "xtuner.producer.sampled") +await trace_event(state, "user.tool.calculator.end") +``` + +建议签名: + +```python +async def trace_event( + target: RolloutState | Sequence[RolloutState] | None, + name: str, + *, + event: Literal["instant", "enter", "exit", "error"] = "instant", + status: str | None = None, + task_name: str | None = None, + uid: int | str | None = None, + train_step: int | None = None, + model_step: int | None = None, + producer_future_step: int | None = None, + worker_rank: int | None = None, + duration_ms: int | None = None, + ok: bool | None = None, + err: str | None = None, + error_msg: str | None = None, + **extra: Any, +) -> None: + ... +``` + +作用: + +- 记录已经发生的单点状态,例如 sampled、worker selected、final completed。 +- 默认写出的标准事件类型是 `event="instant"`。 +- `trace_span()` 和 `trace_function()` 可以复用底层记录能力写 `enter/exit/error`。 +- `**extra` 只用于轻量级附加字段,例如 `task_id`、`entry_kind`、`sandbox_name`。 +- `target` 可以是单个 `RolloutState` 或一组 `RolloutState`。 +- `target` 非空时,从 `RolloutState` 中提取 `task_name`、`uid`、`session_uid`、`status`。 +- 如果 `uid` 缺失,则不记录 event。 + +### trace_span + +`trace_span()` 记录一段代码块的 start/end/error。 + +```python +async with trace_span(state, "user.tool.calculator"): + result = await calculator(expr) +``` + +作用: + +- enter 时记录 `event="enter", stage=name`。 +- 正常退出时记录 `event="exit", stage=name`,带 `duration_ms`、`ok=True`。 +- 异常退出时记录 `event="error", stage=name`,带 `duration_ms`、`ok=False`、`err/error_msg`,然后继续抛出原异常。 +- 适合函数内部的一小段逻辑,例如 tool 调用、环境 step、rollout controller 等。 + +`target` 是一组 states 时,`trace_span()` 会对每个 state 写同一个阶段。它不推断内部进度,而是把调用方声明的阶段写到每个 target 的 timeline。每个 task 当前阶段仍然由它自己的 timeline 最后一条 event 决定。 + +### trace_function + +`trace_function()` 是函数级 span decorator。它适合“整个函数调用就是一个阶段”的边界,减少对原业务代码函数体的改动。 + +建议接口: + +```python +def trace_function( + name: str, + *, + target: str | None = None, + target_getter: Callable[..., RolloutState | Sequence[RolloutState] | None] | None = None, + result: Literal["input", "return"] = "return", +) -> Callable: + ... +``` + +作用: + +- `target="rollout_state"`: 从被装饰函数的入参里取 target。 +- `target_getter=...`: 当 target 需要从多个参数推导时使用。第一版建议内置代码优先用 `target=...`。 +- `result="input"`: end/error event 写到入参 target 上。 +- `result="return"`: start/error event 写到入参 target 上;正常返回时,如果返回值是 `RolloutState` 或 `list[RolloutState]`,end event 写到返回值上。 +- 如果 `target` 和 `target_getter` 都没有提供,第一版不做自动推断,直接 no-op。 + +示例: + +```python +@trace_function("xtuner.producer.generate_group", target="rollout_state", result="return") +async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + ... +``` + +这个装饰器等价于: + +- 调用前对每个 input state 写 `event="enter", stage="xtuner.producer.generate_group"`。 +- 调用异常时对每个 input state 写 `event="error", stage="xtuner.producer.generate_group"`。 +- 调用成功后对每个 returned state 写 `event="exit", stage="xtuner.producer.generate_group"`。 + +装饰器不会自动知道函数内部执行到了哪里。如果一个 sample 在函数内部进入了更细阶段,仍需要内部路径继续用 `trace_event()`、`trace_span()` 或另一个 `@trace_function()` 记录。 + +### Trace Context + +manager 创建 `TraceRuntime` 后,在 producer 生产循环外层绑定当前 recorder: + +```python +with use_trace_recorder(trace_runtime.recorder): + await producer.produce() +``` + +绑定使用 `contextvars.ContextVar`。普通 async 调用链可以直接访问当前 recorder;Ray actor 内部不会自动共享这个 context。第一版不追踪 rollout worker/backend 内部阶段,正是为了避免在最小版本里引入分布式 trace store。 + +### NoopTraceRecorder + +关闭 trace 时当前 recorder 是 no-op。这样业务代码可以直接写: + +```python +await trace_event(state, "xtuner.judger.start") +``` + +不需要在每个插桩点判断 trace 是否开启。 + +## Viewer + +第一版用户侧主要通过可视化界面查看 trace,而不是调用 manager 查询接口。推荐提供一个本地 viewer: + +```bash +python tools/producer_trace_viewer.py {output_dir} +``` + +viewer 只读取 JSONL 分片,不向训练进程写数据,也不要求训练进程暴露新的 public RPC。运行中如果 JSONL 持续追加,viewer 可以通过手动刷新或简单轮询看到近实时状态;训练结束后也可以离线打开同一组 JSONL 分片。 + +### 展示形式 + +第一版 viewer 做成一个本地单页 HTML 即可,包含四块: + +1. Suspect Open Spans:按未闭合 span 聚合,展示 `open_count`、`oldest`、`p50`、`p95`、`oldest_trace_id`。 +2. Latest Stage Distribution:按 latest stage / status 聚合,显示每个阶段当前有多少 sample。 +3. Task 表格:每行一个 trace id,展示 `task_name`、`uid`、latest stage、当前 open span、open age、status、最近 error。 +4. 单 task timeline:点击某一行后展示该 trace id 的完整 event 序列。 + +示意: + +```text +Producer Trace Viewer +total: 1024 | running: 312 | completed: 660 | failed: 52 + +Suspect Open Spans +stage open oldest p50 p95 oldest_trace +xtuner.rollout.controller 80 42m 39m 41m gsm8k:17 +xtuner.judger 20 2m 2m 2m gsm8k:91 + +Latest Stage Distribution +xtuner.rollout.controller.start 180 +xtuner.judger.start 61 +user.tool.calculator.start 37 +xtuner.final.completed 660 + +Task Table Selected Task Timeline +trace_id latest_stage open_span age +gsm8k:17 rollout.controller.start xtuner.rollout.controller 42m +gsm8k:91 judger.start xtuner.judger 2m +tool:456 user.tool.calculator user.tool.calculator 11s +``` + +### 过滤能力 + +第一版 viewer 支持这些过滤: + +- 按 `task_name` 过滤。 +- 按 latest stage 过滤,例如只看卡在 `xtuner.rollout.controller.start` 的 sample。 +- 按 open span 过滤,例如只看当前有未闭合 `xtuner.judger` span 的 sample。 +- 按 stage prefix 过滤,例如看进入过 `user.tool.calculator.` 的 sample。 +- 按 `uid` 或 `trace_id` 搜索。 +- 按 `status` 或是否存在 `error_msg` 过滤。 + +### 数据计算规则 + +- current/latest = 每个 `trace_id` 的最后一条 event。 +- timeline = 同一个 `trace_id` 的 event 序列。 +- reached-stage = 对该 `trace_id` 的 timeline 做精确匹配或前缀匹配。 +- open span = 某个 `event="enter", stage=name` 之后,还没有出现同名 `event="exit"` 或 `event="error"`。 +- open age = `now - start_event.timestamp_s`。 +- suspect open span = 按 open span name 聚合后的 `open_count`、`oldest_open_age`、`p50_open_age`、`p95_open_age`。 + +open span 优先使用标准 `event` 字段计算。为了兼容旧 JSONL,viewer 可以继续识别 `.start/.end/.error` 后缀。`xtuner.rollout.worker_selected` 这类瞬时 event 只会显示在 timeline 和 latest stage 中,不参与 open span 统计。 + +这些计算都来自 timeline,不维护额外 flags 或 latest 索引。 + +## 存储 + +### JSONL + +trace 开启时,每产生一个 event 都会进入 trace writer,但不要求每个 event 立即同步落盘。第一版采用 buffered JSONL writer: + +```text +trace_event / trace_span / trace_function +-> TraceRecorder.mark(event) +-> InMemoryTraceStore.append(event) + 1. 同步更新内存 timeline + 2. 把 JSON line 放进 TraceJsonlWriter buffer +-> TraceJsonlWriter 后台批量写入 JSONL shard +``` + +这样设计的原因是: + +- viewer 需要看到训练卡住前已经到达的阶段,所以不能等一整轮 rollout 结束后才统一 dump。 +- 但每个 event 都同步 open/write/flush 文件会产生大量磁盘 I/O,尤其是在共享存储、NFS、agentic tool 高频调用场景下会拖慢训练。 +- buffered writer 让业务路径只做内存操作,磁盘写入由后台批量完成。 + +写入可见性语义: + +- 内存 timeline 在 `append(event)` 返回前已经更新。 +- JSONL 落盘有短暂延迟,第一版目标是通常在 1 秒左右刷新到文件。 +- 如果训练只是卡住但进程仍存活,后台 writer 仍会继续 flush,viewer 可以看到最近事件。 +- 如果进程直接 crash,可能丢失最后一个 flush interval / buffer 内的少量 trace event。这是性能和完整性的取舍。 +- 正常 shutdown、训练结束、rollout step 边界应调用 `trace_runtime.flush()` 或 `trace_runtime.close()`,尽量把 buffer 写完。 + +为了避免单个文件无限增长,第一版不写单一 `producer_trace.jsonl`,而是写固定大小的 JSONL 分片: + +```text +{output_dir}/producer_trace_000000.jsonl +{output_dir}/producer_trace_000001.jsonl +{output_dir}/producer_trace_000002.jsonl +``` + +写入策略: + +- writer 保持当前 shard 的文件句柄,不在每个 event 上 open/close 文件。 +- event 先进入内存 buffer,后台批量写入当前 active shard。 +- buffer 达到固定 event 数、固定字节数,或距离上次 flush 超过固定时间后触发 flush。 +- active shard 达到固定大小后切到下一个 shard。 +- 第一版建议使用实现常量控制 shard 大小,例如 `TRACE_JSONL_SHARD_BYTES = 256 * 1024 * 1024`,不额外暴露配置项。 +- 第一版建议使用实现常量控制 flush 行为,例如: + +```python +TRACE_JSONL_FLUSH_INTERVAL_S = 1.0 +TRACE_JSONL_FLUSH_EVENTS = 1024 +TRACE_JSONL_FLUSH_BYTES = 1 * 1024 * 1024 +``` + +- 这些 flush 参数第一版不作为 `TraceConfig` 配置项暴露,避免配置面膨胀。 +- viewer 读取 `output_dir` 下所有 `producer_trace_*.jsonl`,按 shard index 和文件内顺序重建 timeline。 + +体积控制: + +- event 行只包含 `TraceEvent` 的小字段,不写 prompt、response、tool result、图片、tensor 或 routed experts。 +- `max_events` 和 `max_events_per_trace` 只限制内存,不删除已经写出的 JSONL 分片。 +- 分片能避免单个文件特别大,但总目录大小仍然会随 event 总量增长。 +- 第一版不自动删除旧分片,因为删除历史 event 可能导致 viewer 无法还原长时间 open span 的 start event。 + +性能边界: + +- 正常情况下,trace event 不等待磁盘写入完成。 +- writer flush 只做普通文件 flush,不对每批 event 执行 `fsync`。 +- writer flush 失败只记录 warning,不让训练任务因为 trace 落盘失败而失败。 +- 如果 writer buffer 长时间堆积,说明落盘速度跟不上 event 产生速度。第一版先通过 warning 暴露这个问题,不引入复杂的采样、丢弃、动态降级策略。 +- 后续如果 agentic 场景 event 量过高,再引入采集侧过滤、压缩、保留策略或更复杂的异步写入策略。 + +JSONL 的作用: + +- viewer 从 JSONL 重建 timeline 和 latest 状态。 +- 训练结束后离线排查。 +- 在线内存 timeline 被驱逐后仍能看到历史。 +- 发生进程退出时尽量保留已经写出的事件。 + +### InMemoryTraceStore + +第一版 store 只维护 bounded timeline,用于进程内临时状态和单元测试。`append(event)` 的同步部分只做内存 timeline 更新和 writer buffer enqueue,不做磁盘 flush。 + +```python +class TraceStore: + def append(self, event: TraceEvent) -> None: ... + + def get_timeline(self, trace_id: str) -> list[TraceEvent]: ... + + def get_latest(self, trace_id: str) -> TraceEvent | None: ... + + def has_stage( + self, + trace_id: str, + *, + stage: str | None = None, + stage_prefix: str | None = None, + status: str | None = None, + ) -> bool: ... + + def query_latest( + self, + *, + task_name: str | None = None, + stage: str | None = None, + stage_prefix: str | None = None, + status: str | None = None, + limit: int = 100, + ) -> list[TraceEvent]: ... + + def flush_jsonl(self) -> None: ... + + def close(self) -> None: ... +``` + +这些方法是内部实现和测试入口,不是第一版面向用户的 manager public API。第一版不需要再增加 `trace_store_get_timeline()` 这类薄 helper,也不建议在 `AgentLoopManager` 上暴露一组用户查询方法。用户主要通过 viewer 看状态。 + +`flush_jsonl()` 和 `close()` 用于训练结束、正常 shutdown、测试读取 JSONL 前把 writer buffer 写出。业务插桩路径不应该直接调用它们。 + +## Stage 命名规范 + +### 内置阶段 + +Producer: + +```text +xtuner.producer.sampled +xtuner.producer.generate_group.start +xtuner.producer.generate_group.end +xtuner.producer.generate_group.error +xtuner.producer.put_buffer.start +xtuner.producer.put_buffer.end +xtuner.producer.put_buffer.error +``` + +Agent loop: + +```text +xtuner.agent_loop.group.start +xtuner.agent_loop.group.end +xtuner.agent_loop.group.error +xtuner.agent_loop.sample.start +xtuner.agent_loop.sample.end +xtuner.agent_loop.sample.error +``` + +Rollout: + +```text +xtuner.rollout.start +xtuner.rollout.end +xtuner.rollout.error +xtuner.rollout.controller.start +xtuner.rollout.worker_selected +xtuner.rollout.controller.end +xtuner.rollout.controller.error +``` + +Judger: + +```text +xtuner.judger.start +xtuner.judger.end +xtuner.judger.error +``` + +Replay buffer / final: + +```text +xtuner.replay_buffer.put.start +xtuner.replay_buffer.put.end +xtuner.final.completed +xtuner.final.aborted +xtuner.final.expired +xtuner.final.failed +``` + +### 用户阶段 + +用户插桩统一使用 `user.*`: + +```text +user.tool..start +user.tool..end +user.tool..error +user.env..step.start +user.env..step.end +user.env..step.error +``` + +## 代码接入点 + +### producer.py + +`ProduceContext.generate_group()` 是 producer 调 agent loop 的公共入口。这里不应该在业务代码里手写逐 state 循环;优先使用装饰器把函数调用包装成一个阶段。 + +```python +@trace_function("xtuner.producer.generate_group", target="rollout_state", result="return") +async def generate_group(self, rollout_state, *, enable_partial_rollout=False): + # 原有 local/ray 分支保持不变。 + ... +``` + +这段只能说明每个 sample 进入了 producer 等待 agent loop 的阶段。如果要知道某个 sample 是否已经进入 `agent_loop.sample`、`judger` 或用户 tool 阶段,需要这些更细阶段在各自执行路径里继续 `trace_event()`、`trace_span()` 或 `@trace_function()`。 + +`ProduceContext.put_generated_group()` 是 group 进入 replay buffer 前的公共入口: + +```python +@trace_function("xtuner.producer.put_buffer", target="group", result="input") +async def put_generated_group(self, group): + ... +``` + +### sampler.py + +采样出 `RolloutState` 后记录: + +```python +await trace_event(states, "xtuner.producer.sampled") +``` + +这里能拿到 `task_name` 和 `uid`,是 trace_id 形成的最早稳定时刻。 + +### agent_loop.py + +`AgentLoop.generate_group()` 负责一组 sample 的并发生成: + +```python +@trace_function("xtuner.agent_loop.group", target="rollout_state", result="return") +async def generate_group(self, rollout_state: list[RolloutState], **kwargs): + ... +``` + +如果希望看到每个 sample 的生成阶段,可以在单 sample 路径上继续装饰: + +```python +@trace_function("xtuner.agent_loop.sample", target="rollout_state", result="return") +async def generate_sample(self, rollout_state: RolloutState, **kwargs): + ... +``` + +`AgentLoop.run_judger()` 是 judger start/end/error 的统一插桩点。即使后续 judger 具体实现变化,这个入口仍然稳定。 + +### single_turn_agent_loop.py + +`SingleTurnAgentLoop.generate_sample()` 中不重复记录 judger start/end,避免同一阶段多处插桩。它只需要记录 sample 进入 rollout 和 rollout 返回后的状态;judger 由 `run_judger()` 负责。 + +### rollout/controller.py + +`RolloutController.generate()` 是选择 worker 和等待 worker 返回的边界: + +```python +async with trace_span(state, "xtuner.rollout.controller"): + worker, rank = await self.select_worker(state) + await trace_event(state, "xtuner.rollout.worker_selected", worker_rank=rank) + response = await worker.generate.remote(...) +``` + +如果没有可用 worker 或 worker 调用失败,`trace_span` 会记录 `xtuner.rollout.controller.error`。 + +### rollout/worker.py 和 rollout/lmdeploy.py + +第一版不在 rollout worker actor 和 backend 子类中插桩。原因是这些阶段通常发生在独立 actor/process 中。如果没有集中式 trace store,worker 内部 event 无法可靠汇总到 producer/manager 的 store。 + +第一版只通过 `RolloutController.generate()` 记录 manager 可见的 worker selected / worker returned / worker failed。HTTP start/end/retry、response parse、backend name 等细粒度信息放到后续版本。 + +### judger + +judger 相关实现包括 native、remote、composed judger。第一版统一在 `AgentLoop.run_judger()` 插 `xtuner.judger.start/end/error`,不在每个 judger 子类重复插桩。 + +如果后续需要区分 native/remote/composed,可以在 `run_judger()` 调用前后补充更细的 `stage`;第一版不做额外 judger 类型字段。 + +### replay_buffer.py + +`ReplayBuffer.put()` 是 final status 的稳定位置,因为这里会处理 model version/staleness,也可能把 group 变成 expired: + +```python +await trace_event(group, "xtuner.replay_buffer.put.start") +... +await trace_event(state, "xtuner.final.completed", status=state.status.value) +await trace_event(result_group, "xtuner.replay_buffer.put.end") +``` + +如果状态是 aborted/failed/expired,则记录对应 final stage。 + +## 分布式范围 + +第一版不做分布式 trace store。trace store 只放在 producer/manager 所在进程,用来记录 producer 能直接观察到的 sample 阶段。 + +这样做的原因: + +- 当前目标是追踪 producer 发出的每个 task/sample 当前运行信息,不是做全链路分布式观测系统。 +- 引入 `RayTraceStoreActor` 会带来远程 append、actor 生命周期、查询一致性、性能和故障处理问题。 +- 第一版可以先把 producer、agent loop、judger、replay buffer 等 manager 侧可见阶段跑通。 + +如果后续确认必须追踪 rollout worker actor 内部 HTTP/retry/parse 的细粒度阶段,再单独增加集中式 Ray trace actor。那应该作为第二版能力,而不是第一版默认设计。 + +## 示例 + +### 内置链路 + +一个成功 sample 的 timeline 可能是: + +```text +xtuner.producer.sampled +xtuner.agent_loop.sample.start +xtuner.rollout.start +xtuner.rollout.controller.start +xtuner.rollout.worker_selected +xtuner.rollout.controller.end +xtuner.rollout.end +xtuner.judger.start +xtuner.judger.end +xtuner.agent_loop.sample.end +xtuner.replay_buffer.put.start +xtuner.final.completed +xtuner.replay_buffer.put.end +``` + +此时 viewer 的 task 表格中,这条 sample 的 latest stage 是 `xtuner.replay_buffer.put.end`。点开该 sample 后,timeline 中可以看到它进入过 `xtuner.judger.start` 和 `xtuner.judger.end`。 + +### Agentic RL 工具调用 + +用户 tool wrapper 中插桩: + +```python +async def call_tool(state, tool_name, payload): + stage_prefix = f"user.tool.{tool_name}" + async with trace_span(state, stage_prefix): + return await tool_registry[tool_name](payload) +``` + +生成的 stage: + +```text +user.tool.calculator.start +user.tool.calculator.end +``` + +如果要看哪些 sample 调用过 calculator,在 viewer 中按 stage prefix `user.tool.calculator.` 过滤即可。 + +## 测试建议 + +### 单元测试 + +1. `NoopTraceRecorder` 下 `trace_event/trace_span/trace_function` 不写 event。 +2. `InMemoryTraceStore.append` 后 `get_timeline` 返回 event。 +3. `get_latest` 返回 timeline 最后一条 event。 +4. `has_stage(stage=...)` 精确匹配生效。 +5. `has_stage(stage_prefix=...)` 前缀匹配生效。 +6. `max_events_per_trace` 驱逐单 trace 最老 event。 +7. `max_events` 驱逐全局最老 event。 +8. JSONL writer 在 `flush_jsonl()` / `close()` 后写出字段完整且不包含大对象。 +9. recorder 捕获 store append 异常,记录 warning,不影响调用方。 +10. `@trace_function` 对单 state 和 list state 都能写 start/end/error。 +11. viewer 从 JSONL 重建 latest stage、task table 和单 task timeline。 + +### 集成测试 + +1. producer 采样后能看到 `xtuner.producer.sampled`。 +2. rollout controller 等待 worker 返回前后能看到对应 stage。 +3. judger start/end 能出现在同一个 trace timeline 中。 +4. replay buffer put 后 final stage 正确。 +5. manager 侧可见插桩点写入后,JSONL viewer 能重建 latest stage 和 timeline。 + +## 后续项 + +这些能力暂时不进第一版: + +1. Event 附加字段:用于记录 tool name、env name、retry count、token count 等更细信息。 +2. 回调式 hook:event 入库后触发用户回调,生成派生事件或外部通知。 +3. Watchpoint 命中索引:为常用阶段维护快速布尔索引,避免每次扫描 timeline。 +4. 采集侧 Stage 过滤:只记录某些前缀,降低高频工具调用场景下的 trace 量。 +5. 写入性能策略:fire-and-forget、batch append、定期 flush。 +6. JSONL 压缩和保留策略:例如压缩已关闭 shard、按 run 清理旧 trace、只保留最近 N 个 shard。 +7. 更复杂的离线分析 CLI:例如批量导出卡住样本、按阶段耗时排序、跨 run 对比。 +8. 在线远程 dashboard:权限、多用户协作、跨机器聚合、训练进程 RPC。 +9. 集中式 Ray trace actor:当确实需要汇总 worker actor 内部阶段时再引入。 +10. Rollout worker/backend 内部阶段:HTTP start/end/retry、response parse start/end、backend name 等。 + +## 第一版验收标准 + +1. 开启 trace 后,每个 producer sample 至少能看到 sampled、rollout、judger、final 关键阶段。 +2. viewer 能展示每个 sample 当前最新 stage,且该值来自该 sample timeline 的最后一条 event。 +3. viewer 能展示可疑 open span,并给出 open count、oldest、p50、p95 和 oldest trace。 +4. viewer 能展示所有 sample 的 latest stage 分布。 +5. viewer 能展示单个 sample 的完整 timeline。 +6. viewer 能按 latest stage、stage prefix、task_name、uid / trace_id 过滤。 +7. 用户可以在自定义工具调用中通过 `trace_event/trace_span/trace_function` 记录 `user.*` 阶段,并能在 viewer 中看到。 +8. trace 关闭时业务链路无需额外判断,插桩调用不产生副作用。 +9. 第一版不依赖分布式 store;viewer 基于 JSONL 重建 trace 状态。 diff --git a/docs/superpowers/specs/2026-06-05-producer-task-trace-pseudocode.md b/docs/superpowers/specs/2026-06-05-producer-task-trace-pseudocode.md new file mode 100644 index 0000000000..7761ca3008 --- /dev/null +++ b/docs/superpowers/specs/2026-06-05-producer-task-trace-pseudocode.md @@ -0,0 +1,1338 @@ +# Producer Task Trace Minimal Pseudocode + +日期: 2026-06-05 + +说明: 这是一份最小功能版伪代码。第一版只有 timeline 这一份在线事实来源;latest 从 timeline 最后一条 event 派生,是否到达某阶段通过扫描 timeline 得到。第一版不支持 event 任意附加字段,用户关心的工具名、环境阶段、业务阶段先编码进 stage 字符串,例如 `user.tool.calculator.end`。 + +## 1. Core Types + +目标文件: + +```text +xtuner/v1/rl/trace.py +``` + +### 1.1 TraceConfig + +```python +class TraceConfig(BaseModel): + # 总开关。False 时使用 NoopTraceRecorder;调用方仍然可以无条件调用 trace_* API。 + enabled: bool = False + + # TraceStore 内存中保留的全局 event 数上限。超过后驱逐最老 event。 + # 只影响在线 timeline 查询,不影响已 append 到 JSONL 的历史。 + max_events: int = 100_000 + + # 单个 trace_id 的 timeline event 数上限。用于限制单个 agentic sample + # 在长工具调用 / 环境交互循环中产生过多事件。 + max_events_per_trace: int = 256 + + # JSONL 分片输出目录。None 时由 manager 解析成 worker_log_dir / "producer_traces"。 + # 第一版 trace 开启后固定启用 buffered JSONL writer,不提供单独 dump 开关。 + output_dir: Path | None = None +``` + +### 1.2 TraceEvent + +```python +@dataclass +class TraceEvent: + trace_id: str + stage: str + timestamp_s: float + + status: str | None = None + task_name: str | None = None + uid: int | str | None = None + session_uid: int | str | None = None + + train_step: int | None = None + model_step: int | None = None + producer_future_step: int | None = None + + worker_rank: int | None = None + elapsed_s: float | None = None + error_msg: str | None = None +``` + +## 2. Event Construction + +### 2.1 Trace ID + +```python +def build_trace_id_from_state( + state: RolloutState | None, + *, + task_name: str | None = None, + uid: int | str | None = None, +) -> str | None: + resolved_task_name = task_name + resolved_uid = uid + + if state is not None: + resolved_task_name = resolved_task_name or state.task_name + if resolved_uid is None: + resolved_uid = state.uid + + if resolved_uid is None: + return None + + return f"{resolved_task_name or 'unknown'}:{resolved_uid}" +``` + +### 2.2 Build Event + +```python +def build_event( + state: RolloutState | None, + stage: str, + *, + status: str | None = None, + task_name: str | None = None, + uid: int | str | None = None, + train_step: int | None = None, + model_step: int | None = None, + producer_future_step: int | None = None, + worker_rank: int | None = None, + elapsed_s: float | None = None, + error_msg: str | None = None, +) -> TraceEvent | None: + resolved_trace_id = build_trace_id_from_state( + state, + task_name=task_name, + uid=uid, + ) + if resolved_trace_id is None: + return None + + if state is not None: + resolved_task_name = task_name or state.task_name + resolved_uid = uid if uid is not None else state.uid + resolved_status = status or state.status.value + session_uid = state.session_uid + else: + resolved_task_name = task_name + resolved_uid = uid + resolved_status = status + session_uid = None + + return TraceEvent( + trace_id=resolved_trace_id, + stage=stage, + timestamp_s=time.time(), + status=resolved_status, + task_name=resolved_task_name, + uid=resolved_uid, + session_uid=session_uid, + train_step=train_step, + model_step=model_step, + producer_future_step=producer_future_step, + worker_rank=worker_rank, + elapsed_s=elapsed_s, + error_msg=error_msg, + ) +``` + +## 3. Store + +### 3.1 InMemoryTraceStore + +```python +TRACE_JSONL_SHARD_BYTES = 256 * 1024 * 1024 +TRACE_JSONL_FLUSH_INTERVAL_S = 1.0 +TRACE_JSONL_FLUSH_EVENTS = 1024 +TRACE_JSONL_FLUSH_BYTES = 1 * 1024 * 1024 + + +class BufferedTraceJsonlWriter: + def __init__(self, output_dir: Path): + output_dir.mkdir(parents=True, exist_ok=True) + self._jsonl_dir = output_dir + self._jsonl_shard_index = 0 + self._jsonl_shard_bytes = 0 + self._jsonl_path = self._jsonl_dir / "producer_trace_000000.jsonl" + self._file = self._jsonl_path.open("a", encoding="utf-8") + + self._buffer: list[str] = [] + self._buffer_bytes = 0 + self._last_flush_s = time.monotonic() + self._stop = False + self._lock = threading.RLock() + self._write_lock = threading.Lock() + self._condition = threading.Condition(self._lock) + self._thread = threading.Thread( + target=self._flush_loop, + name="xtuner-trace-jsonl-writer", + daemon=True, + ) + self._thread.start() + + def append(self, event: TraceEvent) -> None: + line = json.dumps(dataclasses.asdict(event), ensure_ascii=False) + "\n" + line_bytes = len(line.encode("utf-8")) + + with self._condition: + self._buffer.append(line) + self._buffer_bytes += line_bytes + + if ( + len(self._buffer) >= TRACE_JSONL_FLUSH_EVENTS + or self._buffer_bytes >= TRACE_JSONL_FLUSH_BYTES + ): + self._condition.notify() + + def flush(self) -> None: + with self._condition: + self._condition.notify() + self._flush_once() + + def close(self) -> None: + with self._condition: + self._stop = True + self._condition.notify() + self._thread.join(timeout=5.0) + self._flush_once() + self._file.close() + + def _flush_loop(self) -> None: + while True: + with self._condition: + if self._stop: + break + self._condition.wait(timeout=TRACE_JSONL_FLUSH_INTERVAL_S) + self._flush_once() + + def _flush_once(self) -> None: + with self._lock: + if not self._buffer: + return + + lines = self._buffer + self._buffer = [] + self._buffer_bytes = 0 + self._last_flush_s = time.monotonic() + + with self._write_lock: + try: + for line in lines: + self._write_line(line) + self._file.flush() + except Exception: + logger.warning("Trace JSONL flush failed", exc_info=True) + + def _write_line(self, line: str) -> None: + line_bytes = len(line.encode("utf-8")) + self._rotate_jsonl_if_needed(line_bytes) + self._file.write(line) + self._jsonl_shard_bytes += line_bytes + + def _rotate_jsonl_if_needed(self, next_line_bytes: int) -> None: + if self._jsonl_shard_bytes + next_line_bytes <= TRACE_JSONL_SHARD_BYTES: + return + self._file.flush() + self._file.close() + self._jsonl_shard_index += 1 + self._jsonl_shard_bytes = 0 + self._jsonl_path = ( + self._jsonl_dir / f"producer_trace_{self._jsonl_shard_index:06d}.jsonl" + ) + self._file = self._jsonl_path.open("a", encoding="utf-8") + + +@dataclass +class _StoredEvent: + sequence: int + event: TraceEvent + + +class InMemoryTraceStore: + def __init__(self, config: TraceConfig): + self.config = config + self._timeline: dict[str, deque[_StoredEvent]] = defaultdict(deque) + self._global_order: OrderedDict[int, _StoredEvent] = OrderedDict() + self._next_sequence = 0 + self._lock = threading.RLock() + + output_dir = resolve_output_dir(config.output_dir) + self._jsonl_writer = BufferedTraceJsonlWriter(output_dir) + + def append(self, event: TraceEvent) -> None: + with self._lock: + record = _StoredEvent( + sequence=self._next_sequence, + event=event, + ) + self._next_sequence += 1 + + trace_events = self._timeline[event.trace_id] + trace_events.append(record) + self._global_order[record.sequence] = record + + self._evict_trace_if_needed(event.trace_id) + self._evict_global_if_needed() + self._jsonl_writer.append(event) + + def _evict_trace_if_needed(self, trace_id: str) -> None: + trace_events = self._timeline[trace_id] + while len(trace_events) > self.config.max_events_per_trace: + oldest = trace_events.popleft() + self._global_order.pop(oldest.sequence, None) + + if not trace_events: + del self._timeline[trace_id] + + def _evict_global_if_needed(self) -> None: + while len(self._global_order) > self.config.max_events: + _, oldest = self._global_order.popitem(last=False) + trace_events = self._timeline.get(oldest.event.trace_id) + + if trace_events is None: + continue + + if trace_events and trace_events[0].sequence == oldest.sequence: + trace_events.popleft() + + if not trace_events: + del self._timeline[oldest.event.trace_id] + + def flush_jsonl(self) -> None: + self._jsonl_writer.flush() + + def close(self) -> None: + self._jsonl_writer.close() + + def get_timeline(self, trace_id: str) -> list[TraceEvent]: + with self._lock: + return [ + record.event + for record in self._timeline.get(trace_id, ()) + ] + + def get_latest(self, trace_id: str) -> TraceEvent | None: + with self._lock: + trace_events = self._timeline.get(trace_id) + if not trace_events: + return None + return trace_events[-1].event + + def has_stage( + self, + trace_id: str, + *, + stage: str | None = None, + stage_prefix: str | None = None, + status: str | None = None, + ) -> bool: + if stage is None and stage_prefix is None and status is None: + return False + + with self._lock: + events = [ + record.event + for record in self._timeline.get(trace_id, ()) + ] + + for event in events: + if stage is not None and event.stage != stage: + continue + if stage_prefix is not None and not event.stage.startswith(stage_prefix): + continue + if status is not None and event.status != status: + continue + return True + + return False + + def query_latest( + self, + *, + task_name: str | None = None, + stage: str | None = None, + stage_prefix: str | None = None, + status: str | None = None, + limit: int = 100, + ) -> list[TraceEvent]: + with self._lock: + newest_records = list(reversed(self._global_order.values())) + + matched = [] + seen_trace_ids = set() + for record in newest_records: + event = record.event + if event.trace_id in seen_trace_ids: + continue + seen_trace_ids.add(event.trace_id) + + if task_name is not None and event.task_name != task_name: + continue + if stage is not None and event.stage != stage: + continue + if stage_prefix is not None and not event.stage.startswith(stage_prefix): + continue + if status is not None and event.status != status: + continue + matched.append(event) + if len(matched) >= limit: + break + + return matched +``` + +说明: + +- store 不维护单独 latest 字典,`get_latest()` 直接取 timeline 最后一条。 +- store 不维护命中布尔索引,`has_stage()` 直接扫描该 trace 的 timeline。 +- store 不做 stage 过滤,第一版所有插桩 event 都记录。 +- JSONL 在 trace 开启时固定写分片,内存驱逐不会删除已写文件。 +- 第一版不做跨 Ray actor 的集中式 store。store 只放在 producer/manager 本地。 + +## 4. Recorder and Public Trace API + +```python +class TraceRecorder: + def __init__(self, store: InMemoryTraceStore): + self.store = store + + async def mark( + self, + state: RolloutState | None, + stage: str, + **kwargs, + ) -> None: + event = build_event(state, stage, **kwargs) + if event is None: + return + try: + self.store.append(event) + except Exception: + logger.warning("Trace append failed", exc_info=True) + + async def mark_many( + self, + states: Sequence[RolloutState], + stage: str, + **kwargs, + ) -> None: + for state in states: + await self.mark(state, stage, **kwargs) + + @asynccontextmanager + async def span( + self, + state: RolloutState | None, + stage_prefix: str, + **kwargs, + ): + start_s = time.monotonic() + await self.mark(state, f"{stage_prefix}.start", **kwargs) + + try: + yield + except Exception as exc: + elapsed_s = time.monotonic() - start_s + await self.mark( + state, + f"{stage_prefix}.error", + elapsed_s=elapsed_s, + error_msg=short_error(exc), + **kwargs, + ) + raise + else: + elapsed_s = time.monotonic() - start_s + await self.mark( + state, + f"{stage_prefix}.end", + elapsed_s=elapsed_s, + **kwargs, + ) + + @asynccontextmanager + async def span_many( + self, + states: Sequence[RolloutState], + stage_prefix: str, + **kwargs, + ): + # span_many only records a shared boundary for all given samples. + # Finer per-sample progress must be recorded by trace_event/trace_span inside each + # sample's own execution path. + states = list(states) + start_s = time.monotonic() + await self.mark_many(states, f"{stage_prefix}.start", **kwargs) + + try: + yield + except Exception as exc: + elapsed_s = time.monotonic() - start_s + await self.mark_many( + states, + f"{stage_prefix}.error", + elapsed_s=elapsed_s, + error_msg=short_error(exc), + **kwargs, + ) + raise + else: + elapsed_s = time.monotonic() - start_s + await self.mark_many( + states, + f"{stage_prefix}.end", + elapsed_s=elapsed_s, + **kwargs, + ) +``` + +### 4.2 NoopTraceRecorder + +```python +class NoopTraceRecorder: + async def mark(self, *args, **kwargs) -> None: + return None + + async def mark_many(self, *args, **kwargs) -> None: + return None + + @asynccontextmanager + async def span(self, *args, **kwargs): + yield + + @asynccontextmanager + async def span_many(self, *args, **kwargs): + yield +``` + +### 4.3 Trace Context + +```python +_NOOP_TRACE_RECORDER = NoopTraceRecorder() +_CURRENT_TRACE_RECORDER: ContextVar[ + TraceRecorder | NoopTraceRecorder +] = ContextVar( + "xtuner_current_trace_recorder", + default=_NOOP_TRACE_RECORDER, +) + + +@contextmanager +def use_trace_recorder(recorder: TraceRecorder | NoopTraceRecorder): + token = _CURRENT_TRACE_RECORDER.set(recorder) + try: + yield + finally: + _CURRENT_TRACE_RECORDER.reset(token) + + +def current_trace_recorder() -> TraceRecorder | NoopTraceRecorder: + return _CURRENT_TRACE_RECORDER.get() +``` + +说明: + +- manager 在 producer 生产循环外层调用 `use_trace_recorder(runtime.recorder)`。 +- 普通 async 调用链通过 `contextvars` 读取当前 recorder。 +- 第一版不依赖这个 context 跨 Ray actor 传播,所以 rollout worker/backend 内部先不插桩。 + +### 4.4 Target Normalization + +```python +def _as_state_list(value) -> list[RolloutState]: + if value is None: + return [] + if isinstance(value, RolloutState): + return [value] + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + return [item for item in value if isinstance(item, RolloutState)] + return [] +``` + +第一版 target 只认 `RolloutState` 或 `Sequence[RolloutState]`。如果 target 为空且没有显式 `task_name/uid`,不会记录 event。 + +### 4.5 trace_event + +```python +async def trace_event( + target: RolloutState | Sequence[RolloutState] | None, + name: str, + **kwargs, +) -> None: + recorder = current_trace_recorder() + states = _as_state_list(target) + + if states: + await recorder.mark_many(states, name, **kwargs) + return + + await recorder.mark(None, name, **kwargs) +``` + +### 4.6 trace_span + +```python +@asynccontextmanager +async def trace_span( + target: RolloutState | Sequence[RolloutState] | None, + name: str, + **kwargs, +): + start_s = time.monotonic() + await trace_event(target, f"{name}.start", **kwargs) + + try: + yield + except Exception as exc: + await trace_event( + target, + f"{name}.error", + elapsed_s=time.monotonic() - start_s, + error_msg=short_error(exc), + **kwargs, + ) + raise + else: + await trace_event( + target, + f"{name}.end", + elapsed_s=time.monotonic() - start_s, + **kwargs, + ) +``` + +### 4.7 trace_function decorator + +```python +def _resolve_trace_function_target( + target_name: str | None, + target_getter: Callable | None, + args, + kwargs, + bound_arguments, +): + if target_getter is not None: + return target_getter(*args, **kwargs) + if target_name is None: + return None + return bound_arguments.get(target_name) + + +def trace_function( + name: str, + *, + target: str | None = None, + target_getter: Callable | None = None, + result: Literal["input", "return"] = "return", +): + def decorate(fn): + signature = inspect.signature(fn) + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + bound = signature.bind_partial(*args, **kwargs) + input_target = _resolve_trace_function_target( + target, + target_getter, + args, + kwargs, + bound.arguments, + ) + + start_s = time.monotonic() + await trace_event(input_target, f"{name}.start") + + try: + value = await fn(*args, **kwargs) + except Exception as exc: + await trace_event( + input_target, + f"{name}.error", + elapsed_s=time.monotonic() - start_s, + error_msg=short_error(exc), + ) + raise + + if result == "return": + end_target = value if _as_state_list(value) else input_target + else: + end_target = input_target + + await trace_event( + end_target, + f"{name}.end", + elapsed_s=time.monotonic() - start_s, + ) + return value + + return wrapper + + return decorate +``` + +说明: + +- 装饰器只支持 async function。 +- 装饰器只根据 `target` 或 `target_getter` 找到一个 state 或一组 states,不自动识别函数内部细节。 +- `result="return"` 适合函数返回新的/更新后的 `RolloutState`。 +- `result="input"` 适合函数返回值不是 state,但入参 state 会被原地更新的场景。 +- 如果没有显式传 `target` 或 `target_getter`,第一版不做自动推断,直接 no-op。 + +## 5. Runtime Construction + +```python +@dataclass +class TraceRuntime: + config: TraceConfig + recorder: TraceRecorder | NoopTraceRecorder + store: InMemoryTraceStore | None + + def flush(self) -> None: + if self.store is not None: + self.store.flush_jsonl() + + def close(self) -> None: + if self.store is not None: + self.store.close() + + +def build_trace_runtime( + config: TraceConfig, + *, + worker_log_dir: Path, +) -> TraceRuntime: + if not config.enabled: + return TraceRuntime( + config=config, + recorder=NoopTraceRecorder(), + store=None, + ) + + resolved_config = config.model_copy( + update={ + "output_dir": config.output_dir + or worker_log_dir / "producer_traces", + }, + ) + + store = InMemoryTraceStore(resolved_config) + return TraceRuntime( + config=resolved_config, + recorder=TraceRecorder(store), + store=store, + ) +``` + +manager 在 producer 生产循环外层绑定当前 recorder: + +```python +async def run_producer_loop(self): + try: + with use_trace_recorder(self.trace_runtime.recorder): + await self.producer.produce() + finally: + self.trace_runtime.close() +``` + +## 6. Internal Store Reads and Viewer Data + +第一版不在 `AgentLoopManager` 上暴露用户查询 API,也不需要 `trace_store_get_timeline()` 这类薄 helper。用户通过 viewer 看状态;store 读方法只用于内部测试、内部 debug,或 viewer 读取内存快照时复用。 + +`InMemoryTraceStore` 已经提供这些内部方法: + +```python +store.get_timeline(trace_id) +store.get_latest(trace_id) +store.has_stage(trace_id, stage="xtuner.judger.start") +store.has_stage(trace_id, stage_prefix="user.tool.calculator.") +store.query_latest(stage="xtuner.rollout.controller.start") +``` + +第一版 viewer 的主要数据源是 JSONL: + +```python +@dataclass +class TraceViewerRow: + trace_id: str + task_name: str | None + uid: int | str | None + latest_stage: str + open_span: str | None + open_age_s: float | None + status: str | None + latest_timestamp_s: float + error_msg: str | None + event_count: int + + +@dataclass +class OpenSpanSummary: + span: str + open_count: int + oldest_open_age_s: float + p50_open_age_s: float + p95_open_age_s: float + oldest_trace_id: str + + +def load_trace_jsonl(output_dir: Path) -> dict[str, list[TraceEvent]]: + timelines: dict[str, list[TraceEvent]] = defaultdict(list) + for path in sorted(output_dir.glob("producer_trace_*.jsonl")): + for line in path.read_text().splitlines(): + event = TraceEvent.from_json(line) + timelines[event.trace_id].append(event) + return timelines + + +def span_name_from_stage(stage: str, suffix: str) -> str | None: + token = f".{suffix}" + if not stage.endswith(token): + return None + return stage[: -len(token)] + + +def get_open_spans(events: list[TraceEvent]) -> dict[str, TraceEvent]: + open_spans: dict[str, TraceEvent] = {} + for event in events: + start_span = span_name_from_stage(event.stage, "start") + if start_span is not None: + open_spans[start_span] = event + continue + + end_span = span_name_from_stage(event.stage, "end") + error_span = span_name_from_stage(event.stage, "error") + closed_span = end_span or error_span + if closed_span is not None: + open_spans.pop(closed_span, None) + return open_spans + + +def percentile(sorted_values: list[float], percentile_value: int) -> float: + if not sorted_values: + return 0.0 + index = round((len(sorted_values) - 1) * percentile_value / 100) + return sorted_values[index] + + +def build_viewer_rows( + timelines: dict[str, list[TraceEvent]], + *, + now_s: float, +) -> list[TraceViewerRow]: + rows = [] + for trace_id, events in timelines.items(): + if not events: + continue + latest = events[-1] + open_spans = get_open_spans(events) + newest_open_span = None + newest_open_event = None + if open_spans: + newest_open_span, newest_open_event = list(open_spans.items())[-1] + rows.append( + TraceViewerRow( + trace_id=trace_id, + task_name=latest.task_name, + uid=latest.uid, + latest_stage=latest.stage, + open_span=newest_open_span, + open_age_s=( + now_s - newest_open_event.timestamp_s + if newest_open_event is not None + else None + ), + status=latest.status, + latest_timestamp_s=latest.timestamp_s, + error_msg=latest.error_msg, + event_count=len(events), + ) + ) + return rows + + +def build_open_span_summaries( + timelines: dict[str, list[TraceEvent]], + *, + now_s: float, +) -> list[OpenSpanSummary]: + ages_by_span: dict[str, list[tuple[float, str]]] = defaultdict(list) + for trace_id, events in timelines.items(): + for span, start_event in get_open_spans(events).items(): + ages_by_span[span].append((now_s - start_event.timestamp_s, trace_id)) + + summaries = [] + for span, ages_and_trace_ids in ages_by_span.items(): + ages_and_trace_ids.sort() + ages = [age for age, _ in ages_and_trace_ids] + oldest_age, oldest_trace_id = ages_and_trace_ids[-1] + summaries.append( + OpenSpanSummary( + span=span, + open_count=len(ages), + oldest_open_age_s=oldest_age, + p50_open_age_s=percentile(ages, 50), + p95_open_age_s=percentile(ages, 95), + oldest_trace_id=oldest_trace_id, + ) + ) + return sorted( + summaries, + key=lambda item: (item.oldest_open_age_s, item.open_count), + reverse=True, + ) +``` + +viewer 上的 current/latest、timeline、reached-stage、open span 都从 `timelines` 这一个数据结构派生。 + +## 7. Code Insertion Points + +### 7.1 Producer + +目标文件: + +```text +xtuner/v1/rl/agent_loop_manager/producer.py +``` + +`generate_group()`。第一版只记录能绑定到具体 sample 的事件;没有 `uid` 的纯 group 级事件不进入 trace。业务代码不要手写逐 state 插桩循环,优先使用 `@trace_function`。 + +```python +@trace_function("xtuner.producer.generate_group", target="rollout_state", result="return") +async def generate_group( + self, + rollout_state: list[RolloutState], + *, + enable_partial_rollout: bool = False, +): + if isinstance(self.agent_loop, ray.actor.ActorHandle): + group = await self.agent_loop.generate_group.remote( + rollout_state, + enable_partial_rollout=enable_partial_rollout, + ) + else: + group = await self.agent_loop.generate_group( + rollout_state, + enable_partial_rollout=enable_partial_rollout, + ) + return group +``` + +这段只表示这组 sample 正在 producer 侧等待 agent loop 返回。某个 sample 后续是否进入 `agent_loop.sample`、`judger`、用户 tool 等更细阶段,依赖对应执行路径继续写自己的 `@trace_function`、`trace_span` 或 `trace_event`。 + +`put_generated_group()`: + +```python +@trace_function("xtuner.producer.put_buffer", target="group", result="input") +async def put_generated_group(self, group): + await self.replay_buffer.put(group) +``` + +### 7.2 Sampler + +目标文件: + +```text +xtuner/v1/rl/agent_loop_manager/sampler.py +``` + +采样出 `RolloutState` 后: + +```python +states = [RolloutState(...)] +await trace_event(states, "xtuner.producer.sampled", train_step=train_step) +``` + +如果当前 sampler 不方便处在 trace context 中,也可以把 `sampled` 放在 producer 收到 sampled states 后记录。优先选不会显著扩大 sampler 参数面的插入点。 + +### 7.3 Agent Loop + +目标文件: + +```text +xtuner/v1/rl/agent_loop/agent_loop.py +``` + +`generate_group()`: + +```python +@trace_function("xtuner.agent_loop.group", target="rollout_state", result="return") +async def generate_group(self, rollout_state: list[RolloutState]): + ... +``` + +`generate_sample()`: + +```python +@trace_function("xtuner.agent_loop.sample", target="rollout_state", result="return") +async def generate_sample(self, rollout_state: RolloutState, **kwargs): + ... +``` + +如果需要保留局部手写方式,也可以在 `generate_group()` 里包 `generate_one`: + +```python +async def generate_group(self, rollout_state: list[RolloutState]): + await trace_event(rollout_state, "xtuner.agent_loop.group.start") + + async def generate_one(state): + async with trace_span(state, "xtuner.agent_loop.sample"): + return await self.generate_sample(state) + + results = await asyncio.gather(*(generate_one(state) for state in rollout_state)) + await trace_event(results, "xtuner.agent_loop.group.end") + return results +``` + +`run_judger()`: + +```python +@trace_function("xtuner.judger", target="state", result="return") +async def run_judger(self, state: RolloutState) -> RolloutState: + return await self.judger.judge(state) +``` + +### 7.4 Single Turn Agent Loop + +目标文件: + +```text +xtuner/v1/rl/agent_loop/single_turn_agent_loop.py +``` + +只记录 rollout 周围的阶段,不重复记录 judger: + +```python +async def generate_sample(self, state: RolloutState) -> RolloutState: + async with trace_span(state, "xtuner.rollout"): + state = await self.rollout_controller.generate(state) + + state = await self.run_judger(state) + return state +``` + +### 7.5 Rollout Controller + +目标文件: + +```text +xtuner/v1/rl/rollout/controller.py +``` + +```python +async def generate(self, state: RolloutState) -> RolloutState: + async with trace_span(state, "xtuner.rollout.controller"): + worker, worker_rank = await self.select_worker(state) + await trace_event( + state, + "xtuner.rollout.worker_selected", + worker_rank=worker_rank, + ) + return await worker.generate.remote(state) +``` + +### 7.6 Rollout Worker and Backend + +目标文件: + +```text +xtuner/v1/rl/rollout/worker.py +xtuner/v1/rl/rollout/lmdeploy.py +``` + +第一版不修改这些文件。worker/backend 内部阶段通常发生在独立 actor/process 中;在不引入集中式 trace store 的前提下,这些 event 不容易汇总到 producer/manager 本地 store。 + +第一版只通过 `RolloutController.generate()` 记录 manager 可见的 worker selected / worker returned / worker failed。HTTP start/end/retry、response parse、backend name 等细粒度信息放到后续版本。 + +### 7.7 Replay Buffer + +目标文件: + +```text +xtuner/v1/rl/replay_buffer.py +``` + +```python +async def put(self, group: list[RolloutState]): + await trace_event(group, "xtuner.replay_buffer.put.start") + + result_group = await self._put_impl(group) + + for state in result_group: + final_stage = final_stage_from_status(state.status) + await trace_event(state, final_stage, status=state.status.value) + + await trace_event(result_group, "xtuner.replay_buffer.put.end") + + return result_group + + +def final_stage_from_status(status: RolloutStatus) -> str: + if status == RolloutStatus.COMPLETED: + return "xtuner.final.completed" + if status == RolloutStatus.ABORTED: + return "xtuner.final.aborted" + if status == RolloutStatus.EXPIRED: + return "xtuner.final.expired" + if status == RolloutStatus.FAILED: + return "xtuner.final.failed" + return "xtuner.final.unknown" +``` + +## 8. User Instrumentation + +### 8.1 Tool Wrapper + +```python +async def call_tool( + state: RolloutState, + tool_name: str, + payload, +): + stage_prefix = f"user.tool.{normalize_stage_part(tool_name)}" + + async with trace_span(state, stage_prefix): + return await tool_registry[tool_name](payload) +``` + +生成 timeline: + +```text +user.tool.calculator.start +user.tool.calculator.end +``` + +在 viewer 中按 stage prefix `user.tool.calculator.` 过滤,就能看到调用过 calculator 的 sample。 + +### 8.2 Environment Step + +```python +async def env_step(state, env, action): + stage_prefix = f"user.env.{normalize_stage_part(env.name)}.step" + async with trace_span(state, stage_prefix): + return await env.step(action) +``` + +生成 timeline: + +```text +user.env.browser.step.start +user.env.browser.step.end +``` + +## 9. Tests + +### 9.1 Latest Comes From Timeline + +```python +async def test_latest_comes_from_timeline(tmp_path): + config = TraceConfig(enabled=True, output_dir=tmp_path) + store = InMemoryTraceStore(config) + + store.append(TraceEvent("task:1", "xtuner.producer.sampled", time.time())) + store.append(TraceEvent("task:1", "xtuner.judger.start", time.time())) + + assert store.get_timeline("task:1")[-1].stage == "xtuner.judger.start" + assert store.get_latest("task:1").stage == "xtuner.judger.start" +``` + +### 9.2 Stage Scan + +```python +async def test_has_stage_scans_timeline(tmp_path): + config = TraceConfig(enabled=True, output_dir=tmp_path) + store = InMemoryTraceStore(config) + + store.append(TraceEvent("tool_agent:1", "user.tool.calculator.start", time.time())) + store.append(TraceEvent("tool_agent:1", "user.tool.calculator.end", time.time())) + + assert store.has_stage("tool_agent:1", stage="user.tool.calculator.end") + assert store.has_stage("tool_agent:1", stage_prefix="user.tool.calculator.") + assert not store.has_stage("tool_agent:1", stage="xtuner.judger.start") +``` + +### 9.3 Per Trace Eviction + +```python +async def test_max_events_per_trace(tmp_path): + config = TraceConfig( + enabled=True, + output_dir=tmp_path, + max_events=100, + max_events_per_trace=2, + ) + store = InMemoryTraceStore(config) + + store.append(TraceEvent("task:1", "stage.1", time.time())) + store.append(TraceEvent("task:1", "stage.2", time.time())) + store.append(TraceEvent("task:1", "stage.3", time.time())) + + timeline = store.get_timeline("task:1") + assert [event.stage for event in timeline] == ["stage.2", "stage.3"] + assert store.get_latest("task:1").stage == "stage.3" +``` + +### 9.4 Global Eviction + +```python +async def test_max_events_global(tmp_path): + config = TraceConfig( + enabled=True, + output_dir=tmp_path, + max_events=2, + max_events_per_trace=10, + ) + store = InMemoryTraceStore(config) + + store.append(TraceEvent("task:1", "stage.1", time.time())) + store.append(TraceEvent("task:2", "stage.2", time.time())) + store.append(TraceEvent("task:3", "stage.3", time.time())) + + assert store.get_timeline("task:1") == [] + assert store.get_latest("task:2").stage == "stage.2" + assert store.get_latest("task:3").stage == "stage.3" +``` + +### 9.5 Noop Recorder + +```python +async def test_noop_trace_api_is_safe(): + await trace_event(None, "xtuner.any.stage") + + async with trace_span(None, "xtuner.any.span"): + pass + + @trace_function("xtuner.any.function", target="state") + async def run(state): + return state + + await run(None) +``` + +### 9.6 JSONL + +```python +async def test_jsonl_written(tmp_path): + config = TraceConfig(enabled=True, output_dir=tmp_path) + store = InMemoryTraceStore(config) + + store.append(TraceEvent("task:1", "xtuner.producer.sampled", time.time())) + store.flush_jsonl() + + path = tmp_path / "producer_trace_000000.jsonl" + rows = [json.loads(line) for line in path.read_text().splitlines()] + assert rows[0]["trace_id"] == "task:1" + assert rows[0]["stage"] == "xtuner.producer.sampled" +``` + +### 9.7 Recorder Boundary Failure + +```python +async def test_recorder_swallow_append_failure(caplog): + class BrokenStore: + def append(self, event): + raise RuntimeError("broken") + + recorder = TraceRecorder(BrokenStore()) + + await recorder.mark(None, "xtuner.producer.sampled", task_name="task", uid=1) + + assert "Trace append failed" in caplog.text +``` + +### 9.8 Trace Function Decorator + +```python +async def test_trace_function_records_return_states(tmp_path): + config = TraceConfig(enabled=True, output_dir=tmp_path) + store = InMemoryTraceStore(config) + + class Runner: + @trace_function("xtuner.producer.generate_group", target="states", result="return") + async def run(self, states): + return states + + state = RolloutState(task_name="task", uid=1, ...) + + with use_trace_recorder(TraceRecorder(store)): + await Runner().run([state]) + + assert [ + event.stage + for event in store.get_timeline("task:1") + ] == [ + "xtuner.producer.generate_group.start", + "xtuner.producer.generate_group.end", + ] +``` + +### 9.9 Viewer Rows From JSONL + +```python +async def test_viewer_rows_use_latest_event(tmp_path): + path = tmp_path / "producer_trace_000000.jsonl" + path.write_text( + "\n".join( + [ + json.dumps( + { + "trace_id": "task:1", + "stage": "xtuner.producer.sampled", + "timestamp_s": 1.0, + "task_name": "task", + "uid": 1, + } + ), + json.dumps( + { + "trace_id": "task:1", + "stage": "xtuner.judger.start", + "timestamp_s": 2.0, + "task_name": "task", + "uid": 1, + } + ), + ] + ) + ) + + timelines = load_trace_jsonl(tmp_path) + rows = build_viewer_rows(timelines, now_s=12.0) + summaries = build_open_span_summaries(timelines, now_s=12.0) + + assert rows[0].trace_id == "task:1" + assert rows[0].latest_stage == "xtuner.judger.start" + assert rows[0].open_span == "xtuner.judger" + assert rows[0].open_age_s == 10.0 + assert summaries[0].span == "xtuner.judger" + assert summaries[0].open_count == 1 + assert summaries[0].oldest_open_age_s == 10.0 + assert [event.stage for event in timelines["task:1"]] == [ + "xtuner.producer.sampled", + "xtuner.judger.start", + ] +``` + +## 10. Minimal Implementation Order + +1. Add `xtuner/v1/rl/trace.py` with `TraceConfig`, `TraceEvent`, store, recorder, trace context, `trace_event`, `trace_span`, `trace_function`, and runtime builder. +2. Add unit tests for store, recorder boundary, trace context, `trace_event`, `trace_span`, `trace_function`, JSONL, and viewer data reconstruction. +3. Thread `TraceConfig` and `TraceRuntime` through agent loop manager construction. +4. Add producer and sampler insertion points. +5. Add agent loop and judger insertion points. +6. Add rollout controller insertion points. +7. Add replay buffer final stage insertion points. +8. Add `tools/producer_trace_viewer.py` to read JSONL and render suspect open spans, latest-stage summary, task table, and selected-task timeline. +9. Add a small doc example showing user `trace_span(state, "user.tool.calculator")` and how to open the viewer. + +## 11. Deferred Extensions + +这些能力刻意不进第一版: + +1. Event 任意附加字段。 +2. 回调式 hook。 +3. Watchpoint flags 索引。 +4. 采集侧 stage prefix allowlist。 +5. Fire-and-forget / batch flush。 +6. 按阶段类型配置是否采集。 +7. JSONL 压缩和保留策略:压缩已关闭 shard、按 run 清理旧 trace、只保留最近 N 个 shard。 +8. 更复杂的离线分析 CLI:批量导出、耗时排序、跨 run 对比。 +9. 在线远程 dashboard:权限、多用户协作、跨机器聚合、训练进程 RPC。 +10. 集中式 Ray trace actor。 +11. Rollout worker/backend 内部阶段。 diff --git a/docs/superpowers/specs/2026-06-09-trace-next-phase-working-notes.md b/docs/superpowers/specs/2026-06-09-trace-next-phase-working-notes.md new file mode 100644 index 0000000000..79ba9459dc --- /dev/null +++ b/docs/superpowers/specs/2026-06-09-trace-next-phase-working-notes.md @@ -0,0 +1,1858 @@ +# XTuner RL Observability 下一阶段工作记录 + +> 这是下一阶段 XTuner RL 系统可观测性工作的持续更新文档。用于记录已经确认的需求、当前的技术理解、待澄清问题,以及每一轮讨论的结论。后续讨论会持续追加到这份文档中。 + +## 文档目的 + +这份文档目前还不是最终的实现计划。 + +现阶段的目标是: + +1. 把用户下一阶段的需求集中记录到一个地方。 +2. 保留讨论上下文,便于后续设计和实现时追溯决策依据。 +3. 明确区分: + - 已确认需求 + - 当前理解 + - 未解决问题 +4. 在正式写实现计划之前,先形成稳定的工作底稿。 + +## 当前基础能力 + +当前 trace 已经具备这些能力: + +- 全局 trace runtime,以及 `trace_function` / `trace_span` / `trace_event` +- 写入 `producer_trace/` 下的 JSONL trace shard +- 在线 producer trace viewer +- 离线 hotspot HTML viewer +- `latest-produce-batch` 和 `all` 两种 scope +- 已有验证覆盖: + - trace API 语义 + - 在线 viewer payload + - 离线 hotspot payload + - trainer 自动启动 viewer + - trace enabled 的真实 smoke + - async + partial-rollout 的真实 smoke + - judger-fail 的真实 smoke + +当前相关文件: + +- `xtuner/v1/rl/trace.py` +- `xtuner/tools/producer_trace_analysis.py` +- `xtuner/tools/producer_trace_viewer.py` +- `xtuner/tools/producer_trace_hotspots.py` +- `xtuner/v1/train/rl_trainer.py` + +下一阶段与 sandbox agent loop 相关的关键文件: + +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py` +- 以及 `xtuner/v1/rl/agent_loop/sandbox_agent_loop/` 目录下可能相关的 infer / validator / sandbox stage 文件 + +## 当前总目标 + +当前目标已经从“producer 发出的 task trace”扩大为“XTuner 整个 RL 系统的 task-level observability”。 + +核心问题是:训练过程中,每个 task 当前执行到了哪里,是否还在运行、已经完成、失败,或者卡在某个阶段。 + +这个目标覆盖的是 RL task 的完整生命周期,而不是某一个模块内部的局部耗时: + +- sampler 取样 / 构造任务 +- producer 分发任务 +- agent loop 组织单条或多条 rollout +- rollout controller 调度生成 +- rollout worker / inference backend 执行 LLM generate +- sandbox / tool call 执行外部交互 +- judger / reward / verifier 评估结果 +- trainer / replay buffer / producer 回收结果 + +设计上的优先级: + +1. 训练现场能实时看 task 状态。 +2. 训练卡住时能定位 task 卡在哪个阶段。 +3. 训练结束或异常退出后能离线复盘。 +4. 不强依赖外部 trace backend。 +5. 数据模型逐步向 OpenTelemetry / OpenInference 靠齐,未来支持 exporter。 + +## 已确认需求 + +### 1. 合并在线 viewer 和离线 viewer + +当前在线 viewer 和离线 viewer 分开使用,体验不友好。 + +目标方向: + +- 合并为一个统一 viewer +- 同一套页面结构同时覆盖 live 查看和训练结束后的离线查看 +- 默认关注全量 task,而不是只看单个 batch 的 task 状态 + +页面主要包含以下几个板块: + +- `overview` +- `stage summary` +- `task list` +- `task detail` + +#### 1.1 Overview 需求 + +viewer 至少要展示: + +- 总 sample 数 +- 已完成 sample 数 +- 运行中 sample 数 +- 失败 sample 数 + +当前已确认: + +- `overview` 采用最小版方案 +- 不额外加入平均总耗时、P95 总耗时、当前 train step、最新 batch id 等信息 +- `overview` 只负责表达整体状态,不和下面的 stage summary 做重复 + +#### 1.2 Stage summary 需求 + +viewer 要展示: + +- 每个阶段当前有多少 task +- 每个阶段累计有多少 task 经过 +- 每个阶段的耗时统计: + - average + - p95 + - max + +用户的意图是:阶段级别的耗时统计应该直接在主 viewer 中展示,而不是还要单独看一个离线热点页。 + +当前已确认: + +- `stage summary` 采用双口径方案 +- 每个阶段同时展示: + - 当前运行中 task 数 + - 累计经过该阶段的 task 数 + - `avg / p95 / max` + +这意味着 unified viewer 的 stage summary 既要能表达: + +- 当前卡在哪个阶段 +- 又要能表达全量阶段统计 + +#### 1.2.1 全量 task 保留语义 + +用户最新明确要求: + +- viewer 和底层 trace 保留语义应当面向全量 task +- 不应默认只保留一个 batch 中 task 的状态 + +这意味着后续设计里需要重新审视以下几个点: + +- viewer 的默认 scope +- live trace index 的保留策略 +- 是否还保留 `latest-produce-batch` 作为可选过滤视图,而不是默认主视图 +- overview / stage summary / task list 的统计口径应该默认面向全量 task + +#### 1.3 需要移除的区块 + +当前这两个区块后续不再需要: + +- `Suspect Open Spans` +- `Latest Stage Distribution` + +#### 1.4 Task list 和 task detail 需求 + +viewer 需要包含: + +- 所有 task 的列表 +- task 的 filter 能力 +- 点开一个 task 之后,能同时看到: + - 文字版 timeline + - 图形版 timeline +- viewer 中需要显示 `error_msg` + +这里的图形版 timeline,指的是当前离线 hotspot viewer 里画出来的那种 timeline。 + +当前已确认: + +- `task detail` 采用上下布局 +- 上半部分展示文字版 timeline +- 下半部分展示图形版 timeline +- viewer 中需要展示 `error_msg` + +当前理解: + +- 文字版 timeline 负责事件级排查 +- 图形版 timeline 负责阶段耗时和嵌套关系观察 +- 不采用左右分栏,避免图形 timeline 的横向空间被压缩 +- `error_msg` 至少要在 viewer 中可见,便于失败任务排查 + +### 2. sandbox runner 迁移到当前 trace 体系 + +目标文件: + +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py` + +当前里面使用的 span 包括: + +- `with span(uid_obs, "run_total", task_id=tid)` +- `with span(uid_obs, "acquire", task_id=tid)` +- `with span(uid_obs, "infer", task_id=tid)` +- `with span(uid_obs, "validate", task_id=tid)` + +需求是: + +- 当前这些 span 表达出来的功能,需要被当前 trace 体系覆盖 +- 实现方向以当前 trace 体系为主 +- 用户偏好是尽量改成 `trace_function` +- 迁移之后功能必须验证通过 + +迁移时需要覆盖旧 `span(...)` 保存下来的信息: + +- `ts`: `time.time()` wall clock 时间戳 +- `event`: `enter` 或 `exit` +- `uid`: 任务 uid +- `stage`: 阶段名,例如 `run_total`、`acquire`、`infer`、`validate` +- 用户传入的轻量级 extra 字段,例如 `task_id` +- `exit` 事件额外包含: + - `duration_ms` + - `ok` + - `err` + - `annotate()` 追加的字段,例如 sandbox name、env id、sandbox url、image + +新的 producer trace 事件模型应当把这些能力标准化,而不是继续把所有语义都编码进 stage 字符串。 + +### 3. sandbox agent loop 是下一阶段 trace 的重点对象 + +用户下一阶段关注的重点,不再只是通用 producer trace,而是 sandbox agent loop。 + +用户希望 trace 的阶段表达更偏业务语义,不只是底层技术函数名。 + +用户举的示例阶段包括: + +- `sampler` +- `创建沙盒` +- `llm call1` +- `toolcall1` +- `llm call2` +- `toolcall2` +- `judger` + +这里有一个重要前提: + +- 用户明确表示这个需求现在还不够清楚 +- 后续需要通过持续提问,把这个阶段模型澄清到可以实现的程度 + +## 当前技术理解 + +### A. viewer 合并首先是数据模型问题,不只是页面样式问题 + +当前状态: + +- 在线 viewer 和离线 hotspot viewer 不只是渲染方式不同 +- 它们底层的 payload 结构和侧重点也不同 + +这意味着: + +- 真正合理的改造顺序,应该先统一分析层 / payload 层 +- 然后再统一页面结构 +- 最终在线和离线的差异应该收敛成: + - 相同页面结构 + - 相同 payload schema + - 不同数据来源模式 + +### B. `trace_function` 不能直接无侵入替换所有 inline `with span(...)` + +`trace_function` 是函数装饰器。 + +而 `runner.py` 里的 `acquire / infer / validate` 现在是写在一个 `run()` 里的 inline 代码块。 + +这意味着: + +- 如果目标是“尽量都改成 `trace_function`” +- 那么自然的实现方式大概率是把 `run()` 拆成几个 helper 方法,再给这些 helper 打装饰器 + +例如可能拆成: + +- `_acquire_infer_sandbox(...)` +- `_run_infer(...)` +- `_run_validate(...)` + +这不是额外重构,而是 instrumentation 方式决定的结构调整。 + +### C. sandbox 的阶段命名大概率需要两层 + +目前看有两种可能: + +1. 底层 trace stage:贴近真实函数 / 真实操作边界 +2. viewer 展示阶段:贴近用户理解的业务阶段 + +当前更推荐的方向是: + +- 保留底层 trace 的细粒度信息 +- 在 viewer 层做稳定的语义阶段映射 + +这一点还没有最终定稿,后续取决于和用户继续澄清的结果。 + +### D. 旧 sandbox `span(...)` 的事件语义需要并入标准 `TraceEvent` + +当前 sandbox 旧 trace 的 `with span(uid_obs, "run_total", task_id=tid)` 会向 `$WORK_DIR/trace/spans.*.jsonl` 写两类记录: + +```json +{"ts": 1710000000.0, "event": "enter", "uid": "sample_uid", "stage": "run_total", "task_id": "task_id"} +{"ts": 1710000001.2, "event": "exit", "uid": "sample_uid", "stage": "run_total", "duration_ms": 1200, "ok": true, "err": null, "task_id": "task_id"} +``` + +如果 span 内部抛异常,`exit` 事件会写: + +```json +{"event": "exit", "stage": "run_total", "ok": false, "err": "RuntimeError: ...", "duration_ms": 1200} +``` + +如果没有抛异常,但业务逻辑认为失败,旧实现需要显式调用 `SpanHandle.mark_error(...)`,否则 `ok` 仍然会是 `true`。 + +旧 `SpanHandle.annotate(...)` 会把运行时补充字段追加到 `exit` 事件中。例如 `acquire` span 会追加: + +- `sandbox_name` +- `sandbox_env_id` +- `sandbox_url` +- `sandbox_image` + +当前结论: + +- 新 `TraceEvent` 需要一等支持 `event` 字段,而不是只依赖 `stage` 的 `.start/.end/.error` 后缀。 +- 标准事件类型建议为: + - `enter` + - `exit` + - `error` + - `instant` +- `timestamp_s` 对应旧 `ts`。 +- `duration_ms` 对应旧 `duration_ms`,用于 `exit/error` 事件。 +- `ok` / `err` 需要作为标准字段保留,用于表达异常失败和业务失败。 +- 用户传入的 `task_id`、`entry_kind` 以及 `annotate()` 追加的 sandbox 信息应放到 `extra` 字段中。 +- `error_msg` 继续用于 viewer 展示完整错误摘要;`err` 用于兼容旧 span 语义和短错误字符串。 +- viewer / analysis 应优先使用标准 `event` 字段计算 open span、耗时和 error timeline。 +- 为兼容当前已有 producer trace JSONL,reader 仍需要识别旧 `.start/.end/.error` stage 后缀。 + +因此后续 sandbox 迁移的目标不是简单把旧 `spans.*.jsonl` 搬到新的目录,而是让 `trace_function/trace_span/trace_event` 能写出等价信息,并进入统一 viewer。 + +### E. OpenTelemetry / Langfuse / Phoenix / OpenInference 与当前 XTuner trace 的对比 + +调研时间:2026-06-11。 + +这里的“当前 XTuner trace”指当前代码已经实现的 producer trace,而不是未来目标设计: + +- 数据模型:`TraceEvent(trace_id, stage, timestamp_s, status, task_name, uid, session_uid, train_step, model_step, producer_future_step, produce_batch_id, worker_rank, elapsed_s, error_msg)`。 +- span 表达:`trace_span/trace_function` 当前通过 `{name}.start`、`{name}.end`、`{name}.error` 编码生命周期。 +- 存储:每个进程本地 global trace runtime,内存 timeline + buffered JSONL shards。 +- viewer:统一在线/离线 viewer,包含 overview、stage stats、task list、task detail、文字 timeline、图形 timeline、error 展示。 +- 约束:不存 prompt、response、tool result、图片、tensor 等大对象;主要关注 task/sample 当前阶段、耗时热点、失败原因。 +- 当前缺口:没有 `event=enter/exit/error/instant` 一等字段,没有 `extra/attrs` 一等字段,没有 `span_id/parent_span_id`,没有标准 context propagation,不能直接导出 OTLP。 + +#### OpenTelemetry + +官方语义: + +- trace 表示一次请求或任务在系统中的执行路径。 +- span 是 trace 的基本单元,包含 name、parent span id、start/end timestamp、span context、attributes、events、links、status。 +- context propagation 是分布式 tracing 的核心,用于把不同位置生成的 span 关联到同一条 trace。 +- exporter 可以把 trace 发到 stdout、OpenTelemetry Collector 或其他 backend。 + +和 XTuner 的对应关系: + +| OpenTelemetry 概念 | 当前 XTuner 对应物 | 当前差距 | +| --- | --- | --- | +| trace id | `trace_id = task_name:uid` | 语义类似,但不是 OTel 128-bit trace id,也没有 trace flags/state | +| span | `trace_span/trace_function` 产生的 `.start/.end/.error` event pair | 没有独立 `span_id`,嵌套只能靠 stage stack 和时间推断 | +| span parent | 无 | 当前无法精确表达 parent-child,只能按时间和 stage 名重建 | +| attributes | 少量固定字段,例如 `worker_rank`、`train_step` | 没有通用 `extra/attrs`,无法自然承接 `task_id`、`entry_kind`、sandbox 信息 | +| span events | `trace_event(...)` 瞬时 event | 当前和 span lifecycle 都混在 `stage` 字符串里 | +| span status | `status` / `error_msg` | `status` 是 sample 状态,不等同于 span status;缺 `ok/err` | +| context propagation | Ray env vars 传播 trace config | 没有传播当前 trace/span context | +| exporter | JSONL writer | 无 OTLP exporter,无 collector/backend 对接 | + +对 XTuner 的启发: + +- 应该学习 OTel 的数据边界:`event`、`stage/name`、`attributes`、`status` 要拆开。 +- 如果后续要精确显示嵌套 lane,`span_id/parent_span_id` 是比“靠栈猜”更稳的方案。 +- 如果后续要跨 Ray actor、inference server、judger service 串成一条链路,需要引入 context propagation,至少要传播 `trace_id` 和当前 parent span。 +- 第一版不建议直接把 OpenTelemetry SDK/Collector 作为强依赖,因为部署、采样、collector、exporter、后端查询都会明显加重系统复杂度。 +- 比较稳的路线是:内部保持轻量 JSONL + viewer,同时让数据模型逐步兼容 OTel 概念;将来再加 JSONL -> OTLP 的离线/可选 exporter。 + +#### Langfuse + +官方定位: + +- Langfuse 是面向 LLM 应用的 tracing/observability。 +- 它关注完整 request lifecycle:prompt、model response、token usage、latency、tool step、retrieval step、cost、eval 等。 +- 它的 UI 以 trace / session / observation 为核心,适合看一次 LLM app 请求内部发生了什么。 + +和 XTuner 的对应关系: + +| Langfuse 能力 | 当前 XTuner trace | 差异 | +| --- | --- | --- | +| trace/session/observation | `trace_id` 下的 task timeline | XTuner 是训练 sample/task 粒度,不是线上 request/session 粒度 | +| prompt/response 记录 | 不记录 | XTuner 为了体积和隐私,当前明确不写大文本 | +| token usage/cost | 不记录 | XTuner 当前更关注阶段进度和耗时,不关注 API 成本 | +| tool/retrieval step | 计划通过 sandbox/agent loop trace stage 表达 | 还没有标准 `span_kind=TOOL/RETRIEVER` | +| eval/score | `status/error_msg`,judger 阶段耗时 | 没有 Langfuse 那种 eval dataset / score dashboard | +| async batching | buffered JSONL writer | 目标相似:不阻塞主业务路径 | + +对 XTuner 的启发: + +- 对 agentic RL,Langfuse 的 trace/session/observation 层次很有参考价值,但不应该直接照搬 prompt/response 全量记录。 +- XTuner 更适合只保存“结构化执行骨架”:LLM call、tool call、sandbox、judger、error、duration、少量 extra。 +- 如果用户后续需要看 prompt/response/tool result,应该作为可选 debug dump,而不是默认 trace event。 +- Langfuse 更适合作为未来可选外部 exporter,而不是当前 producer trace 的内部存储。 + +#### Phoenix + +官方定位: + +- Phoenix 是面向 AI/agent/RAG 的 observability UI,强调用 OpenTelemetry 快速接入。 +- 它可以 trace LLM calls、tool executions、RAG retrievals,并把相关操作放在 parent spans 下面。 +- 它还覆盖 annotations、evaluation、session 视角,适合分析 agent 失败模式。 + +和 XTuner 的对应关系: + +| Phoenix 能力 | 当前 XTuner trace | 差异 | +| --- | --- | --- | +| OpenTelemetry 接入 | 无 | XTuner 现在不产生 OTLP span | +| parent spans 完整请求上下文 | 时间线 + open span 推断 | XTuner 缺 `span_id/parent_span_id` | +| LLM/tool/RAG 自动或半自动 tracing | 需要显式 `trace_function/trace_span` | XTuner 当前更偏手动插桩 | +| annotation/eval UI | 无 | XTuner viewer 只展示运行状态、耗时、错误 | +| session debug | `trace_id` task detail | XTuner 是训练 sample 视角,不是用户会话视角 | + +对 XTuner 的启发: + +- Phoenix 的“按 parent span 展开一次 agent 请求”的 UI 方向,和我们要看的 sandbox agent loop 很接近。 +- 但当前 XTuner viewer 的首要问题是训练卡住时定位 sample 卡在哪,不是做 eval/annotation 平台。 +- 可以学习 Phoenix 的 span 树和 span kind 渲染方式,但继续保留 XTuner 的 overview / stage stats / task list。 + +#### OpenInference + +官方定位: + +- OpenInference 是建立在 OpenTelemetry 之上的 AI observability 语义约定。 +- 它定义了 AI workload 的 attribute schema 和 span kind taxonomy。 +- 典型 span kind 包括 `LLM`、`AGENT`、`CHAIN`、`TOOL`、`RETRIEVER`、`RERANKER`、`EMBEDDING`、`GUARDRAIL`、`EVALUATOR`、`PROMPT`。 + +和 XTuner 的对应关系: + +| OpenInference 概念 | XTuner sandbox trace 可映射阶段 | +| --- | --- | +| `AGENT` | `agent_loop.generate_sample` / sandbox agent turn | +| `CHAIN` | sampler、producer wrapper、rollout orchestration | +| `LLM` | rollout generate / lmdeploy / sglang call | +| `TOOL` | toolcall、sandbox command、external API call | +| `EVALUATOR` | judger / reward / verifier | +| `GUARDRAIL` | format check、safety check、abort condition | +| `PROMPT` | prompt template / chat template 构造 | + +对 XTuner 的启发: + +- `span_kind` 很适合作为 XTuner `TraceEvent.extra` 或未来标准字段。 +- sandbox agent loop 的主阶段可以定义为 `sampler -> sandbox/acquire -> llm -> tool -> llm -> tool -> judger`,同时用 `span_kind` 表达阶段类型。 +- 第一版不应照搬 OpenInference 的完整 prompt/message/tool argument schema,否则 JSONL 会快速膨胀。 +- 只建议先借鉴 span kind taxonomy,把 viewer 的颜色、过滤和 stage summary 做得更可读。 + +#### Weave / MLflow 与 OpenTelemetry / OpenInference 的关系 + +这四者不是同一层面的东西,应该按“标准、AI 语义规范、平台产品”来区分。 + +| 对象 | 类型 | 是否是标准 | 是否提供 UI / 存储 | 与 OpenTelemetry 的关系 | 与 OpenInference 的关系 | +| --- | --- | --- | --- | --- | --- | +| OpenTelemetry | 通用 observability 标准、SDK、协议生态 | 是 | 不直接提供最终业务 UI,需要接 backend | 本体 | OpenInference 建在它之上 | +| OpenInference | AI trace 语义规范 | 是,偏 AI 领域语义约定 | 不直接提供最终业务 UI,需要接 backend | 基于 OpenTelemetry span / OTLP trace | 本体 | +| MLflow Tracing | LLM / Agent trace 平台 | 不是底层标准,是产品/平台 | 提供 tracking server、存储、UI、eval/feedback 等能力 | 官方宣称 fully OpenTelemetry-compatible,支持 GenAI semantic conventions | 可以承接 GenAI / OpenInference 风格的语义 | +| Weave | W&B 的 LLM / Agent trace 平台 | 不是底层标准,是产品/平台 | 提供 Weave backend、UI、trace compare、evaluation 等能力 | `Call` 概念类似 OpenTelemetry span,但不是 OTel 协议本身 | 可表达 LLM/tool/agent 语义,但不是 OpenInference 标准实现 | + +更具体地说: + +- OpenTelemetry 解决的是“trace 数据怎么表示、怎么传播、怎么导出”的通用问题。 +- OpenInference 解决的是“AI / LLM / agent trace 里的字段应该叫什么、span kind 怎么分类”的领域语义问题。 +- MLflow 解决的是“trace 落到哪里、怎么在 UI 里看、怎么和实验管理/eval/feedback 结合”的平台问题。 +- Weave 解决的是“在 W&B 生态里快速记录和查看 LLM / agent 调用树”的平台问题。 + +在 verl `rollout_trace.rst` 这条路径里: + +- `backend=weave`:把 trajectory / agent loop / tool call 记录到 Weave。 +- `backend=mlflow`:把 trajectory / agent loop / tool call 记录到 MLflow Tracing。 +- 没有直接使用 OpenTelemetry SDK。 +- 没有直接声明使用 OpenInference semantic convention。 + +对 XTuner 的含义: + +- 如果目标是做一个最小、低侵入、训练现场可用的 task trace,继续保留内置 JSONL + viewer 是合理的。 +- 如果目标是未来支持可选外部 exporter,数据模型应该向 OpenTelemetry / OpenInference 靠齐。 +- 如果目标是提供外部平台集成,可以先做 exporter: + - JSONL -> MLflow Tracing。 + - JSONL -> OTLP / OpenTelemetry。 + - JSONL -> Weave。 +- 默认不应该直接绑定 Weave 或 MLflow,否则用户必须先部署或登录外部平台才能看 trace。 + +#### 当前建议 + +从这些方案看,XTuner trace 下一步最应该做的是“向标准 trace 模型靠齐,但不引入完整外部观测平台”: + +1. 近期必须补齐 `TraceEvent.event`、`duration_ms`、`ok`、`err`、`extra`,覆盖旧 sandbox `span(...)` 的能力。 +2. sandbox agent loop 可以新增轻量 `span_kind`,枚举先参考 OpenInference:`AGENT`、`CHAIN`、`LLM`、`TOOL`、`EVALUATOR`。 +3. 暂时不默认记录 prompt、response、tool result,只记录阶段、耗时、状态、错误和轻量 extra。 +4. viewer 保持 XTuner 特色:overview、stage stats、task list、task detail,而不是变成通用 Langfuse/Phoenix clone。 +5. 中期如果嵌套 lane 继续变复杂,再增加 `span_id/parent_span_id`。 +6. 远期按需支持可选 OTLP exporter 或 JSONL -> OTLP 转换器。 + +### F. 常见 RL / RLHF 框架观测能力调研 + +调研范围: + +- `verl` +- `slime` +- `AReaL` +- `OpenRLHF` +- `ROLL` + +#### verl + +已确认能力: + +- README 明确支持 experiment tracking:`wandb`、`swanlab`、`mlflow`、`tensorboard`。 +- rollout 配置中有 `trace: TraceConfig`,但它不是 OpenTelemetry / OpenInference trace。 +- `verl.utils.rollout_trace` 当前明确支持的 rollout trace backend 是: + - `weave` + - `mlflow` + - `None` / `null`,表示关闭 trace +- 本地源码检索没有发现 verl rollout trace 直接接入 `OpenTelemetry`、`OpenInference` 或 `OTLP`。 +- 支持 Prometheus + Grafana 监控 rollout server metrics,但这不是 tracing backend,而是 metrics backend。 +- 支持 PyTorch profiler / Nsight profiling,输出 Chrome tracing / Perfetto 可读的性能 trace。 + +##### verl TraceConfig 和 trace backend + +`verl.workers.config.rollout.TraceConfig` 字段: + +| 字段 | 默认值 | 含义 | +| --- | --- | --- | +| `project_name` | `None` | trace backend 里的 project 名,默认从 `trainer.project_name` 取 | +| `experiment_name` | `None` | trace backend 里的 experiment / run 名,默认从 `trainer.experiment_name` 取 | +| `backend` | `None` | rollout trace backend,目前支持 `mlflow`、`weave`;`None` 表示关闭 | +| `token2text` | `False` | 是否把 `prompt_ids` / `response_ids` decode 成文本后写进 trace output | +| `max_samples_per_step_per_worker` | `None` | 每个 agent worker、每个 step 最多 trace 多少个 unique sample;`None` 表示全量 trace | + +`backend="weave"` 时: + +- 初始化时调用 `weave.init(project_name)`。 +- `rollout_trace_attr(...)` 用 `weave.attributes(attributes)` 给后续 call 附加 attributes。 +- `rollout_trace_op` 会通过 Weave client 创建 call: + - `op=func.__qualname__` + - `inputs=函数入参` + - `attributes=当前 call attributes` +- 成功时调用 `finish_call(output=...)`。 +- 异常时调用 `finish_call(exception=...)`。 + +`backend="mlflow"` 时: + +- 初始化时调用 `mlflow.config.enable_async_logging()`。 +- tracking URI 从 `MLFLOW_TRACKING_URI` 环境变量读取,默认是 `sqlite:////tmp/mlruns.db`。 +- 调用 `mlflow.set_experiment(project_name)`。 +- `rollout_trace_attr(...)` 用 `mlflow.start_span(name=...)` 建 span,并用 `mlflow.set_trace_tag(...)` 写 trace tags。 +- async 函数里用 `mlflow.start_span(name=func.__qualname__)`,并写入 `span.set_inputs(...)`、`span.set_outputs(...)`。 +- sync 函数直接走 `mlflow.trace(func)`。 + +`backend=None` 时: + +- `rollout_trace_op` 直接执行原函数。 +- 不创建 span,不写 trace backend。 + +verl trace 的采样逻辑: + +- `max_samples_per_step_per_worker=N` 时,每个 agent worker 每个 step 随机选择最多 `N` 个 unique sample index。 +- 如果同一个 sample 有多个 rollout,选中的 sample 会 trace 它的所有 rollout。 +- trace attributes 包括 `step`、`sample_index`、`rollout_n`、`validate`、`experiment_name`。 + +所以,verl 的 rollout trace 后端结论是: + +- 不是 OpenTelemetry。 +- 不是 OpenInference。 +- 不是本地 JSONL timeline viewer。 +- 是 Weave 或 MLflow 这类外部 LLM / experiment tracking backend。 +- 这对应 verl 的 `docs/advance/rollout_trace.rst` 文档,不对应 Prometheus/Grafana rollout metrics 文档。 + +##### verl Prometheus / Grafana 是 metrics,不是 trace + +verl 的 Prometheus / Grafana 链路主要用于 async rollout server 的系统观测: + +- Prometheus 负责从 vLLM / SGLang rollout server 的 `/metrics` endpoint scrape 时间序列指标。 +- verl 会在 `AgentLoopManager` 初始化 rollout server 后拿到 `server_addresses`。 +- 如果 `actor_rollout_ref.rollout.prometheus.enable=True`,verl 会自动更新 Prometheus 配置,把 rollout server 地址写入 scrape target,并尝试触发 Prometheus reload。 +- Grafana 读取 Prometheus 的时间序列并展示 dashboard。 +- 这类 dashboard 适合看 rollout server 吞吐、延迟、cache hit、queue、long-tail、资源空闲等。 +- 它不能表达单个 sample 的 span tree,也不能直接回答“某个 uid 卡在 toolcall / llm / judger 哪一步”。 + +Prometheus / Grafana 的启用条件包括: + +- `actor_rollout_ref.rollout.mode="async"` +- `actor_rollout_ref.rollout.disable_log_stats=False` +- `actor_rollout_ref.rollout.prometheus.enable=True` + +和 XTuner 的关系: + +- verl 的 `rollout_trace_op` 更像外接 Weave/MLflow 的 LLM trace;XTuner 当前 trace 是内置 JSONL + viewer。 +- verl 的 trace backend 选择对 XTuner 的直接参考是“外部 backend 应该可选”,而不是把内部 task trace 直接绑定到某个平台。 +- verl 的 Prometheus/Grafana 是系统 metrics,不是 task lifecycle trace;XTuner 当前 viewer 更偏 task 当前状态、timeline、stage duration。 +- verl 已经支持 trace 采样,这是 XTuner 后续在大规模 agentic 场景可能要补的能力。 +- 如果 XTuner 后续要支持外部观测平台,可以考虑新增 exporter,例如 JSONL -> OpenTelemetry / OpenInference / MLflow / Weave,而不是替换当前内置 viewer。 + +本地参考文件: + +- `/mnt/shared-storage-user/duanyanhui/workspace/code/verl/verl/utils/rollout_trace.py` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/verl/docs/advance/rollout_trace.rst` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/verl/docs/advance/grafana_prometheus.md` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/verl/docs/perf/torch_profiling.md` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/verl/verl/workers/config/rollout.py` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/verl/verl/trainer/config/rollout/rollout.yaml` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/verl/verl/experimental/agent_loop/agent_loop.py` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/verl/verl/experimental/agent_loop/prometheus_utils.py` + +#### slime + +已确认能力: + +- 提供 sample 级轻量 trace。 +- trace 数据保存在 rollout debug dump `.pt` 中。 +- 离线运行 `tools/trace_timeline_viewer.py` 生成 cache JSON 和 HTML viewer。 +- viewer 中每一行是 sample,条形块表示 span,点表示瞬时 event。 +- API 包括: + - `trace_span(target, name, attrs=...)` + - `trace_event(target, name, attrs=...)` + - `trace_function(name, ...)` + - `bind_trace(sample)` +- `trace_function` 支持 `target`、`target_getter`、`attrs_getter`。 +- 文档建议显式写 `target=...`,不要依赖自动推断。 +- `trace_span` 可以在执行过程中通过 span handle 追加属性。 + +和 XTuner 的关系: + +- slime 是和 XTuner 当前目标最接近的方案。 +- slime 的优势是 trace API 和 viewer 语义很干净,特别是 `attrs_getter` 和 span update。 +- XTuner 的优势是在线 viewer 和训练中实时状态;slime 主要是 rollout dump 后离线看。 +- XTuner sandbox 迁移应直接借鉴 slime 的 API 形状,但保留自己的 task overview 和 stage stats。 + +官方参考: + +- https://github.com/THUDM/slime/blob/main/docs/zh/developer_guide/trace.md + +#### AReaL + +已确认能力: + +- 依赖和配置中包含 `wandb`、`tensorboardx`、`swanlab`。 +- 有 `perf_tracer`: + - 输出 `traces.jsonl` + - 支持 `trace_scope(...)` + - 支持 `atrace_scope(...)` + - 支持 `instant(...)` + - 支持 `@trace_perf(...)` + - category 包括 `compute`、`comm`、`io`、`sync`、`scheduler`、`instr`、`misc` + - 可按 step save,可配置 profile steps +- 有 `session_tracer`: + - 输出 `sessions.jsonl` + - 记录 per-session lifecycle + - phase 包括 `generate`、`reward`、`toolcall` + - 记录 `task_id`、`session_id`、`status`、`reason`、`submit_ts`、`finalized_ts`、`total_s`、`generate_s`、`reward_s`、`toolcall_s` +- scaffolding 例子里有 `ChatTracer` / `TraceTrajectoryMaker`,用于多轮 chat / interaction trace,并把 trace result 送到 reward 计算。 +- SGLang inference service 中,如果 `server_args.enable_trace`,会调用 SGLang 的 `process_tracing_init(server_args.otlp_traces_endpoint, "sglang")`,也就是至少在 inference service 层具备 OTLP trace endpoint 对接能力。 + +##### AReaL session_tracer 的实现方式 + +AReaL 的 `session_tracer` 实现在 `areal/utils/perf_tracer.py`,不是独立的外部 backend。它挂在全局 `PerfTracer` 下面: + +- `PerfTracerConfig.session_tracer` 是可选配置。 +- `SessionTracerConfig.enabled=True` 时,`PerfTracer._configure_session_tracer(...)` 创建 `SessionTracer`。 +- 输出路径由 `_default_trace_path(...)` 生成,默认在 `logs/{user}/{experiment_name}/{trial_name}/session_tracer/.../sessions-r{rank}.jsonl`。 +- `flush_threshold` 默认是 `256`,表示 ready session 累积到阈值后批量 flush。 + +它的核心数据结构是 `SessionRecord`,不是一条条原始 span event: + +- 每个 logical task 先 `register_task(task_id)`。 +- 每个 sample / rollout 在 `@session_context()` 里 `register_session(task_id)`,生成递增的 `session_id`。 +- 当前 `task_id` 和 `session_id` 通过 `contextvars` 传播到 async 调用链。 +- phase 通过 `trace_session("reward")` 或 `atrace_session_phase("generate")` 记录 start/end。 +- 支持的内置 phase 是: + - `generate` + - `reward` + - `toolcall` +- 每个 phase 允许多次执行,内部保存为 `PhaseSpan(start_ts, end_ts)` 列表。 + +生命周期事件: + +- `mark_generate_start` / `mark_generate_end` +- `mark_reward_start` / `mark_reward_end` +- `mark_toolcall_start` / `mark_toolcall_end` +- `mark_finalized` +- `increment_counter` + +最终写出的 JSONL 是 session summary,一行对应一个 session,典型字段包括: + +- `task_id` +- `session_id` +- `rank` +- `role` +- `status` +- `reason` +- `submit_ts` +- `finalized_ts` +- `total_s` +- `generate_s` +- `reward_s` +- `toolcall_s` +- `phases` +- `counters` + +flush 规则: + +- `SessionRecord.status in {"rejected", "failed", "dropped"}` 时 ready。 +- `status == "accepted"` 且 `finalized_ts is not None` 时 ready。 +- ready session 达到 `flush_threshold` 时写入 JSONL。 +- `PerfTracer.save(force=True)` 或进程退出时也会 force flush。 + +典型调用链: + +- `WorkflowExecutor.submit(...)` / dataloader 路径调用 `perf_tracer.register_task(task_id)`。 +- workflow executor 执行任务前调用 `perf_tracer.set_task_id(task_id)`。 +- workflow 中 `@session_context()` 根据当前 task 创建 session。 +- `areal/workflow/rlvr.py` 中: + - `async with atrace_session_phase("generate")` 包住 `engine.agenerate(...)`。 + - `@trace_session("reward")` 包住 reward 计算。 +- agent 例子中: + - `async with atrace_session_phase("toolcall")` 包住工具执行。 +- workflow executor 在 accepted / rejected / failed 时调用 `trace_session_event("mark_finalized", task_id=task_id, status=..., reason=...)`,对该 task 下所有 session 做终态标记。 + +离线可视化: + +- `areal/tools/plot_session_trace.py` 读取 `sessions.jsonl`。 +- 它会展示 `total_s/generate_s/reward_s/toolcall_s` 的 histogram。 +- 也会从 `phases` 字段解析出 per-session timeline。 + +和 XTuner 的关系: + +- AReaL 的 `session_tracer` 和 XTuner 想要的 task lifecycle 很像,尤其是 `generate/reward/toolcall` 的 phase duration 统计。 +- AReaL 的 `perf_tracer` 更偏 Chrome tracing / Perfetto 这种系统性能视角;XTuner 当前 viewer 更偏 task list 和 task current stage。 +- AReaL 的 session 记录提供了一个很好的参考:可以把每个 task 的关键 phase 统计沉淀成一行 summary,同时保留 detailed timeline。 +- AReaL 已经在 inference service 层接触 OTLP;这说明“XTuner 内部轻量 trace + inference backend 可选 OTLP”是合理分层。 + +本地参考文件: + +- `/mnt/shared-storage-user/duanyanhui/workspace/code/AReaL/areal/api/cli_args.py` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/AReaL/areal/utils/perf_tracer.py` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/AReaL/examples/scaffolding/controllers.py` +- `/mnt/shared-storage-user/duanyanhui/workspace/code/AReaL/areal/experimental/inference_service/sglang/scheduler.py` + +#### OpenRLHF + +已确认能力: + +- README 明确支持 `Wandb` 和 `TensorBoard` logging。 +- 当前本地源码检索没有发现专门的 trace timeline / span API。 +- PPO observability 主要体现为训练指标和 per-phase timing,例如 `timing/make_experience`、`timing/ppo_train`、`timing/broadcast`、`timing/generation`、`timing/step_total`。 +- reward function 可以通过 `extra_logs` 把自定义指标写到 wandb。 + +和 XTuner 的关系: + +- OpenRLHF 更偏传统 experiment metrics logging。 +- 它没有明显提供“每个 sample 当前卡在哪个阶段”的 viewer。 +- XTuner 当前 trace 能补的是 OpenRLHF 这类 metrics logging 不擅长的 task-level live debugging。 + +官方参考: + +- https://github.com/OpenRLHF/OpenRLHF + +#### ROLL + +已确认信息: + +- 公开论文摘要中提到 ROLL 包含 rollout scheduler,并对 rollout 阶段每个 sample 的 lifecycle 做 fine-grained management。 +- 当前没有检索到可验证的官方源码入口或文档,无法确认它具体使用了 wandb、tensorboard、OpenTelemetry、Langfuse、Phoenix 或自研 trace。 + +对 XTuner 的参考价值: + +- ROLL 强调 sample lifecycle 管理,这和 XTuner 当前 trace 的目标一致。 +- 但在没有源码/文档佐证前,不应把它作为具体 API 或存储设计参考。 + +参考: + +- https://arxiv.org/abs/2506.06122 + +#### 横向结论 + +| 框架 | experiment metrics | task/sample trace | live viewer | external tracing backend | profiling / system metrics | +| --- | --- | --- | --- | --- | --- | +| XTuner 当前 | 训练原有 logging 另算 | 有,producer task JSONL timeline | 有,内置 viewer | 无 | 无专门集成 | +| verl | wandb / swanlab / mlflow / tensorboard | 有 rollout trace | 依赖 backend / tools | Weave / MLflow | Prometheus/Grafana、PyTorch profiler、Nsight | +| slime | wandb / tensorboard | 有 sample trace | 主要离线 viewer | 无明显外部 tracing backend | SGLang profiling / Chrome tracing | +| AReaL | wandb / tensorboardx / swanlab | 有 session_tracer / scaffolding trace | 主要文件和外部工具 | SGLang OTLP trace endpoint | perf_tracer、Chrome tracing / Perfetto 风格 | +| OpenRLHF | wandb / tensorboard | 未发现专门 span timeline | 无 | 无 | per-phase timing metrics | +| ROLL | 未确认 | 论文提到 sample lifecycle 管理 | 未确认 | 未确认 | 未确认 | + +对 XTuner 的直接结论: + +- 只做 wandb/tensorboard 指标不够,无法回答“100 个 task 里哪些卡在哪个阶段”。 +- slime / AReaL / verl 都说明 sample/session/rollout trace 是 agentic RL 真实需要的能力。 +- OpenTelemetry / OpenInference 适合作为语义模型参考,但不必直接成为第一版 runtime 依赖。 +- XTuner 当前最有价值的差异点是“在线 task viewer + stage distribution + stage duration stats”,这个方向应该保留并强化。 +- 下一步最重要的是补齐标准事件字段和 sandbox agent loop 阶段,而不是先接入外部观测平台。 + +## 当前候选工作流拆分 + +这里还不是最终实现任务,只是当前的工作拆分候选。 + +### 工作流 1:统一 viewer payload 和渲染 + +可能范围: + +- 把 viewer 和 hotspot analysis 收敛成一套统一 payload schema +- 合并在线和离线页面结构 +- 在一个 viewer 里同时支持: + - overview + - stage stats + - task list + - task detail + +大概率会涉及的文件: + +- `xtuner/tools/producer_trace_analysis.py` +- `xtuner/tools/producer_trace_viewer.py` +- `xtuner/tools/producer_trace_hotspots.py` + +### 工作流 2:task 级别的失败展示和过滤体验 + +可能范围: + +- 在统一 viewer 中显式展示 failed task +- 提供 task filter,例如: + - all / running / completed / failed + - stage filter + - task id / trace id / task name filter + +这一块重要的原因是:当前 viewer 虽然能看出失败,但还不够直观。 + +### 工作流 3:sandbox runner trace 迁移 + +可能范围: + +- 如果需要,拆分 `Runner.run()` +- 用 `trace_function` 取代现在的 inline `span(...)` +- 保持现有阶段功能等价 +- 覆盖成功路径和失败路径验证 + +主要目标文件: + +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py` + +可能还会涉及: + +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/trace.py` +- infer / validator 相关模块 + +### 工作流 4:sandbox 语义阶段模型 + +可能范围: + +- 定义 sandbox agent loop 的 canonical stage +- 决定这些 stage 是: + - trace 原生写出的主 stage + - 还是 viewer 层聚合出来的业务 stage +- 决定多轮 LLM / tool 调用在 viewer 里怎么表示 + +这一部分目前被需求澄清阻塞,还不能直接进入最终实现计划。 + +## 待澄清问题 + +这些问题目前没有定下来,在正式写实现计划之前必须继续澄清。 + +### Q1. sandbox 的语义阶段是展示层概念,还是 trace 原生阶段? + +两种设计: + +- 设计 A + - trace 里保留底层细粒度阶段 + - viewer 层映射成业务阶段 + +- 设计 B + - 直接把业务阶段作为 trace 里写出的主 stage + +当前推荐: + +- 设计 A 更稳,也更容易兼容现在这套 trace_function 体系 + +但这还需要用户最终确认。 + +### Q2. sandbox 的阶段到底怎么切? + +用户给的示例里有: + +- `llm call1` +- `toolcall1` +- `llm call2` +- `toolcall2` + +这会带来几个具体问题: + +- 阶段身份是不是基于交互轮次编号? +- viewer 里应该显示成: + - `llm_call_1`, `tool_call_1`, `llm_call_2`, ... + - 还是统一显示成 `llm_call` / `tool_call`,只是在内部有多段 span? +- viewer 是否需要显式展示“第几轮”,还是只按类别聚合? + +### Q3. “创建沙盒”具体覆盖哪些操作? + +可能的理解有: + +- 只包含 sandbox acquire +- 包含 acquire + sandbox URL / env setup +- 包含 acquire + 一些初始化 hook + +这个边界后面要进一步明确。 + +### Q4. unified viewer 中,live 特有的信息要不要保留? + +在线和离线本质上还是有时间语义差异: + +- 在线模式可以展示当前 in-flight stage +- 离线模式只能展示已经落盘的最终记录 + +需要确认的是: + +- 是不是同一个页面结构在 offline 下做 graceful degrade +- 还是说某些 live-only 信息在离线模式下应该直接消失 + +### Q4.1 `latest batch` 语义是否降级为次级过滤能力? + +用户最新明确要求默认保留所有 task,而不是只保留一个 batch 中的 task 状态。 + +因此需要进一步确认: + +- `latest-produce-batch` 是不是还需要保留 +- 如果保留,它应该是: + - 一个可选 filter + - 还是完全去掉 + +当前倾向: + +- 全量 task 作为默认主视图 +- `latest batch` 如果保留,应降级为一个可选过滤能力,而不是默认视图 + +### Q5. failed task 在 unified viewer 中怎么展示最合适? + +目前还没最终确定: + +- 只展示 failed count +- task list 增加 failed filter +- 增加 error stage distribution +- 增加 root cause stage +- 每个 task row 显示 error summary + +这个问题重要,因为当前失败虽然能定位,但交互不够直接。 + +### Q5.1 `error_msg` 的展示位置如何设计? + +用户已经明确要求 unified viewer 中要显示 `error_msg`。 + +还需要进一步确认的是: + +- 只在 task detail 中展示完整 `error_msg` +- 不在 task list 中增加单独的错误摘要列 +- 当前不要求把 `error_msg` 纳入 task list 搜索 + +## 讨论记录 + +### 2026-06-09 - 记录 1 - 引入下一阶段需求 + +用户确认前一阶段开发基本完成,开始提出下一阶段需求。 + +初始需求包括: + +1. 合并在线 viewer 和离线 viewer +2. 把 sandbox runner 中的 `span(...)` 能力迁到当前 trace 体系 +3. 以 sandbox agent loop 作为下一阶段 trace 的主要关注对象 + +### 2026-06-09 - 记录 2 - viewer 结构进一步明确 + +用户进一步澄清 unified viewer 应包含: + +- overview +- task 状态统计 +- 每个阶段的 task 数 +- 每个阶段的耗时统计:avg / p95 / max +- 所有 task 的列表和 filter +- 单个 task 的文字版 timeline 和图形版 timeline + +并明确表示下面两个当前区块不再需要: + +- `Suspect Open Spans` +- `Latest Stage Distribution` + +### 2026-06-09 - 记录 3 - sandbox runner trace 迁移目标明确 + +用户澄清: + +- `runner.py` 当前用到了 `with span(uid_obs, "validate", task_id=tid)` 以及相关 span +- 这些功能需要被当前 trace 体系覆盖 +- 偏好的实现方向是改成 `trace_function` + +### 2026-06-09 - 记录 4 - sandbox agent loop 成为下一阶段核心 + +用户进一步澄清,下一阶段重点应该是 sandbox agent loop。 + +用户举例的理想阶段序列是: + +- sampler +- 创建沙盒 +- llm call1 +- toolcall1 +- llm call2 +- toolcall2 +- judger + +同时用户明确说明: + +- 这个需求现在还不够清楚 +- 后续需要通过持续提问来把它澄清清楚 + +### 2026-06-09 - 记录 5 - 当前代码检查结论 + +代码检查发现: + +- `runner.py` 当前有四个 inline `span(...)`: + - `run_total` + - `acquire` + - `infer` + - `validate` +- 当前 producer trace 在线 viewer 和离线 hotspot viewer 是分离工具,payload 重点也不同 + +由此得到的当前结论: + +- viewer 合并应该从 analysis / payload 统一开始 +- `runner.py` 如果要以 `trace_function` 为主,极大概率需要把 inline block 拆成带装饰器的 helper 方法 + +### 2026-06-09 - 记录 6 - viewer 默认保留全量 task + +用户确认了 unified viewer 的理解方向,并进一步明确: + +- 后续应保存所有 task,而不是只保存一个 batch 中 task 的状态 + +这带来的直接设计影响包括: + +- viewer 默认统计口径应从 `latest batch` 转向 `all tasks` +- `latest batch` 相关语义如果继续存在,应下沉为可选过滤能力 +- live trace index 的保留策略需要重新评估,不能再默认只围绕 latest batch 展示 + +### 2026-06-09 - 记录 7 - overview 采用最小版 + +用户确认: + +- `overview` 采用最小版方案 +- 只展示以下 4 个数字: + - 总 sample 数 + - completed 数 + - running 数 + - failed 数 + +当前结论: + +- `overview` 不加入整体耗时统计 +- `overview` 不加入训练上下文字段 +- 阶段耗时统计统一放在 `stage summary` + +### 2026-06-09 - 记录 8 - stage summary 采用双口径 + +用户确认: + +- `stage summary` 采用双口径方案 + +当前结论: + +- 每个阶段需要展示: + - 当前运行中 task 数 + - 累计经过该阶段的 task 数 + - `avg / p95 / max` +- unified viewer 必须同时兼顾: + - live 视角下的“现在卡在哪” + - all-task 视角下的“整个过程哪些阶段最热” + +### 2026-06-09 - 记录 9 - task detail 采用上下布局 + +用户确认: + +- `task detail` 采用上下布局 + +当前结论: + +- 上半部分放文字版 timeline +- 下半部分放图形版 timeline +- 图形版 timeline 应尽量保持足够横向空间,避免采用左右布局导致压缩 + +### 2026-06-09 - 记录 10 - task list filter 方案确认 + +用户确认同意第一版 task filter 集合: + +- 状态 filter: + - all + - running + - completed + - failed +- 阶段 filter: + - 按当前阶段筛选 +- 搜索框: + - 支持 `trace_id / uid / task_name` +- scope filter: + - `all tasks / latest batch` + +当前结论: + +- 默认视图仍然是 `all tasks` +- `latest batch` 如果保留,只作为可选过滤能力 + +### 2026-06-09 - 记录 11 - viewer 需要显示 error_msg + +用户新增要求: + +- unified viewer 中需要显示 `error_msg` + +当前结论: + +- `error_msg` 已成为 viewer 的正式需求项 + +### 2026-06-09 - 记录 12 - error_msg 仅在 task detail 中展示 + +用户确认: + +- `error_msg` 采用最简展示方案 +- 只在 task detail 中展示完整 `error_msg` + +当前结论: + +- task list 不增加单独的 `error` 列 +- 第一版 task list 搜索不纳入 `error_msg` +- 失败任务的详细错误内容统一在 task detail 中查看 + +### 2026-06-09 - 记录 13 - unified viewer 第一阶段实现完成 + +当前实现已经落地并完成单测验证,范围包括: + +- `producer_trace_analysis.py` + - 新增 unified payload builder + - 统一生成: + - `overview` + - `stage_stats` + - `task_rows` + - `task_details` + - `stage_stats` 的耗时统计只看已闭合 span + - open span 只计入: + - 当前运行阶段 + - `running_tasks` + - `visited_tasks` +- `producer_trace_viewer.py` + - 在线 / 离线统一使用同一套 HTML 页面结构 + - 页面已经切换为: + - `overview` + - `stage summary` + - `task list` + - `task detail` + - 已移除: + - `Suspect Open Spans` + - `Latest Stage Distribution` +- `trace.py` + - `TraceConfig.viewer_scope` 默认值已经切到 `all` +- `producer_trace_hotspots.py` + - CLI 入口已改成输出统一 viewer 页面 + - 原有 hotspot payload builder 仍保留,用于现有测试和分析逻辑兼容 + +当前实现语义: + +- 默认主视图是 `all tasks` +- `latest-produce-batch` 保留为可选 scope filter +- `error_msg` 在 task detail 中展示 + +### 2026-06-11 - 记录 14 - 旧 sandbox span JSONL 语义需要纳入新 TraceEvent + +用户指出,在继续讨论 sandbox agent loop 的三个具体设计问题之前,需要先梳理清楚: + +- `with span(uid_obs, "run_total", task_id=tid)` 本身的功能 +- 它写出的 JSONL 具体包含哪些字段 +- 当前 XTuner 中所有 `with span(...)` 调用点都记录了什么 + +代码检查结论: + +- 旧 `span(...)` 是同步 context manager,可以包住 async 调用。 +- 每个 span 进入时写 `enter` 事件。 +- 每个 span 退出时写 `exit` 事件。 +- `exit` 事件记录 `duration_ms`、`ok`、`err`。 +- 调用方传入的 `task_id`、`entry_kind` 等字段会直接写入 JSONL。 +- `SpanHandle.annotate(...)` 追加的字段只出现在 `exit` 事件中。 + +当前 XTuner 中的 `with span(...)` 调用点集中在 sandbox agent loop: + +- `runner.py` + - `run_total` + - `acquire` + - `infer` + - `validate` +- `sandbox.py` + - `entry:{self.name}` with `entry_kind="ShellEntry"` + - `entry:{self.name}` with `entry_kind="DetachedShellEntry"` + +用户确认: + +- 新的 `TraceEvent` 也需要保存这些记录能力。 + +当前设计结论: + +- `event` 和 `stage` 要拆开。 +- `event` 表示 `enter/exit/error/instant`。 +- `stage` 表示纯阶段名,例如 `run_total`、`infer`、`validate`。 +- `duration_ms`、`ok`、`err` 成为 `exit/error` 事件的标准字段。 +- 旧 span 的用户字段和 `annotate()` 字段进入 `extra`。 +- viewer 和 analysis 后续应优先使用 `event` 字段,不再把 `.start/.end/.error` 后缀作为唯一标准。 +- 旧 `.start/.end/.error` 后缀仍保留为 reader 兼容逻辑。 + +### 2026-06-11 - 记录 15 - 开源 trace / observability 方案对比 + +用户要求更详细对比 OpenTelemetry、Langfuse、Phoenix、OpenInference 和 XTuner 当前 trace,并调研常见 RL 框架的观测方案。 + +当前调研结论已经记录到本文档: + +- `OpenTelemetry`:适合作为 trace 数据模型和未来分布式上下文传播的参考,不建议第一版直接引入完整 SDK/Collector 作为硬依赖。 +- `Langfuse`:适合 LLM app request 级 prompt/response/tool/cost tracing;XTuner 当前不应默认记录大文本和 tool result。 +- `Phoenix`:适合 agent/RAG trace UI 和 parent-span 展示,XTuner 可以借鉴 span tree / span kind 展示方式。 +- `OpenInference`:最适合提供 sandbox agent loop 的 `span_kind` 语义,例如 `AGENT/CHAIN/LLM/TOOL/EVALUATOR`。 +- `verl`:有 rollout trace,支持 Weave/MLflow backend,还有 Prometheus/Grafana 和 PyTorch profiler。 +- `slime`:和 XTuner 目标最接近,有 sample 级 trace、`trace_span/trace_event/trace_function/bind_trace`、离线 timeline viewer。 +- `AReaL`:有 `perf_tracer`、`session_tracer`、scaffolding trace,并在 SGLang inference service 层支持 OTLP trace endpoint。 +- `OpenRLHF`:主要是 wandb/tensorboard 指标和 phase timing,没有发现专门 sample/span timeline viewer。 +- `ROLL`:论文明确提到 sample lifecycle 管理,但未找到可验证的源码/文档来确认具体观测框架。 + +当前设计结论: + +- XTuner 继续保留轻量内置 JSONL + viewer。 +- 数据模型向 OpenTelemetry/OpenInference 靠齐。 +- sandbox agent loop 阶段语义借鉴 OpenInference 和 AReaL session phase。 +- API 形状继续靠近 slime。 +- 不把 Langfuse/Phoenix 这类平台作为第一版必需依赖。 +- 如果 `.error` event 没有显式 `error_msg`,viewer 会回退显示出错 stage 摘要 + +当前验证结果: + +- `python -m unittest discover -s tests/rl -p test_trace.py` + - 14 个测试通过 +- `python -m compileall -q ...` + - 通过 +- `git diff --check` + - 通过 + +### 2026-06-11 - 记录 16 - verl trace backend 澄清 + +用户澄清:这里关心的是 verl trace 的后端到底是什么,例如是不是 OpenTelemetry 或 OpenInference,而不是 Prometheus/Grafana 这类 metrics 系统。 + +代码检查结论: + +- verl rollout trace 的实际 backend 是 `weave` 或 `mlflow`。 +- `TraceConfig.backend=None` 时关闭 trace。 +- 本地 verl 源码中没有发现 rollout trace 直接接入 `OpenTelemetry`、`OpenInference` 或 `OTLP`。 +- Prometheus/Grafana 是 rollout server metrics 链路,不是 tracing backend: + - Prometheus scrape vLLM/SGLang rollout server 的 `/metrics`。 + - Grafana 读取 Prometheus 数据并展示 dashboard。 + - 它适合看吞吐、延迟、cache、queue、long-tail、资源空闲。 + - 它不能表达某个 sample 的 span tree 或 uid 当前执行阶段。 + +当前设计结论: + +- 后续对外部框架调研必须区分: + - tracing backend / tracing protocol,例如 OpenTelemetry、OpenInference、MLflow tracing、Weave tracing。 + - metrics backend,例如 Prometheus/Grafana。 + - profiler trace,例如 PyTorch profiler、Nsight、Chrome trace、Perfetto。 +- XTuner 当前内置 trace 的对标对象主要是 tracing backend / task lifecycle viewer,不是 Prometheus/Grafana。 +- XTuner 当前内置 trace 的使用体验更接近 verl `docs/advance/rollout_trace.rst` 中介绍的 rollout trace,而不是 verl `docs/advance/grafana_prometheus.md` 中介绍的 rollout metrics。 +- 如果 XTuner 未来接外部平台,推荐走 exporter,而不是把核心 task trace 存储直接绑定到某个 backend。 + +### 2026-06-11 - 记录 17 - Weave / MLflow / OpenTelemetry / OpenInference 分层 + +用户询问 Weave、MLflow 和 OpenTelemetry、OpenInference 的对比。 + +当前结论: + +- OpenTelemetry 是通用 trace 标准和 SDK / exporter / collector 生态。 +- OpenInference 是建立在 OpenTelemetry 之上的 AI observability 语义规范,定义 AI 相关 span kind 和 attributes。 +- MLflow Tracing 是 LLM / Agent trace 平台,提供存储、UI、eval、feedback 等能力,并兼容 OpenTelemetry / GenAI semantic conventions。 +- Weave 是 W&B 的 LLM / Agent trace 平台,核心概念是 `Op`、`Call`、`Trace`、`Thread`;`Call` 类似 OTel span,但 Weave 本身不是 OpenTelemetry 协议。 + +对 XTuner 的设计影响: + +- 内置 trace 应继续保持轻量,不默认绑定 Weave 或 MLflow。 +- 内部数据模型应该向 OpenTelemetry / OpenInference 靠齐,尤其是 `span_id`、`parent_span_id`、`span_kind`、`attributes` 这些概念。 +- 外部平台集成应以 exporter 方式提供,例如 JSONL -> MLflow、JSONL -> OTLP、JSONL -> Weave。 +- 默认 viewer 仍解决训练现场问题:某个 task 当前在哪个阶段、哪个阶段耗时长、失败原因是什么。 + +### 2026-06-11 - 记录 18 - AReaL session_tracer 实现 + +用户询问 AReaL 的 `session_tracer` 是怎么实现的。 + +代码检查结论: + +- `session_tracer` 实现在 `areal/utils/perf_tracer.py`,挂在 `PerfTracer` 下。 +- 配置是 `PerfTracerConfig.session_tracer: SessionTracerConfig | None`。 +- `SessionTracerConfig` 只有两个核心字段: + - `enabled` + - `flush_threshold` +- 它不是外部 backend,也不是 OpenTelemetry,而是 AReaL 自己的本地 JSONL summary tracer。 +- 输出文件是 `sessions-r{rank}.jsonl`,默认在 `session_tracer` 子目录。 +- 它用 `contextvars` 传播当前 `task_id` 和 `session_id`。 +- `@session_context()` 负责在当前 task 下注册一个 session。 +- `trace_session(...)` / `atrace_session_phase(...)` 负责记录 phase start/end。 +- 内置 phase 是 `generate`、`reward`、`toolcall`。 +- 每个 phase 可以出现多次,最终累加为 `generate_s`、`reward_s`、`toolcall_s`。 +- `mark_finalized` 负责写终态 `accepted/rejected/failed/dropped` 和 `reason`。 +- ready session 达到 `flush_threshold` 后批量写 JSONL,force save / 退出时也会 flush。 +- `areal/tools/plot_session_trace.py` 能从 `sessions.jsonl` 生成离线 HTML,包括 duration histogram 和 per-session timeline。 + +对 XTuner 的启发: + +- 可以把 AReaL 的 session summary 作为一个可选聚合层参考,但 XTuner 仍需要保留 event timeline 来支持在线 current-stage viewer。 +- AReaL 的 `task_id -> sessions`、`contextvars`、`phase summary`、`flush_threshold` 设计都值得参考。 +- 如果 XTuner 后续支持 summary JSONL,可以从 event timeline 离线聚合出来,不一定在运行时维护两套状态。 + +### 2026-06-11 - 记录 19 - XTuner trace 方案选择建议 + +用户询问 XTuner 更适合自研 trace,还是直接使用开源方案;如果用开源,应该选哪个。 + +当前判断: + +- XTuner 当前最核心的问题是训练现场排障: + - producer 发出的 task 当前在哪个阶段。 + - 哪些 task running / completed / failed。 + - 卡住时是 rollout / judger / toolcall / sampler 哪个阶段拖慢。 + - rank / Ray / 多进程场景下不依赖外部服务也能工作。 +- 这个需求和通用 LLM observability 平台不完全相同。 +- Weave / MLflow 更适合作为外部 trace backend 或 exporter,而不是 XTuner 第一版内置默认方案。 +- OpenTelemetry / OpenInference 更适合作为数据模型和语义规范参考,而不是第一版运行时强依赖。 + +推荐方案: + +- XTuner 继续保留自研轻量 trace runtime、JSONL、在线/离线 viewer。 +- 内部数据模型向 OpenTelemetry / OpenInference 靠齐: + - `trace_id` + - `span_id` + - `parent_span_id` + - `event` + - `stage` + - `span_kind` + - `attributes` + - `status/error` + - `start/end/duration` +- 默认只记录轻量结构化信息,不默认记录 prompt / response / tool result。 +- 后续按需增加 exporter: + - JSONL -> OTLP / OpenTelemetry + - JSONL -> MLflow + - JSONL -> Weave +- 如果必须优先选一个开源外部方案,建议优先考虑 MLflow / OpenTelemetry-compatible 路线,而不是 Weave: + - MLflow 更容易自建和内网部署。 + - MLflow 和 OpenTelemetry / GenAI semantic conventions 兼容性更好。 + - Weave 更适合已经深度使用 W&B 的用户,但不适合作为 XTuner 默认依赖。 + +不建议的方案: + +- 不建议第一版直接把 XTuner trace 存储替换成 Weave 或 MLflow,因为这会让用户必须先部署或登录外部系统才能看训练状态。 +- 不建议第一版直接全量引入 OpenTelemetry SDK / Collector / OTLP 后端,因为这会显著提高部署复杂度。 +- 不建议默认记录大文本、tool result、图片、tensor,因为训练规模下 JSONL / backend 成本会快速失控。 + +### 2026-06-11 - 记录 20 - 目标升级为 XTuner RL 系统可观测性 + +用户更新目标:现在要做的是 XTuner 整个 RL 系统的可观测性,核心是观测每个 task 的执行状态。 + +新的目标边界: + +- 不再只是 producer trace。 +- 不再只是 rollout trace。 +- 不再只是 sandbox agent loop trace。 +- 而是 task-level observability across XTuner RL system。 + +核心观测对象: + +- `task` +- `sample` +- `rollout` +- `session` +- `trajectory` + +这些对象之间的关系后续需要在正式设计里定义清楚,尤其是: + +- 一个 task 是否可以有多个 rollout。 +- 一个 task 是否等价于一个 sample。 +- 一个 rollout 是否对应一个 session。 +- rejected / failed / timeout 的 task 是否也要保留完整 timeline。 + +系统阶段覆盖范围: + +- sampler +- producer +- agent_loop +- rollout_controller +- rollout_worker +- inference backend +- sandbox / toolcall +- judger / reward / verifier +- trainer 回收结果 + +当前设计方向: + +- 继续保留 XTuner 内置 trace runtime、JSONL、在线/离线 viewer。 +- 以 task 当前状态为第一优先级。 +- 以 stage duration / hotspot 为第二优先级。 +- OpenTelemetry / OpenInference 用作数据模型和语义参考。 +- Weave / MLflow / OTLP 作为未来 exporter,而不是默认 backend。 + +### 2026-06-11 - 记录 21 - 如果彻底接入开源 trace 实现 + +用户询问:如果 XTuner 想彻底接入开源实现,而不是继续自研默认 trace,应该怎么选。 + +当前判断: + +- 如果目标是“彻底接入开源实现”,推荐的基础栈是: + - OpenTelemetry Python SDK:负责 trace/span/context/export。 + - OpenInference semantic conventions:负责 LLM / agent / tool / evaluator 的 AI 语义字段。 + - Phoenix 或 MLflow Tracing:作为可自建的 trace backend / UI。 +- 不建议选择 Weave 作为默认彻底接入方案: + - Weave 适合 W&B 生态用户。 + - 但它不是 OpenTelemetry / OpenInference 标准层。 + - 默认依赖它会把 XTuner 观测能力绑定到 W&B 生态。 +- Langfuse 也是可自建 LLM observability 平台,但更偏应用请求和产品化 LLM trace;对 XTuner 的 RL task current-state viewer 需要额外适配。 + +推荐排序: + +1. OpenTelemetry + OpenInference + Phoenix + - 最贴近 AI trace 标准。 + - Phoenix 原生面向 OpenInference / OTel 风格 AI tracing。 + - 适合 agent / tool / LLM / evaluator span tree。 + - 风险是 XTuner 的“每个 task 当前状态”需要在 Phoenix 之外补 task summary / live dashboard。 +2. OpenTelemetry + OpenInference + MLflow Tracing + - 更偏实验管理和自建 tracking。 + - 和训练生态更自然。 + - UI 可以看 trace,但 task live state dashboard 仍然需要额外开发。 +3. Langfuse + - 适合 LLM app observability。 + - 对 prompt / response / cost / session 支持强。 + - 对 RL 训练 task 状态、批量 rollout、judger、trainer 回收这些语义不是最贴合。 +4. Weave + - 适合 W&B 用户和快速接 agent trace。 + - 不建议作为 XTuner 的开源默认栈。 + +关键设计影响: + +- 一旦彻底接入 OpenTelemetry,XTuner 的 `trace_function` 不应该只写本地 JSONL,而应该创建 OTel span。 +- `task_id` / `uid` / `sample_id` / `rollout_id` 应作为 trace attributes。 +- `agent_loop` / `rollout.generate` / `toolcall` / `judger` 应映射为 spans。 +- `span_kind` 应采用 OpenInference 语义,例如 `AGENT`、`LLM`、`TOOL`、`EVALUATOR`、`CHAIN`。 +- online viewer 的“当前 task 状态”不能完全依赖通用 trace backend;需要: + - 要么额外做一个 task state aggregator。 + - 要么从 OTel span stream / backend query 中实时聚合。 + +因此,如果坚持彻底开源化,推荐架构是: + +- OTel/OpenInference 作为 trace 数据生成和标准语义。 +- Phoenix 或 MLflow 作为外部 trace UI。 +- XTuner 仍保留一个很薄的 task state aggregator,用于实时展示 task current state。 + +### 2026-06-11 - 记录 22 - 开源接入方案新增目标 + +用户补充 XTuner RL observability 的目标: + +- 支持分布式链路追踪:同一个 task 的不同阶段可能运行在 XTuner、lagent、Ray actor、inference backend 等不同进程或 repo 中,需要能连成同一条 trace。 +- 支持自定义 attributes:用户可以在一次 span 里传入自己感兴趣的字段。 +- 支持自动保存慢 task 或失败 task 的调用栈。 +- 保证 trace 结果可复现:同一个稳定数据样本在两次实验中的 trace id 应该一致。 +- 尽量少侵入式修改 XTuner 代码。 + +已经同步更新到独立设计文档: + +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + +当前设计含义: + +- 分布式链路追踪应基于 W3C Trace Context:`traceparent`、`tracestate`、必要时加 `baggage`。 +- 对 lagent 这类外部 repo,优先通过环境变量、HTTP header、CLI 参数或 payload 传播 context,避免强依赖 XTuner 内部代码。 +- 自定义 attributes 同时写入 OTel span 和 XTuner JSONL,但需要 sanitize 和大小限制。 +- 慢/失败调用栈作为 artifact 保存,span / event 只记录 artifact 路径。 +- trace id 由稳定 task identity hash 得到,不能依赖时间戳、递增计数器或实验名。 +- 业务代码继续使用 XTuner `trace_function` / `trace_span`,OTel/OpenInference 作为内部 backend。 + +### 2026-06-11 - 记录 23 - viewer 也优先使用开源实现 + +用户补充:viewer 也可以使用开源实现。 + +当前设计调整: + +- Phoenix / MLflow 这类开源 viewer 应作为用户查看 trace 的主入口。 +- XTuner 不默认维护完整自研 HTML viewer。 +- XTuner 保留一个薄的 task state aggregator / dashboard adapter。 +- task state aggregator 只负责补开源 viewer 未必天然具备的训练现场视角: + - 总 task 数。 + - running / completed / failed。 + - 当前阶段分布。 + - 阶段 avg / p95 / max。 + - failed task 和 error summary。 + - task uid 到 external trace id / trace URL 的映射。 +- 单个 task 的 span tree、LLM call、toolcall、judger、stack artifact 链接尽量交给 Phoenix / MLflow 展示。 + +已同步更新独立设计文档: + +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + +### 2026-06-11 - 记录 24 - 当前阶段只关注每个样本的 tracing + +用户明确收窄当前阶段范围: + +- 先不用管推理系统的 metrics。 +- 当前重点是每个样本的 tracing。 + +这意味着文档和后续实现计划里,`inference backend` / `LLM call` 只表示 sample timeline 里的一个 span,而不是推理系统整体观测。 + +保留关注: + +- 某个 sample 的 LLM call 什么时候开始。 +- 某个 sample 的 LLM call 什么时候结束。 +- 某个 sample 的 LLM call 耗时多久。 +- 某个 sample 的 LLM call 是否失败。 +- 某个 sample 的 LLM call 如何和 sampler、agent loop、toolcall、judger 串成同一条 trace。 + +当前不关注: + +- 推理引擎 QPS。 +- 吞吐。 +- GPU 利用率。 +- queue depth。 +- batching / micro-batching 状态。 +- KV cache / prefix cache 命中率。 +- 服务端整体 latency histogram。 + +后续如果要做推理系统 metrics,应作为独立 observability 子项目处理,可以考虑 Prometheus / Grafana 或 OTel Metrics,但不混入 sample-level trace 设计。 + +### 2026-06-11 - 记录 25 - 第一版 viewer/backend 选择 Jaeger + +用户明确:viewer 的后端第一版先用 Jaeger。 + +最新决策: + +- 第一版开源 trace 栈采用 `OpenTelemetry + OTLP + Jaeger`。 +- Jaeger 作为第一版 trace backend / viewer。 +- XTuner 通过 OTLP exporter 把 sample span 发给 Jaeger。 +- Jaeger UI 用来查看单个 sample 的 span tree、duration、error、attributes。 +- Phoenix 和 MLflow 不作为第一版默认 backend: + - Phoenix 后续作为 AI trace viewer 备选。 + - MLflow 后续作为 experiment trace viewer 备选。 +- OpenInference 不作为第一版 Jaeger 路线的硬依赖,只作为后续 AI 语义增强。 + +第一版 Jaeger 路线需要保留的 XTuner 自有能力: + +- `trace_function` / `trace_span` / `trace_event` 业务插桩 API。 +- sample/task 稳定 identity。 +- task current-state aggregator。 +- stage summary:running/completed/failed、当前阶段分布、avg/p95/max。 +- failed / slow task 的 stack artifact。 +- Ray / lagent / sandbox 的 W3C trace context 传播。 + +已同步更新独立设计文档: + +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + +### 2026-06-11 - 记录 26 - OTel API 封装目标与 OpenInference 依赖边界 + +用户提出两个疑问: + +1. 为什么业务代码不直接调用 OTel API,而是继续调用 `trace_function` / `trace_span`。 +2. OpenTelemetry 和 OpenInference 在 XTuner 中分别做什么,为什么看起来两个库都需要。 + +当前结论: + +- 业务代码继续调用 XTuner trace API 的目标是低侵入和语义集中。 +- producer、agent loop、rollout worker、judger 等业务模块只声明 sample 进入了哪个阶段,不直接处理 OTel 的 `TracerProvider`、`SpanProcessor`、context、exporter 等概念。 +- XTuner trace runtime 统一负责: + - 从 `RolloutState` / `list[RolloutState]` 解析 task id。 + - 生成稳定 trace id。 + - 写入 `xtuner.stage.kind`、status、error、rank、batch id 等 attributes。 + - 处理 `enabled=False`、JSONL fallback、Jaeger exporter、属性裁剪、慢/失败 stack artifact。 + - 处理 Ray / lagent / sandbox 的 trace context 传播。 +- OpenTelemetry 是第一版必需层,负责 trace/span/context propagation/OTLP export。 +- OpenInference 不是第一版 Jaeger 路线的必需依赖,只是后续 AI 语义增强: + - 把 `llm_call`、`tool_call`、`judge` 映射成标准 LLM / TOOL / EVALUATOR 语义。 + - 方便后续接 Phoenix 等更懂 AI trace 语义的 viewer。 + - 方便和其他 LLM/agent observability 工具共享字段语义。 + +第一版依赖边界更新为: + +```text +必需:OpenTelemetry SDK + OTLP exporter + Jaeger + XTuner xtuner.* attributes +可选:OpenInference semantic conventions / Phoenix / MLflow +``` + +已同步更新独立设计文档: + +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + +### 2026-06-11 - 记录 27 - 第一版 trace 必须原生面向 Agent in Sandbox + +用户强调:trace 设计必须原生地以 `agent in sandbox` 的 agent loop 为主场景,代码路径是: + +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py` +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py` +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py` +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/validator.py` +- `xtuner/v1/rl/agent_loop/sandbox_agent_loop/schemas.py` + +当前代码结构理解: + +- `AgentInSandboxLoop.generate_sample()` 从 `RolloutState.extra_fields["rollout_item"]` 拿到 `AgentRolloutItem`。 +- `Runner.run()` 是单 sample 的 sandbox 执行主链路,当前旧 span 是 `run_total -> acquire -> infer -> validate`。 +- `SandboxStage.run()` 负责 `pre -> entries -> post`。 +- `ShellEntry` / `DetachedShellEntry` 是 sandbox 中真正执行命令或 daemon 的 observable entry。 +- `JudgerValidator` 会 fan-out 多个 judger,并聚合 score。 +- `AgentInSandboxLoop._fill_rollout_state()` 会把 agent artifacts 转成训练所需的 `input_ids/labels/logprobs`,其中 `trace_store.export_training_trace` 也必须纳入 trace。 + +文档调整结论: + +- 第一版不应把 sandbox 仅仅看成普通 rollout 的一个 `toolcall`。 +- 第一版 canonical span tree 应以 sandbox agent loop 为主: + +```text +xtuner.agent_in_sandbox.generate_sample + ├── xtuner.agent_in_sandbox.throttle_sandbox_create + ├── xtuner.sandbox.runner.run_total + │ ├── xtuner.sandbox.acquire + │ ├── xtuner.sandbox.infer + │ │ ├── xtuner.sandbox.stage.pre + │ │ ├── xtuner.sandbox.entry + │ │ ├── xtuner.llm_call + │ │ └── xtuner.sandbox.stage.post + │ └── xtuner.sandbox.validate + │ ├── xtuner.sandbox.judger + │ └── xtuner.sandbox.judger.aggregate + └── xtuner.agent_in_sandbox.materialize_trajectory +``` + +旧 sandbox `span(...)` 迁移关系: + +- `run_total` -> `xtuner.sandbox.runner.run_total` +- `acquire` -> `xtuner.sandbox.acquire` +- `infer` -> `xtuner.sandbox.infer` +- `validate` -> `xtuner.sandbox.validate` +- `entry:{name}` -> `xtuner.sandbox.entry` + +第一版需要复用 sandbox domain objects: + +- `AgentRolloutItem` +- `StageRecord` +- `EntryRecord` +- `RolloutError` + +关键 attributes 包括: + +- `xtuner.agent.item_id` +- `xtuner.agent.uid` +- `xtuner.agent.group_id` +- `xtuner.sandbox.name` +- `xtuner.sandbox.env_id` +- `xtuner.sandbox.url` +- `xtuner.entry.name` +- `xtuner.entry.kind` +- `xtuner.entry.return_code` +- `xtuner.judger.name` +- `xtuner.judger.score` + +### 2026-06-11 - 记录 28 - runtime 默认 OpenTelemetry,不再保留自研 JSONL 后端 + +用户澄清:trace 默认就应该使用 OpenTelemetry,之前后端自己记录 JSONL 的代码可以删掉,整体实现要尽量少。 + +当前结论: + +- `trace_function` / `trace_span` 作为业务插桩 API 保留。 +- API 背后的 runtime 默认是 OpenTelemetry SDK + OTLP exporter。 +- `TraceConfig.enabled=True` 时直接创建 OTLP exporter,不再有 `otel_enabled` 开关。 +- `TraceConfig` 收敛为少量必要参数: + - `enabled` + - `otel_endpoint` + - `otel_service_name` + - `jaeger_query_url` +- 训练路径不再写 `producer_trace_*.jsonl`,不再维护 `InMemoryTraceStore`、`BufferedTraceJsonlWriter`、`TraceEventDispatcher` 这套自研后端。 +- rank0 不再自动启动读取本地 JSONL 的 producer trace viewer;如果用户配置 `jaeger_query_url`,rank0 只打印 Jaeger viewer 地址。 +- `xtuner.tools.producer_trace_*` 这类历史 JSONL viewer/analysis 工具如果保留,只作为旧数据兼容工具,不是新 trace runtime 的组成部分。 + +这个决定会降低当前 PR 里的实现面: + +- 没有 JSONL queue / flush / shard / drop 策略。 +- 没有 OTel JSONL exporter。 +- 不需要在 trainer 里解析或创建 trace output directory。 +- Ray 远端只传播 OTel endpoint 和 service name 等必要环境变量。 +- `xtuner.error.stage` +- `xtuner.error.category` +- `xtuner.error.message` +- `xtuner.diagnostic_artifact_path` + +已同步更新独立设计文档: + +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + +### 2026-06-11 - 记录 28 - 必须覆盖旧 sandbox `span(..., "validate")` 功能 + +用户强调:新 trace 必须覆盖旧 sandbox 调用: + +```python +with span(uid_obs, "validate", task_id=tid): + score, failed = await self.validate.run(item, pool) +``` + +这是一条硬性兼容约束,不只是把阶段改名为 `xtuner.sandbox.validate`。 + +旧 `span(...)` 能力需要完整保留: + +- enter 事件。 +- exit 事件。 +- `ts` 时间戳。 +- `uid`。 +- `stage`。 +- `task_id`。 +- `duration_ms`。 +- `ok`。 +- `err`。 +- `SpanHandle.annotate(...)` 追加字段。 +- `SpanHandle.mark_error(...)` 标记错误。 +- 异常路径记录。 + +迁移后对应关系: + +- `uid` -> `xtuner.agent.uid` / trace identity。 +- `stage="validate"` -> span name `xtuner.sandbox.validate`,并保留 `xtuner.legacy_stage="validate"`。 +- `task_id=tid` -> `xtuner.agent.item_id`。 +- `event="enter"` -> span start timestamp。 +- `event="exit"` -> span end timestamp。 +- `duration_ms` -> span duration / JSONL fallback `duration_ms`。 +- `ok` -> span status + `xtuner.span.ok`。 +- `err` -> `xtuner.error.message` + `xtuner.span.err`。 +- `annotate(...)` -> span attributes。 +- `mark_error(...)` -> error status + `xtuner.error.*` attributes。 + +同样需要覆盖的旧 span: + +- `run_total` +- `acquire` +- `infer` +- `validate` +- `entry:{name}` + +验收要求: + +- `validate` 成功时能看到耗时和 judger summary。 +- `validate` 失败时能看到 `ok=false`、error stage、error category、error message。 +- Jaeger 中能看到 `xtuner.sandbox.validate` span。 +- JSONL fallback 或 OTel span attributes 能重建旧脚本依赖的核心字段。 + +已同步更新独立设计文档: + +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + +### 2026-06-11 - 记录 29 - XTuner 默认插桩阶段划分 + +用户要求梳理 XTuner 默认应该包含哪些插桩阶段。 + +当前结论:默认插桩分为 P0 / P1 / P2。 + +P0 默认必须采集,缺失会影响卡住定位或旧功能兼容: + +- `xtuner.sampler.sample` +- `xtuner.producer.generate_group` +- `xtuner.agent_in_sandbox.generate_sample` +- `xtuner.sandbox.runner.run_total` +- `xtuner.sandbox.acquire` +- `xtuner.sandbox.infer` +- `xtuner.sandbox.entry` +- `xtuner.llm_call` +- `xtuner.sandbox.validate` +- `xtuner.sandbox.judger` +- `xtuner.sandbox.judger.aggregate` +- `xtuner.agent_in_sandbox.materialize_trajectory` +- `xtuner.producer.put_generated_group` +- `xtuner.sample.final` + +P1 默认采集,但 viewer 默认折叠: + +- `xtuner.agent_in_sandbox.throttle_sandbox_create` +- `xtuner.sandbox.stage.pre` +- `xtuner.sandbox.stage.post` +- `xtuner.sandbox.release_all` +- `xtuner.agent_in_sandbox.export_training_trace` +- `xtuner.partial_rollout_handler.preprocess` +- `xtuner.partial_rollout_handler.postprocess` + +P2 默认不采集,只在 debug 开启: + +- sandbox health check 每次轮询。 +- entry monitor 每次 probe / poll。 +- 每个 upload / download 文件。 +- stdout / stderr / daemon log 全文。 +- replay buffer 高频管理调用。 +- tokenizer 内部细节。 +- 推理 engine metrics。 + +viewer 默认展示阶段应合并为: + +```text +sample -> queue -> agent_loop -> sandbox.acquire -> sandbox.infer -> llm_call / entry -> validate -> commit -> final +``` + +普通 single-turn / rollout 路径仍保留兼容插桩: + +- `xtuner.agent_loop.generate_group` +- `xtuner.agent_loop.generate_sample` +- `xtuner.rollout_controller.generate` +- `xtuner.rollout_worker.generate` +- `xtuner.rollout_engine.generate` +- `xtuner.judger.judge` + +已同步更新独立设计文档: + +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + +### 2026-06-11 - 记录 30 - LLM call 需要记录首 token 返回耗时 + +用户补充需求:trace 需要记录一次推理引擎调用从发出请求到返回第一个 token 的时间,用来判断慢点是不是发生在推理引擎首 token 返回之前。 + +当前结论:这个能力应该放在 sample 维度的 `xtuner.llm_call` span 上,而不是扩展成推理系统 metrics。也就是说,我们仍然只关注“某个 task 的这次 LLM 调用”,不采集 engine QPS、GPU 利用率、queue depth、batching、KV cache 等系统指标。 + +`xtuner.llm_call` 需要记录以下字段: + +- `xtuner.llm.request_start_ts` +- `xtuner.llm.first_token_ts` +- `xtuner.llm.time_to_first_token_ms` +- `xtuner.llm.total_ms` +- `xtuner.llm.first_token_observed` +- `xtuner.llm.prompt_tokens` +- `xtuner.llm.completion_tokens` +- `xtuner.llm.backend` +- `xtuner.llm.stream` + +语义约束: + +- streaming 请求:从请求发出到收到第一个有效 token / chunk 的时间,写入 `time_to_first_token_ms`,并在 span 上追加 `first_token` event。 +- 非 streaming 请求:无法精确知道首 token 时间,只能记录请求发出到完整响应返回的时间;此时 `first_token_observed=false`,不能把完整响应耗时伪装成首 token 耗时。 +- viewer 中 `llm_call` 需要同时显示 `TTFT` 和 `total`,这样才能区分“首 token 前慢”和“首 token 后生成慢”。 +- 默认不记录 prompt / response 文本;如果后端提供 usage,可以记录 prompt / completion token 数。 + +实现时需要重点检查两类路径: + +- sandbox agent daemon / session server 的 streaming 路径:这里最适合捕获真实 first token。 +- rollout worker 调 inference engine 的 HTTP 边界:如果当前接口是非 streaming,只能先记录总请求耗时和 `first_token_observed=false`;后续要精确 TTFT 需要切 streaming 或让 engine 端显式上报 first-token event。 + +已同步更新独立设计文档: + +- `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + +## 后续更新规则 + +从现在开始,每一轮讨论都往这份文档里追加以下内容: + +1. 新确认的需求或决策 +2. 新发现的代码约束 +3. 新识别出的待澄清问题 +4. 如果有,补充拒绝某个替代方案的理由 + +在正式设计文档和实现计划写出来之前,这份文档就是这一阶段的单一工作记录源。 diff --git a/docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md b/docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md new file mode 100644 index 0000000000..fee43ef312 --- /dev/null +++ b/docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md @@ -0,0 +1,1281 @@ +# XTuner RL Observability 开源 Trace 接入设计 + +日期:2026-06-11 + +## 1. 背景 + +XTuner 之前的 trace 原型包含自研 JSONL writer 和本地 viewer。当前目标已经调整为:trace runtime 默认使用 OpenTelemetry,span 通过 OTLP 发到 Jaeger;不再维护训练路径里的自研 JSONL 后端。XTuner 只保留 `trace_function` / `trace_span` 这类 task-level 插桩 API 和 `xtuner.*` attributes,用来围绕每个 task 的完整生命周期回答它当前执行到了哪里、运行了多久、是否失败、失败原因是什么,以及系统整体卡在哪个阶段。 + +当前阶段只关注每个样本的 tracing,不关注推理系统整体 metrics。文档里提到 inference backend / LLM call 时,含义是“某个样本在推理阶段对应的 span”,例如这次 LLM call 开始、结束、耗时、错误、上下游 trace context,而不是 engine QPS、GPU 利用率、queue depth、吞吐、KV cache、batching 等系统指标。 + +第一版 trace 功能要原生面向 `agent in sandbox` 的 agent loop,而不是只从普通 single-turn rollout 抽象出来。主关注代码是 `xtuner/v1/rl/agent_loop/sandbox_agent_loop/`: + +- `agent_in_sandbox_loop.py`:把 `RolloutState.extra_fields["rollout_item"]` 里的 `AgentRolloutItem` 跑完,并把结果填回标准 `RolloutState`。 +- `runner.py`:单个 `AgentRolloutItem` 的顶层执行器,负责 `run_total -> acquire -> infer -> validate`。 +- `sandbox.py`:`SandboxStage` 的 `pre -> entries -> post` 执行,以及 `ShellEntry` / `DetachedShellEntry`。 +- `validator.py`:多个 sandbox judger 的 fan-out 和 score 聚合。 +- `schemas.py`:`AgentRolloutItem`、`StageRecord`、`EntryRecord`、`RolloutError` 等领域对象。 + +因此后续 trace 的阶段模型必须能直接表达 sandbox agent loop 的结构:sandbox 获取、agent daemon / shell entry、hook、judger、entry 失败、daemon log / diagnostics artifact、训练 trajectory materialize,而不是只表达普通 `rollout.generate -> judger`。 + +如果希望彻底接入开源实现,而不是长期维护一套完全自研的 trace backend,需要重新拆分问题: + +- trace 数据怎么产生。 +- trace 语义怎么标准化。 +- trace 如何跨进程、跨 Ray actor、跨 inference backend 传播。 +- trace 发到哪里。 +- 用户在哪里看。 +- XTuner 仍然需要保留哪些自己的 task-level 聚合能力。 + +本文档讨论的是“彻底接入开源 trace/observability 栈”的目标方案。 + +## 2. 目标 + +开源接入方案需要满足以下目标: + +1. 以 task 为中心观测整个 RL 系统。 +2. 第一版原生支持 agent in sandbox agent loop:sampler、producer、sandbox agent loop、sandbox acquire、infer stage、stage hook、entry、LLM call、validate / judger、trajectory materialize、trainer 回收结果等阶段。 +3. 支持在线查看 task 当前状态。 +4. 支持基于 Jaeger / OTel span 数据分析 task timeline 和阶段耗时热点。 +5. 尽量采用开源标准和开源 backend,避免锁死在单一 SaaS 平台。 +6. 默认不记录 prompt、response、tool result、图片、tensor 等大对象,除非用户显式开启。 +7. 支持分布式链路追踪:一条 task 的不同阶段可能运行在不同 repo、不同进程、不同 Ray actor、不同 HTTP 服务中,例如 lagent 独立进程。 +8. 支持用户自定义 span attributes:用户可以在一次 span 中传入自己关心的结构化字段。 +9. 支持自动保存慢 task 或失败 task 的调用栈,用于定位长时间卡住或异常失败的位置。 +10. 保证 trace identity 可复现:同一个稳定数据样本在两次实验中的 trace id 应该一致。 +11. 尽量少侵入式修改 XTuner 代码:业务代码继续使用 XTuner 的 trace API,不直接依赖 OTel API。 +12. 推理阶段只记录 sample 维度的 tracing span,不采集推理系统 metrics。 +13. 必须覆盖旧 sandbox `with span(uid_obs, "validate", task_id=tid)` 以及同类 `span(...)` 的功能语义,不能迁移后丢失 enter/exit、duration、ok/err、task_id、annotation、error 等信息。 + +## 3. 非目标 + +第一阶段不追求: + +- 替代所有训练 metrics logging。 +- 替代 Prometheus/Grafana 系统指标监控。 +- 采集推理系统 metrics,例如 QPS、吞吐、GPU 利用率、queue depth、engine batching、KV cache 命中率。 +- 默认记录完整 prompt / response / tool result。 +- 默认接入 Weave / W&B SaaS。 +- 要求 lagent 或其他外部 repo 大规模重构才能接入 trace。 +- 默认对所有 task 保存调用栈。 + +## 4. 推荐开源技术栈 + +推荐默认栈: + +```text +XTuner trace_function / trace_span + ↓ +OpenTelemetry SDK + OTLP exporter + ↓ +XTuner sample trace attributes + ↓ +Jaeger open-source trace backend / viewer + ↓ +XTuner task state aggregator / dashboard adapter, optional +``` + +备选栈: + +```text +XTuner trace_function / trace_span + ↓ +OpenTelemetry SDK + OTLP exporter + ↓ +XTuner sample trace attributes + ↓ +Phoenix 或 MLflow Tracing + ↓ +XTuner task state aggregator / dashboard adapter +``` + +如果后续需要更标准的 AI trace 语义,可以在上述链路里增加 OpenInference attributes: + +```text +XTuner trace_function / trace_span + ↓ +OpenTelemetry SDK + OTLP exporter + ↓ +XTuner sample trace attributes + optional OpenInference semantic attributes + ↓ +Jaeger / Phoenix / MLflow + ↓ +XTuner task state aggregator / dashboard adapter +``` + +分工: + +| 组件 | 角色 | 用途 | +| --- | --- | --- | +| OpenTelemetry | 通用 tracing 标准和 SDK | 生成 span、管理 trace context、跨进程传播、导出 OTLP | +| Jaeger | 第一版开源 trace backend / viewer | 接收 OTLP trace,查看单个 sample 的 span tree、duration、error、attributes | +| OpenInference | 可选 AI observability 语义规范 | 后续用于标准化 LLM、agent、tool、evaluator 等 AI span 的 kind 和 attributes | +| Phoenix | 可选开源 AI trace backend / viewer | 后续查看 OpenInference / OTel 风格的 agent trace | +| MLflow Tracing | 可选开源 trace / experiment 平台 | 后续把 LLM/agent trace 和 experiment、eval、feedback 结合 | +| XTuner task state aggregator | XTuner 自己的可选薄聚合层 | 后续如需“所有 task 当前状态”总览,可从 Jaeger / OTel span 查询结果聚合;第一版不再依赖本地 JSONL | + +第一版明确不保留训练路径里的自研 JSONL backend: + +- `TraceConfig` 不再有 `output_dir`、`max_events`、`max_events_per_trace`、`viewer_*`、`otel_enabled`、`otel_exporter` 这类旧参数。 +- `TraceConfig.enabled=True` 时默认创建 OpenTelemetry OTLP exporter。 +- rank0 不再自动启动读取 JSONL 的本地 producer trace viewer;如果配置了 `jaeger_query_url`,只打印 Jaeger viewer 地址。 +- 历史 JSONL viewer/analysis 工具如果仍保留,只能作为旧数据兼容工具,不是当前 trace runtime 的后端。 + +### 4.1 明确使用哪些开源库 + +这里需要明确一点:没有一个开源库能完整解决 XTuner 的全部需求。比较合适的方案是用开源库承担标准化、传播、存储和通用 trace UI,XTuner 只保留 task-level 语义和很薄的聚合胶水。 + +默认推荐的 Python 包: + +| 功能 | 默认使用的开源库 | XTuner 仍然负责什么 | +| --- | --- | --- | +| span API 和 trace context | `opentelemetry-api` | 业务代码不直接调用 OTel API,而是继续调用 `trace_function` / `trace_span` | +| span 创建、processor、sampler | `opentelemetry-sdk` | 初始化全局 tracer provider,并把 XTuner trace config 映射成 OTel 配置 | +| OTLP 导出 | `opentelemetry-exporter-otlp-proto-grpc` | 配置 endpoint、batch processor、resource attributes,例如 experiment id / rank / worker role | +| HTTP OTLP 导出备选 | `opentelemetry-exporter-otlp-proto-http` | 只在目标 backend 更适合 HTTP OTLP 时启用 | +| 跨进程 context 传播 | OpenTelemetry 内置 W3C Trace Context / Baggage propagator | 在 Ray task payload、环境变量、HTTP header、lagent 请求参数里 inject / extract `traceparent`、`tracestate`、必要的 `baggage` | +| sample trace 语义 | XTuner 自定义 `xtuner.*` attributes | 第一版直接表达 task id、stage、status、error、duration、rank、batch id 等 sample tracing 信息 | +| AI trace 语义增强 | `openinference-semantic-conventions` | 可选把 XTuner 的 sampler / rollout / llm / tool / judger 阶段映射成 OpenInference attributes | +| AI instrumentation 基础能力 | `openinference-instrumentation` | 可选启用;核心 task stage 仍由 XTuner 自己插桩 | +| OpenAI / LangChain / LlamaIndex 等自动插桩 | `openinference-instrumentation-openai`、`openinference-instrumentation-langchain`、`openinference-instrumentation-llama-index` 等 | 仅当对应依赖真的出现在 agent loop 或 lagent 进程中时可选启用,不作为 XTuner 默认依赖 | +| 开源 viewer / backend 默认方案 | Jaeger binary / container,例如 `jaegertracing/jaeger` | XTuner 通过 OTLP 把 span 发给 Jaeger;Jaeger UI 展示单 sample span tree | +| 开源 AI viewer 备选方案 | `arize-phoenix` | 后续如果更关注 LLM/tool/evaluator 语义,可以把 Phoenix 作为 AI trace viewer | +| 开源 experiment viewer 备选方案 | `mlflow` | 当用户更希望 trace 和 experiment tracking 放在同一个平台时,用 MLflow Tracing 作为后端 | +| HTTP client 请求 span | `opentelemetry-instrumentation-requests`、`opentelemetry-instrumentation-httpx`、`opentelemetry-instrumentation-aiohttp-client` | 只用于捕获某个 sample 发往 inference server / sandbox service / tool service 的请求耗时和错误;不采集服务端系统 metrics | +| logging 关联 trace id | `opentelemetry-instrumentation-logging` | 可选把 trace id / span id 注入日志,方便从日志跳回 trace | +| 慢 task / 失败 task 调用栈 | Python 标准库 `traceback`、`faulthandler`、`sys._current_frames`、`asyncio.Task.get_stack` | XTuner 自己实现采样触发条件、artifact 落盘、artifact 路径写入 span attributes | + +第一阶段建议把依赖拆成 extras,而不是让 XTuner 默认安装所有 trace 相关包: + +```text +xtuner[trace] + opentelemetry-api + opentelemetry-sdk + opentelemetry-exporter-otlp-proto-grpc + +xtuner[trace-jaeger] + xtuner[trace] + # Jaeger 本身作为外部 binary/container 启动,不是 Python package。 + +xtuner[trace-ai] + openinference-semantic-conventions + +xtuner[trace-phoenix] + xtuner[trace] + xtuner[trace-ai] + arize-phoenix + +xtuner[trace-mlflow] + xtuner[trace] + mlflow + +xtuner[trace-http] + opentelemetry-instrumentation-requests + opentelemetry-instrumentation-httpx + opentelemetry-instrumentation-aiohttp-client +``` + +### 4.2 每一层具体怎么用 + +#### 4.2.1 XTuner 业务插桩层 + +继续保留 XTuner 自己的 API: + +- `trace_function(stage=...)` +- `trace_span(stage=...)` +- `trace_event(stage=...)` +- 后续可以加 `trace_annotate(...)` + +原因是 XTuner 的业务代码关心的是 task 和 stage,而不是 OTel 的 SpanKind、TracerProvider、Context、SpanProcessor。这样可以保证后续从 Jaeger 切到 Phoenix / MLflow,或者同时写 JSONL fallback,都不需要改业务代码。 + +这一层的目标不是隐藏 OTel,而是把 OTel 和 XTuner 业务语义解耦: + +- 业务模块只声明“当前 sample 进入了哪个阶段”,例如 `llm_call`、`tool_call`、`judge`。 +- trace runtime 统一负责把 XTuner 概念翻译成 OTel span、attributes、status、event、context propagation。 +- `RolloutState` / `list[RolloutState]` 到 task id、stable trace id、stage kind、error msg 的解析集中在 trace runtime,不散落到 producer、agent loop、rollout worker、judger 等模块。 +- `enabled=False`、OTLP exporter、属性裁剪、慢/失败 stack artifact 等策略集中实现,避免业务代码写大量观测系统细节。 +- 后续切换或新增 backend 时,业务插桩点不需要重写。 + +也就是说,业务代码调用 XTuner API 是为了保留 XTuner 的 task-level 语义和低侵入性;OpenTelemetry 是这个 API 背后的一个 backend/runtime。 + +#### 4.2.2 OpenTelemetry 标准层 + +XTuner trace runtime 在内部创建 OTel span: + +- `trace_function("xtuner.rollout.generate")` 进入时创建 span。 +- 函数正常返回时结束 span,设置 `status=OK`。 +- 函数异常时结束 span,设置 `status=ERROR`,记录 `error.type`、`error.message`。 +- task id、stage、rank、worker role、train step、produce batch id 等作为 span attributes。 + +跨进程时使用 OTel 的 propagator: + +- 父进程把当前 context inject 到 carrier。 +- Ray actor / lagent process / HTTP service 从 carrier extract。 +- 子进程创建 span 时自动成为同一条 trace 的 child span。 + +carrier 可以是: + +- Ray task 的 kwargs / payload。 +- rollout state 的轻量字段。 +- HTTP headers。 +- lagent 子进程环境变量。 +- sandbox request body 中的 trace 字段。 + +#### 4.2.3 OpenInference 语义层,后续可选 + +OpenInference 不负责创建底层 trace runtime,不负责跨进程传播,不负责 OTLP 导出,也不负责 viewer。它负责让 AI / agent trace 的字段有统一语义。 + +第一版 Jaeger 路线不需要强依赖 OpenInference。第一版只需要: + +```text +OpenTelemetry SDK + OTLP exporter + Jaeger + XTuner 自定义 xtuner.* attributes +``` + +OpenInference 的价值在后续: + +- 当我们希望 `llm_call`、`tool_call`、`judge` 在 AI trace viewer 里有更标准的语义时,再映射到 OpenInference。 +- 当我们接 Phoenix 这类更懂 OpenInference 的 viewer 时,OpenInference attributes 会提升 UI 可读性。 +- 当用户希望和其他 LLM/agent observability 工具共享 trace 语义时,OpenInference 可以减少自定义字段解释成本。 + +XTuner 可以把阶段映射成类似语义: + +| XTuner stage | OpenInference 语义方向 | +| --- | --- | +| `xtuner.sampler.sample` | dataset / input preparation span | +| `xtuner.agent_loop.generate_sample` | agent span | +| `xtuner.rollout.generate` | agent 或 chain span | +| `xtuner.lmdeploy.generate` | LLM span | +| `xtuner.sandbox.create` | tool / environment span | +| `xtuner.tool.call` | tool span | +| `xtuner.judger.judge` | evaluator span | + +第一版不要强行把所有字段都塞进 OpenInference 已有字段里。更合理的是: + +- OpenInference 有标准字段时优先用标准字段。 +- XTuner 特有字段使用 `xtuner.*` attribute namespace。 +- 不默认记录大对象,只记录轻量 metadata,例如 task id、stage、status、rank、duration、error message、artifact path。 + +因此当前依赖关系应理解为: + +```text +OpenTelemetry: 第一版必需,负责 trace/span/context/export。 +OpenInference: 后续可选,负责 AI/LLM/agent 语义命名。 +``` + +#### 4.2.4 Jaeger viewer / backend + +默认推荐 Jaeger 作为第一个开源 viewer/backend: + +- Jaeger 能直接接收 OTLP trace。 +- Jaeger UI 能展示单个 sample 的 span tree、duration、error、attributes。 +- Jaeger 对第一版 sample-level distributed tracing 足够直接,不需要先引入 AI 专用 viewer。 + +开发阶段可以先用 all-in-one: + +```bash +docker run --rm --name jaeger \ + -p 16686:16686 \ + -p 4317:4317 \ + -p 4318:4318 \ + jaegertracing/jaeger:2.19.0 +``` + +其中: + +- `16686` 是 Jaeger UI。 +- `4317` 是 OTLP gRPC。 +- `4318` 是 OTLP HTTP。 + +XTuner 需要做的是: + +- 启动或连接 Jaeger。 +- 把 OTLP endpoint 配到 OTel exporter,例如 `http://localhost:4317`。 +- 在 span attributes 里写清楚 `xtuner.task_id`、`xtuner.stage`、`xtuner.status`、`xtuner.error_msg` 等字段。 +- 额外维护 task state aggregator,用于计算 overview、running/completed/failed、stage distribution、stage avg/p95/max。 + +#### 4.2.5 Phoenix / MLflow 备选 + +Phoenix 更适合作为后续 AI trace viewer: + +- 如果要更好地看 LLM / tool / evaluator 语义,可以增加 OpenInference attributes 后接 Phoenix。 +- 它适合 agentic RL 的多轮 LLM call / tool call / judge 调试。 + +MLflow 更适合作为训练实验管理的一部分: + +- 如果用户已经用 MLflow 管理 experiment,trace 放进 MLflow 更自然。 +- MLflow 的强项是 experiment / run / artifact / eval / trace 放在一起。 +- 对 task current-state dashboard,它不一定比 Jaeger 更自然。 + +因此 Phoenix / MLflow 不作为第一默认 viewer,但作为后续正式支持的 backend 备选。 + +#### 4.2.6 lagent / sandbox 独立进程 + +如果 task 的某些阶段运行在 lagent repo 或独立 sandbox 进程中,推荐做法是: + +1. XTuner 创建父 span。 +2. XTuner 通过 OTel propagator 生成 `traceparent` / `tracestate`。 +3. 通过环境变量、HTTP header 或任务 payload 传给 lagent。 +4. lagent 进程用 `opentelemetry-api` / `opentelemetry-sdk` extract context。 +5. lagent 内部创建 `llm call`、`toolcall`、`validate` 等 child span。 +6. lagent 和 XTuner 都导出到同一个 Jaeger backend。 + +这样不需要 lagent 依赖 XTuner 内部模块,只需要双方都遵守 OTel context propagation。 + +如果 lagent 暂时不能改,XTuner 至少可以在调用 lagent 的边界上包一个 span,先记录“lagent 整体耗时”和错误。 + +### 4.3 不建议作为默认依赖的库 + +这些库可以作为可选 exporter 或后续集成,但不建议作为第一默认方案: + +| 库 | 不作为默认的原因 | +| --- | --- | +| Weave | 更绑定 W&B 生态,不是 OTel / OpenInference 标准层 | +| Langfuse | 更偏 LLM 应用 tracing / prompt / session 管理,XTuner 的核心需求是 RL task 状态和分布式链路 | +| Prometheus / Grafana | 适合系统 metrics,例如 GPU、QPS、latency histogram,不适合表达单条 task 的 parent-child trace | + +### 4.4 和推理系统 metrics 的边界 + +当前阶段不做推理系统 metrics。 + +保留的内容: + +- 某个 sample 进入 LLM call 的时间。 +- 某个 sample 的 LLM call 结束时间和耗时。 +- 某个 sample 的 LLM call 是否失败。 +- 某个 sample 的 LLM call 属于哪条 trace、哪个 parent span。 +- 必要的轻量 attributes,例如 backend name、worker rank、request id。 + +不做的内容: + +- inference engine QPS。 +- engine throughput。 +- GPU 利用率和显存。 +- queue depth。 +- batching / micro-batching 状态。 +- KV cache 命中率。 +- prefix cache 命中率。 +- 服务端整体 latency histogram。 + +如果后续需要推理系统 metrics,应作为独立 observability 子项目处理。它可以继续使用 Prometheus / Grafana 或 OTel Metrics,但不应该混进 sample-level trace 设计里。 + +关键原则: + +- OTel 是第一版标准层,业务代码不直接调用 OTel API。 +- OpenInference 是后续可选 AI 语义增强,不作为第一版 Jaeger 路线的硬依赖。 +- XTuner `trace_function` / `trace_span` 是唯一推荐的业务插桩入口。 +- Jaeger 是第一版开源 trace backend / viewer。 +- Phoenix / MLflow 后续作为可选 AI trace viewer 或 experiment trace viewer。 +- 跨进程传播使用 W3C Trace Context:`traceparent`、`tracestate`、必要时加 `baggage`。 +- 对 lagent 这类独立 repo / 独立进程,优先通过环境变量、HTTP header、CLI 参数、任务 payload 传递 trace context,而不是要求它依赖 XTuner 内部实现。 +- 用户自定义字段统一进入 span attributes,并同时写入本地 JSONL 和 OTel span。 +- 慢/失败 task 的调用栈以独立 artifact 保存,trace event 只记录 artifact 路径,避免 JSONL 膨胀。 + +## 5. 为什么不是直接用 Weave + +Weave 的优点: + +- W&B 生态内体验好。 +- LLM / agent trace UI 成熟。 +- 支持 Op、Call、Trace、Thread。 +- 对函数输入、输出、错误、latency 的展示比较直接。 + +但它不适合作为 XTuner 默认开源栈: + +- Weave 更偏 W&B 平台生态。 +- 它不是 OpenTelemetry / OpenInference 标准层。 +- 默认依赖 Weave 会让 XTuner 的可观测性绑定到 W&B。 +- 大规模训练 trace 数据可能带来网络、配额和成本压力。 + +结论:Weave 可以作为可选 exporter,但不建议作为 XTuner 默认开源 trace backend。 + +## 6. Jaeger、Phoenix 和 MLflow 的取舍 + +### 6.1 Jaeger + +Jaeger 是第一版推荐 viewer/backend。 + +Jaeger 适合先解决 XTuner 当前最核心的问题: + +- 单个 sample 的完整 span tree。 +- 每个 span 的 start/end/duration。 +- error span。 +- span attributes,例如 `xtuner.task.uid`、`xtuner.stage`、`xtuner.status`、`xtuner.error_msg`。 +- 跨进程 parent-child 关系。 + +优点: + +- 和 OpenTelemetry / OTLP 直接对接。 +- 部署简单,开发阶段可以用 all-in-one 容器。 +- UI 足够验证 sample tracing 主链路。 +- 不要求 XTuner 第一版引入 AI 专用 trace 平台。 + +不足: + +- Jaeger 是通用分布式 tracing UI,不是 AI 专用 viewer。 +- 不会天然理解 LLM/tool/evaluator 语义。 +- 对“100 个 task 当前分别在哪个阶段”这种训练现场视角,仍需要 XTuner 自己聚合。 + +第一版接受这个取舍:先把 sample-level distributed tracing 跑通,再考虑 AI 专用 viewer。 + +### 6.2 Phoenix + +Phoenix 更贴近 OpenInference 生态,适合展示 AI trace: + +- LLM span。 +- agent span。 +- tool span。 +- retriever span。 +- evaluator span。 +- parent-child span tree。 + +优点: + +- 和 OpenInference / OpenTelemetry 的语义更自然。 +- 适合 agentic RL 的多轮 LLM call / tool call / judge 调试。 +- 开源可自建。 + +不足: + +- 第一版不选它作为默认 backend,避免同时引入 AI 语义层和 viewer 迁移。 +- 它也不一定天然提供 XTuner 需要的 task-level live state dashboard。 + +### 6.3 MLflow Tracing + +MLflow 更偏训练实验和模型生命周期: + +- experiment tracking。 +- LLM / agent tracing。 +- evaluation。 +- feedback。 +- self-hosted tracking server。 + +优点: + +- 开源、自建路线清晰。 +- 和训练实验管理更自然。 +- 官方强调兼容 OpenTelemetry 和 GenAI semantic conventions。 +- 对内网部署比 SaaS 更友好。 + +不足: + +- UI 更偏 experiment / trace 管理,不是专门为 RL task current-state dashboard 设计。 +- 多 worker 写本地 backend 时需要处理并发和存储配置。 + +### 6.4 推荐优先级 + +第一版优先: + +```text +OpenTelemetry + OTLP + Jaeger +``` + +后续如果目标是标准 AI trace,再考虑: + +```text +OpenTelemetry + OpenInference + Phoenix +``` + +后续如果目标是和训练实验管理深度结合,再考虑: + +```text +OpenTelemetry + MLflow Tracing +``` + +XTuner 第一版先把 OTel exporter abstraction 做出来,后续让 Jaeger、Phoenix 和 MLflow 都能接。 + +## 7. 为什么仍然需要 XTuner task state aggregator + +OpenTelemetry / Jaeger / Phoenix / MLflow 擅长展示 trace tree,例如: + +```text +task + ├── sampler + ├── rollout.generate + │ └── lmdeploy.generate + ├── toolcall + └── judger +``` + +但 XTuner 的核心在线问题是: + +```text +当前总 task 数是多少? +多少 running? +多少 completed? +多少 failed? +哪些 task 卡在 rollout.generate? +哪些 task 卡在 judger? +每个阶段的 avg / p95 / max 是多少? +``` + +这类问题需要 task current-state 聚合。通用 trace backend 可以提供原始 span,但不一定提供 XTuner 训练现场最需要的聚合视图。 + +因此即使 viewer 使用开源方案,XTuner 也建议保留一个很薄的 task state aggregator: + +- 不负责定义 trace 标准。 +- 不负责长期存储所有 trace。 +- 不负责替代 Jaeger / Phoenix / MLflow 这类开源 viewer。 +- 只负责把 OTel span event 聚合成 task current-state summary。 +- 优先把 summary 写回开源 backend 可查询的 trace attributes / aggregated trace summary / dashboard index。 +- 如果开源 viewer 能满足需求,不额外维护完整 XTuner 自研 viewer。 + +## 8. 数据模型映射 + +### 8.0 可复现 trace identity + +trace id 必须尽量由稳定输入决定,而不是由运行时递增计数器、时间戳、future step 或进程随机数决定。 + +推荐规则: + +```text +stable_task_key = canonical_json({ + "task_name": task_name, + "task_uid": task_uid, + "sample_id": sample_id, + "rollout_id": rollout_id, +}) +trace_id = first_16_bytes(blake2s(stable_task_key)) +``` + +设计要求: + +- 同一个数据样本、同一个 rollout identity,在两次实验中得到相同 trace id。 +- `experiment_name`、`trial_name`、`train_step`、`producer_future_step` 默认不参与 trace id 生成,因为这些字段会随实验变化。 +- 如果同一个数据样本会产生多个 rollout,则 `rollout_id` 或 `rollout_n` 应参与 identity,避免多个 rollout 共用一个 root trace。 +- 如果 sampler 无法提供稳定 uid,trace runtime 可以 fallback 到运行时 uid,但必须标记 `xtuner.trace_id_stable=false`。 +- 对 OTel 来说,trace id 是 128-bit hex;可以由上述 hash 结果生成,span id 仍可运行时生成。 + +### 8.1 XTuner task identity + +需要明确以下 identity: + +| XTuner 概念 | OTel span / XTuner attribute 映射 | +| --- | --- | +| task | root span 或 root span attribute | +| task uid | `attributes["xtuner.task.uid"]` | +| task name | `attributes["xtuner.task.name"]` | +| trace id stable | `attributes["xtuner.trace_id_stable"]` | +| sample id | `attributes["xtuner.sample.id"]` | +| rollout id | `attributes["xtuner.rollout.id"]` | +| produce batch id | `attributes["xtuner.produce_batch_id"]` | +| train step | `attributes["xtuner.train_step"]` | +| model step | `attributes["xtuner.model_step"]` | +| rank | `attributes["xtuner.rank"]` | +| worker rank | `attributes["xtuner.worker_rank"]` | + +### 8.2 Agent in Sandbox 主链路 + +第一版 trace 的主链路应以 `AgentInSandboxLoop.generate_sample()` 到 `Runner.run()` 的真实执行路径为准。 + +推荐 span tree: + +```text +xtuner.sample + ├── xtuner.producer.sample_group + ├── xtuner.producer.generate_group + │ └── xtuner.agent_in_sandbox.generate_sample + │ ├── xtuner.agent_in_sandbox.throttle_sandbox_create + │ ├── xtuner.sandbox.runner.run_total + │ │ ├── xtuner.sandbox.acquire + │ │ ├── xtuner.sandbox.infer + │ │ │ ├── xtuner.sandbox.stage.pre + │ │ │ ├── xtuner.sandbox.entry + │ │ │ ├── xtuner.llm_call + │ │ │ └── xtuner.sandbox.stage.post + │ │ └── xtuner.sandbox.validate + │ │ ├── xtuner.sandbox.judger + │ │ │ ├── xtuner.sandbox.stage.pre + │ │ │ ├── xtuner.sandbox.entry + │ │ │ └── xtuner.sandbox.stage.post + │ │ └── xtuner.sandbox.judger.aggregate + │ └── xtuner.agent_in_sandbox.materialize_trajectory + ├── xtuner.producer.put_generated_group + └── xtuner.sample.final +``` + +说明: + +- `xtuner.agent_in_sandbox.generate_sample` 是单个 sample 的 agent-in-sandbox 根 span。 +- `xtuner.sandbox.runner.run_total` 对应旧 `with span(uid_obs, "run_total", task_id=tid)`。 +- `xtuner.sandbox.acquire` 对应 sandbox pool 的 `pool.get(...)`,需要记录 sandbox name、env id、url、image、workspace。 +- `xtuner.sandbox.infer` 对应 agent 执行阶段,不应只叫 `toolcall`;它内部可以包含 setup hook、agent daemon entry、shell entry、LLM call、post hook。 +- `xtuner.sandbox.validate` 对应验证阶段;里面每个 judger 是独立 `xtuner.sandbox.judger` span。 +- `xtuner.sandbox.entry` 是 `ShellEntry` / `DetachedShellEntry` 的统一 span,entry name、entry kind、mode、return code、pid/rc/stdout/stderr artifact 路径写 attributes。 +- `xtuner.llm_call` 是 sample 维度的模型调用 span。如果 LLM 请求由 sandbox 内 agent daemon 触发,也必须通过 trace context / session id 关联到当前 sample。 +- `xtuner.agent_in_sandbox.materialize_trajectory` 对应 `trace_store.export_training_trace` 和 tokenizer 处理。这个阶段可能在 sandbox 已经完成后卡住,必须纳入 trace。 + +### 8.3 旧 sandbox span 迁移表 + +当前 `sandbox_agent_loop.trace.span(...)` 写的是短 stage 名。迁移到 OTel / Jaeger 后,stage 名需要规范化: + +| 当前旧 stage | 新 span name | `xtuner.stage.kind` | 关键 attributes | +| --- | --- | --- | --- | +| `run_total` | `xtuner.sandbox.runner.run_total` | `agent_loop` | `xtuner.agent.item_id`、`xtuner.agent.uid`、`xtuner.agent.group_id` | +| `acquire` | `xtuner.sandbox.acquire` | `sandbox` | `xtuner.sandbox.name`、`xtuner.sandbox.env_id`、`xtuner.sandbox.url`、`xtuner.sandbox.image`、`xtuner.sandbox.workspace` | +| `infer` | `xtuner.sandbox.infer` | `agent_run` | `xtuner.agent.item_id`、`xtuner.sandbox.name`、`xtuner.stage.status` | +| `validate` | `xtuner.sandbox.validate` | `judge` | `xtuner.judger.count`、`xtuner.judger.aggregator`、`xtuner.judger.on_error` | +| `entry:{name}` | `xtuner.sandbox.entry` | `entry` | `xtuner.entry.name`、`xtuner.entry.kind`、`xtuner.entry.mode`、`xtuner.entry.return_code` | + +旧 `SpanHandle.annotate(...)` 对应新 span attributes;旧 `SpanHandle.mark_error(...)` 对应 OTel span status/error attributes: + +```text +xtuner.error.stage +xtuner.error.category +xtuner.error.type +xtuner.error.message +xtuner.error.retryable +``` + +旧 `event=enter/exit` 语义在 OTel 中由 span start/end 表达;如果仍然保留 JSONL fallback,可以继续写 `enter/exit`,但它应由同一个 trace runtime 生成,不再由 sandbox 独立 writer 生成一套割裂文件。 + +#### 8.3.1 旧 `span(...)` 功能覆盖约束 + +新 trace API 必须完整覆盖旧 sandbox `span(...)` 的能力,尤其是: + +```python +with span(uid_obs, "validate", task_id=tid): + score, failed = await self.validate.run(item, pool) +``` + +旧实现会写两条记录: + +```text +{"event": "enter", "ts", "uid", "stage": "validate", "task_id": tid} +{"event": "exit", "ts", "uid", "stage": "validate", "task_id": tid, "duration_ms", "ok", "err"} +``` + +迁移后的 `xtuner.sandbox.validate` span 必须至少保留这些等价信息: + +| 旧字段 / 行为 | 新 OTel / XTuner trace 表达 | +| --- | --- | +| `uid` | `xtuner.agent.uid` / OTel trace id 关联 | +| `stage="validate"` | span name `xtuner.sandbox.validate`,并写 `xtuner.legacy_stage="validate"` | +| `task_id=tid` | `xtuner.agent.item_id` | +| `event="enter"` | span start timestamp | +| `event="exit"` | span end timestamp | +| `duration_ms` | OTel span duration;JSONL fallback 继续写 `duration_ms` | +| `ok` | span status;同时写 `xtuner.span.ok` | +| `err` | `xtuner.error.message`;同时写 `xtuner.span.err` | +| `SpanHandle.annotate(...)` | span attributes | +| `SpanHandle.mark_error(...)` | span status error + `xtuner.error.*` attributes | +| 异常路径 | `span.record_exception(...)` + status error;JSONL fallback 写 error event 或 `ok=false` exit | + +`validate` 不是唯一需要兼容的阶段。以下旧调用都必须按同一规则覆盖: + +- `with span(uid_obs, "run_total", task_id=tid)` +- `with span(uid_obs, "acquire", task_id=tid)` +- `with span(uid_obs, "infer", task_id=tid)` +- `with span(uid_obs, "validate", task_id=tid)` +- `with span(uid_obs, f"entry:{self.name}", entry_kind=...)` + +验收要求: + +- 用新 trace 后,旧 viewer / 脚本依赖的核心字段都能从 JSONL fallback 或 OTel span attributes 中重建。 +- `validate` 成功时能看到耗时和 score / judger summary。 +- `validate` 失败时能看到 `ok=false`、error stage、error category、error message。 +- Jaeger 中能看到 `xtuner.sandbox.validate` span,并且它是当前 sample trace 的子 span。 + +### 8.4 Agent in Sandbox attributes + +第一版只记录轻量、结构化字段,不默认记录 prompt、response、stdout、stderr、tool result 全文。 + +`AgentRolloutItem` 相关: + +- `xtuner.agent.item_id` +- `xtuner.agent.uid` +- `xtuner.agent.group_id` +- `xtuner.agent.data_source` +- `xtuner.agent.ability` +- `xtuner.agent.status` +- `xtuner.agent.reward` + +`StageRecord` / sandbox 相关: + +- `xtuner.sandbox.name` +- `xtuner.sandbox.image` +- `xtuner.sandbox.env_id` +- `xtuner.sandbox.url` +- `xtuner.sandbox.workspace` +- `xtuner.stage.phase`:`pre`、`entry`、`post` +- `xtuner.stage.status` + +`EntryRecord` 相关: + +- `xtuner.entry.id` +- `xtuner.entry.name` +- `xtuner.entry.kind`:`ShellEntry` 或 `DetachedShellEntry` +- `xtuner.entry.mode` +- `xtuner.entry.return_code` +- `xtuner.entry.outcome.source` +- `xtuner.entry.outcome.reason` +- `xtuner.entry.retryable` +- `xtuner.entry.pid_file` +- `xtuner.entry.rc_file` +- `xtuner.entry.stdout_artifact` +- `xtuner.entry.stderr_artifact` + +`JudgerValidator` / judger 相关: + +- `xtuner.judger.name` +- `xtuner.judger.aggregator` +- `xtuner.judger.on_error` +- `xtuner.judger.weight` +- `xtuner.judger.score` +- `xtuner.judger.usable` + +错误相关: + +- `xtuner.error.stage` +- `xtuner.error.category` +- `xtuner.error.type` +- `xtuner.error.message` +- `xtuner.error.retryable` +- `xtuner.diagnostic_artifact_path` +- `xtuner.daemon_log_artifact_path` + +### 8.5 Stage kind 映射 + +第一版 Jaeger 路线不强制依赖 OpenInference。XTuner 先用自己的 `xtuner.stage.kind` 表达业务阶段;后续如果接 Phoenix / OpenInference,再把这些 kind 映射到 OpenInference span kind。 + +| XTuner stage | 第一版 `xtuner.stage.kind` | 后续可选 OpenInference kind | 说明 | +| --- | --- | --- | --- | +| sampler | `sample` | `CHAIN` | 构造任务、采样数据 | +| producer | `queue` / `commit` | `CHAIN` | task 分发和结果回收 | +| agent_loop | `agent_loop` | `AGENT` | 单个 task 的 agentic rollout | +| rollout_controller.generate | `llm_call` | `CHAIN` | 调度生成请求 | +| rollout_worker.generate | `llm_call` | `LLM` | 具体生成执行 | +| lmdeploy / sglang call | `llm_call` | `LLM` | sample 维度的推理请求 | +| sandbox acquire / setup | `sandbox` | `CHAIN` | 环境准备 | +| sandbox infer | `agent_run` | `AGENT` | agent in sandbox 的主执行阶段 | +| sandbox stage hook | `sandbox_hook` | `CHAIN` | pre / post hook | +| sandbox entry | `entry` | `TOOL` | shell entry、daemon entry、工具命令 | +| toolcall | `tool_call` | `TOOL` | agent 语义上的工具调用;不等同于所有 sandbox entry | +| judger / reward / verifier | `judge` | `EVALUATOR` | 评估、打分、验证 | +| trajectory materialize | `commit` | `CHAIN` | agent transcript 转训练样本 | +| format check / guard check | `guard` | `GUARDRAIL` | 格式或安全检查 | + +### 8.6 Span attributes + +建议保留轻量字段: + +- `xtuner.task.uid` +- `xtuner.task.name` +- `xtuner.task.status` +- `xtuner.stage` +- `xtuner.train_step` +- `xtuner.model_step` +- `xtuner.producer_future_step` +- `xtuner.produce_batch_id` +- `xtuner.rank` +- `xtuner.worker_rank` +- `xtuner.error_type` +- `xtuner.error_msg` + +用户自定义 attributes: + +- 由 `trace_span(..., attributes={...})` 或 `trace_function(..., attributes_getter=...)` 传入。 +- 进入 OTel span attributes。 +- 同步写入 XTuner JSONL 的 `attributes` 字段。 +- 推荐命名空间: + - 框架字段使用 `xtuner.*`。 + - 用户字段使用 `user.*` 或业务自定义前缀。 +- 允许的值类型应遵循 OTel attribute 限制: + - `str` + - `bool` + - `int` + - `float` + - 上述类型的短 list +- 不建议传入 dict / 大对象 / tensor / 长文本。确实需要时应保存为 artifact,并在 attribute 中记录路径或摘要。 +- trace runtime 应做轻量 sanitize,例如限制字符串长度、过滤无法序列化的对象。 + +默认不记录: + +- prompt 文本。 +- response 文本。 +- tool result 文本。 +- 图片。 +- tensor。 +- 大型 metadata。 + +这些内容可以作为 debug dump 或显式配置开启。 + +### 8.7 默认插桩阶段 + +默认插桩的原则: + +- 以 `agent in sandbox` 为第一主链路。 +- 默认 span 必须能回答“sample 当前卡在哪个阶段”。 +- 默认 span 必须覆盖旧 sandbox `span(...)` 的能力。 +- 采集粒度可以比 viewer 展示粒度细;viewer 默认折叠 wrapper / hook 细节。 +- 不默认记录大对象,不默认记录推理系统 metrics。 + +#### 8.7.1 P0:默认必须采集 + +这些阶段默认必须插桩。缺任何一个都会影响卡住定位或旧功能兼容。 + +| Span name | stage kind | 插桩位置 | 作用 | +| --- | --- | --- | --- | +| `xtuner.sampler.sample` | `sample` | `Sampler.sample()` 或 producer 的 sample wrapper | 记录 sample 从 dataloader / replay buffer / expired pool 进入 rollout | +| `xtuner.producer.generate_group` | `queue` | `ProduceContext.generate_group()` | 记录 producer 下发 group 到 agent loop 并等待返回的总耗时 | +| `xtuner.agent_in_sandbox.generate_sample` | `agent_loop` | `AgentInSandboxLoop.generate_sample()` | 单个 sandbox sample 的根业务 span | +| `xtuner.sandbox.runner.run_total` | `agent_loop` | `Runner.run()` | 覆盖旧 `run_total`,表示一个 `AgentRolloutItem` 的 sandbox 全生命周期 | +| `xtuner.sandbox.acquire` | `sandbox` | `Runner.run()` 调用 `pool.get(...)` | 覆盖旧 `acquire`,定位 sandbox 创建 / 获取 / 健康检查卡住 | +| `xtuner.sandbox.infer` | `agent_run` | `Runner.run()` 调用 `self.infer.run(...)` | 覆盖旧 `infer`,表示 agent 在 sandbox 中执行 | +| `xtuner.sandbox.entry` | `entry` | `ShellEntry.run()` / `DetachedShellEntry.run()` | 覆盖旧 `entry:{name}`,定位具体 shell / daemon entry 卡住或失败 | +| `xtuner.llm_call` | `llm_call` | sandbox agent daemon / session server / rollout HTTP 边界 | 记录 sample 维度的 LLM call 总耗时、首 token 返回耗时和错误 | +| `xtuner.sandbox.validate` | `judge` | `Runner.run()` 调用 `self.validate.run(...)` | 覆盖旧 `validate`,定位验证和打分卡住 | +| `xtuner.sandbox.judger` | `judge` | `JudgerValidator._run_one()` | 每个 judger 单独 span,定位具体 judger 失败或慢 | +| `xtuner.sandbox.judger.aggregate` | `judge` | `JudgerValidator._aggregate()` | 记录聚合策略、可用 judger 数、最终 score | +| `xtuner.agent_in_sandbox.materialize_trajectory` | `commit` | `AgentInSandboxLoop._fill_rollout_state()` | 记录 agent trajectory 转训练样本,包括 `trace_store.export_training_trace` | +| `xtuner.producer.put_generated_group` | `commit` | `ProduceContext.put_generated_group()` | 记录结果过滤、统计、写 replay buffer | +| `xtuner.sample.final` | `final` | trace runtime / producer 回收处 | 记录 sample 最终状态:completed / failed / aborted / filtered / expired | + +P0 span 的 viewer 默认展示阶段建议合并为: + +```text +sample -> queue -> agent_loop -> sandbox.acquire -> sandbox.infer -> llm_call / entry -> validate -> commit -> final +``` + +`xtuner.llm_call` 必须额外记录首 token 指标,用于判断慢点是否在推理引擎返回首 token 前: + +- `xtuner.llm.request_start_ts` +- `xtuner.llm.first_token_ts` +- `xtuner.llm.time_to_first_token_ms` +- `xtuner.llm.total_ms` +- `xtuner.llm.first_token_observed` +- `xtuner.llm.prompt_tokens` +- `xtuner.llm.completion_tokens` +- `xtuner.llm.backend` +- `xtuner.llm.stream` + +语义要求: + +- streaming 请求:从请求发出到收到第一个有效 token / chunk 的时间,写 `time_to_first_token_ms`,并在 span 上添加 event `first_token`。 +- 非 streaming 请求:无法精确知道首 token 时间,只能记录请求发出到完整响应返回的时间;此时 `first_token_observed=false`,不要把完整响应时间伪装成首 token 时间。 +- 如果后端返回 usage / token 数,写 prompt / completion token 计数;不默认记录 prompt / response 文本。 +- viewer 中 `llm_call` 应同时显示 `TTFT` 和 `total`,用于区分“首 token 慢”和“后续生成慢”。 + +#### 8.7.2 P1:默认采集,但 viewer 默认折叠 + +这些阶段对定位 sandbox 细节有价值,但主视图不应默认铺开。 + +| Span name | stage kind | 插桩位置 | 作用 | +| --- | --- | --- | --- | +| `xtuner.agent_in_sandbox.throttle_sandbox_create` | `queue` | `AgentInSandboxLoop._throttle_sandbox_create()` | 定位 sample 是否卡在 sandbox 创建限流 | +| `xtuner.sandbox.stage.pre` | `sandbox_hook` | `SandboxStage._run_phase("pre", ...)` | 定位 upload / install / setup hook | +| `xtuner.sandbox.stage.post` | `sandbox_hook` | `SandboxStage._run_phase("post", ...)` | 定位 download / parse / cleanup hook | +| `xtuner.sandbox.release_all` | `sandbox` | `Runner.run()` 的 finally | 定位 sample 已完成但 sandbox cleanup 卡住 | +| `xtuner.agent_in_sandbox.export_training_trace` | `commit` | `trace_store.export_training_trace` 调用点 | 区分 trajectory materialize 中 trace store 卡住还是 tokenizer 卡住 | +| `xtuner.partial_rollout_handler.preprocess` | `llm_call` | partial rollout worker | 兼容 partial rollout,定位续跑前处理 | +| `xtuner.partial_rollout_handler.postprocess` | `llm_call` | partial rollout worker | 兼容 partial rollout,定位续跑后处理 | + +P1 span 在 Jaeger 中完整保留;XTuner task summary / viewer 默认只在 task detail 里展开。 + +#### 8.7.3 P2:默认不采集,debug 时开启 + +这些阶段容易产生大量事件或噪声,默认不插 span。 + +| 阶段 | 默认不采集原因 | +| --- | --- | +| `SandboxPool._wait_healthy` 每次健康检查轮询 | 频率高,容易刷屏;默认只在 `xtuner.sandbox.acquire` 上记录总耗时和失败原因 | +| `EntryMonitor` 每次 probe / poll | 频率高;默认只记录 entry 总耗时、timeout、pid/rc 文件结果 | +| 每个 upload/download 文件 | 文件数量可能很大;默认记录 hook span 和 summary | +| stdout / stderr / daemon log 全文 | 体积大;默认写 artifact,只在 span attributes 中放路径和摘要 | +| `replay_buffer.count/get/take_batch` 等高频管理调用 | 更偏系统内部状态,不是 sample 执行主阶段 | +| tokenizer 内部逐步处理 | 默认放在 `materialize_trajectory` 总耗时内 | +| 推理 engine metrics | 当前阶段不关注系统 metrics,只关注 sample 维度 `llm_call` | + +#### 8.7.4 普通 single-turn / rollout 路径兼容 + +虽然第一版验收以 `AgentInSandboxLoop` 为主,但普通 single-turn 路径也需要保留默认插桩: + +| Span name | stage kind | 插桩位置 | +| --- | --- | --- | +| `xtuner.agent_loop.generate_group` | `agent_loop` | `AgentLoop.generate_group()` | +| `xtuner.agent_loop.generate_sample` | `agent_loop` | `SingleTurnAgentLoop.generate_sample()` | +| `xtuner.rollout_controller.generate` | `llm_call` | `RolloutController.generate()` | +| `xtuner.rollout_worker.generate` | `llm_call` | `RolloutWorker.generate()` | +| `xtuner.rollout_engine.generate` | `llm_call` | rollout worker 实际 HTTP 请求 | +| `xtuner.judger.judge` | `judge` | `AgentLoop.run_judger()` | + +这些 span 的 viewer 展示可折叠为: + +```text +sample -> queue -> agent_loop -> llm_call -> judge -> commit -> final +``` + +## 9. 插桩方式 + +用户侧继续使用 XTuner 的 `trace_function` / `trace_span`,不要要求业务代码直接使用 OTel API。 + +推荐: + +```python +@trace_function("xtuner.rollout.generate") +async def generate(...): + ... +``` + +支持自定义 attributes: + +```python +with trace_span( + "xtuner.judger.evaluate", + attributes={ + "user.dataset": "gsm8k", + "user.difficulty": difficulty, + "xtuner.judger.name": judger_name, + }, +): + ... +``` + +装饰器形式: + +```python +@trace_function( + "xtuner.toolcall.execute", + span_kind="TOOL", + attributes_getter=lambda self, request, **kwargs: { + "user.tool_name": request.tool_name, + "user.timeout_s": request.timeout, + }, +) +async def execute_tool(...): + ... +``` + +`trace_function` 内部根据配置选择 backend: + +```text +backend=jsonl -> 当前内置 JSONL writer +backend=otel -> OpenTelemetry span +backend=dual -> JSONL + OpenTelemetry span +backend=disabled -> no-op +``` + +这样可以保证: + +- 业务代码不绑定 OTel。 +- 可以平滑迁移。 +- 可以同时保留 XTuner viewer 和外部 trace backend。 +- 测试可以继续围绕 XTuner trace API 写,而不是直接依赖外部服务。 +- lagent 等外部 repo 如果不方便直接依赖 XTuner,可以只接 OTel / W3C Trace Context;XTuner 负责在调用边界注入 trace context。 + +## 10. 跨进程和 Ray 传播 + +彻底接入 OTel 后,最大风险是 context propagation。 + +需要传播: + +- trace id。 +- parent span id。 +- task uid。 +- sample / rollout id。 +- trace config。 +- 用户自定义 attributes 中可跨进程传播的轻量字段。 + +Ray actor / worker 之间不能只靠 Python contextvars。需要在任务 payload 或 runtime env 中显式传递 trace context。 + +候选方案: + +1. 在 RolloutState / task payload 里带 trace context。 +2. 在 Ray remote call kwargs 里带 trace context。 +3. 在 HTTP request header 里带 W3C traceparent。 +4. 对 lmdeploy / sglang 请求,尽量传递 request id 和 trace attributes。 +5. 对 lagent 等独立进程,在启动进程时通过环境变量或 CLI 参数注入 trace context。 + +推荐最小方案: + +- XTuner 内部 task payload 显式携带 `trace_context`。 +- HTTP 调用支持 `traceparent` header。 +- OTel context 只在进程内自动管理,跨进程必须显式传。 + +### 10.1 lagent / 外部 repo / 新进程链路 + +目标场景: + +```text +XTuner task + └── sandbox agent loop + └── lagent process + ├── llm call + ├── toolcall + └── validate +``` + +要求: + +- lagent 里的 span 必须挂在 XTuner task trace 下面,而不是形成孤立 trace。 +- lagent 可以不依赖 XTuner 代码。 +- lagent 进程退出或失败时,XTuner 仍能看到对应 span error 或最后状态。 + +推荐协议: + +- XTuner 启动 lagent 进程时注入: + - `TRACEPARENT` + - `TRACESTATE` + - `OTEL_EXPORTER_OTLP_ENDPOINT` + - `OTEL_SERVICE_NAME=lagent` + - `XTUNER_TASK_UID` + - `XTUNER_TRACE_ATTRIBUTES_JSON` +- 如果通过 HTTP / RPC 调用 lagent,则使用 header: + - `traceparent` + - `tracestate` + - `baggage` +- 如果通过文件、队列或 Ray object 传递任务,则在 payload 中携带: + - `trace_context` + - `task_uid` + - `stable_task_key` + +lagent 侧最小接入方式: + +- 读取 W3C trace context。 +- 用 OTel SDK extract parent context。 +- 创建子 span。 +- 第一版使用 XTuner stage attributes 标注 `agent_loop`、`llm_call`、`tool_call`、`judge`。 +- 后续如接 OpenInference,再补充 `AGENT`、`LLM`、`TOOL`、`EVALUATOR` 等 AI span kind。 + +如果 lagent 暂时不能改代码: + +- XTuner 至少在调用 lagent 的边界创建一个 wrapper span,例如 `xtuner.lagent.run`。 +- wrapper span 记录进程 pid、启动参数摘要、退出码、timeout、stderr artifact 路径。 +- 这样无法看到 lagent 内部细节,但能把 task 卡住范围定位到 lagent。 + +## 11. 写入和性能 + +OpenTelemetry 接入后仍然要避免阻塞训练。 + +要求: + +- 使用 BatchSpanProcessor。 +- exporter 异步批量发送。 +- trace 后端不可用时不能阻塞训练主流程。 +- 支持 sampling。 +- 支持只记录 lightweight attributes。 + +推荐配置项: + +```python +trace_config = dict( + enabled=True, + backend="otel", + otlp_endpoint="http://localhost:4317", + viewer_backend="jaeger", + viewer_url="http://localhost:16686", + service_name="xtuner-rl", + sample_rate=1.0, + record_payload=False, +) +``` + +### 11.1 慢 / 失败 task 调用栈保存 + +目标: + +- 对失败 task 自动保存异常 traceback。 +- 对耗时超过阈值的 task / span 自动保存调用栈。 +- 不对所有 task 保存 stack,避免性能和存储开销失控。 + +触发条件: + +- span 抛异常。 +- task 失败。 +- task 超过 `slow_task_threshold_s`。 +- 单个 span 超过 `slow_span_threshold_s`。 +- 用户手动调用 `trace_dump_stack(reason=...)`。 + +保存内容: + +- Python traceback。 +- 当前线程 stack。 +- asyncio task stack。 +- 进程 pid、rank、role、task uid、span id。 +- 如果是子进程,例如 lagent,则保存子进程本地 stack artifact 或 wrapper span 的 stderr / log artifact。 + +保存形式: + +- 栈内容写入独立 artifact 文件,例如: + +```text +{output_dir}/trace_artifacts/stacks/{trace_id}/{span_id}.txt +``` + +- trace event / span attributes 只记录: + - `xtuner.stack_artifact_path` + - `xtuner.stack_reason` + - `xtuner.stack_captured=true` + +慢 span 的采集方式: + +- 简单版:span 结束时发现耗时超过阈值,保存结束时 traceback / callsite 信息。 +- 完整版:运行时维护 active span registry,后台 watchdog 周期性扫描 active span,对超过阈值的线程或 asyncio task 捕获实时 stack。 + +推荐分阶段: + +- 第一阶段先做失败 traceback + span 结束时慢 span artifact。 +- 第二阶段再做 active span watchdog,解决“正在卡住时”的实时 stack。 + +## 12. Viewer 方案 + +viewer 也优先使用开源实现,不默认维护完整 XTuner 自研 viewer。 + +首选 viewer: + +- Jaeger。 + +Jaeger 应作为第一版用户查看 trace 的主入口: + +- 单个 task 的完整 span tree。 +- LLM call。 +- toolcall。 +- judger。 +- failed span。 +- slow span。 +- stack artifact 链接。 +- span attributes,例如 task uid、stage、status、error msg、worker rank。 + +Phoenix / MLflow 后续作为可选 viewer: + +- Phoenix 更适合 AI trace 语义展示。 +- MLflow 更适合和 experiment tracking 结合。 + +XTuner 只保留薄的 task state aggregator / dashboard adapter,用来补开源 viewer 未必天然具备的训练现场视角: + +- 总 task 数。 +- running / completed / failed。 +- 当前阶段分布。 +- 阶段 avg / p95 / max。 +- failed task 和 error msg summary。 +- task uid 到 external trace id / trace URL 的映射。 + +优先实现方式: + +1. 尽量把 task summary 写成 OTel attributes / span events / aggregated trace summary,让 Jaeger 能直接过滤和跳转。 +2. 如果开源 viewer 的 dashboard 能满足 task summary 展示,则不单独维护 XTuner HTML viewer。 +3. 如果开源 viewer 暂时无法表达 “100 个 task 当前阶段分布”,XTuner 只提供一个最小 dashboard adapter: + - 读取 OTel/backend query 或本地 span stream。 + - 输出 task summary。 + - 页面只做 overview 和跳转,不再复刻完整 trace timeline。 + +也就是说,trace tree 和单 task timeline 尽量交给开源 viewer;XTuner 自己只补 RL task current-state 聚合。 + +## 13. 实施阶段 + +### Phase 1:OTel 数据模型适配 + +- 扩展 TraceEvent 字段。 +- 增加 `span_id`、`parent_span_id`、`span_kind`、`attributes`。 +- 增加确定性 trace id 生成。 +- 支持用户自定义 attributes。 +- 以 `AgentRolloutItem`、`StageRecord`、`EntryRecord`、`RolloutError` 为第一版 domain mapping,不另起一套和 sandbox agent loop 脱节的数据模型。 +- 兼容旧 sandbox `run_total/acquire/infer/validate/entry:{name}` 语义。 +- 保持当前 JSONL 和 viewer 可用。 +- 不接外部 backend。 + +### Phase 2:OpenTelemetry backend + +- `trace_function` 创建 OTel span。 +- 支持 OTLP exporter。 +- 支持 BatchSpanProcessor。 +- 支持 service name、endpoint、sample rate。 +- 支持 dual mode:JSONL + OTel。 +- 支持 W3C trace context inject / extract。 + +### Phase 3:Jaeger 端到端验证 + +- 跑 agent in sandbox 的最小真实训练。 +- 验证 Jaeger 能看到单个 sample 的 sandbox span tree。 +- 验证 `run_total -> acquire -> infer -> entry -> validate -> judger -> materialize_trajectory` 的嵌套关系。 +- 验证 task uid / stage / status / error msg 出现在 span attributes。 +- 验证跨进程 parent-child 关系。 +- 验证 sandbox entry 失败能在 Jaeger 中看到 error stage、error category、diagnostic artifact path。 +- 验证 trace 后端不可用时训练不受影响。 + +### Phase 4:OpenInference 语义增强,可选 + +- 为主要阶段标注 span kind。 +- 规范 XTuner attribute namespace。 +- 对 LLM / TOOL / EVALUATOR span 做最小字段映射。 + +### Phase 5:分布式链路追踪验证 + +- Ray actor 链路 parent-child 正确。 +- rollout controller -> rollout worker 链路正确。 +- XTuner -> sandbox agent daemon / lagent 新进程链路正确。 +- HTTP 请求携带 `traceparent`。 +- 子进程失败或 timeout 能回写 error / artifact。 + +### Phase 6:慢 / 失败 task 调用栈 + +- 失败 task 保存 traceback artifact。 +- 慢 span 保存 stack artifact。 +- viewer 展示 stack artifact 链接。 +- 后续再接 active span watchdog。 + +### Phase 7:Jaeger viewer + task summary adapter + +- 以 Jaeger 作为主 viewer。 +- task summary 中保存 external trace id / trace URL。 +- 如果 Jaeger dashboard 能直接展示 task summary,则不做 XTuner 自研 HTML viewer。 +- 如果不能,XTuner 只实现最小 task summary adapter。 + +## 14. 风险 + +| 风险 | 说明 | 缓解 | +| --- | --- | --- | +| 部署复杂 | Jaeger 需要作为外部 backend 运行 | 开发阶段使用 Jaeger all-in-one;保留 JSONL fallback 和 disabled 模式 | +| 性能影响 | 每个 task 多个 span,规模大时可能有开销 | BatchSpanProcessor、sampling、轻量 attributes | +| 数据膨胀 | prompt/response/tool result 可能非常大 | 默认不记录 payload | +| 跨进程 trace 断裂 | Ray actor / HTTP 请求可能丢 context | 显式 trace_context 和 traceparent | +| lagent 独立进程断链 | 外部 repo 不一定接 XTuner trace API | 使用 W3C trace context;不能改 lagent 时记录 wrapper span | +| 确定性 trace id 冲突 | 稳定 identity 设计不充分会导致不同 rollout 共用 trace id | stable key 必须包含 task uid 和 rollout identity | +| 自定义属性过大 | 用户传入长文本或不可序列化对象 | attribute sanitize、长度限制、大对象 artifact 化 | +| 调用栈保存开销 | 全量采集 stack 会拖慢训练 | 只对失败/慢 task 触发,watchdog 分阶段实现 | +| viewer 能力缺口 | 开源 viewer 可能不直接提供 XTuner task 状态面板 | 通过 task state aggregator / dashboard adapter 补齐 | +| 后端绑定 | 直接绑定单一 backend 会限制用户 | backend abstraction + exporter | + +## 15. 当前建议 + +如果目标是完全开源化 trace 技术栈,第一版推荐采用: + +```text +OpenTelemetry + OTLP + Jaeger +``` + +同时保留 XTuner 的轻量 task state aggregator;viewer 优先使用 Jaeger。 + +同时必须满足: + +- 使用 W3C Trace Context 串联 XTuner、Ray actor、HTTP、lagent 独立进程。 +- 使用确定性 trace id,保证同一个数据样本跨实验可复现。 +- 支持用户自定义 attributes,但默认限制大对象。 +- 对失败和慢 task 自动保存调用栈 artifact。 +- 第一版验收以 `AgentInSandboxLoop` 为主链路,必须能看清 `acquire`、`infer`、`entry`、`validate`、`judger`、`materialize_trajectory`。 +- 业务代码继续使用 XTuner trace API,尽量少侵入地接入 OTel。 + +如果后续更看重 AI trace 语义展示,第二选择是: + +```text +OpenTelemetry + OpenInference + Phoenix +``` + +如果后续更看重训练实验管理,第三选择是: + +```text +OpenTelemetry + MLflow Tracing +``` + +不建议把 Weave 作为默认开源实现,但可以作为可选 exporter。 + +## 16. 参考资料 + +- OpenTelemetry Python exporters: https://opentelemetry.io/docs/languages/python/exporters/ +- OpenTelemetry traces: https://opentelemetry.io/docs/concepts/signals/traces/ +- Jaeger getting started: https://www.jaegertracing.io/docs/latest/getting-started/ +- Jaeger deployment: https://www.jaegertracing.io/docs/latest/deployment/ +- OpenInference specification: https://arize-ai.github.io/openinference/spec/ +- Phoenix tracing: https://arize.com/docs/phoenix/tracing +- MLflow tracing: https://mlflow.org/docs/latest/genai/tracing/ +- Weave tracing: https://docs.wandb.ai/weave/guides/tracking/tracing/ +- verl rollout trace: https://github.com/verl-project/verl/blob/main/docs/advance/rollout_trace.rst diff --git a/docs/superpowers/specs/2026-06-12-current-trace-capabilities.md b/docs/superpowers/specs/2026-06-12-current-trace-capabilities.md new file mode 100644 index 0000000000..1760b77897 --- /dev/null +++ b/docs/superpowers/specs/2026-06-12-current-trace-capabilities.md @@ -0,0 +1,559 @@ +# 当前 Trace 功能说明 + +本文档记录当前分支中 XTuner RL trace 已经包含的功能。重点描述“现在能观测什么、通过什么接口观测、数据会发到哪里、有哪些默认插桩点”,不再重复早期 JSONL 方案的历史设计。 + +## 1. 总体定位 + +当前 trace 的核心目标是:以单条 rollout sample / task 为中心,观测它在 XTuner RL 系统中的执行路径、当前阶段、阶段耗时、错误原因和异常堆栈。 + +当前实现采用 OpenTelemetry 作为 trace 数据格式和导出协议: + +- XTuner 业务代码主要使用 `xtuner_trace_function` / `xtuner_trace_span` 这两个轻量封装接口。 +- trace 数据通过 OTel exporter 发往 OTLP backend。 +- 当前推荐 viewer/backend 是 Jaeger。 +- 每条 task 通过稳定的 `trace_id` 串起 producer、agent loop、rollout controller、rollout worker、session server、sandbox/localhost agent 等阶段。 + +## 2. TraceConfig + +入口配置是 `xtuner.v1.rl.trace.TraceConfig`。 + +```python +TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_exporter="otlp", + otel_service_name="xtuner-rl", + jaeger_query_url="http://127.0.0.1:16686", +) +``` + +字段含义: + +- `enabled`: 是否开启 trace。关闭时 trace API 走 noop,不导出数据。 +- `otel_endpoint`: OTel trace 导出的 OTLP endpoint。 +- `otel_protocol`: OTLP 传输协议,当前支持 `grpc` 和 `http/protobuf`。 +- `otel_exporter`: OTel exporter 类型,当前支持 `otlp` 和 `console`。 +- `otel_service_name`: 写入 OTel Resource 的 `service.name`,在 Jaeger 中作为 service 名展示。 +- `jaeger_query_url`: Jaeger Query/UI 地址。训练启动时 rank0 会打印这个地址;内部 viewer 也可用它按 task 拉取 Jaeger trace。 + +Trainer 初始化时会调用 `configure_trace(trace_config)`。如果配置里有 `jaeger_query_url`,rank0 会打印: + +```text +Jaeger Trace Viewer: http://127.0.0.1:16686 +XTuner Task Trace Dashboard: http://127.0.0.1: +``` + +两个地址的定位不同: + +- `Jaeger Trace Viewer`: Jaeger 原生 UI,适合查看单条 task 的完整 span tree、父子关系、span attribute、错误 event。 +- `XTuner Task Trace Dashboard`: XTuner 聚合 dashboard,后端仍然查询 Jaeger,但页面以 task 为中心展示 overview、当前阶段分布、阶段耗时统计和 task 列表。 + +`XTuner Task Trace Dashboard` 默认由 trainer 在 rank0 自动启动,不需要用户手动启动额外脚本。它会按 `otel_service_name` 查询 Jaeger Query API,并优先用当前 `run.id` 隔离同一 service 下的多次实验。 + +## 3. 环境变量和分布式传播 + +主进程是否启用 trace 只由 `TraceConfig.enabled` 决定。`enabled=True` 时,`configure_trace()` 会把 `TraceConfig` 转成标准 OpenTelemetry 环境变量,供 Ray actor、worker 进程和 agent 子进程读取。 + +标准 OTel 环境变量: + +- `OTEL_TRACES_EXPORTER`: exporter 类型,例如 `otlp`、`console`、`none`。 +- `OTEL_EXPORTER_OTLP_PROTOCOL`: OTLP 传输协议,例如 `grpc`、`http/protobuf`。 +- `OTEL_EXPORTER_OTLP_TRACES_ENDPOINT`: OTLP traces endpoint。 +- `OTEL_SERVICE_NAME`: OTel service 名称,在 Jaeger 中作为 service 展示。 + +XTuner 额外保留一个可选环境变量: + +- `XTUNER_OTEL_RUN_ID`: 实验 run id。它不是启用开关,也不是 exporter 配置,只用于写入 span attribute,方便按一次训练实验聚合查询。 + +Ray actor 侧通过 `merge_trace_runtime_env(actor_options)` 把上述环境变量合并到 `runtime_env.env_vars`。每个进程各自持有一个本地 OTel runtime,但因为 `trace_id`、OTel context 和标准 OTel 配置一致,导出的 span 可以被 backend 聚合到同一条 task trace 下。 + +`TraceConfig.enabled=False` 时,XTuner 不传播 OTel 环境变量。子进程是否启用 trace 只看标准 OTel 环境变量是否存在;如果 `OTEL_TRACES_EXPORTER=none`、`false`、`off` 或 `0`,则明确关闭。 + +## 4. Task Identity + +当前每条 rollout sample 有一个稳定的业务 trace identity。 + +`RolloutState` 新增字段: + +```python +trace_id: str | None = None +``` + +sampler 在 sample 阶段会填充 `state.trace_id`: + +- 优先使用 `task_name`、`data_source`、`message_uid`、`repeat_index` 生成稳定 hash。 +- 如果没有 `message_uid`,则使用 `task_name`、`data_source`、`message`、`repeat_index` 生成稳定 hash。 +- 生成格式类似:`gsm8k:<16位sha1前缀>`。 + +trace event 构建时的优先级: + +1. 如果 target 上有非空 `trace_id`,直接使用。 +2. 否则 fallback 到 `task_name:uid`。 +3. 如果既没有 `trace_id` 也无法解析 `uid`,该 trace event 不记录。 + +导出的 OTel span attributes 会包含: + +- `xtuner.trace_id` +- `case.id` +- `xtuner.task_name` +- `xtuner.uid` +- `xtuner.session_uid` +- `xtuner.status` +- `xtuner.train_step` +- `xtuner.model_step` +- `xtuner.producer_future_step` +- `xtuner.produce_batch_id` +- `xtuner.worker_rank` + +其中 `case.id` 是为了兼容一些通用 trace viewer / query 习惯,值和 `xtuner.trace_id` 一致。 + +## 5. Trace Event Schema + +内部 task-level event 使用 `TraceEvent` 表达。主要字段: + +- `trace_id`: task trace identity。 +- `stage`: 阶段名,例如 `xtuner.rollout_worker.generate.start`。 +- `timestamp_s`: 事件时间戳,单位秒。 +- `status`: 当前 task 状态。 +- `task_name`: 任务名称。 +- `uid`: task uid。 +- `session_uid`: group/session 维度 id。 +- `train_step`: 当前训练 step。 +- `model_step`: 当前模型 step。 +- `producer_future_step`: producer future step。 +- `produce_batch_id`: producer batch identity。 +- `worker_rank`: rollout worker rank。 +- `elapsed_s`: span 耗时,主要出现在 `.end` / `.error` 事件上。 +- `error_msg`: 错误原因。 +- `error_type`: Python 异常类型,例如 `ValueError`。 +- `error_stacktrace`: Python 异常堆栈。 +- `attributes`: 用户或插桩点传入的自定义属性。 + +导出到 OTel 时: + +- `.start` 事件会启动一个 OTel span。 +- `.start` 事件还会额外导出一个极短的 lifecycle marker span,attribute 中包含 `xtuner.lifecycle_marker=True` 和 `xtuner.stage_event=start`。这是为了让 Jaeger-backed 在线 dashboard 在长 span 结束前也能看到 task 当前已经进入哪个阶段。 +- `.end` 事件会结束对应 OTel span,并标记 status OK。 +- `.error` 事件会结束对应 OTel span,并标记 status ERROR。 +- 如果 `.error` 来源是真实异常,会写标准 OTel `exception` event,包含 `exception.type`、`exception.message`、`exception.stacktrace`。 +- 如果 `.error` 来源是业务失败 `mark_error()`,只写错误原因,不伪造异常堆栈。 + +阶段耗时统计仍然来自真正的 duration span;lifecycle marker 只用于在线判断 running/pending/current stage,不参与 avg/p95/max 统计。 + +## 6. 插桩 API + +### 6.1 xtuner_trace_function + +`xtuner_trace_function` 用于装饰函数,把函数调用记录成一个 task-level span。 + +示例: + +```python +@xtuner_trace_function("xtuner.rollout_worker.generate") +async def generate(self, rollout_state: RolloutState) -> RolloutState: + ... +``` + +默认行为: + +- 默认从参数名 `rollout_state` 中解析 target。 +- 如果函数返回 `RolloutState` 或 `list[RolloutState]`,`.end` 事件会使用返回值作为 target,因此能记录返回后的最新 `status`。 +- 如果参数名不标准,可以显式传 `target="group"` 之类的参数。 +- 如果函数抛异常,会记录 `.error`,并包含 `error_msg`、`error_type`、`error_stacktrace`,然后继续抛出异常。 + +典型使用场景: + +- producer 方法。 +- agent loop 方法。 +- rollout controller / rollout worker 方法。 +- partial rollout preprocess/postprocess。 + +### 6.2 xtuner_trace_span + +`xtuner_trace_span` 用于手写一个 async block 的 trace span,是 sandbox 旧接口 `with span(uid_obs, "validate", task_id=tid)` 的替代方案。 + +示例: + +```python +async with xtuner_trace_span(item, "xtuner.sandbox.validate", task_id=item.id) as span: + score, failed = await self.validate.run(item, pool) + if failed: + span.mark_error("validate_failed") +``` + +能力: + +- 自动记录 `.start` / `.end`。 +- block 内抛异常时自动记录 `.error`,包含异常类型和堆栈。 +- `span.annotate(**fields)` 可以把运行中发现的信息追加到 exit event,例如 sandbox url、env id、agent 名称。 +- `span.mark_error(message)` 可以在没有 Python 异常的情况下,把业务失败记录为 `.error`。 + +### 6.3 trace_event + +`trace_event(target, name, **kwargs)` 是更底层的事件记录接口。当前主要作为 `xtuner_trace_function` 和 `xtuner_trace_span` 的基础能力,不建议业务代码优先直接使用。 + +### 6.4 otel_trace_span / begin_otel_span / end_otel_span + +这些接口用于低层 OTel span,不依赖 `RolloutState`: + +- `otel_trace_span(name, **attrs)`: 在当前 OTel context 下创建普通 span 的同步 context manager,不解析 `trace_id`、task status 或 XTuner task metadata。 +- `begin_otel_span(name, **attrs)`: 手动开始 span。 +- `end_otel_span(span, exc=None, **attrs)`: 手动结束 span。 + +当前主要用于 session server 这类 HTTP proxy / stream 读取逻辑。 + +### 6.5 trace_task_context + +`trace_task_context(attrs)` 用于把当前 task 的 trace context 和 baggage 挂到当前上下文里,使后续子调用、HTTP 转发、agent 子进程能接上同一条 task trace。 + +典型位置: + +- `AgentInSandboxLoop._run_item` +- `AgentInLocalhostLoop._run_item` + +## 7. 默认插桩阶段 + +### 7.1 Producer + +当前 producer 侧默认包含: + +- `xtuner.producer.sample_group` +- `xtuner.producer.generate_group` +- `xtuner.producer.put_generated_group` + +这些阶段会带上: + +- `task_name` +- `train_step` +- `model_step` +- `producer_future_step` +- `produce_batch_id` + +`generate_group` 会把上述 trace step 信息写入每个 `RolloutState.extra_fields`,便于后续跨进程阶段继续带上同一个 batch 语义。 + +### 7.2 Agent Loop + +通用 agent loop 默认包含: + +- `xtuner.agent_loop.generate_group` +- `xtuner.agent_loop.generate_sample` +- `xtuner.judger.judge` + +其中 `xtuner.judger.judge` 覆盖单条 judge 和 batch judge。 + +### 7.3 Rollout Controller / Worker / Engine + +rollout 链路默认包含: + +- `xtuner.rollout_controller.generate` +- `xtuner.rollout_worker.generate` +- `xtuner.rollout_engine.generate` + +`xtuner.rollout_engine.generate` 包住实际向推理引擎发起 HTTP 请求的阶段,目标是直接观测推理引擎调用耗时。 + +partial rollout 相关: + +- `xtuner.partial_rollout_handler.preprocess` +- `xtuner.partial_rollout_handler.postprocess` + +### 7.4 Sandbox Agent Loop + +sandbox agent loop 是当前 trace 设计重点覆盖对象。默认包含: + +- `xtuner.agent_in_sandbox.generate_group` +- `xtuner.agent_in_sandbox.generate_sample` +- `xtuner.sandbox.run_total` +- `xtuner.sandbox.acquire` +- `xtuner.sandbox.infer` +- `xtuner.sandbox.validate` +- `xtuner.sandbox.entry:` +- `xtuner.agent_in_sandbox.materialize_trajectory` + +其中: + +- `agent_in_sandbox.generate_group` / `generate_sample` 覆盖 sandbox agent loop 的高层入口。 +- `run_total` 覆盖单条 sandbox task 从开始到结束的总耗时。 +- `acquire` 覆盖沙盒环境获取过程,并通过 `annotate()` 记录 `sandbox_name`、`sandbox_env_id`、`sandbox_url`、`sandbox_image`。 +- `infer` 覆盖 agent 推理/执行过程,`xtuner.stage.kind=agent_run`。 +- `validate` 覆盖验证/评测过程,`xtuner.stage.kind=judge`。如果 validate 返回业务失败但没有抛 Python 异常,也会通过 `mark_error()` 记录为 error span。 +- `entry:` 覆盖 sandbox 内具体 entry 脚本执行,区分 `ShellEntry` 和 `DetachedShellEntry`,并写入 `entry_name`、`entry_kind`、`xtuner.stage.kind=entry`。entry 返回码失败时会标记为 error span。 +- `materialize_trajectory` 覆盖 agent 结果转成训练用 `input_ids` / `labels` 的阶段,只记录 message 数、是否有 tools、token 数等轻量属性,不记录 prompt / response / tool result 大对象。 + +这部分覆盖了旧 sandbox trace 中类似: + +```python +with span(uid_obs, "validate", task_id=tid): + ... +``` + +的功能,并额外补上了 OTel context、异常堆栈和 Jaeger 展示能力。 + +sandbox shell / detached entry 会把当前 trace context 注入到子进程环境变量: + +- `OTEL_PROPAGATOR_TRACEPARENT` +- `OTEL_PROPAGATOR_TRACESTATE` +- `OTEL_PROPAGATOR_BAGGAGE` + +同时也会把标准 OTel exporter 环境变量传给子进程,例如 `OTEL_TRACES_EXPORTER`、`OTEL_EXPORTER_OTLP_TRACES_ENDPOINT`、`OTEL_SERVICE_NAME`。这样 lagent 或其他独立进程只需要从这些环境变量恢复 carrier,就能把自己的 child span 接回同一条 task trace。 + +### 7.5 Localhost Agent Loop + +localhost agent loop 默认包含: + +- `xtuner.agent_in_localhost.generate_group` +- `xtuner.agent_in_localhost.generate_sample` +- `xtuner.localhost.run_total` +- `xtuner.localhost.infer` +- `xtuner.localhost.validate` +- `xtuner.localhost.judger` +- `xtuner.localhost.agent` +- `xtuner.agent_in_localhost.materialize_trajectory` + +`xtuner.localhost.agent` 会通过 `annotate()` 记录实际选中的 agent 信息。 + +`xtuner.agent_in_localhost.materialize_trajectory` 覆盖 localhost agent 结果转训练 trajectory 的阶段,记录 lightweight attributes,不记录大文本。 + +### 7.6 Session Server + +session server 侧默认包含低层 HTTP / stream span: + +- `xtuner.session_server.on_request` +- `xtuner.session_server.forward_worker` +- `xtuner.session_server.stream_read` +- `xtuner.session_server.read_response` +- `xtuner.session_server.on_response` + +其中 `xtuner.session_server.forward_worker` 和 `xtuner.session_server.stream_read` 会写入 `xtuner.stage.kind=llm_call`,用于在 task timeline 中明确表示这是 sample 维度的 LLM 调用。 + +记录信息包括: + +- target URL +- stream / non-stream +- request bytes +- timeout +- input tokens +- max tokens +- model +- HTTP method/path +- worker base URL +- trace context 来源 +- HTTP status +- response bytes +- output tokens +- prompt/completion/total tokens +- stream chunk 数 +- first chunk latency +- first output token latency +- first content latency +- finish reason + +这里的 `first_output_token_ms` 用于判断推理引擎是否迟迟没有返回第一个 token。 + +## 8. 错误记录能力 + +当前 trace 能记录三类错误信息。 + +### 8.1 Python 异常 + +如果 `xtuner_trace_function` / `xtuner_trace_span` / `otel_trace_span` 包裹的代码抛出异常,会记录: + +- 错误状态:OTel span status = ERROR。 +- 错误原因:`error_msg`。 +- 异常类型:`error_type` / `exception.type`。 +- 异常堆栈:`error_stacktrace` / `exception.stacktrace`。 + +### 8.2 业务失败 + +如果函数没有抛异常,但业务结果失败,可以调用: + +```python +span.mark_error("return_code=1: failed") +``` + +这会把 exit event 从 `.end` 改成 `.error`,并记录 `error_msg`。这种情况下不会记录 `error_stacktrace`,因为没有真实 Python 异常。 + +### 8.3 RolloutState error_msg + +如果某个模块把错误写入 `rollout_state.error_msg`,返回后的 `.end` 事件会记录最新 `status`。当前 task detail / viewer 侧可以展示 task 的失败状态和错误原因。 + +## 9. Viewer / Backend + +### 9.1 Jaeger + +当前推荐使用 Jaeger 作为 viewer/backend。 + +使用方式: + +- 启动支持 OTLP 的 Jaeger。 +- 配置 `TraceConfig.otel_endpoint` 指向 Jaeger OTLP endpoint。 +- 配置 `TraceConfig.jaeger_query_url` 指向 Jaeger UI/Query 地址。 +- 在 Jaeger 中选择 `otel_service_name` 对应的 service。 +- 通过 tag 查询单条 task: + +```text +xtuner.trace_id= +``` + +也可以查询: + +```text +case.id= +``` + +### 9.2 XTuner Task Trace Dashboard + +`xtuner/tools/jaeger_trace_dashboard.py` 提供 Jaeger-backed 聚合 dashboard。它不是新的 trace backend,不保存额外数据,只通过 Jaeger Query API 拉取 OTel span,然后按 `xtuner.trace_id` 重建 task 状态。 + +页面包含: + +- overview:total / pending / running / completed / failed task 数量。 +- stage summary:每个阶段当前有多少 task、完成 span 数、失败 span 数、平均耗时、p95、max。 +- task list:支持按 status、stage、trace_id / uid / task_name 过滤。 +- task detail:点击 task 后展示文字版 timeline 和嵌入式 Jaeger timeline。 + +训练启动时,如果 `TraceConfig.enabled=True` 且配置了 `jaeger_query_url`,rank0 会自动启动该 dashboard 并打印: + +```text +XTuner Task Trace Dashboard: http://127.0.0.1: +``` + +也可以手动启动: + +```bash +python -m xtuner.tools.jaeger_trace_dashboard \ + --jaeger-query-url http://127.0.0.1:16686 \ + --service xtuner-rl \ + --run-id +``` + +如果需要生成一个静态 HTML 快照: + +```bash +python -m xtuner.tools.jaeger_trace_dashboard \ + --jaeger-query-url http://127.0.0.1:16686 \ + --service xtuner-rl \ + --run-id \ + -o /tmp/xtuner_task_trace_dashboard.html +``` + +### 9.3 Task trace analysis / view helpers + +旧的 JSONL producer viewer / hotspot CLI 已不再作为当前实现的一部分保留。 + +当前保留的是中性命名的复用模块: + +- `xtuner/tools/task_trace_analysis.py`: 从 `TraceEvent` 构建 overview、stage summary、task list、task detail 和 timeline payload。 +- `xtuner/tools/task_trace_view.py`: 渲染统一 dashboard HTML,并提供 Jaeger 单 trace 查询结果的 normalize/fetch helper。 + +核心 task trace 已经迁移到 OTel/Jaeger,JSONL 不再是 trace 数据后端。 + +## 10. 当前已支持的 Backend 类型 + +只要 backend 支持 OTel / OTLP,理论上都可以接入。当前代码层直接支持: + +- OTLP gRPC。 +- OTLP HTTP/protobuf。 +- console exporter,用于本地调试。 + +当前推荐组合: + +- backend/viewer: Jaeger。 +- exporter: `otlp`。 +- protocol: `http/protobuf` 或 `grpc`,取决于 Jaeger 暴露的 endpoint。 + +## 11. 已知边界 + +当前实现不覆盖或还未最终收敛的点: + +- 当前主要关注每个样本/task 的 tracing,不覆盖推理系统自身的 GPU metrics、KV cache metrics、queue metrics 等系统指标。 +- 旧 JSONL viewer/tool 已删除;当前聚合 dashboard 直接从 Jaeger Query API 读取数据。 +- `trace_id` 和 `uid` 目前都存在。当前先显式保留 `trace_id`,等语义稳定后再考虑是否和 `uid` 合并。 +- session server 已经记录请求、响应和 stream latency。当前 trace 只承诺观测 session server 视角,不拆推理引擎内部 queue / prefill / decode。 + +## 12. 后续开发项 + +### 12.1 慢 task 自动保存调用栈 + +当前已经支持异常路径的调用栈记录: + +- Python 异常会通过 `xtuner_trace_span` / `xtuner_trace_function` 记录 `error_stacktrace`。 +- 导出到 OTel 后会写入标准 `exception.stacktrace`。 + +但“耗时很久但没有抛异常”的 task 还不会自动保存调用栈。后续需要补充一个慢 task 诊断能力: + +- 支持配置慢 task 阈值,例如某个 span 持续超过 N 秒仍未结束。 +- 触发时保存该 task 所在进程的 Python 调用栈。 +- 调用栈需要作为 OTel span event 或 attribute 关联到对应 task / span。 +- 需要控制采样和大小,避免在大量 task 卡住时产生过多 trace 数据。 +- viewer 侧需要在 task detail 中展示慢 task stacktrace,并标明触发时间和所在进程。 + +这个功能和错误堆栈不同:错误堆栈回答“失败时在哪里抛异常”,慢 task 堆栈回答“卡住时当前正在执行哪里”。 + +### 12.2 Session server 视角的首 token 时间 + +当前不修改推理引擎代码,只从 session server 视角观测流式推理调用延迟。这个能力用于回答: + +> session server 已经把请求发给 upstream worker 后,多久能看到 response headers、第一段 stream、第一批 output token、第一段 content,以及完整 stream 结束。 + +相关 span: + +- `xtuner.session_server.forward_worker` +- `xtuner.session_server.stream_read` + +当前记录的关键字段: + +- `upstream_headers_ms`: 从 session server 发起 upstream request 到拿到 response headers 的时间。 +- `first_chunk_ms`: 从开始读取 stream body 到收到第一段 chunk 的时间。 +- `first_output_token_ms`: 从开始读取 stream body 到收到第一批 `output_ids` 的时间。 +- `first_content_ms`: 从开始读取 stream body 到收到第一段 `delta.content` 的时间。 +- `first_chunk_from_forward_ms`: 从发起 upstream request 到收到第一段 chunk 的时间。 +- `first_output_token_from_forward_ms`: 从发起 upstream request 到收到第一批 `output_ids` 的时间。 +- `first_content_from_forward_ms`: 从发起 upstream request 到收到第一段 `delta.content` 的时间。 +- `stream_read_ms`: 从开始读取 stream body 到 stream 结束的时间。 +- `stream_complete_from_forward_ms`: 从发起 upstream request 到完整 stream 结束的时间。 + +这些字段会写入 OTel span attribute,并会在 XTuner Task Trace Dashboard 的 task detail 文字 timeline 中展示。当前明确不做: + +- engine queue 拆分。 +- prefill / decode 内部拆分。 +- 推理引擎内部 OTel 接入。 +- 修改 lmdeploy / vLLM / sglang 引擎代码。 + +## 13. 最小使用示例 + +配置: + +```python +from xtuner.v1.rl.trace import TraceConfig + +trace_config = TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_service_name="xtuner-rl", + jaeger_query_url="http://127.0.0.1:16686", +) +``` + +业务插桩: + +```python +from xtuner.v1.rl.trace import xtuner_trace_function, xtuner_trace_span + + +@xtuner_trace_function("xtuner.example.generate") +async def generate(rollout_state): + async with xtuner_trace_span(rollout_state, "xtuner.example.inner", custom_field="value") as span: + result = await do_work() + if not result.ok: + span.mark_error(result.error) + return rollout_state +``` + +Jaeger 查询: + +```text +service = xtuner-rl +tag = xtuner.trace_id= +``` diff --git a/docs/superpowers/specs/2026-06-16-xtuner-rl-task-trace-final.md b/docs/superpowers/specs/2026-06-16-xtuner-rl-task-trace-final.md new file mode 100644 index 0000000000..9e73801b45 --- /dev/null +++ b/docs/superpowers/specs/2026-06-16-xtuner-rl-task-trace-final.md @@ -0,0 +1,797 @@ +# XTuner RL Task Trace 最终开发文档 + +日期:2026-06-16 + +本文档是本轮 XTuner RL task-level trace 开发的最终收口文档。它合并并整理了前序设计、实现计划、viewer 方案、OTel/Jaeger 迁移方案和 sandbox follow-up 方案,只描述当前应以之为准的实现状态、支持能力和未完成开发项。 + +历史文档仍保留用于追溯讨论过程,但如果历史文档和本文档冲突,以本文档为准。 + +## 1. 文档来源 + +本轮开发过程中产生过以下几类文档: + +- 早期 producer trace 设计: + - `docs/superpowers/specs/2026-06-05-producer-task-trace-design.md` + - `docs/superpowers/specs/2026-06-05-producer-task-trace-pseudocode.md` +- 统一在线 / 离线 viewer 设计: + - `docs/superpowers/plans/2026-06-09-unified-trace-viewer.md` + - `docs/superpowers/specs/2026-06-09-trace-next-phase-working-notes.md` +- OTel / 开源 trace backend 设计: + - `docs/superpowers/specs/2026-06-11-xtuner-rl-observability-open-source-trace-design.md` + - `docs/superpowers/plans/2026-06-11-otel-trace-parity.md` +- 当前能力记录和后续补齐计划: + - `docs/superpowers/specs/2026-06-12-current-trace-capabilities.md` + - `docs/superpowers/plans/2026-06-16-otel-task-trace-followup.md` +- Jaeger / OTLP 使用说明: + - `recipe/otle/README.md` + - `recipe/otle/jaeger/jaeger-memory.yaml` + +其中早期 JSONL 方案已经被后续 OTel / Jaeger 方案替换。早期文档里的 JSONL writer、训练路径本地 trace store、纯 JSONL online viewer 等设计不再是当前主线。 + +## 2. 当前目标 + +当前 trace 功能的目标是做 XTuner RL 系统的 task-level observability。 + +核心问题是: + +- 每条 rollout sample / task 当前执行到了哪里。 +- 训练卡住时,能看出 task 分布在哪些阶段。 +- 某条 task 失败时,能看到失败阶段、错误原因和异常堆栈。 +- 某个阶段耗时很长时,能看到 avg / p95 / max 等热点统计。 +- 一条 task 跨 producer、agent loop、rollout worker、session server、sandbox / localhost agent loop、子进程时,仍能被串成一条 trace。 + +当前阶段只关注每个样本的 tracing,不关注推理系统整体 metrics。也就是说,当前 trace 不采集 GPU 利用率、engine queue depth、prefill/decode 拆分、KV cache、QPS、吞吐等系统指标。 + +## 3. 总体架构 + +当前实现采用 OpenTelemetry 作为 trace 数据格式、上下文传播机制和导出协议,Jaeger 作为默认 viewer/backend。 + +整体链路: + +```text +XTuner trace APIs: + xtuner_trace_function / xtuner_trace_span / otel_trace_span + ↓ +xtuner.v1.rl.trace + ↓ +OpenTelemetry SDK + ↓ +OTLP exporter + ↓ +Jaeger + ↓ +XTuner Task Trace Dashboard + Jaeger Native Trace +``` + +设计原则: + +- 业务代码只调用 XTuner 封装的 `xtuner_trace_function` / `xtuner_trace_span`,不直接依赖 OTel API。 +- `span name` 表达稳定业务阶段,例如 `xtuner.rollout_worker.generate`。 +- 细节写入 attributes,例如 `xtuner.trace_id`、`xtuner.stage.kind`、`xtuner.worker_rank`。 +- 默认不记录 prompt、response、tool result、图片、tensor 等大对象。 +- 支持标准 OTLP backend。当前 task dashboard 查询层只实现了 Jaeger Query API。 +- 训练路径不再维护自研 JSONL backend。 + +## 4. TraceConfig + +入口配置是 `xtuner.v1.rl.trace.TraceConfig`。 + +示例: + +```python +from xtuner.v1.rl.trace import TraceConfig + +trace_config = TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_exporter="otlp", + otel_service_name="exp-1", + jaeger_query_url="http://127.0.0.1:16686", +) +``` + +字段含义: + +- `enabled`: 是否开启 trace。关闭时 trace API 走 noop,不导出数据。 +- `otel_endpoint`: OTel trace 导出的 OTLP endpoint。 +- `otel_protocol`: OTLP 传输协议,支持 `grpc` 和 `http/protobuf`。 +- `otel_exporter`: exporter 类型,支持 `otlp` 和 `console`。 +- `otel_service_name`: OTel Resource 的 `service.name`,在 Jaeger service 下拉框中展示。建议写短实验名,例如 `exp-1`。 +- `jaeger_query_url`: Jaeger Query/UI 地址。配置后 rank0 会自动启动 XTuner Task Trace Dashboard,并打印 Jaeger 原生 UI 和 XTuner dashboard 地址。 + +Trainer 初始化时会调用 `configure_trace(trace_config)`。如果 `enabled=True` 且配置了 `jaeger_query_url`,rank0 会打印: + +```text +Jaeger Trace Viewer: http://127.0.0.1:16686 +XTuner Task Trace Dashboard: http://127.0.0.1: +``` + +## 5. 环境变量 + +trace 是否开启由 `TraceConfig.enabled` 决定。开启后,XTuner 会把配置转换成标准 OpenTelemetry 环境变量,供 Ray actor、worker 进程、session server 和 sandbox 子进程读取。 + +保留的标准 OTel 环境变量: + +- `OTEL_TRACES_EXPORTER`: exporter 类型,例如 `otlp`、`console`、`none`。 +- `OTEL_EXPORTER_OTLP_PROTOCOL`: OTLP 协议,例如 `grpc`、`http/protobuf`。 +- `OTEL_EXPORTER_OTLP_TRACES_ENDPOINT`: OTLP traces endpoint。 +- `OTEL_SERVICE_NAME`: OTel service name。 + +XTuner 额外保留: + +- `XTUNER_OTEL_RUN_ID`: 当前实验 run id,用于写入 span attribute,方便 dashboard 按一次训练实验聚合查询。 + +已删除或不再使用的历史环境变量: + +- `XTUNER_TRACE_ENABLED` +- `XTUNER_TRACE_OTEL_ENDPOINT` +- `XTUNER_TRACE_OTEL_SERVICE_NAME` +- `XTUNER_TRACE_OTEL_PROTOCOL` +- `XTUNER_TRACE_OTEL_EXPORTER` +- `XTUNER_OTEL_ENABLED` +- `AGENT_OTEL_ENABLED` +- `XTUNER_OTEL_SERVICE_NAME` + +这些变量被删除的原因是:启用开关应来自 `TraceConfig.enabled`,OTel exporter 配置应使用标准 OTel 环境变量,避免 XTuner 自定义环境变量和 OTel 标准重复。 + +## 6. Task Identity + +每条 rollout sample 有稳定 trace identity。 + +`RolloutState` 新增字段: + +```python +trace_id: str | None = None +``` + +sampler 在 sample 阶段写入 `state.trace_id`: + +- 优先使用 `task_name`、`data_source`、`message_uid`、`repeat_index` 生成稳定 hash。 +- 没有 `message_uid` 时,使用 `task_name`、`data_source`、`message`、`repeat_index` 生成稳定 hash。 +- 格式类似:`gsm8k:<16位sha1前缀>`。 + +trace id 解析优先级: + +1. target 上存在非空 `trace_id`,直接使用。 +2. 否则 fallback 到 `task_name:uid`。 +3. 如果无法解析 task identity,则跳过该 event。 + +导出的常用 attributes: + +- `xtuner.trace_id` +- `case.id` +- `xtuner.task_name` +- `xtuner.uid` +- `xtuner.session_uid` +- `xtuner.status` +- `xtuner.train_step` +- `xtuner.model_step` +- `xtuner.producer_future_step` +- `xtuner.produce_batch_id` +- `xtuner.worker_rank` +- `xtuner.stage.kind` + +其中 `case.id` 和 `xtuner.trace_id` 值一致,用于兼容一些通用 trace viewer / query 习惯。 + +## 7. 插桩 API + +### 7.1 xtuner_trace_function + +`xtuner_trace_function` 用于装饰函数,把一次函数调用记录为 task-level span。 + +示例: + +```python +from xtuner.v1.rl.trace import xtuner_trace_function + + +@xtuner_trace_function("xtuner.rollout_worker.generate") +async def generate(self, rollout_state): + ... +``` + +默认行为: + +- 默认从参数名 `rollout_state` 解析 target。 +- target 可以是 `RolloutState` 或 `list[RolloutState]`。 +- 函数返回 `RolloutState` 或 `list[RolloutState]` 时,`.end` 事件使用返回值作为 target,因此能记录返回后的最新 `status`、`error_msg` 等字段。 +- 参数名不标准时,可以显式传 `target="group"`。 +- 函数抛异常时,记录 `.error`,包含 `error_msg`、`error_type`、`error_stacktrace`,然后继续抛出异常。 + +### 7.2 xtuner_trace_span + +`xtuner_trace_span` 用于手写一个 async block 的 task-level span。 + +它替代了旧 sandbox 代码里的: + +```python +with span(uid_obs, "validate", task_id=tid): + ... +``` + +示例: + +```python +from xtuner.v1.rl.trace import xtuner_trace_span + + +async with xtuner_trace_span(item, "xtuner.sandbox.validate", task_id=item.id) as span: + score, failed = await self.validate.run(item, pool) + if failed: + span.mark_error("validate_failed") +``` + +支持能力: + +- 自动记录 start / end。 +- block 内抛 Python 异常时自动记录 error,并保存异常类型和堆栈。 +- `span.annotate(**fields)` 可以把运行时发现的字段追加到退出事件,例如 sandbox url、env id、agent name。 +- `span.mark_error(message)` 可以在没有 Python 异常时,把业务失败标记成 error span。 + +### 7.3 trace_task_context + +`trace_task_context(attrs)` 用于把当前 task 的 trace context 和 baggage 放进当前上下文,使后续子调用、HTTP 转发、sandbox entry 子进程能接到同一条 task trace。 + +典型使用位置: + +- sandbox agent loop 单条 task 执行入口。 +- localhost agent loop 单条 task 执行入口。 +- session server 从 request body 恢复 trace context。 + +### 7.4 otel_trace_span / begin_otel_span / end_otel_span + +这些是低层 OTel span helper,不依赖 `RolloutState`。其中 `otel_trace_span` 是当前 OTel context 下创建普通 span 的薄封装,不会解析 `trace_id`、task status 或 XTuner task metadata。当前主要用于 session server 的 HTTP request、upstream forward 和 stream read。 + +业务代码优先使用 `xtuner_trace_function` / `xtuner_trace_span`。只有在没有明确 `RolloutState` target,但已经处在某条 trace context 下的低层系统边界,才使用 `otel_trace_span`。`begin_otel_span` / `end_otel_span` 是更底层的手动生命周期 helper。 + +## 8. 已有默认插桩点 + +### 8.1 Producer / Sampler + +默认阶段: + +- `xtuner.producer.sample_group` +- `xtuner.producer.generate_group` +- `xtuner.producer.put_generated_group` + +相关能力: + +- sampler 负责给每个 `RolloutState` 写入 `trace_id` 和 `task_name`。 +- producer 侧会把 `train_step`、`model_step`、`producer_future_step`、`produce_batch_id` 写入 task trace metadata。 +- `put_generated_group` 支持 `target="group"`,覆盖非标准参数名的 list target。 + +### 8.2 Generic Agent Loop + +默认阶段: + +- `xtuner.agent_loop.generate_group` +- `xtuner.agent_loop.generate_sample` +- `xtuner.judger.judge` + +`xtuner_trace_function` 会在 `.end` 中使用返回后的状态,因此能看到 judger 后 task 的最新 status。 + +### 8.3 Rollout Controller / Worker / Engine + +默认阶段: + +- `xtuner.rollout_controller.generate` +- `xtuner.rollout_worker.generate` +- `xtuner.rollout_engine.generate` + +其中 `xtuner.rollout_engine.generate` 包住实际向推理引擎发起 HTTP 请求的阶段,用于观测 sample 在 inference engine 请求上的耗时。 + +partial rollout 默认阶段: + +- `xtuner.partial_rollout_handler.preprocess` +- `xtuner.partial_rollout_handler.postprocess` + +### 8.4 Sandbox Agent Loop + +sandbox agent loop 是当前重点覆盖对象。 + +默认阶段: + +- `xtuner.agent_in_sandbox.generate_group` +- `xtuner.agent_in_sandbox.generate_sample` +- `xtuner.sandbox.run_total` +- `xtuner.sandbox.acquire` +- `xtuner.sandbox.infer` +- `xtuner.sandbox.validate` +- `xtuner.sandbox.entry:` +- `xtuner.agent_in_sandbox.materialize_trajectory` + +阶段语义: + +- `generate_group` / `generate_sample`: sandbox agent loop 高层入口。 +- `run_total`: 单条 sandbox task 从开始到结束的总耗时。 +- `acquire`: 沙盒环境获取,记录 `sandbox_name`、`sandbox_env_id`、`sandbox_url`、`sandbox_image`。 +- `infer`: agent 执行 / 推理过程,`xtuner.stage.kind=agent_run`。 +- `validate`: 验证 / 评测过程,`xtuner.stage.kind=judge`。如果 validate 返回业务失败但不抛异常,也会 `mark_error()`。 +- `entry:`: sandbox 内 entry 脚本执行,记录 `entry_name`、`entry_kind`、`xtuner.stage.kind=entry`。entry 返回失败会标为 error span。 +- `materialize_trajectory`: agent 结果转训练用 `input_ids` / `labels` 的阶段,只记录 message 数、tool 数、token 数等轻量属性。 + +sandbox entry 子进程会接收两类环境变量: + +- 标准 OTel exporter 环境变量。 +- 当前 task 的 propagator carrier: + - `OTEL_PROPAGATOR_TRACEPARENT` + - `OTEL_PROPAGATOR_TRACESTATE` + - `OTEL_PROPAGATOR_BAGGAGE` + +这样 lagent 或其他独立进程可以从环境变量恢复上下文,继续创建同一条 trace 的 child span。 + +### 8.5 Localhost Agent Loop + +默认阶段: + +- `xtuner.agent_in_localhost.generate_group` +- `xtuner.agent_in_localhost.generate_sample` +- `xtuner.localhost.run_total` +- `xtuner.localhost.infer` +- `xtuner.localhost.validate` +- `xtuner.localhost.judger` +- `xtuner.localhost.agent` +- `xtuner.agent_in_localhost.materialize_trajectory` + +`xtuner.localhost.agent` 会记录实际使用的 agent 信息。 + +### 8.6 Session Server + +默认阶段: + +- `xtuner.session_server.on_request` +- `xtuner.session_server.forward_worker` +- `xtuner.session_server.stream_read` +- `xtuner.session_server.read_response` +- `xtuner.session_server.on_response` + +`forward_worker` 和 `stream_read` 会写入 `xtuner.stage.kind=llm_call`。 + +session server 记录的关键字段: + +- target URL +- stream / non-stream +- request bytes +- timeout +- input tokens +- max tokens +- model +- HTTP method/path +- upstream worker base URL +- trace context 来源 +- HTTP status +- response bytes +- output tokens +- prompt / completion / total tokens +- stream chunk 数 +- first chunk latency +- first output token latency +- first content latency +- finish reason + +关键耗时字段: + +- `upstream_headers_ms` +- `first_chunk_ms` +- `first_output_token_ms` +- `first_content_ms` +- `first_chunk_from_forward_ms` +- `first_output_token_from_forward_ms` +- `first_content_from_forward_ms` +- `stream_read_ms` +- `stream_complete_from_forward_ms` + +这些字段用于从 session server 视角判断推理服务是不是迟迟没有返回第一个 token。 + +## 9. 错误记录能力 + +当前支持三类错误记录。 + +### 9.1 Python 异常 + +如果 `xtuner_trace_function` / `xtuner_trace_span` / `otel_trace_span` 包裹的代码抛出异常,会记录: + +- OTel span status = ERROR。 +- `error_msg` +- `error_type` +- `error_stacktrace` +- 标准 OTel exception event: + - `exception.type` + - `exception.message` + - `exception.stacktrace` + +### 9.2 业务失败 + +如果没有 Python 异常,但业务结果失败,可以调用: + +```python +span.mark_error("validate failed") +``` + +这会把 span 标记为 error,并记录 `error_msg`。这种情况下不会伪造异常堆栈。 + +### 9.3 RolloutState error_msg + +如果模块把错误写入 `rollout_state.error_msg`,返回后的 `.end` event 会记录最新 status 和 error msg。viewer 的 task detail 可以展示该错误原因。 + +## 10. Viewer / Backend + +### 10.1 Jaeger + +当前推荐 backend/viewer 是 Jaeger。 + +参考配置: + +- `recipe/otle/jaeger/jaeger-memory.yaml` + +默认端口: + +- Jaeger UI: `http://127.0.0.1:16686` +- OTLP HTTP: `http://127.0.0.1:14318/v1/traces` +- OTLP gRPC: `127.0.0.1:14317` + +Jaeger 中可以按 service 和 tag 查询: + +```text +service = exp-1 +tag = xtuner.trace_id= +``` + +也可以用: + +```text +case.id= +``` + +### 10.2 XTuner Task Trace Dashboard + +`xtuner/tools/jaeger_trace_dashboard.py` 提供 Jaeger-backed 聚合 dashboard。 + +它不是新的 trace backend,不保存额外 trace 数据,只通过 Jaeger Query API 拉取 OTel span,然后按 `xtuner.trace_id` 重建 task 状态。 + +页面包含: + +- overview: + - total tasks + - pending tasks + - running tasks + - completed tasks + - failed tasks +- stage summary: + - 当前在该阶段运行的 task 数 + - 经过该阶段的 task 数 + - failed span 数 + - avg + - p95 + - max +- task list: + - status filter + - stage filter + - trace_id / uid / task_name 搜索 + - error msg 展示 +- task detail: + - 文字版 timeline + - 图形版 timeline + - error msg + - Jaeger Native Trace 嵌入视图 + +训练启动时,如果 `TraceConfig.enabled=True` 且 `jaeger_query_url` 非空,rank0 自动启动 dashboard,不需要用户手动启动。 + +也可以手动启动: + +```bash +python -m xtuner.tools.jaeger_trace_dashboard \ + --jaeger-query-url http://127.0.0.1:16686 \ + --service exp-1 \ + --run-id +``` + +生成静态 HTML 快照: + +```bash +python -m xtuner.tools.jaeger_trace_dashboard \ + --jaeger-query-url http://127.0.0.1:16686 \ + --service exp-1 \ + --run-id \ + -o /tmp/xtuner_task_trace_dashboard.html +``` + +### 10.3 Task trace analysis / view helpers + +旧的 JSONL producer viewer / hotspot CLI 已不再作为当前实现的一部分保留。当前保留的是中性命名的复用模块: + +- `xtuner/tools/task_trace_analysis.py` +- `xtuner/tools/task_trace_view.py` + +当前它们的定位是: + +- `task_trace_analysis.py`: 从 `TraceEvent` 构建 overview、stage summary、task list、task detail 和 timeline payload。 +- `task_trace_view.py`: 渲染统一 dashboard HTML,并提供 Jaeger 单 trace 查询结果的 normalize/fetch helper。 +- `jaeger_trace_dashboard.py`: 作为 Jaeger Query API adapter,负责把 Jaeger span 转成 `TraceEvent`,再复用上述分析和渲染逻辑。 + +JSONL 不再是当前训练路径的 trace backend,也不再提供新的 JSONL viewer CLI。 + +## 11. 用户当前能解决的问题 + +### 11.1 训练卡住时看 task 卡在哪 + +XTuner dashboard 可以看到: + +- 当前 total / pending / running / completed / failed task 数。 +- 每个阶段当前有多少 task 正在运行。 +- 每个阶段 avg / p95 / max 耗时。 +- 单条 task 的 timeline 和 Jaeger span tree。 + +例如如果 100 个 task 中: + +- 80 个在 `xtuner.rollout_controller.generate` +- 20 个在 `xtuner.judger.judge` + +dashboard 能进一步给出: + +- 这两个阶段的 running task 数。 +- 这两个阶段历史完成 span 的 avg / p95 / max。 +- 每条具体 task 当前最后阶段、是否失败、error msg。 +- 点进单条 task 后看完整 timeline。 + +这比只看 latest stage distribution 更明确,但仍需要结合具体 task detail 和阶段耗时判断是哪个阶段真的卡住。 + +### 11.2 判断推理服务是不是慢 + +session server span 会记录: + +- 请求发到 upstream 后多久拿到 headers。 +- 多久收到第一段 stream chunk。 +- 多久收到第一批 output token。 +- 多久收到第一段 content。 +- 完整 stream 读完耗时。 + +这些指标能从 session server 视角判断推理服务是否迟迟没有返回。 + +### 11.3 排查失败 task + +失败 task 会在 dashboard task list 中显示为 failed,可以点开看: + +- 失败阶段。 +- `error_msg`。 +- Python 异常类型和 stacktrace。 +- Jaeger Native Trace 中的 error span 和 exception event。 + +validate 业务失败即使没有抛 Python 异常,也会被标成 error span。 + +### 11.4 跨进程串起同一条 task + +当前 XTuner 会把 OTel exporter env 和当前 task carrier 注入 Ray actor / sandbox entry / session server 相关边界。 + +XTuner 仓库内的跨度已经能通过同一个 `trace_id` 串起来。外部 repo 或独立进程只要读取 `OTEL_PROPAGATOR_TRACEPARENT` 等环境变量,并使用 OTel 创建 child span,就能继续接上。 + +## 12. 测试和验证资产 + +当前新增了单元测试: + +- `tests/rl/test_trace.py` + +覆盖方向包括: + +- `xtuner_trace_function` 默认 target 行为。 +- `xtuner_trace_function` 返回值作为 end target。 +- `list[RolloutState]` target。 +- 非标准 target 参数名。 +- `xtuner_trace_span` start / end / error / annotate / mark_error。 +- trace_id 生成和传递。 +- OTel attribute normalize。 +- Jaeger dashboard payload 重建。 +- sandbox / localhost runner trace 属性。 +- session trace helper。 + +当前新增了 smoke 配置: + +- `examples/v1/config/testing/trace_smoke_common.py` +- `examples/v1/config/testing/rl_trace_smoke_enabled.py` +- `examples/v1/config/testing/rl_trace_smoke_disabled.py` +- `examples/v1/config/testing/rl_trace_smoke_judger_fail.py` + +这些配置用于真实训练 smoke: + +- trace enabled。 +- trace disabled。 +- judger failed / error path。 + +## 13. 已知边界 + +### 13.1 当前 dashboard 查询层绑定 Jaeger + +span 本身通过标准 OTel / OTLP 导出,理论上可以发到任何 OTLP backend。 + +但 `XTuner Task Trace Dashboard` 当前通过 Jaeger Query API 拉数据,所以 dashboard 查询层目前绑定 Jaeger。如果换成 Phoenix、MLflow、Tempo、Zipkin 等 backend,需要新增对应 query adapter。 + +### 13.2 只观测 session server 视角,不拆推理引擎内部 + +当前不修改 lmdeploy / vLLM / sglang 等推理引擎内部代码。 + +因此可以看到: + +- session server 发起 upstream request。 +- upstream headers latency。 +- first token latency。 +- stream read latency。 + +不能看到: + +- engine queue。 +- prefill。 +- decode。 +- batch merge。 +- KV cache。 +- GPU kernel / runtime metrics。 + +### 13.3 不默认记录大对象 + +默认 span attributes 不记录: + +- prompt。 +- response。 +- tool result。 +- 图片。 +- tensor。 +- 完整 message 内容。 + +如果后续需要记录这些内容,应做成显式 opt-in,并考虑脱敏、大小限制和存储成本。 + +### 13.4 trace_id 和 uid 仍然并存 + +当前保留 `trace_id` 是为了获得稳定、可复现的 task identity。`uid` 仍然服务于原有业务逻辑。 + +后续如果语义稳定,可以再评估是否把 `trace_id` 和 `uid` 合并或统一命名。 + +## 14. 未完成开发项 + +### 14.1 慢 task 自动保存调用栈 + +当前已经支持异常 stacktrace,但不支持“span 持续很久但没有抛异常”的自动栈采样。 + +后续需要实现: + +- 配置慢 span 阈值。 +- span 超过阈值仍未结束时,采样对应进程 Python stack。 +- 支持 asyncio task stack。 +- 保存 stack artifact 或写入 span event。 +- viewer 在 task detail 中展示慢 task stack。 +- 控制采样数量和输出大小,避免大量卡住 task 时产生过多数据。 + +这个功能用于回答: + +> task 没失败,但卡住时当前代码执行到了哪里? + +### 14.2 外部 repo / lagent 进程内部 child span + +当前 XTuner 已经把 trace context 注入 sandbox entry 子进程环境变量。 + +还未完成的是:lagent 或其他外部 repo 内部读取这些 env,并创建自己的 child span。 + +后续需要在外部 repo 中增加很薄的 OTel adapter: + +- 从 `OTEL_PROPAGATOR_TRACEPARENT` / `OTEL_PROPAGATOR_TRACESTATE` / `OTEL_PROPAGATOR_BAGGAGE` 恢复 context。 +- 给外部 repo 内的 LLM call、tool call、validate 等阶段创建 child span。 +- 不要求外部 repo 大规模改造。 + +### 14.3 更细粒度的 agent LLM / tool 阶段语义 + +当前已经有: + +- session server 视角的 `llm_call`。 +- sandbox entry 视角的 `entry`。 +- validate / judger 视角的 `judge`。 + +但还没有完整表达用户最初期望的这种序列: + +```text +sampler -> create sandbox -> llm call 1 -> tool call 1 -> llm call 2 -> tool call 2 -> judger +``` + +后续需要结合 sandbox / lagent 实际执行模型继续澄清: + +- 哪个边界代表一次 LLM call。 +- 哪个边界代表一次 tool call。 +- tool call 的 name、status、duration、error 如何记录。 +- 是否需要标准化 `xtuner.stage.kind=tool_call`。 +- 是否引入 OpenInference semantic attributes。 + +### 14.4 OpenInference 语义映射 + +当前实现基于 OTel + XTuner 自定义 `xtuner.*` attributes。 + +OpenInference 暂未作为默认依赖,也没有完整映射。 + +后续如果接 Phoenix 或其他 AI trace viewer,可以增加: + +- LLM span 的 OpenInference attributes。 +- Tool span 的 OpenInference attributes。 +- Evaluator / judger span 的 OpenInference attributes。 +- Agent span 的 OpenInference attributes。 + +第一版没有强行引入 OpenInference,是为了减少依赖和侵入,同时保持当前 Jaeger 路线简单。 + +### 14.5 非 Jaeger backend 的 task dashboard adapter + +当前 span 可导出到任意 OTLP backend,但聚合 dashboard 只实现了 Jaeger Query API。 + +如果后续要直接支持其他 backend 的同等 dashboard,需要补: + +- Phoenix query adapter。 +- MLflow query adapter。 +- Tempo query adapter。 +- Zipkin query adapter。 + +否则这些 backend 只能看原生 trace UI,不能直接使用 XTuner dashboard 的 task overview / stage summary。 + +### 14.6 生产级 Jaeger 存储和部署 + +当前提供的是 memory backend 参考配置: + +- `recipe/otle/jaeger/jaeger-memory.yaml` + +它适合本地预览和 smoke,不适合长时间生产训练。 + +后续如果要用于长任务,需要补: + +- 持久化 storage 配置。 +- trace retention 策略。 +- 大量 span 下的查询性能验证。 +- dashboard 查询分页 / 增量刷新优化。 + +### 14.7 trace 数据大小控制 + +当前已经避免记录大对象,但还没有完整的属性裁剪 / 采样策略。 + +后续需要考虑: + +- 单个 attribute 长度限制。 +- 单个 span event 数量限制。 +- 每条 trace 最大 span 数。 +- 高频 marker span 是否需要采样。 +- dashboard 查询窗口和刷新频率控制。 + +## 15. 后续建议开发顺序 + +建议后续按以下顺序推进: + +1. 慢 task stack sampling。 +2. lagent / 外部进程 OTel adapter。 +3. 明确定义 LLM call / tool call 的阶段边界,并补默认插桩。 +4. OpenInference semantic attributes 可选映射。 +5. 非 Jaeger backend adapter。 +6. 生产级 Jaeger 存储和 dashboard 查询性能优化。 + +## 16. 最小使用方式 + +1. 启动 Jaeger,使用 `recipe/otle/jaeger/jaeger-memory.yaml` 作为参考配置。 +2. 在训练 config 中打开 trace: + +```python +from xtuner.v1.rl.trace import TraceConfig + +trace_config = TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_service_name="exp-1", + jaeger_query_url="http://127.0.0.1:16686", +) +``` + +3. 启动训练后,在 rank0 日志中查看: + +```text +Jaeger Trace Viewer: http://127.0.0.1:16686 +XTuner Task Trace Dashboard: http://127.0.0.1: +``` + +4. 在 XTuner dashboard 看整体 task 状态和阶段热点。 +5. 点进单条 task,看文字 timeline、图形 timeline 和 Jaeger Native Trace。 +6. 在 Jaeger 原生 UI 中用 tag 查询单条 task: + +```text +xtuner.trace_id= +``` diff --git a/examples/v1/config/agentic_rl_qwen3p5vl_mtp_ep_code.py b/examples/v1/config/agentic_rl_qwen3p5vl_mtp_ep_code.py index 820b2d0d4f..7884f61948 100644 --- a/examples/v1/config/agentic_rl_qwen3p5vl_mtp_ep_code.py +++ b/examples/v1/config/agentic_rl_qwen3p5vl_mtp_ep_code.py @@ -14,6 +14,7 @@ from xtuner.v1.rl.loss import GRPOLossConfig from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trace import TraceConfig from xtuner.v1.rl.trainer import RolloutImportanceSampling, WorkerConfig from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.rl.utils import AcceleratorResourcesConfig @@ -112,6 +113,13 @@ def _build_dataset_cfg(specs, default_recipe): debug_rollout = _env_bool("DEBUG_ROLLOUT") experimental_name = os.environ.get("EXPERIMENT_NAME", "localhost_agent_rl") +trace_config = TraceConfig( + enabled=_env_bool("TRACE_ENABLED", True), + otel_endpoint=os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://127.0.0.1:4317"), + otel_protocol=os.environ.get("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc"), + otel_service_name=os.environ.get("OTEL_SERVICE_NAME", "exp-1"), + jaeger_query_url=os.environ.get("JAEGER_QUERY_URL", "http://127.0.0.1:16686"), +) total_epochs = 10 total_train_steps = int(os.environ["TOTAL_TRAIN_STEPS"]) if os.environ.get("TOTAL_TRAIN_STEPS") else None train_batch_size = int(os.environ.get("TRAIN_BATCH_SIZE", 128)) @@ -316,6 +324,7 @@ def _build_dataset_cfg(specs, default_recipe): enable_evaluate=enable_evaluate, enable_initial_evaluate=enable_initial_evaluate, evaluate_step=evaluate_step, + trace_config=trace_config, work_dir=work_dir, auto_resume=_env_bool("AUTO_RESUME", True), load_checkpoint_cfg=LoadCheckpointConfig(load_optimizer_states=False, load_optimizer_args=False), diff --git a/examples/v1/config/rl_grpo_gsm8k_judge.py b/examples/v1/config/rl_grpo_gsm8k_judge.py index c7c3255f53..9ca9252d37 100644 --- a/examples/v1/config/rl_grpo_gsm8k_judge.py +++ b/examples/v1/config/rl_grpo_gsm8k_judge.py @@ -22,6 +22,7 @@ from xtuner.v1.rl.agent_loop_manager import AgentLoopManagerConfig, SamplerConfig, SyncProduceStrategyConfig, TaskSpecConfig from xtuner.v1.rl.evaluator import EvaluatorConfig from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.trace import TraceConfig from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig # env @@ -44,6 +45,8 @@ max_prompt_length = 512 max_response_length = 1024 pack_max_length = 32 * 1024 +producer_trace_enabled = os.environ.get("PRODUCER_TRACE_ENABLED", "0") == "1" +producer_trace_config = TraceConfig(enabled=producer_trace_enabled) # 1. resources resources = AcceleratorResourcesConfig( @@ -200,6 +203,7 @@ enable_evaluate=True, enable_initial_evaluate=False, evaluate_step=evaluate_step, + trace_config=producer_trace_config, work_dir=work_dir, seed=123, debug_rollout=False, diff --git a/examples/v1/config/testing/rl_trace_smoke_disabled.py b/examples/v1/config/testing/rl_trace_smoke_disabled.py new file mode 100644 index 0000000000..9227d4256c --- /dev/null +++ b/examples/v1/config/testing/rl_trace_smoke_disabled.py @@ -0,0 +1,4 @@ +from examples.v1.config.testing.trace_smoke_common import build_trace_smoke_trainer + + +trainer = build_trace_smoke_trainer(trace_enabled=False, fail_judger=False) diff --git a/examples/v1/config/testing/rl_trace_smoke_enabled.py b/examples/v1/config/testing/rl_trace_smoke_enabled.py new file mode 100644 index 0000000000..cf61873ad4 --- /dev/null +++ b/examples/v1/config/testing/rl_trace_smoke_enabled.py @@ -0,0 +1,4 @@ +from examples.v1.config.testing.trace_smoke_common import build_trace_smoke_trainer + + +trainer = build_trace_smoke_trainer(trace_enabled=True, fail_judger=False) diff --git a/examples/v1/config/testing/rl_trace_smoke_judger_fail.py b/examples/v1/config/testing/rl_trace_smoke_judger_fail.py new file mode 100644 index 0000000000..963ca12188 --- /dev/null +++ b/examples/v1/config/testing/rl_trace_smoke_judger_fail.py @@ -0,0 +1,4 @@ +from examples.v1.config.testing.trace_smoke_common import build_trace_smoke_trainer + + +trainer = build_trace_smoke_trainer(trace_enabled=True, fail_judger=True) diff --git a/examples/v1/config/testing/trace_smoke_common.py b/examples/v1/config/testing/trace_smoke_common.py new file mode 100644 index 0000000000..62a8c26d51 --- /dev/null +++ b/examples/v1/config/testing/trace_smoke_common.py @@ -0,0 +1,220 @@ +import asyncio +import os +from pathlib import Path +from typing import Literal, cast + +from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.rl.advantage import GRPOAdvantageConfig +from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig +from xtuner.v1.rl.agent_loop_manager import ( + AgentLoopManagerConfig, + AsyncProduceStrategyConfig, + SamplerConfig, + TaskSpecConfig, +) +from xtuner.v1.rl.evaluator import EvaluatorConfig +from xtuner.v1.rl.judger import GSM8KJudgerConfig, JudgerConfig +from xtuner.v1.rl.loss import GRPOLossConfig +from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trace import TraceConfig +from xtuner.v1.rl.trainer import WorkerConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, CPUResourcesConfig +from xtuner.v1.train.rl_trainer import RLColocateTrainerConfig + + +def _env_int(name: str, default: int) -> int: + return int(os.environ.get(name, str(default))) + + +def _env_float(name: str, default: float) -> float: + return float(os.environ.get(name, str(default))) + + +def _trace_config_from_env(*, enabled: bool) -> TraceConfig: + protocol = os.environ.get("OTEL_EXPORTER_OTLP_PROTOCOL", "grpc") + if protocol not in {"grpc", "http/protobuf"}: + protocol = "grpc" + exporter = os.environ.get("OTEL_TRACES_EXPORTER", "otlp") + if exporter not in {"otlp", "console"}: + exporter = "otlp" + return TraceConfig( + enabled=enabled, + otel_endpoint=os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://127.0.0.1:4317"), + otel_protocol=cast(Literal["grpc", "http/protobuf"], protocol), + otel_exporter=cast(Literal["otlp", "console"], exporter), + otel_service_name=os.environ.get("OTEL_SERVICE_NAME", "exp-1"), + jaeger_query_url=os.environ.get("JAEGER_QUERY_URL", "http://127.0.0.1:16686"), + ) + + +async def _sleeping_trace_smoke_judger(*, response, label, extra_info): + await asyncio.sleep(float(extra_info.get("sleep_s", 0.0))) + return {"score": float(extra_info.get("score", 0.0))} + + +async def _raise_trace_smoke_judger_failure(*, response, label, extra_info): + await asyncio.sleep(float(extra_info.get("sleep_s", 0.0))) + raise RuntimeError("trace smoke judger failure") + + +def build_trace_smoke_trainer(*, trace_enabled: bool, fail_judger: bool = False) -> RLColocateTrainerConfig: + work_dir = os.environ["WORK_DIR"] + model_path = os.environ["MODEL_PATH"] + data_path = os.environ["DATA_PATH"] + + total_train_steps = _env_int("XTUNER_TRACE_SMOKE_TOTAL_TRAIN_STEPS", 2) + train_optimizer_steps = 1 + train_batch_size = _env_int("XTUNER_TRACE_SMOKE_TRAIN_BATCH_SIZE", 16) * train_optimizer_steps + prompt_repeat_k = _env_int("XTUNER_TRACE_SMOKE_PROMPT_REPEAT_K", 1) + num_workers = _env_int("XTUNER_TRACE_SMOKE_NUM_WORKERS", 1) + max_prompt_length = _env_int("XTUNER_TRACE_SMOKE_MAX_PROMPT_LENGTH", 256) + max_response_length = _env_int("XTUNER_TRACE_SMOKE_MAX_RESPONSE_LENGTH", 128) + pack_max_length = _env_int("XTUNER_TRACE_SMOKE_PACK_MAX_LENGTH", 4096) + over_sample_threshold = float(os.environ.get("XTUNER_TRACE_SMOKE_OVER_SAMPLE_THRESHOLD", "1")) + judger_sleep_s = _env_float("XTUNER_TRACE_SMOKE_JUDGER_SLEEP_S", 0.0) + + trace_config = _trace_config_from_env(enabled=trace_enabled) + + resources = AcceleratorResourcesConfig( + accelerator="GPU", + num_workers=num_workers, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, + ) + + rollout_config = RolloutConfig( + env="trace_smoke_gsm8k", + device=resources.accelerator, + model_path=model_path, + dtype="bfloat16", + tensor_parallel_size=1, + expert_parallel_size=1, + gpu_memory_utilization=0.6, + context_length=max_prompt_length + max_response_length, + ) + + if fail_judger: + judger_config = JudgerConfig( + judger_name="trace_smoke/failing_judger", + reward_handler=_raise_trace_smoke_judger_failure, + extra_info={"sleep_s": judger_sleep_s}, + cpu_resources=CPUResourcesConfig(num_workers=1, num_cpus_per_worker=1), + ) + elif judger_sleep_s > 0: + judger_config = JudgerConfig( + judger_name="trace_smoke/slow_judger", + reward_handler=_sleeping_trace_smoke_judger, + extra_info={"sleep_s": judger_sleep_s, "score": 0.0}, + cpu_resources=CPUResourcesConfig(num_workers=1, num_cpus_per_worker=1), + ) + else: + judger_config = GSM8KJudgerConfig( + judger_name="openai/gsm8k", + cpu_resources=CPUResourcesConfig(num_workers=1, num_cpus_per_worker=1), + ) + + lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) + fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) + model_cfg = get_model_config_from_hf(Path(model_path)) + if hasattr(model_cfg, "balancing_loss_cfg"): + model_cfg.balancing_loss_cfg = None + if hasattr(model_cfg, "z_loss_cfg"): + model_cfg.z_loss_cfg = None + + train_worker_cfg = WorkerConfig( + model_cfg=model_cfg, + load_from=model_path, + optim_cfg=AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1), + loss_cfg=GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type="vanilla", + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode="chunk", + chunk_size=256, + ), + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=1, + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, + ) + + train_dataset = DatasetConfig(name="trace_smoke_gsm8k", anno_path=data_path) + tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length) + train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] + dataloader_cfg = DataloaderConfig( + dataset_config_list=train_dataset_cfg, + pack_max_length=pack_max_length, + collator="fake_collator", + pack_level="none", + ) + sampler_config = SamplerConfig( + dataloader_cfg=dataloader_cfg, + prompt_repeat_k=prompt_repeat_k, + ) + training_sample_params = SampleParams( + max_tokens=max_response_length, + top_k=0, + top_p=1.0, + temperature=1.0, + min_tokens=0, + ) + agent_loop_config = SingleTurnAgentLoopConfig( + hf_checkpoint=model_path, + sample_params=training_sample_params, + ) + produce_strategy_config = AsyncProduceStrategyConfig( + over_sample_threshold=over_sample_threshold, + enable_partial_rollout=True, + max_staleness=0, + tail_batch_trigger_size=0, + ) + agent_loop_manager_cfg = AgentLoopManagerConfig( + tasks=TaskSpecConfig( + task_name="train_task", + agent_loop_config=agent_loop_config, + judger_config=judger_config, + produce_strategy_config=produce_strategy_config, + sampler_config=sampler_config, + ), + ) + + return RLColocateTrainerConfig( + resources=resources, + train_worker_cfg=train_worker_cfg, + rollout_config=rollout_config, + tokenizer_path=model_path, + replay_buffer_config=SyncReplayBufferConfig(), + agent_loop_manager_cfg=agent_loop_manager_cfg, + eval_agent_loop_manager_cfg=None, + evaluator_config=EvaluatorConfig(compute_metric_func=None), + load_from=model_path, + total_train_steps=total_train_steps, + train_batch_size=train_batch_size, + advantage_estimator_config=GRPOAdvantageConfig(eps=1e-8), + enable_evaluate=False, + enable_initial_evaluate=False, + trace_config=trace_config, + work_dir=work_dir, + checkpoint_interval=-1, + checkpoint_maxkeep=-1, + hf_interval=-1, + hf_max_keep=-1, + seed=123, + debug_rollout=False, + exp_tracker="jsonl", + ) diff --git a/pyproject.toml b/pyproject.toml index a3af546cf6..d192220ee7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,12 @@ rl = [ "mathruler", "pylatexenc" ] +trace = [ + "opentelemetry-api", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-exporter-otlp-proto-http", +] video = [ "decord", "av", diff --git a/recipe/otle/README.md b/recipe/otle/README.md new file mode 100644 index 0000000000..5bbe90e0d0 --- /dev/null +++ b/recipe/otle/README.md @@ -0,0 +1,38 @@ +# XTuner OTel Trace + +XTuner exports rollout traces through OpenTelemetry. Use any OTLP-compatible +backend, such as Jaeger, and point `TraceConfig.otel_endpoint` at that backend. + +The included `jaeger/jaeger-memory.yaml` is a reference Jaeger memory config. +If you start Jaeger with that config, the default endpoints are: + +- Jaeger UI: `http://127.0.0.1:16686` +- OTLP HTTP: `http://127.0.0.1:14318/v1/traces` +- OTLP gRPC: `127.0.0.1:14317` + +Configure XTuner: + +```python +trace_config = TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_service_name="exp-1", + jaeger_query_url="http://127.0.0.1:16686", +) +``` + +Equivalent environment variables: + +```bash +export OTEL_TRACES_EXPORTER=otlp +export OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf +export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://127.0.0.1:14318/v1/traces +export OTEL_SERVICE_NAME=exp-1 +``` + +Query in Jaeger: + +- Select service `exp-1`. +- Search tag `xtuner.trace_id=`. +- `case.id=` is also emitted for compatibility with older trace docs. diff --git a/recipe/otle/jaeger/jaeger-memory.yaml b/recipe/otle/jaeger/jaeger-memory.yaml new file mode 100644 index 0000000000..533ad65925 --- /dev/null +++ b/recipe/otle/jaeger/jaeger-memory.yaml @@ -0,0 +1,40 @@ +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:14317 + http: + endpoint: 0.0.0.0:14318 + +processors: + batch: + +exporters: + jaeger_storage_exporter: + trace_storage: memstore + +extensions: + jaeger_storage: + backends: + memstore: + memory: + max_traces: 100000 + jaeger_query: + storage: + traces: memstore + base_path: / + http: + endpoint: 0.0.0.0:16686 + grpc: + endpoint: 0.0.0.0:16685 + +service: + telemetry: + metrics: + level: none + extensions: [jaeger_storage, jaeger_query] + pipelines: + traces: + receivers: [otlp] + processors: [batch] + exporters: [jaeger_storage_exporter] diff --git a/recipe/otle/tools/otlp_http_sink.py b/recipe/otle/tools/otlp_http_sink.py new file mode 100755 index 0000000000..dc8e1f1e81 --- /dev/null +++ b/recipe/otle/tools/otlp_http_sink.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import time +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path + +from google.protobuf.json_format import MessageToDict +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( + ExportTraceServiceRequest, + ExportTraceServiceResponse, +) + + +def _hex(value: bytes) -> str: + return value.hex() + + +def _attrs(attrs) -> dict: + result = {} + for attr in attrs: + val = attr.value + if val.HasField("string_value"): + result[attr.key] = val.string_value + elif val.HasField("bool_value"): + result[attr.key] = val.bool_value + elif val.HasField("int_value"): + result[attr.key] = val.int_value + elif val.HasField("double_value"): + result[attr.key] = val.double_value + else: + result[attr.key] = MessageToDict(val) + return result + + +def build_handler(root: Path): + raw_dir = root / "received" + summary_path = root / "spans.jsonl" + raw_dir.mkdir(parents=True, exist_ok=True) + + class Handler(BaseHTTPRequestHandler): + server_version = "otlp-http-sink/0.1" + + def do_POST(self) -> None: + if self.path != "/v1/traces": + self.send_response(404) + self.end_headers() + return + + length = int(self.headers.get("content-length") or 0) + body = self.rfile.read(length) + ts = time.time() + raw_path = raw_dir / f"{ts:.6f}_{id(body)}.pb" + raw_path.write_bytes(body) + + req = ExportTraceServiceRequest() + try: + req.ParseFromString(body) + count = 0 + with summary_path.open("a", encoding="utf-8") as fp: + for resource_span in req.resource_spans: + resource = _attrs(resource_span.resource.attributes) + for scope_span in resource_span.scope_spans: + scope = {"name": scope_span.scope.name, "version": scope_span.scope.version} + for span in scope_span.spans: + count += 1 + fp.write( + json.dumps( + { + "recv_ts": ts, + "trace_id": _hex(span.trace_id), + "span_id": _hex(span.span_id), + "parent_span_id": _hex(span.parent_span_id), + "name": span.name, + "start_unix_nano": span.start_time_unix_nano, + "end_unix_nano": span.end_time_unix_nano, + "duration_ms": ( + span.end_time_unix_nano - span.start_time_unix_nano + ) + / 1e6, + "status": span.status.code, + "status_message": span.status.message, + "attributes": _attrs(span.attributes), + "resource": resource, + "scope": scope, + }, + ensure_ascii=False, + ) + + "\n" + ) + print(f"received traces bytes={len(body)} spans={count} raw={raw_path}", flush=True) + except Exception as exc: + print(f"failed to decode traces bytes={len(body)} raw={raw_path}: {exc}", flush=True) + + resp = ExportTraceServiceResponse().SerializeToString() + self.send_response(200) + self.send_header("content-type", "application/x-protobuf") + self.send_header("content-length", str(len(resp))) + self.end_headers() + self.wfile.write(resp) + + def log_message(self, fmt: str, *args) -> None: + print(f"{self.address_string()} - {fmt % args}", flush=True) + + return Handler + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=4318) + parser.add_argument("--root", default="/tmp/otelcol") + args = parser.parse_args() + + root = Path(args.root) + root.mkdir(parents=True, exist_ok=True) + server = ThreadingHTTPServer((args.host, args.port), build_handler(root)) + print(f"OTLP HTTP sink listening on {args.host}:{args.port}, summaries -> {root / 'spans.jsonl'}", flush=True) + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/tests/rl/test_trace.py b/tests/rl/test_trace.py new file mode 100644 index 0000000000..fc2415eeaa --- /dev/null +++ b/tests/rl/test_trace.py @@ -0,0 +1,1641 @@ +import json +import tempfile +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from xtuner.tools import task_trace_view +from xtuner.tools.jaeger_trace_dashboard import build_dashboard_payload_from_jaeger_traces +from xtuner.tools.task_trace_analysis import build_unified_trace_payload_from_events +from xtuner.tools.task_trace_view import ( + render_unified_trace_html, +) +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.rollout.session_server import ( + _choices_output_ids_len, + _extract_body_trace_context, + _response_output_ids_len, +) +from xtuner.v1.rl.rollout.session_trace import ForwardRequestTrace +from xtuner.v1.rl.trace import ( + TraceConfig, + TraceEvent, + TraceRecorder, + _span_attributes, + build_rollout_trace_attributes, + build_rollout_trace_id, + close_trace, + configure_trace, + get_trace_env_vars, + inject_trace_context, + merge_trace_runtime_env, + reset_trace_for_test, + trace_baggage, + trace_event, + xtuner_trace_function, + xtuner_trace_span, + trace_task_context, + use_trace_recorder, +) + + +def make_state(uid: int = 1, task_name: str = "gsm8k", status: Status = Status.INIT) -> RolloutState: + return RolloutState( + message=[{"role": "user", "content": "What is 1 + 1?"}], + uid=uid, + task_name=task_name, + session_uid=uid + 1000, + status=status, + ) + + +def make_event( + trace_id: str, + stage: str, + timestamp_s: float, + *, + task_name: str = "gsm8k", + uid: int = 1, + status: str = "init", + elapsed_s: float | None = None, + train_step: int | None = None, + model_step: int | None = None, + producer_future_step: int | None = None, + produce_batch_id: str | None = None, + error_msg: str | None = None, +) -> TraceEvent: + return TraceEvent( + trace_id=trace_id, + stage=stage, + timestamp_s=timestamp_s, + status=status, + task_name=task_name, + uid=uid, + session_uid=uid + 1000, + train_step=train_step, + model_step=model_step, + producer_future_step=producer_future_step, + produce_batch_id=produce_batch_id, + worker_rank=None, + elapsed_s=elapsed_s, + error_msg=error_msg, + ) + + +def make_batch_event( + trace_id: str, + stage: str, + timestamp_s: float, + *, + uid: int, + batch: str, + train_step: int, + model_step: int = 0, + producer_future_step: int = 0, + elapsed_s: float | None = None, + status: str = "init", + error_msg: str | None = None, +) -> TraceEvent: + return make_event( + trace_id, + stage, + timestamp_s, + uid=uid, + status=status, + elapsed_s=elapsed_s, + train_step=train_step, + model_step=model_step, + producer_future_step=producer_future_step, + produce_batch_id=batch, + error_msg=error_msg, + ) + + +class RecordingTraceSink: + def __init__(self) -> None: + self.events: list[TraceEvent] = [] + self.flushed = False + self.closed = False + + def append(self, event: TraceEvent) -> None: + self.events.append(event) + + def flush(self) -> bool: + self.flushed = True + return True + + def close(self) -> bool: + self.closed = True + return True + + +def timeline_from_events(events: list[TraceEvent], trace_id: str) -> list[TraceEvent]: + return [event for event in events if event.trace_id == trace_id] + + +def make_jaeger_span( + *, + trace_id: str, + span_id: str, + operation: str, + process_id: str, + start_us: int, + duration_us: int, + tags: dict[str, object], + parent_span_id: str | None = None, +) -> dict[str, object]: + references = [] + if parent_span_id is not None: + references.append({"refType": "CHILD_OF", "spanID": parent_span_id}) + return { + "traceID": trace_id, + "spanID": span_id, + "operationName": operation, + "processID": process_id, + "startTime": start_us, + "duration": duration_us, + "references": references, + "tags": [{"key": key, "value": value} for key, value in tags.items()], + "logs": [], + } + + +def make_jaeger_trace(spans: list[dict[str, object]]) -> dict[str, object]: + return { + "traceID": "a" * 32, + "processes": { + "p-xtuner": {"serviceName": "xtuner-rl", "tags": []}, + "p-lagent": {"serviceName": "lagent", "tags": []}, + "p-old": {"serviceName": "old-run", "tags": []}, + }, + "spans": spans, + } + + +class TraceCoreBehaviorTest(unittest.IsolatedAsyncioTestCase): + def tearDown(self): + reset_trace_for_test() + + async def test_trace_api_records_event_span_function(self): + sink = RecordingTraceSink() + recorder = TraceRecorder(sink) + state = make_state(uid=123) + + with use_trace_recorder(recorder): + await trace_event(state, "custom.prepare") + + async with xtuner_trace_span(state, "custom.work"): + state.status = Status.COMPLETED + + @xtuner_trace_function("custom.fn") + async def traced_fn(rollout_state: RolloutState) -> RolloutState: + await trace_event(rollout_state, "custom.fn.inner") + return rollout_state + + await traced_fn(state) + + recorder.flush() + recorder.close() + timeline = timeline_from_events(sink.events, "gsm8k:123") + + expected_stages = [ + "custom.prepare", + "custom.work.start", + "custom.work.end", + "custom.fn.start", + "custom.fn.inner", + "custom.fn.end", + ] + self.assertEqual([event.stage for event in timeline], expected_stages) + self.assertEqual([event.stage for event in sink.events], expected_stages) + self.assertEqual(timeline[-1].status, "completed") + self.assertTrue(sink.flushed) + self.assertTrue(sink.closed) + + async def test_xtuner_trace_function_resolves_target_and_dynamic_kwargs(self): + sink = RecordingTraceSink() + batch_id = "train_step=1/model_step=2/producer_future_step=3" + + class Worker: + def __init__(self): + self.worker_rank = 7 + + @xtuner_trace_function( + "custom.worker", + target="state", + trace_kwargs_getter=lambda self, *args, **kwargs: { + "produce_batch_id": batch_id, + "worker_rank": self.worker_rank, + }, + ) + async def run(self, state: RolloutState) -> RolloutState: + return state.model_copy(update={"status": Status.COMPLETED}, deep=True) + + state = make_state(uid=456) + with use_trace_recorder(TraceRecorder(sink)): + await Worker().run(state) + timeline = timeline_from_events(sink.events, "gsm8k:456") + + self.assertEqual([event.stage for event in timeline], ["custom.worker.start", "custom.worker.end"]) + self.assertEqual([event.produce_batch_id for event in timeline], [batch_id, batch_id]) + self.assertEqual([event.worker_rank for event in timeline], [7, 7]) + self.assertEqual(timeline[0].status, "init") + self.assertEqual(timeline[-1].status, "completed") + + async def test_xtuner_trace_function_records_error_event_and_reraises(self): + sink = RecordingTraceSink() + state = make_state(uid=789) + + @xtuner_trace_function("custom.failure", target="state") + async def failing_fn(state: RolloutState) -> None: + raise ValueError("rollout failed") + + with use_trace_recorder(TraceRecorder(sink)): + with self.assertRaisesRegex(ValueError, "rollout failed"): + await failing_fn(state) + + timeline = timeline_from_events(sink.events, "gsm8k:789") + self.assertEqual([event.stage for event in timeline], ["custom.failure.start", "custom.failure.error"]) + self.assertIsNotNone(timeline[-1].elapsed_s) + self.assertTrue(timeline[-1].error_msg.startswith("ValueError: rollout failed")) + self.assertEqual(timeline[-1].error_type, "ValueError") + self.assertIn("raise ValueError", timeline[-1].error_stacktrace or "") + + async def test_xtuner_trace_function_respects_explicit_group_target(self): + sink = RecordingTraceSink() + + @xtuner_trace_function("custom.group", target="group") + async def traced_group(group: list[RolloutState]) -> list[RolloutState]: + for state in group: + state.status = Status.COMPLETED + return group + + states = [make_state(uid=1), make_state(uid=2)] + with use_trace_recorder(TraceRecorder(sink)): + await traced_group(states) + timeline_1 = timeline_from_events(sink.events, "gsm8k:1") + timeline_2 = timeline_from_events(sink.events, "gsm8k:2") + + self.assertEqual( + [event.stage for event in timeline_1], + ["custom.group.start", "custom.group.end"], + ) + self.assertEqual( + [event.stage for event in timeline_2], + ["custom.group.start", "custom.group.end"], + ) + self.assertEqual(timeline_1[-1].status, "completed") + self.assertEqual(timeline_2[-1].status, "completed") + + async def test_xtuner_trace_function_records_agent_loop_group_span_for_each_task(self): + sink = RecordingTraceSink() + + @xtuner_trace_function("xtuner.agent_in_sandbox.generate_group") + async def traced_group(rollout_state: list[RolloutState]) -> list[RolloutState]: + return [state.model_copy(update={"status": Status.COMPLETED}, deep=True) for state in rollout_state] + + states = [make_state(uid=1), make_state(uid=2)] + with use_trace_recorder(TraceRecorder(sink)): + await traced_group(states) + + self.assertEqual( + [event.stage for event in timeline_from_events(sink.events, "gsm8k:1")], + ["xtuner.agent_in_sandbox.generate_group.start", "xtuner.agent_in_sandbox.generate_group.end"], + ) + self.assertEqual( + [event.stage for event in timeline_from_events(sink.events, "gsm8k:2")], + ["xtuner.agent_in_sandbox.generate_group.start", "xtuner.agent_in_sandbox.generate_group.end"], + ) + self.assertEqual(timeline_from_events(sink.events, "gsm8k:1")[-1].status, "completed") + self.assertEqual(timeline_from_events(sink.events, "gsm8k:2")[-1].status, "completed") + + async def test_xtuner_trace_function_records_start_for_returned_target(self): + sink = RecordingTraceSink() + + @xtuner_trace_function("custom.sample") + async def sample_group() -> list[RolloutState]: + return [make_state(uid=11)] + + with use_trace_recorder(TraceRecorder(sink)): + await sample_group() + timeline = timeline_from_events(sink.events, "gsm8k:11") + + self.assertEqual([event.stage for event in timeline], ["custom.sample.start", "custom.sample.end"]) + self.assertLessEqual(timeline[0].timestamp_s, timeline[1].timestamp_s) + self.assertIsNotNone(timeline[1].elapsed_s) + + async def test_xtuner_trace_span_supports_uid_style_target_attributes_and_annotations(self): + sink = RecordingTraceSink() + item = SimpleNamespace(data_source="sandbox_task", uid=123, group_id=456, status="running") + + with use_trace_recorder(TraceRecorder(sink)): + async with xtuner_trace_span(item, "xtuner.sandbox.acquire", task_id="task-123") as span: + span.annotate( + sandbox_name="default", + sandbox_env_id="env-1", + sandbox_url="http://sandbox", + ) + item.status = "completed" + + timeline = timeline_from_events(sink.events, "sandbox_task:123") + + self.assertEqual( + [event.stage for event in timeline], + ["xtuner.sandbox.acquire.start", "xtuner.sandbox.acquire.end"], + ) + self.assertEqual(timeline[0].attributes, {"task_id": "task-123"}) + self.assertEqual(timeline[-1].status, "completed") + self.assertEqual( + timeline[-1].attributes, + { + "task_id": "task-123", + "sandbox_name": "default", + "sandbox_env_id": "env-1", + "sandbox_url": "http://sandbox", + }, + ) + + async def test_xtuner_trace_span_mark_error_records_error_event_without_exception(self): + sink = RecordingTraceSink() + item = SimpleNamespace(data_source="sandbox_task", uid=124, group_id=None, status="running") + + with use_trace_recorder(TraceRecorder(sink)): + async with xtuner_trace_span(item, "xtuner.sandbox.infer", task_id="task-124") as span: + span.mark_error("return_code=1: failed") + item.status = "failed" + + timeline = timeline_from_events(sink.events, "sandbox_task:124") + + self.assertEqual( + [event.stage for event in timeline], + ["xtuner.sandbox.infer.start", "xtuner.sandbox.infer.error"], + ) + self.assertEqual(timeline[-1].status, "failed") + self.assertEqual(timeline[-1].error_msg, "return_code=1: failed") + self.assertEqual(timeline[-1].attributes, {"task_id": "task-124"}) + + async def test_xtuner_trace_span_records_materialize_trajectory_attributes(self): + sink = RecordingTraceSink() + state = make_state(uid=88) + + with use_trace_recorder(TraceRecorder(sink)): + async with xtuner_trace_span( + state, + "xtuner.agent_in_sandbox.materialize_trajectory", + agent_status="completed", + **{"xtuner.stage.kind": "materialize"}, + ) as span: + span.annotate(agent_message_count=3, agent_has_tools=True, input_tokens=11, label_tokens=7) + + timeline = timeline_from_events(sink.events, "gsm8k:88") + self.assertEqual( + [event.stage for event in timeline], + [ + "xtuner.agent_in_sandbox.materialize_trajectory.start", + "xtuner.agent_in_sandbox.materialize_trajectory.end", + ], + ) + self.assertEqual(timeline[-1].attributes["agent_status"], "completed") + self.assertEqual(timeline[-1].attributes["xtuner.stage.kind"], "materialize") + self.assertEqual(timeline[-1].attributes["agent_message_count"], 3) + self.assertEqual(timeline[-1].attributes["agent_has_tools"], True) + self.assertEqual(timeline[-1].attributes["input_tokens"], 11) + self.assertEqual(timeline[-1].attributes["label_tokens"], 7) + + def test_make_trace_context_carrier_noops_when_trace_disabled(self): + reset_trace_for_test() + from xtuner.v1.rl.trace import make_trace_context_carrier + + self.assertEqual(make_trace_context_carrier(), {}) + + def test_sandbox_trace_env_uses_prefixed_propagator_and_otel_config_keys(self): + pytest.importorskip("lagent") + from xtuner.v1.rl.agent_loop.sandbox_agent_loop import sandbox as sandbox_module + + with patch( + "xtuner.v1.rl.agent_loop.sandbox_agent_loop.sandbox.make_trace_context_carrier", + return_value={"traceparent": "00-abc-def-01", "baggage": "xtuner.trace_id=gsm8k%3A1"}, + ), patch( + "xtuner.v1.rl.agent_loop.sandbox_agent_loop.sandbox.get_trace_env_vars", + return_value={ + "OTEL_TRACES_EXPORTER": "otlp", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": "http://127.0.0.1:4317", + "OTEL_SERVICE_NAME": "exp-1", + }, + ): + merged = sandbox_module._merge_trace_env({"A": "1"}) + + self.assertEqual(merged["A"], "1") + self.assertEqual(merged["OTEL_PROPAGATOR_TRACEPARENT"], "00-abc-def-01") + self.assertEqual(merged["OTEL_PROPAGATOR_BAGGAGE"], "xtuner.trace_id=gsm8k%3A1") + self.assertEqual(merged["OTEL_TRACES_EXPORTER"], "otlp") + self.assertEqual(merged["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"], "http://127.0.0.1:4317") + self.assertEqual(merged["OTEL_SERVICE_NAME"], "exp-1") + + async def test_localhost_runner_records_task_spans(self): + pytest.importorskip("lagent") + from xtuner.v1.rl.agent_loop.localhost_agent_loop.runner import LocalhostRunner + from xtuner.v1.rl.agent_loop.sandbox_agent_loop.schemas import AgentRolloutItem, StageResult, StageStatus + + class FakeInfer: + async def run(self, item: AgentRolloutItem, record): + record.status = StageStatus.COMPLETED + return StageResult(return_code=0) + + class FakeValidate: + name = "validate" + + async def run(self, item: AgentRolloutItem, record): + record.status = StageStatus.COMPLETED + return 1.0 + + sink = RecordingTraceSink() + item = AgentRolloutItem( + id="local-task-1", + data_source="localhost", + instruction="solve", + uid=321, + ) + + with use_trace_recorder(TraceRecorder(sink)): + result = await LocalhostRunner(infer=FakeInfer(), validate=FakeValidate()).run(item) + + timeline = timeline_from_events(sink.events, "localhost:321") + + self.assertEqual(result.status.value, "completed") + self.assertEqual(result.reward, 1.0) + self.assertEqual( + [event.stage for event in timeline], + [ + "xtuner.localhost.run_total.start", + "xtuner.localhost.infer.start", + "xtuner.localhost.infer.end", + "xtuner.localhost.validate.start", + "xtuner.localhost.validate.end", + "xtuner.localhost.run_total.end", + ], + ) + self.assertEqual(timeline[0].attributes, {"task_id": "local-task-1", "xtuner.stage.kind": "agent_loop"}) + self.assertEqual( + timeline[3].attributes, + {"task_id": "local-task-1", "validate_name": "validate", "xtuner.stage.kind": "judge"}, + ) + + def test_trace_runtime_env_is_propagated_to_ray_actor_options(self): + with patch("xtuner.v1.rl.trace.OtelTraceSink", return_value=RecordingTraceSink()): + configure_trace( + TraceConfig( + enabled=True, + otel_endpoint="http://otel-collector:4317", + otel_service_name="xtuner-test", + ) + ) + actor_options = {"num_cpus": 1, "runtime_env": {"env_vars": {"EXISTING": "1"}}} + + env_vars = get_trace_env_vars() + merged = merge_trace_runtime_env(actor_options) + close_trace() + + self.assertIs(merged, actor_options) + self.assertNotIn("XTUNER_TRACE_ENABLED", env_vars) + self.assertNotIn("XTUNER_TRACE_OTEL_ENDPOINT", env_vars) + self.assertEqual(env_vars["OTEL_TRACES_EXPORTER"], "otlp") + self.assertEqual(env_vars["OTEL_EXPORTER_OTLP_PROTOCOL"], "grpc") + self.assertEqual(env_vars["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"], "http://otel-collector:4317") + self.assertEqual(env_vars["OTEL_SERVICE_NAME"], "xtuner-test") + self.assertEqual(actor_options["runtime_env"]["env_vars"]["EXISTING"], "1") + self.assertNotIn("XTUNER_TRACE_ENABLED", actor_options["runtime_env"]["env_vars"]) + self.assertEqual( + actor_options["runtime_env"]["env_vars"]["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"], + "http://otel-collector:4317", + ) + self.assertEqual(actor_options["runtime_env"]["env_vars"]["OTEL_SERVICE_NAME"], "xtuner-test") + self.assertEqual(get_trace_env_vars(), {}) + + def test_trace_config_accepts_http_protocol(self): + config = TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_service_name="xtuner-test", + ) + + self.assertEqual(config.otel_protocol, "http/protobuf") + self.assertEqual(config.otel_endpoint, "http://127.0.0.1:14318/v1/traces") + + def test_standard_otel_env_is_propagated_to_ray_actor_options(self): + with patch("xtuner.v1.rl.trace.OtelTraceSink", return_value=RecordingTraceSink()): + configure_trace( + TraceConfig( + enabled=True, + otel_endpoint="http://127.0.0.1:14318/v1/traces", + otel_protocol="http/protobuf", + otel_service_name="xtuner-test", + ) + ) + actor_options = {"runtime_env": {"env_vars": {}}} + + env_vars = get_trace_env_vars() + merge_trace_runtime_env(actor_options) + close_trace() + + self.assertNotIn("XTUNER_OTEL_ENABLED", env_vars) + self.assertNotIn("AGENT_OTEL_ENABLED", env_vars) + self.assertEqual(env_vars["OTEL_TRACES_EXPORTER"], "otlp") + self.assertEqual(env_vars["OTEL_EXPORTER_OTLP_PROTOCOL"], "http/protobuf") + self.assertEqual(env_vars["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"], "http://127.0.0.1:14318/v1/traces") + self.assertEqual(env_vars["OTEL_SERVICE_NAME"], "xtuner-test") + self.assertEqual(actor_options["runtime_env"]["env_vars"]["OTEL_EXPORTER_OTLP_PROTOCOL"], "http/protobuf") + + def test_build_rollout_trace_id_is_stable_and_repeat_sensitive(self): + state = RolloutState( + task_name="gsm8k", + message=[{"role": "user", "content": "What is 1 + 1?"}], + data_source={"dataset": "demo"}, + message_uid=123, + ) + + trace_id = build_rollout_trace_id(state, repeat_index=1) + + self.assertEqual(trace_id, build_rollout_trace_id(state, repeat_index=1)) + self.assertNotEqual(trace_id, build_rollout_trace_id(state, repeat_index=2)) + self.assertTrue(trace_id.startswith("gsm8k:")) + + def test_build_rollout_trace_attributes_prefers_trace_id_over_uid(self): + state = make_state(uid=99) + state.trace_id = "gsm8k:stable" + + with patch.dict("os.environ", {"XTUNER_OTEL_RUN_ID": "run-a"}, clear=False): + attrs = build_rollout_trace_attributes(state) + + self.assertEqual(attrs["xtuner.trace_id"], "gsm8k:stable") + self.assertEqual(attrs["case.id"], "gsm8k:stable") + self.assertEqual(attrs["xtuner.uid"], 99) + self.assertEqual(attrs["run.id"], "run-a") + + def test_build_rollout_trace_attributes_stringifies_large_uid(self): + large_uid = 34472246533035500442377505850707626501 + state = make_state(uid=large_uid) + state.trace_id = "gsm8k:large" + + attrs = build_rollout_trace_attributes(state) + + self.assertEqual(attrs["xtuner.uid"], str(large_uid)) + + def test_span_attributes_stringify_large_identifiers(self): + large_uid = 199977222912148395370827013746483739450 + large_session_uid = 34472246533035500442377505850707626501 + event = TraceEvent( + trace_id="gsm8k:large", + stage="custom.start", + timestamp_s=1.0, + task_name="gsm8k", + uid=large_uid, + session_uid=large_session_uid, + ) + + attrs = _span_attributes(event, "custom", "start") + + self.assertEqual(attrs["xtuner.uid"], str(large_uid)) + self.assertEqual(attrs["xtuner.session_uid"], str(large_session_uid)) + + async def test_trace_event_prefers_rollout_state_trace_id(self): + sink = RecordingTraceSink() + state = make_state(uid=123) + state.trace_id = "gsm8k:stable" + + with use_trace_recorder(TraceRecorder(sink)): + await trace_event(state, "custom.trace_id") + + self.assertEqual([event.trace_id for event in sink.events], ["gsm8k:stable"]) + + async def test_trace_event_prefers_task_like_trace_id(self): + sink = RecordingTraceSink() + item = SimpleNamespace( + data_source="gsm8k", + uid=123, + trace_id="gsm8k:stable", + ) + + with use_trace_recorder(TraceRecorder(sink)): + await trace_event(item, "custom.agent_item") + + self.assertEqual([event.trace_id for event in sink.events], ["gsm8k:stable"]) + + def test_trace_baggage_and_context_helpers_noop_when_disabled(self): + headers = {} + with patch.dict("os.environ", {"OTEL_TRACES_EXPORTER": "none"}, clear=True): + with trace_baggage({"xtuner.trace_id": "gsm8k:stable"}): + inject_trace_context(headers) + with trace_task_context({"xtuner.trace_id": "gsm8k:stable"}): + inject_trace_context(headers) + + self.assertEqual(headers, {}) + + def test_session_server_trace_context_and_token_helpers(self): + traceparent = "00-" + "1" * 32 + "-" + "2" * 16 + "-01" + body = json.dumps({"_otel_trace_context": {"traceparent": traceparent}}).encode() + + self.assertEqual(_extract_body_trace_context(body), {"traceparent": traceparent}) + self.assertIsNone(_extract_body_trace_context(b"not-json")) + self.assertEqual(_choices_output_ids_len({"choices": [{"output_ids": [1, 2]}, {"output_ids": [3]}]}), 3) + self.assertEqual(_response_output_ids_len({"output_ids": [1, 2, 3]}), 3) + self.assertEqual(_response_output_ids_len({"choices": [{"output_ids": [1]}, {"delta": {"content": "x"}}]}), 1) + + def test_session_server_stream_trace_records_forward_relative_first_token_latency(self): + span_attrs: list[dict[str, object]] = [] + end_attrs: list[dict[str, object]] = [] + + def record_attrs(_span, **attrs): + span_attrs.append(attrs) + + def record_end(_span, exc=None, **attrs): + end_attrs.append(attrs) + + with ( + patch("xtuner.v1.rl.rollout.session_trace.time.perf_counter") as perf_counter, + patch("xtuner.v1.rl.rollout.session_trace.begin_otel_span", return_value="stream-span"), + patch("xtuner.v1.rl.rollout.session_trace.set_otel_span_attrs", side_effect=record_attrs), + patch("xtuner.v1.rl.rollout.session_trace.end_otel_span", side_effect=record_end), + ): + perf_counter.side_effect = [ + 10.05, # response headers received + 10.10, # stream_read span starts + 10.20, # first chunk + 10.35, # first output token and content + 10.60, # stream finishes + ] + forward_trace = ForwardRequestTrace(span="forward-span", start_s=10.0) + + forward_trace.set_http_status(200) + stream_trace = forward_trace.start_stream( + target_url="http://worker/v1/chat/completions", + request_data={"input_ids": [1, 2, 3], "max_tokens": 8}, + ) + stream_trace.on_line(b": keep-alive\n") + stream_trace.on_line( + b'data: {"choices":[{"delta":{"content":"A"},"output_ids":[42]}],' + b'"usage":{"prompt_tokens":3,"completion_tokens":1,"total_tokens":4}}\n' + ) + stream_trace.finish(client_alive=True) + + self.assertEqual(span_attrs[0]["http_status"], 200) + self.assertAlmostEqual(span_attrs[0]["upstream_headers_ms"], 50.0, places=6) + self.assertAlmostEqual(span_attrs[-1]["first_output_token_from_forward_ms"], 350.0, places=6) + self.assertAlmostEqual(end_attrs[0]["first_output_token_ms"], 250.0, places=6) + self.assertAlmostEqual(end_attrs[0]["first_output_token_from_forward_ms"], 350.0, places=6) + self.assertAlmostEqual(end_attrs[0]["stream_complete_from_forward_ms"], 600.0, places=6) + + def test_unified_payload_reports_latest_batch_and_current_stalls(self): + old_batch = "train_step=1/model_step=0/producer_future_step=0" + latest_batch = "train_step=2/model_step=0/producer_future_step=0" + events = [ + make_batch_event( + "gsm8k:1", + "xtuner.rollout_controller.generate.start", + 5.0, + uid=1, + batch=old_batch, + train_step=1, + ), + make_batch_event( + "gsm8k:2", + "xtuner.rollout_controller.generate.start", + 10.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:3", + "xtuner.judger.judge.start", + 12.0, + uid=3, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:4", + "xtuner.producer.put_generated_group.start", + 13.0, + uid=4, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:4", + "xtuner.producer.put_generated_group.end", + 15.0, + uid=4, + batch=latest_batch, + train_step=2, + elapsed_s=2.0, + status="completed", + ), + ] + + with patch("xtuner.tools.task_trace_analysis.time.time", return_value=30.0): + payload = build_unified_trace_payload_from_events( + events, + trace_source="/tmp/trace", + default_scope="latest-produce-batch", + ) + + view = payload["views"]["latest-produce-batch"] + overview = view["overview"] + self.assertEqual(payload["default_scope"], "latest-produce-batch") + self.assertEqual(view["event_count"], 4) + self.assertEqual(overview["total_tasks"], 3) + self.assertEqual(overview["running_tasks"], 2) + self.assertEqual(overview["completed_tasks"], 1) + stage_stats = {item["stage"]: item for item in view["stage_stats"]} + self.assertEqual(stage_stats["rollout.generate"]["running_tasks"], 1) + self.assertEqual(stage_stats["judger"]["running_tasks"], 1) + self.assertEqual({row["trace_id"] for row in view["task_rows"]}, {"gsm8k:2", "gsm8k:3", "gsm8k:4"}) + + details = view["task_details"] + self.assertEqual(details["gsm8k:2"]["duration_s"], 20.0) + self.assertEqual(details["gsm8k:3"]["duration_s"], 18.0) + + def test_unified_view_payload_reports_overview_stage_stats_and_task_detail(self): + latest_batch = "train_step=2/model_step=0/producer_future_step=0" + events = [ + make_batch_event( + "gsm8k:1", + "xtuner.rollout_controller.generate.start", + 10.0, + uid=1, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:1", + "xtuner.rollout_controller.generate.end", + 12.0, + uid=1, + batch=latest_batch, + train_step=2, + elapsed_s=2.0, + status="completed", + ), + make_batch_event( + "gsm8k:1", + "xtuner.judger.judge.start", + 13.0, + uid=1, + batch=latest_batch, + train_step=2, + status="completed", + ), + make_batch_event( + "gsm8k:1", + "xtuner.judger.judge.end", + 16.0, + uid=1, + batch=latest_batch, + train_step=2, + elapsed_s=3.0, + status="completed", + ), + make_batch_event( + "gsm8k:2", + "xtuner.judger.judge.start", + 20.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_event( + "gsm8k:2", + "xtuner.judger.judge.error", + 24.0, + uid=2, + status="failed", + elapsed_s=4.0, + train_step=2, + model_step=0, + producer_future_step=0, + produce_batch_id=latest_batch, + ), + make_batch_event( + "gsm8k:3", + "xtuner.rollout_controller.generate.start", + 30.0, + uid=3, + batch=latest_batch, + train_step=2, + ), + ] + + with patch("xtuner.tools.task_trace_analysis.time.time", return_value=40.0): + payload = build_unified_trace_payload_from_events(events, trace_source="/tmp/trace") + + self.assertEqual(payload["default_scope"], "all") + self.assertEqual(payload["available_scopes"], ["all", "latest-produce-batch"]) + + all_view = payload["views"]["all"] + overview = all_view["overview"] + self.assertEqual(overview["total_tasks"], 3) + self.assertEqual(overview["completed_tasks"], 1) + self.assertEqual(overview["running_tasks"], 1) + self.assertEqual(overview["failed_tasks"], 1) + + stage_stats = {item["stage"]: item for item in all_view["stage_stats"]} + self.assertEqual(stage_stats["rollout.generate"]["running_tasks"], 1) + self.assertEqual(stage_stats["rollout.generate"]["visited_tasks"], 2) + self.assertEqual(stage_stats["rollout.generate"]["avg_s"], 2.0) + self.assertEqual(stage_stats["judger"]["visited_tasks"], 2) + self.assertEqual(stage_stats["judger"]["max_s"], 4.0) + + rows = {row["trace_id"]: row for row in all_view["task_rows"]} + self.assertEqual(rows["gsm8k:1"]["status"], "completed") + self.assertEqual(rows["gsm8k:2"]["status"], "failed") + self.assertEqual(rows["gsm8k:3"]["status"], "running") + self.assertEqual(rows["gsm8k:3"]["current_stage"], "rollout.generate") + + detail = all_view["task_details"]["gsm8k:2"] + self.assertTrue(detail["timeline_events"]) + self.assertTrue(detail["timeline_spans"]) + self.assertIn("error_msg", detail) + self.assertTrue(detail["error_msg"]) + self.assertNotIn("error_msg", rows["gsm8k:2"]) + + def test_unified_viewer_html_contains_new_sections(self): + payload = { + "generated_at_s": 1.0, + "trace_source": "/tmp/trace", + "live_mode": False, + "default_scope": "all", + "available_scopes": ["all", "latest-produce-batch"], + "views": { + "all": { + "overview": { + "total_tasks": 3, + "pending_tasks": 0, + "completed_tasks": 1, + "running_tasks": 1, + "failed_tasks": 1, + }, + "stage_stats": [ + { + "stage": "judger", + "current_tasks": 1, + "pending_tasks": 0, + "running_tasks": 1, + "done_spans": 2, + "visited_tasks": 2, + "avg_s": 1.0, + "p95_s": 2.0, + "max_s": 3.0, + } + ], + "task_rows": [], + "task_details": {}, + }, + "latest-produce-batch": { + "overview": { + "total_tasks": 1, + "pending_tasks": 0, + "completed_tasks": 0, + "running_tasks": 1, + "failed_tasks": 0, + }, + "stage_stats": [], + "task_rows": [], + "task_details": {}, + }, + }, + } + + html = render_unified_trace_html(payload, live=False) + self.assertIn("Total tasks", html) + self.assertIn("Pending", html) + self.assertIn("Current Stage Distribution & Duration", html) + self.assertIn("Done", html) + self.assertIn("Failed", html) + self.assertIn("Task Timeline", html) + self.assertIn("Text Timeline", html) + self.assertIn("Jaeger Native Trace", html) + self.assertIn("formatAttributes", html) + self.assertIn("event.attributes", html) + self.assertNotIn("Suspect Open Spans", html) + self.assertNotIn("Latest Stage Distribution", html) + self.assertNotIn("Jaeger Timeline", html) + + def test_unified_viewer_marks_error_timeline_with_message_and_red_span(self): + batch = "train_step=1/model_step=0/producer_future_step=0" + error_msg = "RuntimeError: judger unavailable" + events = [ + make_batch_event( + "gsm8k:1", + "xtuner.rollout_controller.generate.start", + 1.0, + uid=1, + batch=batch, + train_step=1, + ), + make_batch_event( + "gsm8k:1", + "xtuner.rollout_controller.generate.error", + 4.0, + uid=1, + batch=batch, + train_step=1, + elapsed_s=3.0, + status="failed", + error_msg=error_msg, + ), + ] + + payload = build_unified_trace_payload_from_events(events, trace_source="/tmp/trace", now_s=5.0) + detail = payload["views"]["all"]["task_details"]["gsm8k:1"] + + self.assertEqual(detail["error_msg"], error_msg) + self.assertEqual(detail["timeline_events"][-1]["error_msg"], error_msg) + self.assertEqual(detail["timeline_spans"][0]["outcome"], "error") + self.assertEqual(detail["timeline_spans"][0]["color"], "#dc2626") + + html = render_unified_trace_html(payload, live=False) + self.assertIn("timeline-error-text", html) + self.assertIn("event.error_msg", html) + + def test_unified_viewer_prepares_embedded_jaeger_native_trace(self): + batch = "train_step=1/model_step=0/producer_future_step=0" + events = [ + make_batch_event( + "gsm8k:42", + "xtuner.rollout_controller.generate.start", + 1.0, + uid=42, + batch=batch, + train_step=1, + ), + make_batch_event( + "gsm8k:42", + "xtuner.rollout_controller.generate.end", + 3.0, + uid=42, + batch=batch, + train_step=1, + elapsed_s=2.0, + status="completed", + ), + ] + + payload = build_unified_trace_payload_from_events( + events, + trace_source="/tmp/trace", + jaeger_query_url="http://127.0.0.1:16686/", + ) + view = payload["views"]["all"] + row = view["task_rows"][0] + detail = view["task_details"]["gsm8k:42"] + + self.assertRegex(detail["otel_trace_id"], r"^[0-9a-f]{32}$") + self.assertEqual(row["otel_trace_id"], detail["otel_trace_id"]) + self.assertTrue(payload["jaeger_query_enabled"]) + + html = render_unified_trace_html(payload, live=False) + self.assertIn("renderJaegerNativeViewer", html) + self.assertIn("jaegerTraceUrl", html) + self.assertIn("jaeger-native-frame", html) + self.assertIn("Open in Jaeger", html) + self.assertIn("Jaeger Native Trace", html) + self.assertNotIn("/api/jaeger/trace/", html) + self.assertNotIn("renderJaegerTimeline", html) + self.assertNotIn("Jaeger Timeline", html) + + def test_jaeger_trace_response_is_normalized_for_embedded_timeline(self): + raw_response = { + "data": [ + { + "traceID": "0" * 31 + "1", + "processes": { + "p1": { + "serviceName": "xtuner-rl", + "tags": [{"key": "host.name", "value": "worker-0"}], + } + }, + "spans": [ + { + "traceID": "0" * 31 + "1", + "spanID": "span-root", + "operationName": "xtuner.agent_loop.generate_sample", + "processID": "p1", + "startTime": 1_000_000, + "duration": 5_000_000, + "references": [], + "tags": [{"key": "xtuner.uid", "value": 42}], + "logs": [], + }, + { + "traceID": "0" * 31 + "1", + "spanID": "span-judge", + "operationName": "xtuner.judger.judge", + "processID": "p1", + "startTime": 3_000_000, + "duration": 1_000_000, + "references": [{"refType": "CHILD_OF", "spanID": "span-root"}], + "tags": [ + {"key": "error", "value": True}, + {"key": "error.message", "value": "judge failed"}, + ], + "logs": [], + }, + ], + } + ] + } + + payload = task_trace_view.normalize_jaeger_trace_response(raw_response, "0" * 31 + "1") + + self.assertTrue(payload["found"]) + self.assertEqual(payload["trace_id"], "0" * 31 + "1") + self.assertEqual(payload["span_count"], 2) + self.assertEqual(payload["duration_us"], 5_000_000) + spans = {span["span_id"]: span for span in payload["spans"]} + self.assertEqual(spans["span-root"]["service_name"], "xtuner-rl") + self.assertIsNone(spans["span-root"]["parent_span_id"]) + self.assertEqual(spans["span-judge"]["parent_span_id"], "span-root") + self.assertTrue(spans["span-judge"]["is_error"]) + self.assertEqual(spans["span-judge"]["error_msg"], "judge failed") + + def test_jaeger_trace_fetch_uses_query_api(self): + class FakeResponse: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self): + return json.dumps({"data": []}).encode("utf-8") + + with patch("xtuner.tools.task_trace_view.urlopen", return_value=FakeResponse()) as urlopen_mock: + payload = task_trace_view.fetch_jaeger_trace_payload("http://127.0.0.1:16686/", "0" * 31 + "1") + + self.assertFalse(payload["found"]) + urlopen_mock.assert_called_once() + request = urlopen_mock.call_args.args[0] + self.assertEqual(request.full_url, "http://127.0.0.1:16686/api/traces/" + "0" * 31 + "1") + + def test_unified_viewer_reports_pending_and_current_stage_distribution(self): + batch = "train_step=3/model_step=0/producer_future_step=0" + events = [ + make_batch_event( + "gsm8k:10", + "xtuner.task.registered", + 1.0, + uid=10, + batch=batch, + train_step=3, + status="pending", + ), + make_batch_event( + "gsm8k:11", + "xtuner.rollout_controller.generate.start", + 2.0, + uid=11, + batch=batch, + train_step=3, + ), + make_batch_event( + "gsm8k:12", + "xtuner.rollout_controller.generate.start", + 3.0, + uid=12, + batch=batch, + train_step=3, + ), + make_batch_event( + "gsm8k:12", + "xtuner.rollout_controller.generate.end", + 5.0, + uid=12, + batch=batch, + train_step=3, + elapsed_s=2.0, + status="completed", + ), + make_batch_event( + "gsm8k:13", + "xtuner.judger.judge.start", + 6.0, + uid=13, + batch=batch, + train_step=3, + ), + make_batch_event( + "gsm8k:13", + "xtuner.judger.judge.error", + 9.0, + uid=13, + batch=batch, + train_step=3, + elapsed_s=3.0, + status="failed", + error_msg="RuntimeError: judge failed", + ), + ] + + payload = build_unified_trace_payload_from_events(events, trace_source="/tmp/trace", now_s=12.0) + view = payload["views"]["all"] + + self.assertEqual( + view["overview"], + { + "total_tasks": 4, + "pending_tasks": 1, + "running_tasks": 1, + "completed_tasks": 1, + "failed_tasks": 1, + }, + ) + + rows = {row["trace_id"]: row for row in view["task_rows"]} + self.assertEqual(rows["gsm8k:10"]["status"], "pending") + self.assertEqual(rows["gsm8k:10"]["current_stage"], "pending") + self.assertEqual(rows["gsm8k:11"]["status"], "running") + self.assertEqual(rows["gsm8k:11"]["current_stage"], "rollout.generate") + + stage_stats = {item["stage"]: item for item in view["stage_stats"]} + self.assertEqual(stage_stats["pending"]["current_tasks"], 1) + self.assertEqual(stage_stats["pending"]["pending_tasks"], 1) + self.assertEqual(stage_stats["rollout.generate"]["current_tasks"], 1) + self.assertEqual(stage_stats["rollout.generate"]["running_tasks"], 1) + self.assertEqual(stage_stats["rollout.generate"]["done_spans"], 1) + self.assertEqual(stage_stats["rollout.generate"]["avg_s"], 2.0) + self.assertEqual(stage_stats["judger"]["error_count"], 1) + + def test_trace_config_defaults_to_otlp_export(self): + config = TraceConfig() + + self.assertEqual(config.otel_endpoint, "http://127.0.0.1:4317") + self.assertEqual(config.otel_service_name, "xtuner-rl") + + def test_unified_payload_builds_nested_spans_and_stage_stats(self): + old_batch = "train_step=1/model_step=0/producer_future_step=0" + latest_batch = "train_step=2/model_step=0/producer_future_step=0" + events = [ + make_batch_event( + "gsm8k:1", + "xtuner.producer.generate_group.start", + 0.0, + uid=1, + batch=old_batch, + train_step=1, + ), + make_batch_event( + "gsm8k:1", + "xtuner.producer.generate_group.end", + 1.0, + uid=1, + batch=old_batch, + train_step=1, + elapsed_s=1.0, + ), + make_batch_event( + "gsm8k:2", + "xtuner.producer.generate_group.start", + 100.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:2", + "xtuner.agent_loop.generate_group.start", + 101.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:2", + "xtuner.agent_loop.generate_sample.start", + 102.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:2", + "xtuner.rollout_controller.generate.start", + 103.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:2", + "xtuner.rollout_worker.generate.start", + 104.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:2", + "xtuner.rollout_engine.generate.start", + 105.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:2", + "xtuner.rollout_engine.generate.end", + 108.0, + uid=2, + batch=latest_batch, + train_step=2, + elapsed_s=3.0, + ), + make_batch_event( + "gsm8k:2", + "xtuner.rollout_worker.generate.end", + 109.0, + uid=2, + batch=latest_batch, + train_step=2, + elapsed_s=5.0, + ), + make_batch_event( + "gsm8k:2", + "xtuner.rollout_controller.generate.end", + 110.0, + uid=2, + batch=latest_batch, + train_step=2, + elapsed_s=7.0, + ), + make_batch_event( + "gsm8k:2", + "xtuner.judger.judge.start", + 111.0, + uid=2, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:2", + "xtuner.judger.judge.end", + 115.0, + uid=2, + batch=latest_batch, + train_step=2, + elapsed_s=4.0, + ), + make_batch_event( + "gsm8k:2", + "xtuner.agent_loop.generate_sample.end", + 116.0, + uid=2, + batch=latest_batch, + train_step=2, + elapsed_s=14.0, + ), + make_batch_event( + "gsm8k:2", + "xtuner.agent_loop.generate_group.end", + 117.0, + uid=2, + batch=latest_batch, + train_step=2, + elapsed_s=16.0, + ), + make_batch_event( + "gsm8k:2", + "xtuner.producer.generate_group.end", + 118.0, + uid=2, + batch=latest_batch, + train_step=2, + elapsed_s=18.0, + ), + make_batch_event( + "gsm8k:3", + "xtuner.rollout_controller.generate.start", + 1000.0, + uid=3, + batch=latest_batch, + train_step=2, + ), + make_batch_event( + "gsm8k:3", + "xtuner.rollout_controller.generate.end", + 1010.0, + uid=3, + batch=latest_batch, + train_step=2, + elapsed_s=10.0, + ), + ] + + payload = build_unified_trace_payload_from_events( + events, + trace_source="/tmp/trace", + default_scope="latest-produce-batch", + ) + latest_view = payload["views"]["latest-produce-batch"] + all_view = payload["views"]["all"] + + self.assertEqual(latest_view["overview"]["total_tasks"], 2) + self.assertEqual(all_view["overview"]["total_tasks"], 3) + + nested_spans = latest_view["task_details"]["gsm8k:2"]["timeline_spans"] + self.assertEqual( + [span["display_stage"] for span in nested_spans], + [ + "producer.generate", + "agent_loop.generate_group", + "agent_loop.generate_sample", + "rollout.generate", + "rollout_worker.generate", + "engine.generate", + "judger", + ], + ) + self.assertEqual([span["depth"] for span in nested_spans], [0, 1, 2, 3, 4, 5, 3]) + self.assertEqual(nested_spans[0]["left_pct"], 0.0) + self.assertEqual(latest_view["task_details"]["gsm8k:3"]["timeline_spans"][0]["left_pct"], 0.0) + + stats_by_stage = {stat["stage"]: stat for stat in latest_view["stage_stats"]} + self.assertEqual(stats_by_stage["engine.generate"]["avg_s"], 3.0) + self.assertEqual(stats_by_stage["engine.generate"]["p95_s"], 3.0) + self.assertEqual(stats_by_stage["engine.generate"]["max_s"], 3.0) + self.assertEqual(stats_by_stage["rollout.generate"]["done_spans"], 2) + self.assertEqual(stats_by_stage["rollout.generate"]["avg_s"], 8.5) + self.assertEqual(stats_by_stage["rollout.generate"]["max_s"], 10.0) + + +class JaegerTraceDashboardTest(unittest.TestCase): + def test_jaeger_dashboard_preserves_stage_kind_attribute(self): + trace_id = "b" * 32 + raw_trace = make_jaeger_trace( + [ + make_jaeger_span( + trace_id=trace_id, + span_id="llm-forward", + operation="xtuner.session_server.forward_worker", + process_id="p-xtuner", + start_us=1_000_000, + duration_us=200_000, + tags={ + "xtuner.trace_id": "gsm8k:llm", + "case.id": "gsm8k:llm", + "xtuner.stage": "xtuner.session_server.forward_worker", + "xtuner.stage_event": "end", + "xtuner.task_name": "gsm8k", + "xtuner.uid": "llm", + "xtuner.stage.kind": "llm_call", + "xtuner.attr.task_id": "task-llm", + }, + ) + ] + ) + + payload = build_dashboard_payload_from_jaeger_traces( + [raw_trace], + service_name="xtuner-rl", + now_s=2.0, + ) + + detail = payload["views"]["all"]["task_details"]["gsm8k:llm"] + self.assertTrue( + any( + event["stage"] == "xtuner.session_server.forward_worker.end" + and (event.get("attributes") or {}).get("xtuner.stage.kind") == "llm_call" + and (event.get("attributes") or {}).get("task_id") == "task-llm" + for event in detail["timeline_events"] + ) + ) + + def test_jaeger_dashboard_reports_running_completed_and_stage_durations(self): + spans = [ + make_jaeger_span( + trace_id="a" * 32, + span_id="marker-running", + operation="xtuner.localhost.infer.start", + process_id="p-xtuner", + start_us=1_000_000, + duration_us=1, + tags={ + "xtuner.lifecycle_marker": True, + "xtuner.trace_id": "task:running", + "case.id": "task:running", + "xtuner.stage": "xtuner.localhost.infer", + "xtuner.stage_event": "start", + "xtuner.task_name": "agent", + "xtuner.uid": "1", + "xtuner.status": "running", + "run.id": "run-a", + }, + ), + make_jaeger_span( + trace_id="a" * 32, + span_id="marker-completed", + operation="xtuner.localhost.judger.start", + process_id="p-xtuner", + start_us=2_000_000, + duration_us=1, + tags={ + "xtuner.lifecycle_marker": True, + "xtuner.trace_id": "task:completed", + "case.id": "task:completed", + "xtuner.stage": "xtuner.localhost.judger", + "xtuner.stage_event": "start", + "xtuner.task_name": "agent", + "xtuner.uid": "2", + "xtuner.status": "running", + "run.id": "run-a", + }, + ), + make_jaeger_span( + trace_id="a" * 32, + span_id="judger-completed", + operation="xtuner.localhost.judger", + process_id="p-xtuner", + start_us=2_000_000, + duration_us=4_000_000, + tags={ + "xtuner.trace_id": "task:completed", + "case.id": "task:completed", + "xtuner.stage": "xtuner.localhost.judger", + "xtuner.stage_event": "end", + "xtuner.task_name": "agent", + "xtuner.uid": "2", + "xtuner.status": "completed", + "run.id": "run-a", + "first_output_token_from_forward_ms": 350.0, + "stream_complete_from_forward_ms": 600.0, + }, + ), + make_jaeger_span( + trace_id="a" * 32, + span_id="toolcall-linked-service", + operation="xtuner.sandbox.toolcall", + process_id="p-lagent", + start_us=2_500_000, + duration_us=2_000_000, + tags={ + "xtuner.trace_id": "task:completed", + "case.id": "task:completed", + "xtuner.stage": "xtuner.sandbox.toolcall", + "xtuner.stage_event": "end", + "xtuner.task_name": "agent", + "xtuner.uid": "2", + "xtuner.status": "completed", + "run.id": "run-a", + }, + parent_span_id="judger-completed", + ), + make_jaeger_span( + trace_id="a" * 32, + span_id="old-run-span", + operation="xtuner.localhost.judger", + process_id="p-xtuner", + start_us=1_000_000, + duration_us=60_000_000, + tags={ + "xtuner.trace_id": "task:completed", + "case.id": "task:completed", + "xtuner.stage": "xtuner.localhost.judger", + "xtuner.stage_event": "end", + "xtuner.status": "completed", + "run.id": "run-b", + }, + ), + ] + + payload = build_dashboard_payload_from_jaeger_traces( + [make_jaeger_trace(spans)], + service_name="xtuner-rl", + run_id="run-a", + jaeger_query_url="http://127.0.0.1:16686", + now_s=10.0, + ) + + view = payload["views"]["all"] + self.assertEqual( + view["overview"], + { + "total_tasks": 2, + "pending_tasks": 0, + "running_tasks": 1, + "completed_tasks": 1, + "failed_tasks": 0, + }, + ) + + rows = {row["trace_id"]: row for row in view["task_rows"]} + self.assertEqual(rows["task:running"]["status"], "running") + self.assertEqual(rows["task:running"]["current_stage"], "localhost.infer") + self.assertEqual(rows["task:completed"]["status"], "completed") + + stage_stats = {item["stage"]: item for item in view["stage_stats"]} + self.assertEqual(stage_stats["localhost.infer"]["running_tasks"], 1) + self.assertEqual(stage_stats["localhost.judger"]["done_spans"], 1) + self.assertEqual(stage_stats["localhost.judger"]["avg_s"], 4.0) + self.assertEqual(stage_stats["localhost.judger"]["p95_s"], 4.0) + self.assertEqual(stage_stats["localhost.judger"]["max_s"], 4.0) + self.assertEqual(stage_stats["sandbox.toolcall"]["done_spans"], 1) + self.assertNotEqual(stage_stats["localhost.judger"]["max_s"], 60.0) + detail = view["task_details"]["task:completed"] + latency_events = [ + event + for event in detail["timeline_events"] + if (event.get("attributes") or {}).get("first_output_token_from_forward_ms") == 350.0 + ] + self.assertTrue(latency_events) + self.assertEqual(latency_events[0]["attributes"]["stream_complete_from_forward_ms"], 600.0) + + def test_jaeger_dashboard_filters_to_service_when_run_id_is_missing(self): + spans = [ + make_jaeger_span( + trace_id="a" * 32, + span_id="xtuner-span", + operation="xtuner.localhost.judger", + process_id="p-xtuner", + start_us=1_000_000, + duration_us=1_000_000, + tags={ + "xtuner.trace_id": "task:1", + "case.id": "task:1", + "xtuner.stage": "xtuner.localhost.judger", + "xtuner.stage_event": "end", + "xtuner.status": "completed", + }, + ), + make_jaeger_span( + trace_id="a" * 32, + span_id="linked-without-run-id", + operation="xtuner.sandbox.toolcall", + process_id="p-lagent", + start_us=1_000_000, + duration_us=9_000_000, + tags={ + "xtuner.trace_id": "task:1", + "case.id": "task:1", + "xtuner.stage": "xtuner.sandbox.toolcall", + "xtuner.stage_event": "end", + "xtuner.status": "completed", + }, + ), + ] + + payload = build_dashboard_payload_from_jaeger_traces( + [make_jaeger_trace(spans)], + service_name="xtuner-rl", + now_s=3.0, + ) + + stage_stats = {item["stage"]: item for item in payload["views"]["all"]["stage_stats"]} + self.assertIn("localhost.judger", stage_stats) + self.assertNotIn("sandbox.toolcall", stage_stats) + + +class TraceTrainerIntegrationTest(unittest.TestCase): + def tearDown(self): + reset_trace_for_test() + + def test_trainer_logs_jaeger_viewer_url_on_rank0(self): + with patch.dict("sys.modules", {"causal_conv1d_cuda": MagicMock()}): + from xtuner.v1.train import rl_trainer + + BaseRLTrainer = rl_trainer.BaseRLTrainer + + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = object.__new__(BaseRLTrainer) + trainer._meta = SimpleNamespace(latest_exp=SimpleNamespace(exp_dir=tmp_dir)) + trainer.logger = MagicMock() + handle = SimpleNamespace(url="http://127.0.0.1:12345", close=MagicMock()) + cfg = SimpleNamespace( + trace_config=TraceConfig( + enabled=True, + jaeger_query_url="http://127.0.0.1:16686", + ) + ) + + with ( + patch.object(rl_trainer, "get_rank", return_value=0), + patch.object(rl_trainer, "get_trace_run_id", return_value="run-a"), + patch("xtuner.v1.rl.trace.OtelTraceSink", return_value=RecordingTraceSink()), + patch( + "xtuner.tools.jaeger_trace_dashboard.start_jaeger_trace_dashboard", + return_value=handle, + ) as start_dashboard, + ): + trainer._init_trace(cfg) + trainer._close_trace() + + trainer.logger.info.assert_any_call("Jaeger Trace Viewer: http://127.0.0.1:16686") + trainer.logger.info.assert_any_call("XTuner Task Trace Dashboard: http://127.0.0.1:12345") + start_dashboard.assert_called_once_with( + "http://127.0.0.1:16686", + service_name="xtuner-rl", + run_id="run-a", + ) + handle.close.assert_called_once() diff --git a/xtuner/tools/jaeger_trace_dashboard.py b/xtuner/tools/jaeger_trace_dashboard.py new file mode 100644 index 0000000000..cf471e6391 --- /dev/null +++ b/xtuner/tools/jaeger_trace_dashboard.py @@ -0,0 +1,593 @@ +from __future__ import annotations + +import argparse +import dataclasses +import http.server +import json +import threading +from collections.abc import Iterable +from pathlib import Path +from typing import Any +from urllib.parse import parse_qs, urlencode, urlparse +from urllib.request import Request, urlopen + +from xtuner.tools.task_trace_analysis import ( + TRACE_VIEWER_SCOPE_ALL, + build_unified_trace_payload_from_events, +) +from xtuner.tools.task_trace_view import ( + fetch_jaeger_trace_payload, + normalize_jaeger_query_url, + normalize_jaeger_trace_response, + render_unified_trace_html, +) +from xtuner.v1.rl.trace import TraceEvent + + +JAEGER_QUERY_TIMEOUT_S = 3.0 +JAEGER_DEFAULT_LOOKBACK_S = 60 * 60 +JAEGER_DEFAULT_LIMIT = 500 +_MICROSECONDS_PER_SECOND = 1_000_000.0 +_SESSION_LATENCY_ATTRS = frozenset( + { + "upstream_headers_ms", + "first_chunk_ms", + "first_chunk_from_forward_ms", + "first_output_token_ms", + "first_output_token_from_forward_ms", + "first_content_ms", + "first_content_from_forward_ms", + "stream_read_ms", + "stream_complete_from_forward_ms", + "chunks", + "raw_response_bytes", + "output_tokens", + "prompt_tokens", + "completion_tokens", + "total_tokens", + "finish_reason", + "client_alive", + "http_status", + } +) +_TASK_TIMELINE_ATTRS = _SESSION_LATENCY_ATTRS | frozenset( + { + "xtuner.stage.kind", + "task_id", + "sandbox_name", + "sandbox_env_id", + "sandbox_url", + "sandbox_image", + "entry_kind", + "entry_name", + "validate_name", + "judger_name", + "reward_key", + "stage_name", + "agent_name", + "agent_config", + "agent_status", + "agent_message_count", + "agent_has_tools", + "input_tokens", + "label_tokens", + } +) + + +@dataclasses.dataclass +class JaegerTraceDashboardHandle: + server: http.server.ThreadingHTTPServer + thread: threading.Thread + url: str + closed: bool = False + + def close(self) -> None: + if self.closed: + return + self.closed = True + self.server.shutdown() + self.server.server_close() + self.thread.join(timeout=5) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Serve an XTuner task dashboard backed by Jaeger Query API.") + parser.add_argument( + "--jaeger-query-url", + default="http://127.0.0.1:16686", + help="Jaeger Query/UI base URL.", + ) + parser.add_argument( + "--service", + default="xtuner-rl", + help="OpenTelemetry service.name to query from Jaeger.", + ) + parser.add_argument( + "--run-id", + default=None, + help="Optional run.id attribute used to isolate one training run.", + ) + parser.add_argument("--host", default="127.0.0.1", help="Dashboard bind host.") + parser.add_argument("--port", type=int, default=0, help="Dashboard bind port. Defaults to an available port.") + parser.add_argument( + "--refresh-interval", + type=float, + default=2.0, + help="Browser refresh interval in seconds.", + ) + parser.add_argument( + "--lookback", + type=int, + default=JAEGER_DEFAULT_LOOKBACK_S, + help="Jaeger query lookback window in seconds.", + ) + parser.add_argument( + "--limit", + type=int, + default=JAEGER_DEFAULT_LIMIT, + help="Maximum number of traces fetched from Jaeger per refresh.", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + default=None, + help="Write a static dashboard HTML snapshot instead of serving.", + ) + return parser.parse_args() + + +def fetch_jaeger_dashboard_payload( + jaeger_query_url: str, + *, + service_name: str, + run_id: str | None = None, + lookback_s: int = JAEGER_DEFAULT_LOOKBACK_S, + limit: int = JAEGER_DEFAULT_LIMIT, + timeout_s: float = JAEGER_QUERY_TIMEOUT_S, + now_s: float | None = None, +) -> dict[str, Any]: + traces = fetch_jaeger_traces( + jaeger_query_url, + service_name=service_name, + lookback_s=lookback_s, + limit=limit, + timeout_s=timeout_s, + ) + return build_dashboard_payload_from_jaeger_traces( + traces, + service_name=service_name, + run_id=run_id, + jaeger_query_url=jaeger_query_url, + lookback_s=lookback_s, + limit=limit, + now_s=now_s, + ) + + +def fetch_jaeger_traces( + jaeger_query_url: str, + *, + service_name: str, + lookback_s: int = JAEGER_DEFAULT_LOOKBACK_S, + limit: int = JAEGER_DEFAULT_LIMIT, + timeout_s: float = JAEGER_QUERY_TIMEOUT_S, +) -> list[dict[str, Any]]: + base_url = _require_jaeger_query_url(jaeger_query_url) + query = urlencode( + { + "service": service_name, + "lookback": f"{max(1, lookback_s)}s", + "limit": max(1, limit), + } + ) + request = Request(f"{base_url}/api/traces?{query}", headers={"Accept": "application/json"}) + with urlopen(request, timeout=timeout_s) as response: + raw = json.loads(response.read().decode("utf-8")) + data = raw.get("data") if isinstance(raw, dict) else None + return data if isinstance(data, list) else [] + + +def build_dashboard_payload_from_jaeger_traces( + traces: Iterable[dict[str, Any]], + *, + service_name: str, + run_id: str | None = None, + jaeger_query_url: str | None = None, + lookback_s: int = JAEGER_DEFAULT_LOOKBACK_S, + limit: int = JAEGER_DEFAULT_LIMIT, + now_s: float | None = None, +) -> dict[str, Any]: + events = jaeger_traces_to_trace_events(traces, service_name=service_name, run_id=run_id) + payload = build_unified_trace_payload_from_events( + events, + trace_source=f"Jaeger service={service_name}", + default_scope=TRACE_VIEWER_SCOPE_ALL, + jaeger_query_url=jaeger_query_url, + now_s=now_s, + ) + payload["title"] = "XTuner Task Trace Dashboard" + payload["source"] = "jaeger" + payload["jaeger_service"] = service_name + payload["jaeger_run_id"] = run_id + payload["jaeger_lookback_s"] = lookback_s + payload["jaeger_limit"] = limit + return payload + + +def jaeger_traces_to_trace_events( + traces: Iterable[dict[str, Any]], + *, + service_name: str, + run_id: str | None = None, +) -> list[TraceEvent]: + normalized_spans = _normalize_dashboard_spans(traces) + included_spans = _select_service_run_spans(normalized_spans, service_name=service_name, run_id=run_id) + marker_starts = { + _event_key(span) + for span in included_spans + if _is_lifecycle_marker(span) and _span_lifecycle(span) == "start" + } + + events: list[TraceEvent] = [] + seen: set[tuple[str, str, float, int | str | None]] = set() + for span in sorted(included_spans, key=lambda item: (item["start_time_us"], item["duration_us"])): + if _is_lifecycle_marker(span): + lifecycle = _span_lifecycle(span) + if lifecycle is None: + continue + _append_event(events, seen, _event_from_span(span, suffix=lifecycle)) + continue + + lifecycle = _span_lifecycle(span) + if lifecycle not in {"end", "error"}: + lifecycle = "error" if span["is_error"] else "end" + if _event_key(span) not in marker_starts: + _append_event(events, seen, _event_from_span(span, suffix="start", at_start=True)) + _append_event(events, seen, _event_from_span(span, suffix=lifecycle)) + + return sorted(events, key=lambda event: (event.timestamp_s, event.trace_id, event.stage)) + + +def start_jaeger_trace_dashboard( + jaeger_query_url: str, + *, + service_name: str, + run_id: str | None = None, + host: str = "127.0.0.1", + port: int = 0, + refresh_interval_s: float = 2.0, + lookback_s: int = JAEGER_DEFAULT_LOOKBACK_S, + limit: int = JAEGER_DEFAULT_LIMIT, +) -> JaegerTraceDashboardHandle: + jaeger_query_url = _require_jaeger_query_url(jaeger_query_url) + + class JaegerDashboardHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self) -> None: + path = self.path.split("?", 1)[0] + if path in {"/", "/index.html"}: + payload = self._build_payload() + html = render_unified_trace_html( + payload, + live=True, + api_url="/api/trace", + refresh_interval_s=refresh_interval_s, + ) + self._send_bytes(html.encode("utf-8"), "text/html; charset=utf-8") + return + + if path == "/api/trace": + self._send_json(self._build_payload()) + return + + if path == "/api/services": + self._send_json({"services": fetch_jaeger_services(jaeger_query_url)}) + return + + if path.startswith("/api/jaeger/trace/"): + otel_trace_id = path.rsplit("/", 1)[-1] + self._send_json(fetch_jaeger_trace_payload(jaeger_query_url, otel_trace_id)) + return + + self.send_error(404) + + def _build_payload(self) -> dict[str, Any]: + query = parse_qs(urlparse(self.path).query) + selected_service = query.get("service", [service_name])[0] or service_name + selected_run_id = query.get("run_id", [run_id or ""])[0] or None + selected_lookback = _parse_positive_int(query.get("lookback", [str(lookback_s)])[0], lookback_s) + selected_limit = _parse_positive_int(query.get("limit", [str(limit)])[0], limit) + return fetch_jaeger_dashboard_payload( + jaeger_query_url, + service_name=selected_service, + run_id=selected_run_id, + lookback_s=selected_lookback, + limit=selected_limit, + ) + + def _send_json(self, payload: dict[str, Any], status: int = 200) -> None: + body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + self._send_bytes(body, "application/json; charset=utf-8", status=status) + + def _send_bytes(self, body: bytes, content_type: str, *, status: int = 200) -> None: + self.send_response(status) + self.send_header("Content-Type", content_type) + self.send_header("Content-Length", str(len(body))) + self.send_header("Cache-Control", "no-store") + self.end_headers() + self.wfile.write(body) + + def log_message(self, format: str, *args: Any) -> None: + return + + server = http.server.ThreadingHTTPServer((host, port), JaegerDashboardHandler) + server_host, server_port = server.server_address + display_host = "127.0.0.1" if server_host in {"", "0.0.0.0"} else server_host + thread = threading.Thread(target=server.serve_forever, name="JaegerTraceDashboard", daemon=True) + thread.start() + return JaegerTraceDashboardHandle( + server=server, + thread=thread, + url=f"http://{display_host}:{server_port}", + ) + + +def serve_jaeger_trace_dashboard( + jaeger_query_url: str, + *, + service_name: str, + run_id: str | None = None, + host: str = "127.0.0.1", + port: int = 0, + refresh_interval_s: float = 2.0, + lookback_s: int = JAEGER_DEFAULT_LOOKBACK_S, + limit: int = JAEGER_DEFAULT_LIMIT, +) -> None: + handle = start_jaeger_trace_dashboard( + jaeger_query_url, + service_name=service_name, + run_id=run_id, + host=host, + port=port, + refresh_interval_s=refresh_interval_s, + lookback_s=lookback_s, + limit=limit, + ) + print(f"Serving XTuner Task Trace Dashboard on {handle.url}", flush=True) + print(f"Jaeger Query: {jaeger_query_url}", flush=True) + print(f"Service: {service_name}", flush=True) + if run_id: + print(f"Run ID: {run_id}", flush=True) + try: + handle.thread.join() + except KeyboardInterrupt: + pass + finally: + handle.close() + + +def write_dashboard_html(payload: dict[str, Any], output_path: Path) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(render_unified_trace_html(payload, live=False), encoding="utf-8") + + +def fetch_jaeger_services( + jaeger_query_url: str, + *, + timeout_s: float = JAEGER_QUERY_TIMEOUT_S, +) -> list[str]: + base_url = _require_jaeger_query_url(jaeger_query_url) + request = Request(f"{base_url}/api/services", headers={"Accept": "application/json"}) + with urlopen(request, timeout=timeout_s) as response: + raw = json.loads(response.read().decode("utf-8")) + services = raw.get("data") if isinstance(raw, dict) else None + return sorted(str(service) for service in services) if isinstance(services, list) else [] + + +def _normalize_dashboard_spans(traces: Iterable[dict[str, Any]]) -> list[dict[str, Any]]: + spans: list[dict[str, Any]] = [] + for trace in traces: + trace_id = str(trace.get("traceID") or "") + if not trace_id: + continue + normalized = normalize_jaeger_trace_response({"data": [trace]}, trace_id) + for span in normalized.get("spans", []): + if isinstance(span, dict): + spans.append(span) + return spans + + +def _select_service_run_spans( + spans: list[dict[str, Any]], + *, + service_name: str, + run_id: str | None = None, +) -> list[dict[str, Any]]: + selected_service_spans = [span for span in spans if span.get("service_name") == service_name] + if run_id is not None: + selected_trace_ids = { + span.get("trace_id") + for span in selected_service_spans + if isinstance(span.get("tags"), dict) and str(span["tags"].get("run.id")) == run_id + } + return [ + span + for span in spans + if span.get("trace_id") in selected_trace_ids + and isinstance(span.get("tags"), dict) + and str(span["tags"].get("run.id")) == run_id + ] + + run_ids = { + str(span["tags"]["run.id"]) + for span in selected_service_spans + if isinstance(span.get("tags"), dict) and span["tags"].get("run.id") is not None + } + if not run_ids: + return selected_service_spans + + selected_trace_ids = {span.get("trace_id") for span in selected_service_spans} + return [ + span + for span in spans + if span.get("trace_id") in selected_trace_ids + and isinstance(span.get("tags"), dict) + and str(span["tags"].get("run.id")) in run_ids + ] + + +def _event_from_span(span: dict[str, Any], *, suffix: str, at_start: bool = False) -> TraceEvent: + tags = span.get("tags") if isinstance(span.get("tags"), dict) else {} + start_s = float(span.get("start_time_us") or 0) / _MICROSECONDS_PER_SECOND + duration_s = max(0.0, float(span.get("duration_us") or 0) / _MICROSECONDS_PER_SECOND) + timestamp_s = start_s if at_start or suffix == "start" else start_s + duration_s + stage = _span_stage(span) + error_msg = _span_error_message(span) + return TraceEvent( + trace_id=_logical_trace_id(span), + stage=f"{stage}.{suffix}", + timestamp_s=timestamp_s, + status=_string_or_none(tags.get("xtuner.status")), + task_name=_string_or_none(tags.get("xtuner.task_name") or tags.get("task.name")), + uid=tags.get("xtuner.uid"), + session_uid=tags.get("xtuner.session_uid"), + train_step=_int_or_none(tags.get("xtuner.train_step")), + model_step=_int_or_none(tags.get("xtuner.model_step")), + producer_future_step=_int_or_none(tags.get("xtuner.producer_future_step")), + produce_batch_id=_string_or_none(tags.get("xtuner.produce_batch_id")), + worker_rank=_int_or_none(tags.get("xtuner.worker_rank")), + elapsed_s=None if suffix == "start" else duration_s, + error_msg=error_msg if suffix == "error" else None, + error_type=_string_or_none(tags.get("error.type")) if suffix == "error" else None, + attributes=_event_attributes_from_tags(tags), + ) + + +def _append_event( + events: list[TraceEvent], + seen: set[tuple[str, str, float, int | str | None]], + event: TraceEvent, +) -> None: + key = (event.trace_id, event.stage, event.timestamp_s, event.uid) + if key in seen: + return + seen.add(key) + events.append(event) + + +def _event_key(span: dict[str, Any]) -> tuple[str, str, float]: + return ( + _logical_trace_id(span), + _span_stage(span), + float(span.get("start_time_us") or 0) / _MICROSECONDS_PER_SECOND, + ) + + +def _span_stage(span: dict[str, Any]) -> str: + tags = span.get("tags") if isinstance(span.get("tags"), dict) else {} + stage = tags.get("xtuner.stage") or span.get("operation_name") + return str(stage or "unknown") + + +def _span_lifecycle(span: dict[str, Any]) -> str | None: + tags = span.get("tags") if isinstance(span.get("tags"), dict) else {} + lifecycle = tags.get("xtuner.stage_event") + if lifecycle in {"start", "end", "error", "event"}: + return str(lifecycle) + operation_name = str(span.get("operation_name") or "") + for suffix in ("start", "end", "error"): + if operation_name.endswith(f".{suffix}"): + return suffix + return None + + +def _is_lifecycle_marker(span: dict[str, Any]) -> bool: + tags = span.get("tags") if isinstance(span.get("tags"), dict) else {} + return tags.get("xtuner.lifecycle_marker") is True + + +def _logical_trace_id(span: dict[str, Any]) -> str: + tags = span.get("tags") if isinstance(span.get("tags"), dict) else {} + value = tags.get("xtuner.trace_id") or tags.get("case.id") or span.get("trace_id") + return str(value or "unknown") + + +def _span_error_message(span: dict[str, Any]) -> str | None: + tags = span.get("tags") if isinstance(span.get("tags"), dict) else {} + for key in ("xtuner.error.message", "error.message", "otel.status_description"): + value = tags.get(key) + if value: + return str(value) + value = span.get("error_msg") + return str(value) if value else None + + +def _event_attributes_from_tags(tags: dict[str, Any]) -> dict[str, Any] | None: + attrs = {key: tags[key] for key in _TASK_TIMELINE_ATTRS if key in tags} + for key, value in tags.items(): + if key.startswith("xtuner.attr."): + attrs[key.removeprefix("xtuner.attr.")] = value + return attrs or None + + +def _string_or_none(value: Any) -> str | None: + if value is None: + return None + return str(value) + + +def _int_or_none(value: Any) -> int | None: + if value is None: + return None + if isinstance(value, bool): + return None + try: + return int(value) + except (TypeError, ValueError, OverflowError): + return None + + +def _parse_positive_int(value: str, default: int) -> int: + try: + parsed = int(value) + except (TypeError, ValueError): + return default + return parsed if parsed > 0 else default + + +def _require_jaeger_query_url(jaeger_query_url: str | None) -> str: + normalized = normalize_jaeger_query_url(jaeger_query_url) + if normalized is None: + raise ValueError("jaeger_query_url is required") + return normalized + + +def main() -> None: + args = parse_args() + if args.output is None: + serve_jaeger_trace_dashboard( + args.jaeger_query_url, + service_name=args.service, + run_id=args.run_id, + host=args.host, + port=args.port, + refresh_interval_s=args.refresh_interval, + lookback_s=args.lookback, + limit=args.limit, + ) + return + + payload = fetch_jaeger_dashboard_payload( + args.jaeger_query_url, + service_name=args.service, + run_id=args.run_id, + lookback_s=args.lookback, + limit=args.limit, + ) + write_dashboard_html(payload, args.output) + print(args.output) + + +if __name__ == "__main__": + main() diff --git a/xtuner/tools/task_trace_analysis.py b/xtuner/tools/task_trace_analysis.py new file mode 100644 index 0000000000..65faf848c8 --- /dev/null +++ b/xtuner/tools/task_trace_analysis.py @@ -0,0 +1,622 @@ +from __future__ import annotations + +import dataclasses +import time +from collections import Counter, defaultdict +from typing import Any, Iterable, Literal + +from xtuner.v1.rl.trace import TraceEvent, stable_otel_trace_id + + +TraceViewerScope = Literal["latest-produce-batch", "all"] +TRACE_VIEWER_SCOPE_LATEST_PRODUCE_BATCH: TraceViewerScope = "latest-produce-batch" +TRACE_VIEWER_SCOPE_ALL: TraceViewerScope = "all" +TRACE_STAGE_LABELS = { + "xtuner.producer.sample_group": "sampler", + "xtuner.producer.generate_group": "producer.generate", + "xtuner.producer.put_generated_group": "producer.put", + "xtuner.agent_loop.generate_group": "agent_loop.generate_group", + "xtuner.agent_loop.generate_sample": "agent_loop.generate_sample", + "xtuner.agent_in_sandbox.generate_group": "sandbox.generate_group", + "xtuner.agent_in_sandbox.generate_sample": "sandbox.generate_sample", + "xtuner.agent_in_sandbox.materialize_trajectory": "sandbox.materialize", + "xtuner.agent_in_localhost.generate_group": "localhost.generate_group", + "xtuner.agent_in_localhost.generate_sample": "localhost.generate_sample", + "xtuner.agent_in_localhost.materialize_trajectory": "localhost.materialize", + "xtuner.rollout_controller.generate": "rollout.generate", + "xtuner.rollout_worker.generate": "rollout_worker.generate", + "xtuner.rollout_engine.generate": "engine.generate", + "xtuner.judger.judge": "judger", + "xtuner.sandbox.run_total": "sandbox.run_total", + "xtuner.sandbox.acquire": "sandbox.acquire", + "xtuner.sandbox.infer": "sandbox.infer", + "xtuner.sandbox.validate": "sandbox.validate", + "xtuner.localhost.run_total": "localhost.run_total", + "xtuner.localhost.infer": "localhost.infer", + "xtuner.localhost.validate": "localhost.validate", + "xtuner.localhost.judger": "localhost.judger", + "xtuner.localhost.agent": "localhost.agent", + "xtuner.session_server.forward_worker": "llm.forward", + "xtuner.session_server.stream_read": "llm.stream", +} + +TRACE_STAGE_PALETTE = [ + "#2563eb", + "#059669", + "#d97706", + "#7c3aed", + "#0891b2", + "#4d7c0f", + "#9333ea", + "#0f766e", + "#475569", + "#b45309", + "#1d4ed8", + "#0d9488", +] +TRACE_ERROR_COLOR = "#dc2626" + + +@dataclasses.dataclass +class TraceSpanRecord: + trace_id: str + span: str + display_stage: str + start_s: float + end_s: float + duration_s: float + outcome: str + depth: int = 0 + task_name: str | None = None + uid: int | str | None = None + status: str | None = None + train_step: int | None = None + worker_rank: int | None = None + error_msg: str | None = None + + +def display_trace_stage(span: str | None) -> str: + if not span: + return "unknown" + if span in TRACE_STAGE_LABELS: + return TRACE_STAGE_LABELS[span] + if span.startswith("xtuner.") and span.endswith(".request"): + return span.removeprefix("xtuner.") + return span.removeprefix("xtuner.") + + +def events_to_timelines(events: Iterable[TraceEvent]) -> dict[str, list[TraceEvent]]: + timelines: dict[str, list[TraceEvent]] = defaultdict(list) + for event in events: + timelines[event.trace_id].append(event) + for trace_id in timelines: + timelines[trace_id].sort(key=lambda event: event.timestamp_s) + return dict(timelines) + + +def filter_trace_events_by_scope( + events: Iterable[TraceEvent], + scope: TraceViewerScope = TRACE_VIEWER_SCOPE_LATEST_PRODUCE_BATCH, +) -> list[TraceEvent]: + event_list = list(events) + if scope == TRACE_VIEWER_SCOPE_ALL: + return event_list + if scope != TRACE_VIEWER_SCOPE_LATEST_PRODUCE_BATCH: + raise ValueError(f"Unsupported trace viewer scope: {scope!r}") + + latest_batch_id = _latest_produce_batch_id(event_list) + if latest_batch_id is not None: + return [event for event in event_list if event.produce_batch_id == latest_batch_id] + + latest_key = _latest_produce_batch_key(event_list) + if latest_key is None: + return event_list + return [event for event in event_list if _produce_batch_key(event) == latest_key] + + +def _latest_produce_batch_key(events: list[TraceEvent]) -> tuple[int, int, int] | None: + keys: list[tuple[int, int, int]] = [] + for event in events: + key = _produce_batch_key(event) + if key is not None: + keys.append(key) + return max(keys) if keys else None + + +def _latest_produce_batch_id(events: list[TraceEvent]) -> str | None: + latest_event: TraceEvent | None = None + latest_sort_key: tuple[int, int, int, float] | None = None + for event in events: + if event.produce_batch_id is None: + continue + batch_key = _produce_batch_key(event) + if batch_key is None: + sort_key = (-1, -1, -1, event.timestamp_s) + else: + sort_key = (*batch_key, event.timestamp_s) + if latest_sort_key is None or sort_key > latest_sort_key: + latest_event = event + latest_sort_key = sort_key + if latest_event is None: + return None + return latest_event.produce_batch_id + + +def _produce_batch_key(event: TraceEvent) -> tuple[int, int, int] | None: + if event.train_step is None: + return None + return ( + event.train_step, + -1 if event.model_step is None else event.model_step, + -1 if event.producer_future_step is None else event.producer_future_step, + ) + + +def _get_open_spans(events: list[TraceEvent]) -> list[tuple[str, TraceEvent]]: + stacks: dict[str, list[TraceEvent]] = defaultdict(list) + for event in events: + span, suffix = _split_span_stage(event.stage) + if span is None or suffix is None: + continue + if suffix == "start": + stacks[span].append(event) + elif suffix in {"end", "error"} and stacks.get(span): + stacks[span].pop() + open_spans: list[tuple[str, TraceEvent]] = [] + for span, stack in stacks.items(): + open_spans.extend((span, event) for event in stack) + return open_spans + + +def _split_span_stage(stage: str) -> tuple[str | None, str | None]: + for suffix in (".start", ".end", ".error"): + if stage.endswith(suffix): + return stage[: -len(suffix)], suffix[1:] + return None, None + + +def _percentile(sorted_values: list[float], percentile: float) -> float: + if not sorted_values: + return 0.0 + if len(sorted_values) == 1: + return sorted_values[0] + position = (len(sorted_values) - 1) * percentile + lower = int(position) + upper = min(lower + 1, len(sorted_values) - 1) + if lower == upper: + return sorted_values[lower] + fraction = position - lower + return sorted_values[lower] * (1 - fraction) + sorted_values[upper] * fraction + + +def build_trace_span_records( + events: Iterable[TraceEvent], + *, + include_open: bool = True, + now_s: float | None = None, +) -> list[TraceSpanRecord]: + events_by_trace: dict[str, list[TraceEvent]] = defaultdict(list) + latest_event_s = 0.0 + for event in events: + events_by_trace[event.trace_id].append(event) + latest_event_s = max(latest_event_s, event.timestamp_s) + now_s = latest_event_s if now_s is None else now_s + + records: list[TraceSpanRecord] = [] + for trace_id, trace_events in events_by_trace.items(): + stacks: dict[str, list[TraceEvent]] = defaultdict(list) + for event in sorted(trace_events, key=lambda item: item.timestamp_s): + span, suffix = _split_span_stage(event.stage) + if span is None or suffix is None: + continue + if suffix == "start": + stacks[span].append(event) + continue + if suffix in {"end", "error"} and stacks.get(span): + start_event = stacks[span].pop() + records.append(_build_span_record(start_event, event, outcome=suffix)) + continue + if suffix in {"end", "error"} and event.elapsed_s is not None: + records.append(_build_elapsed_only_span_record(event, span=span, outcome=suffix)) + + if include_open: + for span_starts in stacks.values(): + for start_event in span_starts: + span, _ = _split_span_stage(start_event.stage) + if span is None: + continue + records.append(_build_open_span_record(start_event, span=span, now_s=now_s)) + + _assign_depths(records) + return sorted(records, key=lambda record: (record.trace_id, record.start_s, record.depth, record.end_s)) + + +def build_timeline_stage_records(records: Iterable[TraceSpanRecord]) -> list[TraceSpanRecord]: + return sorted(records, key=lambda record: (record.trace_id, record.start_s, record.depth, record.end_s)) + + +def build_stage_colors(records: Iterable[TraceSpanRecord]) -> dict[str, str]: + stages = sorted({record.display_stage for record in records}) + return {stage: TRACE_STAGE_PALETTE[index % len(TRACE_STAGE_PALETTE)] for index, stage in enumerate(stages)} + + +def build_stage_span_stats(records: Iterable[TraceSpanRecord]) -> list[dict[str, Any]]: + grouped: dict[str, list[TraceSpanRecord]] = defaultdict(list) + for record in records: + grouped[record.display_stage].append(record) + + stats = [] + for stage, stage_records in grouped.items(): + closed_records = [record for record in stage_records if record.outcome != "open"] + durations = sorted(record.duration_s for record in closed_records) + unique_trace_ids = {record.trace_id for record in stage_records} + total_s = sum(durations) + stats.append( + { + "stage": stage, + "span_count": len(closed_records), + "done_spans": len(closed_records), + "visited_tasks": len(unique_trace_ids), + "open_count": sum(1 for record in stage_records if record.outcome == "open"), + "error_count": sum(1 for record in stage_records if record.outcome == "error"), + "total_s": total_s, + "avg_s": total_s / len(durations) if durations else 0.0, + "p50_s": _percentile(durations, 0.50), + "p95_s": _percentile(durations, 0.95), + "max_s": durations[-1] if durations else 0.0, + } + ) + return sorted(stats, key=lambda item: (item["p95_s"], item["total_s"]), reverse=True) + + +def span_record_to_payload( + record: TraceSpanRecord, + timeline_start_s: float, + timeline_duration_s: float, + stage_colors: dict[str, str], +) -> dict[str, Any]: + left_pct = (record.start_s - timeline_start_s) / timeline_duration_s * 100.0 + width_pct = max(0.2, record.duration_s / timeline_duration_s * 100.0) + return { + **dataclasses.asdict(record), + "left_pct": left_pct, + "width_pct": min(width_pct, max(0.2, 100.0 - left_pct)), + "top_px": record.depth * 22, + "color": TRACE_ERROR_COLOR if record.outcome == "error" else stage_colors[record.display_stage], + } + + +def build_unified_trace_payload_from_events( + events: Iterable[TraceEvent], + *, + trace_source: str, + default_scope: TraceViewerScope = TRACE_VIEWER_SCOPE_ALL, + jaeger_query_url: str | None = None, + now_s: float | None = None, +) -> dict[str, Any]: + event_list = list(events) + if default_scope not in {TRACE_VIEWER_SCOPE_ALL, TRACE_VIEWER_SCOPE_LATEST_PRODUCE_BATCH}: + default_scope = TRACE_VIEWER_SCOPE_ALL + normalized_jaeger_query_url = _normalize_jaeger_query_url(jaeger_query_url) + return { + "generated_at_s": time.time() if now_s is None else now_s, + "trace_source": trace_source, + "default_scope": default_scope, + "available_scopes": [TRACE_VIEWER_SCOPE_ALL, TRACE_VIEWER_SCOPE_LATEST_PRODUCE_BATCH], + "jaeger_query_url": normalized_jaeger_query_url, + "jaeger_query_enabled": normalized_jaeger_query_url is not None, + "views": { + TRACE_VIEWER_SCOPE_ALL: _build_unified_trace_view( + event_list, + trace_source=trace_source, + scope=TRACE_VIEWER_SCOPE_ALL, + now_s=now_s, + ), + TRACE_VIEWER_SCOPE_LATEST_PRODUCE_BATCH: _build_unified_trace_view( + event_list, + trace_source=trace_source, + scope=TRACE_VIEWER_SCOPE_LATEST_PRODUCE_BATCH, + now_s=now_s, + ), + }, + } + + +def _build_unified_trace_view( + events: Iterable[TraceEvent], + *, + trace_source: str, + scope: TraceViewerScope, + now_s: float | None = None, +) -> dict[str, Any]: + filtered_events = filter_trace_events_by_scope(events, scope) + timelines = events_to_timelines(filtered_events) + task_rows = _build_unified_task_rows(timelines, now_s=now_s) + span_records = build_timeline_stage_records(build_trace_span_records(filtered_events, now_s=now_s)) + records_by_trace = _group_trace_span_records(span_records) + stage_colors = build_stage_colors(span_records) + stage_stats = _build_unified_stage_stats(task_rows, span_records) + task_details = _build_task_details( + timelines, + records_by_trace, + stage_colors, + now_s=now_s, + ) + + return { + "scope": scope, + "trace_source": trace_source, + "event_count": len(filtered_events), + "trace_count": len(timelines), + "overview": _build_unified_overview(task_rows), + "stage_stats": stage_stats, + "task_rows": task_rows, + "task_details": task_details, + "stage_colors": stage_colors, + } + + +def _build_unified_task_rows( + timelines: dict[str, list[TraceEvent]] | Iterable[TraceEvent], + *, + now_s: float | None = None, +) -> list[dict[str, Any]]: + if not isinstance(timelines, dict): + timelines = events_to_timelines(timelines) + now_s = time.time() if now_s is None else now_s + rows: list[dict[str, Any]] = [] + for trace_id, events in timelines.items(): + if not events: + continue + sorted_events = sorted(events, key=lambda event: event.timestamp_s) + first = sorted_events[0] + latest = sorted_events[-1] + open_spans = _get_open_spans(sorted_events) + newest_open = max(open_spans, key=lambda span: span[1].timestamp_s) if open_spans else None + is_pending = _is_pending_task(latest, open_spans) + current_span = newest_open[0] if newest_open is not None else _span_from_stage(latest.stage) + current_stage = "pending" if is_pending else display_trace_stage(current_span) + is_failed = latest.stage.endswith(".error") or latest.status == "failed" + is_completed = not is_failed and newest_open is None and latest.status == "completed" + status = "failed" if is_failed else "completed" if is_completed else "pending" if is_pending else "running" + end_s = now_s if newest_open is not None else latest.timestamp_s + otel_trace_id = stable_otel_trace_id(trace_id) + rows.append( + { + "trace_id": trace_id, + "otel_trace_id": otel_trace_id, + "task_name": latest.task_name, + "uid": latest.uid, + "status": status, + "current_stage": current_stage, + "latest_stage": latest.stage, + "latest_timestamp_s": latest.timestamp_s, + "event_count": len(sorted_events), + "duration_s": max(0.0, end_s - first.timestamp_s), + "open_span": newest_open[0] if newest_open is not None else None, + "open_age_s": now_s - newest_open[1].timestamp_s if newest_open is not None else None, + "produce_batch_id": latest.produce_batch_id, + } + ) + status_order = {"failed": 0, "running": 1, "pending": 2, "completed": 3} + return sorted( + rows, + key=lambda row: (status_order.get(row["status"], 9), -(row["open_age_s"] or 0.0), -row["latest_timestamp_s"]), + ) + + +def _build_unified_overview(task_rows: list[dict[str, Any]]) -> dict[str, int]: + status_counts = Counter(row["status"] for row in task_rows) + return { + "total_tasks": len(task_rows), + "pending_tasks": status_counts.get("pending", 0), + "completed_tasks": status_counts.get("completed", 0), + "running_tasks": status_counts.get("running", 0), + "failed_tasks": status_counts.get("failed", 0), + } + + +def _build_unified_stage_stats( + task_rows: list[dict[str, Any]], + span_records: list[TraceSpanRecord], +) -> list[dict[str, Any]]: + running_counts = Counter(row["current_stage"] for row in task_rows if row["status"] == "running") + pending_counts = Counter(row["current_stage"] for row in task_rows if row["status"] == "pending") + span_stats = {item["stage"]: item for item in build_stage_span_stats(span_records)} + stages = sorted(set(running_counts) | set(pending_counts) | set(span_stats)) + + stats = [] + for stage in stages: + item = dict(span_stats.get(stage, {})) + item.setdefault("stage", stage) + item.setdefault("visited_tasks", 0) + item.setdefault("span_count", 0) + item.setdefault("done_spans", item["span_count"]) + item.setdefault("open_count", 0) + item.setdefault("error_count", 0) + item.setdefault("total_s", 0.0) + item.setdefault("avg_s", 0.0) + item.setdefault("p50_s", 0.0) + item.setdefault("p95_s", 0.0) + item.setdefault("max_s", 0.0) + item["running_tasks"] = running_counts.get(stage, 0) + item["pending_tasks"] = pending_counts.get(stage, 0) + item["current_tasks"] = item["running_tasks"] + item["pending_tasks"] + stats.append(item) + return sorted( + stats, + key=lambda item: (item["current_tasks"], item["visited_tasks"], item["p95_s"], item["max_s"]), + reverse=True, + ) + + +def _build_task_details( + timelines: dict[str, list[TraceEvent]], + records_by_trace: dict[str, list[TraceSpanRecord]], + stage_colors: dict[str, str], + *, + now_s: float | None = None, +) -> dict[str, dict[str, Any]]: + details: dict[str, dict[str, Any]] = {} + now_s = time.time() if now_s is None else now_s + for trace_id, events in timelines.items(): + sorted_events = sorted(events, key=lambda event: event.timestamp_s) + first = sorted_events[0] + latest = sorted_events[-1] + open_spans = _get_open_spans(sorted_events) + is_pending = _is_pending_task(latest, open_spans) + current_span = open_spans[-1][0] if open_spans else _span_from_stage(latest.stage) + current_stage = "pending" if is_pending else display_trace_stage(current_span) + is_failed = latest.stage.endswith(".error") or latest.status == "failed" + is_completed = latest.status == "completed" and not open_spans and not is_failed + trace_records = records_by_trace.get(trace_id, []) + trace_start = min((record.start_s for record in trace_records), default=first.timestamp_s) + trace_end = max((record.end_s for record in trace_records), default=latest.timestamp_s) + if trace_end < latest.timestamp_s: + trace_end = latest.timestamp_s + trace_window_s = max(0.001, trace_end - trace_start) + max_depth = max((record.depth for record in trace_records), default=0) + latest_error_event = next((event for event in reversed(sorted_events) if event.error_msg), None) + if latest_error_event is None: + latest_error_event = next( + (event for event in reversed(sorted_events) if event.stage.endswith(".error")), + None, + ) + otel_trace_id = stable_otel_trace_id(trace_id) + if is_failed: + status = "failed" + elif is_completed: + status = "completed" + elif is_pending: + status = "pending" + else: + status = "running" + details[trace_id] = { + "trace_id": trace_id, + "otel_trace_id": otel_trace_id, + "task_name": latest.task_name, + "uid": latest.uid, + "status": status, + "current_stage": current_stage, + "duration_s": max(0.0, (now_s if open_spans else latest.timestamp_s) - first.timestamp_s), + "error_msg": latest_error_event.error_msg + if latest_error_event is not None and latest_error_event.error_msg + else f"Error at {latest_error_event.stage}" + if latest_error_event is not None + else None, + "timeline_events": [event.to_dict() for event in sorted_events], + "timeline_spans": [ + span_record_to_payload(record, trace_start, trace_window_s, stage_colors) for record in trace_records + ], + "row_height_px": 56 + (max_depth + 1) * 22, + } + return details + + +def _is_pending_task(latest: TraceEvent, open_spans: list[tuple[str, TraceEvent]]) -> bool: + if open_spans: + return False + if latest.status == "pending": + return True + return latest.stage in { + "xtuner.task.registered", + "xtuner.task.pending", + "xtuner.producer.task_registered", + "xtuner.producer.task.registered", + } + + +def _normalize_jaeger_query_url(jaeger_query_url: str | None) -> str | None: + if jaeger_query_url is None: + return None + jaeger_query_url = jaeger_query_url.strip() + if not jaeger_query_url: + return None + return jaeger_query_url.rstrip("/") + + +def _span_from_stage(stage: str | None) -> str | None: + if stage is None: + return None + span, suffix = _split_span_stage(stage) + if span is not None and suffix is not None: + return span + return stage + + +def _build_span_record(start_event: TraceEvent, end_event: TraceEvent, *, outcome: str) -> TraceSpanRecord: + span, _ = _split_span_stage(start_event.stage) + assert span is not None + duration_s = max(0.0, end_event.timestamp_s - start_event.timestamp_s) + return TraceSpanRecord( + trace_id=start_event.trace_id, + span=span, + display_stage=display_trace_stage(span), + start_s=start_event.timestamp_s, + end_s=end_event.timestamp_s, + duration_s=duration_s, + outcome=outcome, + task_name=end_event.task_name or start_event.task_name, + uid=end_event.uid if end_event.uid is not None else start_event.uid, + status=end_event.status or start_event.status, + train_step=end_event.train_step if end_event.train_step is not None else start_event.train_step, + worker_rank=end_event.worker_rank if end_event.worker_rank is not None else start_event.worker_rank, + error_msg=end_event.error_msg, + ) + + +def _build_elapsed_only_span_record(end_event: TraceEvent, *, span: str, outcome: str) -> TraceSpanRecord: + elapsed_s = max(0.0, end_event.elapsed_s or 0.0) + return TraceSpanRecord( + trace_id=end_event.trace_id, + span=span, + display_stage=display_trace_stage(span), + start_s=end_event.timestamp_s - elapsed_s, + end_s=end_event.timestamp_s, + duration_s=elapsed_s, + outcome=outcome, + task_name=end_event.task_name, + uid=end_event.uid, + status=end_event.status, + train_step=end_event.train_step, + worker_rank=end_event.worker_rank, + error_msg=end_event.error_msg, + ) + + +def _build_open_span_record(start_event: TraceEvent, *, span: str, now_s: float) -> TraceSpanRecord: + duration_s = max(0.0, now_s - start_event.timestamp_s) + return TraceSpanRecord( + trace_id=start_event.trace_id, + span=span, + display_stage=display_trace_stage(span), + start_s=start_event.timestamp_s, + end_s=now_s, + duration_s=duration_s, + outcome="open", + task_name=start_event.task_name, + uid=start_event.uid, + status=start_event.status, + train_step=start_event.train_step, + worker_rank=start_event.worker_rank, + error_msg=start_event.error_msg, + ) + + +def _assign_depths(records: list[TraceSpanRecord]) -> None: + records_by_trace = _group_trace_span_records(records) + for trace_records in records_by_trace.values(): + active_end_times: list[float] = [] + for record in sorted(trace_records, key=lambda item: (item.start_s, -item.end_s)): + active_end_times = [end_s for end_s in active_end_times if end_s > record.start_s] + record.depth = len(active_end_times) + active_end_times.append(record.end_s) + + +def _group_trace_span_records(records: Iterable[TraceSpanRecord]) -> dict[str, list[TraceSpanRecord]]: + grouped: dict[str, list[TraceSpanRecord]] = defaultdict(list) + for record in records: + grouped[record.trace_id].append(record) + for trace_id in grouped: + grouped[trace_id].sort(key=lambda record: (record.start_s, record.depth, record.end_s)) + return dict(grouped) diff --git a/xtuner/tools/task_trace_view.py b/xtuner/tools/task_trace_view.py new file mode 100644 index 0000000000..b464384fe1 --- /dev/null +++ b/xtuner/tools/task_trace_view.py @@ -0,0 +1,1024 @@ +from __future__ import annotations + +import json +from typing import Any +from urllib.error import HTTPError, URLError +from urllib.parse import quote +from urllib.request import Request, urlopen + + +JAEGER_QUERY_TIMEOUT_S = 2.0 + + +def normalize_jaeger_query_url(jaeger_query_url: str | None) -> str | None: + if jaeger_query_url is None: + return None + jaeger_query_url = jaeger_query_url.strip() + if not jaeger_query_url: + return None + return jaeger_query_url.rstrip("/") + + +def fetch_jaeger_trace_payload( + jaeger_query_url: str, + otel_trace_id: str, + *, + timeout_s: float = JAEGER_QUERY_TIMEOUT_S, +) -> dict[str, Any]: + base_url = normalize_jaeger_query_url(jaeger_query_url) + if base_url is None: + return _jaeger_error_payload(otel_trace_id, "Jaeger query URL is not configured.") + + request = Request( + f"{base_url}/api/traces/{quote(otel_trace_id)}", + headers={"Accept": "application/json"}, + ) + try: + with urlopen(request, timeout=timeout_s) as response: + raw = json.loads(response.read().decode("utf-8")) + except HTTPError as exc: + return _jaeger_error_payload(otel_trace_id, f"Jaeger query failed with HTTP {exc.code}.") + except (OSError, URLError, TimeoutError, json.JSONDecodeError) as exc: + return _jaeger_error_payload(otel_trace_id, f"Jaeger query failed: {exc}") + return normalize_jaeger_trace_response(raw, otel_trace_id) + + +def normalize_jaeger_trace_response(raw_response: dict[str, Any], otel_trace_id: str) -> dict[str, Any]: + traces = raw_response.get("data") or [] + if not traces: + return { + "trace_id": otel_trace_id, + "found": False, + "span_count": 0, + "duration_us": 0, + "spans": [], + } + + trace = traces[0] + processes = trace.get("processes") or {} + normalized_spans: list[dict[str, Any]] = [] + trace_start_us: int | None = None + trace_end_us: int | None = None + + for span in trace.get("spans") or []: + start_us = int(span.get("startTime") or 0) + duration_us = int(span.get("duration") or 0) + end_us = start_us + max(0, duration_us) + trace_start_us = start_us if trace_start_us is None else min(trace_start_us, start_us) + trace_end_us = end_us if trace_end_us is None else max(trace_end_us, end_us) + + process = processes.get(span.get("processID"), {}) + tags = _jaeger_tags_to_dict(span.get("tags") or []) + error_msg = _jaeger_error_message(tags, span.get("logs") or []) + is_error = bool(tags.get("error")) or str(tags.get("otel.status_code", "")).upper() == "ERROR" + normalized_spans.append( + { + "trace_id": span.get("traceID") or trace.get("traceID") or otel_trace_id, + "span_id": str(span.get("spanID") or ""), + "parent_span_id": _jaeger_parent_span_id(span.get("references") or []), + "operation_name": span.get("operationName") or "unknown", + "service_name": process.get("serviceName") or span.get("processID") or "unknown", + "start_time_us": start_us, + "duration_us": duration_us, + "tags": tags, + "is_error": is_error, + "error_msg": error_msg, + } + ) + + trace_start_us = trace_start_us or 0 + trace_end_us = trace_end_us or trace_start_us + for span in normalized_spans: + span["relative_start_us"] = max(0, span["start_time_us"] - trace_start_us) + + return { + "trace_id": trace.get("traceID") or otel_trace_id, + "found": True, + "span_count": len(normalized_spans), + "duration_us": max(0, trace_end_us - trace_start_us), + "spans": sorted(normalized_spans, key=lambda item: (item["start_time_us"], item["duration_us"])), + } + + +def _jaeger_error_payload(otel_trace_id: str, message: str) -> dict[str, Any]: + return { + "trace_id": otel_trace_id, + "found": False, + "span_count": 0, + "duration_us": 0, + "spans": [], + "error": message, + } + + +def _jaeger_tags_to_dict(tags: list[dict[str, Any]]) -> dict[str, Any]: + result: dict[str, Any] = {} + for tag in tags: + key = tag.get("key") + if key is None: + continue + result[str(key)] = tag.get("value") + return result + + +def _jaeger_parent_span_id(references: list[dict[str, Any]]) -> str | None: + for reference in references: + if reference.get("refType") == "CHILD_OF" and reference.get("spanID") is not None: + return str(reference["spanID"]) + if references and references[0].get("spanID") is not None: + return str(references[0]["spanID"]) + return None + + +def _jaeger_error_message(tags: dict[str, Any], logs: list[dict[str, Any]]) -> str | None: + for key in ("error.message", "otel.status_description"): + value = tags.get(key) + if value: + return str(value) + for log in logs: + fields = _jaeger_tags_to_dict(log.get("fields") or []) + value = fields.get("error.message") or fields.get("event") + if value: + return str(value) + return None + + +def render_unified_trace_html( + payload: dict[str, Any], + *, + live: bool = False, + api_url: str = "/api/trace", + refresh_interval_s: float = 1.0, +) -> str: + data_json = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).replace(" + + + + + XTuner Task Trace Dashboard + + + +
+
+

XTuner Task Trace Dashboard

+
+
+
+
+
+ + +
+
+
+
+ +
+ +
+
+
+
+
+

Current Stage Distribution & Duration

+ +
+
+ + + + + + + + + + + + + + +
StageCurrentPendingRunningDoneAvgP95Max
+
+
+
+ +
+
+ + + +
+
+ + + + + + + + + + + +
TaskStatusStageDurationEvents
+
+
+
+ +
+
+
+

Task Timeline

+ +
+
+
+
+
+ + + + +""" diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index 836d333df2..5f0af57aed 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -124,6 +124,7 @@ class RolloutState(BaseModel): # --- 状态 --- uid: int | None = None + trace_id: str | None = None # Stable task trace identity; may be folded into uid after semantics converge. task_name: str | None = None status: Status = Status.INIT error_msg: str | None = None diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py index 835683149a..b32bb27060 100644 --- a/xtuner/v1/rl/agent_loop/agent_loop.py +++ b/xtuner/v1/rl/agent_loop/agent_loop.py @@ -13,6 +13,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status from xtuner.v1.rl.judger import Judger from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.trace import xtuner_trace_function from xtuner.v1.rl.utils import ( JUDGER_PAUSE_JUDGE_TASK_TIMEOUT_S, CPUActorLauncher, @@ -189,6 +190,7 @@ def __init__( @abstractmethod async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: ... + @xtuner_trace_function("xtuner.agent_loop.generate_group") async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: pending_tasks = [] for state in rollout_state: @@ -208,6 +210,7 @@ async def run_judger(self, rollout_state: RolloutState) -> RolloutState: ... @overload async def run_judger(self, rollout_state: list[RolloutState]) -> list[RolloutState]: ... + @xtuner_trace_function("xtuner.judger.judge") async def run_judger(self, rollout_state: RolloutState | list[RolloutState]) -> RolloutState | list[RolloutState]: assert self.judger is not None if isinstance(rollout_state, list): diff --git a/xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py b/xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py index 7d66f0e203..1764929330 100644 --- a/xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py +++ b/xtuner/v1/rl/agent_loop/localhost_agent_loop/agent_in_localhost_loop.py @@ -18,6 +18,12 @@ from xtuner.v1.rl.rollout import RolloutController from xtuner.v1.rl.rollout.chat_template import canonicalize_messages_for_chat_template from xtuner.v1.rl.rollout.trace_store import get_store +from xtuner.v1.rl.trace import ( + build_rollout_trace_attributes, + trace_task_context, + xtuner_trace_function, + xtuner_trace_span, +) from xtuner.v1.rl.utils import create_task from ..agent_loop import AgentLoop, AgentLoopConfig @@ -125,6 +131,7 @@ def __init__( self._sample_semaphore = asyncio.Semaphore(max_concurrent_samples) if max_concurrent_samples else None self.mode = mode + @xtuner_trace_function("xtuner.agent_in_localhost.generate_group") async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: async def generate_one(state: RolloutState) -> RolloutState: if self._sample_semaphore is None: @@ -140,6 +147,7 @@ async def generate_one(state: RolloutState) -> RolloutState: return await asyncio.gather(*tasks) + @xtuner_trace_function("xtuner.agent_in_localhost.generate_sample") async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: try: if self.sample_timeout_s is not None and self.sample_timeout_s > 0: @@ -175,8 +183,9 @@ async def _generate_sample_impl(self, rollout_state: RolloutState) -> RolloutSta if rollout_state.uid is None: rollout_state.uid = uuid.uuid4().int item.uid = rollout_state.uid + item.trace_id = rollout_state.trace_id item.group_id = rollout_state.message_uid - result = await self._run_item(item) + result = await self._run_item(item, trace_attrs=build_rollout_trace_attributes(rollout_state)) await self._fill_rollout_state(rollout_state, result) return rollout_state @@ -199,11 +208,11 @@ def _fail_rollout_state( rollout_state.error_msg = error_msg return rollout_state - async def _run_item(self, item: AgentRolloutItem) -> AgentRolloutItem: + async def _run_item(self, item: AgentRolloutItem, trace_attrs: dict[str, Any] | None = None) -> AgentRolloutItem: runner = _resolve_runner(item.pipeline) if runner is None: raise ValueError("AgentRolloutItem.pipeline is required.") - with ctx_session_id.set(str(item.uid)): + with ctx_session_id.set(str(item.uid)), trace_task_context(trace_attrs): return await runner.run(item) async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRolloutItem) -> None: @@ -235,23 +244,39 @@ async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRoll if item.status != RolloutStatus.COMPLETED: return - segment = item.artifacts["messages"][-1] - text = self.tokenizer.apply_chat_template( - canonicalize_messages_for_chat_template(segment["messages"]), - tools=segment["tools"], - tokenize=False, - add_generation_prompt=False, - ) - prompt_text = text[:-1] if text.endswith("\n") else text - data = await get_store().export_training_trace.remote(str(rollout_state.uid), prompt_text) - - rollout_state.input_ids = data["input_ids"] - rollout_state.labels = data["labels"] - rollout_state.response_ids = [ - token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100 - ] - rollout_state.logprobs = data["logprobs"] - rollout_state.routed_experts = data["routed_experts"] + async with xtuner_trace_span( + rollout_state, + "xtuner.agent_in_localhost.materialize_trajectory", + agent_status=item.status.value, + **{"xtuner.stage.kind": "materialize"}, + ) as span: + segment = item.artifacts["messages"][-1] + messages = segment["messages"] + tools = segment["tools"] + span.annotate( + agent_message_count=len(messages), + agent_has_tools=tools is not None, + ) + text = self.tokenizer.apply_chat_template( + canonicalize_messages_for_chat_template(messages), + tools=tools, + tokenize=False, + add_generation_prompt=False, + ) + prompt_text = text[:-1] if text.endswith("\n") else text + data = await get_store().export_training_trace.remote(str(rollout_state.uid), prompt_text) + + rollout_state.input_ids = data["input_ids"] + rollout_state.labels = data["labels"] + rollout_state.response_ids = [ + token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100 + ] + rollout_state.logprobs = data["logprobs"] + rollout_state.routed_experts = data["routed_experts"] + span.annotate( + input_tokens=len(rollout_state.input_ids or []), + label_tokens=len(rollout_state.labels or []), + ) content = response_message.get("content") rollout_state.response = content if isinstance(content, str) else (str(content) if content is not None else "") diff --git a/xtuner/v1/rl/agent_loop/localhost_agent_loop/judger.py b/xtuner/v1/rl/agent_loop/localhost_agent_loop/judger.py index 1979139ab1..6489b9b8e4 100644 --- a/xtuner/v1/rl/agent_loop/localhost_agent_loop/judger.py +++ b/xtuner/v1/rl/agent_loop/localhost_agent_loop/judger.py @@ -16,6 +16,7 @@ StageStatus, ) from xtuner.v1.rl.judger.native import Judger +from xtuner.v1.rl.trace import xtuner_trace_span class LocalhostJudgerStage: @@ -42,44 +43,53 @@ def __init__( async def run(self, item: AgentRolloutItem, record: StageRecord) -> float: record.status = StageStatus.RUNNING record.started_at = record.started_at or time.monotonic() - try: - # reward_model stays as-is (dataset-provided ground_truth/style etc.). - # Per-rollout artifacts (response message, agent trace) flow through extra_fields. - reward_model = dict(item.reward_model or {}) - segment = item.artifacts["messages"][-1] - response_message = item.artifacts.get("response_message") or {} - content = response_message.get("content") - response = content if isinstance(content, str) else (str(content) if content is not None else "") - - rollout_state = RolloutState( - message=[{"role": "user", "content": item.instruction}], - response=response, - reward_model=reward_model, - extra_fields={ - "agent_messages": segment["messages"], - "response_message": response_message, - }, - status=Status.COMPLETED, - ) - judged = await self.judger.judge(rollout_state) - reward_payload = judged.reward or {} - if self.reward_key not in reward_payload: - raise KeyError(f"judger reward payload has no {self.reward_key!r}: {reward_payload!r}") - record.metadata["reward"] = reward_payload - record.score = float(reward_payload[self.reward_key]) - record.status = StageStatus.COMPLETED - return record.score - except Exception as exc: - record.status = StageStatus.FAILED - record.error = record.error or RolloutError( - stage=self.name, - category="judger", - type=type(exc).__name__, - message=str(exc), - ) - raise - finally: - record.finished_at = time.monotonic() + async with xtuner_trace_span( + item, + "xtuner.localhost.judger", + task_name=item.data_source, + uid=item.uid if item.uid is not None else item.id, + task_id=item.id, + judger_name=self.name, + reward_key=self.reward_key, + **{"xtuner.stage.kind": "judge"}, + ): + try: + # reward_model stays as-is (dataset-provided ground_truth/style etc.). + # Per-rollout artifacts (response message, agent trace) flow through extra_fields. + reward_model = dict(item.reward_model or {}) + segment = item.artifacts["messages"][-1] + response_message = item.artifacts.get("response_message") or {} + content = response_message.get("content") + response = content if isinstance(content, str) else (str(content) if content is not None else "") + rollout_state = RolloutState( + message=[{"role": "user", "content": item.instruction}], + response=response, + reward_model=reward_model, + extra_fields={ + "agent_messages": segment["messages"], + "response_message": response_message, + }, + status=Status.COMPLETED, + ) + judged = await self.judger.judge(rollout_state) + reward_payload = judged.reward or {} + if self.reward_key not in reward_payload: + raise KeyError(f"judger reward payload has no {self.reward_key!r}: {reward_payload!r}") + record.metadata["reward"] = reward_payload + record.score = float(reward_payload[self.reward_key]) + record.status = StageStatus.COMPLETED + return record.score + except Exception as exc: + record.status = StageStatus.FAILED + record.error = record.error or RolloutError( + stage=self.name, + category="judger", + type=type(exc).__name__, + message=str(exc), + ) + raise + finally: + record.finished_at = time.monotonic() __all__ = ["LocalhostJudgerStage"] diff --git a/xtuner/v1/rl/agent_loop/localhost_agent_loop/runner.py b/xtuner/v1/rl/agent_loop/localhost_agent_loop/runner.py index 79a17e2266..d690042b3d 100644 --- a/xtuner/v1/rl/agent_loop/localhost_agent_loop/runner.py +++ b/xtuner/v1/rl/agent_loop/localhost_agent_loop/runner.py @@ -16,6 +16,7 @@ StageRecord, StageStatus, ) +from xtuner.v1.rl.trace import xtuner_trace_span from xtuner.v1.utils import get_logger @@ -35,39 +36,71 @@ async def run(self, item: AgentRolloutItem) -> AgentRolloutItem: raise ValueError("AgentRolloutItem.instruction is required by LocalhostRunner.run") item.status = RolloutStatus.RUNNING tid = item.id + trace_kwargs = { + "task_name": item.data_source, + "uid": item.uid if item.uid is not None else tid, + "task_id": tid, + } + t_infer: float | None = None t_validate: float | None = None try: - await self.infer.run(item, item.infer) - if item.infer.status != StageStatus.COMPLETED: - return self._fail(item, item.infer.error) - - if self.validate is not None: + async with xtuner_trace_span( + item, + "xtuner.localhost.run_total", + **trace_kwargs, + **{"xtuner.stage.kind": "agent_loop"}, + ) as total_span: t0 = time.monotonic() - validate_name = getattr(self.validate, "name", "validate") - validate_record = item.judgers.setdefault( - validate_name, - StageRecord(), - ) - try: - score = float(await self.validate.run(item, validate_record)) - except Exception: - return self._fail( - item, - validate_record.error - or _first_judger_error(item) - or RolloutError( - stage=validate_name, - category="validate_failed", - type=type(self.validate).__name__, - message="validate failed", - ), + async with xtuner_trace_span( + item, + "xtuner.localhost.infer", + **trace_kwargs, + **{"xtuner.stage.kind": "agent_run"}, + ) as infer_span: + await self.infer.run(item, item.infer) + if item.infer.status != StageStatus.COMPLETED: + infer_span.mark_error(_format_error(item.infer.error)) + t_infer = time.monotonic() - t0 + if item.infer.status != StageStatus.COMPLETED: + total_span.mark_error(_format_error(item.infer.error)) + return self._fail(item, item.infer.error) + + if self.validate is not None: + t1 = time.monotonic() + validate_name = getattr(self.validate, "name", "validate") + validate_record = item.judgers.setdefault( + validate_name, + StageRecord(), ) - t_validate = time.monotonic() - t0 - item.reward = score - - item.status = RolloutStatus.COMPLETED - return item + async with xtuner_trace_span( + item, + "xtuner.localhost.validate", + validate_name=validate_name, + **trace_kwargs, + **{"xtuner.stage.kind": "judge"}, + ) as validate_span: + try: + score = float(await self.validate.run(item, validate_record)) + except Exception: + error = ( + validate_record.error + or _first_judger_error(item) + or RolloutError( + stage=validate_name, + category="validate_failed", + type=type(self.validate).__name__, + message="validate failed", + ) + ) + validate_span.mark_error(_format_error(error)) + total_span.mark_error(_format_error(error)) + return self._fail(item, error) + t_validate = time.monotonic() - t1 + item.reward = score + + item.status = RolloutStatus.COMPLETED + return item except Exception as exc: promoted = ( item.infer.error @@ -82,7 +115,7 @@ async def run(self, item: AgentRolloutItem) -> AgentRolloutItem: get_logger().error(f"[{tid}] traceback:\n{traceback.format_exc()}") return self._fail(item, promoted) finally: - self._log_final(tid, item, t_validate) + self._log_final(tid, item, t_infer, t_validate) # -- internals -- @@ -99,13 +132,21 @@ def _fail(self, item: AgentRolloutItem, error: RolloutError | None) -> AgentRoll get_logger().error(f"[{item.id}] failed: unknown error") return item - def _log_final(self, tid: str, item: AgentRolloutItem, t_validate: float | None) -> None: + def _log_final( + self, + tid: str, + item: AgentRolloutItem, + t_infer: float | None, + t_validate: float | None, + ) -> None: agent_name = item.infer.agent.name if item.infer.agent is not None else "?" parts = [f"status={item.status.value}", f"agent={agent_name}"] if item.reward is not None: parts.append(f"reward={item.reward:.4f}") if item.infer.started_at and item.infer.finished_at: parts.append(f"t_infer={item.infer.finished_at - item.infer.started_at:.1f}s") + elif t_infer is not None: + parts.append(f"t_infer={t_infer:.1f}s") if t_validate is not None: parts.append(f"t_validate={t_validate:.1f}s") if item.status == RolloutStatus.FAILED and item.error: @@ -120,4 +161,12 @@ def _first_judger_error(item: AgentRolloutItem) -> RolloutError | None: return None +def _format_error(error: RolloutError | None) -> str: + if error is None: + return "unknown error" + stage = f"{error.stage}/" if error.stage else "" + typ = f" ({error.type})" if error.type else "" + return f"{stage}{error.category}{typ}: {error.message}" + + __all__ = ["LocalhostRunner"] diff --git a/xtuner/v1/rl/agent_loop/localhost_agent_loop/stage.py b/xtuner/v1/rl/agent_loop/localhost_agent_loop/stage.py index 42c621a7db..f8132ce53c 100644 --- a/xtuner/v1/rl/agent_loop/localhost_agent_loop/stage.py +++ b/xtuner/v1/rl/agent_loop/localhost_agent_loop/stage.py @@ -19,6 +19,7 @@ StageResult, StageStatus, ) +from xtuner.v1.rl.trace import xtuner_trace_span from xtuner.v1.utils import get_logger @@ -44,44 +45,58 @@ async def run(self, item: AgentRolloutItem, record: StageRecord) -> StageResult: record.status = StageStatus.RUNNING record.started_at = record.started_at or time.monotonic() agent = None - try: - spec = self._pick_agent(item, record) - agent = create_object(deepcopy(_resolve_agent_config(spec.config))) - output = await agent(item.instruction) - response_message = output.model_dump(mode="json") if hasattr(output, "model_dump") else None - if response_message is None: - raise TypeError("Agent forward must return an AgentMessage-like object.") - item.artifacts["response_message"] = response_message - messages = agent.get_messages() - if not isinstance(messages, list) or not messages: - raise ValueError("Agent messages artifact must be a non-empty list.") - segment = messages[-1] - if not isinstance(segment, dict) or "messages" not in segment or "tools" not in segment: - raise ValueError("Agent messages trace segment must contain messages and tools.") - if not isinstance(segment["messages"], list): - raise TypeError("Agent messages trace segment.messages must be a list.") - item.artifacts["messages"] = messages - content = response_message.get("content") - stdout = content if isinstance(content, str) else (str(content) if content is not None else "") - result = StageResult(stdout=stdout, return_code=0) - record.entry_result = result - record.status = StageStatus.COMPLETED - return result - except Exception as exc: - record.status = StageStatus.FAILED - record.error = record.error or RolloutError( - stage=self.name, - category="agent_exception", - type=type(exc).__name__, - message=str(exc), - ) - result = StageResult(return_code=None, error=str(exc), stderr=str(exc)) - record.entry_result = result - return result - finally: - record.finished_at = time.monotonic() - if agent is not None: - await _close_agent(agent) + async with xtuner_trace_span( + item, + "xtuner.localhost.agent", + task_name=item.data_source, + uid=item.uid if item.uid is not None else item.id, + task_id=item.id, + stage_name=self.name, + **{"xtuner.stage.kind": "agent_run"}, + ) as span: + try: + spec = self._pick_agent(item, record) + span.annotate( + agent_name=spec.name, + agent_config=spec.config if isinstance(spec.config, str) else "", + ) + agent = create_object(deepcopy(_resolve_agent_config(spec.config))) + output = await agent(item.instruction) + response_message = output.model_dump(mode="json") if hasattr(output, "model_dump") else None + if response_message is None: + raise TypeError("Agent forward must return an AgentMessage-like object.") + item.artifacts["response_message"] = response_message + messages = agent.get_messages() + if not isinstance(messages, list) or not messages: + raise ValueError("Agent messages artifact must be a non-empty list.") + segment = messages[-1] + if not isinstance(segment, dict) or "messages" not in segment or "tools" not in segment: + raise ValueError("Agent messages trace segment must contain messages and tools.") + if not isinstance(segment["messages"], list): + raise TypeError("Agent messages trace segment.messages must be a list.") + item.artifacts["messages"] = messages + content = response_message.get("content") + stdout = content if isinstance(content, str) else (str(content) if content is not None else "") + result = StageResult(stdout=stdout, return_code=0) + record.entry_result = result + record.status = StageStatus.COMPLETED + return result + except Exception as exc: + record.status = StageStatus.FAILED + record.error = record.error or RolloutError( + stage=self.name, + category="agent_exception", + type=type(exc).__name__, + message=str(exc), + ) + span.mark_error(f"{type(exc).__name__}: {exc}") + result = StageResult(return_code=None, error=str(exc), stderr=str(exc)) + record.entry_result = result + return result + finally: + record.finished_at = time.monotonic() + if agent is not None: + await _close_agent(agent) def _pick_agent(self, item: AgentRolloutItem, record: StageRecord) -> LocalhostAgentSpec: group_id = item.group_id or 0 diff --git a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py index 55feb25f5f..36937c5bf6 100644 --- a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py +++ b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/agent_in_sandbox_loop.py @@ -15,6 +15,12 @@ from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status from xtuner.v1.rl.judger import Judger from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.trace import ( + build_rollout_trace_attributes, + trace_task_context, + xtuner_trace_function, + xtuner_trace_span, +) from xtuner.v1.rl.utils import create_task from ...rollout.chat_template import canonicalize_messages_for_chat_template @@ -256,6 +262,7 @@ async def _throttle_sandbox_create(self) -> None: if self._sandbox_create_limiter is not None: await self._sandbox_create_limiter.acquire() + @xtuner_trace_function("xtuner.agent_in_sandbox.generate_group") async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: async def generate_one(state: RolloutState) -> RolloutState: if self._sample_semaphore is None: @@ -272,15 +279,17 @@ async def generate_one(state: RolloutState) -> RolloutState: group_samples = await generated_samples return group_samples + @xtuner_trace_function("xtuner.agent_in_sandbox.generate_sample") async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: try: rollout_item = rollout_state.extra_fields["rollout_item"].model_copy(deep=True) if rollout_state.uid is None: rollout_state.uid = uuid.uuid4().int rollout_item.uid = rollout_state.uid + rollout_item.trace_id = rollout_state.trace_id rollout_item.group_id = rollout_state.message_uid await self._throttle_sandbox_create() - result = await self._run_item(rollout_item) + result = await self._run_item(rollout_item, trace_attrs=build_rollout_trace_attributes(rollout_state)) await self._fill_rollout_state(rollout_state, result) return rollout_state except Exception as exc: @@ -294,11 +303,12 @@ async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> Rollou self.logger.error(f"[AgentInSandboxLoop] failed: {exc}\n{traceback.format_exc()}") return rollout_state - async def _run_item(self, item: AgentRolloutItem) -> AgentRolloutItem: + async def _run_item(self, item: AgentRolloutItem, trace_attrs: dict[str, Any] | None = None) -> AgentRolloutItem: runner = _resolve_runner(item.pipeline, str(item.uid)) if runner is None: raise ValueError("AgentRolloutItem.pipeline is required.") - return await runner.run(item) + with trace_task_context(trace_attrs): + return await runner.run(item) async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRolloutItem) -> None: if self.mode == "eval": @@ -313,30 +323,44 @@ async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRoll if item.status != RolloutStatus.COMPLETED: return - messages, tools = _load_latest_trace_segment(item.artifacts, require_tools=True) - if not messages: - raise ValueError("Agent artifacts must contain at least one trainable messages trace.") - session_id = rollout_state.uid - - trace_store = get_store() - text = self.tokenizer.apply_chat_template( - canonicalize_messages_for_chat_template(messages), - tools=tools, - tokenize=False, - add_generation_prompt=False, - ) - prompt_text = text[:-1] if text.endswith("\n") else text - data = await trace_store.export_training_trace.remote(str(session_id), prompt_text) - - rollout_state.input_ids = data["input_ids"] - rollout_state.labels = data["labels"] - # Agentic training consumes input_ids/labels directly. response_ids is - # filled here only so rollout throughput logging can print rollout_tgs. - rollout_state.response_ids = [ - token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100 - ] - rollout_state.logprobs = data["logprobs"] - rollout_state.routed_experts = data["routed_experts"] + async with xtuner_trace_span( + rollout_state, + "xtuner.agent_in_sandbox.materialize_trajectory", + agent_status=item.status.value, + **{"xtuner.stage.kind": "materialize"}, + ) as span: + messages, tools = _load_latest_trace_segment(item.artifacts, require_tools=True) + span.annotate( + agent_message_count=len(messages), + agent_has_tools=tools is not None, + ) + if not messages: + raise ValueError("Agent artifacts must contain at least one trainable messages trace.") + session_id = rollout_state.uid + + trace_store = get_store() + text = self.tokenizer.apply_chat_template( + canonicalize_messages_for_chat_template(messages), + tools=tools, + tokenize=False, + add_generation_prompt=False, + ) + prompt_text = text[:-1] if text.endswith("\n") else text + data = await trace_store.export_training_trace.remote(str(session_id), prompt_text) + + rollout_state.input_ids = data["input_ids"] + rollout_state.labels = data["labels"] + # Agentic training consumes input_ids/labels directly. response_ids is + # filled here only so rollout throughput logging can print rollout_tgs. + rollout_state.response_ids = [ + token_id for token_id, label in zip(data["input_ids"][1:], data["labels"][1:]) if label != -100 + ] + rollout_state.logprobs = data["logprobs"] + rollout_state.routed_experts = data["routed_experts"] + span.annotate( + input_tokens=len(rollout_state.input_ids or []), + label_tokens=len(rollout_state.labels or []), + ) def _fill_eval_rollout_state(self, rollout_state: RolloutState, item: AgentRolloutItem) -> None: is_success = item.status == RolloutStatus.COMPLETED diff --git a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py index e2db7609d2..db9c1d79a3 100644 --- a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py +++ b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py @@ -25,8 +25,8 @@ RolloutError, RolloutStatus, ) -from xtuner.v1.rl.agent_loop.sandbox_agent_loop.trace import span from xtuner.v1.rl.agent_loop.sandbox_agent_loop.validator import JudgerValidator +from xtuner.v1.rl.trace import xtuner_trace_span from xtuner.v1.utils import get_logger @@ -82,15 +82,29 @@ async def run(self, item: AgentRolloutItem) -> AgentRolloutItem: item.infer.workspace = infer_spec.workspace_path tid = item.id - uid_obs = str(item.uid) if item.uid is not None else "" t_acquire: float | None = None t_infer: float | None = None t_validate: float | None = None + trace_kwargs = { + "task_name": item.data_source, + "uid": item.uid if item.uid is not None else tid, + "task_id": tid, + } try: - with span(uid_obs, "run_total", task_id=tid) as total_span: + async with xtuner_trace_span( + item, + "xtuner.sandbox.run_total", + **trace_kwargs, + **{"xtuner.stage.kind": "agent_loop"}, + ) as total_span: # ─── acquire infer sandbox ─────────────────────────────── t0 = time.monotonic() - with span(uid_obs, "acquire", task_id=tid) as acquire_span: + async with xtuner_trace_span( + item, + "xtuner.sandbox.acquire", + **trace_kwargs, + **{"xtuner.stage.kind": "sandbox"}, + ) as acquire_span: infer_client = await pool.get(infer_sandbox, record=item.infer) item.infer.sandbox_env_id = pool.env_id(infer_sandbox) sandbox_url = pool.url(infer_sandbox) @@ -106,7 +120,12 @@ async def run(self, item: AgentRolloutItem) -> AgentRolloutItem: # ─── infer ────────────────────────────────────────────── t1 = time.monotonic() - with span(uid_obs, "infer", task_id=tid) as infer_span: + async with xtuner_trace_span( + item, + "xtuner.sandbox.infer", + **trace_kwargs, + **{"xtuner.stage.kind": "agent_run"}, + ) as infer_span: infer_result = await self.infer.run(infer_client, item, item.infer) if not infer_result.ok: infer_span.mark_error(_format_error(item.infer.error)) @@ -117,21 +136,26 @@ async def run(self, item: AgentRolloutItem) -> AgentRolloutItem: # ─── validate ─────────────────────────────────────────── t2 = time.monotonic() - with span(uid_obs, "validate", task_id=tid): + async with xtuner_trace_span( + item, + "xtuner.sandbox.validate", + **trace_kwargs, + **{"xtuner.stage.kind": "judge"}, + ) as validate_span: score, failed = await self.validate.run(item, pool) - t_validate = time.monotonic() - t2 - item.reward = score - if failed: - return self._fail( - item, - _first_judger_error(item) - or RolloutError( + t_validate = time.monotonic() - t2 + item.reward = score + if failed: + error = _first_judger_error(item) or RolloutError( stage="validate", category="validate_failed", type="JudgerValidator", message="all judgers failed" if not item.judgers else "validate failed", - ), - ) + ) + error_msg = _format_error(error) + validate_span.mark_error(error_msg) + total_span.mark_error(error_msg) + return self._fail(item, error) item.status = RolloutStatus.COMPLETED return item diff --git a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py index 70a7252ae6..409be55773 100644 --- a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py +++ b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py @@ -46,10 +46,24 @@ async def hook(client, item, record) -> None StageResult, StageStatus, ) -from xtuner.v1.rl.agent_loop.sandbox_agent_loop.trace import span +from xtuner.v1.rl.trace import get_trace_env_vars, make_trace_context_carrier, xtuner_trace_span from xtuner.v1.utils import get_logger +def _merge_trace_env(env: dict[str, str]) -> dict[str, str]: + carrier = make_trace_context_carrier() + trace_env = get_trace_env_vars() + if not carrier and not trace_env: + return env + + merged = dict(trace_env) + merged.update(env) + for key, value in carrier.items(): + if value: + merged[f"OTEL_PROPAGATOR_{key.upper().replace('-', '_')}"] = value + return merged + + # ───────────────────────────────────────────────────────────────── # Hook base # ───────────────────────────────────────────────────────────────── @@ -462,12 +476,22 @@ async def run( record.entries.append(entry) entry.status = StageStatus.RUNNING entry.started_at = time.monotonic() - uid_obs = str(item.uid) if item.uid is not None else "" - with span(uid_obs, f"entry:{self.name}", entry_kind="ShellEntry"): + async with xtuner_trace_span( + item, + f"xtuner.sandbox.entry:{self.name}", + task_name=item.data_source, + uid=item.uid if item.uid is not None else item.id, + task_id=item.id, + entry_kind="ShellEntry", + entry_name=self.name, + **{"xtuner.stage.kind": "entry"}, + ) as span: try: - outcome = await self._execute(client, self.env) + outcome = await self._execute(client, _merge_trace_env(self.env)) if not outcome.ok and self.failure is not None: outcome = await self.failure.handle(client, item, record, entry, outcome) + if not outcome.ok: + span.mark_error(outcome.result.error or outcome.result.stderr or outcome.reason or "entry failed") self._finish_record(entry, outcome) return outcome except Exception as exc: @@ -573,13 +597,23 @@ async def run( record.entries.append(entry) entry.status = StageStatus.RUNNING entry.started_at = time.monotonic() - uid_obs = str(item.uid) if item.uid is not None else "" - with span(uid_obs, f"entry:{self.name}", entry_kind="DetachedShellEntry"): + async with xtuner_trace_span( + item, + f"xtuner.sandbox.entry:{self.name}", + task_name=item.data_source, + uid=item.uid if item.uid is not None else item.id, + task_id=item.id, + entry_kind="DetachedShellEntry", + entry_name=self.name, + **{"xtuner.stage.kind": "entry"}, + ) as span: try: - outcome = await self._run_detached(client, item, entry, self.env) + outcome = await self._run_detached(client, item, entry, _merge_trace_env(self.env)) await self._fill_output_files(client, entry, outcome.result) if not outcome.ok and self.failure is not None: outcome = await self.failure.handle(client, item, record, entry, outcome) + if not outcome.ok: + span.mark_error(outcome.result.error or outcome.result.stderr or outcome.reason or "entry failed") self._finish_record(entry, outcome) return outcome except Exception as exc: diff --git a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/schemas.py b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/schemas.py index f179e13163..4c4c2d2a98 100644 --- a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/schemas.py +++ b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/schemas.py @@ -172,6 +172,8 @@ class AgentRolloutItem(BaseModel): task the agent sees. Pipeline exports it as ``$TASK_INSTRUCTION``. - ``task_root``: host task directory available to the current worker. - ``uid``: rollout identity supplied by the outer dataflow. + - ``trace_id``: stable per-task trace identity supplied by the outer + dataflow. - ``pipeline`` / ``pipeline_overrides``: lazy runner config binding. Runtime/result fields are filled in place by :class:`Runner.run`: @@ -196,6 +198,7 @@ class AgentRolloutItem(BaseModel): task_root: Path | None = None group_id: int | None = None # rollout group identity (same query's K-rollouts share) uid: int | None = None # per-rollout unique id (different across K-rollouts of same query) + trace_id: str | None = None pipeline: PipelineConfig | Any | None = None pipeline_overrides: dict[str, Any] = Field(default_factory=dict) diff --git a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/trace.py b/xtuner/v1/rl/agent_loop/sandbox_agent_loop/trace.py deleted file mode 100644 index abcc3e9582..0000000000 --- a/xtuner/v1/rl/agent_loop/sandbox_agent_loop/trace.py +++ /dev/null @@ -1,411 +0,0 @@ -"""Per-sample trace emission for the InstallAgentEnvironment pipeline. - -Four outputs under ``$WORK_DIR/trace/``: - -* ``fates.{actor_id}.{pid}.jsonl`` — one terminal line per sample, capturing - whether it ended up COMPLETED or SKIPPED plus the stage/reason. -* ``spans.{actor_id}.{pid}.jsonl`` — one line per stage enter/exit, with - duration and ok/err. Sampled inside the async pipeline via a context - manager whose yield region may ``await``. -* ``llm_calls.{pid}.jsonl`` — one line per ``/v1/chat/completions`` request - served by :class:`RolloutController`, with total / tokenize / rollout / - post durations and token counts. Independent of the per-sample fate/span - writer because the controller is its own actor with no owning sample uid. -* ``diagnostics/{ts}_{task_id}_{uid}.log`` (+ optional ``.daemon.log``) — - unstructured per-failure bundle with the pulled daemon log tail. Written - whenever a sample fails with a non-zero entry rc or runner-level - exception; the header file is always produced so a missing daemon log - still leaves a breadcrumb (``daemon_log_file=(unavailable)``). - -Each Ray actor writes its own two files so concurrent writes never contend -a single file descriptor across processes. Line-buffered + per-line flush -keeps ``tail -f`` and crash-recovery both simple. - -If the ``WORK_DIR`` environment variable is unset the module degrades to a -zero-cost no-op — tests and local one-shot scripts never have to care. -""" - -from __future__ import annotations - -import json -import os -import time -import traceback -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Iterator, TextIO - -from xtuner.v1.utils import get_logger - - -_writer: _TraceWriter | None = None - - -def init_writer(actor_id: str | None = None) -> None: - """Open the per-actor fates/spans files. Safe to call twice (subsequent - calls are no-ops). - - Args: - actor_id (str | None): Stable identifier for this actor process, - folded into the output filename to keep per-actor files apart. - When ``None`` the pid alone is used. - """ - global _writer - if _writer is not None: - return - work_dir = os.environ.get("WORK_DIR") - if not work_dir: - get_logger().info("[trace] WORK_DIR not set; trace emission disabled") - return - try: - trace_dir = Path(work_dir) / "trace" - trace_dir.mkdir(parents=True, exist_ok=True) - _writer = _TraceWriter(trace_dir, actor_id) - get_logger().info(f"[trace] writing to {trace_dir} (actor={actor_id or 'none'}, pid={os.getpid()})") - except Exception as exc: - get_logger().warning(f"[trace] init failed ({exc}); trace emission disabled") - _writer = None - - -def emit_fate( - uid: str, - task_id: str | None, - group_id: str | None, - final: str, - failed_stage: str | None = None, - reason: str | None = None, - **extra: Any, -) -> None: - """Record a sample's terminal outcome. - - Args: - uid (str): Per-sample observation id. - task_id (str | None): Dataset task id; best-effort identifier. - group_id (str | None): Shared id across prompt_repeat_k siblings. - final (str): ``"COMPLETED"`` or ``"SKIPPED"``. - failed_stage (str | None): Stage name where failure originated. - reason (str | None): Human-readable error string. - **extra (Any): Additional keys merged into the record. - """ - if _writer is None: - return - record: dict[str, Any] = { - "ts": time.time(), - "uid": uid, - "task_id": task_id, - "group_id": group_id, - "final": final, - "failed_stage": failed_stage, - "reason": reason, - } - if extra: - record.update(extra) - _writer.write_fate(record) - - -class SpanHandle: - """Mutable handle yielded by :func:`span` so the caller can flag logical - failures that do not raise (e.g. a subprocess returning non-zero without - throwing) or attach runtime-discovered fields (e.g. a sandbox URL that only - becomes known after ``acquire`` succeeds). - - Fields default to "success"; see ``mark_error`` and ``annotate``. - """ - - def __init__(self) -> None: - self.ok: bool = True - self.err: str | None = None - # Runtime-discovered fields merged into the span record at emit time. - # Separate from the constructor ``extra`` kwargs because those are - # fixed at span-entry, whereas these are set mid-block. - self.annotations: dict[str, Any] = {} - - def mark_error(self, err: str) -> None: - self.ok = False - self.err = err - - def annotate(self, **fields: Any) -> None: - """Attach runtime-discovered fields to this span record (merged at emit - time). - - Useful for values that are only known after the guarded - code ran — e.g. the sandbox URL returned by ``SandboxPool.get``. - """ - self.annotations.update(fields) - - -@contextmanager -def span(uid: str, stage: str, **extra: Any) -> Iterator[SpanHandle]: - """Time a stage and emit an enter + exit span pair. - - The context manager is synchronous but safe to wrap ``await``-bearing - code because the yield region runs on the caller's event loop. - - Two records are written per span so consumers tailing the file in real - time can see "stage entered" before the stage finishes: - - * ``{"event": "enter", "ts", "uid", "stage", **extra}`` - * ``{"event": "exit", "ts", "uid", "stage", "duration_ms", "ok", "err", - **extra, **annotations}`` - - Args: - uid (str): Per-sample observation id. - stage (str): Short stage name (e.g. ``"acquire"``, ``"infer"``). - **extra (Any): Additional keys merged into both records. - - Returns: - SpanHandle: Yielded handle; call ``handle.mark_error(...)`` or - ``handle.annotate(...)`` inside the block to customize the exit record. - """ - handle = SpanHandle() - t_start = time.monotonic() - if _writer is not None: - enter: dict[str, Any] = { - "ts": time.time(), - "event": "enter", - "uid": uid, - "stage": stage, - } - if extra: - enter.update(extra) - _writer.write_span(enter) - try: - yield handle - except BaseException as exc: - handle.ok = False - handle.err = f"{type(exc).__name__}: {exc}" - raise - finally: - duration_ms = int((time.monotonic() - t_start) * 1000) - if _writer is not None: - record: dict[str, Any] = { - "ts": time.time(), - "event": "exit", - "uid": uid, - "stage": stage, - "duration_ms": duration_ms, - "ok": handle.ok, - "err": handle.err, - } - if extra: - record.update(extra) - if handle.annotations: - record.update(handle.annotations) - _writer.write_span(record) - - -def _reset_for_testing() -> None: - """Close the writer and clear module state. - - For unit tests only. - """ - global _writer, _llm_writer, _llm_writer_ready - if _writer is not None: - _writer.close() - _writer = None - if _llm_writer is not None: - try: - _llm_writer.close() - except Exception: - pass - _llm_writer = None - _llm_writer_ready = False - - -# ───────────────────────────────────────────────────────────────── -# LLM call stream (separate writer, opened lazily per process) -# ───────────────────────────────────────────────────────────────── -# -# ``RolloutController`` emits one record per ``/v1/chat/completions`` via -# :func:`emit_llm_call`. The writer is lazy so install_agent_env actors -# (which never call emit_llm_call) don't create empty files. - -_llm_writer: TextIO | None = None -_llm_writer_ready: bool = False - - -def _ensure_llm_writer() -> None: - global _llm_writer, _llm_writer_ready - if _llm_writer_ready: - return - _llm_writer_ready = True - work_dir = os.environ.get("WORK_DIR") - if not work_dir: - return - try: - trace_dir = Path(work_dir) / "trace" - trace_dir.mkdir(parents=True, exist_ok=True) - path = trace_dir / f"llm_calls.{os.getpid()}.jsonl" - _llm_writer = open(path, "a", buffering=1, encoding="utf-8") - get_logger().info(f"[trace] LLM call stream → {path}") - except Exception as exc: - get_logger().warning(f"[trace] LLM call stream init failed ({exc}); disabled") - _llm_writer = None - - -def emit_llm_call( - total_ms: int, - tokenize_ms: int, - rollout_ms: int, - post_ms: int, - prompt_tokens: int, - completion_tokens: int, - **extra: Any, -) -> None: - """Record timing and token counts for one LLM request. - - Called from ``RolloutController`` on every ``/v1/chat/completions`` — - unlike the slow-request warning log this captures *all* requests so - ``view.py --llm-stats`` can compute real p50 / p95 / p99. - - Args: - total_ms (int): End-to-end request duration. - tokenize_ms (int): Tokenize phase duration. - rollout_ms (int): Rollout worker call duration. - post_ms (int): Post-processing (detokenize + response packing). - prompt_tokens (int): Input token count. - completion_tokens (int): Output token count. - **extra (Any): Additional fields merged into the record. - """ - _ensure_llm_writer() - if _llm_writer is None: - return - record: dict[str, Any] = { - "ts": time.time(), - "total_ms": total_ms, - "tokenize_ms": tokenize_ms, - "rollout_ms": rollout_ms, - "post_ms": post_ms, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - } - if extra: - record.update(extra) - try: - _llm_writer.write(json.dumps(record, ensure_ascii=False, default=str) + "\n") - except Exception as exc: - get_logger().warning(f"[trace] LLM call write failed: {exc}") - - -class _TraceWriter: - def __init__(self, trace_dir: Path, actor_id: str | None) -> None: - suffix = f"{actor_id or 'noactor'}.{os.getpid()}" - self._fates_path = trace_dir / f"fates.{suffix}.jsonl" - self._spans_path = trace_dir / f"spans.{suffix}.jsonl" - self._fates: TextIO = open(self._fates_path, "a", buffering=1, encoding="utf-8") - self._spans: TextIO = open(self._spans_path, "a", buffering=1, encoding="utf-8") - - def write_fate(self, record: dict[str, Any]) -> None: - self._write(self._fates, record) - - def write_span(self, record: dict[str, Any]) -> None: - self._write(self._spans, record) - - def close(self) -> None: - for fp in (self._fates, self._spans): - try: - fp.close() - except Exception: - pass - - @staticmethod - def _write(fp: TextIO, record: dict[str, Any]) -> None: - try: - line = json.dumps(record, ensure_ascii=False, default=str) - except Exception: - line = json.dumps( - {"_encode_error": traceback.format_exc(limit=1), "_repr": repr(record)}, - ensure_ascii=False, - ) - try: - fp.write(line + "\n") - except Exception as exc: - get_logger().warning(f"[trace] write failed: {exc}") - - -# ───────────────────────────────────────────────────────────────── -# Failure-path diagnostics (unstructured dump) -# ───────────────────────────────────────────────────────────────── -# -# ``emit_diagnostic`` writes a per-failure bundle under -# ``$WORK_DIR/trace/diagnostics/`` whenever a sample dies with a non-zero -# entry rc or runner-level exception. The header ``.log`` is always -# written so a missing ``.daemon.log`` (sandbox unreachable, TTL expired, -# etc.) still leaves a breadcrumb with the ``download_err`` explanation. -# -# Caller (``runner._dump_skipped_diagnostic``) owns the sandbox client and -# is responsible for downloading the daemon log bytes; this function just -# writes files. Keeping HTTP out of trace.py means the module has no -# runtime deps on lagent/sandbox-client internals. - -_DIAGNOSTIC_TAIL_PREVIEW_LINES = 50 - - -def emit_diagnostic( - task_id: str | None, - uid: str | None, - data_source: str | None, - exception_type: str, - exception_msg: str, - daemon_log: bytes | None = None, - download_err: str | None = None, -) -> None: - """Persist a per-failure diagnostic bundle under - ``$WORK_DIR/trace/diagnostics/``. - - Args: - task_id (str | None): Dataset task id; used in filename + header. - uid (str | None): Per-sample observation id; truncated to 12 chars - in the filename so multiple sibling failures don't collide. - data_source (str | None): Data source label (e.g. ``"tb2-rl"``). - exception_type (str): Class name of the exception that triggered - the dump. - exception_msg (str): Short error message. - daemon_log (bytes | None): Full ``/tmp/agent_daemon.log`` bytes if - download succeeded; ``None`` when the caller's download failed - (sandbox unreachable, etc.). - download_err (str | None): Short description of the download - failure when ``daemon_log is None``. - """ - work_dir = os.environ.get("WORK_DIR") - if not work_dir: - return - diag_dir = Path(work_dir) / "trace" / "diagnostics" - try: - diag_dir.mkdir(parents=True, exist_ok=True) - except Exception as exc: - get_logger().debug(f"[trace] diagnostic dir mkdir failed: {exc}") - return - - ts = time.strftime("%H%M%S") - uid_short = (uid or "nouid")[:12] - base = diag_dir / f"{ts}_{task_id or 'notask'}_{uid_short}" - - full_size = 0 - tail_preview = "(no daemon log)" - if daemon_log is not None: - full_size = len(daemon_log) - text = daemon_log.decode(errors="replace") - tail_preview = "\n".join(text.splitlines()[-_DIAGNOSTIC_TAIL_PREVIEW_LINES:]) - try: - base.with_suffix(".daemon.log").write_bytes(daemon_log) - except Exception as exc: - get_logger().debug(f"[trace] daemon log write failed: {exc}") - elif download_err is not None: - tail_preview = f"(could not pull daemon log: {download_err})" - - try: - base.with_suffix(".log").write_text( - f"task_id={task_id}\n" - f"uid={uid}\n" - f"data_source={data_source}\n" - f"timestamp={time.time()}\n" - f"exception_type={exception_type}\n" - f"exception={exception_msg}\n" - f"daemon_log_bytes={full_size}\n" - f"daemon_log_file={base.with_suffix('.daemon.log').name if daemon_log else '(unavailable)'}\n" - f"---daemon_log_tail_preview (last {_DIAGNOSTIC_TAIL_PREVIEW_LINES} lines)---\n" - f"{tail_preview}\n" - ) - except Exception as exc: - get_logger().debug(f"[trace] diagnostic header write failed at {base}: {exc}") diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py index 86e1539da8..681f641160 100644 --- a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py +++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py @@ -1,6 +1,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status from xtuner.v1.rl.judger import Judger from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.trace import xtuner_trace_function from .agent_loop import AgentLoop, AgentLoopConfig @@ -64,6 +65,7 @@ def __init__( enable_batch_judge=enable_batch_judge, ) + @xtuner_trace_function("xtuner.agent_loop.generate_sample") async def generate_sample( self, rollout_state: RolloutState, diff --git a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py index 2b45a6ca15..a32328ccbe 100644 --- a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py +++ b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py @@ -12,6 +12,7 @@ from xtuner.v1.rl.judger import ComposedJudgerConfig, JudgerConfig, build_judger from xtuner.v1.rl.replay_buffer import ReplayBuffer from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.trace import flush_trace from xtuner.v1.utils import get_logger from .produce_utils import ( @@ -296,6 +297,7 @@ async def produce_batch( f"[AgentLoopManager][{self.name}] produce_batch done " f"elapsed={time.perf_counter() - start:.3f}, completed_groups={len(result.rollout_states)}" ) + flush_trace() return result async def save( diff --git a/xtuner/v1/rl/agent_loop_manager/produce_utils.py b/xtuner/v1/rl/agent_loop_manager/produce_utils.py index c9e990934f..6f9a9bddeb 100644 --- a/xtuner/v1/rl/agent_loop_manager/produce_utils.py +++ b/xtuner/v1/rl/agent_loop_manager/produce_utils.py @@ -13,6 +13,14 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status, get_group_status, reset_rollout_response from xtuner.v1.rl.agent_loop import AgentLoopSpec from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.rl.trace import ( + TRACE_EXTRA_MODEL_STEP, + TRACE_EXTRA_PRODUCE_BATCH_ID, + TRACE_EXTRA_PRODUCER_FUTURE_STEP, + TRACE_EXTRA_TRAIN_STEP, + build_produce_batch_id, + xtuner_trace_function, +) from xtuner.v1.rl.utils import ( AGENT_LOOP_PAUSE_REQUEST_TIMEOUT_S, PRODUCER_PAUSE_PENDING_TASK_TIMEOUT_S, @@ -120,13 +128,32 @@ class BaseProduceContext: def consumer_step(self) -> int: return self.train_step + def trace_kwargs(self) -> dict[str, Any]: + producer_future_step = getattr(self.progress, "producer_future_step", self.train_step) + produce_batch_id = build_produce_batch_id(self.train_step, self.model_step, producer_future_step) + return { + "task_name": self.task_name, + "train_step": self.train_step, + "model_step": self.model_step, + "producer_future_step": producer_future_step, + "produce_batch_id": produce_batch_id, + } + async def expired_count(self) -> int: return await self.replay_buffer.count(task_name=self.task_name, group_status=Status.EXPIRED) + @xtuner_trace_function( + "xtuner.producer.sample_group", + trace_kwargs_getter=lambda self, *args, **kwargs: self.trace_kwargs(), + ) async def sample_group(self, *, from_expired_pool: bool) -> list[RolloutState]: group_status = [Status.EXPIRED, Status.ABORTED] if from_expired_pool else [Status.ABORTED] return await self.sampler.sample(task_name=self.task_name, group_status=group_status) + @xtuner_trace_function( + "xtuner.producer.generate_group", + trace_kwargs_getter=lambda self, *args, **kwargs: self.trace_kwargs(), + ) async def generate_group( self, rollout_state: list[RolloutState], @@ -134,6 +161,14 @@ async def generate_group( enable_partial_rollout: bool = False, ) -> list[RolloutState]: # strategy 不关心 agent_loop 是 ray actor 还是本地对象。 + trace_kwargs = self.trace_kwargs() + for state in rollout_state: + extra_fields = dict(state.extra_fields or {}) + extra_fields[TRACE_EXTRA_TRAIN_STEP] = trace_kwargs["train_step"] + extra_fields[TRACE_EXTRA_MODEL_STEP] = trace_kwargs["model_step"] + extra_fields[TRACE_EXTRA_PRODUCER_FUTURE_STEP] = trace_kwargs["producer_future_step"] + extra_fields[TRACE_EXTRA_PRODUCE_BATCH_ID] = trace_kwargs["produce_batch_id"] + state.extra_fields = extra_fields start = time.perf_counter() if isinstance(self.agent_loop, ray.actor.ActorHandle): result = await self.agent_loop.generate_group.remote( @@ -147,13 +182,20 @@ async def generate_group( ) elapsed = time.perf_counter() - start for item in result: - extra_fields = getattr(item, "extra_fields", None) - if extra_fields is None: - extra_fields = {} - setattr(item, "extra_fields", extra_fields) + extra_fields = dict(item.extra_fields or {}) + extra_fields[TRACE_EXTRA_TRAIN_STEP] = trace_kwargs["train_step"] + extra_fields[TRACE_EXTRA_MODEL_STEP] = trace_kwargs["model_step"] + extra_fields[TRACE_EXTRA_PRODUCER_FUTURE_STEP] = trace_kwargs["producer_future_step"] + extra_fields[TRACE_EXTRA_PRODUCE_BATCH_ID] = trace_kwargs["produce_batch_id"] extra_fields[GROUP_GENERATE_TIME_KEY] = elapsed + item.extra_fields = extra_fields return result + @xtuner_trace_function( + "xtuner.producer.put_generated_group", + target="group", + trace_kwargs_getter=lambda self, *args, **kwargs: self.trace_kwargs(), + ) async def put_generated_group(self, group: list[RolloutState]) -> bool: # 只有 COMPLETED group 需要业务过滤;ABORTED / EXPIRED 保留原状态。 is_completed = get_group_status(group) == Status.COMPLETED diff --git a/xtuner/v1/rl/agent_loop_manager/sampler.py b/xtuner/v1/rl/agent_loop_manager/sampler.py index e5b0170bbe..b35945d2c6 100644 --- a/xtuner/v1/rl/agent_loop_manager/sampler.py +++ b/xtuner/v1/rl/agent_loop_manager/sampler.py @@ -12,6 +12,7 @@ from xtuner.v1.datasets.config import DataloaderConfig from xtuner.v1.datasets.dataloader import Dataloader from xtuner.v1.rl.replay_buffer import ReplayBuffer +from xtuner.v1.rl.trace import build_rollout_trace_id from xtuner.v1.utils import XTUNER_DETERMINISTIC from xtuner.v1.utils.logger import get_logger @@ -128,11 +129,20 @@ def __init__( self.replay_buffer = replay_buffer async def sample(self, task_name: str, group_status: list[Status] | None = None) -> list[RolloutState]: + group = None for status in group_status or []: buffer_data = await self.replay_buffer.get(1, task_name=task_name, group_status=status) if buffer_data: - return buffer_data[0] - return self.sample_from_dataloader() + group = buffer_data[0] + break + if group is None: + group = self.sample_from_dataloader() + for repeat_index, state in enumerate(group): + if state.task_name is None: + state.task_name = task_name + if state.trace_id is None: + state.trace_id = build_rollout_trace_id(state, repeat_index=repeat_index) + return group def save(self, checkpoint_path: Path | str) -> None: """Save the sampler's dataloader state to checkpoint.""" diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index efd49d62d6..d584394326 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -12,6 +12,7 @@ from transformers import AutoTokenizer from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.trace import xtuner_trace_function from xtuner.v1.rl.utils import AutoAcceleratorWorkers from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger @@ -173,6 +174,7 @@ def get_generate_concurrency(self) -> int: return active_workers * concurrency_per_worker @ray.method(concurrency_group=ROLLOUT_CONCURRENCY_GROUP_GENERATE) + @xtuner_trace_function("xtuner.rollout_controller.generate") async def generate(self, rollout_state: RolloutState) -> RolloutState: if XTUNER_DETERMINISTIC: sample_params = rollout_state.sample_params.model_copy(deep=True) diff --git a/xtuner/v1/rl/rollout/session_server.py b/xtuner/v1/rl/rollout/session_server.py index 666cfc740b..456809914d 100644 --- a/xtuner/v1/rl/rollout/session_server.py +++ b/xtuner/v1/rl/rollout/session_server.py @@ -8,9 +8,23 @@ from aiohttp import ClientConnectionResetError, ClientSession, ClientTimeout, web from transformers import AutoTokenizer +from xtuner.v1.rl.trace import otel_trace_span, set_otel_span_attrs from xtuner.v1.utils import get_logger from .chat_template import canonicalize_messages_for_chat_template +from .session_trace import ( + ForwardRequestTrace, + SessionTraceContext, +) +from .session_trace import ( + _choices_output_ids_len as _choices_output_ids_len, +) +from .session_trace import ( + _extract_body_trace_context as _extract_body_trace_context, +) +from .session_trace import ( + _response_output_ids_len as _response_output_ids_len, +) from .trace_store import TokenizedSegment, get_store @@ -73,7 +87,7 @@ def _extract_output_logprobs(choice: dict, output_token_ids: list[int]) -> list[ return [item[0] for item in output_token_logprobs] -_SESSION_SERVER_ONLY_KEYS = {"session_id"} +_SESSION_SERVER_ONLY_KEYS = {"session_id", "_otel_trace_context"} def _bool_request_value(value: Any, default: bool = False) -> bool: @@ -135,6 +149,16 @@ def __init__( async def on_request(self, req_body: dict, *, trace_enabled: bool = True) -> dict: """Hook for processing/modifying the request before forwarding.""" + with otel_trace_span( + "xtuner.session_server.on_request", + session_id=req_body.get("session_id"), + trace_store_enabled=trace_enabled, + messages=len(req_body.get("messages") or []) if isinstance(req_body.get("messages"), list) else None, + tools=len(req_body.get("tools") or []) if isinstance(req_body.get("tools"), list) else None, + ) as span: + return await self._on_request_impl(req_body, trace_enabled=trace_enabled, span=span) + + async def _on_request_impl(self, req_body: dict, *, trace_enabled: bool, span: Any = None) -> dict: if not trace_enabled: worker_req = {k: v for k, v in req_body.items() if k not in _SESSION_SERVER_ONLY_KEYS} if "logprobs" in worker_req: @@ -154,6 +178,7 @@ async def on_request(self, req_body: dict, *, trace_enabled: bool = True) -> dic add_generation_prompt=True, tokenize=False, ) + set_otel_span_attrs(span, prompt_chars=len(prompt_text)) # 2. Store 做 string prefix match。 prefix, nodes = await self.store.search.remote(session_id, prompt_text, filter_none=True) @@ -164,6 +189,7 @@ async def on_request(self, req_body: dict, *, trace_enabled: bool = True) -> dic delta_ids = self.tokenizer.encode(delta, add_special_tokens=False) await self.store.insert.remote(session_id, prompt_text, TokenizedSegment(text=delta, token_ids=delta_ids)) input_ids = reduce(add, [node.value.token_ids for node in nodes] + [delta_ids]) + set_otel_span_attrs(span, prefix_chars=len(prefix), delta_chars=len(delta), input_tokens=len(input_ids)) # 3. 组装 OpenAI chat completions 请求。 worker_req = { @@ -184,6 +210,14 @@ async def on_request(self, req_body: dict, *, trace_enabled: bool = True) -> dic async def on_response(self, worker_resp: dict, *, trace_enabled: bool = True) -> dict: """Hook for processing the parsed response received from the worker.""" + with otel_trace_span( + "xtuner.session_server.on_response", + session_id=worker_resp.get("session_id"), + trace_store_enabled=trace_enabled, + ) as span: + return await self._on_response_impl(worker_resp, trace_enabled=trace_enabled, span=span) + + async def _on_response_impl(self, worker_resp: dict, *, trace_enabled: bool, span: Any = None) -> dict: if not trace_enabled: return {k: v for k, v in worker_resp.items() if k not in {"messages", "tools"}} @@ -198,6 +232,13 @@ async def on_response(self, worker_resp: dict, *, trace_enabled: bool = True) -> "SessionServer response choice has no output_ids; " "cannot export a training trace for this assistant turn." ) + message = choice.get("message") or {} + set_otel_span_attrs( + span, + output_tokens=len(output_token_ids), + response_chars=len(message.get("content") or ""), + finish_reason=choice.get("finish_reason"), + ) output_logprobs = _extract_output_logprobs(choice, output_token_ids) raw_routed_expert = choice.get("routed_experts") # 本次 call 的 raw routed_expert,可为 None @@ -295,9 +336,24 @@ async def stop(self): async def _handle_request(self, request: web.Request) -> web.Response: """Proxy handler for the worker API.""" - # Read the request body request_body = await request.read() + trace_context = SessionTraceContext.from_request(request.headers, request_body) + with trace_context.use(): + return await self._handle_request_impl( + request, + request_body=request_body, + trace_context=trace_context, + ) + + async def _handle_request_impl( + self, + request: web.Request, + *, + request_body: bytes, + trace_context: SessionTraceContext, + ) -> web.Response: request_data = session_id = messages = None + tools = None trace_enabled = False orig_return_logprob = orig_return_token_ids = orig_return_routed_experts = False if request_body: @@ -332,6 +388,7 @@ async def _handle_request(self, request: web.Request) -> web.Response: forward_headers.pop("host", None) forward_headers.pop("Content-Length", None) forward_headers.pop("content-length", None) + trace_context.inject_forward_headers(forward_headers) # Re-build Path req_path = request.match_info["path"] @@ -379,72 +436,105 @@ def _clean_data(data: dict) -> bool: # tool_calls/reasoning_content payloads can exceed the 64KB default and trigger # "Chunk too big" from readuntil(b"\n"). timeout = ClientTimeout(total=self.request_timeout, sock_connect=30) - async with ClientSession(read_bufsize=self.read_bufsize, timeout=timeout) as client: - async with client.request( - method=request.method, url=target_url, headers=forward_headers, data=request_body - ) as resp: - # Setup proper stream vs sync response objects - if is_stream: - response_chunks = [] - response = web.StreamResponse( - status=resp.status, - headers={ - k: v - for k, v in resp.headers.items() - if k.lower() not in ("transfer-encoding", "content-length", "content-encoding") - }, - ) - await response.prepare(request) - # If the downstream client closes the socket mid-stream - # (e.g. AsyncAPIClient bails out on a finish_reason=='error' - # chunk after the prompt overflowed the session window), - # keep draining the upstream so the trace is still recorded - # in full but stop attempting to write to the closed socket. - client_alive = True - async for line in resp.content: - # Keep unmodified line for trace store parsing - if trace_enabled: - response_chunks.append(line) - - # Dynamically prune added fields before writing to client - if request_data is not None and line.startswith(b"data: ") and line.strip() != b"data: [DONE]": + raw_response: bytes | None = None + forward_trace = ForwardRequestTrace.start( + request, + target_url=target_url, + timeout_s=self.request_timeout, + worker_base_url=self.worker_base_url, + request_body=request_body, + request_data=request_data, + trace_context=trace_context, + ) + try: + async with ClientSession(read_bufsize=self.read_bufsize, timeout=timeout) as client: + async with client.request( + method=request.method, url=target_url, headers=forward_headers, data=request_body + ) as resp: + forward_trace.set_http_status(resp.status) + # Setup proper stream vs sync response objects + if is_stream: + response_chunks = [] + response = web.StreamResponse( + status=resp.status, + headers={ + k: v + for k, v in resp.headers.items() + if k.lower() not in ("transfer-encoding", "content-length", "content-encoding") + }, + ) + await response.prepare(request) + # If the downstream client closes the socket mid-stream + # (e.g. AsyncAPIClient bails out on a finish_reason=='error' + # chunk after the prompt overflowed the session window), + # keep draining the upstream so the trace is still recorded + # in full but stop attempting to write to the closed socket. + client_alive = True + stream_trace = forward_trace.start_stream( + target_url=target_url, + request_data=request_data, + ) + try: + async for line in resp.content: + stream_trace.on_line(line) + # Keep unmodified line for trace store parsing + if trace_enabled: + response_chunks.append(line) + + # Dynamically prune added fields before writing to client + if ( + request_data is not None + and line.startswith(b"data: ") + and line.strip() != b"data: [DONE]" + ): + try: + text = line.decode("utf-8") + data = json.loads(text[6:]) + if _clean_data(data): + line = ("data: " + json.dumps(data) + "\n").encode("utf-8") + except Exception: + pass + + # Delay [DONE] only while a training trace still needs to be exported. + if client_alive and (not trace_enabled or line.strip() != b"data: [DONE]"): + try: + await response.write(line) + except (ConnectionError, ClientConnectionResetError): + client_alive = False + finally: + stream_trace.finish(client_alive=client_alive) + + raw_response = b"".join(response_chunks) if trace_enabled else b"" + else: + with forward_trace.read_response_span(target_url=target_url): + raw_response = await resp.read() + final_raw_response = raw_response + forward_trace.on_non_stream_response(raw_response) + + if request_data is not None: try: - text = line.decode("utf-8") - data = json.loads(text[6:]) - if _clean_data(data): - line = ("data: " + json.dumps(data) + "\n").encode("utf-8") + clean_data = json.loads(raw_response) + if isinstance(clean_data, dict) and _clean_data(clean_data): + final_raw_response = json.dumps(clean_data).encode("utf-8") except Exception: pass - # Delay [DONE] only while a training trace still needs to be exported. - if client_alive and (not trace_enabled or line.strip() != b"data: [DONE]"): - try: - await response.write(line) - except (ConnectionError, ClientConnectionResetError): - client_alive = False - - raw_response = b"".join(response_chunks) if trace_enabled else b"" - else: - raw_response = await resp.read() - final_raw_response = raw_response - - if request_data is not None: - try: - clean_data = json.loads(raw_response) - if _clean_data(clean_data): - final_raw_response = json.dumps(clean_data).encode("utf-8") - except Exception: - pass - - response = web.Response( - status=resp.status, - headers={ - k: v - for k, v in resp.headers.items() - if k.lower() not in ("transfer-encoding", "content-length", "content-encoding") - }, - body=final_raw_response, # Modified raw response without our injected trace params - ) + response = web.Response( + status=resp.status, + headers={ + k: v + for k, v in resp.headers.items() + if k.lower() not in ("transfer-encoding", "content-length", "content-encoding") + }, + body=final_raw_response, # Modified raw response without our injected trace params + ) + except Exception as exc: + forward_trace.finish(exc=exc) + raise + else: + forward_trace.finish(raw_response=raw_response) + if raw_response is None: + raw_response = b"" # Apply abstract on_response processing response_data = None @@ -656,3 +746,12 @@ async def stop(self) -> None: if self.server is not None: await self.server.stop() self.server = None + + +__all__ = [ + "SessionServer", + "SessionServerActor", + "_choices_output_ids_len", + "_extract_body_trace_context", + "_response_output_ids_len", +] diff --git a/xtuner/v1/rl/rollout/session_trace.py b/xtuner/v1/rl/rollout/session_trace.py new file mode 100644 index 0000000000..7b49974cd2 --- /dev/null +++ b/xtuner/v1/rl/rollout/session_trace.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +import contextlib +import json +import time +from collections.abc import Iterator, Mapping +from dataclasses import dataclass +from typing import Any + +from aiohttp import web + +from xtuner.v1.rl.trace import ( + begin_otel_span, + end_otel_span, + extract_trace_context, + inject_trace_context, + otel_trace_span, + set_otel_span_attrs, + use_trace_context, +) + + +def _extract_body_trace_context(request_body: bytes) -> dict[str, Any] | None: + if not request_body: + return None + try: + body_data = json.loads(request_body) + except json.JSONDecodeError: + return None + if not isinstance(body_data, dict): + return None + trace_context = body_data.get("_otel_trace_context") + return trace_context if isinstance(trace_context, dict) else None + + +def _list_len(value: Any) -> int | None: + return len(value) if isinstance(value, list) else None + + +def _choices_output_ids_len(data: dict) -> int: + total = 0 + for choice in data.get("choices") or []: + output_ids = choice.get("output_ids") + if isinstance(output_ids, list): + total += len(output_ids) + return total + + +def _response_output_ids_len(data: dict) -> int | None: + output_ids = data.get("output_ids") + if isinstance(output_ids, list): + return len(output_ids) + total = _choices_output_ids_len(data) + return total if total > 0 else None + + +@dataclass(frozen=True) +class SessionTraceContext: + parent_context: Any + traceparent_header_present: bool + traceparent_body_present: bool + source: str + + @classmethod + def from_request(cls, headers: Mapping[str, Any], request_body: bytes) -> SessionTraceContext: + body_trace_context = _extract_body_trace_context(request_body) + traceparent_header = headers.get("traceparent") + traceparent_body = body_trace_context.get("traceparent") if body_trace_context else None + parent_context = extract_trace_context(headers) + source = "header" if traceparent_header else "none" + if traceparent_body: + parent_context = extract_trace_context(body_trace_context) + source = "body" + return cls( + parent_context=parent_context, + traceparent_header_present=bool(traceparent_header), + traceparent_body_present=bool(traceparent_body), + source=source, + ) + + @contextlib.contextmanager + def use(self) -> Iterator[None]: + with use_trace_context(self.parent_context): + yield + + def inject_forward_headers(self, headers: dict[str, str]) -> None: + inject_trace_context(headers) + + +@dataclass +class StreamResponseTrace: + forward_trace: ForwardRequestTrace + span: Any + start_s: float + forward_start_s: float + first_chunk_ms: float | None = None + first_chunk_from_forward_ms: float | None = None + first_output_token_ms: float | None = None + first_output_token_from_forward_ms: float | None = None + first_content_ms: float | None = None + first_content_from_forward_ms: float | None = None + chunk_count: int = 0 + raw_response_bytes: int = 0 + output_tokens: int = 0 + prompt_tokens: int | None = None + completion_tokens: int | None = None + total_tokens: int | None = None + finish_reason: str | None = None + + @classmethod + def start( + cls, + forward_trace: ForwardRequestTrace, + *, + target_url: str, + input_tokens: int | None, + max_tokens: Any, + forward_start_s: float, + ) -> StreamResponseTrace: + start_s = time.perf_counter() + return cls( + forward_trace=forward_trace, + span=begin_otel_span( + "xtuner.session_server.stream_read", + target_url=target_url, + input_tokens=input_tokens, + max_tokens=max_tokens, + upstream_headers_ms=forward_trace.upstream_headers_ms, + **{"xtuner.stage.kind": "llm_call"}, + ), + start_s=start_s, + forward_start_s=forward_start_s, + ) + + def on_line(self, line: bytes) -> None: + now_s = time.perf_counter() + elapsed_ms = self._elapsed_ms(now_s) + elapsed_from_forward_ms = self._elapsed_from_forward_ms(now_s) + self.chunk_count += 1 + self.raw_response_bytes += len(line) + if self.first_chunk_ms is None: + self.first_chunk_ms = elapsed_ms + self.first_chunk_from_forward_ms = elapsed_from_forward_ms + if not line.startswith(b"data: ") or line.strip() == b"data: [DONE]": + return + try: + data = json.loads(line.decode("utf-8")[6:]) + except Exception: + return + event_output_tokens = _choices_output_ids_len(data) + if event_output_tokens > 0 and self.first_output_token_ms is None: + self.first_output_token_ms = elapsed_ms + self.first_output_token_from_forward_ms = elapsed_from_forward_ms + self.output_tokens += event_output_tokens + usage = data.get("usage") + if isinstance(usage, dict): + self.prompt_tokens = usage.get("prompt_tokens", self.prompt_tokens) + self.completion_tokens = usage.get("completion_tokens", self.completion_tokens) + self.total_tokens = usage.get("total_tokens", self.total_tokens) + for choice in data.get("choices") or []: + delta = choice.get("delta") or {} + if delta.get("content") and self.first_content_ms is None: + self.first_content_ms = elapsed_ms + self.first_content_from_forward_ms = elapsed_from_forward_ms + if choice.get("finish_reason"): + self.finish_reason = choice.get("finish_reason") + + def finish(self, *, client_alive: bool) -> None: + attrs = self.attrs(client_alive=client_alive, finish_s=time.perf_counter()) + end_otel_span(self.span, **attrs) + self.forward_trace.set_attrs(**attrs) + + def attrs(self, *, client_alive: bool, finish_s: float | None = None) -> dict[str, Any]: + finish_s = time.perf_counter() if finish_s is None else finish_s + return { + "first_chunk_ms": self.first_chunk_ms, + "first_chunk_from_forward_ms": self.first_chunk_from_forward_ms, + "first_output_token_ms": self.first_output_token_ms, + "first_output_token_from_forward_ms": self.first_output_token_from_forward_ms, + "first_content_ms": self.first_content_ms, + "first_content_from_forward_ms": self.first_content_from_forward_ms, + "stream_read_ms": max(0.0, (finish_s - self.start_s) * 1000), + "stream_complete_from_forward_ms": max(0.0, (finish_s - self.forward_start_s) * 1000), + "chunks": self.chunk_count, + "raw_response_bytes": self.raw_response_bytes, + "output_tokens": self.output_tokens if self.output_tokens > 0 else None, + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "total_tokens": self.total_tokens, + "finish_reason": self.finish_reason, + "client_alive": client_alive, + } + + def _elapsed_ms(self, now_s: float | None = None) -> float: + now_s = time.perf_counter() if now_s is None else now_s + return (now_s - self.start_s) * 1000 + + def _elapsed_from_forward_ms(self, now_s: float | None = None) -> float: + now_s = time.perf_counter() if now_s is None else now_s + return (now_s - self.forward_start_s) * 1000 + + +class ForwardRequestTrace: + def __init__(self, span: Any, start_s: float | None = None) -> None: + self.span = span + self.start_s = time.perf_counter() if start_s is None else start_s + self.upstream_headers_ms: float | None = None + + @classmethod + def start( + cls, + request: web.Request, + *, + target_url: str, + request_body: bytes, + request_data: dict[str, Any] | None, + timeout_s: float, + worker_base_url: str, + trace_context: SessionTraceContext, + ) -> ForwardRequestTrace: + start_s = time.perf_counter() + is_stream = bool(request_data.get("stream", False)) if request_data else False + input_tokens = _list_len(request_data.get("input_ids")) if request_data else None + max_tokens = request_data.get("max_tokens") if request_data else None + span = begin_otel_span( + "xtuner.session_server.forward_worker", + target_url=target_url, + stream=is_stream, + request_bytes=len(request_body) if request_body else 0, + timeout_s=timeout_s, + input_tokens=input_tokens, + max_tokens=max_tokens, + model=request_data.get("model") if request_data else None, + http_method=request.method, + http_path=request.path, + worker_base_url=worker_base_url, + traceparent_header_present=trace_context.traceparent_header_present, + traceparent_body_present=trace_context.traceparent_body_present, + traceparent_context_source=trace_context.source, + **{"xtuner.stage.kind": "llm_call"}, + ) + return cls(span, start_s=start_s) + + def set_attrs(self, **attrs: Any) -> None: + set_otel_span_attrs(self.span, **attrs) + + def set_http_status(self, status: int) -> None: + attrs: dict[str, Any] = {"http_status": status} + if self.upstream_headers_ms is None: + self.upstream_headers_ms = self._elapsed_ms() + attrs["upstream_headers_ms"] = self.upstream_headers_ms + self.set_attrs(**attrs) + + def start_stream( + self, + *, + target_url: str, + request_data: dict[str, Any] | None, + ) -> StreamResponseTrace: + input_tokens = _list_len(request_data.get("input_ids")) if request_data else None + max_tokens = request_data.get("max_tokens") if request_data else None + return StreamResponseTrace.start( + self, + target_url=target_url, + input_tokens=input_tokens, + max_tokens=max_tokens, + forward_start_s=self.start_s, + ) + + @contextlib.contextmanager + def read_response_span(self, *, target_url: str) -> Iterator[None]: + with otel_trace_span("xtuner.session_server.read_response", target_url=target_url): + yield + + def on_non_stream_response(self, raw_response: bytes) -> None: + self.set_attrs(response_bytes=len(raw_response)) + try: + data = json.loads(raw_response) + except Exception: + return + if not isinstance(data, dict): + return + usage = data.get("usage") + self.set_attrs( + output_tokens=_response_output_ids_len(data), + prompt_tokens=usage.get("prompt_tokens") if isinstance(usage, dict) else None, + completion_tokens=usage.get("completion_tokens") if isinstance(usage, dict) else None, + total_tokens=usage.get("total_tokens") if isinstance(usage, dict) else None, + ) + + def finish(self, *, raw_response: bytes | None = None, exc: BaseException | None = None) -> None: + if exc is not None: + end_otel_span(self.span, exc=exc) + return + end_otel_span(self.span, response_bytes=len(raw_response) if raw_response is not None else None) + + def _elapsed_ms(self) -> float: + return (time.perf_counter() - self.start_s) * 1000 + + +__all__ = [ + "ForwardRequestTrace", + "SessionTraceContext", + "StreamResponseTrace", + "_choices_output_ids_len", + "_extract_body_trace_context", + "_response_output_ids_len", +] diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index 7991280db8..718a85558f 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -11,6 +11,7 @@ from ray import ObjectRef as RayObjectRef from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.trace import xtuner_trace_function from xtuner.v1.rl.utils import free_object_refs from xtuner.v1.utils import get_logger @@ -275,6 +276,7 @@ class PartialRolloutHandler: def __init__(self) -> None: self.logger = get_logger(self.__class__.__name__) + @xtuner_trace_function("xtuner.partial_rollout_handler.preprocess") def preprocess(self, rollout_state: RolloutState, max_tokens: int) -> RolloutState: # Set up token and length variable response_ids = list(rollout_state.response_ids or []) @@ -291,6 +293,7 @@ def preprocess(self, rollout_state: RolloutState, max_tokens: int) -> RolloutSta ) return rollout_state + @xtuner_trace_function("xtuner.partial_rollout_handler.postprocess") async def postprocess( self, rollout_state: RolloutState, diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index d5f2c7391b..afac4afe1b 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -28,6 +28,7 @@ reset_rollout_response, update_status_from_finish_reason, ) +from xtuner.v1.rl.trace import merge_trace_runtime_env, xtuner_trace_function, xtuner_trace_span from xtuner.v1.rl.utils import ( AutoAcceleratorWorkers, CPUResourcesConfig, @@ -445,13 +446,15 @@ def build(self, placement_group: "PlacementGroup"): ) generate_max_concurrency = self.get_controller_generate_concurrency(placement_group) get_logger().info(f"Calculated RolloutController generate concurrency: {generate_max_concurrency}") + actor_options = {"num_cpus": num_workers} + merge_trace_runtime_env(actor_options) return ( ray.remote( concurrency_groups={ ROLLOUT_CONCURRENCY_GROUP_GENERATE: generate_max_concurrency, }, )(RolloutController) - .options(num_cpus=num_workers) + .options(**actor_options) .remote(self, placement_group) ) @@ -678,6 +681,9 @@ async def _decode_routed_experts(self, routed_experts: Any) -> Any: return routed_experts @ray.method(concurrency_group=ROLLOUT_CONCURRENCY_GROUP_GENERATE) + @xtuner_trace_function( + "xtuner.rollout_worker.generate", trace_kwargs_getter=lambda self, *args, **kwargs: {"worker_rank": self.rank} + ) async def generate(self, rollout_state: RolloutState) -> RolloutState: try: # TODO(@duanyanhui): @@ -739,7 +745,12 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: for attempt in range(max_retries + 1): is_last_attempt = attempt == max_retries - http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload) + async with xtuner_trace_span( + rollout_state, + "xtuner.rollout_engine.generate", + worker_rank=self.rank, + ): + http_result = await self._safe_post_request(endpoint_url, headers=headers, payload=payload) # Case 1: HTTP Request is Successful if http_result.response: diff --git a/xtuner/v1/rl/trace.py b/xtuner/v1/rl/trace.py new file mode 100644 index 0000000000..3ef9ec6d5b --- /dev/null +++ b/xtuner/v1/rl/trace.py @@ -0,0 +1,1596 @@ +from __future__ import annotations + +import atexit +import contextlib +import contextvars +import dataclasses +import functools +import hashlib +import inspect +import json +import os +import threading +import time +import traceback +from collections.abc import Mapping, MutableMapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, cast + +from pydantic import BaseModel, ConfigDict + +from xtuner.v1.data_proto.rl_data import RolloutState +from xtuner.v1.utils import get_logger + + +logger = get_logger() +_NANOSECONDS_PER_SECOND = 1_000_000_000 +_OTEL_INT64_MIN = -(2**63) +_OTEL_INT64_MAX = 2**63 - 1 + +OTEL_TRACES_EXPORTER_ENV = "OTEL_TRACES_EXPORTER" +OTEL_EXPORTER_OTLP_PROTOCOL_ENV = "OTEL_EXPORTER_OTLP_PROTOCOL" +OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV = "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" +OTEL_SERVICE_NAME_ENV = "OTEL_SERVICE_NAME" +XTUNER_OTEL_RUN_ID_ENV = "XTUNER_OTEL_RUN_ID" +TRACE_EXTRA_TRAIN_STEP = "_trace_train_step" +TRACE_EXTRA_MODEL_STEP = "_trace_model_step" +TRACE_EXTRA_PRODUCER_FUTURE_STEP = "_trace_producer_future_step" +TRACE_EXTRA_PRODUCE_BATCH_ID = "_trace_produce_batch_id" +TRACE_RECORD_KWARGS = frozenset( + { + "task_name", + "uid", + "session_uid", + "status", + "train_step", + "model_step", + "producer_future_step", + "produce_batch_id", + "worker_rank", + "elapsed_s", + "error_msg", + "error_type", + "error_stacktrace", + "timestamp_s", + "attributes", + } +) + + +# Config and event schema. +class TraceConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + # Whether producer task tracing is enabled. + enabled: bool = False + # OTLP traces endpoint used by the OpenTelemetry exporter. + otel_endpoint: str = "http://127.0.0.1:4317" + # OpenTelemetry exporter transport protocol. + otel_protocol: Literal["grpc", "http/protobuf"] = "grpc" + # OpenTelemetry exporter backend. "console" is intended for local debugging. + otel_exporter: Literal["otlp", "console"] = "otlp" + # OpenTelemetry service.name resource attribute. + otel_service_name: str = "xtuner-rl" + # Optional Jaeger Query/UI base URL used by the producer trace viewer to fetch each task's OTel trace. + jaeger_query_url: str | None = None + + +@dataclass(frozen=True) +class TraceEvent: + trace_id: str + stage: str + timestamp_s: float + status: str | None = None + task_name: str | None = None + uid: int | str | None = None + session_uid: int | str | None = None + train_step: int | None = None + model_step: int | None = None + producer_future_step: int | None = None + produce_batch_id: str | None = None + worker_rank: int | None = None + elapsed_s: float | None = None + error_msg: str | None = None + error_type: str | None = None + error_stacktrace: str | None = None + attributes: dict[str, Any] | None = None + + def to_dict(self) -> dict[str, Any]: + return dataclasses.asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> TraceEvent: + return cls( + trace_id=str(data["trace_id"]), + stage=str(data["stage"]), + timestamp_s=float(data["timestamp_s"]), + status=data.get("status"), + task_name=data.get("task_name"), + uid=data.get("uid"), + session_uid=data.get("session_uid"), + train_step=data.get("train_step"), + model_step=data.get("model_step"), + producer_future_step=data.get("producer_future_step"), + produce_batch_id=data.get("produce_batch_id"), + worker_rank=data.get("worker_rank"), + elapsed_s=data.get("elapsed_s"), + error_msg=data.get("error_msg"), + error_type=data.get("error_type"), + error_stacktrace=data.get("error_stacktrace"), + attributes=TraceEventBuilder.normalize_attributes(data.get("attributes")), + ) + + +class TraceEventBuilder: + @classmethod + def trace_id(cls, task_name: str | None, uid: int | str | None) -> str | None: + if uid is None: + return None + return f"{task_name or 'unknown'}:{uid}" + + @classmethod + def produce_batch_id( + cls, + train_step: int | None, + model_step: int | None, + producer_future_step: int | None, + ) -> str | None: + if train_step is None and model_step is None and producer_future_step is None: + return None + return ( + f"train_step={cls._format_batch_value(train_step)}/" + f"model_step={cls._format_batch_value(model_step)}/" + f"producer_future_step={cls._format_batch_value(producer_future_step)}" + ) + + @classmethod + def build( + cls, + target: Any, + stage: str, + *, + task_name: str | None = None, + uid: int | str | None = None, + session_uid: int | str | None = None, + status: Any | None = None, + train_step: int | None = None, + model_step: int | None = None, + producer_future_step: int | None = None, + produce_batch_id: str | None = None, + worker_rank: int | None = None, + elapsed_s: float | None = None, + error_msg: str | None = None, + error_type: str | None = None, + error_stacktrace: str | None = None, + timestamp_s: float | None = None, + attributes: Mapping[str, Any] | None = None, + **custom_attributes: Any, + ) -> TraceEvent | None: + resolved_task_name = task_name if task_name is not None else getattr(target, "task_name", None) + if resolved_task_name is None: + resolved_task_name = getattr(target, "data_source", None) + resolved_uid = uid if uid is not None else getattr(target, "uid", None) + explicit_trace_id = getattr(target, "trace_id", None) + trace_id: str | None + if isinstance(explicit_trace_id, str) and explicit_trace_id: + trace_id = explicit_trace_id + else: + trace_id = cls.trace_id(resolved_task_name, resolved_uid) + if trace_id is None: + return None + + resolved_status = status if status is not None else getattr(target, "status", None) + resolved_session_uid = session_uid if session_uid is not None else getattr(target, "session_uid", None) + if resolved_session_uid is None: + resolved_session_uid = getattr(target, "group_id", None) + trace_extra_fields = cls._extra_fields(target) + if train_step is None: + train_step = trace_extra_fields.get(TRACE_EXTRA_TRAIN_STEP) + if model_step is None: + model_step = trace_extra_fields.get(TRACE_EXTRA_MODEL_STEP, getattr(target, "model_step", None)) + if producer_future_step is None: + producer_future_step = trace_extra_fields.get(TRACE_EXTRA_PRODUCER_FUTURE_STEP) + if produce_batch_id is None: + produce_batch_id = trace_extra_fields.get(TRACE_EXTRA_PRODUCE_BATCH_ID) + if produce_batch_id is None: + produce_batch_id = cls.produce_batch_id(train_step, model_step, producer_future_step) + + return TraceEvent( + trace_id=trace_id, + stage=stage, + timestamp_s=time.time() if timestamp_s is None else timestamp_s, + status=cls._stringify_status(resolved_status), + task_name=resolved_task_name, + uid=resolved_uid, + session_uid=resolved_session_uid, + train_step=train_step, + model_step=model_step, + producer_future_step=producer_future_step, + produce_batch_id=produce_batch_id, + worker_rank=worker_rank, + elapsed_s=elapsed_s, + error_msg=error_msg, + error_type=error_type, + error_stacktrace=error_stacktrace, + attributes=cls.normalize_attributes(attributes, custom_attributes), + ) + + @staticmethod + def short_error(exc: BaseException, max_len: int = 500) -> str: + message = f"{type(exc).__name__}: {exc}" + if len(message) <= max_len: + return message + return message[: max_len - 3] + "..." + + @staticmethod + def error_type(exc: BaseException) -> str: + return type(exc).__name__ + + @staticmethod + def stacktrace(exc: BaseException, max_len: int = 20_000) -> str: + stack = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)).rstrip() + if len(stack) <= max_len: + return stack + return stack[: max_len - 3] + "..." + + @classmethod + def custom_attributes_from_kwargs(cls, kwargs: Mapping[str, Any]) -> dict[str, Any] | None: + custom_attributes = {key: value for key, value in kwargs.items() if key not in TRACE_RECORD_KWARGS} + attributes = kwargs.get("attributes") + return cls.normalize_attributes(attributes if isinstance(attributes, Mapping) else None, custom_attributes) + + @classmethod + def normalize_attributes( + cls, + attributes: Mapping[str, Any] | None, + custom_attributes: Mapping[str, Any] | None = None, + ) -> dict[str, Any] | None: + result: dict[str, Any] = {} + for source in (attributes, custom_attributes): + if not source: + continue + for key, value in source.items(): + normalized = cls._normalize_attribute_value(value) + if normalized is not None: + result[str(key)] = normalized + return result or None + + @classmethod + def _normalize_attribute_value(cls, value: Any) -> Any: + if value is None: + return None + enum_value = getattr(value, "value", None) + if isinstance(enum_value, (str, int, float, bool)): + return enum_value + if isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, Path): + return str(value) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [ + item if isinstance(item, (str, int, float, bool)) else str(item) for item in value if item is not None + ] + return str(value) + + @staticmethod + def _extra_fields(target: Any) -> dict[str, Any]: + extra_fields = getattr(target, "extra_fields", None) + if isinstance(extra_fields, dict): + return cast(dict[str, Any], extra_fields) + return {} + + @staticmethod + def _format_batch_value(value: int | None) -> str: + return "none" if value is None else str(value) + + @staticmethod + def _stringify_status(status: Any) -> str | None: + if status is None: + return None + value = getattr(status, "value", None) + if value is not None: + return str(value) + return str(status) + + +def _json_dumps_stable(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str, separators=(",", ":")) + + +def _is_otel_int64(value: int) -> bool: + return _OTEL_INT64_MIN <= value <= _OTEL_INT64_MAX + + +def _normalize_otel_attribute_value(value: Any) -> Any: + if value is None: + return None + if isinstance(value, bool): + return value + enum_value = getattr(value, "value", None) + if enum_value is not None: + return _normalize_otel_attribute_value(enum_value) + if isinstance(value, int): + return value if _is_otel_int64(value) else str(value) + if isinstance(value, (str, float)): + return value + if isinstance(value, Path): + return str(value) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + normalized_items = [_normalize_otel_attribute_value(item) for item in value if item is not None] + return [item for item in normalized_items if item is not None] + return str(value) + + +def _normalize_otel_attributes(attrs: Mapping[str, Any]) -> dict[str, Any]: + normalized_attrs: dict[str, Any] = {} + for key, value in attrs.items(): + normalized = _normalize_otel_attribute_value(value) + if normalized is not None: + normalized_attrs[str(key)] = normalized + return normalized_attrs + + +def get_trace_run_id() -> str | None: + return ( + os.environ.get(XTUNER_OTEL_RUN_ID_ENV) + or os.environ.get("RUN_ID") + or os.environ.get("MODEL_NAME") + or os.environ.get("WORK_DIR") + ) + + +def build_rollout_trace_id( + state: RolloutState, + *, + repeat_index: int | None = None, +) -> str: + payload: dict[str, Any] + if state.message_uid is not None: + payload = { + "task_name": state.task_name, + "data_source": state.data_source, + "message_uid": state.message_uid, + "repeat_index": repeat_index, + } + else: + payload = { + "task_name": state.task_name, + "data_source": state.data_source, + "message": state.message, + "repeat_index": repeat_index, + } + digest = hashlib.sha1(_json_dumps_stable(payload).encode("utf-8")).hexdigest()[:16] + prefix = state.task_name or "unknown" + return f"{prefix}:{digest}" + + +def get_rollout_trace_id(state: RolloutState) -> str | None: + if state.trace_id: + return state.trace_id + if state.uid is not None: + return f"{state.task_name or 'unknown'}:{state.uid}" + return None + + +def build_rollout_trace_attributes(state: RolloutState) -> dict[str, Any]: + trace_id = get_rollout_trace_id(state) + attrs: dict[str, Any] = {} + if trace_id is not None: + attrs["xtuner.trace_id"] = trace_id + attrs["case.id"] = trace_id + run_id = get_trace_run_id() + if run_id: + attrs["run.id"] = run_id + if state.task_name is not None: + attrs["task.name"] = state.task_name + if state.uid is not None: + attrs["xtuner.uid"] = state.uid + if state.message_uid is not None: + attrs["sample.message_uid"] = state.message_uid + if state.data_source is not None: + attrs["sample.data_source"] = ( + _json_dumps_stable(state.data_source) if isinstance(state.data_source, dict) else str(state.data_source) + ) + return _normalize_otel_attributes(attrs) + + +# Recorder API used by trace_event, xtuner_trace_span, and xtuner_trace_function. +class TraceRecorder: + def __init__(self, otel_sink: Any) -> None: + self.otel_sink = otel_sink + + def record( + self, + target: Any, + stage: str, + *, + task_name: str | None = None, + uid: int | str | None = None, + session_uid: int | str | None = None, + status: Any | None = None, + train_step: int | None = None, + model_step: int | None = None, + producer_future_step: int | None = None, + produce_batch_id: str | None = None, + worker_rank: int | None = None, + elapsed_s: float | None = None, + error_msg: str | None = None, + error_type: str | None = None, + error_stacktrace: str | None = None, + timestamp_s: float | None = None, + attributes: Mapping[str, Any] | None = None, + **custom_attributes: Any, + ) -> TraceEvent | None: + event = build_trace_event( + target, + stage, + task_name=task_name, + uid=uid, + session_uid=session_uid, + status=status, + train_step=train_step, + model_step=model_step, + producer_future_step=producer_future_step, + produce_batch_id=produce_batch_id, + worker_rank=worker_rank, + elapsed_s=elapsed_s, + error_msg=error_msg, + error_type=error_type, + error_stacktrace=error_stacktrace, + timestamp_s=timestamp_s, + attributes=attributes, + **custom_attributes, + ) + if event is None: + return None + try: + self.otel_sink.append(event) + except Exception: + logger.exception("Failed to export OpenTelemetry trace event stage=%s trace_id=%s", stage, event.trace_id) + return None + return event + + def record_many(self, targets: Iterable[Any], stage: str, **kwargs: Any) -> list[TraceEvent]: + events: list[TraceEvent] = [] + for target in targets: + event = self.record(target, stage, **kwargs) + if event is not None: + events.append(event) + return events + + async def mark(self, target: Any, stage: str, **kwargs: Any) -> TraceEvent | None: + return self.record(target, stage, **kwargs) + + async def mark_many(self, targets: Iterable[Any], stage: str, **kwargs: Any) -> list[TraceEvent]: + return self.record_many(targets, stage, **kwargs) + + def flush(self) -> bool: + return bool(self.otel_sink.flush()) + + def close(self) -> bool: + return bool(self.otel_sink.close()) + + +class NoopTraceRecorder: + def record(self, *args: Any, **kwargs: Any) -> None: + return None + + def record_many(self, *args: Any, **kwargs: Any) -> list[TraceEvent]: + return [] + + async def mark(self, *args: Any, **kwargs: Any) -> None: + return None + + async def mark_many(self, *args: Any, **kwargs: Any) -> list[TraceEvent]: + return [] + + def flush(self) -> bool: + return True + + def close(self) -> bool: + return True + + +@dataclass +class _ActiveOtelSpan: + name: str + span: Any + + +def _otel_import_error() -> RuntimeError: + return RuntimeError( + "OpenTelemetry tracing requires opentelemetry-api, opentelemetry-sdk, and an OTLP exporter package. " + "Install xtuner[trace] or install opentelemetry-exporter-otlp-proto-grpc / " + "opentelemetry-exporter-otlp-proto-http." + ) + + +def _build_otel_exporter(config: TraceConfig) -> Any: + if config.otel_exporter == "console": + from opentelemetry.sdk.trace.export import ConsoleSpanExporter + + return ConsoleSpanExporter() + if config.otel_protocol == "http/protobuf": + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + return OTLPSpanExporter(endpoint=config.otel_endpoint) + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + return OTLPSpanExporter(endpoint=config.otel_endpoint, insecure=True) + + +class OtelTraceSink: + def __init__(self, config: TraceConfig) -> None: + try: + from opentelemetry import trace as otel_trace + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.trace import NonRecordingSpan, SpanContext, Status, StatusCode, TraceFlags, TraceState + except ImportError as exc: + raise _otel_import_error() from exc + + self.config = config + self._otel_trace = otel_trace + self._non_recording_span_cls = NonRecordingSpan + self._span_context_cls = SpanContext + self._status_cls = Status + self._status_code_cls = StatusCode + self._trace_flags_cls = TraceFlags + self._trace_state_cls = TraceState + resource = Resource.create({"service.name": config.otel_service_name}) + self._provider = TracerProvider(resource=resource) + try: + exporter = _build_otel_exporter(config) + except ImportError as exc: + raise _otel_import_error() from exc + self._provider.add_span_processor( + BatchSpanProcessor(exporter, schedule_delay_millis=100, max_export_batch_size=512) + ) + try: + otel_trace.set_tracer_provider(self._provider) + except Exception: + pass + self._tracer = self._provider.get_tracer("xtuner.v1.rl.trace") + self._active: dict[str, list[_ActiveOtelSpan]] = {} + self._lock = threading.RLock() + + def append(self, event: TraceEvent) -> None: + name, lifecycle = _split_lifecycle_stage(event.stage) + if lifecycle == "start": + self._start_span(event, name) + return + if lifecycle in {"end", "error"}: + self._end_span(event, name, error=lifecycle == "error") + return + self._record_instant_span(event, name) + + def flush(self) -> bool: + return bool(self._provider.force_flush(timeout_millis=5_000)) + + def close(self) -> bool: + self._close_open_spans() + flushed = self.flush() + self._provider.shutdown() + return flushed + + def _start_span(self, event: TraceEvent, name: str) -> None: + with self._lock: + parent_context = self._current_parent_context(event.trace_id) + span = self._tracer.start_span( + name, + context=parent_context, + start_time=_unix_nano(event.timestamp_s), + attributes=_span_attributes(event, name, "start"), + ) + self._active.setdefault(event.trace_id, []).append(_ActiveOtelSpan(name=name, span=span)) + self._record_lifecycle_marker(event, name, "start") + + def _end_span(self, event: TraceEvent, name: str, *, error: bool) -> None: + with self._lock: + active = self._pop_active_span(event.trace_id, name) + if active is None: + self._record_orphan_span(event, name, error=error) + return + span = active.span + for key, value in _span_attributes(event, name, "error" if error else "end").items(): + span.set_attribute(key, value) + if event.elapsed_s is not None: + span.set_attribute("xtuner.elapsed_ms", event.elapsed_s * 1000.0) + if error: + span.set_status(self._status_cls(self._status_code_cls.ERROR, event.error_msg or "error")) + span.add_event( + _error_event_name(event), + _error_event_attributes(event), + timestamp=_unix_nano(event.timestamp_s), + ) + else: + span.set_status(self._status_cls(self._status_code_cls.OK)) + span.end(end_time=_unix_nano(event.timestamp_s)) + + def _record_instant_span(self, event: TraceEvent, name: str) -> None: + with self._lock: + parent_context = self._current_parent_context(event.trace_id) + span = self._tracer.start_span( + name, + context=parent_context, + start_time=_unix_nano(event.timestamp_s), + attributes=_span_attributes(event, name, "event"), + ) + span.end(end_time=_unix_nano(event.timestamp_s) + 1) + + def _record_lifecycle_marker(self, event: TraceEvent, name: str, lifecycle: str) -> None: + marker_name = f"{name}.{lifecycle}" + parent_context = self._current_parent_context(event.trace_id) + attrs = _span_attributes(event, name, lifecycle) + attrs["xtuner.lifecycle_marker"] = True + span = self._tracer.start_span( + marker_name, + context=parent_context, + start_time=_unix_nano(event.timestamp_s), + attributes=attrs, + ) + span.end(end_time=_unix_nano(event.timestamp_s) + 1) + + def _record_orphan_span(self, event: TraceEvent, name: str, *, error: bool) -> None: + start_s = event.timestamp_s - event.elapsed_s if event.elapsed_s is not None else event.timestamp_s + parent_context = self._root_context(event.trace_id) + span = self._tracer.start_span( + name, + context=parent_context, + start_time=_unix_nano(start_s), + attributes=_span_attributes(event, name, "error" if error else "end"), + ) + if event.elapsed_s is not None: + span.set_attribute("xtuner.elapsed_ms", event.elapsed_s * 1000.0) + if error: + span.set_status(self._status_cls(self._status_code_cls.ERROR, event.error_msg or "error")) + span.add_event( + _error_event_name(event), + _error_event_attributes(event), + timestamp=_unix_nano(event.timestamp_s), + ) + else: + span.set_status(self._status_cls(self._status_code_cls.OK)) + span.end(end_time=_unix_nano(event.timestamp_s)) + + def _pop_active_span(self, trace_id: str, name: str) -> _ActiveOtelSpan | None: + stack = self._active.get(trace_id) + if not stack: + return None + for index in range(len(stack) - 1, -1, -1): + if stack[index].name == name: + active = stack.pop(index) + if not stack: + self._active.pop(trace_id, None) + return active + return None + + def _current_parent_context(self, trace_id: str) -> Any: + stack = self._active.get(trace_id) + if stack: + return self._otel_trace.set_span_in_context(stack[-1].span) + return self._root_context(trace_id) + + def _root_context(self, trace_id: str) -> Any: + span_context = self._span_context_cls( + trace_id=_stable_trace_id(trace_id), + span_id=_stable_span_id(f"{trace_id}:root"), + is_remote=True, + trace_flags=self._trace_flags_cls(self._trace_flags_cls.SAMPLED), + trace_state=self._trace_state_cls(), + ) + return self._otel_trace.set_span_in_context(self._non_recording_span_cls(span_context)) + + def _close_open_spans(self) -> None: + with self._lock: + active_by_trace = self._active + self._active = {} + now_nano = _unix_nano(time.time()) + for stack in active_by_trace.values(): + while stack: + active = stack.pop() + active.span.set_attribute("xtuner.span.closed_by_runtime", True) + active.span.end(end_time=now_nano) + + +def _split_lifecycle_stage(stage: str) -> tuple[str, str | None]: + for suffix, lifecycle in ((".start", "start"), (".end", "end"), (".error", "error")): + if stage.endswith(suffix): + return stage[: -len(suffix)], lifecycle + return stage, None + + +def _span_attributes(event: TraceEvent, name: str, lifecycle: str) -> dict[str, Any]: + raw_attrs = { + "xtuner.trace_id": event.trace_id, + "case.id": event.trace_id, + "xtuner.stage": name, + "xtuner.stage_event": lifecycle, + "xtuner.task_name": event.task_name, + "xtuner.uid": event.uid, + "xtuner.session_uid": event.session_uid, + "xtuner.status": event.status, + "xtuner.train_step": event.train_step, + "xtuner.model_step": event.model_step, + "xtuner.producer_future_step": event.producer_future_step, + "xtuner.produce_batch_id": event.produce_batch_id, + "xtuner.worker_rank": event.worker_rank, + "xtuner.error.message": event.error_msg, + "error.type": event.error_type, + "run.id": get_trace_run_id(), + } + attrs = {key: value for key, value in raw_attrs.items() if value is not None} + if event.attributes: + for key, value in event.attributes.items(): + if key == "xtuner.stage.kind": + attrs[key] = value + else: + attrs[f"xtuner.attr.{key}"] = value + return _normalize_otel_attributes(attrs) + + +def _error_event_attributes(event: TraceEvent) -> dict[str, Any]: + attrs: dict[str, Any] = {} + if event.error_msg is not None: + attrs["error.message"] = event.error_msg + attrs["exception.message"] = event.error_msg + if event.error_type is not None: + attrs["error.type"] = event.error_type + attrs["exception.type"] = event.error_type + if event.error_stacktrace is not None: + attrs["exception.stacktrace"] = event.error_stacktrace + if event.status is not None: + attrs["xtuner.status"] = event.status + return attrs + + +def _error_event_name(event: TraceEvent) -> str: + if event.error_type is not None or event.error_stacktrace is not None: + return "exception" + return "error" + + +def stable_otel_trace_id(value: str) -> str: + trace_id = _stable_trace_id(value) + return f"{trace_id:032x}" + + +def _stable_trace_id(value: str) -> int: + trace_id = int.from_bytes(hashlib.blake2b(value.encode("utf-8"), digest_size=16).digest(), "big") + return trace_id or 1 + + +def _stable_span_id(value: str) -> int: + span_id = int.from_bytes(hashlib.blake2b(value.encode("utf-8"), digest_size=8).digest(), "big") + return span_id or 1 + + +def _unix_nano(timestamp_s: float) -> int: + return int(timestamp_s * _NANOSECONDS_PER_SECOND) + + +class TraceTargetResolver: + @classmethod + def as_rollout_state_list(cls, value: Any) -> list[RolloutState]: + if value is None: + return [] + if isinstance(value, RolloutState): + return [value] + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + states: list[RolloutState] = [] + for item in value: + if isinstance(item, RolloutState): + states.append(item) + return states + return [] + + @classmethod + def resolve( + cls, + bound_arguments: dict[str, Any], + *, + target: str | RolloutState | Sequence[RolloutState] | Callable[..., Any] | None, + target_getter: Callable[..., Any] | None, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + if target_getter is not None: + return target_getter(*args, **kwargs) + if isinstance(target, str): + return bound_arguments.get(target) + if callable(target): + return target(*args, **kwargs) + if target is not None: + return target + rollout_state = bound_arguments.get("rollout_state") + rollout_states = cls.as_rollout_state_list(rollout_state) + if rollout_states: + return rollout_states if len(rollout_states) > 1 else rollout_states[0] + for value in bound_arguments.values(): + states = cls.as_rollout_state_list(value) + if states: + return states if len(states) > 1 else states[0] + return None + + @classmethod + def record_event(cls, target: Any, name: str, **kwargs: Any) -> TraceEvent | list[TraceEvent] | None: + targets = cls.as_rollout_state_list(target) + recorder = current_trace_recorder() + if targets: + if len(targets) == 1: + return recorder.record(targets[0], name, **kwargs) + return recorder.record_many(targets, name, **kwargs) + return recorder.record(target, name, **kwargs) + + @classmethod + async def mark_event(cls, target: Any, name: str, **kwargs: Any) -> TraceEvent | list[TraceEvent] | None: + targets = cls.as_rollout_state_list(target) + recorder = current_trace_recorder() + if targets: + if len(targets) == 1: + return await recorder.mark(targets[0], name, **kwargs) + return await recorder.mark_many(targets, name, **kwargs) + return await recorder.mark(target, name, **kwargs) + + +_NOOP_TRACE_RECORDER = NoopTraceRecorder() +_CURRENT_TRACE_RECORDER: contextvars.ContextVar[TraceRecorder | NoopTraceRecorder | None] = contextvars.ContextVar( + "xtuner_current_trace_recorder", + default=None, +) + + +# Global trace runtime. Each process owns one local runtime, propagated through Ray env vars. +def current_trace_recorder() -> TraceRecorder | NoopTraceRecorder: + recorder = _CURRENT_TRACE_RECORDER.get() + if recorder is not None: + return recorder + return get_tracer() + + +@contextlib.contextmanager +def use_trace_recorder(recorder: TraceRecorder | NoopTraceRecorder): + token = _CURRENT_TRACE_RECORDER.set(recorder) + try: + yield + finally: + _CURRENT_TRACE_RECORDER.reset(token) + + +def configure_trace(config: TraceConfig | None) -> TraceRecorder | NoopTraceRecorder: + return _TRACE_RUNTIME_MANAGER.configure(config) + + +def get_tracer() -> TraceRecorder | NoopTraceRecorder: + return _TRACE_RUNTIME_MANAGER.get_tracer() + + +def flush_trace() -> bool: + return _TRACE_RUNTIME_MANAGER.flush() + + +def close_trace() -> None: + _TRACE_RUNTIME_MANAGER.close() + + +def reset_trace_for_test() -> None: + close_trace() + _CURRENT_TRACE_RECORDER.set(None) + + +def get_trace_env_vars() -> dict[str, str]: + return _TRACE_RUNTIME_MANAGER.env_vars() + + +def merge_trace_runtime_env(actor_options: dict[str, Any]) -> dict[str, Any]: + return _TRACE_RUNTIME_MANAGER.merge_runtime_env(actor_options) + + +def build_trace_id(task_name: str | None, uid: int | str | None) -> str | None: + return TraceEventBuilder.trace_id(task_name, uid) + + +def build_produce_batch_id( + train_step: int | None, + model_step: int | None, + producer_future_step: int | None, +) -> str | None: + return TraceEventBuilder.produce_batch_id(train_step, model_step, producer_future_step) + + +def build_trace_event( + target: Any, + stage: str, + *, + task_name: str | None = None, + uid: int | str | None = None, + session_uid: int | str | None = None, + status: Any | None = None, + train_step: int | None = None, + model_step: int | None = None, + producer_future_step: int | None = None, + produce_batch_id: str | None = None, + worker_rank: int | None = None, + elapsed_s: float | None = None, + error_msg: str | None = None, + error_type: str | None = None, + error_stacktrace: str | None = None, + timestamp_s: float | None = None, + attributes: Mapping[str, Any] | None = None, + **custom_attributes: Any, +) -> TraceEvent | None: + return TraceEventBuilder.build( + target, + stage, + task_name=task_name, + uid=uid, + session_uid=session_uid, + status=status, + train_step=train_step, + model_step=model_step, + producer_future_step=producer_future_step, + produce_batch_id=produce_batch_id, + worker_rank=worker_rank, + elapsed_s=elapsed_s, + error_msg=error_msg, + error_type=error_type, + error_stacktrace=error_stacktrace, + timestamp_s=timestamp_s, + attributes=attributes, + **custom_attributes, + ) + + +def trace_enabled_from_env() -> bool: + exporter = (os.environ.get(OTEL_TRACES_EXPORTER_ENV) or "").strip().lower() + if exporter in {"none", "false", "off", "0"}: + return False + return bool(os.environ.get(OTEL_TRACES_EXPORTER_ENV) or os.environ.get(OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV)) + + +def extract_trace_context(headers: Mapping[str, Any] | None) -> Any: + if not trace_enabled_from_env(): + return None + try: + from opentelemetry import propagate + + return propagate.extract(headers or {}) + except Exception: + return None + + +def inject_trace_context(headers: MutableMapping[str, str]) -> None: + if not trace_enabled_from_env(): + return + try: + get_tracer() + from opentelemetry import propagate + + propagate.inject(headers) + except Exception: + return + + +def make_trace_context_carrier() -> dict[str, str]: + carrier: dict[str, str] = {} + inject_trace_context(carrier) + return carrier + + +@contextlib.contextmanager +def use_trace_context(trace_context: Any) -> Iterator[None]: + if not trace_enabled_from_env() or trace_context is None: + yield + return + + token = None + try: + from opentelemetry import context as otel_context + + token = otel_context.attach(trace_context) + except Exception: + yield + return + + try: + yield + finally: + if token is not None: + try: + otel_context.detach(token) + except Exception: + return + + +@contextlib.contextmanager +def trace_baggage(attrs: Mapping[str, Any] | None) -> Iterator[None]: + if not attrs or not trace_enabled_from_env(): + yield + return + + token = None + try: + from opentelemetry import baggage, context + + ctx = context.get_current() + for key, value in attrs.items(): + if value is None: + continue + ctx = baggage.set_baggage(str(key), str(value), context=ctx) + token = context.attach(ctx) + except Exception: + yield + return + + try: + yield + finally: + if token is not None: + try: + context.detach(token) + except Exception: + return + + +@contextlib.contextmanager +def trace_task_context(attrs: Mapping[str, Any] | None) -> Iterator[None]: + """Attach task trace context and baggage for child OpenTelemetry spans.""" + if not attrs or not trace_enabled_from_env(): + yield + return + + trace_id = attrs.get("xtuner.trace_id") or attrs.get("case.id") + if not isinstance(trace_id, str) or not trace_id: + with trace_baggage(attrs): + yield + return + + token = None + otel_context_module = None + try: + get_tracer() + from opentelemetry import baggage + from opentelemetry import context as imported_otel_context + from opentelemetry import trace as otel_trace + from opentelemetry.trace import NonRecordingSpan, SpanContext, TraceFlags, TraceState + + span_context = SpanContext( + trace_id=_stable_trace_id(trace_id), + span_id=_stable_span_id(f"{trace_id}:task_context"), + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + trace_state=TraceState(), + ) + ctx = otel_trace.set_span_in_context(NonRecordingSpan(span_context)) + for key, value in attrs.items(): + if value is None: + continue + ctx = baggage.set_baggage(str(key), str(value), context=ctx) + token = imported_otel_context.attach(ctx) + otel_context_module = imported_otel_context + except Exception: + with trace_baggage(attrs): + yield + return + + try: + yield + finally: + if token is not None and otel_context_module is not None: + try: + otel_context_module.detach(token) + except Exception: + return + + +def set_otel_span_attrs(span: Any, **attrs: Any) -> None: + if span is None: + return + for key, value in _normalize_otel_attributes(attrs).items(): + if value is None: + continue + try: + span.set_attribute(key, value) + except Exception: + continue + + +def record_otel_exception(span: Any, exc: BaseException) -> None: + if span is None: + return + try: + from opentelemetry.trace import Status, StatusCode + + span.record_exception(exc) + span.set_status(Status(StatusCode.ERROR, TraceEventBuilder.short_error(exc))) + span.set_attribute("error.type", TraceEventBuilder.error_type(exc)) + span.set_attribute("error.message", str(exc)) + except Exception: + return + + +@contextlib.contextmanager +def otel_trace_span(name: str, **attrs: Any) -> Iterator[Any]: + """Create a thin OpenTelemetry span in the current context. + + This helper does not resolve ``RolloutState``, ``trace_id``, task status, or + XTuner task metadata. Use it for low-level boundaries that already run under + an extracted OTel context, such as session-server HTTP forwarding. + Task-aware XTuner code should prefer ``xtuner_trace_function`` or + ``xtuner_trace_span``. + """ + if not trace_enabled_from_env(): + yield None + return + try: + get_tracer() + from opentelemetry import trace as otel_trace + + tracer = otel_trace.get_tracer("xtuner.v1.rl.trace") + except Exception: + yield None + return + with tracer.start_as_current_span(name) as span: + set_otel_span_attrs(span, **attrs) + try: + yield span + except Exception as exc: + record_otel_exception(span, exc) + raise + + +def begin_otel_span(name: str, **attrs: Any) -> Any: + if not trace_enabled_from_env(): + return None + try: + get_tracer() + from opentelemetry import trace as otel_trace + + span = otel_trace.get_tracer("xtuner.v1.rl.trace").start_span(name) + set_otel_span_attrs(span, **attrs) + return span + except Exception: + return None + + +def end_otel_span(span: Any, exc: BaseException | None = None, **attrs: Any) -> None: + if span is None: + return + if exc is not None: + record_otel_exception(span, exc) + set_otel_span_attrs(span, **attrs) + try: + span.end() + except Exception: + return + + +async def trace_event(target: Any, name: str, **kwargs: Any) -> TraceEvent | list[TraceEvent] | None: + return await TraceTargetResolver.mark_event(target, name, **kwargs) + + +class TraceSpanHandle: + def __init__(self, attributes: Mapping[str, Any] | None = None) -> None: + self.error_msg: str | None = None + self.attributes = dict(attributes or {}) + + def mark_error(self, error_msg: str) -> None: + self.error_msg = error_msg + + def annotate(self, **fields: Any) -> None: + attributes = TraceEventBuilder.normalize_attributes(fields) + if attributes: + self.attributes.update(attributes) + + +@contextlib.asynccontextmanager +async def xtuner_trace_span(target: Any, name: str, **kwargs: Any): + """Trace an inline async block as one XTuner rollout-task span. + + This is the replacement for the old sandbox ``span(uid, stage, **extra)`` + context manager. ``target`` can be a ``RolloutState`` or any object with + ``task_name``/``data_source``, ``uid``, and optional ``status`` fields. + Unknown keyword arguments are stored as event attributes; known trace + fields such as ``task_name`` and ``uid`` keep their normal meaning. + + The yielded handle supports: + - ``annotate(**fields)``: append runtime-discovered attributes to the exit + event, such as sandbox URL or env id. + - ``mark_error(message)``: record a ``.error`` event even when the block + returns normally but the business result failed. + """ + start_time = time.monotonic() + await trace_event(target, f"{name}.start", **kwargs) + handle = TraceSpanHandle(TraceEventBuilder.custom_attributes_from_kwargs(kwargs)) + end_kwargs = dict(kwargs) + for key in ("attributes", "elapsed_s", "error_msg", "timestamp_s"): + end_kwargs.pop(key, None) + try: + yield handle + except Exception as exc: + error_msg = TraceEventBuilder.short_error(exc) + error_type = TraceEventBuilder.error_type(exc) + error_stacktrace = TraceEventBuilder.stacktrace(exc) + handle.mark_error(error_msg) + for key in handle.attributes: + end_kwargs.pop(key, None) + await trace_event( + target, + f"{name}.error", + elapsed_s=time.monotonic() - start_time, + error_msg=error_msg, + error_type=error_type, + error_stacktrace=error_stacktrace, + attributes=handle.attributes, + **end_kwargs, + ) + raise + else: + for key in handle.attributes: + end_kwargs.pop(key, None) + stage_suffix = "end" if handle.error_msg is None else "error" + await trace_event( + target, + f"{name}.{stage_suffix}", + elapsed_s=time.monotonic() - start_time, + error_msg=handle.error_msg, + attributes=handle.attributes, + **end_kwargs, + ) + + +class TraceFunctionDecorator: + def __init__( + self, + name: str | Callable[..., str], + *, + target: str | RolloutState | Sequence[RolloutState] | Callable[..., Any] | None = None, + target_getter: Callable[..., Any] | None = None, + trace_kwargs_getter: Callable[..., dict[str, Any] | None] | None = None, + trace_kwargs: dict[str, Any] | None = None, + ) -> None: + self.name = name + self.target = target + self.target_getter = target_getter + self.trace_kwargs_getter = trace_kwargs_getter + self.trace_kwargs = trace_kwargs or {} + + def decorate(self, func: Callable[..., Any]): + signature = inspect.signature(func) + if inspect.iscoroutinefunction(func): + return self._decorate_async(func, signature) + return self._decorate_sync(func, signature) + + def _start_target(self, signature: inspect.Signature, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + bound = signature.bind_partial(*args, **kwargs) + bound.apply_defaults() + return TraceTargetResolver.resolve( + bound.arguments, + target=self.target, + target_getter=self.target_getter, + args=args, + kwargs=kwargs, + ) + + def _end_target(self, start_target: Any, return_value: Any) -> Any: + if TraceTargetResolver.as_rollout_state_list(return_value): + return return_value + return start_target + + def _name(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: + if callable(self.name): + return self.name(*args, **kwargs) + return self.name + + def _kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: + resolved = dict(self.trace_kwargs) + if self.trace_kwargs_getter is None: + return resolved + dynamic_kwargs = self.trace_kwargs_getter(*args, **kwargs) + if dynamic_kwargs: + resolved.update(dynamic_kwargs) + return resolved + + def _decorate_async(self, func: Callable[..., Any], signature: inspect.Signature): + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any): + start_target = self._start_target(signature, args, kwargs) + trace_name = self._name(args, kwargs) + trace_kwargs = self._kwargs(args, kwargs) + start_time = time.monotonic() + start_timestamp_s = time.time() + start_event = await trace_event( + start_target, + f"{trace_name}.start", + timestamp_s=start_timestamp_s, + **trace_kwargs, + ) + try: + return_value = await func(*args, **kwargs) + except Exception as exc: + await trace_event( + start_target, + f"{trace_name}.error", + elapsed_s=time.monotonic() - start_time, + error_msg=TraceEventBuilder.short_error(exc), + error_type=TraceEventBuilder.error_type(exc), + error_stacktrace=TraceEventBuilder.stacktrace(exc), + **trace_kwargs, + ) + raise + end_target = self._end_target(start_target, return_value) + if start_event is None and TraceTargetResolver.as_rollout_state_list(end_target): + await trace_event( + end_target, + f"{trace_name}.start", + timestamp_s=start_timestamp_s, + **trace_kwargs, + ) + await trace_event( + end_target, + f"{trace_name}.end", + elapsed_s=time.monotonic() - start_time, + **trace_kwargs, + ) + return return_value + + return async_wrapper + + def _decorate_sync(self, func: Callable[..., Any], signature: inspect.Signature): + @functools.wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any): + start_target = self._start_target(signature, args, kwargs) + trace_name = self._name(args, kwargs) + trace_kwargs = self._kwargs(args, kwargs) + start_time = time.monotonic() + start_timestamp_s = time.time() + start_event = TraceTargetResolver.record_event( + start_target, + f"{trace_name}.start", + timestamp_s=start_timestamp_s, + **trace_kwargs, + ) + try: + return_value = func(*args, **kwargs) + except Exception as exc: + TraceTargetResolver.record_event( + start_target, + f"{trace_name}.error", + elapsed_s=time.monotonic() - start_time, + error_msg=TraceEventBuilder.short_error(exc), + error_type=TraceEventBuilder.error_type(exc), + error_stacktrace=TraceEventBuilder.stacktrace(exc), + **trace_kwargs, + ) + raise + end_target = self._end_target(start_target, return_value) + if start_event is None and TraceTargetResolver.as_rollout_state_list(end_target): + TraceTargetResolver.record_event( + end_target, + f"{trace_name}.start", + timestamp_s=start_timestamp_s, + **trace_kwargs, + ) + TraceTargetResolver.record_event( + end_target, + f"{trace_name}.end", + elapsed_s=time.monotonic() - start_time, + **trace_kwargs, + ) + return return_value + + return sync_wrapper + + +def xtuner_trace_function( + name: str | Callable[..., str], + *, + target: str | RolloutState | Sequence[RolloutState] | Callable[..., Any] | None = None, + target_getter: Callable[..., Any] | None = None, + trace_kwargs_getter: Callable[..., dict[str, Any] | None] | None = None, + **trace_kwargs: Any, +): + """Trace a whole sync/async function as one XTuner rollout-task span. + + Target resolution for the `.start` event: + - If `target_getter` is provided, use its return value. + - Else if `target` is provided, resolve that explicit target. + - Else prefer the argument named `rollout_state` when it is a `RolloutState` + or `list[RolloutState]`. + - Else fall back to the first `RolloutState` / `list[RolloutState]` found + in the bound arguments. + + Target resolution for the `.end` event: + - If the function returns a `RolloutState` or `list[RolloutState]`, use the + return value so the end event reflects the latest task state. + - Otherwise reuse the start target. + + In practice this means standard XTuner functions whose task parameter is + named `rollout_state` usually do not need to pass `target=...`. Functions + with non-standard parameter names such as `group` should still pass an + explicit `target`. + """ + return TraceFunctionDecorator( + name, + target=target, + target_getter=target_getter, + trace_kwargs_getter=trace_kwargs_getter, + trace_kwargs=trace_kwargs, + ).decorate + + +# Runtime wrapper that owns the OpenTelemetry exporter lifecycle. +@dataclass +class TraceRuntime: + config: TraceConfig + recorder: TraceRecorder | NoopTraceRecorder + + def flush(self) -> bool: + return self.recorder.flush() + + def close(self) -> bool: + return self.recorder.close() + + +def build_trace_runtime(config: TraceConfig | None) -> TraceRuntime: + if config is None: + config = TraceConfig() + if not config.enabled: + return TraceRuntime(config=config, recorder=_NOOP_TRACE_RECORDER) + + return TraceRuntime(config=config, recorder=TraceRecorder(OtelTraceSink(config))) + + +class TraceRuntimeManager: + def __init__(self) -> None: + self._runtime: TraceRuntime | None = None + self._identity: tuple[Any, ...] | None = None + self._lock = threading.RLock() + self._atexit_registered = False + + def configure(self, config: TraceConfig | None) -> TraceRecorder | NoopTraceRecorder: + config = config or TraceConfig() + self._export_env(config) + return self._replace_runtime(config).recorder + + def get_tracer(self) -> TraceRecorder | NoopTraceRecorder: + config = self._load_config_from_env() + identity = self._config_identity(config) + with self._lock: + if self._runtime is not None and self._identity == identity: + return self._runtime.recorder + return self._replace_runtime(config).recorder + + def flush(self) -> bool: + with self._lock: + runtime = self._runtime + if runtime is not None: + return runtime.flush() + return True + + def close(self) -> None: + with self._lock: + runtime = self._runtime + self._runtime = None + self._identity = None + self._clear_env() + if runtime is not None: + runtime.close() + + def env_vars(self) -> dict[str, str]: + if not trace_enabled_from_env(): + return {} + env_vars: dict[str, str] = {} + for name in ( + OTEL_TRACES_EXPORTER_ENV, + OTEL_EXPORTER_OTLP_PROTOCOL_ENV, + OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV, + OTEL_SERVICE_NAME_ENV, + XTUNER_OTEL_RUN_ID_ENV, + ): + value = os.environ.get(name) + if value is not None: + env_vars[name] = value + return env_vars + + def merge_runtime_env(self, actor_options: dict[str, Any]) -> dict[str, Any]: + trace_env_vars = self.env_vars() + if not trace_env_vars: + return actor_options + runtime_env = dict(actor_options.get("runtime_env") or {}) + env_vars = dict(runtime_env.get("env_vars") or {}) + env_vars.update(trace_env_vars) + runtime_env["env_vars"] = env_vars + actor_options["runtime_env"] = runtime_env + return actor_options + + def _replace_runtime(self, config: TraceConfig) -> TraceRuntime: + identity = self._config_identity(config) + with self._lock: + if self._runtime is not None and self._identity == identity: + return self._runtime + old_runtime = self._runtime + self._runtime = build_trace_runtime(config) + self._identity = identity + self._register_atexit() + runtime = self._runtime + if old_runtime is not None: + old_runtime.close() + return runtime + + def _register_atexit(self) -> None: + if self._atexit_registered: + return + atexit.register(close_trace) + self._atexit_registered = True + + @staticmethod + def _export_env(config: TraceConfig) -> None: + if not config.enabled: + TraceRuntimeManager._clear_env() + return + + os.environ[OTEL_TRACES_EXPORTER_ENV] = config.otel_exporter + os.environ[OTEL_EXPORTER_OTLP_PROTOCOL_ENV] = config.otel_protocol + os.environ[OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV] = config.otel_endpoint + os.environ[OTEL_SERVICE_NAME_ENV] = config.otel_service_name + run_id = get_trace_run_id() + if run_id: + os.environ[XTUNER_OTEL_RUN_ID_ENV] = run_id + + @staticmethod + def _clear_env() -> None: + for name in ( + OTEL_TRACES_EXPORTER_ENV, + OTEL_EXPORTER_OTLP_PROTOCOL_ENV, + OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV, + OTEL_SERVICE_NAME_ENV, + ): + os.environ.pop(name, None) + + @staticmethod + def _load_config_from_env() -> TraceConfig: + if not trace_enabled_from_env(): + return TraceConfig() + + protocol = os.environ.get(OTEL_EXPORTER_OTLP_PROTOCOL_ENV) or TraceConfig.model_fields["otel_protocol"].default + raw_exporter = os.environ.get(OTEL_TRACES_EXPORTER_ENV) + if raw_exporter == "otlp_proto_http": + protocol = "http/protobuf" + raw_exporter = "otlp" + elif raw_exporter == "otlp_proto_grpc": + protocol = "grpc" + raw_exporter = "otlp" + if protocol not in {"grpc", "http/protobuf"}: + protocol = TraceConfig.model_fields["otel_protocol"].default + exporter = raw_exporter or TraceConfig.model_fields["otel_exporter"].default + if exporter not in {"otlp", "console"}: + exporter = TraceConfig.model_fields["otel_exporter"].default + return TraceConfig( + enabled=True, + otel_endpoint=( + os.environ.get(OTEL_EXPORTER_OTLP_TRACES_ENDPOINT_ENV) + or TraceConfig.model_fields["otel_endpoint"].default + ), + otel_protocol=cast(Literal["grpc", "http/protobuf"], protocol), + otel_exporter=cast(Literal["otlp", "console"], exporter), + otel_service_name=os.environ.get( + OTEL_SERVICE_NAME_ENV, + TraceConfig.model_fields["otel_service_name"].default, + ), + ) + + @staticmethod + def _config_identity(config: TraceConfig) -> tuple[Any, ...]: + return ( + config.enabled, + config.otel_endpoint, + config.otel_protocol, + config.otel_exporter, + config.otel_service_name, + ) + + +_TRACE_RUNTIME_MANAGER = TraceRuntimeManager() diff --git a/xtuner/v1/rl/utils/ray_accelerator_worker.py b/xtuner/v1/rl/utils/ray_accelerator_worker.py index a7d52d981a..8b9d40a1b2 100644 --- a/xtuner/v1/rl/utils/ray_accelerator_worker.py +++ b/xtuner/v1/rl/utils/ray_accelerator_worker.py @@ -15,6 +15,8 @@ ) from typing_extensions import Annotated +from xtuner.v1.rl.trace import merge_trace_runtime_env + from .ray_utils import find_master_addr_and_port, get_accelerator_ids @@ -458,6 +460,7 @@ def from_placement_group( (rank, bundle_index). """ pg_options = cls.get_pg_options(pg) + merge_trace_runtime_env(pg_options) device_type = cls.get_device_type(pg) sorted_bundle_idxs, master_addr, master_port, world_size = cls.get_spmd_info(pg) diff --git a/xtuner/v1/rl/utils/ray_cpu_worker.py b/xtuner/v1/rl/utils/ray_cpu_worker.py index ed9ebc3783..e69349f69b 100644 --- a/xtuner/v1/rl/utils/ray_cpu_worker.py +++ b/xtuner/v1/rl/utils/ray_cpu_worker.py @@ -17,6 +17,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from typing_extensions import Annotated +from xtuner.v1.rl.trace import merge_trace_runtime_env from xtuner.v1.utils.logger import get_logger @@ -185,6 +186,7 @@ def build_actor( } if resolved_memory is not None and resolved_memory > 0: actor_options["memory"] = resolved_memory + merge_trace_runtime_env(actor_options) if pg is None: return actor_cls.options(**actor_options).remote(*init_args, **init_kwargs) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 7073b881ff..b060c76d24 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -42,6 +42,7 @@ ) from xtuner.v1.rl.rollout.controller import RolloutControllerProxy from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.trace import TraceConfig, close_trace, configure_trace, get_trace_run_id from xtuner.v1.rl.trainer.controller import TrainingController from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem from xtuner.v1.rl.utils import ( @@ -327,6 +328,7 @@ class BaseRLTrainerConfig(BaseModel): train_batch_size: int advantage_estimator_config: BaseAdvantageConfig = Field(default_factory=GRPOAdvantageConfig) sync_weights_interval: int = 1 + trace_config: TraceConfig = Field(default_factory=TraceConfig) enable_evaluate: bool = True enable_initial_evaluate: bool = False @@ -570,6 +572,7 @@ def _init_common(self, cfg: BaseRLTrainerConfig, *, meta_path: str, logger_tag: self._init_load_source(cfg) self._init_save_config(cfg) log_dir = self._init_logger(cfg, logger_tag) + self._init_trace(cfg) self._save_runtime_environment(log_dir) self._init_train_state(cfg) self._init_train_worker_config(cfg, log_dir) @@ -611,6 +614,36 @@ def _init_save_config(self, cfg: BaseRLTrainerConfig) -> None: self._checkpoint_no_save_replay_buffer = cfg.checkpoint_no_save_replay_buffer self._load_checkpoint_cfg = self._resolve_load_checkpoint_cfg(cfg.auto_resume, cfg.load_checkpoint_cfg) + def _init_trace(self, cfg: BaseRLTrainerConfig) -> None: + trace_config = cfg.trace_config + self._trace_config = trace_config + self._trace_dashboard_handle: Any | None = None + configure_trace(trace_config) + self._maybe_log_trace_viewer(trace_config) + + def _maybe_log_trace_viewer(self, trace_config: TraceConfig) -> None: + if not trace_config.enabled or trace_config.jaeger_query_url is None: + return + if get_rank() != 0: + return + self.logger.info(f"Jaeger Trace Viewer: {trace_config.jaeger_query_url}") + from xtuner.tools.jaeger_trace_dashboard import start_jaeger_trace_dashboard + + handle = start_jaeger_trace_dashboard( + trace_config.jaeger_query_url, + service_name=trace_config.otel_service_name, + run_id=get_trace_run_id(), + ) + self._trace_dashboard_handle = handle + self.logger.info(f"XTuner Task Trace Dashboard: {handle.url}") + + def _close_trace(self) -> None: + trace_dashboard_handle = getattr(self, "_trace_dashboard_handle", None) + if trace_dashboard_handle is not None: + trace_dashboard_handle.close() + self._trace_dashboard_handle = None + close_trace() + def _init_logger(self, cfg: BaseRLTrainerConfig, logger_tag: str) -> Path: log_dir = self.exp_dir / "logs" log_dir.mkdir(parents=True, exist_ok=True) @@ -1655,6 +1688,12 @@ def _sync_weights_from_train_workers(self) -> None: self.logger.info("Rollout workers updated weights from train workers.") def fit(self): + try: + return self._fit() + finally: + self._close_trace() + + def _fit(self): self.logger.info("Start RL training") if self._cur_step >= self._total_train_steps: self.logger.info(f"Train steps {self._total_train_steps} reached, stop training") @@ -1857,8 +1896,11 @@ def _resume_from_checkpoint(self, checkpoint_path: Path | str) -> None: asyncio_run(self.agent_loop_manager.continue_produce(model_step=saved_model_step)) def fit(self): - # 对外同步 fit;内部用 async loop 组织 producer/consumer。 - return asyncio_run(self._fit()) + # 对外保留同步 fit 接口,内部用 async loop 组织 producer/consumer。 + try: + return asyncio_run(self._fit()) + finally: + self._close_trace() async def _get_batch_or_raise_producer_failure( self,