Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Fixed parameter mismatch in complex gemm benchmarks (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki authored Jan 3, 2024
1 parent 2ea49db commit 29e2c93
Showing 1 changed file with 165 additions and 77 deletions.
242 changes: 165 additions & 77 deletions common/include/common/common_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,17 +491,29 @@ static inline std::vector<blas3_param_t<scalar_t>> get_blas3_params(
} else {
return parse_csv_file<blas3_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 7) {
if (v.size() == 7) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]));
} catch (...) {
throw std::runtime_error("invalid parameter");
}
} else if (v.size() == 9) {
// Case where complex alpha and beta are passed, take the real
// part only
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[7]));
} catch (...) {
throw std::runtime_error("invalid parameter");
}
} else {
throw std::runtime_error(
"invalid number of parameters (7 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]));
} catch (...) {
throw std::runtime_error("invalid parameter");
"invalid number of parameters (7 or 9 expected)");
}
});
}
Expand Down Expand Up @@ -541,18 +553,32 @@ static inline std::vector<blas3_cplx_param_t<scalar_t>> get_blas3_cplx_params(
} else {
return parse_csv_file<blas3_cplx_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 9) {
if (v.size() == 9) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]));
} catch (...) {
throw std::runtime_error("invalid parameter");
}
} else if (v.size() == 7) {
// Case where scalar alpha and beta are passed, duplicate the value
// for both real and imaginary parts
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[5]),
str_to_scalar<scalar_t>(v[6]), str_to_scalar<scalar_t>(v[6]));
} catch (...) {
throw std::runtime_error("invalid parameter");
}

} else {
throw std::runtime_error(
"invalid number of parameters (9 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]));
} catch (...) {
throw std::runtime_error("invalid parameter");
"invalid number of parameters (9 or 7 expected)");
}
});
}
Expand Down Expand Up @@ -589,20 +615,36 @@ get_gemm_batched_strided_cplx_params(Args& args) {
} else {
return parse_csv_file<gemm_batched_strided_cplx_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 13) {
if (v.size() == 13) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]),
str_to_int<index_t>(v[11]), str_to_int<index_t>(v[12]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}

} else if (v.size() == 11) {
// Case where scalar alpha and beta are passed, duplicate the value
// for both real and imaginary parts
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[5]),
str_to_scalar<scalar_t>(v[6]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_int<index_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else {
throw std::runtime_error(
"invalid number of parameters (13 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]),
str_to_int<index_t>(v[11]), str_to_int<index_t>(v[12]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
"invalid number of parameters (13 or 11 expected)");
}
});
}
Expand Down Expand Up @@ -639,19 +681,34 @@ get_gemm_cplx_batched_params(Args& args) {
} else {
return parse_csv_file<gemm_batched_cplx_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 11) {
if (v.size() == 11) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_batch_type(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else if (v.size() == 9) {
// Case where scalar alpha and beta are passed, duplicate the value
// for both real and imaginary parts
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[5]),
str_to_scalar<scalar_t>(v[6]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_batch_type(v[8]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}

} else {
throw std::runtime_error(
"invalid number of parameters (11 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_scalar<scalar_t>(v[7]), str_to_scalar<scalar_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_batch_type(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
"invalid number of parameters (11 or 9 expected)");
}
});
}
Expand Down Expand Up @@ -688,18 +745,32 @@ inline std::vector<gemm_batched_param_t<scalar_t>> get_gemm_batched_params(
} else {
return parse_csv_file<gemm_batched_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 9) {
if (v.size() == 9) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_batch_type(v[8]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else if (v.size() == 11) {
// Case where complex alpha and beta are passed, take the real
// part only
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[7]),
str_to_int<index_t>(v[9]), str_to_batch_type(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}

} else {
throw std::runtime_error(
"invalid number of parameters (9 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_batch_type(v[8]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
"invalid number of parameters (9 or 11 expected)");
}
});
}
Expand Down Expand Up @@ -735,19 +806,33 @@ get_gemm_batched_strided_params(Args& args) {
} else {
return parse_csv_file<gemm_batched_strided_param_t<scalar_t>>(
args.csv_param, [&](std::vector<std::string>& v) {
if (v.size() != 11) {
if (v.size() == 11) {
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_int<index_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else if (v.size() == 13) {
// Case where complex alpha and beta are passed, take the real
// part only
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[7]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]),
str_to_int<index_t>(v[11]), str_to_int<index_t>(v[12]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
}
} else {
throw std::runtime_error(
"invalid number of parameters (11 expected)");
}
try {
return std::make_tuple(
v[0].c_str(), v[1].c_str(), str_to_int<index_t>(v[2]),
str_to_int<index_t>(v[3]), str_to_int<index_t>(v[4]),
str_to_scalar<scalar_t>(v[5]), str_to_scalar<scalar_t>(v[6]),
str_to_int<index_t>(v[7]), str_to_int<index_t>(v[8]),
str_to_int<index_t>(v[9]), str_to_int<index_t>(v[10]));
} catch (...) {
std::throw_with_nested(std::runtime_error("invalid parameter"));
"invalid number of parameters (11 or 13 expected)");
}
});
}
Expand All @@ -766,7 +851,8 @@ get_trsm_batched_params(Args& args) {
warning_no_csv();
std::vector<trsm_batched_param_t<scalar_t>> trsm_batched_default;
constexpr index_t dmin = 512, dmax = 8192;
// Stride Multipliers are set by default and correspond to default striding
// Stride Multipliers are set by default and correspond to default
// striding
constexpr index_t stride_a_mul = 1;
constexpr index_t stride_b_mul = 1;
constexpr index_t batch_size = 8;
Expand Down Expand Up @@ -806,8 +892,9 @@ get_trsm_batched_params(Args& args) {
}
/**
* @fn get_reduction_params
* @brief Returns a vector containing the reduction benchmark parameters, either
* read from a file according to the command-line args, or the default ones.
* @brief Returns a vector containing the reduction benchmark parameters,
* either read from a file according to the command-line args, or the default
* ones.
*/
template <typename scalar_t>
static inline std::vector<reduction_param_t> get_reduction_params(Args& args) {
Expand Down Expand Up @@ -920,9 +1007,9 @@ static inline std::vector<syrk_param_t<scalar_t>> get_syrk_params(Args& args) {
}
/**
* @fn get_trsm_params
* @brief Returns a vector containing the trsm benchmark parameters (also valid
* for trmm), either read from a file according to the command-line args, or the
* default ones.
* @brief Returns a vector containing the trsm benchmark parameters (also
* valid for trmm), either read from a file according to the command-line
* args, or the default ones.
*/
template <typename scalar_t>
static inline std::vector<trsm_param_t<scalar_t>> get_trsm_params(Args& args) {
Expand Down Expand Up @@ -1274,8 +1361,9 @@ static inline std::vector<matcopy_param_t<scalar_t>> get_matcopy_params(

/**
* @fn get_omatcopy2_params
* @brief Returns a vector containing the omatcopy2 benchmark parameters, either
* read from a file according to the command-line args, or the default ones.
* @brief Returns a vector containing the omatcopy2 benchmark parameters,
* either read from a file according to the command-line args, or the default
* ones.
*/
template <typename scalar_t>
static inline std::vector<omatcopy2_param_t<scalar_t>> get_omatcopy2_params(
Expand Down

0 comments on commit 29e2c93

Please sign in to comment.