import gradio as gr import numpy as np from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler, DiffusionPipeline import torch import torch.nn.functional as F from matplotlib import pyplot as plt from PIL import Image import spaces device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline = DiffusionPipeline.from_pretrained("gjbooth2/Unconditional_A4C_1").to(device) #try to return dataframe def image_gen(click,rows = 4,cols = 4): images = pipeline(batch_size=16).images w, h = images[0].size grid = Image.new('L', size=(cols*w, rows*h)) for i, image in enumerate(images): grid.paste(image, box=(i%cols*w, i//cols*h)) return grid #return 'button clicked' @spaces.GPU(duration = 300) def image_gen_modified(rows=4,cols=4): pic_hold = [] model_output = pipeline(batch_size=16).images count = 0 for i in range(len(model_output)): pic = np.array(model_output[i].convert('L')) max_val = max([element for row in pic for element in row]) min_val = min([element for row in pic for element in row]) if min_val > 55: #for washed out images, set them to all black normalized_pic = np.ones((128,128)) pic_hold.append(Image.fromarray(np.uint8(normalized_pic))) if min_val < 56: def normalize_images(x,min_val,max_val): #normalize pixels to be more homogenous grayscale appearance return 200*((x-min_val)/(max_val-min_val)) vectorized_normalizer = np.vectorize(normalize_images) normalized_pic = vectorized_normalizer(pic,min_val,max_val) pic_hold.append(Image.fromarray(np.uint8(normalized_pic))) count+=1 w, h = model_output[0].size grid = Image.new('L', size=(cols*w, rows*h)) for i, image in enumerate(pic_hold): grid.paste(image, box=(i%cols*w, i//cols*h)) return grid with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown('CS 614 Greg Booth Vision Assignment') gr.Markdown('This gradio app can be used to generate realistic cardiac ultrasound images.') gr.HTML("" +'Example anatomy'+ "") with gr.Tab('Generate a cardiac ultrasound image'): playground_btn = gr.Button(value='Push me some images! (may take a couple minutes depending on hardware)') playground_out = gr.Image() playground_btn.click(image_gen_modified,outputs = playground_out) demo.launch()