Skip to content

Commit

Permalink
Merge pull request #91 from jumutc/master
Browse files Browse the repository at this point in the history
Resolving grayscale image handling & conversions
  • Loading branch information
muellerdo authored Jul 29, 2021
2 parents 50175b1 + c35101d commit f1edda9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
13 changes: 8 additions & 5 deletions miscnn/data_loading/interfaces/image_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,18 @@ def load_image(self, index):
# Load image from file
img_raw = Image.open(os.path.join(img_path, "imaging" + "." + \
self.img_format))
# Convert image to rgb or grayscale
if self.img_type == "grayscale":
# Convert image to rgb or grayscale if needed
if self.img_type == "grayscale" and len(img_raw.getbands()) > 1:
img_pil = img_raw.convert("LA")
elif self.img_type == "rgb":
elif self.img_type == "rgb" and img_raw.mode != "RGB":
img_pil = img_raw.convert("RGB")
else:
img_pil = img_raw
# Convert Pillow image to numpy matrix
img = np.array(img_pil)
# Keep only intensity for grayscale images
if self.img_type =="grayscale" : img = img[:,:,0]
# Keep only intensity for grayscale images if needed
if self.img_type == "grayscale" and len(img.shape) > 2:
img = img[:, :, 0]
# Return image
return img, {"type": "image"}

Expand Down
43 changes: 27 additions & 16 deletions tests/test_iointerfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,24 @@ def setUpClass(self):
# Create image and segmentation
np.random.seed(1234)
img = np.random.rand(16, 16, 16) * 255
self.img = img.astype(int)
self.img = img.astype(np.uint8)
seg = np.random.rand(16, 16, 16) * 3
self.seg = seg.astype(int)
self.seg = seg.astype(np.uint8)
# Initialize temporary directory
self.tmp_data = tempfile.TemporaryDirectory(prefix="tmp.miscnn.",
suffix=".data")
# Write PNG sample with image and segmentation
path_sample_image = opj(self.tmp_data.name, "image")
path_sample_image = opj(self.tmp_data.name, "image_01")
os.mkdir(path_sample_image)
img_pillow = Image.fromarray(img[:,:,0].astype(np.uint8))
img_pillow = Image.fromarray(self.img[:,:,0])
img_pillow.save(opj(path_sample_image, "imaging.png"))
seg_pillow = Image.fromarray(seg[:,:,0].astype(np.uint8))
seg_pillow = Image.fromarray(self.seg[:,:,0])
seg_pillow.save(opj(path_sample_image, "segmentation.png"))
path_sample_image = opj(self.tmp_data.name, "image_02")
os.mkdir(path_sample_image)
img_pillow = Image.fromarray(self.img[:, :, :3])
img_pillow.save(opj(path_sample_image, "imaging.png"))
seg_pillow = Image.fromarray(self.seg[:, :, 0])
seg_pillow.save(opj(path_sample_image, "segmentation.png"))
# Write NIfTI sample with image and segmentation
path_sample_nifti = opj(self.tmp_data.name, "nifti")
Expand All @@ -76,23 +82,28 @@ def test_IOI_IMAGE_creation(self):
interface = Image_interface()
# Initialization
def test_IOI_IMAGE_initialize(self):
interface = Image_interface(pattern="image")
interface = Image_interface(pattern="image_\\d+")
sample_list = interface.initialize(self.tmp_data.name)
self.assertEqual(len(sample_list), 1)
self.assertEqual(sample_list[0], "image")
self.assertEqual(len(sample_list), 2)
self.assertTrue("image_01" in sample_list)
self.assertTrue("image_02" in sample_list)
# Loading Images and Segmentations
def test_IOI_IMAGE_loading(self):
interface = Image_interface(pattern="image")
sample_list = interface.initialize(self.tmp_data.name)
img = interface.load_image(sample_list[0])
seg = interface.load_segmentation(sample_list[0])
self.assertTrue(np.array_equal(img[0], self.img[:,:,0]))
self.assertTrue(np.array_equal(seg, self.seg[:,:,0]))
interface = Image_interface(pattern="image_\\d+")
sample_list = sorted(interface.initialize(self.tmp_data.name))
img_01 = interface.load_image(sample_list[0])
seg_01 = interface.load_segmentation(sample_list[0])
self.assertTrue(np.array_equal(img_01[0], self.img[:, :, 0]))
self.assertTrue(np.array_equal(seg_01, self.seg[:,:,0]))
img_02 = interface.load_image(sample_list[1])
seg_02 = interface.load_segmentation(sample_list[1])
true_img_02 = np.array(Image.fromarray(self.img[:, :, :3]).convert('LA'))
self.assertTrue(np.array_equal(img_02[0], true_img_02[:,:,0]))
self.assertTrue(np.array_equal(seg_02, self.seg[:, :, 0]))
# NIFTI_interface - Loading and Storage of Predictions
def test_IOI_IMAGE_predictionhandling(self):
interface = Image_interface(pattern="image")
sample_list = interface.initialize(self.tmp_data.name)
sample = MIScnn_sample.Sample("pred.image", self.img[:,:,0], 1, 2);
sample = MIScnn_sample.Sample("pred.image", self.img[:,:,0], 1, 2)
sample.add_prediction(self.seg[:,:,0]);
interface.save_prediction(sample, self.tmp_data.name)
pred = interface.load_prediction("pred.image", self.tmp_data.name)
Expand Down

0 comments on commit f1edda9

Please sign in to comment.