LSDM / app.py
QinLei086's picture
Upload 28 files
15acbf0 verified
raw
history blame contribute delete
No virus
1.82 kB
import gradio as gr
from evolution import random_walk
from generate import generate
def process_random_walk(img):
img1, _ = random_walk(img)
return img1
def process_first_generation(img1, model_path="pretrain_weights/b2m/unet_ema"):
generated_images = generate(img1, model_path)
return generated_images[0]
def process_second_generation(img1, model_path="pretrain_weights/m2i/unet_ema"):
generated_images = generate(img1, model_path)
return generated_images[0]
# 创建 Gradio 接口
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
input_image = gr.Image(value="figs/4.png", image_mode='L', type='numpy', label="Upload Grayscale Image")
process_button_1 = gr.Button("1. Process Evolution")
with gr.Column():
output_image_1 = gr.Image(value="figs/4_1.png", image_mode='L', type="numpy", label="After Evolution Image",sources=[])
process_button_2 = gr.Button("2. Generate Masks")
with gr.Row():
with gr.Column():
output_image_3 = gr.Image(value="figs/4_1_mask.png", image_mode='L', type="numpy", label="Generated Mask Image",sources=[])
process_button_3 = gr.Button("3. Generate Images")
with gr.Column():
output_image_5 = gr.Image(value="figs/4_1.jpg", type="numpy", image_mode='RGB', label="Final Generated Image 1",sources=[])
process_button_1.click(
process_random_walk,
inputs=[input_image],
outputs=[output_image_1]
)
process_button_2.click(
process_first_generation,
inputs=[output_image_1],
outputs=[output_image_3]
)
process_button_3.click(
process_second_generation,
inputs=[output_image_3],
outputs=[output_image_5]
)
app.launch()