Skip to content

Commit

Permalink
Merge branch '79-aligner-dataset-coordonnees-band' into 'master'
Browse files Browse the repository at this point in the history
Alignement Coordonnées des Datasets avec la convention utilisée par CARS

Closes #79

See merge request 3d/PandoraBox/pandora_plugins/plugin_libsgm!119
  • Loading branch information
adebardo committed Aug 31, 2023
2 parents e03ad2e + aa94a69 commit b6a9dc9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 13 deletions.
28 changes: 17 additions & 11 deletions pandora_plugin_libsgm/abstract_sgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import logging
import sys
from abc import abstractmethod
from typing import Dict, Union, Tuple
from typing import Dict, Union, Tuple, Optional

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -149,16 +149,8 @@ def optimize_cv(self, cv: xr.Dataset, img_left: xr.Dataset, img_right: xr.Datase
cv["cost_volume"].data *= -1

# If the input images were multiband, the band used for the correlation is used
if cv.attrs["band_correl"] is not None:
# Obtain correlation band from cost_volume attributes
band_index_left = list(img_left.band.data).index(cv.attrs["band_correl"])
band_index_right = list(img_right.band.data).index(cv.attrs["band_correl"])
# Get the image band
img_left_array = np.ascontiguousarray(img_left["im"].data[band_index_left, :, :], dtype=np.float32)
img_right_array = np.ascontiguousarray(img_right["im"].data[band_index_right, :, :], dtype=np.float32)
else:
img_left_array = np.ascontiguousarray(img_left["im"].data, dtype=np.float32)
img_right_array = np.ascontiguousarray(img_right["im"].data, dtype=np.float32)
img_left_array = get_band_values(img_left, cv.attrs["band_correl"])
img_right_array = get_band_values(img_right, cv.attrs["band_correl"])

# Compute penalties
invalid_value, p1_mat, p2_mat = self._penalty.compute_penalty(cv, img_left_array, img_right_array)
Expand Down Expand Up @@ -430,3 +422,17 @@ def sgm_cpp(
)

return cost_volumes_out


def get_band_values(image_dataset: xr.Dataset, band_name: Optional[str] = None) -> np.ndarray:
"""
Get values of given band_name from image_dataset as numpy array.
if band_name is not provided or is None, returns all bands values.
:param image_dataset: dataset to extract data from.
:param band_name: band_name to extract. If None selects all bands.
:return: selected values.
"""
selection = image_dataset if band_name is None else image_dataset.sel(band_im=band_name)
return selection["im"].to_numpy()
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def left_crafted():
"no_data_mask": 1,
"crs": None,
"transform": None,
"band_list": None,
},
)
return result
Expand All @@ -185,7 +184,6 @@ def right_crafted():
"no_data_mask": 1,
"crs": None,
"transform": None,
"band_list": None,
},
)
return result
43 changes: 43 additions & 0 deletions tests/test_libsgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import xarray as xr
from pandora import matching_cost, optimization, cost_volume_confidence
from pandora.state_machine import PandoraMachine
from pandora_plugin_libsgm.abstract_sgm import get_band_values

from tests import common

Expand Down Expand Up @@ -600,3 +601,45 @@ def test_optimization_layer_with_multiband(self, user_cfg, left_rgb, right_rgb):

# Check if the calculated optimized cv is equal to the ground truth
np.testing.assert_array_equal(cost_volumes_gt["cv"], out_cv["cost_volume"].data)


@pytest.mark.parametrize(
["band_name", "expected"],
[
(None, np.array([[[1, 1], [1, 1]], [[2, 2], [2, 2]], [[3, 3], [3, 3]]], dtype=np.float32)),
("r", np.array([[1, 1], [1, 1]], dtype=np.float32)),
("g", np.array([[2, 2], [2, 2]], dtype=np.float32)),
("b", np.array([[3, 3], [3, 3]], dtype=np.float32)),
],
)
def test_get_band_values(band_name, expected):
"""Given a band_name, test we get expected band values."""
data = np.array(
[
[
[1, 1],
[1, 1],
],
[
[2, 2],
[2, 2],
],
[
[3, 3],
[3, 3],
],
],
dtype=np.float32,
)
input_dataset = xr.Dataset(
{"im": (["band_im", "row", "col"], data)},
coords={
"band_im": ["r", "g", "b"],
"row": np.arange(2),
"col": np.arange(2),
},
)

result = get_band_values(input_dataset, band_name)

np.testing.assert_array_equal(result, expected)

0 comments on commit b6a9dc9

Please sign in to comment.