feat: add native dense Qwen3 layer primitives#118
Conversation
|
I believe most of these are handled in emlx_axon. EMLX should be lower level. To clarify, emlx_axon is the point of contact for "let's optimize this specific model or class of models" |
|
@polvalente That makes sense. I agree the public Qwen3 API belongs in Before reworking this, I want to clarify the native boundary. The public Qwen3 API can move to The practical option I see is to keep that native plumbing in The alternative would be giving Does the internal support approach fit what you had in mind, or would you prefer avoiding Qwen3 native kernels in |
|
The reasoning makes more sense now! Thanks for clarifying. I know emlx_axon already has quite a few implementations similar to the ones you're proposing. So I think we should have a pull request that does things end to end instead of adding all functions in the lower layer first. |
|
Thanks, that makes sense. I’m closing this in favor of an end to end There is still some native support needed in |
This adds native MLX-backed primitives for dense Qwen3 layer execution.
The main goal is to make the Qwen3 attention/MLP/layer pieces available as
tested EMLX primitives before adding higher-level EMLXAxon loading or text
generation code on top.
This includes:
The dense projection layout is
{in, out}, so the native path can use directmatmul(input, weight)calls. The Qwen3-specific KV cache helpers use{B, N_kv, T_max, D}caches.I kept this PR limited to the lower-level EMLX side. It does not add
safetensors loading or a text generation API; those are easier to review as
follow-up changes once the primitives are in place.
Tests cover numerical comparisons against pure Nx references for:
There is also coverage for nonzero cache offsets, batch size 2, cache updates,
projection shape validation, and cache-length overflow before graph
construction.
Tests:
mix test test/emlx/fast_test.exs --include metal