-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[oneDPL][ranges] support size limit for output for merge algorithm #1942
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -31,6 +31,7 @@ | |||||
#include "parallel_backend.h" | ||||||
#include "parallel_impl.h" | ||||||
#include "iterator_impl.h" | ||||||
#include "../functional" | ||||||
|
||||||
#if _ONEDPL_HETERO_BACKEND | ||||||
# include "hetero/algorithm_impl_hetero.h" // for __pattern_fill_n, __pattern_generate_n | ||||||
|
@@ -2948,6 +2949,49 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, | |||||
// merge | ||||||
//------------------------------------------------------------------------ | ||||||
|
||||||
template<typename It1, typename It2, typename ItOut, typename _Comp> | ||||||
std::pair<It1, It2> | ||||||
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably the existing implementation of |
||||||
/* __is_vector = */ std::false_type) | ||||||
{ | ||||||
while(__it_1 != __it_1_e && __it_2 != __it_2_e) | ||||||
{ | ||||||
if (__comp(*__it_1, *__it_2)) | ||||||
{ | ||||||
*__it_out = *__it_1; | ||||||
++__it_out, ++__it_1; | ||||||
} | ||||||
else | ||||||
{ | ||||||
*__it_out = *__it_2; | ||||||
++__it_out, ++__it_2; | ||||||
} | ||||||
if(__it_out == __it_out_e) | ||||||
return {__it_1, __it_2}; | ||||||
} | ||||||
|
||||||
if(__it_1 == __it_1_e) | ||||||
{ | ||||||
for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out) | ||||||
*__it_out = *__it_2; | ||||||
} | ||||||
else | ||||||
{ | ||||||
//assert(__it_2 == __it_2_e); | ||||||
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out) | ||||||
*__it_out = *__it_1; | ||||||
} | ||||||
return {__it_1, __it_2}; | ||||||
} | ||||||
|
||||||
template<typename It1, typename It2, typename ItOut, typename _Comp> | ||||||
std::pair<It1, It2> | ||||||
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp, | ||||||
/* __is_vector = */ std::true_type) | ||||||
{ | ||||||
return __unseq_backend::__simd_merge(__it_1, __it_1_e, __it_2, __it_2_e, __it_out, __it_out_e, __comp); | ||||||
} | ||||||
|
||||||
template <class _ForwardIterator1, class _ForwardIterator2, class _OutputIterator, class _Compare> | ||||||
_OutputIterator | ||||||
__brick_merge(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2, | ||||||
|
@@ -2980,10 +3024,87 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt | |||||
typename _Tag::__is_vector{}); | ||||||
} | ||||||
|
||||||
template<class _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2, | ||||||
typename _Index2, typename _OutIt, typename _Index3, typename _Comp> | ||||||
std::pair<_It1, _It2> | ||||||
__pattern_merge_2(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2, | ||||||
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp) | ||||||
{ | ||||||
return __brick_merge_2(__it_1, __it_1 + __n_1, __it_2, __it_2 + __n_2, __it_out, __it_out + __n_out, __comp, | ||||||
typename _Tag::__is_vector{}); | ||||||
} | ||||||
|
||||||
template<typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2, | ||||||
typename _Index2, typename _OutIt, typename _Index3, typename _Comp> | ||||||
std::pair<_It1, _It2> | ||||||
__pattern_merge_2(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2, | ||||||
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp) | ||||||
{ | ||||||
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag; | ||||||
|
||||||
_It1 __it_res_1; | ||||||
_It2 __it_res_2; | ||||||
|
||||||
__internal::__except_handler([&]() { | ||||||
__par_backend::__parallel_for(__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out, | ||||||
[=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j) | ||||||
{ | ||||||
//a start merging point on the merge path; for each thread | ||||||
_Index1 __r = 0; //row index | ||||||
_Index2 __c = 0; //column index | ||||||
|
||||||
if(__i > 0) | ||||||
{ | ||||||
//calc merge path intersection: | ||||||
const _Index3 __d_size = | ||||||
std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1; | ||||||
|
||||||
auto __get_row = [__i, __n_1](auto __d) | ||||||
{ return std::min<_Index1>(__i, __n_1) - __d - 1; }; | ||||||
auto __get_column = [__i, __n_1](auto __d) | ||||||
{ return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0); }; | ||||||
|
||||||
oneapi::dpl::counting_iterator<_Index3> __it_d(0); | ||||||
|
||||||
auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1, | ||||||
[&](auto __d, auto __val) { | ||||||
auto __r = __get_row(__d); | ||||||
auto __c = __get_column(__d); | ||||||
|
||||||
oneapi::dpl::__internal::__compare<_Comp, oneapi::dpl::identity> | ||||||
__cmp{__comp, oneapi::dpl::identity{}}; | ||||||
const auto __res = (__cmp(__it_1[__r], __it_2[__c]) ? 1 : 0); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
return __res < __val; | ||||||
Comment on lines
+3076
to
+3078
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or the second variant: retrun !__cmp(__it_1[__r], __it_2[__c]); |
||||||
} | ||||||
); | ||||||
|
||||||
//intersection point | ||||||
__r = __get_row(__res_d); | ||||||
__c = __get_column(__res_d); | ||||||
++__r; //to get a merge matrix ceil, lying on the current diagonal | ||||||
} | ||||||
|
||||||
//serial merge n elements, starting from input x and y, to [i, j) output range | ||||||
auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
__it_2 + __c, __it_2 + __n_2, | ||||||
__it_out + __i, __it_out + __j, __comp, _IsVector{}); | ||||||
|
||||||
if(__j == __n_out) | ||||||
{ | ||||||
__it_res_1 = __res.first; | ||||||
__it_res_2 = __res.second; | ||||||
} | ||||||
}, _ONEDPL_MERGE_CUT_OFF); //grainsize | ||||||
}); | ||||||
|
||||||
return {__it_res_1, __it_res_2}; | ||||||
} | ||||||
|
||||||
template <class _IsVector, class _ExecutionPolicy, class _RandomAccessIterator1, class _RandomAccessIterator2, | ||||||
class _RandomAccessIterator3, class _Compare> | ||||||
_RandomAccessIterator3 | ||||||
__pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1, | ||||||
__pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1, | ||||||
_RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2, | ||||||
_RandomAccessIterator3 __d_first, _Compare __comp) | ||||||
{ | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -51,17 +51,19 @@ namespace __ranges | |||||
//------------------------------------------------------------------------ | ||||||
|
||||||
template <typename _BackendTag, typename _ExecutionPolicy, typename _Function, typename... _Ranges> | ||||||
void | ||||||
auto | ||||||
__pattern_walk_n(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Function __f, _Ranges&&... __rngs) | ||||||
{ | ||||||
auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...); | ||||||
using _Size = std::make_unsigned_t<std::common_type_t<oneapi::dpl::__internal::__difference_t<_Ranges>...>>; | ||||||
auto __n = std::min({_Size(__rngs.size())...}); | ||||||
if (__n > 0) | ||||||
{ | ||||||
oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), | ||||||
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n, | ||||||
::std::forward<_Ranges>(__rngs)...) | ||||||
.__deferrable_wait(); | ||||||
} | ||||||
return __n; | ||||||
} | ||||||
|
||||||
#if _ONEDPL_CPP20_RANGES_PRESENT | ||||||
|
@@ -680,44 +682,44 @@ struct __copy2_wrapper; | |||||
|
||||||
template <typename _BackendTag, typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, | ||||||
typename _Compare> | ||||||
oneapi::dpl::__internal::__difference_t<_Range3> | ||||||
std::pair<oneapi::dpl::__internal::__difference_t<_Range1>, oneapi::dpl::__internal::__difference_t<_Range2>> | ||||||
dmitriy-sobolev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, | ||||||
_Range3&& __rng3, _Compare __comp) | ||||||
{ | ||||||
auto __n1 = __rng1.size(); | ||||||
auto __n2 = __rng2.size(); | ||||||
auto __n = __n1 + __n2; | ||||||
if (__n == 0) | ||||||
return 0; | ||||||
if (__rng3.size() == 0) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return {0, 0}; | ||||||
|
||||||
//To consider the direct copying pattern call in case just one of sequences is empty. | ||||||
if (__n1 == 0) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can make additional optimization here for the case when last(rng1) < first(rng2) |
||||||
{ | ||||||
oneapi::dpl::__internal::__ranges::__pattern_walk_n( | ||||||
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n( | ||||||
__tag, | ||||||
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy1_wrapper>( | ||||||
::std::forward<_ExecutionPolicy>(__exec)), | ||||||
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{}, | ||||||
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3)); | ||||||
return {0, __res}; | ||||||
} | ||||||
else if (__n2 == 0) | ||||||
|
||||||
if (__n2 == 0) | ||||||
{ | ||||||
oneapi::dpl::__internal::__ranges::__pattern_walk_n( | ||||||
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n( | ||||||
__tag, | ||||||
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy2_wrapper>( | ||||||
::std::forward<_ExecutionPolicy>(__exec)), | ||||||
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{}, | ||||||
::std::forward<_Range1>(__rng1), ::std::forward<_Range3>(__rng3)); | ||||||
} | ||||||
else | ||||||
{ | ||||||
__par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), | ||||||
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2), | ||||||
::std::forward<_Range3>(__rng3), __comp) | ||||||
.__deferrable_wait(); | ||||||
return {__res, 0}; | ||||||
} | ||||||
|
||||||
return __n; | ||||||
auto __res = __par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), | ||||||
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2), | ||||||
::std::forward<_Range3>(__rng3), __comp); | ||||||
|
||||||
auto __val = __res.get(); | ||||||
return {__val.first, __val.second}; | ||||||
} | ||||||
|
||||||
#if _ONEDPL_CPP20_RANGES_PRESENT | ||||||
|
@@ -727,21 +729,27 @@ auto | |||||
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, | ||||||
_Comp __comp, _Proj1 __proj1, _Proj2 __proj2) | ||||||
{ | ||||||
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only | ||||||
|
||||||
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp, | ||||||
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), | ||||||
std::invoke(__proj2, std::forward<decltype(__val2)>(__val2)));}; | ||||||
|
||||||
using _Index1 = std::ranges::range_difference_t<_R1>; | ||||||
using _Index2 = std::ranges::range_difference_t<_R2>; | ||||||
using _Index3 = std::ranges::range_difference_t<_OutRange>; | ||||||
|
||||||
_Index1 __n_1 = std::ranges::size(__r1); | ||||||
_Index2 __n_2 = std::ranges::size(__r2); | ||||||
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r)); | ||||||
|
||||||
auto __res = oneapi::dpl::__internal::__ranges::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec), | ||||||
oneapi::dpl::__ranges::views::all_read(__r1), oneapi::dpl::__ranges::views::all_read(__r2), | ||||||
oneapi::dpl::__ranges::views::all_write(__out_r), __comp_2); | ||||||
|
||||||
using __return_t = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>, | ||||||
std::ranges::borrowed_iterator_t<_OutRange>>; | ||||||
|
||||||
return __return_t{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) + | ||||||
std::ranges::size(__r2), std::ranges::begin(__out_r) + __res}; | ||||||
return __return_t{std::ranges::begin(__r1) + __res.first, std::ranges::begin(__r2) + __res.second, | ||||||
std::ranges::begin(__out_r) + __n_out}; | ||||||
} | ||||||
#endif //_ONEDPL_CPP20_RANGES_PRESENT | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we still have
#include <functional>
above?