Skip to content

feat: add native dense Qwen3 layer primitives#118

Closed
hfiguera wants to merge 2 commits into
elixir-nx:mainfrom
hfiguera:qwen3-native-dense-layer-primitives
Closed

feat: add native dense Qwen3 layer primitives#118
hfiguera wants to merge 2 commits into
elixir-nx:mainfrom
hfiguera:qwen3-native-dense-layer-primitives

Conversation

@hfiguera

Copy link
Copy Markdown
Contributor

This adds native MLX-backed primitives for dense Qwen3 layer execution.

The main goal is to make the Qwen3 attention/MLP/layer pieces available as
tested EMLX primitives before adding higher-level EMLXAxon loading or text
generation code on top.

This includes:

  • Qwen3 RoPE + KV cache attention
  • dense Qwen3 MLP
  • dense attention residual projection
  • dense attention block
  • dense transformer layer

The dense projection layout is {in, out}, so the native path can use direct
matmul(input, weight) calls. The Qwen3-specific KV cache helpers use
{B, N_kv, T_max, D} caches.

I kept this PR limited to the lower-level EMLX side. It does not add
safetensors loading or a text generation API; those are easier to review as
follow-up changes once the primitives are in place.

Tests cover numerical comparisons against pure Nx references for:

  • GQA KV attention
  • Qwen3 MLP
  • attention residual projection
  • attention block
  • full transformer layer

There is also coverage for nonzero cache offsets, batch size 2, cache updates,
projection shape validation, and cache-length overflow before graph
construction.

Tests:

  • mix test test/emlx/fast_test.exs --include metal

@polvalente

polvalente commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

I believe most of these are handled in emlx_axon. EMLX should be lower level.

To clarify, emlx_axon is the point of contact for "let's optimize this specific model or class of models"

@hfiguera

Copy link
Copy Markdown
Contributor Author

@polvalente That makes sense. I agree the public Qwen3 API belongs in emlx_axon, not in EMLX.

Before reworking this, I want to clarify the native boundary. The public Qwen3 API can move to emlx_axon, but the fused MLX kernels currently rely on EMLX’s NIF, tensor resource, and worker plumbing.

The practical option I see is to keep that native plumbing in emlx, expose the Qwen3 kernels only as internal support, and have emlx_axon own the public Qwen3 API.

The alternative would be giving emlx_axon its own NIF/native layer or a native plugin API for EMLX tensor refs, which seems like a bigger design.

Does the internal support approach fit what you had in mind, or would you prefer avoiding Qwen3 native kernels in emlx entirely?

@polvalente

polvalente commented Jun 20, 2026

Copy link
Copy Markdown
Collaborator

The reasoning makes more sense now! Thanks for clarifying.

I know emlx_axon already has quite a few implementations similar to the ones you're proposing. So I think we should have a pull request that does things end to end instead of adding all functions in the lower layer first.

@hfiguera

Copy link
Copy Markdown
Contributor Author

Thanks, that makes sense.

I’m closing this in favor of an end to end emlx_axon PR. That keeps the user facing Qwen3 dense loading and generation path in emlx_axon instead of trying to land the lower level pieces first.

There is still some native support needed in emlx because the fast path has to operate on EMLX tensor refs from C++, but I agree the review should happen around the full emlx_axon path.

@hfiguera hfiguera closed this Jun 20, 2026
@hfiguera hfiguera deleted the qwen3-native-dense-layer-primitives branch June 20, 2026 21:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants