diff --git a/emlx/c_src/emlx_nif.cpp b/emlx/c_src/emlx_nif.cpp index f18a408..7e8d799 100644 --- a/emlx/c_src/emlx_nif.cpp +++ b/emlx/c_src/emlx_nif.cpp @@ -160,6 +160,13 @@ NIF(astype) { TENSOR(mlx::core::astype(*t, type, device)); } +NIF(copy) { + TENSOR_PARAM(0, t); + DEVICE_PARAM(1, device); + + TENSOR(mlx::core::copy(*t, device)); +} + // Builds the resource binary for `to_blob` in `out_env`. Used by both the // legacy direct-call NIF and `command_queue_post_to_blob` (which builds the // term inside the worker thread for delivery via enif_send). May throw @@ -1683,6 +1690,7 @@ ASYNC_NIF(full) ASYNC_NIF(arange) ASYNC_NIF(eye) ASYNC_NIF(reshape) +ASYNC_NIF(copy) ASYNC_NIF(astype) ASYNC_NIF(view) ASYNC_NIF(broadcast_to) @@ -1817,6 +1825,7 @@ static ErlNifFunc nif_funcs[] = { {"arange", 6, arange_async}, {"eye", 5, eye_async}, {"reshape", 4, reshape_async}, + {"copy", 3, copy_async}, {"astype", 4, astype_async}, {"view", 4, view_async}, {"broadcast_to", 4, broadcast_to_async}, @@ -1979,4 +1988,3 @@ static ErlNifFunc nif_funcs[] = { {"kv_cache_sdpa_update", 9, kv_cache_sdpa_update_async}}; ERL_NIF_INIT(Elixir.EMLX.NIF, nif_funcs, load, NULL, upgrade, NULL) - diff --git a/emlx/lib/emlx.ex b/emlx/lib/emlx.ex index d33f545..791702c 100644 --- a/emlx/lib/emlx.ex +++ b/emlx/lib/emlx.ex @@ -179,6 +179,7 @@ defmodule EMLX do ## Manipulation deftensor reshape(tensor, shape) deftensor broadcast_to(tensor, shape) + deftensor copy(tensor) deftensor astype(tensor, type) deftensor as_strided(tensor, shape, strides, offset) deftensor view(tensor, type) diff --git a/emlx/lib/emlx/backend.ex b/emlx/lib/emlx/backend.ex index 9033eb3..f558ba3 100644 --- a/emlx/lib/emlx/backend.ex +++ b/emlx/lib/emlx/backend.ex @@ -159,27 +159,86 @@ defmodule EMLX.Backend do @impl true def backend_copy( %T{ - shape: logical_shape, - type: logical_type, - names: names, data: %Backend{ - ref: ref, - shape: packed_shape, - type: packed_type, - quantization_config: %EMLX.Quantization.Config{} = cfg + quantization_config: nil } - }, + } = tensor, EMLX.Backend, opts ) do + copy_to_device(tensor, device_option(opts)) + end + + @impl true + def backend_copy( + %T{ + data: %Backend{ + quantization_config: %EMLX.Quantization.Config{} + } + } = tensor, + EMLX.Backend, + opts + ) do + copy_quantized_to_device(tensor, opts) + end + + @impl true + def backend_copy( + %T{data: %Nx.BinaryBackend{state: binary}, type: type, shape: shape} = tensor, + EMLX.Backend, + opts + ) do + binary + |> maybe_modify_binary(type, to_nx_type(to_mlx_type(type))) + |> EMLX.from_blob(shape, to_mlx_type(type), device_option(opts)) + |> to_nx(tensor) + end + + @impl true + def backend_copy(%T{type: type, shape: shape} = tensor, backend, opts) do + Nx.from_binary(to_binary(tensor, Nx.size(tensor)), type, backend: {backend, opts}) + |> Nx.reshape(shape) + end + + defp copy_to_device(%T{data: %Backend{ref: {device, _ref}}} = tensor, device) do + tensor + |> from_nx() + |> EMLX.copy() + |> to_nx(tensor) + end + + defp copy_to_device(%T{data: %Backend{ref: ref}} = tensor, target_device) do + ref + |> EMLX.to_device(target_device) + |> to_nx(tensor) + end + + defp copy_quantized_to_device( + %T{ + shape: logical_shape, + type: logical_type, + names: names, + data: %Backend{ + ref: ref, + shape: packed_shape, + type: packed_type, + quantization_config: %EMLX.Quantization.Config{} = cfg + } + }, + opts + ) do # Preserve quantization_config when copying a quantized tensor to EMLX.Backend. # The generic backend_copy goes through to_binary/from_binary which drops the config. target_device = device_option(opts) device_opts = {EMLX.Backend, device: target_device} - packed_size = Enum.reduce(Tuple.to_list(packed_shape), 1, &*/2) - packed_binary = EMLX.to_blob(ref, packed_size) - new_ref = EMLX.from_blob(packed_binary, packed_shape, :uint32, target_device) + new_ref = + if elem(ref, 0) == target_device do + ref + |> EMLX.copy() + else + EMLX.to_device(ref, target_device) + end new_scales = Nx.backend_copy(cfg.scales, device_opts) new_biases = Nx.backend_copy(cfg.biases, device_opts) @@ -198,9 +257,14 @@ defmodule EMLX.Backend do end @impl true - def backend_copy(%T{type: type, shape: shape} = tensor, backend, opts) do - Nx.from_binary(to_binary(tensor, Nx.size(tensor)), type, backend: {backend, opts}) - |> Nx.reshape(shape) + def backend_transfer(%T{data: %Backend{ref: {device, _ref}}} = tensor, EMLX.Backend, opts) do + if device == device_option(opts) do + tensor + else + new_tensor = backend_copy(tensor, EMLX.Backend, opts) + backend_deallocate(tensor) + new_tensor + end end @impl true diff --git a/emlx/test/emlx/nx_test.exs b/emlx/test/emlx/nx_test.exs index 4a7b801..dbaa188 100644 --- a/emlx/test/emlx/nx_test.exs +++ b/emlx/test/emlx/nx_test.exs @@ -27,6 +27,79 @@ defmodule EMLX.NxTest do ] @unary_ops [:abs, :bitwise_not, :ceil, :floor, :negate, :round, :sign, :argmax, :argmin] + test "backend_transfer to the same EMLX device is a no-op" do + tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32, backend: {EB, device: :cpu}) + + assert Nx.backend_transfer(tensor, {EB, device: :cpu}) |> EB.from_nx() == EB.from_nx(tensor) + end + + test "EMLX.copy creates an independent tensor ref" do + tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32, backend: {EB, device: :cpu}) + ref = EB.from_nx(tensor) + + copied_ref = EMLX.copy(ref) + + assert copied_ref != ref + + Nx.backend_deallocate(tensor) + assert Nx.to_flat_list(EB.to_nx(copied_ref, tensor)) == [1.0, 2.0, 3.0] + end + + test "backend_copy to the same EMLX device creates an independent tensor ref" do + tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32, backend: {EB, device: :cpu}) + + copied = Nx.backend_copy(tensor, {EB, device: :cpu}) + + assert EB.from_nx(copied) != EB.from_nx(tensor) + + Nx.backend_deallocate(tensor) + assert Nx.to_flat_list(copied) == [1.0, 2.0, 3.0] + end + + @tag :metal + test "backend_copy to the same GPU device creates an independent tensor ref" do + tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f32, backend: {EB, device: :gpu}) + + copied = Nx.backend_copy(tensor, {EB, device: :gpu}) + + assert EB.from_nx(copied) != EB.from_nx(tensor) + + Nx.backend_deallocate(tensor) + assert Nx.to_flat_list(copied) == [1.0, 2.0, 3.0] + end + + test "backend_copy to the same EMLX device creates independent quantized refs" do + tensor = Nx.iota({2, 32}, type: :f32, backend: {EB, device: :cpu}) + quantized = EMLX.quantize(tensor, group_size: 32, type: {:s, 4}) + + expected = EMLX.dequantize(quantized) + copied = Nx.backend_copy(quantized, {EB, device: :cpu}) + + assert EB.from_nx(copied) != EB.from_nx(quantized) + + assert EB.from_nx(copied.data.quantization_config.scales) != + EB.from_nx(quantized.data.quantization_config.scales) + + assert EB.from_nx(copied.data.quantization_config.biases) != + EB.from_nx(quantized.data.quantization_config.biases) + + Nx.backend_deallocate(quantized) + Nx.backend_deallocate(quantized.data.quantization_config.scales) + Nx.backend_deallocate(quantized.data.quantization_config.biases) + + assert_all_close(EMLX.dequantize(copied), expected) + end + + @tag :metal + test "backend_copy between EMLX devices preserves values" do + tensor = Nx.iota({3, 4}, type: :f32, backend: {EB, device: :cpu}) + + copied = Nx.backend_copy(tensor, {EB, device: :gpu}) + + assert elem(EB.from_nx(copied), 0) == :gpu + assert Nx.to_binary(copied) == Nx.to_binary(tensor) + end + defp test_binary_op(op, data_a \\ [[5, 6], [7, 8]], data_b \\ [[4, 3], [2, 1]], type_a, type_b) do a = Nx.tensor(data_a, type: type_a) b = Nx.tensor(data_b, type: type_b)