From d4895e093ae4c13a2963b0541d13742ef8748715 Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Mon, 23 Oct 2023 18:09:14 +0200 Subject: [PATCH 1/2] Add `mkl` support --- candle-pyo3/Cargo.toml | 2 ++ candle-pyo3/py_src/candle/__init__.py | 49 ++++++++++++++++++++------- candle-pyo3/src/lib.rs | 3 ++ 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 5ef0240d3b..8bccbcc690 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -18,6 +18,7 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.3.0" } half = { workspace = true } pyo3 = { version = "0.19.0", features = ["extension-module"] } +intel-mkl-src = { workspace = true, optional = true } [build-dependencies] pyo3-build-config = "0.19" @@ -25,3 +26,4 @@ pyo3-build-config = "0.19" [features] default = [] cuda = ["candle/cuda"] +mkl = ["dep:intel-mkl-src","candle/mkl"] diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py index dc97b775d5..fe92650d11 100644 --- a/candle-pyo3/py_src/candle/__init__.py +++ b/candle-pyo3/py_src/candle/__init__.py @@ -3,27 +3,52 @@ try: from .candle import * except ImportError as e: - # If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here - logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...") + # If we are in development mode, or we did not bundle the DLLs, we try to locate them here + # PyO3 wont give us any infomration about what DLLs are missing, so we can only try to load the DLLs and re-import the module + logging.warning("DLLs were not bundled with this package. Trying to locate them...") import os import platform - # Try to locate CUDA_PATH environment variable - cuda_path = os.environ.get("CUDA_PATH", None) - if cuda_path: - logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}") - if platform.system() == "Windows": - cuda_path = os.path.join(cuda_path, "bin") + def locate_cuda_dlls(): + logging.warning("Locating CUDA DLLs...") + # Try to locate CUDA_PATH environment variable + cuda_path = os.environ.get("CUDA_PATH", None) + if cuda_path: + logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}") + if platform.system() == "Windows": + cuda_path = os.path.join(cuda_path, "bin") + else: + cuda_path = os.path.join(cuda_path, "lib64") + + logging.warning(f"Adding {cuda_path} to DLL search path...") + os.add_dll_directory(cuda_path) + else: + logging.warning("CUDA_PATH environment variable not found!") + + def locate_mkl_dlls(): + # Try to locate ONEAPI_ROOT environment variable + oneapi_root = os.environ.get("ONEAPI_ROOT", None) + if oneapi_root: + if platform.system() == "Windows": + mkl_path = os.path.join( + oneapi_root, "compiler", "latest", "windows", "redist", "intel64_win", "compiler" + ) + else: + # Unsure of this is correct + mkl_path = os.path.join(oneapi_root, "mkl", "latest", "lib") + + logging.warning(f"Adding {mkl_path} to DLL search path...") + os.add_dll_directory(mkl_path) else: - cuda_path = os.path.join(cuda_path, "lib64") + logging.warning("ONEAPI_ROOT environment variable not found!") - logging.warning(f"Adding {cuda_path} to DLL search path...") - os.add_dll_directory(cuda_path) + locate_cuda_dlls() + locate_mkl_dlls() try: from .candle import * except ImportError as inner_e: - raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.") + raise ImportError("Could not locate DLLs. Please check the documentation for more information.") __doc__ = candle.__doc__ if hasattr(candle, "__all__"): diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index f16d8c1b18..29f38ff802 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -8,6 +8,9 @@ use std::sync::Arc; use half::{bf16, f16}; +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; pub fn wrap_err(err: ::candle::Error) -> PyErr { From 72c6c6bc9827dc3eafbff2dd6064576b31be995e Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Mon, 23 Oct 2023 20:40:39 +0200 Subject: [PATCH 2/2] Set `mkl` path on linux --- candle-pyo3/py_src/candle/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py index fe92650d11..38718a46cc 100644 --- a/candle-pyo3/py_src/candle/__init__.py +++ b/candle-pyo3/py_src/candle/__init__.py @@ -34,8 +34,7 @@ def locate_mkl_dlls(): oneapi_root, "compiler", "latest", "windows", "redist", "intel64_win", "compiler" ) else: - # Unsure of this is correct - mkl_path = os.path.join(oneapi_root, "mkl", "latest", "lib") + mkl_path = os.path.join(oneapi_root, "mkl", "latest", "lib", "intel64") logging.warning(f"Adding {mkl_path} to DLL search path...") os.add_dll_directory(mkl_path)