Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions src/madengine/deployment/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,27 +842,29 @@ def _generate_sglang_disagg_command(
Generate SGLang Disaggregated launcher environment for SLURM.

SGLang Disaggregated Architecture:
- Node 0: Proxy (load balancer)
- Nodes 1 to xP: Prefill nodes
- Nodes xP+1 to xP+yD: Decode nodes
- Prefill nodes: xP
- Decode nodes: yD
- Proxy/router: a dedicated node (1 + xP + yD == nnodes) or co-located
on the first prefill node (xP + yD == nnodes). The rank-to-role
assignment is handled by the model run.sh, not by this launcher.

Minimum cluster: 3 nodes (1 proxy + 1 prefill + 1 decode)
Minimum cluster: 2 nodes (co-located proxy + 1 prefill + 1 decode)

Args:
nnodes: Total number of nodes (must be >= 3)
nnodes: Total number of nodes (must be >= 2)
nproc_per_node: GPUs per node (tensor parallel size)
master_port: Master port for coordination

Returns:
Environment setup with node role assignment

Raises:
ValueError: If nnodes < 3 (minimum for disagg)
ValueError: If nnodes < 2 (minimum for disagg)
"""
if nnodes < 3:
if nnodes < 2:
raise ValueError(
f"SGLang Disaggregated requires minimum 3 nodes "
f"(1 proxy + 1 prefill + 1 decode), got {nnodes}"
f"SGLang Disaggregated requires minimum 2 nodes "
f"(co-located proxy + 1 prefill + 1 decode), got {nnodes}"
)
Comment thread
mkuznet1 marked this conversation as resolved.

# Check if custom split is specified in additional_context
Expand All @@ -877,16 +879,28 @@ def _generate_sglang_disagg_command(
f"SGLang Disaggregated requires at least 1 prefill and 1 decode node, "
f"got prefill={prefill_nodes}, decode={decode_nodes}"
)
if prefill_nodes + decode_nodes + 1 != nnodes:
# Accept either a dedicated proxy node (1 + xP + yD == nnodes) or a
# co-located proxy/router on the first prefill node (xP + yD == nnodes),
# mirroring the vllm-disagg layout. The proxy is launched by the model
# script (not by this launcher), so both topologies are valid here.
if (prefill_nodes + decode_nodes != nnodes
and prefill_nodes + decode_nodes + 1 != nnodes):
raise ValueError(
Comment thread
mkuznet1 marked this conversation as resolved.
f"Custom split validation failed: "
f"prefill_nodes ({prefill_nodes}) + decode_nodes ({decode_nodes}) + 1 proxy "
f"must equal nnodes ({nnodes}), but got {prefill_nodes + decode_nodes + 1}"
f"Custom split validation failed: prefill_nodes ({prefill_nodes}) + "
f"decode_nodes ({decode_nodes}) = {prefill_nodes + decode_nodes} must equal "
f"nnodes ({nnodes}) for a co-located proxy, or nnodes-1 ({nnodes - 1}) for a "
f"dedicated proxy node"
)
xP = prefill_nodes
yD = decode_nodes
elif nnodes == 2:
# Co-located proxy on the first prefill node: 1 prefill + 1 decode.
# The general default below assumes a dedicated proxy (nnodes-1 worker
# nodes), which would yield yD=0 for nnodes==2 and an invalid topology.
xP = 1
yD = 1
else:
# Default split: use golden ratio for prefill/decode
# Default split: dedicated proxy + golden-ratio prefill/decode split.
# For N total nodes: 1 proxy + ~40% prefill + ~60% decode
xP = max(1, (nnodes - 1) * 2 // 5) # ~40% of worker nodes
yD = nnodes - 1 - xP # remaining nodes
Expand All @@ -895,9 +909,10 @@ def _generate_sglang_disagg_command(
# ============================================
# Cluster Configuration:
# Total Nodes: {nnodes}
# Proxy: 1 node (NODE_RANK=0)
# Prefill: {xP} nodes (NODE_RANK=1 to {xP})
# Decode: {yD} nodes (NODE_RANK={xP+1} to {nnodes-1})
# Prefill: {xP} nodes
# Decode: {yD} nodes
# Proxy/router: dedicated node or co-located on the first prefill node
# (rank-to-role assignment handled by the model run.sh)
# ============================================

# Export cluster topology
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_slurm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,73 @@ def test_two_model_manifest_keys_match(self, tmp_path):
for mn in ("model_a", "model_b"):
assert manifest["built_models"][mn]["built_on_compute"] is True
assert "DOCKER_IMAGE_NAME" in manifest["built_models"][mn]["env_vars"]


# ---------------------------------------------------------------------------
# 6. SGLang-disagg default split topology
# ---------------------------------------------------------------------------
class TestSglangDisaggDefaultSplit:
"""`_generate_sglang_disagg_command` must produce a valid prefill/decode
split for the default (no custom prefill/decode) path, including the
co-located-proxy minimum of nnodes=2."""

@pytest.fixture
def deployment_factory(self, tmp_path: Path):
"""Build a SlurmDeployment with no custom sglang_disagg split, so
`_generate_sglang_disagg_command` takes the default-split branch."""
manifest = {
"built_images": {},
"built_models": {},
"context": {},
}
manifest_path = tmp_path / "build_manifest.json"
manifest_path.write_text(json.dumps(manifest))

cfg = DeploymentConfig(
target="slurm",
manifest_file=str(manifest_path),
additional_context={
"deploy": "slurm",
"slurm": {"output_dir": str(tmp_path / "slurm_results")},
},
)
return SlurmDeployment(cfg)

@staticmethod
def _parse_split(script: str):
"""Extract (xP, yD, total) from the exported topology env vars."""
xP = yD = total = None
for line in script.splitlines():
line = line.strip()
if line.startswith("export SGLANG_DISAGG_PREFILL_NODES="):
xP = int(line.split("=", 1)[1])
elif line.startswith("export SGLANG_DISAGG_DECODE_NODES="):
yD = int(line.split("=", 1)[1])
elif line.startswith("export SGLANG_DISAGG_TOTAL_NODES="):
total = int(line.split("=", 1)[1])
return xP, yD, total

def test_default_split_two_nodes_is_co_located(self, deployment_factory):
"""nnodes=2 default split must be co-located (xP=1, yD=1), not yD=0."""
script = deployment_factory._generate_sglang_disagg_command(
nnodes=2, nproc_per_node=8, master_port=12345
)
xP, yD, total = self._parse_split(script)
assert (xP, yD, total) == (1, 1, 2), f"expected xP=1,yD=1,total=2, got {(xP, yD, total)}"

@pytest.mark.parametrize("nnodes", [2, 3, 4, 5, 6, 8])
def test_default_split_always_has_prefill_and_decode(self, deployment_factory, nnodes):
"""Every supported nnodes must yield at least one prefill and one decode node."""
script = deployment_factory._generate_sglang_disagg_command(
nnodes=nnodes, nproc_per_node=8, master_port=12345
)
xP, yD, _ = self._parse_split(script)
assert xP >= 1, f"prefill nodes must be >= 1 for nnodes={nnodes}, got {xP}"
assert yD >= 1, f"decode nodes must be >= 1 for nnodes={nnodes}, got {yD}"

def test_below_minimum_nodes_raises(self, deployment_factory):
"""nnodes < 2 must raise ValueError (minimum cluster size)."""
with pytest.raises(ValueError):
deployment_factory._generate_sglang_disagg_command(
nnodes=1, nproc_per_node=8, master_port=12345
)