Skip to content

[BUG]LigerFusedLinearCrossEntropyLoss(reduction='none') results in wrong grad, even causes grad=0 #968

Description

@Hevans123

🐛 Describe the bug

Using reduction=none, LigerFusedLinearCrossEntropyLoss returns wrong grads when multiply a designed weight_mask.

  1. For target_labels has ignore_index, grads incorrectly returns all 0
    Code:
Image

Output:
Image

  1. If multiplying a desined weight_mask, grad is not correct (inconsistent with LigerCrossEntropyLoss)
    Code:
Image

Output:
Image

Reproduce code is uploaded below.

Reproduce

import torch

from torch.nn import CrossEntropyLoss

from liger_kernel.transformers import LigerCrossEntropyLoss
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss

# B, T, H, V = 2, 2048, 256, 32000
# B, T, H, V = 2, 1, 10, 15
B, T, H, V = 2, 4, 10, 15

ignore_index = -100
reduction = "none"
device = "cuda"
dtype = torch.float32
scalar = 2

atol, rtol = 1e-8, 1e-5

use_ignore_index = False

if use_ignore_index:
    # Passed
    target_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
    target_flce = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
    # torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
else:
    # Also passed
    target_ce = LigerCrossEntropyLoss(reduction=reduction)
    target_flce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)

_tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
lin_weight = torch.randn(V, H, device=device, dtype=dtype)

_input1 = _tensor.detach().clone().requires_grad_(True)
lin_weight1=lin_weight.clone().detach().requires_grad_(True)
_input1_mul_weight = (_input1@lin_weight1.transpose(0,1))

_input2 = _tensor.detach().clone().requires_grad_(True)
lin_weight2=lin_weight.clone().detach().requires_grad_(True)


target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

# # Assign some random number of elements as ignore_index
# num_elements_to_assign = torch.randint(
#     1, B * T // 2, (1,)
# ).item()  # Random number of elements to set to ignore_index
# indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]  # Randomly select indices
# target[indices_to_assign] = ignore_index

# target[:B*T//2] = ignore_index
target[:1] = ignore_index

if use_ignore_index:
    output = target_ce(_input1_mul_weight, target)
    output2 = target_flce(lin_weight2, _input2, target)
    mask = (target != -100)
    loss1 = (output * mask).sum() / mask.sum()
    loss2 = (output2 * mask).sum() / mask.sum()
else:
    output = target_ce(_input1_mul_weight, target)
    output2 = target_flce(lin_weight2, _input2, target)
    mask = (target != -100).type_as(output2)
    # mask=torch.randn(B*T,device=device, dtype=dtype)
    print(f'weight_mask:{mask}')
    loss1 = (output * mask).sum()
    loss2 = (output2 * mask).sum()
    # loss1=output
    # loss2=output2*output2

print(f'loss1:{loss1}')
print(f'loss2:{loss2}')

loss1.backward(gradient=torch.ones_like(loss1))
loss2.backward(gradient=torch.ones_like(loss2))

print(f'grad1_sum:{torch.sum(torch.abs(_input1.grad))}')

print(f'grad2_sum:{torch.sum(torch.abs(_input2.grad))}')
# assert torch.allclose(loss1, loss2, atol=atol, rtol=rtol)
# assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)

Versions

Environment Versions:

Python version: 3.11.11
Liger Kernel version: 0.6.4
PyTorch version: 2.6.0+cu126
CUDA version: 12.6
Triton version: 3.2.0
Transformers version: 4.56.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions