Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ defmodule Bumblebee do
"Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling},
"Gemma3TextForSequenceClassification" =>
{Bumblebee.Text.Gemma3Text, :for_sequence_classification},
"Gemma4ForConditionalGeneration" =>
{Bumblebee.Text.Gemma4Text, :for_causal_language_modeling},
"GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
Expand Down Expand Up @@ -273,6 +275,7 @@ defmodule Bumblebee do
"clip" => :clip,
"gemma" => :gemma,
"gemma3_text" => :gemma,
"gemma4" => :gemma,
"gpt_neox" => :gpt_neo_x,
"gpt2" => :gpt2,
"gpt_bigcode" => :gpt2,
Expand Down Expand Up @@ -777,11 +780,15 @@ defmodule Bumblebee do
end

defp params_file_loader_fun(".safetensors", opts) do
opts[:safetensors_reader] || (&Safetensors.read!(&1, lazy: true))
opts[:safetensors_reader] || (&read_safetensors_chunked/1)
end

defp params_file_loader_fun(_, _opts), do: &Bumblebee.Conversion.PyTorchLoader.load!/1

defp read_safetensors_chunked(path) do
Safetensors.read!(path, lazy: true)
end

@doc """
Featurizes `input` with the given featurizer.

Expand Down
47 changes: 46 additions & 1 deletion lib/bumblebee/conversion/pytorch_params.ex
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do

{value, diff} =
if all_sources_found? do
source_values = Enum.map(source_values, &Nx.to_tensor/1)
source_values = Enum.map(source_values, &lazy_to_tensor/1)
value = builder_fun.(Enum.reverse(source_values))

case verify_param_shape(param_expr, value) do
Expand Down Expand Up @@ -188,6 +188,51 @@ defmodule Bumblebee.Conversion.PyTorchParams do

defp prepend(diff, key, values), do: Map.update!(diff, key, &(values ++ &1))

# macOS pread(2) returns EINVAL when byte count > INT_MAX (~2 GB).
# For large safetensors tensors, read in 1 GB chunks instead.
@pread_chunk 1_073_741_824

defp lazy_to_tensor(%Safetensors.FileTensor{byte_size: size} = ft)
when size > @pread_chunk do
# Force BinaryBackend: the GPU backend (EMLX) cannot allocate tensors
# this large in a single call, and we must also avoid the macOS pread
# INT_MAX limit by reading in chunks.
Nx.with_default_backend(Nx.BinaryBackend, fn ->
File.open!(ft.path, [:read, :raw], fn file ->
binary = pread_chunked(file, ft.byte_offset, ft.byte_size)
Safetensors.Shared.build_tensor(binary, ft.shape, ft.type)
end)
end)
end

defp lazy_to_tensor(value), do: Nx.to_tensor(value)

defp pread_chunked(file, offset, size) when size <= @pread_chunk do
{:ok, binary} = :file.pread(file, offset, size)
binary
end

defp pread_chunked(file, offset, size) do
full = div(size, @pread_chunk)
rest = rem(size, @pread_chunk)

chunks =
for i <- 0..(full - 1) do
{:ok, chunk} = :file.pread(file, offset + i * @pread_chunk, @pread_chunk)
chunk
end

chunks =
if rest > 0 do
{:ok, tail} = :file.pread(file, offset + full * @pread_chunk, rest)
chunks ++ [tail]
else
chunks
end

IO.iodata_to_binary(chunks)
end

defp infer_prefixes(layers, pytorch_state, params_mapping) do
# Note: target refers to the parameters we are initializing, while
# source refers to the state we are loading from
Expand Down
43 changes: 39 additions & 4 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,14 @@ defmodule Bumblebee.Layers do
Adds a rotary embedding layer to the network.
"""
def rotary_embedding(query, key, position_ids, attention_mask, size, opts \\ []) do
opts = Keyword.validate!(opts, [:name, :scaling_strategy, max_positions: 2048, base: 10_000])
opts =
Keyword.validate!(opts, [
:name,
:scaling_strategy,
:rotary_dim,
max_positions: 2048,
base: 10_000
])

output =
Axon.layer(
Expand All @@ -1254,15 +1261,24 @@ defmodule Bumblebee.Layers do
max_positions,
size,
base,
scaling_strategy
scaling_strategy,
rotary_dim \\ nil
) do
position = Nx.iota({sequence_length})

range = Nx.iota({div(size, 2)}) |> Nx.multiply(2) |> Nx.divide(size)
{num_freqs, denominator} =
if rotary_dim do
{div(rotary_dim, 2), size}
else
{div(size, 2), size}
end

range = Nx.iota({num_freqs}) |> Nx.multiply(2) |> Nx.divide(denominator)

case scaling_strategy do
%{type: :linear, factor: factor} ->
inv_frequency = inv_frequency(base, range)
inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
position = Nx.divide(position, factor)
positions_cos_sin(position, inv_frequency)

Expand All @@ -1273,6 +1289,7 @@ defmodule Bumblebee.Layers do
|> Nx.pow(size / (size - 2))

