Skip to content
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,31 @@ apply_liger_kernel_to_llama(
model = transformers.AutoModelForCausalLM("path/to/llama/model")
```

#### Memory-efficient tiled MLP (opt-in)

For long sequences you can additionally replace the MLP with Liger's tiled MLP, which recomputes the MLP
forward during the backward to trade compute for a large activation-memory saving. It is opt-in and
covers every model that shares the SwiGLU or GeGLU layout.

```python
from liger_kernel.transformers import apply_liger_tiled_mlp

# Before loading: register the tiled MLP so it is applied to any supported model on construction
apply_liger_tiled_mlp()
model = transformers.AutoModelForCausalLM.from_pretrained("path/to/llama/model")

# Or patch an already-loaded model in place
apply_liger_tiled_mlp(model=model, num_shards=4)
```

It is also reachable through the standard instance patching entry point (and therefore the Hugging Face
Trainer `use_liger_kernel` config) via the `tiled_mlp` and `tiled_mlp_num_shards` keyword arguments.

> [!NOTE]
> Distributed support, verified on 2x H100: tiled gradients match a non-tiled reference under both FSDP2
> (`torch.distributed.fsdp.fully_shard`) and DeepSpeed ZeRO-3. Plain DDP is not yet covered (a fix is in
> flight in #1125). Verify gradient correctness for your setup before relying on it.

### 3. Compose Your Own Model

You can take individual [kernels](https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#model-kernels) to compose your models.
Expand Down
20 changes: 20 additions & 0 deletions src/liger_kernel/ops/tiled_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

import torch
import torch.distributed as dist

from liger_kernel.ops.utils import ensure_contiguous

Expand Down Expand Up @@ -44,6 +45,7 @@ def forward(
ctx.fn = fn
ctx.mlp_module = mlp_module
ctx.shards = shards
ctx.compute_params = compute_params
ctx.save_for_backward(x)

# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
Expand All @@ -61,6 +63,7 @@ def backward(ctx, *grads) -> tuple:
(x,) = ctx.saved_tensors
mlp_module = ctx.mlp_module
shards = ctx.shards
compute_params = ctx.compute_params

x_requires_grad = x.requires_grad
x = x.detach()
Expand All @@ -78,6 +81,10 @@ def backward(ctx, *grads) -> tuple:

x_shards = list(torch.chunk(x, chunks=shards, dim=0))

# ZeRO-3 partitioned parameters carry a ds_id; collect them once so the per-shard loop only
# flips the ready flag. Parameters on other backends have no ds_id and are left untouched.
ds_params = [p for p in compute_params if hasattr(p, "ds_id")] if compute_params else []

for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)

Expand All @@ -88,6 +95,11 @@ def backward(ctx, *grads) -> tuple:
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)

# Defer DeepSpeed's reduction until the last shard has accumulated into param.grad; the flag
# is read by ZeRO's hook during each shard's backward, so it must be set per shard.
for param in ds_params:
param.ds_grad_is_ready = i + 1 == len(x_shards)

with torch.enable_grad():
output = fn(mlp_module, x_shard)
torch.autograd.backward(output, incoming_grad_shard)
Expand Down Expand Up @@ -127,6 +139,14 @@ def apply_tiled_mlp(
# Ensure num_shards is at least 1
num_shards = max(1, num_shards)

# All ranks must run the same number of shards: a sharded-parameter backend (DeepSpeed ZeRO-3, FSDP)
# gathers weights inside each shard's recompute, so a rank that runs fewer shards stops participating
# in those collectives and deadlocks the others. Harmonize on the per-rank maximum.
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
num_shards_tensor = torch.tensor(num_shards, device=x.device)
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
num_shards = int(num_shards_tensor.item())

return LigerTiledMLPFunction.apply(
fn,
mlp_module,
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_tiled_mlp # noqa: F401


# Check if 'transformers' is installed
Expand Down Expand Up @@ -160,6 +161,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_hunyuan_v1_moe",
"apply_liger_kernel_to_deepseek_v4",
"apply_liger_kernel_to_exaone4",
"apply_liger_tiled_mlp",
}

if name in monkey_patch_symbols:
Expand Down Expand Up @@ -251,5 +253,6 @@ def __getattr__(name: str):
"apply_liger_kernel_to_hunyuan_v1_moe",
"apply_liger_kernel_to_deepseek_v4",
"apply_liger_kernel_to_exaone4",
"apply_liger_tiled_mlp",
]
)
77 changes: 76 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from liger_kernel.transformers.swiglu import LigerExperts
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP

try:
import peft
Expand Down Expand Up @@ -138,6 +140,69 @@ def _patch_swiglu_module(module, liger_module):
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)


def _patch_tiled_mlp_module(module, liger_tiled_module, num_shards=None):
module.num_shards = num_shards
_bind_method_to_module(module, "_mlp_forward", liger_tiled_module._mlp_forward)
_bind_method_to_module(module, "forward", liger_tiled_module.forward)
_bind_method_to_module(module, "_get_name", lambda self: liger_tiled_module.__name__)


# Maps the transformers MLP class name to the Liger tiled MLP that replaces it. Only models whose MLP
# matches the gate/up/down SwiGLU or GEGLU layout are listed; MoE experts, fused gate_up (phi3) and
# Gemma4 use bespoke layouts and are intentionally excluded until tiled variants exist.
LIGER_TILED_MLP_PATCH_MAPPING = {
"LlamaMLP": LigerTiledSwiGLUMLP,
"MllamaTextMLP": LigerTiledSwiGLUMLP,
"Llama4TextMLP": LigerTiledSwiGLUMLP,
"MistralMLP": LigerTiledSwiGLUMLP,
"MinistralMLP": LigerTiledSwiGLUMLP,
"PixtralMLP": LigerTiledSwiGLUMLP,
"Qwen2MLP": LigerTiledSwiGLUMLP,
"Qwen3MLP": LigerTiledSwiGLUMLP,
"SmolLM3MLP": LigerTiledSwiGLUMLP,
"Exaone4MLP": LigerTiledSwiGLUMLP,
"Olmo2MLP": LigerTiledSwiGLUMLP,
"Olmo3MLP": LigerTiledSwiGLUMLP,
"GemmaMLP": LigerTiledGEGLUMLP,
"Gemma2MLP": LigerTiledGEGLUMLP,
"Gemma3MLP": LigerTiledGEGLUMLP,
}


def apply_liger_tiled_mlp(model=None, num_shards=None, mapping=LIGER_TILED_MLP_PATCH_MAPPING) -> None:
"""
Apply Liger's memory-efficient tiled MLP to the supported models.

When `model` is None the replacement is registered through the official transformers patch mapping
(`register_patch_mapping`) and applied automatically to any model later built with `from_pretrained`
or `from_config`. When `model` is provided, every already-instantiated MLP whose class name is in
`mapping` is patched in place, reusing its existing weights.

Tiled MLP recomputes the MLP forward during the backward to trade compute for a large activation
memory saving on long sequences. It is opt-in. Gradients have been verified to match a non-tiled
reference under both FSDP2 (`torch.distributed.fsdp.fully_shard`) and DeepSpeed ZeRO-3, where the
backward defers ZeRO-3 gradient reduction to the last shard. Plain DDP is not yet covered.

Args:
model (PreTrainedModel): An already-loaded model to patch in place. If None, the replacement is
registered for future model construction instead. Default is None.
num_shards (Optional[int]): Number of sequence shards used when patching an existing model
instance. If None, it is computed automatically per forward. Default is None.
mapping (dict): Mapping from transformers MLP class name to the Liger tiled MLP class to use.
Defaults to all models that share the SwiGLU or GEGLU layout.
"""
if model is None:
from transformers.monkey_patching import register_patch_mapping

register_patch_mapping(mapping, overwrite=True)
return

for module in model.modules():
liger_tiled_module = mapping.get(module.__class__.__name__)
if liger_tiled_module is not None:
_patch_tiled_mlp_module(module, liger_tiled_module, num_shards=num_shards)


def _patch_geglu_module(module):
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
Expand Down Expand Up @@ -3620,12 +3685,19 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
apply_fn(**applicable_kwargs)


def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
def _apply_liger_kernel_to_instance(
model: PreTrainedModel, tiled_mlp: bool = False, tiled_mlp_num_shards: Optional[int] = None, **kwargs
) -> None:
"""
Applies Liger kernels to the provided model instance.

Args:
- model: the model instance to apply Liger kernels to
- tiled_mlp: whether to additionally replace the model's MLP with Liger's tiled MLP for
activation-memory savings on long sequences. Opt-in; see `apply_liger_tiled_mlp` for the
distributed-backend caveats.
- tiled_mlp_num_shards: number of sequence shards used by the tiled MLP, computed automatically
when None.
- kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
"""
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
Expand All @@ -3648,3 +3720,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
)

apply_fn(model=model, **applicable_kwargs)

if tiled_mlp:
apply_liger_tiled_mlp(model=model, num_shards=tiled_mlp_num_shards)
94 changes: 94 additions & 0 deletions test/transformers/test_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from liger_kernel.transformers import LigerQwen3MoeSwiGLUMLP
from liger_kernel.transformers import LigerRMSNorm
from liger_kernel.transformers import LigerSwiGLUMLP
from liger_kernel.transformers import LigerTiledGEGLUMLP
from liger_kernel.transformers import LigerTiledSwiGLUMLP
from liger_kernel.transformers import monkey_patch
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.model.falcon_h1 import lce_forward as falcon_h1_lce_forward
Expand Down Expand Up @@ -513,6 +515,98 @@ def test_apply_liger_kernel_to_instance_for_llama():
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")


def test_apply_liger_kernel_to_instance_for_llama_with_tiled_mlp():
from transformers.monkey_patching import clear_patch_mapping

# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.llama.modeling_llama"):
config = transformers.models.llama.configuration_llama.LlamaConfig(
dtype=torch.bfloat16,
rms_norm_eps=1e-5,
hidden_size=32,
intermediate_size=64,
hidden_act="silu",
num_hidden_layers=2,
)
dummy_model_instance = AutoModelForCausalLM.from_config(config)

for layer in dummy_model_instance.model.layers:
assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerTiledSwiGLUMLP.forward)

try:
_apply_liger_kernel_to_instance(
model=dummy_model_instance,
rope=False,
rms_norm=False,
fused_linear_cross_entropy=False,
swiglu=False,
tiled_mlp=True,
tiled_mlp_num_shards=4,
)

for layer in dummy_model_instance.model.layers:
assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerTiledSwiGLUMLP.forward)
assert layer.mlp.num_shards == 4
finally:
clear_patch_mapping()


def test_apply_liger_tiled_mlp_to_instance():
config = transformers.models.llama.configuration_llama.LlamaConfig(
dtype=torch.bfloat16,
rms_norm_eps=1e-5,
hidden_size=32,
intermediate_size=64,
hidden_act="silu",
num_hidden_layers=2,
)
model = AutoModelForCausalLM.from_config(config)

for layer in model.model.layers:
assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerTiledSwiGLUMLP.forward)

monkey_patch.apply_liger_tiled_mlp(model=model, num_shards=4)

for layer in model.model.layers:
assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerTiledSwiGLUMLP.forward)
assert layer.mlp.num_shards == 4


def test_apply_liger_tiled_mlp_registers_supported_models():
from transformers.monkey_patching import clear_patch_mapping

llama_config = transformers.models.llama.configuration_llama.LlamaConfig(
dtype=torch.bfloat16,
rms_norm_eps=1e-5,
hidden_size=32,
intermediate_size=64,
hidden_act="silu",
num_hidden_layers=2,
)
gemma2_config = transformers.models.gemma2.configuration_gemma2.Gemma2Config(
dtype=torch.bfloat16,
hidden_size=32,
intermediate_size=64,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=16,
)

try:
monkey_patch.apply_liger_tiled_mlp()

llama_model = AutoModelForCausalLM.from_config(llama_config)
gemma2_model = AutoModelForCausalLM.from_config(gemma2_config)

for layer in llama_model.model.layers:
assert isinstance(layer.mlp, LigerTiledSwiGLUMLP)
for layer in gemma2_model.model.layers:
assert isinstance(layer.mlp, LigerTiledGEGLUMLP)
finally:
clear_patch_mapping()


@pytest.mark.skipif(not is_qwen3_vl_available(), reason="qwen3_vl module not available")
def test_apply_liger_kernel_to_instance_for_qwen3_vl_for_conditional_generation():
# Ensure any monkey patching is cleaned up for subsequent tests
Expand Down
Loading
Loading