|
|
|
|
|
import random |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from transformers import BatchEncoding, PreTrainedTokenizer |
|
|
|
|
|
""" |
|
|
Mixin for all modalities, each mixin has: |
|
|
- preprocess function that takes in path or data and returns tensor |
|
|
- construct_input function that takes in tensor and returns dict with batch |
|
|
dimension for model input |
|
|
- key string for model input dict |
|
|
""" |
|
|
|
|
|
|
|
|
class ECHO_Mixin: |
|
|
LOWER_YELLOW: list[int] = [20, 50, 50] |
|
|
UPPER_YELLOW: list[int] = [100, 255, 255] |
|
|
IMAGE_SIZE: tuple[int, int] = (224, 224) |
|
|
NORM_MEAN: tuple[float, float, float] = (0.48145466, 0.4578275, 0.40821073) |
|
|
NORM_STD: tuple[float, float, float] = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
ECHO_TRANSFORMS = transforms.Compose( |
|
|
[ |
|
|
transforms.ToTensor(), |
|
|
transforms.Resize(IMAGE_SIZE), |
|
|
transforms.Normalize( |
|
|
mean=NORM_MEAN, |
|
|
std=NORM_STD, |
|
|
), |
|
|
] |
|
|
) |
|
|
ECHO_KEY: str = "echo" |
|
|
|
|
|
def grabimage(self, split: str, data: dict[str, np.ndarray]) -> np.ndarray: |
|
|
"""""" |
|
|
if split == "train": |
|
|
caseofinterest = random.choice(list(data.keys())) |
|
|
imageindice = random.choice(list(range(data[caseofinterest].shape[0]))) |
|
|
|
|
|
else: |
|
|
caseofinterest = random.choice(list(data.keys())) |
|
|
imageindice = 0 |
|
|
video = data[caseofinterest] |
|
|
return self.extract_echoframe(imageindice, video) |
|
|
|
|
|
def extract_echoframe(self, imageindice: int, video: np.ndarray) -> np.ndarray: |
|
|
image = video[imageindice] |
|
|
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
|
|
lower_yellow = np.array(self.LOWER_YELLOW) |
|
|
upper_yellow = np.array(self.UPPER_YELLOW) |
|
|
mask = cv2.inRange(hsv_image, lower_yellow, upper_yellow) |
|
|
image[mask > 0] = [0, 0, 0] |
|
|
image = np.array(image, dtype=np.float32) |
|
|
image -= image.min() |
|
|
image /= image.max() |
|
|
image *= 255 |
|
|
|
|
|
image = image |
|
|
image = image[:, :, :] |
|
|
image = image.astype(np.uint8) |
|
|
return image |
|
|
|
|
|
def preprocess_echoseries( |
|
|
self, video_dict: dict[str, np.ndarray], split: str = "valid" |
|
|
) -> torch.Tensor: |
|
|
"""assumes inference mode""" |
|
|
image = self.grabimage(split, video_dict) |
|
|
if not isinstance(image, np.ndarray): |
|
|
raise TypeError("Expected image to be a numpy ndarray") |
|
|
pil_image = Image.fromarray(image) |
|
|
transformed = self.ECHO_TRANSFORMS(pil_image) |
|
|
if not isinstance(transformed, torch.Tensor): |
|
|
transformed = transforms.ToTensor()(pil_image) |
|
|
return transformed |
|
|
|
|
|
def preprocess_single_echo(self, avi_path: str) -> torch.Tensor: |
|
|
"""assumes inference mode, opens AVI file and processes first frame |
|
|
Output: image: torch.Tensor of shape (C, H, W) |
|
|
""" |
|
|
cap = cv2.VideoCapture(avi_path) |
|
|
success, frame = cap.read() |
|
|
cap.release() |
|
|
if not success or frame is None: |
|
|
raise ValueError(f"Could not read frame from AVI file: {avi_path}") |
|
|
image = self.extract_echoframe(0, np.array([frame])) |
|
|
image = self.ECHO_TRANSFORMS(Image.fromarray(image)) |
|
|
if not isinstance(image, torch.Tensor): |
|
|
image = torch.from_numpy(image) |
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
class CXR_Mixin: |
|
|
RESIZE: tuple[int, int] = (256, 256) |
|
|
IMAGE_SIZE: tuple[int, int] = (224, 224) |
|
|
NORM_MEAN: list[float] = [0.5862785803043838] |
|
|
NORM_STD: list[float] = [0.27950088968644304] |
|
|
VISION_KEY: str = "vision" |
|
|
CXR_TRANSFORMS = transforms.Compose( |
|
|
[ |
|
|
transforms.ToTensor(), |
|
|
transforms.Resize(RESIZE), |
|
|
transforms.CenterCrop(IMAGE_SIZE), |
|
|
transforms.Normalize( |
|
|
mean=NORM_MEAN, |
|
|
std=NORM_STD, |
|
|
), |
|
|
] |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def remove_border(pixel_array: np.ndarray) -> np.ndarray: |
|
|
|
|
|
coords = np.column_stack(np.where(pixel_array > 0)) |
|
|
x_min, y_min = coords.min(axis=0) |
|
|
x_max, y_max = coords.max(axis=0) |
|
|
|
|
|
cropped_image = pixel_array[x_min:x_max, y_min:y_max] |
|
|
return cropped_image |
|
|
|
|
|
def preprocess_loaded_cxr(self, img: np.array) -> torch.Tensor: |
|
|
cxr = self.remove_border(img) |
|
|
|
|
|
cxr = np.repeat(cxr[..., np.newaxis], 3, axis=-1) |
|
|
|
|
|
cxr = Image.fromarray(cxr) |
|
|
transformed = self.CXR_TRANSFORMS(cxr) |
|
|
if not isinstance(transformed, torch.Tensor): |
|
|
transformed = transforms.ToTensor()(cxr) |
|
|
return transformed |
|
|
|
|
|
def preprocess_single_cxr(self, image_path: str) -> torch.Tensor: |
|
|
"""assumes inference mode""" |
|
|
with open(image_path, "rb") as fopen: |
|
|
image = Image.open(fopen).convert("RGB") |
|
|
image = np.array(image)[:, :, 0] |
|
|
|
|
|
cxr = self.preprocess_loaded_cxr(image) |
|
|
return cxr |
|
|
|
|
|
|
|
|
class ECG_Mixin: |
|
|
LENGTH: int = 1000 |
|
|
FREQUENCY: int = 100 |
|
|
CHANNELS: int = 12 |
|
|
NORM_MEAN: float = 0.02547506 |
|
|
NORM_SCALE: float = 0.16486814 |
|
|
NORM_VAR: float = 0.0271815 |
|
|
ECG_KEY: str = "ecg" |
|
|
|
|
|
def manual_standardize(self, x: np.ndarray) -> torch.Tensor: |
|
|
""" |
|
|
Apply manual standardization to ECG or other data. |
|
|
Equivalent to sklearn's StandardScaler with given constants. |
|
|
|
|
|
Args: |
|
|
x (np.ndarray): Input array of shape (12, 1000) |
|
|
Returns: |
|
|
torch.Tensor: Scaled array of the same shape |
|
|
""" |
|
|
return torch.from_numpy((x - self.NORM_MEAN) / self.NORM_SCALE).float() |
|
|
|
|
|
def check_ecg(self, ecg: np.ndarray) -> np.ndarray: |
|
|
|
|
|
if np.isnan(ecg).any() or np.isinf(ecg).any(): |
|
|
raise ValueError("ECG contains NaN or Inf values") |
|
|
return ecg[:, : self.LENGTH] |
|
|
|
|
|
def preprocess_single_ecg(self, ecg_path: str) -> torch.Tensor: |
|
|
"""assumes inference mode""" |
|
|
|
|
|
ecg = np.load(ecg_path) |
|
|
if ecg.ndim == 2 and ecg.shape[0] != self.CHANNELS: |
|
|
raise ValueError(f"Expected ECG with {self.CHANNELS} channels, got {ecg.shape[0]}") |
|
|
|
|
|
ecg = self.check_ecg(ecg) |
|
|
transformed = self.manual_standardize(ecg) |
|
|
|
|
|
return transformed |
|
|
|
|
|
|
|
|
class Text_Mixin: |
|
|
MODALITY_LIST: dict[str, str] = {"echo": "echocardiogram", "ecg": "ecg", "vision": "cxr"} |
|
|
MAX_LENGTH: int = 120 |
|
|
TEXT_LENGTH: int = 100 |
|
|
|
|
|
def get_first_n_words(self, text: str, n: int = 100) -> str: |
|
|
"""97.5 percentile of text is less than 35 words""" |
|
|
words = text.split() |
|
|
return " ".join(words[:n]) |
|
|
|
|
|
def createCaption(self, caption: str, modality: str = "") -> str: |
|
|
assert modality in set(self.MODALITY_LIST.keys()) or modality == "", ( |
|
|
f"modality should be in {self.MODALITY_LIST} or empty" |
|
|
) |
|
|
return f"text : {caption}, {modality} looks like : " |
|
|
|
|
|
def createTokenizedCaption(self, caption: str, tokenizer: PreTrainedTokenizer) -> BatchEncoding: |
|
|
encoding = tokenizer( |
|
|
caption, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=self.MAX_LENGTH, |
|
|
return_tensors="pt", |
|
|
) |
|
|
return encoding |
|
|
|
|
|
def construct_caption( |
|
|
self, caption: str, tokenizer: PreTrainedTokenizer, modality: str = "" |
|
|
) -> BatchEncoding: |
|
|
"""given caption string, return tokenized caption dict for model input |
|
|
Output: dict with keys 'input_ids' and 'attention_mask', each of shape (1, L) |
|
|
""" |
|
|
caption_str = self.createCaption(caption, modality) |
|
|
tokenized = self.createTokenizedCaption(caption_str, tokenizer) |
|
|
return tokenized |
|
|
|