inv_frequency = inv_frequency(base, range)
inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
positions_cos_sin(position, inv_frequency)

%{
Expand Down Expand Up @@ -1300,6 +1317,7 @@ defmodule Bumblebee.Layers do
end

inv_frequency = inv_frequency(base, range) |> Nx.divide(factor)
inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
{cos, sin} = positions_cos_sin(position, inv_frequency)
{Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)}

Expand All @@ -1321,14 +1339,29 @@ defmodule Bumblebee.Layers do
original_max_positions
)

inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
positions_cos_sin(position, inv_frequency)

_other ->
inv_frequency = inv_frequency(base, range)
inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim)
positions_cos_sin(position, inv_frequency)
end
end

defp maybe_pad_inv_frequency(inv_frequency, _target_size, nil), do: inv_frequency

defp maybe_pad_inv_frequency(inv_frequency, target_size, _rotary_dim) do
pad_size = target_size - Nx.axis_size(inv_frequency, 0)

if pad_size > 0 do
padding = Nx.broadcast(Nx.tensor(0.0, type: Nx.type(inv_frequency)), {pad_size})
Nx.concatenate([inv_frequency, padding])
else
inv_frequency
end
end

defnp llama3_inv_frequency(
inv_frequency,
factor,
Expand Down Expand Up @@ -1381,6 +1414,7 @@ defmodule Bumblebee.Layers do
keyword!(opts, [
:size,
:scaling_strategy,
:rotary_dim,
mode: :inference,
max_positions: 2048,
base: 10_000
Expand All @@ -1400,7 +1434,8 @@ defmodule Bumblebee.Layers do
opts[:max_positions],
opts[:size],
opts[:base],
opts[:scaling_strategy]
opts[:scaling_strategy],
opts[:rotary_dim]
)

position_ids = Nx.as_type(position_ids, :s64)
Expand Down
31 changes: 27 additions & 4 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ defmodule Bumblebee.Layers.Transformer do
:hidden_size,
:ffn,
:kernel_initializer,
:attention_head_size,
:dropout_rate,
:attention_dropout_rate,
:query_use_bias,
Expand All @@ -63,7 +62,8 @@ defmodule Bumblebee.Layers.Transformer do
:block_type,
:attention_scale,
:query_norm,
:key_norm
:key_norm,
:value_norm
]

opts =
Expand All @@ -75,6 +75,7 @@ defmodule Bumblebee.Layers.Transformer do
:num_blocks,
:rotary_embedding,
:attention_window_size,
:attention_head_size,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: nil,
Expand All @@ -97,6 +98,7 @@ defmodule Bumblebee.Layers.Transformer do
cache = opts[:cache]
rotary_embedding = opts[:rotary_embedding]
attention_window_size = opts[:attention_window_size]
attention_head_size = opts[:attention_head_size]

block_opts = Keyword.take(opts, block_opts_keys)

Expand Down Expand Up @@ -142,12 +144,20 @@ defmodule Bumblebee.Layers.Transformer do
size -> size
end

block_attention_head_size =
case attention_head_size do
nil -> nil
fun when is_function(fun, 1) -> fun.(idx)
size -> size
end

{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
block(
state.hidden_state,
[
attention_mask: attention_mask,
attention_head_mask: block_attention_head_mask,
attention_head_size: block_attention_head_size,
attention_relative_bias: attention_relative_bias,
cross_hidden_state: cross_hidden_state,
cross_attention_mask: cross_attention_mask,
Expand Down Expand Up @@ -354,7 +364,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_scale: nil,
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
key_norm: nil,
value_norm: nil
])

name = opts[:name]
Expand Down Expand Up @@ -386,6 +397,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]
value_norm = opts[:value_norm]

ffn_fun =
case ffn do
Expand Down Expand Up @@ -446,6 +458,7 @@ defmodule Bumblebee.Layers.Transformer do
rotary_embedding: rotary_embedding,
query_norm: query_norm,
key_norm: key_norm,
value_norm: value_norm,
name: join(name, "self_attention")
)

Expand Down Expand Up @@ -772,7 +785,8 @@ defmodule Bumblebee.Layers.Transformer do
output_use_bias: true,
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
key_norm: nil,
value_norm: nil
])

attention_mask = opts[:attention_mask]
Expand All @@ -788,6 +802,7 @@ defmodule Bumblebee.Layers.Transformer do
causal = opts[:causal]
attention_window_size = opts[:attention_window_size]
attention_scale = opts[:attention_scale]
value_norm = opts[:value_norm]
dropout_rate = opts[:dropout_rate]
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
Expand Down Expand Up @@ -846,6 +861,13 @@ defmodule Bumblebee.Layers.Transformer do
key
end

value =
if value_norm do
value_norm.(value, join(name, "value_norm"))
else
value
end

{query, key} =
case rotary_embedding do
opts when is_list(opts) ->
Expand All @@ -856,6 +878,7 @@ defmodule Bumblebee.Layers.Transformer do
:position_ids,
:max_positions,
:scaling_strategy,
:rotary_dim,
base: 10_000,
percentage: 1.0
])
Expand Down
Loading