Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/serverless/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ For more complex operations where you are downloading files or making changes to
return {"output": "Job completed successfully"}
```

## Stopping Individual Jobs

A worker can process more than one job concurrently. When a single request is cancelled, expires, or times out, the Runpod server signals the worker to stop just that request without affecting the worker's other in-progress jobs. The worker long-polls a dedicated stop channel so these signals arrive with low latency, and it cancels the task running the matching job, so a stopped job no longer consumes worker time.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"long-polls a dedicated stop channel so these signals arrive with low latency" describes server behavior as if the SDK guarantees it. The client only issues a single GET; whether the connection is held open vs. returns 204 immediately is entirely the server's contract, and monitor_stop_signals has no enforced floor delay (see the busy-spin comment), so if the server doesn't long-poll this becomes continuous polling rather than low-latency long-poll.

Suggest attributing the behavior to the server, e.g. "The worker continuously polls a dedicated stop channel; the server is expected to hold each request open (long-poll) until a stop signal is available or the poll times out." Same applies to the get_stop_signals docstring.

Separately at line 65: the cleanup advice should note that a handler catching asyncio.CancelledError must re-raise after cleanup (as the SDK's own handle_job does) — otherwise the stop is swallowed and the job is reported completed.


No handler changes are required to support this. Handlers that hold resources can perform cleanup by catching `asyncio.CancelledError` in async handlers.

## See Also

- [Worker Fitness Checks](./worker_fitness_checks.md) - Validate your worker environment at startup
Expand Down
56 changes: 56 additions & 0 deletions runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,62 @@
job_progress = JobsProgress()


def _job_stop_url() -> Optional[str]:
"""
Prepare the URL for the worker's dedicated stop channel.

Derived from the job-take URL so it points at the same endpoint and worker.
Returns None when the job-take URL is not in the expected form.
"""
base_url = JOB_GET_URL.split("?")[0]
if "/job-take/" not in base_url:
return None
return base_url.replace("/job-take/", "/job-stop/")
Comment on lines +36 to +39


async def get_stop_signals(session: ClientSession) -> List[str]:
"""
Long-poll the dedicated stop channel for request ids the worker should stop.

The server holds the request open until a stop signal is available or the
poll times out, so cancellations and timeouts reach the worker without
waiting for the next heartbeat.

Returns:
A list of request ids to stop. Empty when the poll returned no signals.
"""
stop_url = _job_stop_url()
if not stop_url:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When _job_stop_url() returns None, the entire stop feature is silently disabled for the worker's whole lifetime. Note JOB_GET_URL is built as str(os.environ.get("RUNPOD_WEBHOOK_GET_JOB")).replace(...) — if that env var is unset it becomes the literal string "None", which has no /job-take/, so _job_stop_url() returns None and this branch makes the stop loop a permanent no-op that logs nothing. Every cancelled job on a misconfigured worker bills in full, invisibly.

This is exactly the "validate config at startup, fail fast/loud" case. Suggest validating once at worker startup (so it's not per-poll log spam) and emitting a single clear diagnostic when the stop URL can't be derived, rather than silently returning [] forever.

return []

async with session.get(stop_url) as response:
if response.status == 204:
return []

if response.status == 429:
raise TooManyRequests(
response.request_info,
response.history,
status=response.status,
message=response.reason,
)

response.raise_for_status()

if response.content_type != "application/json":

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Four server-contract violations here all return [] with zero logging:

  • non-JSON content-type on a 2xx (:71)
  • JSON parse failure (:76)
  • valid JSON but non-dict payload (:79)
  • non-string ids dropped by the isinstance(job_id, str) filter (:82)

Each one means a stop signal silently never fires, so cancelled jobs keep billing — and there's no trace to diagnose schema drift between server and worker. The guards themselves are good defensive code; they just need to be observable. Suggest a log.warn before each return (include the status/content-type/payload-type), and for the id filter, log when len(filtered) != len(raw) so a type change in jobsToStop surfaces instead of vanishing.

return []

try:
payload = await response.json()
except (aiohttp.ContentTypeError, ValueError):
return []

if not isinstance(payload, dict):
return []

return [job_id for job_id in payload.get("jobsToStop", []) if isinstance(job_id, str)]


def _job_get_url(batch_size: int = 1):
"""
Prepare the URL for making a 'get' request to the serverless API (sls).
Expand Down
68 changes: 64 additions & 4 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Dict, Set

from ...http_client import AsyncClientSession, ClientSession, TooManyRequests
from .rp_job import get_job, handle_job
from .rp_job import get_job, get_stop_signals, handle_job
from .rp_logger import RunPodLogger
from .worker_state import JobsProgress, IS_LOCAL_TEST

Expand Down Expand Up @@ -48,6 +48,12 @@ def __init__(self, config: Dict[str, Any]):
self.config = config
self.job_progress = JobsProgress() # Cache the singleton instance

# maps in-progress job ids to their running tasks so individual jobs
# can be stopped without killing the whole worker
self.jobs_tasks: Dict[str, asyncio.Task] = {}

self.stop_signals_fetcher = get_stop_signals

self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)

self.concurrency_modifier = _default_concurrency_modifier
Expand All @@ -71,6 +77,9 @@ def __init__(self, config: Dict[str, Any]):
if jobs_handler := self.config.get("jobs_handler"):
self.jobs_handler = jobs_handler

if stop_signals_fetcher := self.config.get("stop_signals_fetcher"):
self.stop_signals_fetcher = stop_signals_fetcher

async def set_scale(self):
self.current_concurrency = self.concurrency_modifier(self.current_concurrency)

Expand Down Expand Up @@ -128,8 +137,9 @@ async def run(self):
# Create tasks for getting and running jobs.
jobtake_task = asyncio.create_task(self.get_jobs(session))
jobrun_task = asyncio.create_task(self.run_jobs(session))
jobstop_task = asyncio.create_task(self.monitor_stop_signals(session))

tasks = [jobtake_task, jobrun_task]
tasks = [jobtake_task, jobrun_task, jobstop_task]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment two lines down still says "Concurrently run both tasks and wait for both to finish," but this list now holds three tasks (jobtake_task, jobrun_task, jobstop_task). Suggest updating it — or, to survive the next task addition, dropping the count entirely: "Run the worker's concurrent loops until shutdown."


# Concurrently run both tasks and wait for both to finish.
await asyncio.gather(*tasks)
Expand Down Expand Up @@ -226,9 +236,10 @@ async def run_jobs(self, session: ClientSession):
# Fetch as many jobs as the concurrency allows
while len(tasks) < self.current_concurrency and not self.jobs_queue.empty():
job = await self.jobs_queue.get()
# Create a new task for each job and add it to the task list
# Create a new task for each job and track it by job id
task = asyncio.create_task(self.handle_job(session, job))
tasks.add(task)
self.jobs_tasks[job["id"]] = task

# Wait for any job to finish
if tasks:
Expand All @@ -250,7 +261,51 @@ async def run_jobs(self, session: ClientSession):


# Ensure all remaining tasks finish before stopping
await asyncio.gather(*tasks)
await asyncio.gather(*tasks, return_exceptions=True)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to return_exceptions=True is the right call now that stopped tasks raise CancelledError during the final drain. The side effect is that genuine handler exceptions raised at shutdown are now collected into the (unused) result list and silently discarded. CancelledError from stopped jobs is fine to ignore here, but a real handler error shouldn't vanish. Consider inspecting the results and logging non-CancelledError exceptions:

results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
    if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError):
        log.error(f"run_jobs | Task failed during shutdown drain: {result}")


async def monitor_stop_signals(self, session: ClientSession):
"""
Long-polls the dedicated stop channel and stops signalled jobs.

Runs in an infinite loop while the worker is alive. The Runpod server
signals a request to be stopped (for example when it is cancelled or
times out) and this loop stops just that in-progress job, leaving the
worker's other jobs running.
"""
while self.is_alive():
try:
job_ids = await self.stop_signals_fetcher(session)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop has no backpressure on its normal path. The await asyncio.sleep(0) in the finally is a cooperative yield (it lets other tasks run) — correct to have, but it adds no delay. So the polling cadence is governed entirely by how long self.stop_signals_fetcher(session) takes to return, and the design depends on the server holding that GET open (long-poll). get_stop_signals returns [] immediately on HTTP 204, and nothing here enforces hold-open behavior. If the stop channel returns quickly (server not yet long-polling, a buffering proxy, or a misconfigured endpoint), this becomes a 100%-CPU hot loop firing back-to-back GETs.

The error branches already add real backoff (429→5s, generic→1s), so the concern is only the success/empty path. Suggest a floor delay when no ids were returned, e.g. await asyncio.sleep(1) on the empty result, so the loop can't busy-spin if the server doesn't long-poll.

for job_id in job_ids:
await self.stop_job(job_id)
Comment on lines +276 to +279
except TooManyRequests:
await asyncio.sleep(5) # debounce
except asyncio.CancelledError:
raise
except Exception as error:
log.debug(f"JobScaler.monitor_stop_signals | Error: {error}.")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This catch-all logs at debug, which is suppressed at the default production log level. monitor_stop_signals is the only loop that stops billed jobs — if get_stop_signals starts failing persistently (DNS/TLS error, stop endpoint 5xx, malformed JOB_GET_URL, a bug in stop_job), every stop signal is silently dropped, cancelled jobs keep running, and customers keep getting billed while the worker looks perfectly healthy. The sibling get_jobs loop logs the same catch-all at error (with the error type), and it's no more billing-critical than this one.

Suggest matching that: log.error with type(error).__name__, and ideally a consecutive-failure counter/metric so a dead stop channel is alertable rather than invisible. The broad except is correct here (the loop must survive) — it just needs to be loud.

await asyncio.sleep(1) # don't spin on persistent errors
finally:
await asyncio.sleep(0)

async def stop_job(self, job_id: str) -> bool:
"""
Stop a single in-progress job by cancelling its running task.

Args:
job_id: The id of the job to stop.

Returns:
True if a matching in-progress job was found and stopped,
False otherwise.
"""
task = self.jobs_tasks.get(job_id)
if task is None:
log.debug(f"JobScaler.stop_job | No in-progress job for {job_id}.")
return False

log.info("Stopping job.", job_id)
task.cancel()
return True

async def handle_job(self, session: ClientSession, job: dict):
"""
Expand All @@ -268,11 +323,16 @@ async def handle_job(self, session: ClientSession, job: dict):
log.error(f"Error handling job: {err}", job["id"])
raise err

except asyncio.CancelledError:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This except asyncio.CancelledError is placed after except Exception. It works only because CancelledError derives from BaseException (not Exception) on Python 3.8+, so it isn't shadowed — but it reads as dead code, and it silently breaks the moment someone widens the first clause to BaseException or wraps the cancellation. It's also inconsistent with get_jobs, which orders CancelledError first with an explanatory comment.

Suggest reordering except asyncio.CancelledError above except Exception (and mirroring the # CancelledError is a BaseException comment from get_jobs). While here: raise err on the line above rebinds the traceback — prefer a bare raise to preserve the original origin.

log.info("Job stopped.", job["id"])
raise

finally:
# Inform Queue of a task completion
self.jobs_queue.task_done()

# Job is no longer in progress
self.job_progress.remove(job)
self.jobs_tasks.pop(job["id"], None)

log.debug("Finished Job", job["id"])
70 changes: 70 additions & 0 deletions tests/test_serverless/test_modules/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,76 @@ async def test_get_job_exception(self):
self.assertEqual(str(context.exception), "Unexpected error")


class TestGetStopSignals(IsolatedAsyncioTestCase):
"""Tests for the get_stop_signals function."""

STOP_TAKE_URL = "http://mock.url/v2/ep/job-take/pod?gpu=x"

def test_job_stop_url_derived_from_job_take(self):
with patch("runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL):
assert rp_job._job_stop_url() == "http://mock.url/v2/ep/job-stop/pod"

Comment on lines +156 to +161
def test_job_stop_url_none_when_not_job_take(self):
with patch("runpod.serverless.modules.rp_job.JOB_GET_URL", "http://mock.url/other"):
assert rp_job._job_stop_url() is None

async def test_get_stop_signals_200(self):
response = Mock(ClientResponse)
response.status = 200
response.content_type = "application/json"
response.json = make_mocked_coro(return_value={"jobsToStop": ["a", "b", 5]})

with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL
):
mock_session.get.return_value.__aenter__.return_value = response
result = await rp_job.get_stop_signals(mock_session)
self.assertEqual(result, ["a", "b"])

async def test_get_stop_signals_204(self):
response = Mock(ClientResponse)
response.status = 204

with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL
):
mock_session.get.return_value.__aenter__.return_value = response
result = await rp_job.get_stop_signals(mock_session)
self.assertEqual(result, [])

async def test_get_stop_signals_429(self):
response = Mock(ClientResponse)
response.status = 429

with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL
):
mock_session.get.return_value.__aenter__.return_value = response
with self.assertRaises(TooManyRequests):
await rp_job.get_stop_signals(mock_session)

async def test_get_stop_signals_no_url(self):
with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", "http://mock.url/other"
):
result = await rp_job.get_stop_signals(mock_session)
self.assertEqual(result, [])
mock_session.get.assert_not_called()

async def test_get_stop_signals_non_dict_payload(self):
response = Mock(ClientResponse)
response.status = 200
response.content_type = "application/json"
response.json = make_mocked_coro(return_value=["not", "a", "dict"])

with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL
):
mock_session.get.return_value.__aenter__.return_value = response
result = await rp_job.get_stop_signals(mock_session)
self.assertEqual(result, [])


class TestRunJob(IsolatedAsyncioTestCase):
"""Tests the run_job function"""

Expand Down
86 changes: 86 additions & 0 deletions tests/test_serverless/test_rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,92 @@ async def handler(_session, _config, job):

scaler.kill_worker()

@pytest.mark.asyncio
async def test_stop_job_cancels_inflight_task(job_scaler: PatchScaler):
scaler = job_scaler.scaler
job_started = asyncio.Event()
cancelled = []

async def handler(_session, _config, job):
job_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
cancelled.append(job["id"])
raise

scaler.jobs_handler = handler
scaler.current_concurrency = 1
scaler.jobs_queue = asyncio.Queue(maxsize=1)
run_task = asyncio.create_task(scaler.run_jobs(None))

await scaler.jobs_queue.put(generate_job("stop-me"))
await asyncio.wait_for(job_started.wait(), timeout=2)

assert "stop-me" in scaler.jobs_tasks
assert await scaler.stop_job("stop-me") is True

scaler.kill_worker()
await asyncio.wait_for(run_task, timeout=2)

assert cancelled == ["stop-me"]
assert "stop-me" not in scaler.jobs_tasks
assert job_scaler.progress.count == 0

scaler.kill_worker()


@pytest.mark.asyncio
async def test_stop_job_unknown_id_returns_false(job_scaler: PatchScaler):
scaler = job_scaler.scaler
assert await scaler.stop_job("does-not-exist") is False
scaler.kill_worker()


@pytest.mark.asyncio
async def test_monitor_stop_signals_stops_jobs(job_scaler: PatchScaler):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two test-coverage notes for this area:

  1. This test monkeypatches both collaborators (scaler.stop_job and scaler.stop_signals_fetcher), so the only production code it exercises is the loop's id-forwarding — it's mocks talking to mocks, not integration. The real end-to-end wiring is already well covered by test_stop_job_cancels_inflight_task; consider a variant here where the fetcher returns ids and the real stop_job runs against a populated jobs_tasks, asserting the underlying task is cancelled.

  2. Several new branches are untested: get_stop_signals' raise_for_status() 5xx path, the wrong-content-type branch, and the JSON parse (ContentTypeError/ValueError) catch; plus monitor_stop_signals' TooManyRequests→sleep(5) branch and the CancelledError→re-raise. The sibling get_job suite already covers its 500/content-type cases — worth matching. For the backoff branches, patch asyncio.sleep so they're deterministic rather than wall-clock-dependent.

scaler = job_scaler.scaler
stopped = []

async def fake_stop_job(job_id):
stopped.append(job_id)
return True

async def fetcher(_session):
if not stopped:
return ["job-a", "job-b"]
return []

scaler.stop_job = fake_stop_job
scaler.stop_signals_fetcher = fetcher

monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock()))
await asyncio.sleep(0.05)
scaler.kill_worker()
await asyncio.wait_for(monitor_task, timeout=2)

assert sorted(stopped) == ["job-a", "job-b"]


@pytest.mark.asyncio
async def test_monitor_stop_signals_survives_errors(job_scaler: PatchScaler):
scaler = job_scaler.scaler
calls = {"value": 0}

async def fetcher(_session):
calls["value"] += 1
raise RuntimeError("boom")

scaler.stop_signals_fetcher = fetcher

monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock()))
await asyncio.sleep(0.05)
scaler.kill_worker()
await asyncio.wait_for(monitor_task, timeout=2)

assert calls["value"] >= 1


@pytest.mark.asyncio
async def test_get_jobs_feeds_workers_end_to_end(job_scaler: PatchScaler):
scaler = job_scaler.scaler
Expand Down