diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 184ee35dbc..c0d9ea6661 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -392,12 +392,20 @@ def async_save_dcp( # temporary .incomplete directory and commit it only after every rank's # async_save future has completed. incomplete_dir = weights_dir.with_name(f"{weights_dir.name}.incomplete") - if weights_dir.exists(): + # Check existence only on rank 0 and broadcast the result so all ranks + # raise (or continue) together. Without this, NFS cache inconsistencies + # could cause some ranks to raise while others proceed to the barrier, + # resulting in a deadlock. + dir_exists = torch.tensor(int(weights_dir.exists() if dist.get_rank() == 0 else 0), dtype=torch.int32) + dist.broadcast(dir_exists, src=0, group=async_checkpoint_pg) + if dir_exists.item(): raise FileExistsError(f"Checkpoint directory already exists: {weights_dir}") if dist.get_rank() == 0: if incomplete_dir.exists(): shutil.rmtree(incomplete_dir) incomplete_dir.mkdir(parents=True, exist_ok=True) + # Ensure rank 0 finishes rmtree+mkdir before any rank proceeds to write. + dist.barrier(group=async_checkpoint_pg) # XtunerCacheWriter.stage() creates its staging cache directly in POSIX # shared memory (/dev/shm). PyTorch's ForkingPickler detects @@ -440,7 +448,20 @@ def commit_async_save() -> None: dcp_future.result() break except BaseException as exc: - if attempt == max_daemon_init_attempts or not self._is_async_checkpoint_daemon_init_error(exc): + # Use all_reduce(MAX) so all ranks agree on whether this is + # a retryable daemon-init error. Without this, ranks that + # see a non-retryable error would raise (skipping the + # barrier) while other ranks wait at the barrier, causing a + # deadlock. + is_retryable = attempt < max_daemon_init_attempts and self._is_async_checkpoint_daemon_init_error( + exc + ) + # 0 = retryable, 1 = fatal; MAX means any fatal rank wins. + decision = torch.tensor(0 if is_retryable else 1, dtype=torch.int32) + dist.all_reduce(decision, op=dist.ReduceOp.MAX, group=async_checkpoint_pg) + is_fatal = bool(decision.item()) + + if is_fatal: elapsed = time.time() - t0 logger.error(f"[DCP async_save for {weights_dir}] failed after {elapsed:.2f}s: {exc}") logger.error(traceback.format_exc()) @@ -492,6 +513,51 @@ def _build_async_storage_writer(self, weights_dir: Path, *, save_optimizer: bool storage_writer.state_dict_cache = self._async_state_dict_cache return storage_writer + @classmethod + def warmup_async_save_dcp(cls, work_dir: Path) -> None: + """Warm up async DCP save infrastructure with a tiny dummy state dict. + + This triggers the full async save path — including daemon subprocess + spawn and its internal init_process_group — so that errors like port + conflicts (EADDRINUSE) surface before any real training begins. + + Args: + work_dir (Path): Working directory for temporary preflight files. + """ + preflight_dir = work_dir / ".preflight_dcp" + weights_dir = preflight_dir / "weights" + + if dist.get_rank() == 0: + if preflight_dir.exists(): + shutil.rmtree(preflight_dir) + weights_dir.mkdir(parents=True, exist_ok=True) + dist.barrier() + + dummy_state_dict = {"_preflight": torch.zeros(1)} + + try: + async_save_kwargs: dict[str, Any] = {} + state_dict_saver = importlib.import_module("torch.distributed.checkpoint.state_dict_saver") + async_checkpointer_type = getattr(state_dict_saver, "AsyncCheckpointerType", None) + if async_checkpointer_type is not None: + async_save_kwargs["async_checkpointer_type"] = async_checkpointer_type.PROCESS + + future = cast(Any, dcp.async_save)( + dummy_state_dict, + checkpoint_id=weights_dir, + **async_save_kwargs, + ) + future.result(timeout=300) + except Exception as e: + raise RuntimeError( + f"DCP warmup save failed. This usually indicates a port conflict " + f"or process group initialization issue. Error: {e}" + ) from e + finally: + if dist.get_rank() == 0 and preflight_dir.exists(): + shutil.rmtree(preflight_dir, ignore_errors=True) + dist.barrier() + def destroy_async_checkpoint_pg(self) -> None: """Destroy the dedicated gloo process group used for async checkpoint.""" diff --git a/xtuner/v1/patch/xtuner_storage.py b/xtuner/v1/patch/xtuner_storage.py index f06426f674..b715d15077 100644 --- a/xtuner/v1/patch/xtuner_storage.py +++ b/xtuner/v1/patch/xtuner_storage.py @@ -15,7 +15,7 @@ StreamTransformExtension, ) from torch.distributed.checkpoint.filesystem import FileSystem -from torch.distributed.checkpoint.staging import _copy_state_dict, _create_cpu_state_dict +from torch.distributed.checkpoint.staging import _copy_state_dict from torch.distributed.checkpoint.storage import ( WriteResult, ) @@ -25,6 +25,60 @@ logger = logging.getLogger(__name__) +def _create_coalesced_shm_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]: + """Create a CPU state dict backed by coalesced shared-memory buffers. + + Instead of creating one shared-memory file per tensor (which leads to + thousands of fds and triggers ``received 0 items of ancdata`` when the + daemon subprocess tries to receive them all), this function groups tensors + by dtype, allocates a single large shared-memory tensor per dtype, and + returns views into that buffer. + + Args: + state_dict (dict[str, Any]): The source state dict (tensors can be on + any device). + + Returns: + dict[str, Any]: A new state dict with the same keys, where every tensor + is a view into a dtype-coalesced shared-memory buffer. + """ + # Collect tensor metadata grouped by dtype + dtype_groups: dict[torch.dtype, list[tuple[str, torch.Size]]] = {} + for key, val in state_dict.items(): + if isinstance(val, torch.Tensor) and val.numel() > 0: + dtype_groups.setdefault(val.dtype, []).append((key, val.size())) + + # Allocate one coalesced buffer per dtype in shared memory + dtype_buffers: dict[torch.dtype, torch.Tensor] = {} + dtype_offsets: dict[torch.dtype, int] = {} + for dtype, items in dtype_groups.items(): + total_numel = sum(size.numel() for _, size in items) + buf = torch.empty(total_numel, dtype=dtype) + buf.share_memory_() + dtype_buffers[dtype] = buf + dtype_offsets[dtype] = 0 + + # Build the output state dict with views into coalesced buffers + result: dict[str, Any] = {} + for key, val in state_dict.items(): + if isinstance(val, torch.Tensor) and val.numel() > 0: + dtype = val.dtype + offset = dtype_offsets[dtype] + numel = val.numel() + view = dtype_buffers[dtype][offset : offset + numel].view(val.size()) + dtype_offsets[dtype] = offset + numel + result[key] = view + elif isinstance(val, torch.Tensor): + # Zero-numel tensors: just create a shared empty tensor + t = torch.zeros_like(val, device="cpu") + t.share_memory_() + result[key] = t + else: + result[key] = val + + return result + + # PyTorch 2.7+ introduced _extensions parameter for FileSystemWriter _TORCH_DCP_FSWRITER_HAS_EXTENSIONS = version.parse(torch.__version__) >= version.parse("2.7.0") @@ -194,16 +248,13 @@ def stage(self, state_dict: dict[str, Any]) -> dict[str, Any]: self.per_thread_copy_ahead = 0 if not self.cache_staged_state_dict: - staged_state_dict = _create_cpu_state_dict(state_dict, share_memory=True) + staged_state_dict = _create_coalesced_shm_state_dict(state_dict) return _copy_state_dict(state_dict, staged_state_dict, type_check=self.type_check) if self.state_dict_cache is None: if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: - logger.info("[DCP async_save] creating shared-memory staged cache") - self.state_dict_cache = _create_cpu_state_dict( - state_dict, - share_memory=True, - ) + logger.info("[DCP async_save] creating shared-memory staged cache (coalesced)") + self.state_dict_cache = _create_coalesced_shm_state_dict(state_dict) return _copy_state_dict(state_dict, self.state_dict_cache, type_check=self.type_check) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 4dc9efbaab..a7fc0c799c 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -739,6 +739,8 @@ def __init__( self._metrics_recorder = self._maybe_init_model_metrics_recorder(internal_metrics_cfg) + self._preflight_async_checkpoint() + @classmethod def from_config(cls, config: TrainerConfig) -> Self: """Create a Trainer instance from a TrainerConfig. @@ -865,6 +867,7 @@ def fit(self): ckpt_saved = self._maybe_save(is_snapshot=False) if not ckpt_saved: _ = self._maybe_save(is_snapshot=True) + self._check_async_save_health() time_before_get_data = time.time() @@ -1177,6 +1180,22 @@ def _maybe_check_health(self): raise RuntimeError("Health check failed, exit training") log_rank0.info(f"Health check passed at step {self.cur_step}") + def _preflight_async_checkpoint(self) -> None: + """Warm up async DCP save to surface daemon init errors early.""" + if not self._async_checkpoint: + return + log_rank0.info("Preflight: warming up async DCP save infrastructure...") + TrainEngine.warmup_async_save_dcp(work_dir=self.work_dir) + log_rank0.info("Preflight: async DCP save infrastructure verified OK.") + + def _check_async_save_health(self) -> None: + """Non-blocking check: if any pending async save has failed, raise immediately.""" + if self._pending_checkpoint is not None and self._pending_checkpoint.done(): + exc = self._pending_checkpoint.exception() + if exc is not None: + self._pending_checkpoint = None + raise RuntimeError(f"Async DCP checkpoint failed in background: {exc}") from exc + def _wait_for_pending_checkpoint(self, timeout: int = 3000) -> None: if self._pending_checkpoint is None: return @@ -1226,7 +1245,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: # Save model and optimizer future: Future | None = None - if self._async_checkpoint and not is_snapshot: + if self._async_checkpoint: future = self._engine.async_save_dcp(weights_dir=weights_path) else: self._engine.save_dcp(weights_dir=weights_path)