import torch from transformers import AutoTokenizer, BatchEncoding from mixinhelpers import CXR_Mixin, ECG_Mixin, ECHO_Mixin, Text_Mixin """ Preprocessor classes for different modalities and their combinations. You can combine different mixins to create preprocessors for multi-modal inputs. Examples below are provided for ECHO+Text, ECG+Text, and CXR+Text. """ class BasePreprocessor: def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: self.tokenizer = AutoTokenizer.from_pretrained(model_name) # duo modality preprocessors class ECHOText_Preprocessor(BasePreprocessor, ECHO_Mixin, Text_Mixin): def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: super().__init__(model_name=model_name) def preprocess_echo_text(self, echo_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: """this can be used in dataloader to correctly collate batches, use the string keys to identify the modalities echo_path: path to echo npy file text: string of text report returns: (echo tensor, tokenized text dict)""" echo = self.preprocess_single_echo(echo_path) # (C, H, W) text_inputs = self.construct_caption( caption=text, tokenizer=self.tokenizer, modality=self.ECHO_KEY ) return echo, text_inputs class ECGText_Preprocessor(BasePreprocessor, ECG_Mixin, Text_Mixin): def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: super().__init__(model_name=model_name) def preprocess_ecg_text(self, ecg_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: """this can be used in dataloader to correctly collate batches, use the string keys to identify the modalities ecg_path: path to ecg npy file text: string of text report returns: (ecg tensor, tokenized text dict)""" ecg = self.preprocess_single_ecg(ecg_path) # (C, L) text_inputs = self.construct_caption( caption=text, tokenizer=self.tokenizer, modality=self.ECG_KEY ) return ecg, text_inputs class CXRText_Preprocessor(BasePreprocessor, CXR_Mixin, Text_Mixin): def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: super().__init__(model_name=model_name) def preprocess_cxr_text(self, cxr_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: """this can be used in dataloader to correctly collate batches, use the string keys to identify the modalities cxr_path: path to cxr image file text: string of text report returns: (cxr tensor, tokenized text dict)""" cxr = self.preprocess_single_cxr(cxr_path) # (C, H, W) text_inputs = self.construct_caption( caption=text, tokenizer=self.tokenizer, modality=self.VISION_KEY ) return cxr, text_inputs