Prostate-Inference / src /data /custom_transforms.py
Anirudh Balaraman
first commit
d348964
import numpy as np
import torch
from typing import Union, Optional
from monai.transforms import MapTransform
from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.meta_obj import get_track_meta
from monai.transforms.transform import Transform
from monai.transforms.utils import soft_clip
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype
from scipy.ndimage import binary_dilation
import cv2
from typing import Union, Sequence
from collections.abc import Hashable, Mapping, Sequence
class DilateAndSaveMaskd(MapTransform):
"""
Custom transform to dilate binary mask and save a copy.
"""
def __init__(self, keys, dilation_size=10, copy_key="original_mask"):
super().__init__(keys)
self.dilation_size = dilation_size
self.copy_key = copy_key
def __call__(self, data):
d = dict(data)
for key in self.keys:
mask = d[key].numpy() if isinstance(d[key], torch.Tensor) else d[key]
mask = mask.squeeze(0) # Remove channel dimension if present
# Save a copy of the original mask
d[self.copy_key] = torch.tensor(mask, dtype=torch.float32).unsqueeze(0) # Save to a new key
# Apply binary dilation to the mask
dilated_mask = binary_dilation(mask, iterations=self.dilation_size).astype(np.uint8)
# Store the dilated mask
d[key] = torch.tensor(dilated_mask, dtype=torch.float32).unsqueeze(0) # Add channel dimension back
return d
class ClipMaskIntensityPercentiles(Transform):
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
def __init__(
self,
lower: Union[float, None],
upper: Union[float, None],
sharpness_factor : Union[float, None] = None,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
if lower is None and upper is None:
raise ValueError("lower or upper percentiles must be provided")
if lower is not None and (lower < 0.0 or lower > 100.0):
raise ValueError("Percentiles must be in the range [0, 100]")
if upper is not None and (upper < 0.0 or upper > 100.0):
raise ValueError("Percentiles must be in the range [0, 100]")
if upper is not None and lower is not None and upper < lower:
raise ValueError("upper must be greater than or equal to lower")
if sharpness_factor is not None and sharpness_factor <= 0:
raise ValueError("sharpness_factor must be greater than 0")
#self.mask_data = mask_data
self.lower = lower
self.upper = upper
self.sharpness_factor = sharpness_factor
self.channel_wise = channel_wise
self.dtype = dtype
def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
masked_img = img * (mask_data > 0)
if self.sharpness_factor is not None:
lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else None
upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else None
img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype)
else:
lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else percentile(masked_img, 0)
upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else percentile(masked_img, 100)
img = clip(img, lower_percentile, upper_percentile)
img = convert_to_tensor(img, track_meta=False)
return img
def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
mask_t = convert_to_tensor(mask_data, track_meta=False)
if self.channel_wise:
img_t = torch.stack([self._clip(img=d, mask_data=mask_t[e]) for e,d in enumerate(img_t)]) # type: ignore
else:
img_t = self._clip(img=img_t, mask_data=mask_t)
img = convert_to_dst_type(img_t, dst=img)[0]
return img
class ClipMaskIntensityPercentilesd(MapTransform):
def __init__(
self,
keys: KeysCollection,
mask_key: str,
lower: Union[float, None],
upper: Union[float, None],
sharpness_factor: Union[float, None] = None,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.scaler = ClipMaskIntensityPercentiles(
lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype
)
self.mask_key = mask_key
def __call__(self, data: dict) -> dict:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key], d[self.mask_key])
return d
class ElementwiseProductd(MapTransform):
def __init__(self, keys: KeysCollection, output_key: str) -> None:
super().__init__(keys)
self.output_key = output_key
def __call__(self, data) -> NdarrayOrTensor:
d = dict(data)
d[self.output_key] = d[self.keys[0]] * d[self.keys[1]]
return d
class CLAHEd(MapTransform):
"""
Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to images in a data dictionary.
Works on 2D images or 3D volumes (applied slice-by-slice).
Args:
keys (KeysCollection): Keys of the items to be transformed.
clip_limit (float): Threshold for contrast limiting. Default is 2.0.
tile_grid_size (Union[tuple, Sequence[int]]): Size of grid for histogram equalization (default: (8,8)).
"""
def __init__(
self,
keys: KeysCollection,
clip_limit: float = 2.0,
tile_grid_size: Union[tuple, Sequence[int]] = (8, 8),
) -> None:
super().__init__(keys)
self.clip_limit = clip_limit
self.tile_grid_size = tile_grid_size
def __call__(self, data):
d = dict(data)
for key in self.keys:
image_ = d[key]
image = image_.cpu().numpy()
if image.dtype != np.uint8:
image = image.astype(np.uint8)
clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
# Handle 2D images or process 3D images slice-by-slice.
image_clahe = np.stack([clahe.apply(slice) for slice in image[0]])
# Convert back to float in [0,1]
processed_img = image_clahe.astype(np.float32) / 255.0
reshaped_ = processed_img.reshape(1, *processed_img.shape)
d[key] = torch.from_numpy(reshaped_).to(image_.device)
return d
class NormalizeIntensity_custom(Transform):
"""
Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
Use calculated mean or std value of the input image if no `subtrahend` or `divisor` provided.
This transform can normalize only non-zero values or entire image, and can also calculate
mean and std on each channel separately.
When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
be the number of image channels if they are not None.
If the input is not of floating point type, it will be converted to float32
Args:
subtrahend: the amount to subtract by (usually the mean).
divisor: the amount to divide by (usually the standard deviation).
nonzero: whether only normalize non-zero values.
channel_wise: if True, calculate on each channel separately, otherwise, calculate on
the entire image directly. default to False.
dtype: output data type, if None, same as input image. defaults to float32.
"""
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
def __init__(
self,
subtrahend: Union[Sequence, NdarrayOrTensor, None] = None,
divisor: Union[Sequence, NdarrayOrTensor, None] = None,
nonzero: bool = False,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
self.subtrahend = subtrahend
self.divisor = divisor
self.nonzero = nonzero
self.channel_wise = channel_wise
self.dtype = dtype
@staticmethod
def _mean(x):
if isinstance(x, np.ndarray):
return np.mean(x)
x = torch.mean(x.float())
return x.item() if x.numel() == 1 else x
@staticmethod
def _std(x):
if isinstance(x, np.ndarray):
return np.std(x)
x = torch.std(x.float(), unbiased=False)
return x.item() if x.numel() == 1 else x
def _normalize(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:
img, *_ = convert_data_type(img, dtype=torch.float32)
'''
if self.nonzero:
slices = img != 0
masked_img = img[slices]
if not slices.any():
return img
else:
slices = None
masked_img = img
'''
slices = None
mask_data = mask_data.squeeze(0)
slices_mask = mask_data > 0
masked_img = img[slices_mask]
_sub = sub if sub is not None else self._mean(masked_img)
if isinstance(_sub, (torch.Tensor, np.ndarray)):
_sub, *_ = convert_to_dst_type(_sub, img)
if slices is not None:
_sub = _sub[slices]
_div = div if div is not None else self._std(masked_img)
if np.isscalar(_div):
if _div == 0.0:
_div = 1.0
elif isinstance(_div, (torch.Tensor, np.ndarray)):
_div, *_ = convert_to_dst_type(_div, img)
if slices is not None:
_div = _div[slices]
_div[_div == 0.0] = 1.0
if slices is not None:
img[slices] = (masked_img - _sub) / _div
else:
img = (img - _sub) / _div
return img
def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
mask_data = convert_to_tensor(mask_data, track_meta=get_track_meta())
dtype = self.dtype or img.dtype
if self.channel_wise:
if self.subtrahend is not None and len(self.subtrahend) != len(img):
raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.")
if self.divisor is not None and len(self.divisor) != len(img):
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")
if not img.dtype.is_floating_point:
img, *_ = convert_data_type(img, dtype=torch.float32)
for i, d in enumerate(img):
img[i] = self._normalize( # type: ignore
d,
mask_data,
sub=self.subtrahend[i] if self.subtrahend is not None else None,
div=self.divisor[i] if self.divisor is not None else None,
)
else:
img = self._normalize(img, mask_data, self.subtrahend, self.divisor)
out = convert_to_dst_type(img, img, dtype=dtype)[0]
return out
class NormalizeIntensity_customd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.NormalizeIntensity`.
This transform can normalize only non-zero values or entire image, and can also calculate
mean and std on each channel separately.
Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
subtrahend: the amount to subtract by (usually the mean)
divisor: the amount to divide by (usually the standard deviation)
nonzero: whether only normalize non-zero values.
channel_wise: if True, calculate on each channel separately, otherwise, calculate on
the entire image directly. default to False.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""
backend = NormalizeIntensity_custom.backend
def __init__(
self,
keys: KeysCollection,
mask_key: str,
subtrahend:Union[ NdarrayOrTensor, None] = None,
divisor: Union[ NdarrayOrTensor, None] = None,
nonzero: bool = False,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.normalizer = NormalizeIntensity_custom(subtrahend, divisor, nonzero, channel_wise, dtype)
self.mask_key = mask_key
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.normalizer(d[key], d[self.mask_key])
return d