diff --git a/backend/src/github_pm/api.py b/backend/src/github_pm/api.py index 206852a..4693daf 100644 --- a/backend/src/github_pm/api.py +++ b/backend/src/github_pm/api.py @@ -1,4 +1,5 @@ from collections import defaultdict +from collections.abc import Callable from datetime import datetime import re import time @@ -18,6 +19,10 @@ # We sort "semver" style milestones first, then others alphabetically VERSION_MATCH = re.compile(r"^v\d+\.\d+\.\d+$") +# Bounded retries for transient GitHub gateway timeouts (504). +_GITHUB_504_MAX_ATTEMPTS = 5 +_GITHUB_504_BACKOFF_SEC = 1.5 + class Connector: def __init__(self, github_token: str, *, github_repo: str | None = None): @@ -46,19 +51,39 @@ def __init__(self, github_token: str, *, github_repo: str | None = None): repo, ) + def _with_504_retry( + self, request: Callable[[], requests.Response] + ) -> requests.Response: + """Perform one HTTP call, retrying on 504 with a short capped backoff.""" + for attempt in range(_GITHUB_504_MAX_ATTEMPTS): + response = request() + if response.status_code == 504 and attempt < _GITHUB_504_MAX_ATTEMPTS - 1: + delay = _GITHUB_504_BACKOFF_SEC * (attempt + 1) + logger.warning( + "GitHub API 504 Gateway Timeout; waiting %.1fs before retry %d/%d", + delay, + attempt + 2, + _GITHUB_504_MAX_ATTEMPTS, + ) + time.sleep(delay) + continue + response.raise_for_status() + self.response = response + return response + def get(self, path: str, headers: dict[str, str] | None = None) -> dict: - response = self.github.get(f"{self.base_url}{path}", headers=headers) - response.raise_for_status() - self.response = response + response = self._with_504_retry( + lambda: self.github.get(f"{self.base_url}{path}", headers=headers) + ) return response.json() def get_paged(self, path: str, headers: dict[str, str] | None = None) -> list[dict]: url: str | None = f"{self.base_url}{path}" results = [] while url: - response = self.github.get(url, headers=headers) - response.raise_for_status() - self.response = response + response = self._with_504_retry( + lambda u=url: self.github.get(u, headers=headers) + ) data = response.json() logger.debug(f"{url}: {len(data)}") results.extend(data) @@ -80,8 +105,9 @@ def search_issue_items( url: str | None = f"{self.base_url}/search/issues?q={q_param}&per_page=100" results: list[dict] = [] while url: - response = self.github.get(url, headers=headers) - response.raise_for_status() + response = self._with_504_retry( + lambda u=url: self.github.get(u, headers=headers) + ) data = response.json() items = data.get("items") if isinstance(items, list): @@ -99,21 +125,21 @@ def search_issue_items( def patch( self, path: str, data: dict[str, Any], headers: dict[str, str] | None = None ) -> dict: - response = self.github.patch( - f"{self.base_url}{path}", json=data, headers=headers + response = self._with_504_retry( + lambda: self.github.patch( + f"{self.base_url}{path}", json=data, headers=headers + ) ) - response.raise_for_status() - self.response = response return response.json() def post( self, path: str, data: dict[str, Any], headers: dict[str, str] | None = None ) -> dict: - response = self.github.post( - f"{self.base_url}{path}", json=data, headers=headers + response = self._with_504_retry( + lambda: self.github.post( + f"{self.base_url}{path}", json=data, headers=headers + ) ) - response.raise_for_status() - self.response = response return response.json() def delete( @@ -122,11 +148,11 @@ def delete( data: dict[str, Any] | None = None, headers: dict[str, str] | None = None, ) -> dict: - response = self.github.delete( - f"{self.base_url}{path}", json=data, headers=headers + response = self._with_504_retry( + lambda: self.github.delete( + f"{self.base_url}{path}", json=data, headers=headers + ) ) - response.raise_for_status() - self.response = response return response.json() if response.content else {} diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index 137bab6..8655ff3 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -8,6 +8,7 @@ from fastapi import HTTPException from fastapi.testclient import TestClient import pytest +import requests from github_pm.api import ( add_label_to_issue, @@ -1208,6 +1209,59 @@ def test_get_project_endpoint(self): assert "github_repo" in data +class TestConnector504Retry: + """Tests for ``Connector`` retry behavior on GitHub HTTP 504 responses.""" + + def test_get_retries_504_then_succeeds(self): + """504 then 200 triggers one backoff sleep and succeeds.""" + mock_session = Mock() + r504 = Mock() + r504.status_code = 504 + r200 = Mock() + r200.status_code = 200 + r200.raise_for_status = Mock() + r200.json.return_value = {"ok": True} + mock_session.get.side_effect = [r504, r200] + + with ( + patch("github_pm.api.requests.session", return_value=mock_session), + patch("github_pm.api.time.sleep") as mock_sleep, + patch("github_pm.api.context") as mock_context, + ): + mock_context.github_repo = "o/r" + mock_context.github_token = "tok" + conn = Connector("tok", github_repo="o/r") + data = conn.get("/repos/o/r/issues/1") + + assert data == {"ok": True} + assert mock_session.get.call_count == 2 + mock_sleep.assert_called_once() + assert mock_sleep.call_args[0][0] == pytest.approx(1.5) + + def test_get_exhausts_retries_on_persistent_504(self): + """After max attempts, persistent 504 propagates ``HTTPError``.""" + mock_session = Mock() + r504 = Mock() + r504.status_code = 504 + err = requests.HTTPError() + err.response = r504 + r504.raise_for_status = Mock(side_effect=err) + mock_session.get.return_value = r504 + + with ( + patch("github_pm.api.requests.session", return_value=mock_session), + patch("github_pm.api.time.sleep"), + patch("github_pm.api.context") as mock_context, + ): + mock_context.github_repo = "o/r" + mock_context.github_token = "tok" + conn = Connector("tok", github_repo="o/r") + with pytest.raises(requests.HTTPError): + conn.get("/repos/o/r/issues/1") + + assert mock_session.get.call_count == 5 + + class TestSearchIssueItems: """Tests for ``Connector.search_issue_items`` (GitHub search pagination)."""