|
|
import shutil |
|
|
import traceback |
|
|
from io import BytesIO |
|
|
from urllib.parse import urlparse |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import pydicom |
|
|
import requests |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
from transformers import BitImageProcessor, BlipImageProcessor |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def model_inference(image, text, model, image_processor, tokenizer): |
|
|
image = load_image(image) |
|
|
|
|
|
(width, height) = image.size |
|
|
|
|
|
image_size = (height, width) |
|
|
|
|
|
image_processor_outputs = image_processor(image) |
|
|
|
|
|
processed_image = torch.FloatTensor( |
|
|
np.array(image_processor_outputs["pixel_values"]) |
|
|
).to(model.device) |
|
|
|
|
|
tokenized_text = tokenizer( |
|
|
text, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
).to(model.device) |
|
|
|
|
|
output = model.compute_logits(processed_image, [tokenized_text]) |
|
|
logits = output["logits"] |
|
|
similarity_prob = logits.sigmoid() |
|
|
|
|
|
similarity_scores = output["similarity_scores"] |
|
|
similarity_scores = similarity_scores.view(-1) |
|
|
|
|
|
similarity_scores = interpolate_similarity_scores( |
|
|
similarity_scores, image_size, image_processor |
|
|
) |
|
|
similarity_map = similarity_scores.sigmoid()[0] |
|
|
|
|
|
return similarity_prob, similarity_map |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def model_inference_multiple_text(image, text_list, model, image_processor, tokenizer): |
|
|
|
|
|
probs, similarity_maps = [], [] |
|
|
for text in text_list: |
|
|
prob, similarity_map = model_inference( |
|
|
image, text, model, image_processor, tokenizer |
|
|
) |
|
|
probs.append(prob) |
|
|
similarity_maps.append(similarity_map) |
|
|
|
|
|
return torch.stack(probs), torch.stack(similarity_maps) |
|
|
|
|
|
|
|
|
def interpolate_similarity_scores(similarity_scores, origin_size, image_processor): |
|
|
(height, width) = origin_size |
|
|
patch_size = int(similarity_scores.shape[-1] ** 0.5) |
|
|
scores = similarity_scores.view(1, 1, patch_size, patch_size) |
|
|
|
|
|
if isinstance(image_processor, BlipImageProcessor): |
|
|
|
|
|
interpolated_scores = F.interpolate( |
|
|
scores, |
|
|
size=(height, width), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
interpolated_scores = interpolated_scores.squeeze(1) |
|
|
|
|
|
elif isinstance(image_processor, BitImageProcessor): |
|
|
shortest = min(height, width) |
|
|
|
|
|
interpolated_scores = F.interpolate( |
|
|
scores, |
|
|
size=(shortest, shortest), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
cropped_left = (width - shortest) // 2 |
|
|
cropped_top = (height - shortest) // 2 |
|
|
|
|
|
original_size_map = torch.ones(height, width) * -999 |
|
|
original_size_map[ |
|
|
cropped_top : cropped_top + shortest, cropped_left : cropped_left + shortest |
|
|
] = interpolated_scores.view(shortest, shortest) |
|
|
|
|
|
interpolated_scores = original_size_map |
|
|
interpolated_scores = interpolated_scores.unsqueeze(0) |
|
|
|
|
|
return interpolated_scores |
|
|
|
|
|
|
|
|
|
|
|
def dicom_to_pil_image(input_file_path, save_dir=None): |
|
|
""" |
|
|
Extract the image from a DICOM file and return it as a PIL.Image object. |
|
|
Args: |
|
|
input_file_path (str): Path to the input DICOM file. |
|
|
Returns: |
|
|
PIL.Image.Image: Processed image. |
|
|
""" |
|
|
try: |
|
|
|
|
|
dcm_file = pydicom.dcmread(input_file_path) |
|
|
raw_image = dcm_file.pixel_array |
|
|
|
|
|
assert len(raw_image.shape) == 2, "Expecting single channel (grayscale) image." |
|
|
|
|
|
|
|
|
raw_image = raw_image - raw_image.min() |
|
|
normalized_image = raw_image / raw_image.max() |
|
|
rescaled_image = (normalized_image * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
if dcm_file.PhotometricInterpretation == "MONOCHROME1": |
|
|
rescaled_image = cv2.bitwise_not(rescaled_image) |
|
|
|
|
|
|
|
|
final_image = cv2.equalizeHist(rescaled_image) |
|
|
|
|
|
|
|
|
image = Image.fromarray(final_image) |
|
|
|
|
|
if save_dir is not None: |
|
|
shutil.copy2(input_file_path, save_dir) |
|
|
|
|
|
return image |
|
|
except Exception: |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
|
|
|
def load_image(image): |
|
|
""" |
|
|
Load an image from a file path or a PIL.Image object. |
|
|
Args: |
|
|
image (str or PIL.Image.Image): Path to the image file or a PIL.Image object. |
|
|
Returns: |
|
|
PIL.Image.Image: Processed image. |
|
|
""" |
|
|
|
|
|
if isinstance(image, str): |
|
|
if image.lower().endswith(".dcm"): |
|
|
image = dicom_to_pil_image(image) |
|
|
elif ( |
|
|
image.lower().endswith(".png") |
|
|
or image.lower().endswith(".jpg") |
|
|
or image.lower().endswith(".jpeg") |
|
|
): |
|
|
image = Image.open(image) |
|
|
else: |
|
|
raise ValueError(f"Invalid image type: {image}") |
|
|
elif not isinstance(image, Image.Image): |
|
|
raise ValueError(f"Invalid image type: {type(image)}") |
|
|
|
|
|
return image |
|
|
|