Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneDPL][ranges] support size limit for output for merge algorithm #1942

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion include/oneapi/dpl/pstl/algorithm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "parallel_backend.h"
#include "parallel_impl.h"
#include "iterator_impl.h"
#include "../functional"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we still have #include <functional> above?


#if _ONEDPL_HETERO_BACKEND
# include "hetero/algorithm_impl_hetero.h" // for __pattern_fill_n, __pattern_generate_n
Expand Down Expand Up @@ -2948,6 +2949,49 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec,
// merge
//------------------------------------------------------------------------

template<typename It1, typename It2, typename ItOut, typename _Comp>
std::pair<It1, It2>
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably the existing implementation of __serial_merge is more faster then this.

/* __is_vector = */ std::false_type)
{
while(__it_1 != __it_1_e && __it_2 != __it_2_e)
{
if (__comp(*__it_1, *__it_2))
{
*__it_out = *__it_1;
++__it_out, ++__it_1;
}
else
{
*__it_out = *__it_2;
++__it_out, ++__it_2;
}
if(__it_out == __it_out_e)
return {__it_1, __it_2};
}

if(__it_1 == __it_1_e)
{
for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out)
*__it_out = *__it_2;
}
else
{
//assert(__it_2 == __it_2_e);
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
*__it_out = *__it_1;
}
return {__it_1, __it_2};
}

template<typename It1, typename It2, typename ItOut, typename _Comp>
std::pair<It1, It2>
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
/* __is_vector = */ std::true_type)
{
return __unseq_backend::__simd_merge(__it_1, __it_1_e, __it_2, __it_2_e, __it_out, __it_out_e, __comp);
}

template <class _ForwardIterator1, class _ForwardIterator2, class _OutputIterator, class _Compare>
_OutputIterator
__brick_merge(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2,
Expand Down Expand Up @@ -2980,10 +3024,87 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt
typename _Tag::__is_vector{});
}

template<class _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
std::pair<_It1, _It2>
__pattern_merge_2(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
{
return __brick_merge_2(__it_1, __it_1 + __n_1, __it_2, __it_2 + __n_2, __it_out, __it_out + __n_out, __comp,
typename _Tag::__is_vector{});
}

template<typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
std::pair<_It1, _It2>
__pattern_merge_2(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
{
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;

_It1 __it_res_1;
_It2 __it_res_2;

__internal::__except_handler([&]() {
__par_backend::__parallel_for(__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
[=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j)
{
//a start merging point on the merge path; for each thread
_Index1 __r = 0; //row index
_Index2 __c = 0; //column index

if(__i > 0)
{
//calc merge path intersection:
const _Index3 __d_size =
std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;

auto __get_row = [__i, __n_1](auto __d)
{ return std::min<_Index1>(__i, __n_1) - __d - 1; };
auto __get_column = [__i, __n_1](auto __d)
{ return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0); };

oneapi::dpl::counting_iterator<_Index3> __it_d(0);

auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
[&](auto __d, auto __val) {
auto __r = __get_row(__d);
auto __c = __get_column(__d);

oneapi::dpl::__internal::__compare<_Comp, oneapi::dpl::identity>
__cmp{__comp, oneapi::dpl::identity{}};
const auto __res = (__cmp(__it_1[__r], __it_2[__c]) ? 1 : 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const auto __res = (__cmp(__it_1[__r], __it_2[__c]) ? 1 : 0);
const auto __res = __cmp(__it_1[__r], __it_2[__c]) ? 1 : 0;


return __res < __val;
Comment on lines +3076 to +3078
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or the second variant:

retrun !__cmp(__it_1[__r], __it_2[__c]);

}
);

//intersection point
__r = __get_row(__res_d);
__c = __get_column(__res_d);
++__r; //to get a merge matrix ceil, lying on the current diagonal
}

//serial merge n elements, starting from input x and y, to [i, j) output range
auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1,
const auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1,

__it_2 + __c, __it_2 + __n_2,
__it_out + __i, __it_out + __j, __comp, _IsVector{});

if(__j == __n_out)
{
__it_res_1 = __res.first;
__it_res_2 = __res.second;
}
}, _ONEDPL_MERGE_CUT_OFF); //grainsize
});

return {__it_res_1, __it_res_2};
}

template <class _IsVector, class _ExecutionPolicy, class _RandomAccessIterator1, class _RandomAccessIterator2,
class _RandomAccessIterator3, class _Compare>
_RandomAccessIterator3
__pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1,
__pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1,
_RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2,
_RandomAccessIterator3 __d_first, _Compare __comp)
{
Expand Down
36 changes: 18 additions & 18 deletions include/oneapi/dpl/pstl/algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,31 +448,31 @@ auto
__pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
_Proj1 __proj1, _Proj2 __proj2)
{
static_assert(__is_parallel_tag_v<_Tag> || typename _Tag::__is_vector{});
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only

using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
std::forward<decltype(__val2)>(__val2)));};

auto __res = oneapi::dpl::__internal::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
std::ranges::begin(__r1), std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2),
std::ranges::begin(__r2) + std::ranges::size(__r2), std::ranges::begin(__out_r), __comp_2);
using _Index1 = std::ranges::range_difference_t<_R1>;
using _Index2 = std::ranges::range_difference_t<_R2>;
using _Index3 = std::ranges::range_difference_t<_OutRange>;

using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;
_Index1 __n_1 = std::ranges::size(__r1);
_Index2 __n_2 = std::ranges::size(__r2);
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));

return __return_type{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) + std::ranges::size(__r2), __res};
}
auto __it_1 = std::ranges::begin(__r1);
auto __it_2 = std::ranges::begin(__r2);
auto __it_out = std::ranges::begin(__out_r);

template<typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
typename _Proj1, typename _Proj2>
auto
__pattern_merge(__serial_tag</*IsVector*/std::false_type>, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
_Proj1 __proj1, _Proj2 __proj2)
{
return std::ranges::merge(std::forward<_R1>(__r1), std::forward<_R2>(__r2), std::ranges::begin(__out_r), __comp, __proj1,
__proj2);
if(__n_out == 0)
return __return_type{__it_1, __it_2, __it_out};

auto __res = __pattern_merge_2(__tag, std::forward<_ExecutionPolicy>(__exec), __it_2, __n_2, __it_1, __n_1, __it_out, __n_out, __comp_2);

return __return_type{__res.second, __res.first, __it_out + __n_out};
}

} // namespace __ranges
Expand Down
7 changes: 5 additions & 2 deletions include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,12 @@ merge(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& _
{
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec, __rng1, __rng2, __rng3);

return oneapi::dpl::__internal::__ranges::__pattern_merge(
auto __view_res = views::all_write(::std::forward<_Range3>(__rng3));
oneapi::dpl::__internal::__ranges::__pattern_merge(
__dispatch_tag, ::std::forward<_ExecutionPolicy>(__exec), views::all_read(::std::forward<_Range1>(__rng1)),
views::all_read(::std::forward<_Range2>(__rng2)), views::all_write(::std::forward<_Range3>(__rng3)), __comp);
views::all_read(::std::forward<_Range2>(__rng2)), __view_res, __comp);

return __view_res.size();
dmitriy-sobolev marked this conversation as resolved.
Show resolved Hide resolved
dmitriy-sobolev marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3>
Expand Down
50 changes: 29 additions & 21 deletions include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,19 @@ namespace __ranges
//------------------------------------------------------------------------

template <typename _BackendTag, typename _ExecutionPolicy, typename _Function, typename... _Ranges>
void
auto
__pattern_walk_n(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Function __f, _Ranges&&... __rngs)
{
auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);
using _Size = std::make_unsigned_t<std::common_type_t<oneapi::dpl::__internal::__difference_t<_Ranges>...>>;
auto __n = std::min({_Size(__rngs.size())...});
if (__n > 0)
{
oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n,
::std::forward<_Ranges>(__rngs)...)
.__deferrable_wait();
}
return __n;
}

#if _ONEDPL_CPP20_RANGES_PRESENT
Expand Down Expand Up @@ -680,44 +682,44 @@ struct __copy2_wrapper;

template <typename _BackendTag, typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3,
typename _Compare>
oneapi::dpl::__internal::__difference_t<_Range3>
std::pair<oneapi::dpl::__internal::__difference_t<_Range1>, oneapi::dpl::__internal::__difference_t<_Range2>>
dmitriy-sobolev marked this conversation as resolved.
Show resolved Hide resolved
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2,
_Range3&& __rng3, _Compare __comp)
{
auto __n1 = __rng1.size();
auto __n2 = __rng2.size();
auto __n = __n1 + __n2;
if (__n == 0)
return 0;
if (__rng3.size() == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (__rng3.size() == 0)
if (__rng3.empty())

return {0, 0};

//To consider the direct copying pattern call in case just one of sequences is empty.
if (__n1 == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make additional optimization here for the case when last(rng1) < first(rng2)

{
oneapi::dpl::__internal::__ranges::__pattern_walk_n(
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n(
__tag,
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy1_wrapper>(
::std::forward<_ExecutionPolicy>(__exec)),
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3));
return {0, __res};
}
else if (__n2 == 0)

if (__n2 == 0)
{
oneapi::dpl::__internal::__ranges::__pattern_walk_n(
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n(
__tag,
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy2_wrapper>(
::std::forward<_ExecutionPolicy>(__exec)),
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
::std::forward<_Range1>(__rng1), ::std::forward<_Range3>(__rng3));
}
else
{
__par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
::std::forward<_Range3>(__rng3), __comp)
.__deferrable_wait();
return {__res, 0};
}

return __n;
auto __res = __par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
::std::forward<_Range3>(__rng3), __comp);

auto __val = __res.get();
return {__val.first, __val.second};
}

#if _ONEDPL_CPP20_RANGES_PRESENT
Expand All @@ -727,21 +729,27 @@ auto
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r,
_Comp __comp, _Proj1 __proj1, _Proj2 __proj2)
{
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only

auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)),
std::invoke(__proj2, std::forward<decltype(__val2)>(__val2)));};

using _Index1 = std::ranges::range_difference_t<_R1>;
using _Index2 = std::ranges::range_difference_t<_R2>;
using _Index3 = std::ranges::range_difference_t<_OutRange>;

_Index1 __n_1 = std::ranges::size(__r1);
_Index2 __n_2 = std::ranges::size(__r2);
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));

auto __res = oneapi::dpl::__internal::__ranges::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
oneapi::dpl::__ranges::views::all_read(__r1), oneapi::dpl::__ranges::views::all_read(__r2),
oneapi::dpl::__ranges::views::all_write(__out_r), __comp_2);

using __return_t = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;

return __return_t{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) +
std::ranges::size(__r2), std::ranges::begin(__out_r) + __res};
return __return_t{std::ranges::begin(__r1) + __res.first, std::ranges::begin(__r2) + __res.second,
std::ranges::begin(__out_r) + __n_out};
}
#endif //_ONEDPL_CPP20_RANGES_PRESENT

Expand Down
Loading
Loading