Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Enable Proton for XPU #2635

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/build-test-reusable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ jobs:
echo TRITON_TEST_CMD="bash -v -x scripts/test-triton.sh --warning-reports --skip-pytorch-install --reports-dir $GITHUB_WORKSPACE/reports ${{ inputs.ignore_errors && '--ignore-errors' || '' }} $skiplist"
} | tee -a $GITHUB_ENV
- name: Run Proton tests
if: ${{ inputs.driver_version == 'rolling' }}
run: |
cd third_party/proton/test
pytest test_api.py test_lib.py test_profile.py test_viewer.py -s -v
cd ..
- name: Run unit tests
run: |
${{ env.TRITON_TEST_CMD }} --unit
Expand Down
10 changes: 10 additions & 0 deletions third_party/proton/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ endif()
include_directories(${JSON_INCLUDE_DIR})
include_directories(${PROTON_SRC_DIR}/include)
include_directories(${PROTON_EXTERN_DIR})
include_directories(/opt/intel/oneapi/pti/latest/include)
include_directories(/opt/intel/oneapi/compiler/latest/include)
include_directories(/opt/intel/oneapi/compiler/latest/include/sycl)

find_package(Python3 REQUIRED Interpreter Development.Module)
find_package(pybind11 CONFIG REQUIRED HINTS "${Python3_SITELIB}")
Expand All @@ -38,5 +41,12 @@ include_directories(${CUPTI_INCLUDE_DIR})
include_directories(SYSTEM ${ROCTRACER_INCLUDE_DIR})
target_compile_definitions(proton PRIVATE __HIP_PLATFORM_AMD__)

