gjbooth2's picture
Update app.py
c01e289 verified
raw
history blame contribute delete
No virus
2.44 kB
import gradio as gr
import numpy as np
from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler, DiffusionPipeline
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = DiffusionPipeline.from_pretrained("gjbooth2/Unconditional_A4C_1").to(device)
#try to return dataframe
def image_gen(click,rows = 4,cols = 4):
images = pipeline(batch_size=16).images
w, h = images[0].size
grid = Image.new('L', size=(cols*w, rows*h))
for i, image in enumerate(images):
grid.paste(image, box=(i%cols*w, i//cols*h))
return grid
#return 'button clicked'
@spaces.GPU(duration = 300)
def image_gen_modified(rows=4,cols=4):
pic_hold = []
model_output = pipeline(batch_size=16).images
count = 0
for i in range(len(model_output)):
pic = np.array(model_output[i].convert('L'))
max_val = max([element for row in pic for element in row])
min_val = min([element for row in pic for element in row])
if min_val > 55: #for washed out images, set them to all black
normalized_pic = np.ones((128,128))
pic_hold.append(Image.fromarray(np.uint8(normalized_pic)))
if min_val < 56:
def normalize_images(x,min_val,max_val): #normalize pixels to be more homogenous grayscale appearance
return 200*((x-min_val)/(max_val-min_val))
vectorized_normalizer = np.vectorize(normalize_images)
normalized_pic = vectorized_normalizer(pic,min_val,max_val)
pic_hold.append(Image.fromarray(np.uint8(normalized_pic)))
count+=1
w, h = model_output[0].size
grid = Image.new('L', size=(cols*w, rows*h))
for i, image in enumerate(pic_hold):
grid.paste(image, box=(i%cols*w, i//cols*h))
return grid
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown('CS 614 Greg Booth Vision Assignment')
gr.Markdown('This gradio app can be used to generate realistic cardiac ultrasound images.')
gr.HTML("<a href = "+'https://pocus.sg/topic/subcostal-4-chamber/'+" _target='blank'>" +'Example anatomy'+ "</a>")
with gr.Tab('Generate a cardiac ultrasound image'):
playground_btn = gr.Button(value='Push me some images! (may take a couple minutes depending on hardware)')
playground_out = gr.Image()
playground_btn.click(image_gen_modified,outputs = playground_out)
demo.launch()