Skip to content

Commit

Permalink
Reuse compile_module_from_src func in benchmark_driver.py (#3051)
Browse files Browse the repository at this point in the history
Part of #2540

---------

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
  • Loading branch information
anmyachev authored Dec 21, 2024
1 parent 7d03355 commit e4e0905
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 49 deletions.
46 changes: 3 additions & 43 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,19 @@
import os
import hashlib
import importlib.util
import tempfile
from pathlib import Path

from triton.backends.compiler import GPUTarget
from triton.backends.driver import DriverBase
from triton.runtime.cache import get_cache_manager
from triton.runtime.build import _build, quiet
from triton._utils import parse_list_string
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER

import torch

_dirname = os.getenv("ZE_PATH", default="/usr/local")

include_dir = [
os.path.join(_dirname, "include"),
os.path.join(torch.utils.cmake_prefix_path, "../../include"),
os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include")
]

oneapi_root = os.getenv("ONEAPI_ROOT")
if oneapi_root:
include_dir += [
os.path.join(oneapi_root, "compiler/latest/include"),
os.path.join(oneapi_root, "compiler/latest/include/sycl")
]

library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")]
libraries = ["ze_loader", "sycl", "torch"]


def compile_module_from_src(src, name):
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
cache_path = cache.get_file(f"{name}.so")
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.cpp")
with open(src_path, "w", encoding="utf-8") as f:
f.write(src)
with quiet():
so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
spec = importlib.util.spec_from_file_location(name, cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod


# ------------------------
# Utils
# ------------------------

COMPILATION_HELPER.inject_pytorch_dep()


class XPUUtils:

Expand Down
31 changes: 25 additions & 6 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]:
class CompilationHelper:
_library_dir: list[str]
_include_dir: list[str]
libraries: list[str]

# for benchmarks
_build_with_pytorch_dep: bool = False

def __init__(self):
self._library_dir = None
Expand All @@ -77,6 +81,12 @@ def __init__(self):
if os.name != "nt":
self.libraries += ["sycl"]

def inject_pytorch_dep(self):
# must be called before any cached properties (if pytorch is needed)
if self._build_with_pytorch_dep is False:
self._build_with_pytorch_dep = True
self.libraries += ['torch']

@cached_property
def _compute_compilation_options_lazy(self):
ze_root = os.getenv("ZE_PATH", default="/usr/local")
Expand All @@ -91,9 +101,18 @@ def _compute_compilation_options_lazy(self):

dirname = os.path.dirname(os.path.realpath(__file__))
include_dir += [os.path.join(dirname, "include")]
# TODO: do we need this?
library_dir += [os.path.join(dirname, "lib")]

if self._build_with_pytorch_dep:
import torch

torch_path = torch.utils.cmake_prefix_path
include_dir += [
os.path.join(torch_path, "../../include"),
os.path.join(torch_path, "../../include/torch/csrc/api/include"),
]
library_dir += [os.path.join(torch_path, "../../lib")]

self._library_dir = library_dir
self._include_dir = include_dir

Expand All @@ -113,7 +132,7 @@ def libsycl_dir(self) -> Optional[str]:
return self._libsycl_dir


compilation_helper = CompilationHelper()
COMPILATION_HELPER = CompilationHelper()


def compile_module_from_src(src, name):
Expand All @@ -127,10 +146,10 @@ def compile_module_from_src(src, name):
with open(src_path, "w") as f:
f.write(src)
extra_compiler_args = []
if compilation_helper.libsycl_dir:
extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir]
so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir,
compilation_helper.libraries, extra_compile_args=extra_compiler_args)
if COMPILATION_HELPER.libsycl_dir:
extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER.libsycl_dir]
so = _build(name, src_path, tmpdir, COMPILATION_HELPER.library_dir, COMPILATION_HELPER.include_dir,
COMPILATION_HELPER.libraries, extra_compile_args=extra_compiler_args)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), file_name, binary=True)
import importlib.util
Expand Down

0 comments on commit e4e0905

Please sign in to comment.