Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/12449.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Retry committing container registry image rescan results after transient database serialization conflicts.
230 changes: 144 additions & 86 deletions src/ai/backend/manager/container_registry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from collections.abc import AsyncIterator, Mapping, Sequence
from contextlib import asynccontextmanager as actxmgr
from contextvars import ContextVar
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Final,
cast,
Expand All @@ -30,14 +30,18 @@
)
from ai.backend.common.docker import login as registry_login
from ai.backend.common.exception import (
BackendAIError,
InvalidImageName,
InvalidImageTag,
ProjectMismatchWithCanonical,
)
from ai.backend.common.json import read_json
from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy
from ai.backend.common.resilience.resilience import Resilience
from ai.backend.common.types import SlotName, SSLContextType
from ai.backend.common.utils import join_non_empty
from ai.backend.logging import BraceStyleAdapter
from ai.backend.logging.types import LogLevel
from ai.backend.manager.data.image.types import (
ImageData,
ImageStatus,
Expand All @@ -47,6 +51,7 @@
from ai.backend.manager.data.permission.types import RBACElementRef
from ai.backend.manager.defs import INTRINSIC_SLOTS_MIN
from ai.backend.manager.exceptions import ScanImageError, ScanTagError
from ai.backend.manager.models.container_registry import ContainerRegistryRow
from ai.backend.manager.models.image import ImageIdentifier, ImageRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from ai.backend.manager.repositories.base.rbac.entity_creator import (
Expand All @@ -61,9 +66,30 @@
"progress_reporter", default=None
)
all_updates: ContextVar[dict[ImageIdentifier, dict[str, Any]]] = ContextVar("all_updates")
commit_rescan_result_resilience = Resilience(
policies=[
RetryPolicy(
RetryArgs(
max_retries=10,
retry_delay=0.1,
backoff_strategy=BackoffStrategy.FIXED,
non_retryable_exceptions=(BackendAIError,),
)
),
]
)


@dataclass(frozen=True)
class _RescanProgressEvent:
"""
A rescan progress message buffered during the commit transaction and emitted
only after a successful commit, so retries do not duplicate logs or over-count
the progress reporter.
"""

if TYPE_CHECKING:
from ai.backend.manager.models.container_registry import ContainerRegistryRow
level: LogLevel
message: str


class BaseContainerRegistry(metaclass=ABCMeta):
Expand Down Expand Up @@ -176,97 +202,129 @@ def _determine_additional_image_scopes(

async def commit_rescan_result(self) -> list[ImageData]:
scanned_images: list[ImageData] = []
_all_updates = all_updates.get()
if not _all_updates:
original_updates = all_updates.get()
if not original_updates:
log.info("No images found in registry {0}", self.registry_url)
else:
image_identifiers = [(k.canonical, k.architecture) for k in _all_updates.keys()]
async with self.db.begin_session() as session:
existing_images = await session.scalars(
sa.select(ImageRow).where(
sa.func.ROW(ImageRow.name, ImageRow.architecture).in_(image_identifiers),
)
image_identifiers = [(k.canonical, k.architecture) for k in original_updates.keys()]

scanned_images, progress_events = await self._commit_rescan_result_once(
original_updates,
image_identifiers,
)
for event in progress_events:
match event.level:
case LogLevel.WARNING:
log.warning(event.message)
case _:
log.info(event.message)
if (reporter := progress_reporter.get()) is not None:
await reporter.update(1, message=event.message)
return scanned_images

@commit_rescan_result_resilience.apply()
async def _commit_rescan_result_once(
self,
original_updates: dict[ImageIdentifier, dict[str, Any]],
image_identifiers: list[tuple[str, str]],
) -> tuple[list[ImageData], list[_RescanProgressEvent]]:
scanned_images: list[ImageData] = []
progress_events: list[_RescanProgressEvent] = []
pending_updates = dict(original_updates)

async with self.db.begin_session() as session:
# Serialize commits per registry. Some images may not exist yet, so locking
# only ImageRow cannot prevent concurrent insert races for the same registry.
await session.execute(
sa.select(ContainerRegistryRow.id)
.where(ContainerRegistryRow.id == self.registry_info.id)
.with_for_update()
)
existing_images = await session.scalars(
sa.select(ImageRow).where(
sa.func.ROW(ImageRow.name, ImageRow.architecture).in_(image_identifiers),
)
is_local = self.registry_name == "local"

for image_row in existing_images:
image_ref = image_row.image_ref
update_key = ImageIdentifier(image_ref.canonical, image_ref.architecture)
if update := _all_updates.pop(update_key, None):
image_row.config_digest = update["config_digest"]
image_row.size_bytes = update["size_bytes"]
image_row.accelerators = update.get("accels")
image_row.labels = update["labels"]
image_row.is_local = is_local
scanned_images.append(image_row.to_dataclass())

if image_row.status == ImageStatus.DELETED:
image_row.status = ImageStatus.ALIVE

progress_msg = f"Restored deleted image - {image_ref.canonical}/{image_ref.architecture} ({update['config_digest']})"
log.info(progress_msg)

if (reporter := progress_reporter.get()) is not None:
await reporter.update(1, message=progress_msg)

rbac_creators: list[RBACEntityCreator[ImageRow]] = []
for image_identifier, update in _all_updates.items():
try:
parsed_img = ImageRef.from_image_str(
image_identifier.canonical,
self.registry_info.project,
self.registry_info.registry_name,
is_local=is_local,
)
is_local = self.registry_name == "local"

for image_row in existing_images:
image_ref = image_row.image_ref
update_key = ImageIdentifier(image_ref.canonical, image_ref.architecture)
if update := pending_updates.pop(update_key, None):
image_row.config_digest = update["config_digest"]
image_row.size_bytes = update["size_bytes"]
image_row.accelerators = update.get("accels")
image_row.labels = update["labels"]
image_row.is_local = is_local
scanned_images.append(image_row.to_dataclass())

if image_row.status == ImageStatus.DELETED:
image_row.status = ImageStatus.ALIVE

progress_msg = (
f"Restored deleted image - {image_ref.canonical}/"
f"{image_ref.architecture} ({update['config_digest']})"
)
except (ProjectMismatchWithCanonical, ValueError) as e:
skip_reason = str(e)
progress_msg = f"Skipped image - {image_identifier.canonical}/{image_identifier.architecture} ({skip_reason})"
log.warning(progress_msg)
if (reporter := progress_reporter.get()) is not None:
await reporter.update(1, message=progress_msg)
continue

rbac_creators.append(
RBACEntityCreator(
spec=ImageRowCreatorSpec(
name=parsed_img.canonical,
project=self.registry_info.project,
architecture=image_identifier.architecture,
registry_id=self.registry_info.id,
is_local=is_local,
registry=parsed_img.registry,
image=join_non_empty(parsed_img.project, parsed_img.name, sep="/"),
tag=parsed_img.tag,
config_digest=update["config_digest"],
size_bytes=update["size_bytes"],
type=ImageType.COMPUTE,
accelerators=update.get("accels"),
labels=update["labels"],
status=ImageStatus.ALIVE,
),
scope_ref=RBACElementRef(
RBACElementType.CONTAINER_REGISTRY,
str(self.registry_info.id),
),
additional_scope_refs=self._determine_additional_image_scopes(
update["labels"]
),
element_type=RBACElementType.IMAGE,
),
)
progress_events.append(_RescanProgressEvent(LogLevel.INFO, progress_msg))

bulk_result = await execute_rbac_entity_creators(session, rbac_creators)
for row in bulk_result.rows:
scanned_images.append(row.to_dataclass())
rbac_creators: list[RBACEntityCreator[ImageRow]] = []
for image_identifier, update in pending_updates.items():
try:
parsed_img = ImageRef.from_image_str(
image_identifier.canonical,
self.registry_info.project,
self.registry_info.registry_name,
is_local=is_local,
)
except (ProjectMismatchWithCanonical, ValueError) as e:
skip_reason = str(e)
progress_msg = (
f"Updated image - {row.name}/{row.architecture} ({row.config_digest})"
f"Skipped image - {image_identifier.canonical}/"
f"{image_identifier.architecture} ({skip_reason})"
)
log.info(progress_msg)
if (reporter := progress_reporter.get()) is not None:
await reporter.update(1, message=progress_msg)
progress_events.append(_RescanProgressEvent(LogLevel.WARNING, progress_msg))
continue

await session.flush()
return scanned_images
rbac_creators.append(
RBACEntityCreator(
spec=ImageRowCreatorSpec(
name=parsed_img.canonical,
project=self.registry_info.project,
architecture=image_identifier.architecture,
registry_id=self.registry_info.id,
is_local=is_local,
registry=parsed_img.registry,
image=join_non_empty(parsed_img.project, parsed_img.name, sep="/"),
tag=parsed_img.tag,
config_digest=update["config_digest"],
size_bytes=update["size_bytes"],
type=ImageType.COMPUTE,
accelerators=update.get("accels"),
labels=update["labels"],
status=ImageStatus.ALIVE,
),
scope_ref=RBACElementRef(
RBACElementType.CONTAINER_REGISTRY,
str(self.registry_info.id),
),
additional_scope_refs=self._determine_additional_image_scopes(
update["labels"]
),
element_type=RBACElementType.IMAGE,
),
)

bulk_result = await execute_rbac_entity_creators(session, rbac_creators)
for row in bulk_result.rows:
scanned_images.append(row.to_dataclass())
progress_msg = (
f"Updated image - {row.name}/{row.architecture} ({row.config_digest})"
)
progress_events.append(_RescanProgressEvent(LogLevel.INFO, progress_msg))

await session.flush()

return scanned_images, progress_events

async def scan_single_ref(self, image: str) -> RescanImagesResult:
all_updates_token = all_updates.set({})
Expand Down
Loading
Loading