Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from typing import Union, Optional | |
| from monai.transforms import MapTransform | |
| from monai.config import DtypeLike, KeysCollection | |
| from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor | |
| from monai.data.meta_obj import get_track_meta | |
| from monai.transforms.transform import Transform | |
| from monai.transforms.utils import soft_clip | |
| from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where | |
| from monai.utils.enums import TransformBackends | |
| from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype | |
| from scipy.ndimage import binary_dilation | |
| import cv2 | |
| from typing import Union, Sequence | |
| from collections.abc import Hashable, Mapping, Sequence | |
| class DilateAndSaveMaskd(MapTransform): | |
| """ | |
| Custom transform to dilate binary mask and save a copy. | |
| """ | |
| def __init__(self, keys, dilation_size=10, copy_key="original_mask"): | |
| super().__init__(keys) | |
| self.dilation_size = dilation_size | |
| self.copy_key = copy_key | |
| def __call__(self, data): | |
| d = dict(data) | |
| for key in self.keys: | |
| mask = d[key].numpy() if isinstance(d[key], torch.Tensor) else d[key] | |
| mask = mask.squeeze(0) # Remove channel dimension if present | |
| # Save a copy of the original mask | |
| d[self.copy_key] = torch.tensor(mask, dtype=torch.float32).unsqueeze(0) # Save to a new key | |
| # Apply binary dilation to the mask | |
| dilated_mask = binary_dilation(mask, iterations=self.dilation_size).astype(np.uint8) | |
| # Store the dilated mask | |
| d[key] = torch.tensor(dilated_mask, dtype=torch.float32).unsqueeze(0) # Add channel dimension back | |
| return d | |
| class ClipMaskIntensityPercentiles(Transform): | |
| backend = [TransformBackends.TORCH, TransformBackends.NUMPY] | |
| def __init__( | |
| self, | |
| lower: Union[float, None], | |
| upper: Union[float, None], | |
| sharpness_factor : Union[float, None] = None, | |
| channel_wise: bool = False, | |
| dtype: DtypeLike = np.float32, | |
| ) -> None: | |
| if lower is None and upper is None: | |
| raise ValueError("lower or upper percentiles must be provided") | |
| if lower is not None and (lower < 0.0 or lower > 100.0): | |
| raise ValueError("Percentiles must be in the range [0, 100]") | |
| if upper is not None and (upper < 0.0 or upper > 100.0): | |
| raise ValueError("Percentiles must be in the range [0, 100]") | |
| if upper is not None and lower is not None and upper < lower: | |
| raise ValueError("upper must be greater than or equal to lower") | |
| if sharpness_factor is not None and sharpness_factor <= 0: | |
| raise ValueError("sharpness_factor must be greater than 0") | |
| #self.mask_data = mask_data | |
| self.lower = lower | |
| self.upper = upper | |
| self.sharpness_factor = sharpness_factor | |
| self.channel_wise = channel_wise | |
| self.dtype = dtype | |
| def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor: | |
| masked_img = img * (mask_data > 0) | |
| if self.sharpness_factor is not None: | |
| lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else None | |
| upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else None | |
| img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype) | |
| else: | |
| lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else percentile(masked_img, 0) | |
| upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else percentile(masked_img, 100) | |
| img = clip(img, lower_percentile, upper_percentile) | |
| img = convert_to_tensor(img, track_meta=False) | |
| return img | |
| def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor: | |
| """ | |
| Apply the transform to `img`. | |
| """ | |
| img = convert_to_tensor(img, track_meta=get_track_meta()) | |
| img_t = convert_to_tensor(img, track_meta=False) | |
| mask_t = convert_to_tensor(mask_data, track_meta=False) | |
| if self.channel_wise: | |
| img_t = torch.stack([self._clip(img=d, mask_data=mask_t[e]) for e,d in enumerate(img_t)]) # type: ignore | |
| else: | |
| img_t = self._clip(img=img_t, mask_data=mask_t) | |
| img = convert_to_dst_type(img_t, dst=img)[0] | |
| return img | |
| class ClipMaskIntensityPercentilesd(MapTransform): | |
| def __init__( | |
| self, | |
| keys: KeysCollection, | |
| mask_key: str, | |
| lower: Union[float, None], | |
| upper: Union[float, None], | |
| sharpness_factor: Union[float, None] = None, | |
| channel_wise: bool = False, | |
| dtype: DtypeLike = np.float32, | |
| allow_missing_keys: bool = False, | |
| ) -> None: | |
| super().__init__(keys, allow_missing_keys) | |
| self.scaler = ClipMaskIntensityPercentiles( | |
| lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype | |
| ) | |
| self.mask_key = mask_key | |
| def __call__(self, data: dict) -> dict: | |
| d = dict(data) | |
| for key in self.key_iterator(d): | |
| d[key] = self.scaler(d[key], d[self.mask_key]) | |
| return d | |
| class ElementwiseProductd(MapTransform): | |
| def __init__(self, keys: KeysCollection, output_key: str) -> None: | |
| super().__init__(keys) | |
| self.output_key = output_key | |
| def __call__(self, data) -> NdarrayOrTensor: | |
| d = dict(data) | |
| d[self.output_key] = d[self.keys[0]] * d[self.keys[1]] | |
| return d | |
| class CLAHEd(MapTransform): | |
| """ | |
| Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to images in a data dictionary. | |
| Works on 2D images or 3D volumes (applied slice-by-slice). | |
| Args: | |
| keys (KeysCollection): Keys of the items to be transformed. | |
| clip_limit (float): Threshold for contrast limiting. Default is 2.0. | |
| tile_grid_size (Union[tuple, Sequence[int]]): Size of grid for histogram equalization (default: (8,8)). | |
| """ | |
| def __init__( | |
| self, | |
| keys: KeysCollection, | |
| clip_limit: float = 2.0, | |
| tile_grid_size: Union[tuple, Sequence[int]] = (8, 8), | |
| ) -> None: | |
| super().__init__(keys) | |
| self.clip_limit = clip_limit | |
| self.tile_grid_size = tile_grid_size | |
| def __call__(self, data): | |
| d = dict(data) | |
| for key in self.keys: | |
| image_ = d[key] | |
| image = image_.cpu().numpy() | |
| if image.dtype != np.uint8: | |
| image = image.astype(np.uint8) | |
| clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size) | |
| # Handle 2D images or process 3D images slice-by-slice. | |
| image_clahe = np.stack([clahe.apply(slice) for slice in image[0]]) | |
| # Convert back to float in [0,1] | |
| processed_img = image_clahe.astype(np.float32) / 255.0 | |
| reshaped_ = processed_img.reshape(1, *processed_img.shape) | |
| d[key] = torch.from_numpy(reshaped_).to(image_.device) | |
| return d | |
| class NormalizeIntensity_custom(Transform): | |
| """ | |
| Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`. | |
| Use calculated mean or std value of the input image if no `subtrahend` or `divisor` provided. | |
| This transform can normalize only non-zero values or entire image, and can also calculate | |
| mean and std on each channel separately. | |
| When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should | |
| be the number of image channels if they are not None. | |
| If the input is not of floating point type, it will be converted to float32 | |
| Args: | |
| subtrahend: the amount to subtract by (usually the mean). | |
| divisor: the amount to divide by (usually the standard deviation). | |
| nonzero: whether only normalize non-zero values. | |
| channel_wise: if True, calculate on each channel separately, otherwise, calculate on | |
| the entire image directly. default to False. | |
| dtype: output data type, if None, same as input image. defaults to float32. | |
| """ | |
| backend = [TransformBackends.TORCH, TransformBackends.NUMPY] | |
| def __init__( | |
| self, | |
| subtrahend: Union[Sequence, NdarrayOrTensor, None] = None, | |
| divisor: Union[Sequence, NdarrayOrTensor, None] = None, | |
| nonzero: bool = False, | |
| channel_wise: bool = False, | |
| dtype: DtypeLike = np.float32, | |
| ) -> None: | |
| self.subtrahend = subtrahend | |
| self.divisor = divisor | |
| self.nonzero = nonzero | |
| self.channel_wise = channel_wise | |
| self.dtype = dtype | |
| def _mean(x): | |
| if isinstance(x, np.ndarray): | |
| return np.mean(x) | |
| x = torch.mean(x.float()) | |
| return x.item() if x.numel() == 1 else x | |
| def _std(x): | |
| if isinstance(x, np.ndarray): | |
| return np.std(x) | |
| x = torch.std(x.float(), unbiased=False) | |
| return x.item() if x.numel() == 1 else x | |
| def _normalize(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor: | |
| img, *_ = convert_data_type(img, dtype=torch.float32) | |
| ''' | |
| if self.nonzero: | |
| slices = img != 0 | |
| masked_img = img[slices] | |
| if not slices.any(): | |
| return img | |
| else: | |
| slices = None | |
| masked_img = img | |
| ''' | |
| slices = None | |
| mask_data = mask_data.squeeze(0) | |
| slices_mask = mask_data > 0 | |
| masked_img = img[slices_mask] | |
| _sub = sub if sub is not None else self._mean(masked_img) | |
| if isinstance(_sub, (torch.Tensor, np.ndarray)): | |
| _sub, *_ = convert_to_dst_type(_sub, img) | |
| if slices is not None: | |
| _sub = _sub[slices] | |
| _div = div if div is not None else self._std(masked_img) | |
| if np.isscalar(_div): | |
| if _div == 0.0: | |
| _div = 1.0 | |
| elif isinstance(_div, (torch.Tensor, np.ndarray)): | |
| _div, *_ = convert_to_dst_type(_div, img) | |
| if slices is not None: | |
| _div = _div[slices] | |
| _div[_div == 0.0] = 1.0 | |
| if slices is not None: | |
| img[slices] = (masked_img - _sub) / _div | |
| else: | |
| img = (img - _sub) / _div | |
| return img | |
| def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor: | |
| """ | |
| Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, | |
| """ | |
| img = convert_to_tensor(img, track_meta=get_track_meta()) | |
| mask_data = convert_to_tensor(mask_data, track_meta=get_track_meta()) | |
| dtype = self.dtype or img.dtype | |
| if self.channel_wise: | |
| if self.subtrahend is not None and len(self.subtrahend) != len(img): | |
| raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.") | |
| if self.divisor is not None and len(self.divisor) != len(img): | |
| raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.") | |
| if not img.dtype.is_floating_point: | |
| img, *_ = convert_data_type(img, dtype=torch.float32) | |
| for i, d in enumerate(img): | |
| img[i] = self._normalize( # type: ignore | |
| d, | |
| mask_data, | |
| sub=self.subtrahend[i] if self.subtrahend is not None else None, | |
| div=self.divisor[i] if self.divisor is not None else None, | |
| ) | |
| else: | |
| img = self._normalize(img, mask_data, self.subtrahend, self.divisor) | |
| out = convert_to_dst_type(img, img, dtype=dtype)[0] | |
| return out | |
| class NormalizeIntensity_customd(MapTransform): | |
| """ | |
| Dictionary-based wrapper of :py:class:`monai.transforms.NormalizeIntensity`. | |
| This transform can normalize only non-zero values or entire image, and can also calculate | |
| mean and std on each channel separately. | |
| Args: | |
| keys: keys of the corresponding items to be transformed. | |
| See also: monai.transforms.MapTransform | |
| subtrahend: the amount to subtract by (usually the mean) | |
| divisor: the amount to divide by (usually the standard deviation) | |
| nonzero: whether only normalize non-zero values. | |
| channel_wise: if True, calculate on each channel separately, otherwise, calculate on | |
| the entire image directly. default to False. | |
| dtype: output data type, if None, same as input image. defaults to float32. | |
| allow_missing_keys: don't raise exception if key is missing. | |
| """ | |
| backend = NormalizeIntensity_custom.backend | |
| def __init__( | |
| self, | |
| keys: KeysCollection, | |
| mask_key: str, | |
| subtrahend:Union[ NdarrayOrTensor, None] = None, | |
| divisor: Union[ NdarrayOrTensor, None] = None, | |
| nonzero: bool = False, | |
| channel_wise: bool = False, | |
| dtype: DtypeLike = np.float32, | |
| allow_missing_keys: bool = False, | |
| ) -> None: | |
| super().__init__(keys, allow_missing_keys) | |
| self.normalizer = NormalizeIntensity_custom(subtrahend, divisor, nonzero, channel_wise, dtype) | |
| self.mask_key = mask_key | |
| def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: | |
| d = dict(data) | |
| for key in self.key_iterator(d): | |
| d[key] = self.normalizer(d[key], d[self.mask_key]) | |
| return d |