Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion emlx/c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ NIF(astype) {
TENSOR(mlx::core::astype(*t, type, device));
}

NIF(copy) {
TENSOR_PARAM(0, t);
DEVICE_PARAM(1, device);

TENSOR(mlx::core::copy(*t, device));
}

// Builds the resource binary for `to_blob` in `out_env`. Used by both the
// legacy direct-call NIF and `command_queue_post_to_blob` (which builds the
// term inside the worker thread for delivery via enif_send). May throw
Expand Down Expand Up @@ -1683,6 +1690,7 @@ ASYNC_NIF(full)
ASYNC_NIF(arange)
ASYNC_NIF(eye)
ASYNC_NIF(reshape)
ASYNC_NIF(copy)
ASYNC_NIF(astype)
ASYNC_NIF(view)
ASYNC_NIF(broadcast_to)
Expand Down Expand Up @@ -1817,6 +1825,7 @@ static ErlNifFunc nif_funcs[] = {
{"arange", 6, arange_async},
{"eye", 5, eye_async},
{"reshape", 4, reshape_async},
{"copy", 3, copy_async},
{"astype", 4, astype_async},
{"view", 4, view_async},
{"broadcast_to", 4, broadcast_to_async},
Expand Down Expand Up @@ -1979,4 +1988,3 @@ static ErlNifFunc nif_funcs[] = {
{"kv_cache_sdpa_update", 9, kv_cache_sdpa_update_async}};

ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL)

1 change: 1 addition & 0 deletions emlx/lib/emlx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ defmodule EMLX do
## Manipulation
deftensor reshape(tensor, shape)
deftensor broadcast_to(tensor, shape)
deftensor copy(tensor)
deftensor astype(tensor, type)
deftensor as_strided(tensor, shape, strides, offset)
deftensor view(tensor, type)
Expand Down
92 changes: 78 additions & 14 deletions emlx/lib/emlx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -159,27 +159,86 @@ defmodule EMLX.Backend do
@impl true
def backend_copy(
%T{
shape: logical_shape,
type: logical_type,
names: names,
data: %Backend{
ref: ref,
shape: packed_shape,
type: packed_type,
quantization_config: %EMLX.Quantization.Config{} = cfg
quantization_config: nil
}
},
} = tensor,
EMLX.Backend,
opts
) do
copy_to_device(tensor, device_option(opts))
end

@impl true
def backend_copy(
%T{
data: %Backend{
quantization_config: %EMLX.Quantization.Config{}
}
} = tensor,
EMLX.Backend,
opts
) do
copy_quantized_to_device(tensor, opts)
end

@impl true
def backend_copy(
%T{data: %Nx.BinaryBackend{state: binary}, type: type, shape: shape} = tensor,
EMLX.Backend,
opts
) do
binary
|> maybe_modify_binary(type, to_nx_type(to_mlx_type(type)))
|> EMLX.from_blob(shape, to_mlx_type(type), device_option(opts))
|> to_nx(tensor)
end

@impl true
def backend_copy(%T{type: type, shape: shape} = tensor, backend, opts) do
Nx.from_binary(to_binary(tensor, Nx.size(tensor)), type, backend: {backend, opts})
|> Nx.reshape(shape)
end

defp copy_to_device(%T{data: %Backend{ref: {device, _ref}}} = tensor, device) do
tensor
|> from_nx()
|> EMLX.copy()
|> to_nx(tensor)
end

defp copy_to_device(%T{data: %Backend{ref: ref}} = tensor, target_device) do
ref
|> EMLX.to_device(target_device)
|> to_nx(tensor)
end

defp copy_quantized_to_device(
%T{
shape: logical_shape,
type: logical_type,
names: names,
data: %Backend{
ref: ref,
shape: packed_shape,
type: packed_type,
quantization_config: %EMLX.Quantization.Config{} = cfg
}
},
opts
) do
# Preserve quantization_config when copying a quantized tensor to EMLX.Backend.
# The generic backend_copy goes through to_binary/from_binary which drops the config.
target_device = device_option(opts)
device_opts = {EMLX.Backend, device: target_device}

