Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneDPL][ranges] An implementation fix for a random access range without operator[] and size() method #1969

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,53 @@ namespace dpl
namespace ranges
{

namespace __internal
{

template<typename _T>
concept __is_subscriptable = requires(_T&& __a) { __a[0]; };

template<typename _T>
concept __is_not_subscriptable = !__is_subscriptable<_T>;

template<typename _T>
concept __is_sizeable = requires(_T&& __a) { __a.size(); };

template<typename _T>
concept __is_empty_method = requires(_T&& __a) { __a.empty(); };

template <typename _R>
struct _WrapperRAR: public _R
{
template <typename _Base>
_WrapperRAR(_Base&& __r): _R(std::forward<_Base>(__r)) {}
decltype(auto) operator[](auto __i) { return _R::begin()[__i]; }
decltype(auto) operator[](auto __i) const { return _R::begin()[__i]; }

std::enable_if_t<!__is_sizeable<_R>, std::ranges::range_size_t<_R>>
size() const { return this->_R::end() - this->_R::begin(); }

std::enable_if_t<!__is_empty_method<_R>, bool>
empty() const { return this->_R::end() - this->_R::begin() <= 0; }
};

template <__is_not_subscriptable _R>
constexpr decltype (auto)
__get_r(_R&& __r)
{
using _T = std::remove_reference_t<_R>;
return _WrapperRAR<_T>(std::forward<_R>(__r));
}

template <__is_subscriptable _R>
constexpr decltype (auto)
__get_r(_R&& __r)
{
return std::forward<_R>(__r);
}

} //__internal

// [alg.foreach]

namespace __internal
Expand All @@ -64,7 +111,7 @@ struct __for_each_fn
{
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec);
oneapi::dpl::__internal::__ranges::__pattern_for_each(
__dispatch_tag, std::forward<_ExecutionPolicy>(__exec), __r, __f, __proj);
__dispatch_tag, std::forward<_ExecutionPolicy>(__exec), __get_r(__r), __f, __proj);

return {std::ranges::begin(__r) + std::ranges::size(__r)};
}
Expand Down Expand Up @@ -149,8 +196,11 @@ struct __find_if_fn
operator()(_ExecutionPolicy&& __exec, _R&& __r, _Pred __pred, _Proj __proj = {}) const
{
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec);
return oneapi::dpl::__internal::__ranges::__pattern_find_if(__dispatch_tag,
std::forward<_ExecutionPolicy>(__exec), std::forward<_R>(__r), __pred, __proj);

auto __ra = __get_r(__r);
auto __res = oneapi::dpl::__internal::__ranges::__pattern_find_if(__dispatch_tag,
std::forward<_ExecutionPolicy>(__exec), __ra, __pred, __proj) - __ra.begin();
return __r.begin() + __res;
}
}; //__find_if_fn
} //__internal
Expand Down Expand Up @@ -213,7 +263,7 @@ struct __any_of_fn
{
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec);
return oneapi::dpl::__internal::__ranges::__pattern_any_of(__dispatch_tag,
std::forward<_ExecutionPolicy>(__exec), std::forward<_R>(__r), __pred, __proj);
std::forward<_ExecutionPolicy>(__exec), __get_r(std::forward<_R>(__r)), __pred, __proj);
}
}; //__any_of_fn
} //__internal
Expand Down
1 change: 1 addition & 0 deletions test/parallel_api/ranges/std_ranges_for_each.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ main()
test_range_algo<1>{}(dpl_ranges::for_each, for_each_checker, f_mutuable, proj_mutuable);
test_range_algo<2, P2>{}(dpl_ranges::for_each, for_each_checker, f_mutuable, &P2::x);
test_range_algo<3, P2>{}(dpl_ranges::for_each, for_each_checker, f_mutuable, &P2::proj);

#endif //_ENABLE_STD_RANGES_TESTING

return TestUtils::done(_ENABLE_STD_RANGES_TESTING);
Expand Down
24 changes: 24 additions & 0 deletions test/parallel_api/ranges/std_ranges_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,28 @@ template<typename T>
static constexpr
bool is_range<T, std::void_t<decltype(std::declval<T&>().begin())>> = true;

//a random access range, but without operator[] and size() method
template<typename R>
struct RangeRA
{
RangeRA(R r): m_r(r) {}
R m_r;
auto begin() { return m_r.begin(); }
auto end() { return m_r.end(); }
auto begin() const { return m_r.begin(); }
auto end() const { return m_r.end(); }
};

struct __range_ra_fn
{
template<typename R>
RangeRA<R>
operator()(R r)
{
return RangeRA<R>(r);
}
} __range_ra_wr;

template<typename DataType, typename Container, TestDataMode test_mode = data_in>
struct test
{
Expand Down Expand Up @@ -536,6 +558,7 @@ struct test_range_algo
test<T, host_vector<T>, mode>{max_n}(host_policies(), algo, checker, subrange_view, std::identity{}, args...);
test<T, host_vector<T>, mode>{max_n}(host_policies(), algo, checker, std::views::all, std::identity{}, args...);
test<T, host_subrange<T>, mode>{max_n}(host_policies(), algo, checker, std::views::all, std::identity{}, args...);
test<T, host_subrange<T>, mode>{max_n}(host_policies(), algo, checker, __range_ra_wr, __range_ra_wr, args...);
#if TEST_CPP20_SPAN_PRESENT
test<T, host_vector<T>, mode>{max_n}(host_policies(), algo, checker, span_view, std::identity{}, args...);
test<T, host_span<T>, mode>{max_n}(host_policies(), algo, checker, std::views::all, std::identity{}, args...);
Expand All @@ -551,6 +574,7 @@ struct test_range_algo
{
test<T, usm_vector<T>, mode>{max_n}(dpcpp_policy<call_id + 10>(), algo, checker, subrange_view, subrange_view, args...);
test<T, usm_subrange<T>, mode>{max_n}(dpcpp_policy<call_id + 30>(), algo, checker, std::identity{}, std::identity{}, args...);
test<T, usm_subrange<T>, mode>{max_n}(dpcpp_policy<call_id + 35>(), algo, checker, __range_ra_wr, __range_ra_wr, args...);
#if TEST_CPP20_SPAN_PRESENT
test<T, usm_vector<T>, mode>{max_n}(dpcpp_policy<call_id + 20>(), algo, checker, span_view, subrange_view, args...);
test<T, usm_span<T>, mode>{max_n}(dpcpp_policy<call_id + 40>(), algo, checker, std::identity{}, std::identity{}, args...);
Expand Down
Loading