Spaces:
Runtime error
Runtime error
import numpy as np | |
import matplotlib.pyplot as plt | |
from diffusers import DiffusionPipeline | |
# Load the pre-trained model | |
pipeline = DiffusionPipeline.from_pretrained("ankush-003/retinal_fundus") | |
# pipeline.to("cuda") | |
# gradio function for generating image | |
def generate_image(): | |
image = pipeline().images[0] | |
image.save("trial.png") | |
img = plt.imread("trial.png") | |
# Display the image (optional) | |
# plt.imshow(img) | |
# plt.axis("off") | |
# plt.show() | |
return img | |
# gradio interface | |
# import gradio as gr | |
# iface = gr.Interface(fn=generate_image, inputs=None, outputs=[gr.Image(label="Generated Image", type="numpy", tool='editor')], | |
# title="Image Data Generator", | |
# description="This tool generates synthetic images using the DiffusionPipeline model.", | |
# article="### Using the Image Data Generator\n\nSimply click 'Generate Image' to create a synthetic image. The generated image will be displayed below.") | |
# iface.launch(debug=True) | |
# blocks ui | |
import gradio as gr | |
def generate_multiple(num): | |
images = [] | |
for i in range(num): | |
images.append(generate_image()) | |
return images | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
gr.Markdown("""<h1 style="text-align: center;">Synthetic Image Generator</h1>""") | |
with gr.Tab("Generate Single Image"): | |
gr.Markdown("## Using the Synthetic Generator\n\nSimply click 'Generate Image' to create a synthetic image. The generated image will be displayed below.") | |
gen_img = gr.Image( tool="select", type="numpy", label="Generated Image").style(height=256, width=256, rounded=True) | |
gen_button = gr.Button("Generate", variant="primary") | |
with gr.Tab("Generate Multiple Images"): | |
gr.Markdown( | |
""" | |
## Using the Synthetic Image Generator to generate multiple images | |
""" | |
) | |
gen_number = gr.Slider(2, 5, step=1.0, label="Number of Images", info="Generate multiple images") | |
gen_images = gr.Gallery(label="Generated Images").style(columns=[2], rows=[2], object_fit="contain", height="auto") | |
gen_m_button = gr.Button("Generate Images", variant="primary") | |
with gr.Accordion("Read More"): | |
gr.Markdown(""" | |
- [Images used to train the model](https://ieee-dataport.org/open-access/retinal-fundus-multi-disease-image-dataset-rfmid) | |
""") | |
gen_button.click(generate_image, inputs=None, outputs=gen_img) | |
gen_m_button.click(generate_multiple, inputs=gen_number, outputs=gen_images) | |
app.launch(debug=True) |