Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/liger_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
except ImportError:
PEFT_AVAILABLE = False

import functools

import torch


Expand Down Expand Up @@ -65,6 +67,91 @@ def is_npu_available() -> bool:
return False


# NVIDIA: CUDA compute capability (major, minor) -> coarse arch family
_NVIDIA_ARCH_BY_CC = {
(7, 0): "volta_turing", # Volta V100
(7, 5): "volta_turing", # Turing T4 / RTX 20xx
(8, 0): "ampere_ada", # Ampere A100
(8, 6): "ampere_ada", # Ampere RTX 30xx / A40
(8, 9): "ampere_ada", # Ada Lovelace RTX 40xx / L4 / L40
(9, 0): "hopper", # H100 / H200
(10, 0): "blackwell", # B100 / B200 / GB200 (sm_100)
(10, 3): "blackwell_ultra", # B300 / GB300 (sm_103)
(12, 0): "blackwell_consumer", # RTX 50xx (sm_120)
}

# AMD: gfx target (gcnArchName) -> coarse arch family
_AMD_ARCH_BY_GFX = {
"gfx908": "cdna", # MI100
"gfx90a": "cdna2", # MI200
"gfx940": "cdna3", # MI300
"gfx941": "cdna3",
"gfx942": "cdna3", # MI300X/MI300A
"gfx1100": "rdna3", # RX 7900
"gfx1101": "rdna3",
"gfx1102": "rdna3",
}
Comment thread
Tcc0403 marked this conversation as resolved.


def _infer_nvidia_arch(device_id: int) -> str:
major, minor = torch.cuda.get_device_capability(device_id)
return _NVIDIA_ARCH_BY_CC.get((major, minor), f"sm_{major}{minor}")


def _infer_amd_arch(device_id: int) -> str:
# gcnArchName looks like "gfx942:sramecc+:xnack-"; keep the gfx target only.
gfx = getattr(torch.cuda.get_device_properties(device_id), "gcnArchName", "").split(":")[0]
return _AMD_ARCH_BY_GFX.get(gfx, gfx or "cuda")


def _infer_xpu_arch(device_id: int) -> str:
name = torch.xpu.get_device_properties(device_id).name.lower()
if any(tag in name for tag in ("max", "pvc", "ponte")):
return "pvc" # Ponte Vecchio / Data Center GPU Max
if any(tag in name for tag in ("arc", "battlemage", "alchemist")):
return "arc"
return "xpu"


def _infer_npu_arch(device_id: int) -> str:
name = torch.npu.get_device_properties(device_id).name.lower()
if "910" in name:
return "ascend910"
if "310" in name:
return "ascend310"
return "npu"
Comment thread
Tcc0403 marked this conversation as resolved.


@functools.lru_cache(maxsize=None)
def infer_device_arch(device_id: int = 0) -> str:
"""
Get a coarse architecture/generation name for the current device.

Returns a family name when detectable, falling back to the device type
from ``infer_device()`` (e.g. ``"cpu"``) otherwise:

- NVIDIA: ``"volta_turing"``, ``"ampere_ada"``, ``"hopper"``, ``"blackwell"``,
``"blackwell_ultra"``, ``"blackwell_consumer"`` (else ``"sm_<major><minor>"``)
- AMD: ``"cdna"``, ``"cdna2"``, ``"cdna3"``, ``"rdna3"`` (else the raw gfx target)
- Intel: ``"pvc"``, ``"arc"`` (else ``"xpu"``)
- Ascend: ``"ascend910"``, ``"ascend310"`` (else ``"npu"``)

The result is cached; call ``infer_device_arch.cache_clear()`` to reset.
"""
device = infer_device()
try:
if device == "cuda":
# ROCm reports as "cuda" in torch; torch.version.hip distinguishes AMD.
return _infer_amd_arch(device_id) if torch.version.hip else _infer_nvidia_arch(device_id)
if device == "xpu":
return _infer_xpu_arch(device_id)
if device == "npu":
return _infer_npu_arch(device_id)
except Exception:
return device
return device


def transformers_version_dispatch(
required_version: str,
before_fn,
Expand Down
165 changes: 165 additions & 0 deletions test/transformers/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from types import SimpleNamespace

import pytest
import torch

import liger_kernel.utils as utils

from liger_kernel.utils import _infer_amd_arch
from liger_kernel.utils import _infer_npu_arch
from liger_kernel.utils import _infer_nvidia_arch
from liger_kernel.utils import _infer_xpu_arch
from liger_kernel.utils import infer_device
from liger_kernel.utils import infer_device_arch


@pytest.fixture(autouse=True)
def clear_arch_cache():
# infer_device_arch is lru_cached; isolate every test from cached results.
infer_device_arch.cache_clear()
yield
infer_device_arch.cache_clear()


# -----------------------------------------------------------------------------
# Per-architecture mapping (helpers, with the underlying torch calls mocked)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
"capability, expected",
[
((7, 0), "volta_turing"),
((7, 5), "volta_turing"),
((8, 0), "ampere_ada"),
((8, 6), "ampere_ada"),
((8, 9), "ampere_ada"),
((9, 0), "hopper"),
((10, 0), "blackwell"),
((10, 3), "blackwell_ultra"),
((12, 0), "blackwell_consumer"),
((8, 7), "sm_87"), # known major, unknown minor -> fallback keeps the minor
((11, 0), "sm_110"), # unknown cc -> sm_<major><minor> fallback
],
)
def test_infer_nvidia_arch(monkeypatch, capability, expected):
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device_id=0: capability)
assert _infer_nvidia_arch(0) == expected


