LDM-SyntheticChestX-Ray / load_data.py
lfolle's picture
Add code.
e5a19d6
raw history blame
No virus
1.18 kB
import numpy as np
import random
from PIL import Image
import os
from glob import glob
from utils import get_index_to_class_mapping
from create_video import generate_video
def load_overlay(pathology_str:str, index:int):
path = "data"
pathology_path = os.path.join(path, pathology_str, f"*_{index}_overlay.png")
pathology_path = glob(pathology_path)
if len(pathology_path) == 0:
return Image.fromarray(np.zeros((256, 256))).convert("L")
pathology_path = pathology_path[0]
im = Image.open(pathology_path)
return im
def generate_image(pathology_str:str):
pathology_idx = get_index_to_class_mapping()
pathology_idx_inverted = {v: k for k, v in pathology_idx.items()}
idx = random.randint(1, 100)
im = Image.open(os.path.join("data", pathology_str, f"{pathology_idx_inverted[pathology_str] - 1}_{idx}.png"))
im_to_disp = np.array(im)
return im_to_disp, idx
def generate_video_and_gradcam(pathology_str):
image, idx = generate_image(pathology_str)
video = generate_video(image)
generated_gradcam = load_overlay(pathology_str, idx)
return video, generated_gradcam