diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 6237dc61..d10a56e1 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -142,6 +142,8 @@ defmodule Bumblebee do "Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling}, "Gemma3TextForSequenceClassification" => {Bumblebee.Text.Gemma3Text, :for_sequence_classification}, + "Gemma4ForConditionalGeneration" => + {Bumblebee.Text.Gemma4Text, :for_causal_language_modeling}, "GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification}, "GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification}, "GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling}, @@ -273,6 +275,7 @@ defmodule Bumblebee do "clip" => :clip, "gemma" => :gemma, "gemma3_text" => :gemma, + "gemma4" => :gemma, "gpt_neox" => :gpt_neo_x, "gpt2" => :gpt2, "gpt_bigcode" => :gpt2, @@ -777,11 +780,15 @@ defmodule Bumblebee do end defp params_file_loader_fun(".safetensors", opts) do - opts[:safetensors_reader] || (&Safetensors.read!(&1, lazy: true)) + opts[:safetensors_reader] || (&read_safetensors_chunked/1) end defp params_file_loader_fun(_, _opts), do: &Bumblebee.Conversion.PyTorchLoader.load!/1 + defp read_safetensors_chunked(path) do + Safetensors.read!(path, lazy: true) + end + @doc """ Featurizes `input` with the given featurizer. diff --git a/lib/bumblebee/conversion/pytorch_params.ex b/lib/bumblebee/conversion/pytorch_params.ex index c17fef85..0576fe8b 100644 --- a/lib/bumblebee/conversion/pytorch_params.ex +++ b/lib/bumblebee/conversion/pytorch_params.ex @@ -150,7 +150,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do {value, diff} = if all_sources_found? do - source_values = Enum.map(source_values, &Nx.to_tensor/1) + source_values = Enum.map(source_values, &lazy_to_tensor/1) value = builder_fun.(Enum.reverse(source_values)) case verify_param_shape(param_expr, value) do @@ -188,6 +188,51 @@ defmodule Bumblebee.Conversion.PyTorchParams do defp prepend(diff, key, values), do: Map.update!(diff, key, &(values ++ &1)) + # macOS pread(2) returns EINVAL when byte count > INT_MAX (~2 GB). + # For large safetensors tensors, read in 1 GB chunks instead. + @pread_chunk 1_073_741_824 + + defp lazy_to_tensor(%Safetensors.FileTensor{byte_size: size} = ft) + when size > @pread_chunk do + # Force BinaryBackend: the GPU backend (EMLX) cannot allocate tensors + # this large in a single call, and we must also avoid the macOS pread + # INT_MAX limit by reading in chunks. + Nx.with_default_backend(Nx.BinaryBackend, fn -> + File.open!(ft.path, [:read, :raw], fn file -> + binary = pread_chunked(file, ft.byte_offset, ft.byte_size) + Safetensors.Shared.build_tensor(binary, ft.shape, ft.type) + end) + end) + end + + defp lazy_to_tensor(value), do: Nx.to_tensor(value) + + defp pread_chunked(file, offset, size) when size <= @pread_chunk do + {:ok, binary} = :file.pread(file, offset, size) + binary + end + + defp pread_chunked(file, offset, size) do + full = div(size, @pread_chunk) + rest = rem(size, @pread_chunk) + + chunks = + for i <- 0..(full - 1) do + {:ok, chunk} = :file.pread(file, offset + i * @pread_chunk, @pread_chunk) + chunk + end + + chunks = + if rest > 0 do + {:ok, tail} = :file.pread(file, offset + full * @pread_chunk, rest) + chunks ++ [tail] + else + chunks + end + + IO.iodata_to_binary(chunks) + end + defp infer_prefixes(layers, pytorch_state, params_mapping) do # Note: target refers to the parameters we are initializing, while # source refers to the state we are loading from diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 27d81990..49d333c6 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1237,7 +1237,14 @@ defmodule Bumblebee.Layers do Adds a rotary embedding layer to the network. """ def rotary_embedding(query, key, position_ids, attention_mask, size, opts \\ []) do - opts = Keyword.validate!(opts, [:name, :scaling_strategy, max_positions: 2048, base: 10_000]) + opts = + Keyword.validate!(opts, [ + :name, + :scaling_strategy, + :rotary_dim, + max_positions: 2048, + base: 10_000 + ]) output = Axon.layer( @@ -1254,15 +1261,24 @@ defmodule Bumblebee.Layers do max_positions, size, base, - scaling_strategy + scaling_strategy, + rotary_dim \\ nil ) do position = Nx.iota({sequence_length}) - range = Nx.iota({div(size, 2)}) |> Nx.multiply(2) |> Nx.divide(size) + {num_freqs, denominator} = + if rotary_dim do + {div(rotary_dim, 2), size} + else + {div(size, 2), size} + end + + range = Nx.iota({num_freqs}) |> Nx.multiply(2) |> Nx.divide(denominator) case scaling_strategy do %{type: :linear, factor: factor} -> inv_frequency = inv_frequency(base, range) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) position = Nx.divide(position, factor) positions_cos_sin(position, inv_frequency) @@ -1273,6 +1289,7 @@ defmodule Bumblebee.Layers do |> Nx.pow(size / (size - 2)) inv_frequency = inv_frequency(base, range) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) positions_cos_sin(position, inv_frequency) %{ @@ -1300,6 +1317,7 @@ defmodule Bumblebee.Layers do end inv_frequency = inv_frequency(base, range) |> Nx.divide(factor) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) {cos, sin} = positions_cos_sin(position, inv_frequency) {Nx.multiply(cos, cos_sin_factor), Nx.multiply(sin, cos_sin_factor)} @@ -1321,14 +1339,29 @@ defmodule Bumblebee.Layers do original_max_positions ) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) positions_cos_sin(position, inv_frequency) _other -> inv_frequency = inv_frequency(base, range) + inv_frequency = maybe_pad_inv_frequency(inv_frequency, div(size, 2), rotary_dim) positions_cos_sin(position, inv_frequency) end end + defp maybe_pad_inv_frequency(inv_frequency, _target_size, nil), do: inv_frequency + + defp maybe_pad_inv_frequency(inv_frequency, target_size, _rotary_dim) do + pad_size = target_size - Nx.axis_size(inv_frequency, 0) + + if pad_size > 0 do + padding = Nx.broadcast(Nx.tensor(0.0, type: Nx.type(inv_frequency)), {pad_size}) + Nx.concatenate([inv_frequency, padding]) + else + inv_frequency + end + end + defnp llama3_inv_frequency( inv_frequency, factor, @@ -1381,6 +1414,7 @@ defmodule Bumblebee.Layers do keyword!(opts, [ :size, :scaling_strategy, + :rotary_dim, mode: :inference, max_positions: 2048, base: 10_000 @@ -1400,7 +1434,8 @@ defmodule Bumblebee.Layers do opts[:max_positions], opts[:size], opts[:base], - opts[:scaling_strategy] + opts[:scaling_strategy], + opts[:rotary_dim] ) position_ids = Nx.as_type(position_ids, :s64) diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 188b0ffe..8289f2a9 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -52,7 +52,6 @@ defmodule Bumblebee.Layers.Transformer do :hidden_size, :ffn, :kernel_initializer, - :attention_head_size, :dropout_rate, :attention_dropout_rate, :query_use_bias, @@ -63,7 +62,8 @@ defmodule Bumblebee.Layers.Transformer do :block_type, :attention_scale, :query_norm, - :key_norm + :key_norm, + :value_norm ] opts = @@ -75,6 +75,7 @@ defmodule Bumblebee.Layers.Transformer do :num_blocks, :rotary_embedding, :attention_window_size, + :attention_head_size, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: nil, @@ -97,6 +98,7 @@ defmodule Bumblebee.Layers.Transformer do cache = opts[:cache] rotary_embedding = opts[:rotary_embedding] attention_window_size = opts[:attention_window_size] + attention_head_size = opts[:attention_head_size] block_opts = Keyword.take(opts, block_opts_keys) @@ -142,12 +144,20 @@ defmodule Bumblebee.Layers.Transformer do size -> size end + block_attention_head_size = + case attention_head_size do + nil -> nil + fun when is_function(fun, 1) -> fun.(idx) + size -> size + end + {hidden_state, attention, cross_attention, block_cache, attention_relative_bias} = block( state.hidden_state, [ attention_mask: attention_mask, attention_head_mask: block_attention_head_mask, + attention_head_size: block_attention_head_size, attention_relative_bias: attention_relative_bias, cross_hidden_state: cross_hidden_state, cross_attention_mask: cross_attention_mask, @@ -354,7 +364,8 @@ defmodule Bumblebee.Layers.Transformer do attention_scale: nil, rotary_embedding: nil, query_norm: nil, - key_norm: nil + key_norm: nil, + value_norm: nil ]) name = opts[:name] @@ -386,6 +397,7 @@ defmodule Bumblebee.Layers.Transformer do rotary_embedding = opts[:rotary_embedding] query_norm = opts[:query_norm] key_norm = opts[:key_norm] + value_norm = opts[:value_norm] ffn_fun = case ffn do @@ -446,6 +458,7 @@ defmodule Bumblebee.Layers.Transformer do rotary_embedding: rotary_embedding, query_norm: query_norm, key_norm: key_norm, + value_norm: value_norm, name: join(name, "self_attention") ) @@ -772,7 +785,8 @@ defmodule Bumblebee.Layers.Transformer do output_use_bias: true, rotary_embedding: nil, query_norm: nil, - key_norm: nil + key_norm: nil, + value_norm: nil ]) attention_mask = opts[:attention_mask] @@ -788,6 +802,7 @@ defmodule Bumblebee.Layers.Transformer do causal = opts[:causal] attention_window_size = opts[:attention_window_size] attention_scale = opts[:attention_scale] + value_norm = opts[:value_norm] dropout_rate = opts[:dropout_rate] rotary_embedding = opts[:rotary_embedding] query_norm = opts[:query_norm] @@ -846,6 +861,13 @@ defmodule Bumblebee.Layers.Transformer do key end + value = + if value_norm do + value_norm.(value, join(name, "value_norm")) + else + value + end + {query, key} = case rotary_embedding do opts when is_list(opts) -> @@ -856,6 +878,7 @@ defmodule Bumblebee.Layers.Transformer do :position_ids, :max_positions, :scaling_strategy, + :rotary_dim, base: 10_000, percentage: 1.0 ]) diff --git a/lib/bumblebee/text/gemma4_text.ex b/lib/bumblebee/text/gemma4_text.ex new file mode 100644 index 00000000..0b593884 --- /dev/null +++ b/lib/bumblebee/text/gemma4_text.ex @@ -0,0 +1,815 @@ +defmodule Bumblebee.Text.Gemma4Text do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 262_144, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 131_072, + doc: """ + the maximum sequence length that this model can process + """ + ], + hidden_size: [ + default: 2304, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 9216, + doc: "the dimensionality of intermediate layers" + ], + attention_head_size: [ + default: 256, + doc: + "the size of the key, value, and query projection per attention head for sliding attention layers" + ], + global_attention_head_size: [ + default: 512, + doc: + "the size of the key, value, and query projection per attention head for global (full) attention layers" + ], + num_blocks: [ + default: 30, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 8, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: 4, + doc: "the number of key value heads for each attention layer in the model" + ], + num_global_key_value_heads: [ + default: nil, + doc: """ + the number of key value heads for global (full) attention layers. + If nil, defaults to num_key_value_heads. + """ + ], + activation: [ + default: :gelu_approx_tanh, + doc: "the activation function" + ], + rotary_embedding_base: [ + default: 1_000_000, + doc: "base for computing rotary embedding frequency for global attention layers" + ], + rotary_embedding_base_local: [ + default: 10_000, + doc: "base for computing rotary embedding frequency for local (sliding) attention layers" + ], + partial_rotary_factor: [ + default: 1.0, + doc: """ + the fraction of head dimensions to apply rotary embeddings to in global attention layers. + Sliding attention layers always use full rotation (1.0). + Extracted from rope_parameters.full_attention.partial_rotary_factor. + """ + ], + use_attention_bias: [ + default: false, + doc: + "whether or not to use bias in the query, key, value, and output projections in attention layers" + ], + layer_norm_epsilon: [ + default: 1.0e-6, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + attention_window_size: [ + default: 512, + doc: + "window size for both sides of the sliding attention window (used for `:sliding_attention` layers)" + ], + layer_types: [ + default: nil, + doc: """ + a list of layer types for each layer, where each element is either `:sliding_attention` + (local attention with sliding window) or `:full_attention` (global attention) + """ + ], + tie_word_embeddings: [ + default: true, + doc: "whether to tie input and output embedding weights" + ], + final_logit_softcapping: [ + default: nil, + doc: """ + if set, logits are capped using `tanh` to this value before the final softmax. + This prevents extreme logit values from dominating the output distribution. + Logits are scaled by tanh(logit / cap) * cap. + """ + ], + hidden_size_per_layer_input: [ + default: 256, + doc: """ + the dimensionality of the per-layer input embeddings (PLE). Each transformer layer + gets its own small embedding that is added to the main hidden state. + """ + ], + vocab_size_per_layer_input: [ + default: 262_144, + doc: "the vocabulary size for per-layer input embeddings" + ], + num_kv_shared_layers: [ + default: 0, + doc: """ + the number of consecutive decoder layers that share the same key-value projections. + A value of 0 means no sharing (each layer has independent KV projections). + """ + ], + use_double_wide_mlp: [ + default: false, + doc: """ + whether to use a double-width MLP with fused gate and up projections. + When true, the gate and up projections are doubled in size. + """ + ] + ] ++ + Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + Gemma 4 model family (text backbone). + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + layer_types = spec.layer_types || generate_layer_types(spec.num_blocks) + + blocks = + Enum.map(0..(spec.num_blocks - 1), fn idx -> + head_size = + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> spec.global_attention_head_size + :sliding_attention -> spec.attention_head_size + end + + shape = {batch_size, max_length, spec.num_attention_heads, head_size} + zeros = Nx.broadcast(0.0, shape) + self_attention = %{key: zeros, value: zeros} + + %{self_attention: self_attention, cross_attention: %Axon.None{}} + end) + |> List.to_tuple() + + offset = Nx.tensor(0) + attention_mask = Nx.broadcast(0, {batch_size, max_length}) + %{blocks: blocks, offset: offset, attention_mask: attention_mask} + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output", + use_bias: false + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + # PLE: compute per-layer inputs + per_layer_inputs = + if spec.hidden_size_per_layer_input do + compute_per_layer_inputs(inputs["input_ids"], embeddings, spec) + else + nil + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + per_layer_inputs: per_layer_inputs, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + shift: 1.0, + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp compute_per_layer_inputs(input_ids, embeddings, spec) do + ple_dim = spec.hidden_size_per_layer_input + num_layers = spec.num_blocks + total_ple_dim = num_layers * ple_dim + + # Token-identity: lookup in per-layer embedding table + token_identity = + Axon.embedding(input_ids, spec.vocab_size_per_layer_input, total_ple_dim, + name: "embedder.token_embedding_per_layer" + ) + |> Axon.nx(fn x -> + # Scale by sqrt(ple_dim) + scale = Nx.tensor(ple_dim, type: Nx.type(x)) |> Nx.sqrt() + x = Nx.multiply(x, scale) + # Reshape from [B, S, num_layers * ple_dim] to [B, S, num_layers, ple_dim] + shape = Nx.shape(x) + batch = elem(shape, 0) + seq = elem(shape, 1) + Nx.reshape(x, {batch, seq, num_layers, ple_dim}) + end) + + # Context-aware: project main embeddings, reshape, then norm. + # The norm weight has shape [ple_dim] because HuggingFace applies it + # after reshaping to [B, S, num_layers, ple_dim]. + context_aware = + Axon.dense(embeddings, total_ple_dim, + name: "per_layer_model_projection", + use_bias: false + ) + |> Axon.nx(fn x -> + scale = Nx.divide(1.0, Nx.sqrt(Nx.tensor(spec.hidden_size, type: Nx.type(x)))) + x = Nx.multiply(x, scale) + shape = Nx.shape(x) + batch = elem(shape, 0) + seq = elem(shape, 1) + Nx.reshape(x, {batch, seq, num_layers, ple_dim}) + end) + |> Layers.rms_norm( + name: "per_layer_projection_norm", + shift: 1.0, + epsilon: spec.layer_norm_epsilon + ) + + # Combine: (token_identity + context_aware) * (1/sqrt(2)) + Axon.layer( + fn token_id, context, _opts -> + inv_sqrt2 = Nx.tensor(1.0 / :math.sqrt(2), type: Nx.type(token_id)) + Nx.multiply(Nx.add(token_id, context), inv_sqrt2) + end, + [token_identity, context_aware] + ) + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + |> Axon.nx(fn x -> + normalization_factor = + spec.hidden_size + |> Nx.tensor(type: Nx.type(x)) + |> Nx.sqrt() + + Nx.multiply(x, normalization_factor) + end) + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + per_layer_inputs = opts[:per_layer_inputs] + + query_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + key_norm = &Layers.rms_norm(&1, shift: 1.0, epsilon: spec.layer_norm_epsilon, name: &2) + + value_norm = fn value, _name -> + Axon.nx(value, fn x -> + variance = Nx.mean(Nx.multiply(x, x), axes: [-1], keep_axes: true) + Nx.multiply(x, Nx.rsqrt(Nx.add(variance, spec.layer_norm_epsilon))) + end) + end + + layer_types = spec.layer_types || generate_layer_types(spec.num_blocks) + + attention_window_size = fn idx -> + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> nil + :sliding_attention -> {spec.attention_window_size, spec.attention_window_size} + end + end + + rotary_embedding = fn idx -> + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + percentage: 1.0, + rotary_dim: trunc(spec.global_attention_head_size * spec.partial_rotary_factor) + ] + + :sliding_attention -> + [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base_local, + percentage: 1.0 + ] + end + end + + attention_scale = 1.0 + non_double_wide_count = spec.num_blocks - spec.num_kv_shared_layers + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + attention_head_size: fn idx -> + case Enum.at(layer_types, idx, :sliding_attention) do + :full_attention -> spec.global_attention_head_size + :sliding_attention -> spec.attention_head_size + end + end, + attention_scale: attention_scale, + kernel_initializer: kernel_initializer(spec), + layer_norm: + &Layers.rms_norm(&1, + shift: 1.0, + name: &2, + epsilon: spec.layer_norm_epsilon, + upcast: :all + ), + ffn: fn hidden_state, ffn_name -> + idx = + ffn_name + |> String.split(".") + |> Enum.at(2) + |> String.to_integer() + + intermediate_size = + if spec.use_double_wide_mlp and idx >= non_double_wide_count do + spec.intermediate_size * 2 + else + spec.intermediate_size + end + + gated_ffn(hidden_state, intermediate_size, spec.hidden_size, + name: ffn_name, + activation: spec.activation + ) + end, + block_type: &gemma4_block_impl(&1, &2, &3, spec, per_layer_inputs), + causal: true, + rotary_embedding: rotary_embedding, + attention_window_size: attention_window_size, + query_norm: query_norm, + key_norm: key_norm, + value_norm: value_norm, + query_use_bias: spec.use_attention_bias, + key_use_bias: spec.use_attention_bias, + value_use_bias: spec.use_attention_bias, + output_use_bias: spec.use_attention_bias, + name: join(name, "blocks") + ) + end + + # Custom block implementation for Gemma 4's normalization structure: + # - Post-attention norm BEFORE residual add + # - Pre/post FFN norms + # - PLE applied AFTER attention+MLP as a third residual block + defp gemma4_block_impl(hidden_state, steps, name, spec, per_layer_inputs) do + # 1. Self-attention with pre/post norms + shortcut = hidden_state + + {hidden_state, attention_info} = + hidden_state + |> steps.self_attention_norm.() + |> steps.self_attention.() + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "post_attention_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = Axon.add(shortcut, hidden_state) + + # 2. FFN with pre/post norms + shortcut = hidden_state + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "pre_ffn_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = steps.ffn.(hidden_state) + + hidden_state = + Layers.rms_norm(hidden_state, + shift: 1.0, + name: join(name, "post_ffn_norm"), + epsilon: spec.layer_norm_epsilon, + upcast: :all + ) + + hidden_state = Axon.add(shortcut, hidden_state) + + # 3. PLE: gate hidden_state down to PLE dim, multiply with PLE signal, project back up + hidden_state = + if per_layer_inputs do + idx = + name + |> String.split(".") + |> Enum.at(2) + |> String.to_integer() + + ple_slice = + Axon.nx(per_layer_inputs, fn x -> + x[[.., .., idx, ..]] + end) + + shortcut_ple = hidden_state + + # Gate: project hidden_state DOWN to PLE dimension, then activation + gated = + Axon.dense(hidden_state, spec.hidden_size_per_layer_input, + name: join(name, "per_layer_input_gate"), + use_bias: false + ) + + gated = Layers.activation(gated, spec.activation) + + # Element-wise multiply with PLE signal + gated = Axon.multiply(gated, ple_slice) + + # Project back UP to hidden dimension + gated = + Axon.dense(gated, spec.hidden_size, + name: join(name, "per_layer_projection"), + use_bias: false + ) + + # Normalize + gated = + Layers.rms_norm(gated, + shift: 1.0, + name: join(name, "post_per_layer_input_norm"), + epsilon: spec.layer_norm_epsilon + ) + + Axon.add(shortcut_ple, gated) + else + hidden_state + end + + # 4. Layer scalar: multiply output by per-layer learned scalar + hidden_state = + Axon.layer( + fn hidden_state, scalar, _opts -> + Nx.multiply(hidden_state, Nx.reshape(scalar, {})) + end, + [ + hidden_state, + Axon.param("layer_scalar", fn _ -> {1} end, initializer: Axon.Initializers.ones()) + ], + name: join(name, "layer_scalar_op") + ) + + # Handle cross-attention (required by block interface but not used by Gemma 4) + {_hidden_state, cross_attention_info} = + steps.cross_attention_maybe.(hidden_state, fn _ -> + raise "cross attention not supported" + end) + + {hidden_state, attention_info, cross_attention_info} + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + logits = + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + + cap = spec.final_logit_softcapping + + if cap do + Axon.nx(logits, fn x -> + x + |> Nx.divide(cap) + |> Nx.tanh() + |> Nx.multiply(cap) + end) + else + logits + end + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + # Generate layer_types fallback: every 5th layer uses full attention + defp generate_layer_types(num_blocks) do + Enum.map(0..(num_blocks - 1), fn i -> + if rem(i + 1, 5) == 0 do + :full_attention + else + :sliding_attention + end + end) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + data = data["text_config"] || data + rope_params = data["rope_parameters"] || %{} + full_attention_rope = rope_params["full_attention"] || %{} + sliding_attention_rope = rope_params["sliding_attention"] || %{} + + data = + data + |> Map.put_new("rope_theta", full_attention_rope["rope_theta"] || 1_000_000) + |> Map.put_new("rope_local_base_freq", sliding_attention_rope["rope_theta"] || 10_000) + |> Map.put_new( + "partial_rotary_factor", + full_attention_rope["partial_rotary_factor"] || 1.0 + ) + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + num_global_key_value_heads: {"num_global_key_value_heads", optional(number())}, + attention_head_size: {"head_dim", number()}, + global_attention_head_size: {"global_head_dim", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_activation", activation()}, + use_attention_bias: {"attention_bias", boolean()}, + rotary_embedding_base: {"rope_theta", number()}, + rotary_embedding_base_local: {"rope_local_base_freq", number()}, + partial_rotary_factor: {"partial_rotary_factor", number()}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()}, + attention_window_size: {"sliding_window", optional(number())}, + layer_types: + {"layer_types", + list( + mapping(%{ + "sliding_attention" => :sliding_attention, + "full_attention" => :full_attention + }) + )}, + tie_word_embeddings: {"tie_word_embeddings", boolean()}, + final_logit_softcapping: {"final_logit_softcapping", optional(number())}, + hidden_size_per_layer_input: {"hidden_size_per_layer_input", number()}, + vocab_size_per_layer_input: {"vocab_size_per_layer_input", number()}, + num_kv_shared_layers: {"num_kv_shared_layers", number()}, + use_double_wide_mlp: {"use_double_wide_mlp", boolean()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(spec) do + %{ + "embedder.token_embedding" => "model.language_model.embed_tokens", + # PLE global weights + "embedder.token_embedding_per_layer" => "model.language_model.embed_tokens_per_layer", + "per_layer_model_projection" => "model.language_model.per_layer_model_projection", + "per_layer_projection_norm" => "model.language_model.per_layer_projection_norm", + # PLE per-layer weights + "decoder.blocks.{n}.per_layer_input_gate" => + "model.language_model.layers.{n}.per_layer_input_gate", + "decoder.blocks.{n}.per_layer_projection" => + "model.language_model.layers.{n}.per_layer_projection", + "decoder.blocks.{n}.post_per_layer_input_norm" => + "model.language_model.layers.{n}.post_per_layer_input_norm", + # Layer scalar + "decoder.blocks.{n}.layer_scalar_op" => "model.language_model.layers.{n}", + "decoder.blocks.{n}.layer_scalar_op.layer_scalar" => + "model.language_model.layers.{n}.layer_scalar", + # Attention projections + "decoder.blocks.{n}.self_attention.query" => + "model.language_model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => + "model.language_model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => + "model.language_model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => + "model.language_model.layers.{n}.self_attn.o_proj", + # QK-norm + "decoder.blocks.{n}.self_attention.query_norm" => + "model.language_model.layers.{n}.self_attn.q_norm", + "decoder.blocks.{n}.self_attention.key_norm" => + "model.language_model.layers.{n}.self_attn.k_norm", + # Layer norms + "decoder.blocks.{n}.self_attention_norm" => + "model.language_model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.post_attention_norm" => + "model.language_model.layers.{n}.post_attention_layernorm", + # FFN layer norms + "decoder.blocks.{n}.pre_ffn_norm" => + "model.language_model.layers.{n}.pre_feedforward_layernorm", + "decoder.blocks.{n}.post_ffn_norm" => + "model.language_model.layers.{n}.post_feedforward_layernorm", + # FFN projections + "decoder.blocks.{n}.ffn.gate" => "model.language_model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.language_model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.language_model.layers.{n}.mlp.down_proj", + # Output + "output_norm" => "model.language_model.norm", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, + do: "model.language_model.embed_tokens", + else: "lm_head" + ), + "sequence_classification_head.output" => "score" + } + end + end +end