Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from pathlib import Path | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torchvision.transforms import InterpolationMode | |
| import pandas as pd | |
| from transformers import AutoTokenizer | |
| # Check if running in CI environment | |
| IS_CI = os.getenv("CI", "false").lower() == "true" | |
| class TriageDataset(Dataset): | |
| def __init__( | |
| self, | |
| csv_file, | |
| tokenizer_name="emilyalsentzer/Bio_ClinicalBERT", | |
| max_length=128, | |
| transform=None, | |
| mode="multimodal", | |
| image_base_dir=None, | |
| ): | |
| assert mode in [ | |
| "text", | |
| "image", | |
| "multimodal", | |
| ], "Mode must be one of: 'text', 'image', or 'multimodal'" | |
| self.df = pd.read_csv(csv_file) # Create a dataframe from csv file | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| self.max_length = max_length | |
| self.mode = mode.lower() | |
| if self.mode in ["image", "multimodal"]: | |
| if image_base_dir is None: | |
| raise ValueError("image directory must be provided for image or multimodal mode.") | |
| self.image_base_dir = Path(image_base_dir).resolve() | |
| self.transform = ( | |
| transform | |
| if transform | |
| else transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), # Resize first | |
| transforms.RandomResizedCrop( | |
| 224, | |
| scale=(0.9, 1.0), | |
| interpolation=InterpolationMode.BILINEAR, | |
| ), # Slight zoom-in/out | |
| transforms.RandomRotation(degrees=10), # + or - 10° rotation | |
| transforms.ColorJitter( | |
| brightness=0.3, contrast=0.3 | |
| ), # simulate slight imaging variations | |
| transforms.GaussianBlur(kernel_size=3), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| ) | |
| # Label mapping | |
| self.label_map = {"low": 0, "medium": 1, "high": 2} | |
| def __len__(self): | |
| return len( | |
| self.df | |
| ) # returns number of rows so dataloader can know how many batches | |
| # to prepare | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| output = {} | |
| if self.mode in ["text", "multimodal"]: | |
| # Process text | |
| text = row["emr_text"] | |
| tokens = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_tensors="pt", | |
| ) | |
| # removing batch dimension from tokenized tensors | |
| output["input_ids"] = tokens["input_ids"].squeeze(0) | |
| output["attention_mask"] = tokens["attention_mask"].squeeze(0) | |
| # for inference | |
| if "text" in self.df.columns: | |
| output["raw_text"] = text | |
| if self.mode in ["image", "multimodal"]: | |
| # Process image | |
| image_path = Path(row["image_path"]) | |
| if not image_path.is_absolute(): | |
| if image_path.parts[:2] == ("data", self.image_base_dir.name): | |
| image_path = self.image_base_dir.parent / Path(*image_path.parts[1:]) | |
| else: | |
| image_path = self.image_base_dir / image_path | |
| if not image_path.exists(): | |
| msg = f"Image file not found: {image_path}" | |
| raise FileNotFoundError(f"[CI] {msg}" if IS_CI else f"[LOCAL] {msg}") | |
| image = Image.open(image_path).convert("RGB") | |
| output["image"] = self.transform(image) | |
| # Label | |
| if "triage_level" in row and row["triage_level"] in self.label_map: | |
| output["label"] = torch.tensor( | |
| self.label_map[row["triage_level"]], dtype=torch.long | |
| ) | |
| # fields for inference output | |
| if "patient_id" in row: | |
| output["patient_id"] = row["patient_id"] | |
| if "emr_text" in row and "emr_text" not in output: | |
| output["emr_text"] = row["emr_text"] | |
| if "image_path" in row and "image_path" not in output: | |
| output["image_path"] = str(image_path) if "image_path" in locals() else row["image_path"] | |
| return output | |