| | import shutil |
| | import traceback |
| | from io import BytesIO |
| | from urllib.parse import urlparse |
| |
|
| | import cv2 |
| | import numpy as np |
| | import pydicom |
| | import requests |
| | import torch |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | from transformers import BitImageProcessor, BlipImageProcessor |
| |
|
| |
|
| | @torch.no_grad() |
| | def model_inference(image, text, model, image_processor, tokenizer): |
| | image = load_image(image) |
| |
|
| | (width, height) = image.size |
| |
|
| | image_size = (height, width) |
| |
|
| | image_processor_outputs = image_processor(image) |
| |
|
| | processed_image = torch.FloatTensor( |
| | np.array(image_processor_outputs["pixel_values"]) |
| | ).to(model.device) |
| |
|
| | tokenized_text = tokenizer( |
| | text, |
| | padding=True, |
| | truncation=True, |
| | return_tensors="pt", |
| | ).to(model.device) |
| |
|
| | output = model.compute_logits(processed_image, [tokenized_text]) |
| | logits = output["logits"] |
| | similarity_prob = logits.sigmoid() |
| |
|
| | similarity_scores = output["similarity_scores"] |
| | similarity_scores = similarity_scores.view(-1) |
| |
|
| | similarity_scores = interpolate_similarity_scores( |
| | similarity_scores, image_size, image_processor |
| | ) |
| | similarity_map = similarity_scores.sigmoid()[0] |
| |
|
| | return similarity_prob, similarity_map |
| |
|
| |
|
| | @torch.no_grad() |
| | def model_inference_multiple_text(image, text_list, model, image_processor, tokenizer): |
| | |
| | probs, similarity_maps = [], [] |
| | for text in text_list: |
| | prob, similarity_map = model_inference( |
| | image, text, model, image_processor, tokenizer |
| | ) |
| | probs.append(prob) |
| | similarity_maps.append(similarity_map) |
| |
|
| | return torch.stack(probs), torch.stack(similarity_maps) |
| |
|
| |
|
| | def interpolate_similarity_scores(similarity_scores, origin_size, image_processor): |
| | (height, width) = origin_size |
| | patch_size = int(similarity_scores.shape[-1] ** 0.5) |
| | scores = similarity_scores.view(1, 1, patch_size, patch_size) |
| |
|
| | if isinstance(image_processor, BlipImageProcessor): |
| | |
| | interpolated_scores = F.interpolate( |
| | scores, |
| | size=(height, width), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | interpolated_scores = interpolated_scores.squeeze(1) |
| |
|
| | elif isinstance(image_processor, BitImageProcessor): |
| | shortest = min(height, width) |
| |
|
| | interpolated_scores = F.interpolate( |
| | scores, |
| | size=(shortest, shortest), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| |
|
| | cropped_left = (width - shortest) // 2 |
| | cropped_top = (height - shortest) // 2 |
| |
|
| | original_size_map = torch.ones(height, width) * -999 |
| | original_size_map[ |
| | cropped_top : cropped_top + shortest, cropped_left : cropped_left + shortest |
| | ] = interpolated_scores.view(shortest, shortest) |
| |
|
| | interpolated_scores = original_size_map |
| | interpolated_scores = interpolated_scores.unsqueeze(0) |
| |
|
| | return interpolated_scores |
| |
|
| |
|
| | |
| | def dicom_to_pil_image(input_file_path, save_dir=None): |
| | """ |
| | Extract the image from a DICOM file and return it as a PIL.Image object. |
| | Args: |
| | input_file_path (str): Path to the input DICOM file. |
| | Returns: |
| | PIL.Image.Image: Processed image. |
| | """ |
| | try: |
| | |
| | dcm_file = pydicom.dcmread(input_file_path) |
| | raw_image = dcm_file.pixel_array |
| |
|
| | assert len(raw_image.shape) == 2, "Expecting single channel (grayscale) image." |
| |
|
| | |
| | raw_image = raw_image - raw_image.min() |
| | normalized_image = raw_image / raw_image.max() |
| | rescaled_image = (normalized_image * 255).astype(np.uint8) |
| |
|
| | |
| | if dcm_file.PhotometricInterpretation == "MONOCHROME1": |
| | rescaled_image = cv2.bitwise_not(rescaled_image) |
| |
|
| | |
| | final_image = cv2.equalizeHist(rescaled_image) |
| |
|
| | |
| | image = Image.fromarray(final_image) |
| |
|
| | if save_dir is not None: |
| | shutil.copy2(input_file_path, save_dir) |
| |
|
| | return image |
| | except Exception: |
| | print(traceback.format_exc()) |
| |
|
| |
|
| | def load_image(image): |
| | """ |
| | Load an image from a file path or a PIL.Image object. |
| | Args: |
| | image (str or PIL.Image.Image): Path to the image file or a PIL.Image object. |
| | Returns: |
| | PIL.Image.Image: Processed image. |
| | """ |
| |
|
| | if isinstance(image, str): |
| | if image.lower().endswith(".dcm"): |
| | image = dicom_to_pil_image(image) |
| | elif ( |
| | image.lower().endswith(".png") |
| | or image.lower().endswith(".jpg") |
| | or image.lower().endswith(".jpeg") |
| | ): |
| | image = Image.open(image) |
| | else: |
| | raise ValueError(f"Invalid image type: {image}") |
| | elif not isinstance(image, Image.Image): |
| | raise ValueError(f"Invalid image type: {type(image)}") |
| |
|
| | return image |
| |
|