LDM-SyntheticChestX-Ray / load_data.py
lfolle's picture
Add code.
e5a19d6
import numpy as np
import random
from PIL import Image
import os
from glob import glob
from utils import get_index_to_class_mapping
from create_video import generate_video
def load_overlay(pathology_str:str, index:int):
path = "data"
pathology_path = os.path.join(path, pathology_str, f"*_{index}_overlay.png")
pathology_path = glob(pathology_path)
if len(pathology_path) == 0:
return Image.fromarray(np.zeros((256, 256))).convert("L")
pathology_path = pathology_path[0]
im = Image.open(pathology_path)
return im
def generate_image(pathology_str:str):
pathology_idx = get_index_to_class_mapping()
pathology_idx_inverted = {v: k for k, v in pathology_idx.items()}
idx = random.randint(1, 100)
im = Image.open(os.path.join("data", pathology_str, f"{pathology_idx_inverted[pathology_str] - 1}_{idx}.png"))
im_to_disp = np.array(im)
return im_to_disp, idx
def generate_video_and_gradcam(pathology_str):
image, idx = generate_image(pathology_str)
video = generate_video(image)
generated_gradcam = load_overlay(pathology_str, idx)
return video, generated_gradcam