Skip to content

Commit

Permalink
[oneDPL] Remove usage of __kernel_name_generator as not required fo…
Browse files Browse the repository at this point in the history
…r non-compiled Kernel's (#1935)
  • Loading branch information
SergeyKopienko authored Jan 15, 2025
1 parent a1aaf97 commit ea87681
Showing 1 changed file with 112 additions and 96 deletions.
208 changes: 112 additions & 96 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1776,100 +1776,46 @@ struct __parallel_find_or_nd_range_tuner<oneapi::dpl::__internal::__device_backe
};
#endif // !_ONEDPL_FPGA_EMU

// Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag.
template <typename KernelName, bool __or_tag_check, typename _ExecutionPolicy, typename _BrickTag,
typename __FoundStateType, typename _Predicate, typename... _Ranges>
__FoundStateType
__parallel_find_or_impl_one_wg(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec,
_BrickTag __brick_tag, const std::size_t __rng_n, const std::size_t __wgroup_size,
const __FoundStateType __init_value, _Predicate __pred, _Ranges&&... __rngs)
{
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, __FoundStateType>;
__result_and_scratch_storage_t __result_storage{__exec, 1, 0};

// Calculate the number of elements to be processed by each work-item.
const auto __iters_per_work_item = oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __wgroup_size);

// main parallel_for
auto __event = __exec.queue().submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rngs...);
auto __result_acc =
__result_storage.template __get_result_acc<sycl::access_mode::write>(__cgh, __dpl_sycl::__no_init{});

__cgh.parallel_for<KernelName>(
sycl::nd_range</*dim=*/1>(sycl::range</*dim=*/1>(__wgroup_size), sycl::range</*dim=*/1>(__wgroup_size)),
[=](sycl::nd_item</*dim=*/1> __item_id) {
auto __local_idx = __item_id.get_local_id(0);

// 1. Set initial value to local found state
__FoundStateType __found_local = __init_value;

// 2. Find any element that satisfies pred
// - after this call __found_local may still have initial value:
// 1) if no element satisfies pred;
// 2) early exit from sub-group occurred: in this case the state of __found_local will updated in the next group operation (3)
__pred(__item_id, __rng_n, __iters_per_work_item, __wgroup_size, __found_local, __brick_tag, __rngs...);

// 3. Reduce over group: find __dpl_sycl::__minimum (for the __parallel_find_forward_tag),
// find __dpl_sycl::__maximum (for the __parallel_find_backward_tag)
// or update state with __dpl_sycl::__any_of_group (for the __parallel_or_tag)
// inside all our group items
if constexpr (__or_tag_check)
__found_local = __dpl_sycl::__any_of_group(__item_id.get_group(), __found_local);
else
__found_local = __dpl_sycl::__reduce_over_group(__item_id.get_group(), __found_local,
typename _BrickTag::_LocalResultsReduceOp{});

// Set local found state value value to global state to have correct result
if (__local_idx == 0)
{
__result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__result_acc)[0] = __found_local;
}
});
});

// Wait and return result
return __result_storage.__wait_and_get_value(__event);
}
template <bool __or_tag_check, typename KernelName>
struct __parallel_find_or_impl_one_wg;

// Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag.
template <typename KernelName, bool __or_tag_check, typename _ExecutionPolicy, typename _BrickTag, typename _AtomicType,
typename _Predicate, typename... _Ranges>
_AtomicType
__parallel_find_or_impl_multiple_wgs(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec,
_BrickTag __brick_tag, const std::size_t __rng_n, const std::size_t __n_groups,
const std::size_t __wgroup_size, const _AtomicType __init_value, _Predicate __pred,
_Ranges&&... __rngs)
{
auto __result = __init_value;

// Calculate the number of elements to be processed by each work-item.
const auto __iters_per_work_item = oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __n_groups * __wgroup_size);

