import numpy as np import matplotlib.pyplot as plt from diffusers import DiffusionPipeline # Load the pre-trained model pipeline = DiffusionPipeline.from_pretrained("ankush-003/retinal_fundus") # pipeline.to("cuda") # gradio function for generating image def generate_image(): image = pipeline().images[0] image.save("trial.png") img = plt.imread("trial.png") # Display the image (optional) # plt.imshow(img) # plt.axis("off") # plt.show() return img # gradio interface # import gradio as gr # iface = gr.Interface(fn=generate_image, inputs=None, outputs=[gr.Image(label="Generated Image", type="numpy", tool='editor')], # title="Image Data Generator", # description="This tool generates synthetic images using the DiffusionPipeline model.", # article="### Using the Image Data Generator\n\nSimply click 'Generate Image' to create a synthetic image. The generated image will be displayed below.") # iface.launch(debug=True) # blocks ui import gradio as gr def generate_multiple(num): images = [] for i in range(num): images.append(generate_image()) return images with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown("""

Synthetic Image Generator

""") with gr.Tab("Generate Single Image"): with gr.Row(): with gr.Column(): gr.Markdown("""## Using the Synthetic Image Generator\n\nSimply click 'Generate Image' to create a synthetic image.\n""") gr.Image("./train.png",label="Training Image sample").style( rounded=True, scale=1) gen_button = gr.Button("Generate", variant="primary") gen_img = gr.Image( tool="select", type="numpy", label="Generated Image").style(height=512, width=512, rounded=True) with gr.Tab("Generate Multiple Images"): gr.Markdown( """ ## Using the Synthetic Image Generator to generate multiple images """ ) gen_number = gr.Slider(2, 5, step=1.0, label="Number of Images", info="Generate multiple images") gen_images = gr.Gallery(label="Generated Images").style(columns=[2], rows=[2], object_fit="contain", height="auto") gen_m_button = gr.Button("Generate Images", variant="primary") with gr.Accordion("Read More"): gr.Markdown(""" - [Images used to train the model](https://ieee-dataport.org/open-access/retinal-fundus-multi-disease-image-dataset-rfmid) """) gen_button.click(generate_image, inputs=None, outputs=gen_img) gen_m_button.click(generate_multiple, inputs=gen_number, outputs=gen_images) app.launch(debug=True)