Skip to content

Commit

Permalink
Fix PCQ race condition
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <quic_mtuttle@quicinc.com>
  • Loading branch information
quic-mtuttle committed Nov 12, 2024
1 parent ef829a9 commit cc24faa
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions TrainingExtensions/onnx/src/QuantizeDequantizeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,27 +126,29 @@ void quantizeDequantizePerChannel(
}
}

T encVec[4][channels];
std::vector<T> encVec(4 * channels);

for (int i = 0; i < channels; i++)
{
encVec[0][i] = encodings[i]->min;
encVec[1][i] = encodings[i]->max;
encVec[2][i] = encodings[i]->delta;
encVec[3][i] = encodings[i]->offset;
encVec[i] = encodings[i]->min;
encVec[channels + i] = encodings[i]->max;
encVec[2 * channels + i] = encodings[i]->delta;
encVec[3 * channels + i] = encodings[i]->offset;
}
T* encodingVectorDevice;
if (useCuda)
{
#ifdef ONNX_CUDA
encodingVectorDevice = (T*) allocator->allocateRaw(4 * channels * sizeof(T));
cudaMemcpy(encodingVectorDevice, encVec, 4 * channels * sizeof(T), cudaMemcpyHostToDevice);
cudaMemcpyAsync(encodingVectorDevice, encVec.data(), 4 * channels * sizeof(T), cudaMemcpyHostToDevice,
static_cast<cudaStream_t>(stream));
#else
throw std::runtime_error("Not compiled for GPU mode.");
#endif
}
else
{
encodingVectorDevice = (T*) encVec;
encodingVectorDevice = (T*) encVec.data();
}

T* encodingMin = encodingVectorDevice;
Expand Down

0 comments on commit cc24faa

Please sign in to comment.