// scope is to copy data back to __result after destruction of temporary sycl:buffer
template <bool __or_tag_check, typename... KernelName>
struct __parallel_find_or_impl_one_wg<__or_tag_check, __internal::__optional_kernel_name<KernelName...>>
{
template <typename _ExecutionPolicy, typename _BrickTag, typename __FoundStateType, typename _Predicate,
typename... _Ranges>
__FoundStateType
operator()(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _BrickTag __brick_tag,
const std::size_t __rng_n, const std::size_t __wgroup_size, const __FoundStateType __init_value,
_Predicate __pred, _Ranges&&... __rngs)
{
sycl::buffer<_AtomicType, 1> __result_sycl_buf(&__result, 1); // temporary storage for global atomic
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, __FoundStateType>;
__result_and_scratch_storage_t __result_storage{__exec, 1, 0};

// Calculate the number of elements to be processed by each work-item.
const auto __iters_per_work_item = oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __wgroup_size);

// main parallel_for
__exec.queue().submit([&](sycl::handler& __cgh) {
auto __event = __exec.queue().submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rngs...);
auto __result_sycl_buf_acc = __result_sycl_buf.template get_access<access_mode::read_write>(__cgh);
auto __result_acc =
__result_storage.template __get_result_acc<sycl::access_mode::write>(__cgh, __dpl_sycl::__no_init{});

__cgh.parallel_for<KernelName>(
sycl::nd_range</*dim=*/1>(sycl::range</*dim=*/1>(__n_groups * __wgroup_size),
sycl::range</*dim=*/1>(__wgroup_size)),
__cgh.parallel_for<KernelName...>(
sycl::nd_range</*dim=*/1>(sycl::range</*dim=*/1>(__wgroup_size), sycl::range</*dim=*/1>(__wgroup_size)),
[=](sycl::nd_item</*dim=*/1> __item_id) {
auto __local_idx = __item_id.get_local_id(0);

// 1. Set initial value to local found state
_AtomicType __found_local = __init_value;
__FoundStateType __found_local = __init_value;

// 2. Find any element that satisfies pred
// - after this call __found_local may still have initial value:
// 1) if no element satisfies pred;
// 2) early exit from sub-group occurred: in this case the state of __found_local will updated in the next group operation (3)
__pred(__item_id, __rng_n, __iters_per_work_item, __n_groups * __wgroup_size, __found_local,
__brick_tag, __rngs...);
__pred(__item_id, __rng_n, __iters_per_work_item, __wgroup_size, __found_local, __brick_tag,
__rngs...);

// 3. Reduce over group: find __dpl_sycl::__minimum (for the __parallel_find_forward_tag),
// find __dpl_sycl::__maximum (for the __parallel_find_backward_tag)
Expand All @@ -1881,22 +1827,92 @@ __parallel_find_or_impl_multiple_wgs(oneapi::dpl::__internal::__device_backend_t
__found_local = __dpl_sycl::__reduce_over_group(__item_id.get_group(), __found_local,
typename _BrickTag::_LocalResultsReduceOp{});

// Set local found state value value to global atomic
if (__local_idx == 0 && __found_local != __init_value)
// Set local found state value value to global state to have correct result
if (__local_idx == 0)
{
__dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::global_space> __found(
*__dpl_sycl::__get_accessor_ptr(__result_sycl_buf_acc));

// Update global (for all groups) atomic state with the found index
_BrickTag::__save_state_to_atomic(__found, __found_local);
__result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__result_acc)[0] =
__found_local;
}
});
});
//The end of the scope - a point of synchronization (on temporary sycl buffer destruction)

// Wait and return result
return __result_storage.__wait_and_get_value(__event);
}
};

return __result;
}
template <bool __or_tag_check, typename KernelName>
struct __parallel_find_or_impl_multiple_wgs;

// Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag.
template <bool __or_tag_check, typename... KernelName>
struct __parallel_find_or_impl_multiple_wgs<__or_tag_check, __internal::__optional_kernel_name<KernelName...>>
{
template <typename _ExecutionPolicy, typename _BrickTag, typename _AtomicType, typename _Predicate,
typename... _Ranges>
_AtomicType
operator()(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _BrickTag __brick_tag,
const std::size_t __rng_n, const std::size_t __n_groups, const std::size_t __wgroup_size,
const _AtomicType __init_value, _Predicate __pred, _Ranges&&... __rngs)
{
auto __result = __init_value;

// Calculate the number of elements to be processed by each work-item.
const auto __iters_per_work_item =
oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __n_groups * __wgroup_size);

// scope is to copy data back to __result after destruction of temporary sycl:buffer
{
sycl::buffer<_AtomicType, 1> __result_sycl_buf(&__result, 1); // temporary storage for global atomic

// main parallel_for
__exec.queue().submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rngs...);
auto __result_sycl_buf_acc = __result_sycl_buf.template get_access<access_mode::read_write>(__cgh);

__cgh.parallel_for<KernelName...>(
sycl::nd_range</*dim=*/1>(sycl::range</*dim=*/1>(__n_groups * __wgroup_size),
sycl::range</*dim=*/1>(__wgroup_size)),
[=](sycl::nd_item</*dim=*/1> __item_id) {
auto __local_idx = __item_id.get_local_id(0);

// 1. Set initial value to local found state
_AtomicType __found_local = __init_value;

// 2. Find any element that satisfies pred
// - after this call __found_local may still have initial value:
// 1) if no element satisfies pred;
// 2) early exit from sub-group occurred: in this case the state of __found_local will updated in the next group operation (3)
__pred(__item_id, __rng_n, __iters_per_work_item, __n_groups * __wgroup_size, __found_local,
__brick_tag, __rngs...);

// 3. Reduce over group: find __dpl_sycl::__minimum (for the __parallel_find_forward_tag),
// find __dpl_sycl::__maximum (for the __parallel_find_backward_tag)
// or update state with __dpl_sycl::__any_of_group (for the __parallel_or_tag)
// inside all our group items
if constexpr (__or_tag_check)
__found_local = __dpl_sycl::__any_of_group(__item_id.get_group(), __found_local);
else
__found_local = __dpl_sycl::__reduce_over_group(
__item_id.get_group(), __found_local, typename _BrickTag::_LocalResultsReduceOp{});

// Set local found state value value to global atomic
if (__local_idx == 0 && __found_local != __init_value)
{
__dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::global_space> __found(
*__dpl_sycl::__get_accessor_ptr(__result_sycl_buf_acc));

// Update global (for all groups) atomic state with the found index
_BrickTag::__save_state_to_atomic(__found, __found_local);
}
});
});
//The end of the scope - a point of synchronization (on temporary sycl buffer destruction)
}

return __result;
}
};

// Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag.
template <typename _ExecutionPolicy, typename _Brick, typename _BrickTag, typename... _Ranges>
Expand All @@ -1907,12 +1923,6 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
_BrickTag __brick_tag, _Ranges&&... __rngs)
{
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
using _FindOrKernelOneWG =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<__find_or_kernel_one_wg, _CustomName,
_Brick, _BrickTag, _Ranges...>;
using _FindOrKernel =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<__find_or_kernel, _CustomName, _Brick,
_BrickTag, _Ranges...>;

auto __rng_n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);
assert(__rng_n > 0);
Expand All @@ -1935,8 +1945,11 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
// We shouldn't have any restrictions for _AtomicType type here
// because we have a single work-group and we don't need to use atomics for inter-work-group communication.

using __find_or_one_wg_kernel_name =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__find_or_kernel_one_wg<_CustomName>>;

// Single WG implementation
__result = __parallel_find_or_impl_one_wg<_FindOrKernelOneWG, __or_tag_check>(
__result = __parallel_find_or_impl_one_wg<__or_tag_check, __find_or_one_wg_kernel_name>()(
oneapi::dpl::__internal::__device_backend_tag{}, std::forward<_ExecutionPolicy>(__exec), __brick_tag,
__rng_n, __wgroup_size, __init_value, __pred, std::forward<_Ranges>(__rngs)...);
}
Expand All @@ -1945,8 +1958,11 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
assert("This device does not support 64-bit atomics" &&
(sizeof(_AtomicType) < 8 || __exec.queue().get_device().has(sycl::aspect::atomic64)));

using __find_or_kernel_name =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__find_or_kernel<_CustomName>>;

// Multiple WG implementation
__result = __parallel_find_or_impl_multiple_wgs<_FindOrKernel, __or_tag_check>(
__result = __parallel_find_or_impl_multiple_wgs<__or_tag_check, __find_or_kernel_name>()(
oneapi::dpl::__internal::__device_backend_tag{}, std::forward<_ExecutionPolicy>(__exec), __brick_tag,
__rng_n, __n_groups, __wgroup_size, __init_value, __pred, std::forward<_Ranges>(__rngs)...);
}
Expand Down

0 comments on commit ea87681

Please sign in to comment.