Skip to content

Commit

Permalink
Code review fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Jan 9, 2025
1 parent c32b46c commit d0775cf
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
check_numba_compatibility_cpu,
has_operator,
restrict_platform,
is_of_supported
)
from nose2.tools import params, cartesian_params
from nose_utils import assert_raises, SkipTest, attr
Expand Down Expand Up @@ -575,8 +576,6 @@ def test_preemphasis_filter_stateless(device):

@stateless_signed_off("optical_flow")
def test_optical_flow_stateless():
from test_optical_flow import is_of_supported

if not is_of_supported():
raise SkipTest("Optical Flow is not supported on this platform")
check_single_sequence_input(fn.optical_flow, "gpu")
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion dali/test/python/test_dali_variable_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
import test_utils
from segmentation_test_utils import make_batch_select_masks
from test_detection_pipeline import coco_anchors
from test_optical_flow import load_frames, is_of_supported
from test_utils import (
module_functions,
has_operator,
restrict_platform,
check_numba_compatibility_cpu,
check_numba_compatibility_gpu,
is_of_supported
)

"""
Expand Down
25 changes: 25 additions & 0 deletions dali/test/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from nose_utils import SkipTest


is_of_supported_var = None


def get_arch(device_id=0):
compute_cap = 0
try:
Expand Down Expand Up @@ -989,3 +992,25 @@ def load_test_operator_plugin():
except RuntimeError:
# in conda "libtestoperatorplugin" lands inside lib/ dir
plugin_manager.load_library("libtestoperatorplugin.so")


def is_of_supported(device_id=0):
global is_of_supported_var
if is_of_supported_var is not None:
return is_of_supported_var

driver_version_major = 0
try:
import pynvml

pynvml.nvmlInit()
driver_version = pynvml.nvmlSystemGetDriverVersion().decode("utf-8")
driver_version_major = int(driver_version.split(".")[0])
except ModuleNotFoundError:
print("NVML not found")

# there is an issue with OpticalFlow driver in R495 and newer on aarch64 platform
is_of_supported_var = get_arch(device_id) >= 7.5 and (
platform.machine() == "x86_64" or driver_version_major < 495
)
return is_of_supported_var
5 changes: 3 additions & 2 deletions qa/TL0_self_test_Ampere/test.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash -ex

pip_packages='${python_test_runner_package} numpy'
pip_packages='${python_test_runner_package} numpy opencv-python nvidia-ml-py==11.450.51'

target_dir=./dali/test/python

Expand Down Expand Up @@ -33,8 +33,9 @@ test_body() {
${python_new_invoke_test} -s decoder test_image

# test Optical Flow
${python_invoke_test} test_optical_flow.py
${python_new_invoke_test} -s operator_1 test_optical_flow
${python_invoke_test} test_dali_variable_batch_size.py:test_optical_flow
${python_invoke_test} test_dali_stateless_operators.py:test_optical_flow_stateless
}

pushd ../..
Expand Down

0 comments on commit d0775cf

Please sign in to comment.