| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import logging |
| import os |
| import random |
| import time |
| from datetime import datetime |
|
|
| import monai |
| import torch |
| from monai.data import MetaTensor |
| from monai.inferers.inferer import DiffusionInferer, SlidingWindowInferer |
| from monai.transforms import Compose, SaveImage |
| from monai.utils import set_determinism |
| from tqdm import tqdm |
|
|
| from .augmentation import augmentation |
| from .find_masks import find_masks |
| from .quality_check import is_outlier |
| from .utils import binarize_labels, dynamic_infer, general_mask_generation_post_process, remap_labels |
|
|
| modality_mapping = { |
| "unknown": 0, |
| "ct": 1, |
| "ct_wo_contrast": 2, |
| "ct_contrast": 3, |
| "mri": 8, |
| "mri_t1": 9, |
| "mri_t2": 10, |
| "mri_flair": 11, |
| "mri_pd": 12, |
| "mri_dwi": 13, |
| "mri_adc": 14, |
| "mri_ssfp": 15, |
| "mri_mra": 16, |
| } |
|
|
|
|
| class ReconModel(torch.nn.Module): |
| """ |
| A PyTorch module for reconstructing images from latent representations. |
| |
| Attributes: |
| autoencoder: The autoencoder model used for decoding. |
| scale_factor: Scaling factor applied to the input before decoding. |
| """ |
|
|
| def __init__(self, autoencoder, scale_factor): |
| super().__init__() |
| self.autoencoder = autoencoder |
| self.scale_factor = scale_factor |
|
|
| def forward(self, z): |
| """ |
| Decode the input latent representation to an image. |
| |
| Args: |
| z (torch.Tensor): The input latent representation. |
| |
| Returns: |
| torch.Tensor: The reconstructed image. |
| """ |
| recon_pt_nda = self.autoencoder.decode_stage_2_outputs(z / self.scale_factor) |
| return recon_pt_nda |
|
|
|
|
| def initialize_noise_latents(latent_shape, device): |
| """ |
| Initialize random noise latents for image generation with float16. |
| |
| Args: |
| latent_shape (tuple): The shape of the latent space. |
| device (torch.device): The device to create the tensor on. |
| |
| Returns: |
| torch.Tensor: Initialized noise latents. |
| """ |
| return torch.randn([1] + list(latent_shape)).half().to(device) |
|
|
|
|
| def ldm_conditional_sample_one_mask( |
| autoencoder, |
| diffusion_unet, |
| noise_scheduler, |
| scale_factor, |
| anatomy_size, |
| device, |
| latent_shape, |
| label_dict_remap_json, |
| num_inference_steps=1000, |
| autoencoder_sliding_window_infer_size=(96, 96, 96), |
| autoencoder_sliding_window_infer_overlap=0.6667, |
| ): |
| """ |
| Generate a single synthetic mask using a latent diffusion model. |
| |
| Args: |
| autoencoder (nn.Module): The autoencoder model. |
| diffusion_unet (nn.Module): The diffusion U-Net model. |
| noise_scheduler: The noise scheduler for the diffusion process. |
| scale_factor (float): Scaling factor for the latent space. |
| anatomy_size (torch.Tensor): Tensor specifying the desired anatomy sizes. |
| device (torch.device): The device to run the computation on. |
| latent_shape (tuple): The shape of the latent space. |
| label_dict_remap_json (str): Path to the JSON file for label remapping. |
| num_inference_steps (int): Number of inference steps for the diffusion process. |
| autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96]. |
| autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667. |
| |
| Returns: |
| torch.Tensor: The generated synthetic mask. |
| """ |
| recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) |
|
|
| with torch.no_grad(), torch.amp.autocast("cuda"): |
| |
| latents = initialize_noise_latents(latent_shape, device) |
| anatomy_size = torch.FloatTensor(anatomy_size).unsqueeze(0).unsqueeze(0).half().to(device) |
| |
| noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) |
| inferer_ddpm = DiffusionInferer(noise_scheduler) |
| latents = inferer_ddpm.sample( |
| input_noise=latents, |
| diffusion_model=diffusion_unet, |
| scheduler=noise_scheduler, |
| verbose=True, |
| conditioning=anatomy_size.to(device), |
| ) |
| |
| inferer = SlidingWindowInferer( |
| roi_size=autoencoder_sliding_window_infer_size, |
| sw_batch_size=1, |
| progress=True, |
| mode="gaussian", |
| overlap=autoencoder_sliding_window_infer_overlap, |
| device=torch.device("cpu"), |
| sw_device=device, |
| ) |
| synthetic_mask = dynamic_infer(inferer, recon_model, latents) |
| synthetic_mask = torch.softmax(synthetic_mask, dim=1) |
| synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True) |
| |
| synthetic_mask = remap_labels(synthetic_mask, label_dict_remap_json) |
|
|
| |
| data = synthetic_mask.squeeze().cpu().detach().numpy() |
|
|
| labels = [23, 24, 26, 27, 128] |
| target_tumor_label = None |
| for index, size in enumerate(anatomy_size[0, 0, 5:10]): |
| if size.item() != -1.0: |
| target_tumor_label = labels[index] |
|
|
| logging.info(f"target_tumor_label for postprocess:{target_tumor_label}") |
| data = general_mask_generation_post_process(data, target_tumor_label=target_tumor_label, device=device) |
| synthetic_mask = torch.from_numpy(data).unsqueeze(0).unsqueeze(0).to(device) |
|
|
| return synthetic_mask |
|
|
|
|
| def ldm_conditional_sample_one_image( |
| autoencoder, |
| diffusion_unet, |
| controlnet, |
| noise_scheduler, |
| scale_factor, |
| device, |
| combine_label_or, |
| modality_tensor, |
| spacing_tensor, |
| latent_shape, |
| output_size, |
| noise_factor, |
| num_inference_steps=1000, |
| autoencoder_sliding_window_infer_size=(96, 96, 96), |
| autoencoder_sliding_window_infer_overlap=0.6667, |
| ): |
| """ |
| Generate a single synthetic image using a latent diffusion model with controlnet. |
| |
| Args: |
| autoencoder (nn.Module): The autoencoder model. |
| diffusion_unet (nn.Module): The diffusion U-Net model. |
| controlnet (nn.Module): The controlnet model. |
| noise_scheduler: The noise scheduler for the diffusion process. |
| scale_factor (float): Scaling factor for the latent space. |
| device (torch.device): The device to run the computation on. |
| combine_label_or (torch.Tensor): The combined label tensor. |
| spacing_tensor (torch.Tensor): Tensor specifying the spacing. |
| latent_shape (tuple): The shape of the latent space. |
| output_size (tuple): The desired output size of the image. |
| noise_factor (float): Factor to scale the initial noise. |
| num_inference_steps (int): Number of inference steps for the diffusion process. |
| autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96]. |
| autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667. |
| |
| Returns: |
| tuple: A tuple containing the synthetic image and its corresponding label. |
| """ |
| |
| a_min = -1000 |
| a_max = 1000 |
| |
| b_min = 0.0 |
| b_max = 1 |
|
|
| recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) |
|
|
| with torch.no_grad(), torch.amp.autocast("cuda", enabled=True): |
| logging.info("---- Start generating latent features... ----") |
| start_time = time.time() |
| |
| combine_label = combine_label_or.to(device) |
| if ( |
| output_size[0] != combine_label.shape[2] |
| or output_size[1] != combine_label.shape[3] |
| or output_size[2] != combine_label.shape[4] |
| ): |
| logging.info( |
| "output_size is not a desired value. Need to interpolate the mask to match " |
| "with output_size. The result image will be very low quality." |
| ) |
| combine_label = torch.nn.functional.interpolate(combine_label, size=output_size, mode="nearest") |
|
|
| controlnet_cond_vis = binarize_labels(combine_label.as_tensor().long()).half() |
|
|
| |
| latents = initialize_noise_latents(latent_shape, device) * noise_factor |
|
|
| |
| noise_scheduler.set_timesteps( |
| num_inference_steps=num_inference_steps, input_img_size=torch.prod(torch.tensor(latent_shape[-3:])) |
| ) |
| |
| guidance_scale = 0 |
| all_next_timesteps = torch.cat( |
| (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) |
| ) |
| for t, next_t in tqdm( |
| zip(noise_scheduler.timesteps, all_next_timesteps), |
| total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), |
| ): |
| timesteps = torch.Tensor((t,)).to(device) |
| if guidance_scale == 0: |
| down_block_res_samples, mid_block_res_sample = controlnet( |
| x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis, class_labels=modality_tensor |
| ) |
| predicted_velocity = diffusion_unet( |
| x=latents, |
| timesteps=timesteps, |
| spacing_tensor=spacing_tensor, |
| class_labels=modality_tensor, |
| down_block_additional_residuals=down_block_res_samples, |
| mid_block_additional_residual=mid_block_res_sample, |
| ) |
| else: |
| down_block_res_samples, mid_block_res_sample = controlnet( |
| x=torch.cat([latents] * 2), |
| timesteps=torch.cat([timesteps] * 2), |
| controlnet_cond=torch.cat([controlnet_cond_vis] * 2), |
| class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), |
| ) |
| model_t, model_uncond = diffusion_unet( |
| x=torch.cat([latents] * 2), |
| timesteps=timesteps, |
| spacing_tensor=torch.cat([timesteps] * 2), |
| class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), |
| down_block_additional_residuals=down_block_res_samples, |
| mid_block_additional_residual=mid_block_res_sample, |
| ).chunk(2) |
| predicted_velocity = model_uncond + guidance_scale * (model_t - model_uncond) |
| latents, _ = noise_scheduler.step(predicted_velocity, t, latents, next_timestep=next_t) |
| end_time = time.time() |
| logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----") |
| del predicted_velocity |
| torch.cuda.empty_cache() |
|
|
| |
| logging.info("---- Start decoding latent features into images... ----") |
| inferer = SlidingWindowInferer( |
| roi_size=autoencoder_sliding_window_infer_size, |
| sw_batch_size=1, |
| progress=True, |
| mode="gaussian", |
| overlap=autoencoder_sliding_window_infer_overlap, |
| device=torch.device("cpu"), |
| sw_device=device, |
| ) |
| start_time = time.time() |
| synthetic_images = dynamic_infer(inferer, recon_model, latents) |
| synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu() |
| end_time = time.time() |
| logging.info(f"---- Image decoding time: {end_time - start_time} seconds ----") |
|
|
| |
| |
| synthetic_images = (synthetic_images - b_min) / (b_max - b_min) |
| |
| synthetic_images = synthetic_images * (a_max - a_min) + a_min |
| |
| synthetic_images = crop_img_body_mask(synthetic_images, combine_label) |
| torch.cuda.empty_cache() |
|
|
| return synthetic_images, combine_label |
|
|
|
|
| def filter_mask_with_organs(combine_label, anatomy_list): |
| """ |
| Filter a mask to only include specified organs. |
| |
| Args: |
| combine_label (torch.Tensor): The input mask. |
| anatomy_list (list): List of organ labels to keep. |
| |
| Returns: |
| torch.Tensor: The filtered mask. |
| """ |
| |
| |
| combine_label = combine_label.long() |
| |
| for i in range(len(anatomy_list)): |
| organ = anatomy_list[i] |
| |
| combine_label[combine_label == organ] = -(i + 1) |
| |
| combine_label[combine_label > 0] = 0 |
| |
| combine_label = -combine_label |
| return combine_label |
|
|
|
|
| def crop_img_body_mask(synthetic_images, combine_label): |
| """ |
| Crop the synthetic image using a body mask. |
| |
| Args: |
| synthetic_images (torch.Tensor): The synthetic images. |
| combine_label (torch.Tensor): The body mask. |
| |
| Returns: |
| torch.Tensor: The cropped synthetic images. |
| """ |
| synthetic_images[combine_label == 0] = -1000 |
| return synthetic_images |
|
|
|
|
| def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing, controllable_anatomy_size): |
| """ |
| Validate input parameters for image generation. |
| |
| Args: |
| body_region (list): List of body regions. |
| anatomy_list (list): List of anatomical structures. |
| label_dict_json (str): Path to the label dictionary JSON file. |
| output_size (tuple): Desired output size of the image. |
| spacing (tuple): Desired voxel spacing. |
| controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes. |
| |
| Raises: |
| ValueError: If any input parameter is invalid. |
| """ |
| |
| if output_size[0] != output_size[1]: |
| raise ValueError(f"The first two components of output_size need to be equal, yet got {output_size}.") |
| if (output_size[0] not in [256, 384, 512]) or (output_size[2] not in [128, 256, 384, 512, 640, 768]): |
| raise ValueError( |
| ( |
| "The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] " |
| f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." |
| ) |
| ) |
|
|
| if spacing[0] != spacing[1]: |
| raise ValueError(f"The first two components of spacing need to be equal, yet got {spacing}.") |
| if spacing[0] < 0.5 or spacing[0] > 3.0 or spacing[2] < 0.5 or spacing[2] > 5.0: |
| raise ValueError( |
| f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." |
| ) |
|
|
| if ( |
| output_size[0] * spacing[0] < 256 |
| or output_size[2] * spacing[2] < 128 |
| or output_size[0] * spacing[0] > 640 |
| or output_size[2] * spacing[2] > 2000 |
| ): |
| fov = [output_size[axis] * spacing[axis] for axis in range(3)] |
| raise ValueError( |
| ( |
| f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). " |
| f"The FOV will be {fov}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least " |
| "384mm for other body regions like abdomen, and less than 640mm. " |
| "For z-axis, we require it to be at least 128mm and less than 2000mm." |
| ) |
| ) |
|
|
| |
| if len(controllable_anatomy_size) > 10: |
| raise ValueError( |
| ( |
| "The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] " |
| f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." |
| ) |
| ) |
| available_controllable_organ = ["liver", "gallbladder", "stomach", "pancreas", "colon"] |
| available_controllable_tumor = [ |
| "hepatic tumor", |
| "bone lesion", |
| "lung tumor", |
| "colon cancer primaries", |
| "pancreatic tumor", |
| ] |
| available_controllable_anatomy = available_controllable_organ + available_controllable_tumor |
| controllable_tumor = [] |
| controllable_organ = [] |
| for controllable_anatomy_size_pair in controllable_anatomy_size: |
| if controllable_anatomy_size_pair[0] not in available_controllable_anatomy: |
| raise ValueError( |
| ( |
| f"The controllable_anatomy have to be chosen from {available_controllable_anatomy}, " |
| f"yet got {controllable_anatomy_size_pair[0]}." |
| ) |
| ) |
| if controllable_anatomy_size_pair[0] in available_controllable_tumor: |
| controllable_tumor += [controllable_anatomy_size_pair[0]] |
| if controllable_anatomy_size_pair[0] in available_controllable_organ: |
| controllable_organ += [controllable_anatomy_size_pair[0]] |
| if controllable_anatomy_size_pair[1] == -1: |
| continue |
| if controllable_anatomy_size_pair[1] < 0 or controllable_anatomy_size_pair[1] > 1.0: |
| raise ValueError( |
| ( |
| "The controllable size scale have to be between 0 and 1,0, or equal to -1, " |
| f"yet got {controllable_anatomy_size_pair[1]}." |
| ) |
| ) |
| if len(controllable_tumor + controllable_organ) != len(list(set(controllable_tumor + controllable_organ))): |
| raise ValueError(f"Please do not repeat controllable_anatomy. Got {controllable_tumor + controllable_organ}.") |
| if len(controllable_tumor) > 1: |
| raise ValueError(f"Only one controllable tumor is supported. Yet got {controllable_tumor}.") |
|
|
| if len(controllable_anatomy_size) > 0: |
| logging.info( |
| ( |
| "`controllable_anatomy_size` is not empty.\nWe will ignore `body_region` and `anatomy_list` " |
| f"and synthesize based on `controllable_anatomy_size`: ({controllable_anatomy_size})." |
| ) |
| ) |
| else: |
| logging.info( |
| (f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `anatomy_list`: ({anatomy_list}).") |
| ) |
| |
| available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"] |
| for region in body_region: |
| if region not in available_body_region: |
| raise ValueError( |
| f"The components in body_region have to be chosen from {available_body_region}, yet got {region}." |
| ) |
|
|
| |
| with open(label_dict_json) as f: |
| label_dict = json.load(f) |
| for anatomy in anatomy_list: |
| if anatomy not in label_dict.keys(): |
| raise ValueError( |
| f"The components in anatomy_list have to be chosen from {label_dict.keys()}, yet got {anatomy}." |
| ) |
| logging.info(f"The generate results will have voxel size to be {spacing} mm, volume size to be {output_size}.") |
|
|
| return |
|
|
|
|
| class LDMSampler: |
| """ |
| A sampler class for generating synthetic medical images and masks using latent diffusion models. |
| |
| Attributes: |
| Various attributes related to model configuration, input parameters, and generation settings. |
| """ |
|
|
| def __init__( |
| self, |
| body_region, |
| anatomy_list, |
| modality, |
| all_mask_files_json, |
| all_anatomy_size_condtions_json, |
| all_mask_files_base_dir, |
| label_dict_json, |
| label_dict_remap_json, |
| autoencoder, |
| diffusion_unet, |
| controlnet, |
| noise_scheduler, |
| scale_factor, |
| mask_generation_autoencoder, |
| mask_generation_diffusion_unet, |
| mask_generation_scale_factor, |
| mask_generation_noise_scheduler, |
| device, |
| latent_shape, |
| mask_generation_latent_shape, |
| output_size, |
| output_dir, |
| controllable_anatomy_size, |
| image_output_ext=".nii.gz", |
| label_output_ext=".nii.gz", |
| real_img_median_statistics="./configs/image_median_statistics.json", |
| spacing=(1, 1, 1), |
| num_inference_steps=None, |
| mask_generation_num_inference_steps=None, |
| random_seed=None, |
| autoencoder_sliding_window_infer_size=(96, 96, 96), |
| autoencoder_sliding_window_infer_overlap=0.6667, |
| ) -> None: |
| """ |
| Initialize the LDMSampler with various parameters and models. |
| |
| Args: |
| Various parameters related to model configuration, input settings, and output specifications. |
| """ |
| self.random_seed = random_seed |
| if random_seed is not None: |
| set_determinism(seed=random_seed) |
|
|
| with open(label_dict_json, "r") as f: |
| label_dict = json.load(f) |
| self.all_anatomy_size_condtions_json = all_anatomy_size_condtions_json |
|
|
| |
| self.body_region = body_region |
| self.anatomy_list = [label_dict[organ] for organ in anatomy_list] |
| self.modality_int = modality_mapping[modality] |
| self.all_mask_files_json = all_mask_files_json |
| self.data_root = all_mask_files_base_dir |
| self.label_dict_remap_json = label_dict_remap_json |
| self.autoencoder = autoencoder |
| self.diffusion_unet = diffusion_unet |
| self.controlnet = controlnet |
| self.noise_scheduler = noise_scheduler |
| self.scale_factor = scale_factor |
| self.mask_generation_autoencoder = mask_generation_autoencoder |
| self.mask_generation_diffusion_unet = mask_generation_diffusion_unet |
| self.mask_generation_scale_factor = mask_generation_scale_factor |
| self.mask_generation_noise_scheduler = mask_generation_noise_scheduler |
| self.device = device |
| self.latent_shape = latent_shape |
| self.mask_generation_latent_shape = mask_generation_latent_shape |
| self.output_size = output_size |
| self.output_dir = output_dir |
| self.noise_factor = 1.0 |
| self.controllable_anatomy_size = controllable_anatomy_size |
| if len(self.controllable_anatomy_size): |
| logging.info("controllable_anatomy_size is given, mask generation is triggered!") |
| |
| self.anatomy_list = [label_dict[organ_and_size[0]] for organ_and_size in self.controllable_anatomy_size] |
| self.image_output_ext = image_output_ext |
| self.label_output_ext = label_output_ext |
| |
| self.num_inference_steps = num_inference_steps if num_inference_steps is not None else 1000 |
| self.mask_generation_num_inference_steps = ( |
| mask_generation_num_inference_steps if mask_generation_num_inference_steps is not None else 1000 |
| ) |
|
|
| if any(size % 16 != 0 for size in autoencoder_sliding_window_infer_size): |
| raise ValueError( |
| f"autoencoder_sliding_window_infer_size must be divisible by 16.\n Got {autoencoder_sliding_window_infer_size}" |
| ) |
| if not (0 <= autoencoder_sliding_window_infer_overlap <= 1): |
| raise ValueError( |
| ( |
| "Value of autoencoder_sliding_window_infer_overlap must be between 0 " |
| f"and 1.\n Got {autoencoder_sliding_window_infer_overlap}" |
| ) |
| ) |
| self.autoencoder_sliding_window_infer_size = autoencoder_sliding_window_infer_size |
| self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap |
|
|
| |
| self.max_try_time = 3 |
| with open(real_img_median_statistics, "r") as json_file: |
| self.median_statistics = json.load(json_file) |
| self.label_int_dict = { |
| "liver": [1], |
| "spleen": [3], |
| "pancreas": [4], |
| "kidney": [5, 14], |
| "lung": [28, 29, 30, 31, 31], |
| "brain": [22], |
| "hepatic tumor": [26], |
| "bone lesion": [128], |
| "lung tumor": [23], |
| "colon cancer primaries": [27], |
| "pancreatic tumor": [24], |
| "bone": list(range(33, 57)) + list(range(63, 98)) + [120, 122, 127], |
| } |
|
|
| |
| self.autoencoder.eval() |
| self.diffusion_unet.eval() |
| self.controlnet.eval() |
| self.mask_generation_autoencoder.eval() |
| self.mask_generation_diffusion_unet.eval() |
|
|
| self.spacing = spacing |
|
|
| self.val_transforms = Compose( |
| [ |
| monai.transforms.LoadImaged(keys=["pseudo_label"]), |
| monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]), |
| monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"), |
| monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8), |
| monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), |
| monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), |
| ] |
| ) |
| logging.info("LDM sampler initialized.") |
|
|
| def sample_multiple_images(self, num_img): |
| """ |
| Generate multiple synthetic images and masks. |
| |
| Args: |
| num_img (int): Number of images to generate. |
| """ |
| output_filenames = [] |
| if len(self.controllable_anatomy_size) > 0: |
| |
| |
| selected_mask_files = list(range(num_img)) |
| |
| anatomy_size_condtion = self.prepare_anatomy_size_condtion(self.controllable_anatomy_size) |
| else: |
| need_resample = False |
| |
| candidate_mask_files = find_masks( |
| self.anatomy_list, self.spacing, self.output_size, True, self.all_mask_files_json, self.data_root |
| ) |
| if len(candidate_mask_files) < num_img: |
| |
| |
| logging.info("Resample mask file to get desired output size and spacing") |
| candidate_mask_files = self.find_closest_masks(num_img) |
| need_resample = True |
|
|
| selected_mask_files = self.select_mask(candidate_mask_files, num_img) |
| if len(selected_mask_files) < num_img: |
| raise ValueError( |
| ( |
| f"len(selected_mask_files) ({len(selected_mask_files)}) < num_img ({num_img}). " |
| "This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." |
| ) |
| ) |
| num_generated_img = 0 |
| for index_s in range(len(selected_mask_files)): |
| item = selected_mask_files[index_s] |
| if num_generated_img >= num_img: |
| break |
| logging.info("---- Start preparing masks... ----") |
| start_time = time.time() |
| logging.info(f"Image will be generated based on {item}.") |
| if len(self.controllable_anatomy_size) > 0: |
| |
| (combine_label_or, spacing_tensor) = self.prepare_one_mask_and_meta_info(anatomy_size_condtion) |
| else: |
| |
| mask_file = item["mask_file"] |
| if_aug = item["if_aug"] |
| (combine_label_or, spacing_tensor) = self.read_mask_information(mask_file) |
| if need_resample: |
| combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) |
| |
| if if_aug: |
| combine_label_or = augmentation(combine_label_or, self.output_size, random_seed=self.random_seed) |
| end_time = time.time() |
| logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----") |
| torch.cuda.empty_cache() |
| |
| modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int |
| |
| synthetic_images, synthetic_labels = self.sample_one_pair(combine_label_or, modality_tensor, spacing_tensor) |
| |
| pass_quality_check = self.quality_check( |
| synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() |
| ) |
| if pass_quality_check or (num_img - num_generated_img) >= (len(selected_mask_files) - index_s): |
| if not pass_quality_check: |
| logging.info( |
| "Generated image/label pair did not pass quality check, but will still save them. " |
| "Please consider changing spacing and output_size to facilitate a more realistic setting." |
| ) |
| num_generated_img = num_generated_img + 1 |
| |
| output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
| synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" |
| synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta) |
| img_saver = SaveImage( |
| output_dir=self.output_dir, |
| output_postfix=output_postfix + "_image", |
| output_ext=self.image_output_ext, |
| separate_folder=False, |
| ) |
| img_saver(synthetic_images[0]) |
| synthetic_images_filename = os.path.join( |
| self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext |
| ) |
| |
| synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) |
| label_saver = SaveImage( |
| output_dir=self.output_dir, |
| output_postfix=output_postfix + "_label", |
| output_ext=self.label_output_ext, |
| separate_folder=False, |
| ) |
| label_saver(synthetic_labels[0]) |
| synthetic_labels_filename = os.path.join( |
| self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext |
| ) |
| output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) |
| else: |
| logging.info("Generated image/label pair did not pass quality check, will re-generate another pair.") |
| return output_filenames |
|
|
| def select_mask(self, candidate_mask_files, num_img): |
| """ |
| Select mask files for image generation. |
| |
| Args: |
| candidate_mask_files (list): List of candidate mask files. |
| num_img (int): Number of images to generate. |
| |
| Returns: |
| list: Selected mask files with augmentation flags. |
| """ |
| selected_mask_files = [] |
| random.shuffle(candidate_mask_files) |
|
|
| for n in range(num_img * self.max_try_time): |
| mask_file = candidate_mask_files[n % len(candidate_mask_files)] |
| selected_mask_files.append({"mask_file": mask_file, "if_aug": True}) |
| return selected_mask_files |
|
|
| def sample_one_pair(self, combine_label_or_aug, modality_tensor, spacing_tensor): |
| """ |
| Generate a single pair of synthetic image and mask. |
| |
| Args: |
| combine_label_or_aug (torch.Tensor): Combined label tensor or augmented label. |
| modality_tensor (torch.Tensor): Tensor specifying the image modality. |
| spacing_tensor (torch.Tensor): Tensor specifying the spacing. |
| |
| Returns: |
| tuple: A tuple containing the synthetic image and its corresponding label. |
| """ |
| |
| synthetic_images, synthetic_labels = ldm_conditional_sample_one_image( |
| autoencoder=self.autoencoder, |
| diffusion_unet=self.diffusion_unet, |
| controlnet=self.controlnet, |
| noise_scheduler=self.noise_scheduler, |
| scale_factor=self.scale_factor, |
| device=self.device, |
| combine_label_or=combine_label_or_aug, |
| modality_tensor=modality_tensor, |
| spacing_tensor=spacing_tensor, |
| latent_shape=self.latent_shape, |
| output_size=self.output_size, |
| noise_factor=self.noise_factor, |
| num_inference_steps=self.num_inference_steps, |
| autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size, |
| autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap, |
| ) |
| return synthetic_images, synthetic_labels |
|
|
| def prepare_anatomy_size_condtion(self, controllable_anatomy_size): |
| """ |
| Prepare anatomy size conditions for mask generation. |
| |
| Args: |
| controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes. |
| |
| Returns: |
| list: Prepared anatomy size conditions. |
| """ |
| anatomy_size_idx = { |
| "gallbladder": 0, |
| "liver": 1, |
| "stomach": 2, |
| "pancreas": 3, |
| "colon": 4, |
| "lung tumor": 5, |
| "pancreatic tumor": 6, |
| "hepatic tumor": 7, |
| "colon cancer primaries": 8, |
| "bone lesion": 9, |
| } |
| provide_anatomy_size = [None for _ in range(10)] |
| logging.info(f"controllable_anatomy_size: {controllable_anatomy_size}") |
| for element in controllable_anatomy_size: |
| anatomy_name, anatomy_size = element |
| provide_anatomy_size[anatomy_size_idx[anatomy_name]] = anatomy_size |
|
|
| with open(self.all_anatomy_size_condtions_json, "r") as f: |
| all_anatomy_size_condtions = json.load(f) |
|
|
| |
| candidate_list = [] |
| for anatomy_size in all_anatomy_size_condtions: |
| size = anatomy_size["organ_size"] |
| diff = 0 |
| for db_size, provide_size in zip(size, provide_anatomy_size): |
| if provide_size is None: |
| continue |
| diff += abs(provide_size - db_size) |
| candidate_list.append((size, diff)) |
| candidate_condition = sorted(candidate_list, key=lambda x: x[1])[0][0] |
|
|
| |
| for element in controllable_anatomy_size: |
| anatomy_name, anatomy_size = element |
| candidate_condition[anatomy_size_idx[anatomy_name]] = anatomy_size |
|
|
| return candidate_condition |
|
|
| def prepare_one_mask_and_meta_info(self, anatomy_size_condtion): |
| """ |
| Prepare a single mask and its associated meta information. |
| |
| Args: |
| anatomy_size_condtion (list): Anatomy size conditions. |
| |
| Returns: |
| tuple: A tuple containing the prepared mask and associated tensors. |
| """ |
| combine_label_or = self.sample_one_mask(anatomy_size=anatomy_size_condtion) |
| |
| affine = torch.zeros((4, 4)) |
| affine[0, 0] = 1.5 |
| affine[1, 1] = 1.5 |
| affine[2, 2] = 1.5 |
| affine[3, 3] = 1.0 |
| combine_label_or = MetaTensor(combine_label_or, affine=affine) |
| combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) |
|
|
| spacing_tensor = torch.FloatTensor(self.spacing).unsqueeze(0).half().to(self.device) * 1e2 |
|
|
| return combine_label_or, spacing_tensor |
|
|
| def sample_one_mask(self, anatomy_size): |
| """ |
| Generate a single synthetic mask. |
| |
| Args: |
| anatomy_size (list): Anatomy size specifications. |
| |
| Returns: |
| torch.Tensor: The generated synthetic mask. |
| """ |
| |
| synthetic_mask = ldm_conditional_sample_one_mask( |
| self.mask_generation_autoencoder, |
| self.mask_generation_diffusion_unet, |
| self.mask_generation_noise_scheduler, |
| self.mask_generation_scale_factor, |
| anatomy_size, |
| self.device, |
| self.mask_generation_latent_shape, |
| label_dict_remap_json=self.label_dict_remap_json, |
| num_inference_steps=self.mask_generation_num_inference_steps, |
| autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size, |
| autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap, |
| ) |
| return synthetic_mask |
|
|
| def ensure_output_size_and_spacing(self, labels, check_contains_target_labels=True): |
| """ |
| Ensure the output mask has the correct size and spacing. |
| |
| Args: |
| labels (torch.Tensor): Input label tensor. |
| check_contains_target_labels (bool): Whether to check if the resampled mask contains target labels. |
| |
| Returns: |
| torch.Tensor: Resampled label tensor. |
| |
| Raises: |
| ValueError: If the resampled mask doesn't contain required class labels. |
| """ |
| current_spacing = [labels.affine[0, 0], labels.affine[1, 1], labels.affine[2, 2]] |
| current_shape = list(labels.squeeze().shape) |
|
|
| need_resample = False |
| |
| for i, j in zip(current_spacing, self.spacing): |
| if i != j: |
| need_resample = True |
| |
| for i, j in zip(current_shape, self.output_size): |
| if i != j: |
| need_resample = True |
| |
| if need_resample: |
| logging.info("Resampling mask to target shape and spacing") |
| logging.info(f"Resize Spacing: {current_spacing} -> {self.spacing}") |
| logging.info(f"Output size: {current_shape} -> {self.output_size}") |
| spacing = monai.transforms.Spacing(pixdim=tuple(self.spacing), mode="nearest") |
| pad_crop = monai.transforms.ResizeWithPadOrCrop(spatial_size=tuple(self.output_size)) |
| labels = pad_crop(spacing(labels.squeeze(0))).unsqueeze(0).to(labels.dtype) |
|
|
| contained_labels = torch.unique(labels) |
| if check_contains_target_labels: |
| |
| for anatomy_label in self.anatomy_list: |
| if anatomy_label not in contained_labels: |
| raise ValueError( |
| ( |
| f"Resampled mask does not contain required class labels {anatomy_label}. " |
| "Please consider increasing the output spacing or specifying a larger output size." |
| ) |
| ) |
| return labels |
|
|
| def read_mask_information(self, mask_file): |
| """ |
| Read mask information from a file. |
| |
| Args: |
| mask_file (str): Path to the mask file. |
| |
| Returns: |
| tuple: A tuple containing the mask tensor and associated information. |
| """ |
| val_data = self.val_transforms(mask_file) |
|
|
| for key in ["pseudo_label", "spacing"]: |
| val_data[key] = val_data[key].unsqueeze(0).to(self.device) |
|
|
| return (val_data["pseudo_label"], val_data["spacing"]) |
|
|
| def find_closest_masks(self, num_img): |
| """ |
| Find the closest matching masks from the database. |
| |
| Args: |
| num_img (int): Number of images to generate. |
| |
| Returns: |
| list: List of closest matching mask candidates. |
| |
| Raises: |
| ValueError: If suitable candidates cannot be found. |
| """ |
| |
| candidates = find_masks( |
| self.anatomy_list, self.spacing, self.output_size, False, self.all_mask_files_json, self.data_root |
| ) |
|
|
| if len(candidates) < num_img: |
| raise ValueError(f"candidate masks are less than {num_img}).") |
|
|
| |
| new_candidates = [] |
| for c in candidates: |
| diff = 0 |
| include_c = True |
| for axis in range(3): |
| if abs(c["dim"][axis]) < self.output_size[axis] - 64: |
| |
| include_c = False |
| break |
| |
| diff += abs( |
| (abs(c["dim"][axis] * c["spacing"][axis]) - self.output_size[axis] * self.spacing[axis]) / 10 |
| ) |
| |
| diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100) |
| |
| diff += abs(abs(c["spacing"][axis]) - self.spacing[axis]) |
| if include_c: |
| new_candidates.append((c, diff)) |
|
|
| |
| new_candidates = sorted(new_candidates, key=lambda x: x[1])[: max(2 * num_img, 5)] |
| final_candidates = [] |
|
|
| |
| image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True) |
| for c, _ in new_candidates: |
| label = image_loader(c["pseudo_label"]) |
| try: |
| label = self.ensure_output_size_and_spacing(label.unsqueeze(0)) |
| except ValueError as e: |
| if "Resampled mask does not contain required class labels" in str(e): |
| continue |
| else: |
| raise e |
| |
| c["spacing"] = self.spacing |
| c["dim"] = self.output_size |
|
|
| final_candidates.append(c) |
| if len(final_candidates) == 0: |
| raise ValueError("Cannot find body region with given anatomy list.") |
| return final_candidates |
|
|
| def quality_check(self, image_data, label_data): |
| """ |
| Perform a quality check on the generated image. |
| Args: |
| image_data (np.ndarray): The generated image. |
| label_data (np.ndarray): The corresponding whole body mask. |
| Returns: |
| bool: True if the image passes the quality check, False otherwise. |
| """ |
| outlier_results = is_outlier(self.median_statistics, image_data, label_data, self.label_int_dict) |
| for label, result in outlier_results.items(): |
| if result.get("is_outlier", False): |
| logging.info( |
| ( |
| f"Generated image quality check for label '{label}' failed: median value {result['median_value']} " |
| f"is outside the acceptable range ({result['low_thresh']} - {result['high_thresh']})." |
| ) |
| ) |
| return False |
| return True |
|
|