import argparse import os import shutil import time import yaml import sys import gdown import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F from monai.config import KeysCollection from monai.metrics import Cumulative, CumulativeAverage from monai.networks.nets import milmodel, resnet, MILModel from monai.transforms import ( Compose, GridPatchd, LoadImaged, MapTransform, RandFlipd, RandGridPatchd, RandRotate90d, ScaleIntensityRanged, SplitDimd, ToTensord, ConcatItemsd, SelectItemsd, EnsureChannelFirstd, RepeatChanneld, DeleteItemsd, EnsureTyped, ClipIntensityPercentilesd, MaskIntensityd, HistogramNormalized, RandBiasFieldd, RandCropByPosNegLabeld, NormalizeIntensityd, SqueezeDimd, CropForegroundd, ScaleIntensityd, SpatialPadd, CenterSpatialCropd, ScaleIntensityd, Transposed, RandWeightedCropd, ) from sklearn.metrics import cohen_kappa_score from torch.cuda.amp import GradScaler, autocast from torch.utils.data.dataloader import default_collate from torchvision.models.resnet import ResNet50_Weights from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd from torch.utils.data.distributed import DistributedSampler from torch.utils.tensorboard import SummaryWriter import matplotlib.patches as patches import matplotlib.pyplot as plt import wandb import math from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset from src.model.MIL import MILModel_3D from src.model.csPCa_model import csPCa_Model import logging from pathlib import Path def save_pirads_checkpoint(model, epoch, args, filename="model.pth", best_acc=0): """Save checkpoint""" state_dict = model.state_dict() save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict} filename = os.path.join(args.logdir, filename) torch.save(save_dict, filename) logging.info("Saving checkpoint", filename) def save_cspca_checkpoint(model, val_metric, model_dir): state_dict = model.state_dict() save_dict = { 'epoch' : val_metric['epoch'], 'loss' : val_metric['loss'], 'auc' : val_metric['auc'], 'sensitivity' : val_metric['sensitivity'], 'specificity' : val_metric['specificity'], 'state' : val_metric['state'], 'state_dict' : state_dict, } torch.save(save_dict, os.path.join(model_dir,f"cspca_model.pth")) logging.info('Saving model with auc: ', str(val_metric['auc'])) def get_metrics(metric_dict: dict): for metric_name, metric_list in metric_dict.items(): metric_list = np.array(metric_list) lower = np.percentile(metric_list, 2.5) upper = np.percentile(metric_list, 97.5) mean_metric = np.mean(metric_list) logging.info(f"Mean {metric_name}: {mean_metric:.3f}") logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})") def setup_logging(log_file): log_file = Path(log_file) log_file.parent.mkdir(parents=True, exist_ok=True) if log_file.exists(): log_file.write_text("") # overwrite with empty string logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", handlers=[ logging.FileHandler(log_file), ], ) def validate_steps(steps): REQUIRES = { "get_segmentation_mask": ["register_and_crop"], "histogram_match": ["get_segmentation_mask", "register_and_crop"], "get_heatmap": ["get_segmentation_mask", "histogram_match", "register_and_crop"], } for i, step in enumerate(steps): required = REQUIRES.get(step, []) for req in required: if req not in steps[:i]: logging.error( f"Step '{step}' requires '{req}' to be executed before it. " f"Given order: {steps}" ) sys.exit(1) def get_patch_coordinate(patches_top_5, parent_image, args): sample = np.array([i.transpose(1,2,0) for i in patches_top_5]) coords = [] rows, h, w, slices = sample.shape for i in range(rows): for j in range(slices): if j == 0: for k in range(parent_image.shape[2]): img_temp = parent_image[:, :, k] H, W = img_temp.shape h, w = sample[i, :, :, j].shape a,b = 0, 0 # Initialize a and b bool1 = False for l in range(H - h + 1): for m in range(W - w + 1): if np.array_equal(img_temp[l:l+h, m:m+w], sample[i, :, :, j]): a,b = l, m # top-left corner coords.append((a,b,k)) bool1 = True break if bool1: break if bool1: break return coords def get_parent_image(temp_data_list, args): transform_image = Compose( [ LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32), ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"), NormalizeIntensity_customd(keys=["image"], mask_key="mask", channel_wise=True), EnsureTyped(keys=["label"], dtype=torch.float32), ToTensord(keys=["image", "label"]), ] ) dataset_image = Dataset(data=temp_data_list, transform=transform_image) return dataset_image[0]['image'][0].numpy() ''' def visualise_patches(): sample = np.array([i.transpose(1,2,0) for i in patches_top_5]) rows = len(patches_top_5) img = sample[0] coords = [] rows, h, w, slices = sample.shape fig, axes = plt.subplots(nrows=rows, ncols=slices, figsize=(slices * 3, rows * 3)) for i in range(rows): for j in range(slices): ax = axes[i, j] if j == 0: for k in range(parent_image.shape[2]): img_temp = parent_image[:, :, k] H, W = img_temp.shape h, w = sample[i, :, :, j].shape a,b = 0, 0 # Initialize a and b bool1 = False for l in range(H - h + 1): for m in range(W - w + 1): if np.array_equal(img_temp[l:l+h, m:m+w], sample[i, :, :, j]): a,b = l, m # top-left corner coords.append((a,b,k)) bool1 = True break if bool1: break if bool1: break ax.imshow(parent_image[:, :, k+j], cmap='gray') rect = patches.Rectangle((b, a), args.tile_size, args.tile_size, linewidth=2, edgecolor='red', facecolor='none') ax.add_patch(rect) ax.axis('off') plt.tight_layout() plt.show() a=1 '''