diff --git a/benchmarks/triton_kernels_benchmark/benchmark_driver.py b/benchmarks/triton_kernels_benchmark/benchmark_driver.py index d119f3962e..996b28b7ea 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_driver.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_driver.py @@ -1,59 +1,19 @@ import os -import hashlib -import importlib.util -import tempfile from pathlib import Path from triton.backends.compiler import GPUTarget from triton.backends.driver import DriverBase -from triton.runtime.cache import get_cache_manager -from triton.runtime.build import _build, quiet from triton._utils import parse_list_string +from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER import torch -_dirname = os.getenv("ZE_PATH", default="/usr/local") - -include_dir = [ - os.path.join(_dirname, "include"), - os.path.join(torch.utils.cmake_prefix_path, "../../include"), - os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include") -] - -oneapi_root = os.getenv("ONEAPI_ROOT") -if oneapi_root: - include_dir += [ - os.path.join(oneapi_root, "compiler/latest/include"), - os.path.join(oneapi_root, "compiler/latest/include/sycl") - ] - -library_dir = [os.path.join(_dirname, "lib"), os.path.join(torch.utils.cmake_prefix_path, "../../lib")] -libraries = ["ze_loader", "sycl", "torch"] - - -def compile_module_from_src(src, name): - key = hashlib.sha256(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.cpp") - with open(src_path, "w", encoding="utf-8") as f: - f.write(src) - with quiet(): - so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - - # ------------------------ # Utils # ------------------------ +COMPILATION_HELPER.inject_pytorch_dep() + class XPUUtils: diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 7be311ce4f..aa7e536543 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -68,6 +68,10 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]: class CompilationHelper: _library_dir: list[str] _include_dir: list[str] + libraries: list[str] + + # for benchmarks + _build_with_pytorch_dep: bool = False def __init__(self): self._library_dir = None @@ -77,6 +81,12 @@ def __init__(self): if os.name != "nt": self.libraries += ["sycl"] + def inject_pytorch_dep(self): + # must be called before any cached properties (if pytorch is needed) + if self._build_with_pytorch_dep is False: + self._build_with_pytorch_dep = True + self.libraries += ['torch'] + @cached_property def _compute_compilation_options_lazy(self): ze_root = os.getenv("ZE_PATH", default="/usr/local") @@ -91,9 +101,18 @@ def _compute_compilation_options_lazy(self): dirname = os.path.dirname(os.path.realpath(__file__)) include_dir += [os.path.join(dirname, "include")] - # TODO: do we need this? library_dir += [os.path.join(dirname, "lib")] + if self._build_with_pytorch_dep: + import torch + + torch_path = torch.utils.cmake_prefix_path + include_dir += [ + os.path.join(torch_path, "../../include"), + os.path.join(torch_path, "../../include/torch/csrc/api/include"), + ] + library_dir += [os.path.join(torch_path, "../../lib")] + self._library_dir = library_dir self._include_dir = include_dir @@ -113,7 +132,7 @@ def libsycl_dir(self) -> Optional[str]: return self._libsycl_dir -compilation_helper = CompilationHelper() +COMPILATION_HELPER = CompilationHelper() def compile_module_from_src(src, name): @@ -127,10 +146,10 @@ def compile_module_from_src(src, name): with open(src_path, "w") as f: f.write(src) extra_compiler_args = [] - if compilation_helper.libsycl_dir: - extra_compiler_args += ['-Wl,-rpath,' + compilation_helper.libsycl_dir] - so = _build(name, src_path, tmpdir, compilation_helper.library_dir, compilation_helper.include_dir, - compilation_helper.libraries, extra_compile_args=extra_compiler_args) + if COMPILATION_HELPER.libsycl_dir: + extra_compiler_args += ['-Wl,-rpath,' + COMPILATION_HELPER.libsycl_dir] + so = _build(name, src_path, tmpdir, COMPILATION_HELPER.library_dir, COMPILATION_HELPER.include_dir, + COMPILATION_HELPER.libraries, extra_compile_args=extra_compiler_args) with open(so, "rb") as f: cache_path = cache.put(f.read(), file_name, binary=True) import importlib.util