From 5c08f2f6f7390bbe5f138acd25d63b0bcd56b7f8 Mon Sep 17 00:00:00 2001 From: Michael Tuttle Date: Wed, 30 Oct 2024 18:55:32 -0700 Subject: [PATCH] Allocate encoding array on heap Signed-off-by: Michael Tuttle --- .../onnx/src/QuantizeDequantizeUtils.hpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/TrainingExtensions/onnx/src/QuantizeDequantizeUtils.hpp b/TrainingExtensions/onnx/src/QuantizeDequantizeUtils.hpp index 31087f117f5..7b2f066222a 100644 --- a/TrainingExtensions/onnx/src/QuantizeDequantizeUtils.hpp +++ b/TrainingExtensions/onnx/src/QuantizeDequantizeUtils.hpp @@ -198,13 +198,14 @@ void quantizeDequantizeBroadcast(const T* inTensor, T* outTensor, const Broadcas const int64_t* encodingStrides = shapeInfo.encodingStrides.data(); // Kernels expect separate lists for each encoding type - T encVec[4][numEncodings]; + std::vector encVec(4 * numEncodings); + for (int i = 0; i < numEncodings; i++) { - encVec[0][i] = encodings[i]->min; - encVec[1][i] = encodings[i]->max; - encVec[2][i] = encodings[i]->delta; - encVec[3][i] = encodings[i]->offset; + encVec[i] = encodings[i]->min; + encVec[numEncodings + i] = encodings[i]->max; + encVec[2 * numEncodings + i] = encodings[i]->delta; + encVec[3 * numEncodings + i] = encodings[i]->offset; } T* encodingVectorDevice; int64_t* stridesDevice = nullptr; @@ -217,7 +218,7 @@ void quantizeDequantizeBroadcast(const T* inTensor, T* outTensor, const Broadcas encodingVectorDevice = static_cast(allocator->allocateRaw(4 * numEncodings * sizeof(T))); // Send encoding information to device - cudaMemcpyAsync(encodingVectorDevice, encVec, 4 * numEncodings * sizeof(T), cudaMemcpyHostToDevice, + cudaMemcpyAsync(encodingVectorDevice, encVec.data(), 4 * numEncodings * sizeof(T), cudaMemcpyHostToDevice, static_cast(stream)); // Send stride information to device @@ -237,7 +238,7 @@ void quantizeDequantizeBroadcast(const T* inTensor, T* outTensor, const Broadcas else { mode = DlQuantization::ComputationMode::COMP_MODE_CPU; - encodingVectorDevice = (T*) (encVec); + encodingVectorDevice = encVec.data(); } T* encodingMin = encodingVectorDevice;