Skip to content

LigerORPOTrainer doesn't work when not using FSDP #1229

Description

@Ola-Vish

🐛 Describe the bug

Hi all,
I tried using LigerORPOTrainer but it doesn't seem to work when running on a single GPU.
I tried running the orpo trainer example but it fails on assert isinstance(wrapper_module, FullyShardedDataParallel) .
As far as I can tell from the code, this is because _FSDPForwardRedirection is used without checking if the model is FullyShardedDataParallel, as can be seen here.

Please let me know if I understood the issue correctly, perhaps I can submit a fix 😃

Reproduce

Link to reproduction code - here

Tried running this code (which is basically the same as the this example code by running "python liger_orpo_trainer.py", and what I get is the following error:

0%| | 0/100 [00:00<?, ?it/s]Traceback (most recent call last): File "/home/ubuntu/training_code/liger_orpo_trainer.py", line 43, in <module> main() File "/home/ubuntu/training_code/liger_orpo_trainer.py", line 39, in main trainer.train() File "/home/ubuntu/training_code/sota_venv/lib/python3.12/site-packages/transformers/trainer.py", line 1424, in train return inner_training_loop( ^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/training_code/sota_venv/lib/python3.12/site-packages/transformers/trainer.py", line 1506, in _inner_training_loop self._run_epoch( File "/home/ubuntu/training_code/sota_venv/lib/python3.12/site-packages/transformers/trainer.py", line 1734, in _run_epoch tr_loss_step = self.training_step(model, inputs, num_items_in_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/training_code/sota_venv/lib/python3.12/site-packages/transformers/trainer.py", line 1906, in training_step loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/training_code/sota_venv/lib/python3.12/site-packages/trl/trainer/orpo_trainer.py", line 873, in compute_loss loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/training_code/sota_venv/lib/python3.12/site-packages/liger_kernel/transformers/trainer/orpo_trainer.py", line 101, in get_batch_loss_metrics loss, aux_outputs = self.concatenated_forward(model, batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/training_code/sota_venv/lib/python3.12/site-packages/liger_kernel/transformers/trainer/orpo_trainer.py", line 77, in concatenated_forward orpo_loss, aux_outputs = _FSDPForwardRedirection()( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ubuntu/training_code/sota_venv/lib/python3.12/site-packages/liger_kernel/transformers/fsdp.py", line 40, in __call__ assert isinstance(wrapper_module, FullyShardedDataParallel) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AssertionError 0%| | 0/100 [00:00<?, ?it/s]

Versions

Environment Report:

Operating System: Linux-6.17.0-1013-gcp-x86_64-with-glibc2.39
Python version: 3.12.3
Liger Kernel version: 0.8.0
PyTorch version: 2.10.0+cu128
CUDA version: 12.8
HIP(ROCm) version: Not available
Triton version: 3.6.0
Transformers version: 5.5.0
XPU version: XPU Not Available

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions