Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Fixed parameter mismatch in complex gemm benchmarks #489

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 165 additions & 77 deletions common/include/common/common_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,17 +491,29 @@ static inline std::vector<blas3_param_t<scalar_t>> get_blas3_params(
} else {
return parse_csv_file<blas3_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 7) {
if (v.size() == 7) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]));
} catch (...) {
throw std::runtime_error("invalid parameter");
}
} else if (v.size() == 9) {
// Case where complex alpha and beta are passed, take the real
// part only
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[7]));
} catch (...) {
throw std::runtime_error("invalid parameter");
}
} else {
throw std::runtime_error(
"invalid number of parameters (7 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]));
} catch (...) {
throw std::runtime_error("invalid parameter");
"invalid number of parameters (7 or 9 expected)");
}
});
}
Expand Down Expand Up @@ -541,18 +553,32 @@ static inline std::vector<blas3_cplx_param_t<scalar_t>> get_blas3_cplx_params(
} else {
return parse_csv_file<blas3_cplx_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 9) {
if (v.size() == 9) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]));
} catch (...) {
throw std::runtime_error("invalid parameter");
}
} else if (v.size() == 7) {
// Case where scalar alpha and beta are passed, duplicate the value
// for both real and imaginary parts
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[5]),
str_to_scalar<scalar_t>(v[6]), str_to_scalar<scalar_t>(v[6]));
} catch (...) {
throw std::runtime_error("invalid parameter");
}

} else {
throw std::runtime_error(
"invalid number of parameters (9 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]));
} catch (...) {
throw std::runtime_error("invalid parameter");
"invalid number of parameters (9 or 7 expected)");
}
});
}
Expand Down Expand Up @@ -589,20 +615,36 @@ get_gemm_batched_strided_cplx_params(Args& args) {
} else {
return parse_csv_file<gemm_batched_strided_cplx_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 13) {
if (v.size() == 13) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]),
str_to_int<index_t>(v[11]), str_to_int<index_t>(v[12]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}

} else if (v.size() == 11) {
// Case where scalar alpha and beta are passed, duplicate the value
// for both real and imaginary parts
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[5]),
str_to_scalar<scalar_t>(v[6]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_int<index_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else {
throw std::runtime_error(
"invalid number of parameters (13 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]),
str_to_int<index_t>(v[11]), str_to_int<index_t>(v[12]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
"invalid number of parameters (13 or 11 expected)");
}
});
}
Expand Down Expand Up @@ -639,19 +681,34 @@ get_gemm_cplx_batched_params(Args& args) {
} else {
return parse_csv_file<gemm_batched_cplx_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 11) {
if (v.size() == 11) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_batch_type(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else if (v.size() == 9) {
// Case where scalar alpha and beta are passed, duplicate the value
// for both real and imaginary parts
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[5]),
str_to_scalar<scalar_t>(v[6]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_batch_type(v[8]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}

} else {
throw std::runtime_error(
"invalid number of parameters (11 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_batch_type(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
"invalid number of parameters (11 or 9 expected)");
}
});
}
Expand Down Expand Up @@ -688,18 +745,32 @@ inline std::vector<gemm_batched_param_t<scalar_t>> get_gemm_batched_params(
} else {
return parse_csv_file<gemm_batched_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 9) {
if (v.size() == 9) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_batch_type(v[8]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else if (v.size() == 11) {
// Case where complex alpha and beta are passed, take the real
// part only
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[7]),
str_to_int<index_t>(v[9]), str_to_batch_type(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}

} else {
throw std::runtime_error(
"invalid number of parameters (9 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_batch_type(v[8]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
"invalid number of parameters (9 or 11 expected)");
}
});
}
Expand Down Expand Up @@ -735,19 +806,33 @@ get_gemm_batched_strided_params(Args& args) {
} else {
return parse_csv_file<gemm_batched_strided_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 11) {
if (v.size() == 11) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_int<index_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else if (v.size() == 13) {
// Case where complex alpha and beta are passed, take the real
// part only
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[7]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]),
str_to_int<index_t>(v[11]), str_to_int<index_t>(v[12]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else {
throw std::runtime_error(
"invalid number of parameters (11 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_int<index_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
"invalid number of parameters (11 or 13 expected)");
}
});
}
Expand All @@ -766,7 +851,8 @@ get_trsm_batched_params(Args& args) {
warning_no_csv();
std::vector<trsm_batched_param_t<scalar_t>> trsm_batched_default;
constexpr index_t dmin = 512, dmax = 8192;
// Stride Multipliers are set by default and correspond to default striding
// Stride Multipliers are set by default and correspond to default
// striding
constexpr index_t stride_a_mul = 1;
constexpr index_t stride_b_mul = 1;
constexpr index_t batch_size = 8;
Expand Down Expand Up @@ -806,8 +892,9 @@ get_trsm_batched_params(Args& args) {
}
/**
* @fn get_reduction_params
* @brief Returns a vector containing the reduction benchmark parameters, either
* read from a file according to the command-line args, or the default ones.
* @brief Returns a vector containing the reduction benchmark parameters,
* either read from a file according to the command-line args, or the default
* ones.
*/
template <typename scalar_t>
static inline std::vector<reduction_param_t> get_reduction_params(Args& args) {
Expand Down Expand Up @@ -920,9 +1007,9 @@ static inline std::vector<syrk_param_t<scalar_t>> get_syrk_params(Args& args) {
}
/**
* @fn get_trsm_params
* @brief Returns a vector containing the trsm benchmark parameters (also valid
* for trmm), either read from a file according to the command-line args, or the
* default ones.
* @brief Returns a vector containing the trsm benchmark parameters (also
* valid for trmm), either read from a file according to the command-line
* args, or the default ones.
*/
template <typename scalar_t>
static inline std::vector<trsm_param_t<scalar_t>> get_trsm_params(Args& args) {
Expand Down Expand Up @@ -1274,8 +1361,9 @@ static inline std::vector<matcopy_param_t<scalar_t>> get_matcopy_params(

/**
* @fn get_omatcopy2_params
* @brief Returns a vector containing the omatcopy2 benchmark parameters, either
* read from a file according to the command-line args, or the default ones.
* @brief Returns a vector containing the omatcopy2 benchmark parameters,
* either read from a file according to the command-line args, or the default
* ones.
*/
template <typename scalar_t>
static inline std::vector<omatcopy2_param_t<scalar_t>> get_omatcopy2_params(
Expand Down