From 845de21be8d1cd55ef1a85ef67064321a99f8aff Mon Sep 17 00:00:00 2001 From: Matthew Michel Date: Thu, 19 Dec 2024 19:44:48 -0600 Subject: [PATCH] Address applicable comments from PR #1870 Signed-off-by: Matthew Michel --- .../hetero/dpcpp/parallel_backend_sycl_for.h | 24 +++++++++---------- .../dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h | 2 +- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_for.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_for.h index da148da1a70..ff28eab7412 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_for.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_for.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "sycl_defs.h" #include "parallel_backend_sycl_utils.h" @@ -107,8 +108,7 @@ struct __parallel_for_large_submitter<__internal::__optional_kernel_name<_Name.. __count; const std::size_t __work_item_idx = __sub_group_start_idx + __adj_elements_per_work_item * __sub_group_local_id; - return std::make_tuple(__work_item_idx, __adj_elements_per_work_item * __sub_group_size, - __is_full_sub_group); + return std::tuple(__work_item_idx, __adj_elements_per_work_item * __sub_group_size, __is_full_sub_group); } else { @@ -119,8 +119,7 @@ struct __parallel_for_large_submitter<__internal::__optional_kernel_name<_Name.. const bool __is_full_work_group = __work_group_start_idx + __iters_per_work_item * __work_group_size * __adj_elements_per_work_item <= __count; - return std::make_tuple(__work_item_idx, __work_group_size * __adj_elements_per_work_item, - __is_full_work_group); + return std::tuple(__work_item_idx, __work_group_size * __adj_elements_per_work_item, __is_full_work_group); } } @@ -147,21 +146,20 @@ struct __parallel_for_large_submitter<__internal::__optional_kernel_name<_Name.. operator()(_ExecutionPolicy&& __exec, _Fp __brick, _Index __count, _Ranges&&... __rngs) const { assert(oneapi::dpl::__ranges::__get_first_range_size(__rngs...) > 0); + const std::size_t __work_group_size = + oneapi::dpl::__internal::__max_work_group_size(__exec, __max_work_group_size); _PRINT_INFO_IN_DEBUG_MODE(__exec); - auto __event = __exec.queue().submit([__rngs..., __brick, __exec, __count](sycl::handler& __cgh) { + auto __event = __exec.queue().submit([__rngs..., __brick, __work_group_size, __count](sycl::handler& __cgh) { //get an access to data under SYCL buffer: oneapi::dpl::__ranges::__require_access(__cgh, __rngs...); constexpr static std::uint16_t __iters_per_work_item = _Fp::__preferred_iters_per_item; - const std::size_t __work_group_size = - oneapi::dpl::__internal::__max_work_group_size(__exec, __max_work_group_size); - const std::size_t __num_groups = - oneapi::dpl::__internal::__dpl_ceiling_div(__count, (__work_group_size * _Fp::__preferred_vector_size * __iters_per_work_item)); - const std::size_t __num_items = __num_groups * __work_group_size; + const std::size_t __num_groups = oneapi::dpl::__internal::__dpl_ceiling_div( + __count, (__work_group_size * _Fp::__preferred_vector_size * __iters_per_work_item)); __cgh.parallel_for<_Name...>( - sycl::nd_range(sycl::range<1>(__num_items), sycl::range<1>(__work_group_size)), + sycl::nd_range(sycl::range<1>(__num_groups * __work_group_size), sycl::range<1>(__work_group_size)), [=](sycl::nd_item __item) { - auto [__idx, __stride, __is_full] = - __stride_recommender(__item, __count, __iters_per_work_item, _Fp::__preferred_vector_size, __work_group_size); + auto [__idx, __stride, __is_full] = __stride_recommender( + __item, __count, __iters_per_work_item, _Fp::__preferred_vector_size, __work_group_size); __strided_loop<__iters_per_work_item> __execute_loop{static_cast(__count)}; if (__is_full) { diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h index 934c0a54b90..f126df3ab9e 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h @@ -19,7 +19,7 @@ #include #include #if _ONEDPL_CPP20_RANGES_PRESENT && _ONEDPL_CPP20_CONCEPTS_PRESENT -# include // std::ranges::contiguous_range +#include // std::ranges::contiguous_range #endif #include "../../utils_ranges.h"