[transformers] Add tiled MLP patching via the transformers register_patch_mapping API#1276
Open
kashif wants to merge 9 commits into
Open
[transformers] Add tiled MLP patching via the transformers register_patch_mapping API#1276kashif wants to merge 9 commits into
kashif wants to merge 9 commits into
Conversation
319a911 to
c4a41e3
Compare
b34ff77 to
133e322
Compare
133e322 to
c4f7a81
Compare
Contributor
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Liger already ships a memory-efficient tiled SwiGLU/GEGLU MLP, but there was no way to reach it through the normal patching path. You had to construct the module by hand, so the activation-memory win was hidden from
apply_liger_kernel_to_*/use_liger_kernel=True.This wires it in using the official transformers monkey patching API, with one mechanism rather than per-model edits:
apply_liger_tiled_mlp(model=None, num_shards=None):model=None): registers the swap viatransformers.monkey_patching.register_patch_mapping, sofrom_pretrained/from_configbuild the tiled MLP directly. Weights load by name since gate/up/down projections match.model=...): walks the model and patches in place every MLP whose class name is in the mapping, reusing existing weights. This covers the already-loaded case the official API cannot retrofit._apply_liger_kernel_to_instancegainstiled_mlp/tiled_mlp_num_shards, souse_liger_kernel=Trueand TRL can opt in uniformly for every supported model.LIGER_TILED_MLP_PATCH_MAPPINGlists the 15 models that share the SwiGLU or GEGLU layout (llama, mistral, ministral, qwen2/3, qwen2_vl, smollm3, mllama, pixtral, llama4, olmo2/3, exaone4 -> tiled SwiGLU; gemma/2/3 -> tiled GEGLU).MoE (mixtral, qwen3_moe), fused gate_up (phi3), Gemma4 and Granite are intentionally excluded since they use bespoke layouts and need their own tiled variants.
Distributed note: tiled MLP is strictly opt-in. Its distributed behaviour is whatever the underlying
LigerTiledMLPFunctionprovides, which targets DeepSpeed/ZeRO. FSDP/FSDP2 and plain DDP are not yet handled by the default tiled backward (see open PRs #1128 and #1125), so enabling tiled MLP under those backends can corrupt gradients today. This PR only does the wiring and does not change that kernel.Tests cover the central instance path, the standalone instance patch, and the pre-load registry across both a silu and a gelu model. Full test_monkey_patch.py passes (65).