packed_size = Enum.reduce(Tuple.to_list(packed_shape), 1, &*/2)
packed_binary = EMLX.to_blob(ref, packed_size)
new_ref = EMLX.from_blob(packed_binary, packed_shape, :uint32, target_device)
new_ref =
if elem(ref, 0) == target_device do
ref
|> EMLX.copy()
else
EMLX.to_device(ref, target_device)
end

new_scales = Nx.backend_copy(cfg.scales, device_opts)
new_biases = Nx.backend_copy(cfg.biases, device_opts)
Expand All @@ -198,9 +257,14 @@ defmodule EMLX.Backend do
end

@impl true
def backend_copy(%T{type: type, shape: shape} = tensor, backend, opts) do
Nx.from_binary(to_binary(tensor, Nx.size(tensor)), type, backend: {backend, opts})
|> Nx.reshape(shape)
def backend_transfer(%T{data: %Backend{ref: {device, _ref}}} = tensor, EMLX.Backend, opts) do
if device == device_option(opts) do
tensor
else
new_tensor = backend_copy(tensor, EMLX.Backend, opts)
backend_deallocate(tensor)
new_tensor
end
end

@impl true
Expand Down
73 changes: 73 additions & 0 deletions emlx/test/emlx/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,79 @@ defmodule EMLX.NxTest do
]
@unary_ops [:abs, :bitwise_not, :ceil, :floor, :negate, :round, :sign, :argmax, :argmin]

test "backend_transfer to the same EMLX device is a no-op" do
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32, backend: {EB, device: :cpu})

assert Nx.backend_transfer(tensor, {EB, device: :cpu}) |> EB.from_nx() == EB.from_nx(tensor)
end

test "EMLX.copy creates an independent tensor ref" do
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32, backend: {EB, device: :cpu})
ref = EB.from_nx(tensor)

copied_ref = EMLX.copy(ref)

assert copied_ref != ref

Nx.backend_deallocate(tensor)
assert Nx.to_flat_list(EB.to_nx(copied_ref, tensor)) == [1.0, 2.0, 3.0]
end

test "backend_copy to the same EMLX device creates an independent tensor ref" do
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32, backend: {EB, device: :cpu})

copied = Nx.backend_copy(tensor, {EB, device: :cpu})

assert EB.from_nx(copied) != EB.from_nx(tensor)

Nx.backend_deallocate(tensor)
assert Nx.to_flat_list(copied) == [1.0, 2.0, 3.0]
end

@tag :metal
test "backend_copy to the same GPU device creates an independent tensor ref" do
tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32, backend: {EB, device: :gpu})

copied = Nx.backend_copy(tensor, {EB, device: :gpu})

assert EB.from_nx(copied) != EB.from_nx(tensor)

Nx.backend_deallocate(tensor)
assert Nx.to_flat_list(copied) == [1.0, 2.0, 3.0]
end

test "backend_copy to the same EMLX device creates independent quantized refs" do
tensor = Nx.iota({2, 32}, type: :f32, backend: {EB, device: :cpu})
quantized = EMLX.quantize(tensor, group_size: 32, type: {:s, 4})

expected = EMLX.dequantize(quantized)
copied = Nx.backend_copy(quantized, {EB, device: :cpu})

assert EB.from_nx(copied) != EB.from_nx(quantized)

assert EB.from_nx(copied.data.quantization_config.scales) !=
EB.from_nx(quantized.data.quantization_config.scales)

assert EB.from_nx(copied.data.quantization_config.biases) !=
EB.from_nx(quantized.data.quantization_config.biases)

Nx.backend_deallocate(quantized)
Nx.backend_deallocate(quantized.data.quantization_config.scales)
Nx.backend_deallocate(quantized.data.quantization_config.biases)

assert_all_close(EMLX.dequantize(copied), expected)
end

@tag :metal
test "backend_copy between EMLX devices preserves values" do
tensor = Nx.iota({3, 4}, type: :f32, backend: {EB, device: :cpu})

copied = Nx.backend_copy(tensor, {EB, device: :gpu})

assert elem(EB.from_nx(copied), 0) == :gpu
assert Nx.to_binary(copied) == Nx.to_binary(tensor)
end

defp test_binary_op(op, data_a \\ [[5, 6], [7, 8]], data_b \\ [[4, 3], [2, 1]], type_a, type_b) do
a = Nx.tensor(data_a, type: type_a)
b = Nx.tensor(data_b, type: type_b)
Expand Down
Loading