import numpy as np
import matplotlib.pyplot as plt
from diffusers import DiffusionPipeline
# Load the pre-trained model
pipeline = DiffusionPipeline.from_pretrained("ankush-003/retinal_fundus")
# pipeline.to("cuda")
# gradio function for generating image
def generate_image():
image = pipeline().images[0]
image.save("trial.png")
img = plt.imread("trial.png")
# Display the image (optional)
# plt.imshow(img)
# plt.axis("off")
# plt.show()
return img
# gradio interface
# import gradio as gr
# iface = gr.Interface(fn=generate_image, inputs=None, outputs=[gr.Image(label="Generated Image", type="numpy", tool='editor')],
# title="Image Data Generator",
# description="This tool generates synthetic images using the DiffusionPipeline model.",
# article="### Using the Image Data Generator\n\nSimply click 'Generate Image' to create a synthetic image. The generated image will be displayed below.")
# iface.launch(debug=True)
# blocks ui
import gradio as gr
def generate_multiple(num):
images = []
for i in range(num):
images.append(generate_image())
return images
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("""
Synthetic Image Generator
""")
with gr.Tab("Generate Single Image"):
with gr.Row():
with gr.Column():
gr.Markdown("""## Using the Synthetic Image Generator\n\nSimply click 'Generate Image' to create a synthetic image.\n""")
gr.Image("./train.png",label="Training Image sample").style( rounded=True, scale=1)
gen_button = gr.Button("Generate", variant="primary")
gen_img = gr.Image( tool="select", type="numpy", label="Generated Image").style(height=512, width=512, rounded=True)
with gr.Tab("Generate Multiple Images"):
gr.Markdown(
"""
## Using the Synthetic Image Generator to generate multiple images
"""
)
gen_number = gr.Slider(2, 5, step=1.0, label="Number of Images", info="Generate multiple images")
gen_images = gr.Gallery(label="Generated Images").style(columns=[2], rows=[2], object_fit="contain", height="auto")
gen_m_button = gr.Button("Generate Images", variant="primary")
with gr.Accordion("Read More"):
gr.Markdown("""
- [Images used to train the model](https://ieee-dataport.org/open-access/retinal-fundus-multi-disease-image-dataset-rfmid)
""")
gen_button.click(generate_image, inputs=None, outputs=gen_img)
gen_m_button.click(generate_multiple, inputs=gen_number, outputs=gen_images)
app.launch(debug=True)