diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index 63148e0be63..41cc600dc0d 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -28,10 +28,12 @@ #include #include #include +#include #include "../../iterator_impl.h" #include "../../execution_impl.h" #include "../../utils_ranges.h" +#include "../../utils.h" #include "sycl_defs.h" #include "parallel_backend_sycl_utils.h" @@ -258,12 +260,8 @@ struct __parallel_for_large_submitter; template struct __parallel_for_large_submitter<__internal::__optional_kernel_name<_Name...>, _RangeTypes...> { - // Flatten the range as std::tuple value types in the range are likely coming from separate ranges in a zip - // iterator. - using _FlattenedRangesTuple = typename oneapi::dpl::__internal::__flatten_std_or_internal_tuple< - std::tuple...>>::type; - static constexpr std::size_t __min_type_size = - oneapi::dpl::__internal::__min_tuple_type_size_v<_FlattenedRangesTuple>; + static constexpr std::size_t __min_type_size = oneapi::dpl::__internal::__min_nested_type_size< + std::tuple...>>::value; // __iters_per_work_item is set to 1, 2, 4, 8, or 16 depending on the smallest type in the // flattened ranges. This allows us to launch enough work per item to saturate the device's memory // bandwidth. This heuristic errs on the side of launching more work per item than what is needed to diff --git a/include/oneapi/dpl/pstl/tuple_impl.h b/include/oneapi/dpl/pstl/tuple_impl.h index c758a4a3f1b..239734d4861 100644 --- a/include/oneapi/dpl/pstl/tuple_impl.h +++ b/include/oneapi/dpl/pstl/tuple_impl.h @@ -793,25 +793,6 @@ struct __decay_with_tuple_specialization<::std::tuple<_Args...>> template using __decay_with_tuple_specialization_t = typename __decay_with_tuple_specialization<_Args...>::type; -// Flatten nested std::tuple or oneapi::dpl::__internal::tuple types into a single std::tuple. -template -struct __flatten_std_or_internal_tuple -{ - using type = std::tuple<_T>; -}; - -template -struct __flatten_std_or_internal_tuple> -{ - using type = decltype(std::tuple_cat(std::declval::type>()...)); -}; - -template -struct __flatten_std_or_internal_tuple> -{ - using type = decltype(std::tuple_cat(std::declval::type>()...)); -}; - } // namespace __internal } // namespace dpl } // namespace oneapi diff --git a/include/oneapi/dpl/pstl/utils.h b/include/oneapi/dpl/pstl/utils.h index 10d60d8c5d6..1848d33eaea 100644 --- a/include/oneapi/dpl/pstl/utils.h +++ b/include/oneapi/dpl/pstl/utils.h @@ -25,6 +25,7 @@ #include #include #include +#include #if _ONEDPL_BACKEND_SYCL # include "hetero/dpcpp/sycl_defs.h" @@ -784,29 +785,21 @@ union __lazy_ctor_storage } }; -// Utility that returns the smallest type size in a tuple. -template -class __min_tuple_type_size; - +// Returns the smallest type within a set of potentially nested template types. +// E.g. If we consider the type: T = tuple, int, double>, +// then __min_nested_type_size::value returns sizeof(short). template -class __min_tuple_type_size> +struct __min_nested_type_size { - public: - static constexpr std::size_t value = sizeof(_T); + constexpr static std::size_t value = sizeof(_T); }; -template -class __min_tuple_type_size> +template