@pytest.mark.parametrize(
"gcn_arch_name, expected",
[
("gfx908", "cdna"),
("gfx90a", "cdna2"),
("gfx942:sramecc+:xnack-", "cdna3"), # decorated name -> gfx target only
("gfx1100", "rdna3"),
("gfx9999", "gfx9999"), # unknown gfx -> raw target
("", "cuda"), # missing arch name -> device-type fallback
],
)
def test_infer_amd_arch(monkeypatch, gcn_arch_name, expected):
monkeypatch.setattr(
torch.cuda,
"get_device_properties",
lambda device_id=0: SimpleNamespace(gcnArchName=gcn_arch_name),
)
assert _infer_amd_arch(0) == expected


@pytest.mark.parametrize(
"name, expected",
[
("Intel(R) Data Center GPU Max 1550", "pvc"),
("Intel(R) Arc(TM) A770 Graphics", "arc"),
("Some Future Intel GPU", "xpu"), # unrecognized -> device-type fallback
],
)
def test_infer_xpu_arch(monkeypatch, name, expected):
monkeypatch.setattr(
torch,
"xpu",
SimpleNamespace(get_device_properties=lambda device_id=0: SimpleNamespace(name=name)),
raising=False,
)
assert _infer_xpu_arch(0) == expected


@pytest.mark.parametrize(
"name, expected",
[
("Ascend910B", "ascend910"),
("Ascend910B3", "ascend910"),
("Ascend310P3", "ascend310"),
("Future NPU", "npu"), # unrecognized -> device-type fallback
],
)
def test_infer_npu_arch(monkeypatch, name, expected):
monkeypatch.setattr(
torch,
"npu",
SimpleNamespace(get_device_properties=lambda device_id=0: SimpleNamespace(name=name)),
raising=False,
)
assert _infer_npu_arch(0) == expected


# -----------------------------------------------------------------------------
# Dispatch in infer_device_arch
# -----------------------------------------------------------------------------
def test_dispatch_nvidia(monkeypatch):
monkeypatch.setattr(utils, "infer_device", lambda: "cuda")
monkeypatch.setattr(torch.version, "hip", None, raising=False)
monkeypatch.setattr(utils, "_infer_nvidia_arch", lambda device_id: "blackwell")
monkeypatch.setattr(utils, "_infer_amd_arch", lambda device_id: pytest.fail("AMD path taken on NVIDIA"))
assert infer_device_arch() == "blackwell"


def test_dispatch_amd(monkeypatch):
# ROCm reports as "cuda"; torch.version.hip routes to the AMD helper.
monkeypatch.setattr(utils, "infer_device", lambda: "cuda")
monkeypatch.setattr(torch.version, "hip", "6.0.0", raising=False)
monkeypatch.setattr(utils, "_infer_amd_arch", lambda device_id: "cdna3")
monkeypatch.setattr(utils, "_infer_nvidia_arch", lambda device_id: pytest.fail("NVIDIA path taken on AMD"))
assert infer_device_arch() == "cdna3"


@pytest.mark.parametrize("device", ["xpu", "npu"])
def test_dispatch_xpu_npu(monkeypatch, device):
monkeypatch.setattr(utils, "infer_device", lambda: device)
monkeypatch.setattr(utils, f"_infer_{device}_arch", lambda device_id: f"{device}-arch")
assert infer_device_arch() == f"{device}-arch"


def test_falls_back_to_device_type_on_error(monkeypatch):
monkeypatch.setattr(utils, "infer_device", lambda: "cuda")
monkeypatch.setattr(torch.version, "hip", None, raising=False)

def boom(device_id):
raise RuntimeError("driver not initialized")

monkeypatch.setattr(utils, "_infer_nvidia_arch", boom)
assert infer_device_arch() == "cuda"


def test_unaccelerated_device_returns_device_type(monkeypatch):
monkeypatch.setattr(utils, "infer_device", lambda: "cpu")
assert infer_device_arch() == "cpu"


# -----------------------------------------------------------------------------
# Real-environment behavior + caching
# -----------------------------------------------------------------------------
def test_returns_nonempty_string_for_current_device():
arch = infer_device_arch()
assert isinstance(arch, str) and arch
# On an unaccelerated host this collapses to the device type.
if infer_device() == "cpu":
assert arch == "cpu"


def test_result_is_cached():
infer_device_arch.cache_clear()
first = infer_device_arch()
second = infer_device_arch()
assert first == second
info = infer_device_arch.cache_info()
assert info.hits >= 1 and info.misses == 1
Loading