-
Notifications
You must be signed in to change notification settings - Fork 116
feat: add per-job stop capability to serverless worker #510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,62 @@ | |
| job_progress = JobsProgress() | ||
|
|
||
|
|
||
| def _job_stop_url() -> Optional[str]: | ||
| """ | ||
| Prepare the URL for the worker's dedicated stop channel. | ||
|
|
||
| Derived from the job-take URL so it points at the same endpoint and worker. | ||
| Returns None when the job-take URL is not in the expected form. | ||
| """ | ||
| base_url = JOB_GET_URL.split("?")[0] | ||
| if "/job-take/" not in base_url: | ||
| return None | ||
| return base_url.replace("/job-take/", "/job-stop/") | ||
|
Comment on lines
+36
to
+39
|
||
|
|
||
|
|
||
| async def get_stop_signals(session: ClientSession) -> List[str]: | ||
| """ | ||
| Long-poll the dedicated stop channel for request ids the worker should stop. | ||
|
|
||
| The server holds the request open until a stop signal is available or the | ||
| poll times out, so cancellations and timeouts reach the worker without | ||
| waiting for the next heartbeat. | ||
|
|
||
| Returns: | ||
| A list of request ids to stop. Empty when the poll returned no signals. | ||
| """ | ||
| stop_url = _job_stop_url() | ||
| if not stop_url: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When This is exactly the "validate config at startup, fail fast/loud" case. Suggest validating once at worker startup (so it's not per-poll log spam) and emitting a single clear diagnostic when the stop URL can't be derived, rather than silently returning |
||
| return [] | ||
|
|
||
| async with session.get(stop_url) as response: | ||
| if response.status == 204: | ||
| return [] | ||
|
|
||
| if response.status == 429: | ||
| raise TooManyRequests( | ||
| response.request_info, | ||
| response.history, | ||
| status=response.status, | ||
| message=response.reason, | ||
| ) | ||
|
|
||
| response.raise_for_status() | ||
|
|
||
| if response.content_type != "application/json": | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Four server-contract violations here all return
Each one means a stop signal silently never fires, so cancelled jobs keep billing — and there's no trace to diagnose schema drift between server and worker. The guards themselves are good defensive code; they just need to be observable. Suggest a |
||
| return [] | ||
|
|
||
| try: | ||
| payload = await response.json() | ||
| except (aiohttp.ContentTypeError, ValueError): | ||
| return [] | ||
|
|
||
| if not isinstance(payload, dict): | ||
| return [] | ||
|
|
||
| return [job_id for job_id in payload.get("jobsToStop", []) if isinstance(job_id, str)] | ||
|
|
||
|
|
||
| def _job_get_url(batch_size: int = 1): | ||
| """ | ||
| Prepare the URL for making a 'get' request to the serverless API (sls). | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| from typing import Any, Dict, Set | ||
|
|
||
| from ...http_client import AsyncClientSession, ClientSession, TooManyRequests | ||
| from .rp_job import get_job, handle_job | ||
| from .rp_job import get_job, get_stop_signals, handle_job | ||
| from .rp_logger import RunPodLogger | ||
| from .worker_state import JobsProgress, IS_LOCAL_TEST | ||
|
|
||
|
|
@@ -48,6 +48,12 @@ def __init__(self, config: Dict[str, Any]): | |
| self.config = config | ||
| self.job_progress = JobsProgress() # Cache the singleton instance | ||
|
|
||
| # maps in-progress job ids to their running tasks so individual jobs | ||
| # can be stopped without killing the whole worker | ||
| self.jobs_tasks: Dict[str, asyncio.Task] = {} | ||
|
|
||
| self.stop_signals_fetcher = get_stop_signals | ||
|
|
||
| self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) | ||
|
|
||
| self.concurrency_modifier = _default_concurrency_modifier | ||
|
|
@@ -71,6 +77,9 @@ def __init__(self, config: Dict[str, Any]): | |
| if jobs_handler := self.config.get("jobs_handler"): | ||
| self.jobs_handler = jobs_handler | ||
|
|
||
| if stop_signals_fetcher := self.config.get("stop_signals_fetcher"): | ||
| self.stop_signals_fetcher = stop_signals_fetcher | ||
|
|
||
| async def set_scale(self): | ||
| self.current_concurrency = self.concurrency_modifier(self.current_concurrency) | ||
|
|
||
|
|
@@ -128,8 +137,9 @@ async def run(self): | |
| # Create tasks for getting and running jobs. | ||
| jobtake_task = asyncio.create_task(self.get_jobs(session)) | ||
| jobrun_task = asyncio.create_task(self.run_jobs(session)) | ||
| jobstop_task = asyncio.create_task(self.monitor_stop_signals(session)) | ||
|
|
||
| tasks = [jobtake_task, jobrun_task] | ||
| tasks = [jobtake_task, jobrun_task, jobstop_task] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment two lines down still says "Concurrently run both tasks and wait for both to finish," but this list now holds three tasks ( |
||
|
|
||
| # Concurrently run both tasks and wait for both to finish. | ||
| await asyncio.gather(*tasks) | ||
|
|
@@ -226,9 +236,10 @@ async def run_jobs(self, session: ClientSession): | |
| # Fetch as many jobs as the concurrency allows | ||
| while len(tasks) < self.current_concurrency and not self.jobs_queue.empty(): | ||
| job = await self.jobs_queue.get() | ||
| # Create a new task for each job and add it to the task list | ||
| # Create a new task for each job and track it by job id | ||
| task = asyncio.create_task(self.handle_job(session, job)) | ||
| tasks.add(task) | ||
| self.jobs_tasks[job["id"]] = task | ||
|
|
||
| # Wait for any job to finish | ||
| if tasks: | ||
|
|
@@ -250,7 +261,51 @@ async def run_jobs(self, session: ClientSession): | |
|
|
||
|
|
||
| # Ensure all remaining tasks finish before stopping | ||
| await asyncio.gather(*tasks) | ||
| await asyncio.gather(*tasks, return_exceptions=True) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switching to results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError):
log.error(f"run_jobs | Task failed during shutdown drain: {result}") |
||
|
|
||
| async def monitor_stop_signals(self, session: ClientSession): | ||
| """ | ||
| Long-polls the dedicated stop channel and stops signalled jobs. | ||
|
|
||
| Runs in an infinite loop while the worker is alive. The Runpod server | ||
| signals a request to be stopped (for example when it is cancelled or | ||
| times out) and this loop stops just that in-progress job, leaving the | ||
| worker's other jobs running. | ||
| """ | ||
| while self.is_alive(): | ||
| try: | ||
| job_ids = await self.stop_signals_fetcher(session) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loop has no backpressure on its normal path. The The error branches already add real backoff (429→5s, generic→1s), so the concern is only the success/empty path. Suggest a floor delay when no ids were returned, e.g. |
||
| for job_id in job_ids: | ||
| await self.stop_job(job_id) | ||
|
Comment on lines
+276
to
+279
|
||
| except TooManyRequests: | ||
| await asyncio.sleep(5) # debounce | ||
| except asyncio.CancelledError: | ||
| raise | ||
| except Exception as error: | ||
| log.debug(f"JobScaler.monitor_stop_signals | Error: {error}.") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This catch-all logs at Suggest matching that: |
||
| await asyncio.sleep(1) # don't spin on persistent errors | ||
| finally: | ||
| await asyncio.sleep(0) | ||
|
|
||
| async def stop_job(self, job_id: str) -> bool: | ||
| """ | ||
| Stop a single in-progress job by cancelling its running task. | ||
|
|
||
| Args: | ||
| job_id: The id of the job to stop. | ||
|
|
||
| Returns: | ||
| True if a matching in-progress job was found and stopped, | ||
| False otherwise. | ||
| """ | ||
| task = self.jobs_tasks.get(job_id) | ||
| if task is None: | ||
| log.debug(f"JobScaler.stop_job | No in-progress job for {job_id}.") | ||
| return False | ||
|
|
||
| log.info("Stopping job.", job_id) | ||
| task.cancel() | ||
| return True | ||
|
|
||
| async def handle_job(self, session: ClientSession, job: dict): | ||
| """ | ||
|
|
@@ -268,11 +323,16 @@ async def handle_job(self, session: ClientSession, job: dict): | |
| log.error(f"Error handling job: {err}", job["id"]) | ||
| raise err | ||
|
|
||
| except asyncio.CancelledError: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This Suggest reordering |
||
| log.info("Job stopped.", job["id"]) | ||
| raise | ||
|
|
||
| finally: | ||
| # Inform Queue of a task completion | ||
| self.jobs_queue.task_done() | ||
|
|
||
| # Job is no longer in progress | ||
| self.job_progress.remove(job) | ||
| self.jobs_tasks.pop(job["id"], None) | ||
|
|
||
| log.debug("Finished Job", job["id"]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -238,6 +238,92 @@ async def handler(_session, _config, job): | |
|
|
||
| scaler.kill_worker() | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_stop_job_cancels_inflight_task(job_scaler: PatchScaler): | ||
| scaler = job_scaler.scaler | ||
| job_started = asyncio.Event() | ||
| cancelled = [] | ||
|
|
||
| async def handler(_session, _config, job): | ||
| job_started.set() | ||
| try: | ||
| await asyncio.sleep(10) | ||
| except asyncio.CancelledError: | ||
| cancelled.append(job["id"]) | ||
| raise | ||
|
|
||
| scaler.jobs_handler = handler | ||
| scaler.current_concurrency = 1 | ||
| scaler.jobs_queue = asyncio.Queue(maxsize=1) | ||
| run_task = asyncio.create_task(scaler.run_jobs(None)) | ||
|
|
||
| await scaler.jobs_queue.put(generate_job("stop-me")) | ||
| await asyncio.wait_for(job_started.wait(), timeout=2) | ||
|
|
||
| assert "stop-me" in scaler.jobs_tasks | ||
| assert await scaler.stop_job("stop-me") is True | ||
|
|
||
| scaler.kill_worker() | ||
| await asyncio.wait_for(run_task, timeout=2) | ||
|
|
||
| assert cancelled == ["stop-me"] | ||
| assert "stop-me" not in scaler.jobs_tasks | ||
| assert job_scaler.progress.count == 0 | ||
|
|
||
| scaler.kill_worker() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_stop_job_unknown_id_returns_false(job_scaler: PatchScaler): | ||
| scaler = job_scaler.scaler | ||
| assert await scaler.stop_job("does-not-exist") is False | ||
| scaler.kill_worker() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_monitor_stop_signals_stops_jobs(job_scaler: PatchScaler): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two test-coverage notes for this area:
|
||
| scaler = job_scaler.scaler | ||
| stopped = [] | ||
|
|
||
| async def fake_stop_job(job_id): | ||
| stopped.append(job_id) | ||
| return True | ||
|
|
||
| async def fetcher(_session): | ||
| if not stopped: | ||
| return ["job-a", "job-b"] | ||
| return [] | ||
|
|
||
| scaler.stop_job = fake_stop_job | ||
| scaler.stop_signals_fetcher = fetcher | ||
|
|
||
| monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock())) | ||
| await asyncio.sleep(0.05) | ||
| scaler.kill_worker() | ||
| await asyncio.wait_for(monitor_task, timeout=2) | ||
|
|
||
| assert sorted(stopped) == ["job-a", "job-b"] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_monitor_stop_signals_survives_errors(job_scaler: PatchScaler): | ||
| scaler = job_scaler.scaler | ||
| calls = {"value": 0} | ||
|
|
||
| async def fetcher(_session): | ||
| calls["value"] += 1 | ||
| raise RuntimeError("boom") | ||
|
|
||
| scaler.stop_signals_fetcher = fetcher | ||
|
|
||
| monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock())) | ||
| await asyncio.sleep(0.05) | ||
| scaler.kill_worker() | ||
| await asyncio.wait_for(monitor_task, timeout=2) | ||
|
|
||
| assert calls["value"] >= 1 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_get_jobs_feeds_workers_end_to_end(job_scaler: PatchScaler): | ||
| scaler = job_scaler.scaler | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"long-polls a dedicated stop channel so these signals arrive with low latency" describes server behavior as if the SDK guarantees it. The client only issues a single GET; whether the connection is held open vs. returns 204 immediately is entirely the server's contract, and
monitor_stop_signalshas no enforced floor delay (see the busy-spin comment), so if the server doesn't long-poll this becomes continuous polling rather than low-latency long-poll.Suggest attributing the behavior to the server, e.g. "The worker continuously polls a dedicated stop channel; the server is expected to hold each request open (long-poll) until a stop signal is available or the poll times out." Same applies to the
get_stop_signalsdocstring.Separately at line 65: the cleanup advice should note that a handler catching
asyncio.CancelledErrormust re-raise after cleanup (as the SDK's ownhandle_jobdoes) — otherwise the stop is swallowed and the job is reported completed.