Skip to content

Commit

Permalink
[SYCLomatic] Add support for two-way type-cast from `dpct::kernel_lib…
Browse files Browse the repository at this point in the history
…rary` and `dpct::kernel_function` to `uint64_t` conversion (#2606)
  • Loading branch information
the-slow-one authored Jan 10, 2025
1 parent 821800f commit 8b31cc7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
12 changes: 8 additions & 4 deletions clang/runtime/dpct-rt/include/dpct/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,10 @@ class kernel_library {
public:
constexpr kernel_library() : ptr{nullptr} {}
constexpr kernel_library(void *ptr) : ptr{ptr} {}
kernel_library(uint64_t addr) : ptr(reinterpret_cast<void *>(addr)) {}

operator void *() const { return ptr; }
explicit operator uint64_t() const { return reinterpret_cast<uint64_t>(ptr); }

private:
void *ptr;
Expand Down Expand Up @@ -393,15 +395,17 @@ class kernel_function {
public:
constexpr kernel_function() : ptr{nullptr} {}
constexpr kernel_function(dpct::kernel_functor ptr) : ptr{ptr} {}
kernel_function(uint64_t addr)
: ptr(reinterpret_cast<dpct::kernel_functor>(addr)) {}

operator void *() const { return ((void *)ptr); }

void operator()(sycl::queue &q, const sycl::nd_range<3> &range,
unsigned int a, void **args, void **extra) {
unsigned int a, void **args, void **extra) const {
ptr(q, range, a, args, extra);
}

explicit operator uint64_t() const { return (uint64_t)this; }
explicit operator uint64_t() const { return reinterpret_cast<uint64_t>(ptr); }

private:
dpct::kernel_functor ptr;
Expand All @@ -411,7 +415,7 @@ class kernel_function {
/// \param [in] library Handle to the kernel library.
/// \param [in] name Name of the kernel function.
static inline dpct::kernel_function
get_kernel_function(kernel_library &library, const std::string &name) {
get_kernel_function(const kernel_library &library, const std::string &name) {
#ifdef _WIN32
dpct::kernel_functor fn = reinterpret_cast<dpct::kernel_functor>(
GetProcAddress(static_cast<HMODULE>(static_cast<void *>(library)),
Expand All @@ -434,7 +438,7 @@ get_kernel_function(kernel_library &library, const std::string &name) {
/// function.
/// \param [in] kernelParams Array of pointers to kernel arguments.
/// \param [in] extra Extra arguments.
static inline void invoke_kernel_function(dpct::kernel_function &function,
static inline void invoke_kernel_function(const dpct::kernel_function &function,
sycl::queue &queue,
sycl::range<3> groupRange,
sycl::range<3> localRange,
Expand Down
26 changes: 19 additions & 7 deletions clang/test/dpct/kernel-function-typecast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@

typedef uint64_t u64;

// CHECK: u64 foo(dpct::kernel_function cuFunc, dpct::kernel_library cuMod) {
u64 foo(CUfunction cuFunc, CUmodule cuMod) {
// CHECK: cuFunc = dpct::get_kernel_function(cuMod, "kfoo");
cuModuleGetFunction(&cuFunc, cuMod, "kfoo");
u64 function = (u64)cuFunc;
// CHECK: void exec_kernel(dpct::kernel_function cuFunc, dpct::kernel_library cuMod, dpct::queue_ptr stream) {
void exec_kernel(CUfunction cuFunc, CUmodule cuMod, CUstream stream) {
u64 mod;
u64 function;

return function;
}
// verify the conversion from dpct::kernel_library to uint64_t
mod = (u64)cuMod;

// verify the conversion from uint64_t to dpct::kernel_library
// CHECK: cuFunc = dpct::get_kernel_function((dpct::kernel_library)mod, "kfoo");
cuModuleGetFunction(&cuFunc, (CUmodule)mod, "kfoo");

// verify the conversion from dpct::kernel_function to uint64_t
function = (u64)cuFunc;

void *config[] = {0};

// verify the conversion from uint64_t to dpct::kernel_function
// CHECK: dpct::invoke_kernel_function((dpct::kernel_function)function, *stream, sycl::range<3>(100, 100, 100), sycl::range<3>(100, 100, 100), 1024, NULL, config);
cuLaunchKernel((CUfunction)function, 100, 100, 100, 100, 100, 100, 1024, stream, NULL, config);
}

0 comments on commit 8b31cc7

Please sign in to comment.