Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/kernels/flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
drv =
sys: out: out.packages.${sys}.redistributable.${"tvm-ffi${tvmFfiVersion}-${cudaVersion}-${sys}"};
}
{
name = "relu-tvm-ffi-compiler-flags-kernel";
path = ./relu-tvm-ffi-compiler-flags;
drv =
sys: out: out.packages.${sys}.redistributable.${"tvm-ffi${tvmFfiVersion}-${cudaVersion}-${sys}"};
}
{
name = "extra-data";
path = ./extra-data;
Expand Down Expand Up @@ -225,6 +231,12 @@
drv =
sys: out: out.packages.${sys}.redistributable.${"tvm-ffi${tvmFfiVersion}-${xpuVersion}-${sys}"};
}
{
name = "relu-tvm-ffi-compiler-flags-kernel";
path = ./relu-tvm-ffi-compiler-flags;
drv =
sys: out: out.packages.${sys}.redistributable.${"tvm-ffi${tvmFfiVersion}-${xpuVersion}-${sys}"};
}
{
name = "relu-compiler-flags";
path = ./relu-compiler-flags;
Expand Down
1 change: 1 addition & 0 deletions examples/kernels/relu-compiler-flags/build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]
cxx-flags = ["-DCANARY_IN_THE_KERNEL"]

[kernel.activation_xpu]
backend = "xpu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include "registration.h"
#include "torch_binding.h"

#ifndef CANARY_IN_THE_KERNEL
#error "Framework cxx-flags are not correctly handled."
#endif

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("relu(Tensor! out, Tensor input) -> ()");
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
Expand Down
62 changes: 62 additions & 0 deletions examples/kernels/relu-tvm-ffi-compiler-flags/CARD.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
---
library_name: kernels
{% if license %}license: {{ license }}
{% endif %}---

This is the repository card of {{ repo_id }} that has been pushed on the Hub. It was built to be used with the [`kernels` library](https://github.com/huggingface/kernels). This card was automatically generated.

## How to use
{% if functions %}

```python
# make sure `kernels` is installed: `pip install -U kernels`
from kernels import get_kernel

kernel_module = get_kernel("{{ repo_id }}", version={{ version }})
{{ functions[0] }} = kernel_module.{{ functions[0] }}

{{ functions[0] }}(...)
```
{% else %}

Usage example not available.
{% endif %}

## Available functions
{% if functions %}
{% for func in functions %}
- `{{ func }}`
{% endfor %}
{% else %}

Function list not available.
{% endif %}
{% if layers %}

## Available layers
{% for layer in layers %}
- `{{ layer }}`
{% endfor %}
{% endif %}

## Benchmarks
{% if has_benchmark %}

Benchmarking script is available for this kernel. Run `kernels benchmark {{ repo_id }} --version {{ version }}`.
{% else %}

No benchmark available yet.
{% endif %}
{% if upstream %}

## Upstream

The original source code for this kernel comes from {{ upstream }}.
{% endif %}
{% if source %}

## Source

The kernel-builder formatted source for this kernel is available at {{ source }}.
{% endif %}

43 changes: 43 additions & 0 deletions examples/kernels/relu-tvm-ffi-compiler-flags/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[general]
name = "relu-tvm-ffi-compiler-flags"
version = 1
license = "Apache-2.0"
backends = [
"cpu",
"cuda",
"xpu",
]

[general.hub]
repo-id = "kernels-test/relu-tvm-ffi-compiler-flags"

[tvm-ffi]
src = ["tvm-ffi-ext/tvm_ffi_binding.cpp"]
cxx-flags = ["-DCANARY_IN_THE_KERNEL"]

[kernel.relu_cuda]
backend = "cuda"
depends = []
cuda-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"]
src = [
"relu_cuda/relu.cu",
"util.hh",
]

[kernel.relu_cpu]
backend = "cpu"
depends = []
cxx-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"]
src = [
"relu_cpu/relu_cpu.cpp",
"util.hh",
]

[kernel.relu_xpu]
backend = "xpu"
depends = ["torch"]
sycl-flags = ["-DWHO_AM_I_IF_NOT_THE_CANARY"]
src = [
"relu_xpu/relu.cpp",
"util.hh",
]
17 changes: 17 additions & 0 deletions examples/kernels/relu-tvm-ffi-compiler-flags/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
description = "Flake for ReLU kernel";

inputs = {
kernel-builder.url = "path:../../..";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genKernelFlakeOutputs {
inherit self;
path = ./.;
};
}
74 changes: 74 additions & 0 deletions examples/kernels/relu-tvm-ffi-compiler-flags/relu_cpu/relu_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <tvm/ffi/tvm_ffi.h>

#include "../util.hh"

#ifndef WHO_AM_I_IF_NOT_THE_CANARY
#error "Kernel flags are not correctly handled."
#endif

#ifdef __SSE__
#include <xmmintrin.h>
#endif

#ifdef __ARM_NEON
#include <arm_neon.h>
#endif

// NOTE: This is a minimal example kernel that is not optimized for
// performance, so we do not care about unaligned loads/stores.

#ifdef __SSE__
void relu_forward_sse(float* out, const float* input, size_t size) {
size_t i = 0;

for (; i + 4 <= size; i += 4) {
__m128 vec_input = _mm_loadu_ps(input + i);
__m128 vec_zero = _mm_setzero_ps();
__m128 vec_output = _mm_max_ps(vec_input, vec_zero);
_mm_storeu_ps(out + i, vec_output);
}

for (; i < size; ++i) {
out[i] = input[i] > 0 ? input[i] : 0;
}
}
#endif

#ifdef __ARM_NEON
void relu_forward_neon(float* out, const float* input, size_t size) {
size_t i = 0;

for (; i + 4 <= size; i += 4) {
float32x4_t vec_input = vld1q_f32(input + i);
float32x4_t vec_output = vmaxq_f32(vec_input, vdupq_n_f32(0));
vst1q_f32(out + i, vec_output);
}

for (; i < size; ++i) {
out[i] = input[i] > 0 ? input[i] : 0;
}
}
#endif

using namespace tvm;

void relu_cpu(ffi::TensorView out, ffi::TensorView const input) {
CHECK_INPUT(input);
CHECK_INPUT(out);
CHECK_DEVICE(input, out);

TVM_FFI_CHECK(input.dtype() == out.dtype(), ValueError) << "input/output dtype mismatch";
TVM_FFI_CHECK(input.numel() == out.numel(), ValueError) << "input/output size mismatch";

if (input.dtype() == dl_float32) {
#if defined(__SSE__)
relu_forward_sse(static_cast<float *>(out.data_ptr()), static_cast<float *>(input.data_ptr()), input.numel());
#elif defined(__ARM_NEON)
relu_forward_neon(static_cast<float *>(out.data_ptr()), static_cast<float *>(input.data_ptr()), input.numel());
#else
#error "Unsupported architecture; please use a CPU with SSE or ARM NEON support."
#endif
} else {
TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
}
}
50 changes: 50 additions & 0 deletions examples/kernels/relu-tvm-ffi-compiler-flags/relu_cuda/relu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <cmath>
#include <tvm/ffi/tvm_ffi.h>
#include <tvm/ffi/extra/cuda/device_guard.h>
#include <tvm/ffi/extra/c_env_api.h>

#include "../util.hh"

#ifndef WHO_AM_I_IF_NOT_THE_CANARY
#error "Kernel flags are not correctly handled."
#endif

__global__ void relu_kernel(float *__restrict__ out,
float const *__restrict__ input, const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
auto x = input[token_idx * d + idx];
out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
}
}

using namespace tvm;

void relu_cuda(ffi::TensorView out, ffi::TensorView const input) {
CHECK_INPUT_CUDA(input);
CHECK_INPUT_CUDA(out);
CHECK_DEVICE(input, out);

TVM_FFI_CHECK(input.dtype() == out.dtype(), ValueError) << "input/output dtype mismatch";
TVM_FFI_CHECK(input.numel() == out.numel(), ValueError) << "input/output size mismatch";

if (input.numel() == 0) {
return;
}

ffi::CUDADeviceGuard guard(input.device().device_id);
cudaStream_t stream = static_cast<cudaStream_t>(
TVMFFIEnvGetStream(input.device().device_type, input.device().device_id));

int d = input.size(-1);
int64_t num_tokens = input.numel() / d;
int64_t block(std::min(d, 1024));

if (input.dtype() == dl_float32) {
relu_kernel<<<num_tokens, block, 0, stream>>>(static_cast<float*>(out.data_ptr()),
static_cast<const float*>(input.data_ptr()),
d);
} else {
TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
}
}
43 changes: 43 additions & 0 deletions examples/kernels/relu-tvm-ffi-compiler-flags/relu_xpu/relu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <tvm/ffi/tvm_ffi.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <sycl/sycl.hpp>

#include "../util.hh"

#ifndef WHO_AM_I_IF_NOT_THE_CANARY
#error "Kernel flags are not correctly handled."
#endif

using namespace sycl;

void relu_xpu_impl(float *output, float const *input, int const numel) {
// Create SYCL queue directly
sycl::queue queue;

// Launch SYCL kernel
queue.parallel_for(range<1>(numel), [=](id<1> idx) {
auto i = idx[0];
output[i] = input[i] > 0.0f ? input[i] : 0.0f;
}).wait();
}

using namespace tvm;

void relu_xpu(ffi::TensorView out, ffi::TensorView const input) {
CHECK_INPUT_XPU(input);
CHECK_INPUT_XPU(out);
CHECK_DEVICE(input, out);

TVM_FFI_CHECK(input.dtype() == out.dtype(), ValueError) << "input/output dtype mismatch";
TVM_FFI_CHECK(input.numel() == out.numel(), ValueError) << "input/output size mismatch";

auto numel = input.numel();

if (input.dtype() == dl_float32) {
relu_xpu_impl(static_cast<float *>(out.data_ptr()),
static_cast<float const *>(input.data_ptr()),
numel);
} else {
TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import tvm_ffi

from ._ops import has_jax, jax_add_ffi_target_name_prefix, ops


def relu(x, out) -> tvm_ffi.Tensor:
x_t = tvm_ffi.from_dlpack(x)
out_t = tvm_ffi.from_dlpack(out)

device = x_t.device
if device.type == "cpu":
ops.relu_cpu(out_t, x_t)
elif device.type == "cuda":
ops.relu_cuda(out_t, x_t)
elif device.type == "x_tpu":
ops.relu_x_tpu(out_t, x)
else:
raise NotImplementedError(f"Unsupported device type: {device.type}")

return out


if has_jax:
ops_func = (
getattr(ops, "relu_cuda", None)
or getattr(ops, "relu_xpu", None)
or getattr(ops, "relu_cpu", None)
)
if ops_func is not None:
from jax_tvm_ffi import register_ffi_target

register_ffi_target(
jax_add_ffi_target_name_prefix("relu"),
ops_func,
arg_spec=["rets", "args"],
platform="cpu" if hasattr(ops, "relu_cpu") else "gpu",
)


def relu_jax(x):
if not has_jax:
raise RuntimeError(
"JAX is not available. Please install JAX to use this function."
)

import jax.ffi

return jax.ffi.ffi_call(
jax_add_ffi_target_name_prefix("relu"),
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="broadcast_all",
)(x)


from . import layers


__all__ = ["layers", "relu", "relu_jax"]
Loading
Loading