set_target_properties(proton PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
# set_target_properties(proton PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
target_link_libraries(proton PRIVATE Python3::Module pybind11::headers)
target_link_libraries(proton PRIVATE /opt/intel/oneapi/compiler/latest/lib/libsycl.so)
target_link_libraries(proton PRIVATE /usr/lib/x86_64-linux-gnu/libze_intel_gpu.so.1)
target_link_libraries(proton PRIVATE /usr/lib/x86_64-linux-gnu/libze_tracing_layer.so.1)
target_link_libraries(proton PRIVATE /usr/lib/x86_64-linux-gnu/libze_loader.so.1)
target_link_libraries(proton PRIVATE /opt/intel/oneapi/pti/latest/lib/libpti_view.so)
target_link_options(proton PRIVATE ${PROTON_PYTHON_LDFLAGS})
1 change: 0 additions & 1 deletion third_party/proton/csrc/include/Context/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <atomic>
#include <limits>
#include <map>
#include <mutex>
#include <optional>
#include <string>
#include <vector>
Expand Down
1 change: 1 addition & 0 deletions third_party/proton/csrc/include/Data/Metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define PROTON_DATA_METRIC_H_

#include "Utility/Traits.h"
#include <string>
#include <variant>
#include <vector>

Expand Down
1 change: 0 additions & 1 deletion third_party/proton/csrc/include/Data/TreeData.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "Context/Context.h"
#include "Data.h"
#include <stdexcept>
#include <unordered_map>

namespace proton {
Expand Down
7 changes: 6 additions & 1 deletion third_party/proton/csrc/include/Driver/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace proton {

enum class DeviceType { HIP, CUDA, COUNT };
enum class DeviceType { XPU, HIP, CUDA, COUNT };

template <DeviceType T> struct DeviceTraits;

Expand All @@ -20,6 +20,11 @@ template <> struct DeviceTraits<DeviceType::HIP> {
constexpr static const char *name = "HIP";
};

template <> struct DeviceTraits<DeviceType::XPU> {
constexpr static DeviceType type = DeviceType::XPU;
constexpr static const char *name = "XPU";
};

struct Device {
DeviceType type;
uint64_t id;
Expand Down
30 changes: 30 additions & 0 deletions third_party/proton/csrc/include/Driver/GPU/XpuApi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef PROTON_DRIVER_GPU_SYCL_H_
#define PROTON_DRIVER_GPU_SYCL_H_

#include "Driver/Device.h"
#include <level_zero/ze_api.h>

namespace proton {

namespace xpu {

template <bool CheckSuccess> ze_result_t init(ze_init_flags_t flags);

template <bool CheckSuccess>
ze_result_t ctxSynchronize(ze_command_queue_handle_t hCommandQueue,
uint64_t timeout);

/*

template <bool CheckSuccess> CUresult ctxGetCurrent(CUcontext *pctx);

template <bool CheckSuccess> CUresult deviceGet(CUdevice *device, int ordinal);
*/

Device getDevice(uint64_t index);

} // namespace xpu

} // namespace proton

#endif // PROTON_DRIVER_GPU_SYCL_H_
116 changes: 116 additions & 0 deletions third_party/proton/csrc/include/Driver/GPU/XpuptiApi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#ifndef PROTON_DRIVER_GPU_XPUPTI_H_
#define PROTON_DRIVER_GPU_XPUPTI_H_

#include <pti/pti_view.h>

namespace proton {

namespace xpupti {

using Pti_Activity = pti_view_record_base;

/*
template <bool CheckSuccess> CUptiResult getVersion(uint32_t *version);

template <bool CheckSuccess>
CUptiResult getContextId(CUcontext context, uint32_t *pCtxId);

template <bool CheckSuccess>
CUptiResult activityRegisterCallbacks(
CUpti_BuffersCallbackRequestFunc funcBufferRequested,
CUpti_BuffersCallbackCompleteFunc funcBufferCompleted);

template <bool CheckSuccess>
CUptiResult subscribe(CUpti_SubscriberHandle *subscriber,
CUpti_CallbackFunc callback, void *userdata);

template <bool CheckSuccess>
CUptiResult enableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain);

template <bool CheckSuccess>
CUptiResult enableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber,
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);

template <bool CheckSuccess>
CUptiResult activityEnableContext(CUcontext context, CUpti_ActivityKind kind);

template <bool CheckSuccess>
CUptiResult activityDisableContext(CUcontext context, CUpti_ActivityKind kind);
*/

template <bool CheckSuccess> pti_result viewEnable(pti_view_kind kind);

template <bool CheckSuccess> pti_result viewDisable(pti_view_kind kind);

template <bool CheckSuccess> pti_result viewFlushAll();

/*
template <bool CheckSuccess>
CUptiResult activityGetNextRecord(uint8_t *buffer, size_t validBufferSizeBytes,
CUpti_Activity **record);

template <bool CheckSuccess>
CUptiResult
activityPushExternalCorrelationId(CUpti_ExternalCorrelationKind kind,
uint64_t id);

template <bool CheckSuccess>
CUptiResult activityPopExternalCorrelationId(CUpti_ExternalCorrelationKind kind,
uint64_t *lastId);

template <bool CheckSuccess>
CUptiResult activitySetAttribute(CUpti_ActivityAttribute attr,
size_t *valueSize, void *value);

template <bool CheckSuccess>
CUptiResult unsubscribe(CUpti_SubscriberHandle subscriber);

template <bool CheckSuccess> CUptiResult finalize();

template <bool CheckSuccess>
CUptiResult getGraphExecId(CUgraphExec graph, uint32_t *pId);

template <bool CheckSuccess>
CUptiResult getGraphId(CUgraph graph, uint32_t *pId);

template <bool CheckSuccess>
CUptiResult getCubinCrc(CUpti_GetCubinCrcParams *pParams);

template <bool CheckSuccess>
CUptiResult
getSassToSourceCorrelation(CUpti_GetSassToSourceCorrelationParams *pParams);

template <bool CheckSuccess>
CUptiResult
pcSamplingGetNumStallReasons(CUpti_PCSamplingGetNumStallReasonsParams *pParams);

template <bool CheckSuccess>
CUptiResult
pcSamplingGetStallReasons(CUpti_PCSamplingGetStallReasonsParams *pParams);

template <bool CheckSuccess>
CUptiResult pcSamplingSetConfigurationAttribute(
CUpti_PCSamplingConfigurationInfoParams *pParams);

template <bool CheckSuccess>
CUptiResult pcSamplingEnable(CUpti_PCSamplingEnableParams *pParams);

template <bool CheckSuccess>
CUptiResult pcSamplingDisable(CUpti_PCSamplingDisableParams *pParams);

template <bool CheckSuccess>
CUptiResult pcSamplingGetData(CUpti_PCSamplingGetDataParams *pParams);

template <bool CheckSuccess>
CUptiResult pcSamplingStart(CUpti_PCSamplingStartParams *pParams);

template <bool CheckSuccess>
CUptiResult pcSamplingStop(CUpti_PCSamplingStopParams *pParams);
*/

} // namespace xpupti

} // namespace proton

#endif // PROTON_EXTERN_DISPATCH_H_
2 changes: 2 additions & 0 deletions third_party/proton/csrc/include/Profiler/GPUProfiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "Utility/Atomic.h"
#include "Utility/Map.h"
#include "Utility/Set.h"
#include <iostream>

#include <atomic>
#include <deque>
Expand Down Expand Up @@ -72,6 +73,7 @@ class GPUProfiler : public Profiler,
void enterOp(size_t scopeId) {
if (profiler.isOpInProgress())
return;
std::cout << "\tenterOp:: pushExternId: " << scopeId << "\n";
profiler.correlation.pushExternId(scopeId);
profiler.setOpInProgress(true);
}
Expand Down
5 changes: 0 additions & 5 deletions third_party/proton/csrc/include/Profiler/Profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,9 @@
#include "Data/Data.h"
#include "Utility/Singleton.h"

#include <atomic>
#include <cstdint>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include <shared_mutex>
#include <string>

namespace proton {

Expand Down
19 changes: 19 additions & 0 deletions third_party/proton/csrc/include/Profiler/Xpupti/XpuptiProfiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef PROTON_PROFILER_XPUPTI_PROFILER_H_
#define PROTON_PROFILER_XPUPTI_PROFILER_H_

#include "Profiler/GPUProfiler.h"

namespace proton {

class XpuptiProfiler : public GPUProfiler<XpuptiProfiler> {
public:
XpuptiProfiler();
virtual ~XpuptiProfiler();

private:
struct XpuptiProfilerPimpl;
};

} // namespace proton

#endif // PROTON_PROFILER_XPUPTI_PROFILER_H_
1 change: 0 additions & 1 deletion third_party/proton/csrc/include/Session/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "Utility/Singleton.h"
#include <map>
#include <memory>
#include <set>
#include <shared_mutex>
#include <string>
#include <vector>
Expand Down
1 change: 1 addition & 0 deletions third_party/proton/csrc/include/Utility/Map.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define PROTON_UTILITY_MAP_H_

#include <map>
#include <mutex>
#include <shared_mutex>

namespace proton {
Expand Down
1 change: 1 addition & 0 deletions third_party/proton/csrc/include/Utility/Set.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef PROTON_UTILITY_SET_H_
#define PROTON_UTILITY_SET_H_

#include <mutex>
#include <set>
#include <shared_mutex>

Expand Down
2 changes: 0 additions & 2 deletions third_party/proton/csrc/lib/Data/TraceData.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "Data/TraceData.h"
#include "Utility/Errors.h"

#include <stdexcept>

namespace proton {

void TraceData::startOp(const Scope &scope) { throw NotImplemented(); }
Expand Down
27 changes: 24 additions & 3 deletions third_party/proton/csrc/lib/Data/TreeData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "Data/Metric.h"
#include "Driver/Device.h"
#include "nlohmann/json.hpp"
#include <iostream>

#include <limits>
#include <map>
Expand Down Expand Up @@ -134,17 +135,28 @@ size_t TreeData::addScope(size_t parentScopeId, const std::string &name) {
}

void TreeData::addMetric(size_t scopeId, std::shared_ptr<Metric> metric) {
std::cout << "\taddMetric\n";
std::unique_lock<std::shared_mutex> lock(mutex);
auto scopeIdIt = scopeIdToContextId.find(scopeId);
// The profile data is deactived, ignore the metric
if (scopeIdIt == scopeIdToContextId.end())
if (scopeIdIt == scopeIdToContextId.end()) {
std::cout << "MARK111\n" << std::flush;
return;
}
auto contextId = scopeIdIt->second;
std::cout << "\taddMetric::contextId: " << contextId << "\n";
auto &node = tree->getNode(contextId);
if (node.metrics.find(metric->getKind()) == node.metrics.end())
if (node.metrics.find(metric->getKind()) == node.metrics.end()) {
std::cout << "MARK112\n" << std::flush;
std::cout << "duration: "
<< std::get<uint64_t>(metric->getValue(KernelMetric::Duration))
<< "\n"
<< std::flush;
node.metrics.emplace(metric->getKind(), metric);
else
} else {
std::cout << "MARK113\n" << std::flush;
node.metrics[metric->getKind()]->updateMetric(*metric);
}
}

void TreeData::addMetrics(size_t scopeId,
Expand Down Expand Up @@ -184,23 +196,32 @@ void TreeData::dumpHatchet(std::ostream &os) const {
&treeNode) {
const auto contextName = treeNode.name;
auto contextId = treeNode.id;
std::cout << "\t dumpHatchet::contextId: " << contextId << "\n";
json *jsonNode = jsonNodes[contextId];
(*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}};
(*jsonNode)["metrics"] = json::object();
for (auto [metricKind, metric] : treeNode.metrics) {
std::cout << "MARK: dumpHatchet\n";
if (metricKind == MetricKind::Kernel) {
std::cout << "metricKind == MetricKind::Kernel\n";
std::shared_ptr<KernelMetric> kernelMetric =
std::dynamic_pointer_cast<KernelMetric>(metric);
uint64_t duration =
std::get<uint64_t>(kernelMetric->getValue(KernelMetric::Duration));
std::cout << "\t dumpHatchet::duration: " << duration << "\n";
uint64_t invocations = std::get<uint64_t>(
kernelMetric->getValue(KernelMetric::Invocations));
std::cout << "\t dumpHatchet::invocations: " << invocations << "\n";
uint64_t deviceId =
std::get<uint64_t>(kernelMetric->getValue(KernelMetric::DeviceId));
std::cout << "\t dumpHatchet::deviceId: " << deviceId << "\n";
uint64_t deviceType = std::get<uint64_t>(
kernelMetric->getValue(KernelMetric::DeviceType));
std::cout << "\t dumpHatchet::deviceType: " << deviceType << "\n";
std::string deviceTypeName =
getDeviceTypeString(static_cast<DeviceType>(deviceType));
std::cout << "\t dumpHatchet::deviceTypeName: " << deviceTypeName
<< "\n";
(*jsonNode)["metrics"]
[kernelMetric->getValueName(KernelMetric::Duration)] =
duration;
Expand Down
Loading