ReNO / app.py
fffiloni's picture
Update app.py
2ae6665 verified
raw
history blame
6.39 kB
import gradio as gr
from main import main
from arguments import parse_args
import os
import shutil
import glob
def list_iter_images(save_dir):
# Specify the image extensions you want to search for
image_extensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp'] # Add more if needed
# Create a list to store the image file paths
image_paths = []
# Iterate through the specified image extensions and get the file paths
for ext in image_extensions:
# Use glob to find all image files with the given extension
image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}')))
# Now image_paths contains the list of all image file paths
print(image_paths)
return image_paths
def clean_dir(save_dir):
# Check if the directory exists
if os.path.exists(save_dir):
# Check if the directory contains any files
if len(os.listdir(save_dir)) > 0:
# If it contains files, delete all files in the directory
for filename in os.listdir(save_dir):
file_path = os.path.join(save_dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or symbolic link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory and its contents
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
print(f"All files in {save_dir} have been deleted.")
else:
print(f"{save_dir} exists but is empty.")
else:
print(f"{save_dir} does not exist.")
def generate_image(prompt, model, num_iterations, learning_rate, progress=gr.Progress(track_tqdm=True)):
# Set up arguments
args = parse_args()
args.task = "single"
args.prompt = prompt
args.model = model
args.n_iters = num_iterations
args.lr = learning_rate
args.cache_dir = "./HF_model_cache"
args.save_dir = "./outputs"
args.save_all_images = True
settings = (
f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}"
f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}"
f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}"
f"_reg{args.reg_weight if args.enable_reg else '0'}"
f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}"
f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}"
f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}"
f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
)
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
clean_dir(save_dir)
try:
# Run the main function with progress tracking
def progress_callback(step):
progress(step / num_iterations, f"Iteration {step}/{num_iterations}")
best_image, total_init_rewards, total_best_rewards = main(args, progress_callback)
# Return the path to the generated image
image_path = f"{save_dir}/best_image.png"
if os.path.exists(image_path):
iter_images = list_iter_images(save_dir)
return image_path, f"Image generated successfully and saved at {image_path}", iter_images
else:
return None, "Image generation completed, but the file was not found.", None
except Exception as e:
return None, f"An error occurred: {str(e)}", None
# Create Gradio interface
title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown(title)
gr.Markdown(description)
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href='https://github.com/ExplainableML/ReNO'>
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href='https://arxiv.org/abs/2406.04312v1'>
<img src='https://img.shields.io/badge/Paper-Arxiv-red'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model", value="sd-turbo")
with gr.Row():
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
submit_btn = gr.Button("Submit")
gr.Examples(
examples = [
"A minimalist logo design of a reindeer, fully rendered. The reindeer features distinct, complete shapes using bold and flat colors. The design emphasizes simplicity and clarity, suitable for logo use with a sharp outline and white background.",
"A blue scooter is parked near a curb in front of a green vintage car",
"A impressionistic oil painting: a lone figure walking on a misty beach, a weathered lighthouse on a cliff, seagulls above crashing waves",
"A bird with 8 legs",
"An orange chair to the right of a black airplane",
"A pink elephant and a grey cow",
],
inputs = [prompt]
)
with gr.Column():
output_image = gr.Image(type="filepath", label="Best Generated Image")
status = gr.Textbox(label="Status")
iter_gallery = gr.Gallery(label="Iterations", columns=4)
submit_btn.click(
fn = generate_image,
inputs = [prompt, chosen_model, n_iter, learning_rate],
outputs = [output_image, status, iter_gallery]
)
# Launch the app
demo.queue().launch(show_error=True, show_api=False)