Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,9 @@ benchmark_results/
*.dat

# CatBoost
catboost_info/
catboost_info/

# Dev artifacts
training_folder/
*.pt
data/*
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [Unreleased]

### Added
- `rectools.fast_transformers` module — standalone transformer-based sequential recommenders that work directly with torch tensors, bypassing the `Dataset`/pandas pipeline. GPU-native sequence building via `build_sequences()` gives ~30x preprocessing speedup over `SASRecDataPreparator` on ML-20M
- `FlatSASRec` network and `FlatSASRecModel` — flat SASRec implementation without the ItemNet hierarchy. Pre-norm transformer encoder with id-embeddings, causal masking, softmax and BCE losses. Integrates with RecTools `ModelBase` for compatibility with the standard `fit`/`recommend` API
- `UniSRec` network and `UniSRecModel` — sequential recommender with pretrained text embeddings (e.g. Qwen) and a learnable PCA/BN adaptor. Three-phase training: (1) SASRec warm-up on ID embeddings, (2) adaptor-only with frozen transformer, (3) full fine-tune on pretrained embeddings. Configurable losses (softmax, BCE, gBCE, sampled_softmax), optimizers (Adam, AdamW), cosine warmup scheduler, early stopping, checkpoint save/load. `UniSRecModel.fit()` accepts raw `(user_ids, item_ids, timestamps)` tensors
- `rank_topk()` utility for batched top-k scoring with CSR-based viewed-item filtering and item whitelist support
- `align_embeddings()` for mapping pretrained embedding matrices to internal item ID order
- `GPUBatchDataset` and `make_dataloader()` — lightweight torch Dataset/DataLoader wrappers for sequence training data
- Configurable FFN blocks in `UniSRec`: `conv1d` (original paper), `linear_gelu`, `linear_relu` with adjustable expansion factor
- Tests for all `fast_transformers` submodules (143 tests)


## [0.18.0] - 21.02.2026

### Added
Expand Down
23 changes: 23 additions & 0 deletions rectools/fast_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Fast Transformers: flat sequential recommenders without ItemNet hierarchy."""

from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, hash_item_ids, make_dataloader
from .net import FlatSASRec, SASRecBlock
from .ranking import rank_topk
from .unisrec_lightning import UniSRecLightning
from .unisrec_model import UniSRecModel
from .unisrec_net import FeedForward, UniSRec

__all__ = [
"build_sequences",
"align_embeddings",
"hash_item_ids",
"GPUBatchDataset",
"make_dataloader",
"FlatSASRec",
"SASRecBlock",
"rank_topk",
"UniSRec",
"FeedForward",
"UniSRecLightning",
"UniSRecModel",
]
151 changes: 151 additions & 0 deletions rectools/fast_transformers/gpu_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""GPU-native sequence building for transformer training. Pure torch, no pandas/numpy."""

import typing as tp

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset


def _splitmix64(x: torch.Tensor) -> torch.Tensor:
"""Vectorized splitmix64 bit-mixer: element-wise int64 hash over a torch tensor.

Standard library hashes (``hash()``, ``hashlib``) operate on scalar Python objects
and cannot be vectorized across GPU tensors. Splitmix64 is pure int64 arithmetic,
so it maps naturally to ``torch.Tensor`` ops and runs on any device.

Reference: https://xorshift.di.unimi.it/splitmix64.c (Vigna, 2015).
"""
x = x.long()
x = (x ^ (x >> 30)) * (-4658895280553007687) # 0xbf58476d1ce4e5b9 as signed int64
x = (x ^ (x >> 27)) * (-7723592293110705685) # 0x94d049bb133111eb as signed int64
return x ^ (x >> 31)


def hash_item_ids(item_ids: torch.Tensor, dict_size: int) -> torch.Tensor:
"""Map arbitrary integer item IDs to [1, dict_size] via splitmix64 hash."""
return _splitmix64(item_ids) % dict_size + 1


def build_sequences(
user_ids: torch.Tensor,
item_ids: torch.Tensor,
timestamps: torch.Tensor,
max_len: int,
min_interactions: int = 2,
device: str = "cuda",
id_mapping: str = "dense",
) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
user_ids = user_ids.to(device)
item_ids = item_ids.to(device)
timestamps = timestamps.to(device)

unique_items = torch.unique(item_ids)
n_unique = len(unique_items)

if id_mapping == "dense":
_, item_inv = torch.unique(item_ids, return_inverse=True)
internal_items = item_inv + 1
elif id_mapping == "hash":
internal_items = hash_item_ids(item_ids, n_unique)
else:
raise ValueError(f"Unknown id_mapping: {id_mapping}. Use 'dense' or 'hash'")

unique_users, user_inv = torch.unique(user_ids, return_inverse=True)

order1 = torch.argsort(timestamps, stable=True)
order2 = torch.argsort(user_inv[order1], stable=True)
order = order1[order2]

sorted_user_inv = user_inv[order]
sorted_items = internal_items[order]

changes = torch.where(sorted_user_inv[1:] != sorted_user_inv[:-1])[0] + 1
starts = torch.cat([torch.tensor([0], device=device), changes])
ends = torch.cat([changes, torch.tensor([len(sorted_user_inv)], device=device)])
lengths = ends - starts

mask = lengths >= min_interactions
starts = starts[mask]
ends = ends[mask]
lengths = lengths[mask]
n_users = len(starts)

capped_lens = torch.clamp(lengths, max=max_len + 1)

effective_lens = torch.clamp(capped_lens - 1, min=0)
total_elements = effective_lens.sum().item()

x = torch.zeros(n_users, max_len, dtype=torch.long, device=device)
y = torch.zeros(n_users, max_len, dtype=torch.long, device=device)

if total_elements > 0:
user_indices = torch.repeat_interleave(torch.arange(n_users, device=device), effective_lens)
cumsum = effective_lens.cumsum(0)
offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave(
cumsum - effective_lens, effective_lens
)

x_src = torch.repeat_interleave(ends - capped_lens, effective_lens) + offsets
y_src = x_src + 1
col_indices = max_len - torch.repeat_interleave(effective_lens, effective_lens) + offsets

x[user_indices, col_indices] = sorted_items[x_src]
y[user_indices, col_indices] = sorted_items[y_src]

valid_user_indices = torch.where(mask)[0]
result_users = unique_users[valid_user_indices] if len(valid_user_indices) < len(unique_users) else unique_users

return x, y, unique_items, result_users


def align_embeddings(
pretrained: torch.Tensor,
unique_items: torch.Tensor,
n_items: int,
id_mapping: str = "dense",
) -> torch.Tensor:
idx = unique_items.long().cpu()
valid = (idx >= 0) & (idx < pretrained.shape[0])

if pretrained.ndim == 2:
aligned = torch.zeros(n_items + 1, pretrained.shape[1])
else:
aligned = torch.zeros(n_items + 1, pretrained.shape[1], pretrained.shape[2])

if id_mapping == "dense":
aligned[1:][valid] = pretrained[idx[valid]]
elif id_mapping == "hash":
positions = hash_item_ids(idx, n_items)
aligned[positions[valid]] = pretrained[idx[valid]]
else:
raise ValueError(f"Unknown id_mapping: {id_mapping}. Use 'dense' or 'hash'")

return aligned


class GPUBatchDataset(TorchDataset):
def __init__(self, x: torch.Tensor, y: torch.Tensor, transform: tp.Optional[tp.Callable] = None):
self.x = x
self.y = y
self.transform = transform

def __len__(self) -> int:
return len(self.x)

def __getitem__(self, idx: int) -> tp.Dict[str, torch.Tensor]:
batch = {"x": self.x[idx], "y": self.y[idx]}
if self.transform:
batch = self.transform(batch)
return batch


def make_dataloader(
x: torch.Tensor,
y: torch.Tensor,
batch_size: int,
shuffle: bool = True,
transform: tp.Optional[tp.Callable] = None,
) -> DataLoader:
ds = GPUBatchDataset(x, y, transform=transform)
return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0)
175 changes: 175 additions & 0 deletions rectools/fast_transformers/net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Flat SASRec network: pre-norm transformer encoder with plain id embeddings."""

import typing as tp

import torch
from torch import nn


class SASRecBlock(nn.Module):
"""Pre-norm transformer block: LayerNorm -> MHA -> residual -> LayerNorm -> FFN -> residual."""

def __init__(self, n_factors: int, n_heads: int, dropout: float = 0.1) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(n_factors)
self.mha = nn.MultiheadAttention(n_factors, n_heads, dropout=dropout, batch_first=True)
self.ln2 = nn.LayerNorm(n_factors)
self.ffn = nn.Sequential(
nn.Linear(n_factors, n_factors * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(n_factors * 4, n_factors),
nn.Dropout(dropout),
)

def forward(
self,
x: torch.Tensor,
attn_mask: tp.Optional[torch.Tensor] = None,
key_padding_mask: tp.Optional[torch.Tensor] = None,
) -> torch.Tensor:
h = self.ln1(x)
h, _ = self.mha(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
x = x + h
h = self.ln2(x)
x = x + self.ffn(h)
return x


class FlatSASRec(nn.Module):
"""
Flat SASRec: sequential recommender with plain id-embedding table
(no ItemNet hierarchy).

Parameters
----------
n_items : int
Total number of items (excluding padding token 0).
n_factors : int
Embedding / hidden dimension.
n_blocks : int
Number of transformer blocks.
n_heads : int
Number of attention heads.
session_max_len : int
Maximum sequence length.
dropout : float
Dropout rate.
"""

PADDING_IDX = 0

def __init__(
self,
n_items: int,
n_factors: int,
n_blocks: int,
n_heads: int,
session_max_len: int,
dropout: float = 0.1,
) -> None:
super().__init__()
self.n_items = n_items
self.n_factors = n_factors
self.session_max_len = session_max_len

# +1 for padding at index 0
self.item_emb = nn.Embedding(n_items + 1, n_factors, padding_idx=self.PADDING_IDX)
self.pos_emb = nn.Embedding(session_max_len, n_factors)
self.emb_dropout = nn.Dropout(dropout)

self.blocks = nn.ModuleList([SASRecBlock(n_factors, n_heads, dropout) for _ in range(n_blocks)])
self.final_ln = nn.LayerNorm(n_factors)

def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
return torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1)

def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode full sequence.

Parameters
----------
x : LongTensor (B, L)
Item id sequences (0 = padding).

Returns
-------
Tensor (B, L, D)
"""
B, L = x.shape
positions = torch.arange(L, device=x.device).unsqueeze(0)
h = self.item_emb(x) + self.pos_emb(positions)
h = self.emb_dropout(h)

# timeline_mask: zero out padding positions to prevent NaN from attention
timeline_mask = (x != self.PADDING_IDX).unsqueeze(-1).float() # (B, L, 1)
attn_mask = self._causal_mask(L, x.device)
key_padding_mask = x == self.PADDING_IDX

for block in self.blocks:
h = h * timeline_mask
h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
h = h * timeline_mask
h = self.final_ln(h)
return h

def encode_last(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode and return only the last non-padding position representation.

Parameters
----------
x : LongTensor (B, L)

Returns
-------
Tensor (B, D)
"""
h = self.encode(x) # (B, L, D)
return h[:, -1, :] # left-padded: last position is always rightmost

def all_item_embeddings(self) -> torch.Tensor:
"""
Return embeddings for all items (1..n_items), excluding padding.

Returns
-------
Tensor (n_items, D)
"""
ids = torch.arange(1, self.n_items + 1, device=self.item_emb.weight.device)
return self.item_emb(ids)

def forward(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Training forward pass.

Parameters
----------
batch : dict
Must contain 'x' (B, L) and 'y' (B, L).
Optionally 'negatives' (B, L, N) for candidate-logits branch.

Returns
-------
logits : Tensor
If negatives present: (B, L, 1 + N) — positive + negative logits.
Otherwise: (B, L, n_items) — full catalog logits.
"""
x = batch["x"] # (B, L)
y = batch["y"] # (B, L)

h = self.encode(x) # (B, L, D)

if "negatives" in batch:
negatives = batch["negatives"] # (B, L, N)
pos_emb = self.item_emb(y).unsqueeze(3) # (B, L, D, 1)
neg_emb = self.item_emb(negatives) # (B, L, N, D)
neg_emb = neg_emb.transpose(2, 3) # (B, L, D, N)
all_emb = torch.cat([pos_emb, neg_emb], dim=3) # (B, L, D, 1+N)
logits = (h.unsqueeze(2) @ all_emb).squeeze(2) # (B, L, 1+N)
# -> shape is (B, L, 1+N) where first column is positive logit
else:
item_embs = self.all_item_embeddings() # (n_items, D)
logits = h @ item_embs.T # (B, L, n_items)
return logits
Loading