Skip to content

[transformers] Add tiled MLP patching via the transformers register_patch_mapping API#1276

Open
kashif wants to merge 9 commits into
linkedin:mainfrom
kashif:llama-tiled-mlp-patch
Open

[transformers] Add tiled MLP patching via the transformers register_patch_mapping API#1276
kashif wants to merge 9 commits into
linkedin:mainfrom
kashif:llama-tiled-mlp-patch

Conversation

@kashif

@kashif kashif commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

Liger already ships a memory-efficient tiled SwiGLU/GEGLU MLP, but there was no way to reach it through the normal patching path. You had to construct the module by hand, so the activation-memory win was hidden from apply_liger_kernel_to_* / use_liger_kernel=True.

This wires it in using the official transformers monkey patching API, with one mechanism rather than per-model edits:

  • apply_liger_tiled_mlp(model=None, num_shards=None):
    • pre-load (model=None): registers the swap via transformers.monkey_patching.register_patch_mapping, so from_pretrained / from_config build the tiled MLP directly. Weights load by name since gate/up/down projections match.
    • instance (model=...): walks the model and patches in place every MLP whose class name is in the mapping, reusing existing weights. This covers the already-loaded case the official API cannot retrofit.
  • _apply_liger_kernel_to_instance gains tiled_mlp / tiled_mlp_num_shards, so use_liger_kernel=True and TRL can opt in uniformly for every supported model.
  • LIGER_TILED_MLP_PATCH_MAPPING lists the 15 models that share the SwiGLU or GEGLU layout (llama, mistral, ministral, qwen2/3, qwen2_vl, smollm3, mllama, pixtral, llama4, olmo2/3, exaone4 -> tiled SwiGLU; gemma/2/3 -> tiled GEGLU).

MoE (mixtral, qwen3_moe), fused gate_up (phi3), Gemma4 and Granite are intentionally excluded since they use bespoke layouts and need their own tiled variants.

Distributed note: tiled MLP is strictly opt-in. Its distributed behaviour is whatever the underlying LigerTiledMLPFunction provides, which targets DeepSpeed/ZeRO. FSDP/FSDP2 and plain DDP are not yet handled by the default tiled backward (see open PRs #1128 and #1125), so enabling tiled MLP under those backends can corrupt gradients today. This PR only does the wiring and does not change that kernel.

Tests cover the central instance path, the standalone instance patch, and the pre-load registry across both a silu and a gelu model. Full test_monkey_patch.py passes (65).

@kashif kashif changed the title Wire tiled MLP into llama patching via register_patch_mapping Add tiled MLP patching via the transformers register_patch_mapping API Jun 28, 2026
@kashif kashif force-pushed the llama-tiled-mlp-patch branch from 319a911 to c4a41e3 Compare June 28, 2026 20:35
@kashif kashif force-pushed the llama-tiled-mlp-patch branch from b34ff77 to 133e322 Compare June 28, 2026 20:59
@kashif kashif force-pushed the llama-tiled-mlp-patch branch from 133e322 to c4f7a81 Compare June 28, 2026 21:29
@kashif

kashif commented Jun 28, 2026

Copy link
Copy Markdown
Contributor Author

cc @vaibhavjindal

@kashif kashif changed the title Add tiled MLP patching via the transformers register_patch_mapping API [transformers] Add tiled MLP patching via the transformers register_patch_mapping API Jul 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant