Spaces:
Sleeping
Sleeping
import torch | |
import gc | |
import gradio as gr | |
from main import setup, execute_task | |
from arguments import parse_args | |
import os | |
import shutil | |
import glob | |
import time | |
import threading | |
import argparse | |
def list_iter_images(save_dir): | |
# Specify the image extensions you want to search for | |
image_extensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp'] # Add more if needed | |
# Create a list to store the image file paths | |
image_paths = [] | |
# Iterate through the specified image extensions and get the file paths | |
for ext in image_extensions: | |
# Use glob to find all image files with the given extension | |
image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}'))) | |
# Now image_paths contains the list of all image file paths | |
#print(image_paths) | |
return image_paths | |
def clean_dir(save_dir): | |
# Check if the directory exists | |
if os.path.exists(save_dir): | |
# Check if the directory contains any files | |
if len(os.listdir(save_dir)) > 0: | |
# If it contains files, delete all files in the directory | |
for filename in os.listdir(save_dir): | |
file_path = os.path.join(save_dir, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) # Remove file or symbolic link | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) # Remove directory and its contents | |
except Exception as e: | |
print(f"Failed to delete {file_path}. Reason: {e}") | |
print(f"All files in {save_dir} have been deleted.") | |
else: | |
print(f"{save_dir} exists but is empty.") | |
else: | |
print(f"{save_dir} does not exist.") | |
def start_over(gallery_state, loaded_model_setup): | |
torch.cuda.empty_cache() # Free up cached memory | |
gc.collect() | |
if gallery_state is not None: | |
gallery_state = None | |
if loaded_model_setup is not None: | |
loaded_model_setup = None # Reset loaded model setup to prevent re-triggering old state | |
return gallery_state, None, None, gr.update(visible=False), loaded_model_setup | |
def setup_model(prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)): | |
if prompt is None or prompt == "": | |
raise gr.Error("You forgot to provide a prompt !") | |
"""Clear CUDA memory before starting the training.""" | |
torch.cuda.empty_cache() # Free up cached memory | |
gc.collect() | |
# Set up arguments | |
args = parse_args() | |
args.task = "single" | |
args.prompt = prompt | |
args.model = model | |
args.seed = seed | |
args.n_iters = num_iterations | |
args.lr = learning_rate | |
args.cache_dir = "./HF_model_cache" | |
args.save_dir = "./outputs" | |
args.save_all_images = True | |
if enable_hps is True: | |
args.disable_hps = False | |
args.hps_weighting = hps_w | |
if enable_imagereward is True: | |
args.disable_imagereward = False | |
args.imagereward_weighting = imgrw_w | |
if enable_pickscore is True: | |
args.disable_pickscore = False | |
args.pickscore_weighting = pcks_w | |
if enable_clip is True: | |
args.disable_clip = False | |
args.clip_weighting = clip_w | |
if model == "flux": | |
args.cpu_offloading = True | |
args.enable_multi_apply= True | |
args.multi_step_model = "flux" | |
try: | |
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args) | |
loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings] | |
return f"{model} model loaded succesfully !", loaded_setup | |
except Exception as e: | |
print(f"Unexpected Error: {e}") | |
return f"Something went wrong with {model} loading: {e}", None | |
def generate_image(setup_args, num_iterations): | |
torch.cuda.empty_cache() # Free up cached memory | |
gc.collect() | |
args = setup_args[0] | |
trainer = setup_args[1] | |
device = setup_args[2] | |
dtype = setup_args[3] | |
shape = setup_args[4] | |
enable_grad = setup_args[5] | |
multi_apply_fn = setup_args[6] | |
settings = setup_args[7] | |
print(f"SETTINGS: {settings}") | |
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}" | |
clean_dir(save_dir) | |
try: | |
torch.cuda.empty_cache() # Free up cached memory | |
gc.collect() | |
steps_completed = [] | |
result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None} | |
error_status = {"error_occurred": False} # Shared dictionary to track error status | |
thread_status = {"running": False} # Track whether a thread is already running | |
def progress_callback(step): | |
# Limit redundant prints by checking the step number | |
if not steps_completed or step > steps_completed[-1]: | |
steps_completed.append(step) | |
print(f"Progress: Step {step} completed.") | |
def run_main(): | |
thread_status["running"] = True # Mark thread as running | |
try: | |
execute_task( | |
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback | |
) | |
except torch.cuda.OutOfMemoryError as e: | |
print(f"CUDA Out of Memory Error: {e}") | |
error_status["error_occurred"] = True | |
except RuntimeError as e: | |
if 'out of memory' in str(e): | |
print(f"Runtime Error: {e}") | |
error_status["error_occurred"] = True | |
else: | |
raise | |
finally: | |
thread_status["running"] = False # Mark thread as completed | |
if not thread_status["running"]: # Ensure no other thread is running | |
main_thread = threading.Thread(target=run_main) | |
main_thread.start() | |
last_step_yielded = 0 | |
while main_thread.is_alive() and not error_status["error_occurred"]: | |
# Check if new steps have been completed | |
if steps_completed and steps_completed[-1] > last_step_yielded: | |
last_step_yielded = steps_completed[-1] | |
png_number = last_step_yielded - 1 | |
# Get the image for this step | |
image_path = os.path.join(save_dir, f"{png_number}.png") | |
if os.path.exists(image_path): | |
yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None) | |
else: | |
yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None) | |
else: | |
time.sleep(0.1) # Sleep to prevent busy waiting | |
if error_status["error_occurred"]: | |
torch.cuda.empty_cache() # Free up cached memory | |
gc.collect() | |
yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None) | |
else: | |
main_thread.join() # Ensure thread completion | |
final_image_path = os.path.join(save_dir, "best_image.png") | |
if os.path.exists(final_image_path): | |
iter_images = list_iter_images(save_dir) | |
torch.cuda.empty_cache() # Free up cached memory | |
gc.collect() | |
time.sleep(0.5) | |
yield (final_image_path, f"Final image saved at {final_image_path}", iter_images) | |
else: | |
torch.cuda.empty_cache() # Free up cached memory | |
gc.collect() | |
yield (None, "Image generation completed, but no final image was found.", None) | |
torch.cuda.empty_cache() # Free up cached memory | |
gc.collect() | |
except torch.cuda.OutOfMemoryError as e: | |
print(f"Global CUDA Out of Memory Error: {e}") | |
yield (None, "CUDA out of memory.", None) | |
except RuntimeError as e: | |
if 'out of memory' in str(e): | |
print(f"Runtime Error: {e}") | |
yield (None, "CUDA out of memory.", None) | |
else: | |
yield (None, f"An error occurred: {str(e)}", None) | |
except Exception as e: | |
print(f"Unexpected Error: {e}") | |
yield (None, f"An unexpected error occurred: {str(e)}", None) | |
def show_gallery_output(gallery_state): | |
if gallery_state is not None: | |
return gr.update(value=gallery_state, visible=True) | |
else: | |
return gr.update(value=None, visible=False) | |
# Create Gradio interface | |
title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization" | |
description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed." | |
css=""" | |
#model-status-id{ | |
height: 126px; | |
} | |
#model-status-id .progress-text{ | |
font-size: 10px!important; | |
} | |
#model-status-id .progress-level-inner{ | |
font-size: 8px!important; | |
} | |
""" | |
with gr.Blocks(css=css, analytics_enabled=False) as demo: | |
loaded_model_setup = gr.State() | |
gallery_state = gr.State() | |
with gr.Column(): | |
gr.Markdown(title) | |
gr.Markdown(description) | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href='https://github.com/ExplainableML/ReNO'> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
<a href='https://arxiv.org/abs/2406.04312v1'> | |
<img src='https://img.shields.io/badge/Paper-Arxiv-red'> | |
</a> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
with gr.Row(): | |
chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo") | |
seed = gr.Number(label="seed", value=0) | |
model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id") | |
with gr.Row(): | |
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations") | |
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate") | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Column(): | |
with gr.Row(): | |
enable_hps = gr.Checkbox(label="HPS ON", value=False, scale=1) | |
hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3) | |
with gr.Row(): | |
enable_imagereward = gr.Checkbox(label="ImageReward ON", value=False, scale=1) | |
imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3) | |
with gr.Row(): | |
enable_pickscore = gr.Checkbox(label="PickScore ON", value=False, scale=1) | |
pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05, interactive=False, scale=3) | |
with gr.Row(): | |
enable_clip = gr.Checkbox(label="CLIP ON", value=False, scale=1) | |
clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3) | |
submit_btn = gr.Button("Submit") | |
gr.Examples( | |
examples = [ | |
"A red dog and a green cat", | |
"A pink elephant and a grey cow", | |
"A toaster riding a bike", | |
"Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski", | |
"A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions", | |
"An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains" | |
], | |
inputs = [prompt] | |
) | |
with gr.Column(): | |
output_image = gr.Image(type="filepath", label="Best Generated Image") | |
status = gr.Textbox(label="Status") | |
iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False) | |
def allow_weighting(weight_type): | |
if weight_type is True: | |
return gr.update(interactive=True) | |
else: | |
return gr.update(interactive=False) | |
enable_hps.change( | |
fn = allow_weighting, | |
inputs = [enable_hps], | |
outputs = [hps_w], | |
queue = False | |
) | |
enable_imagereward.change( | |
fn = allow_weighting, | |
inputs = [enable_imagereward], | |
outputs = [imgrw_w], | |
queue = False | |
) | |
enable_pickscore.change( | |
fn = allow_weighting, | |
inputs = [enable_pickscore], | |
outputs = [pcks_w], | |
queue = False | |
) | |
enable_clip.change( | |
fn = allow_weighting, | |
inputs = [enable_clip], | |
outputs = [clip_w], | |
queue = False | |
) | |
submit_btn.click( | |
fn = start_over, | |
inputs =[gallery_state, loaded_model_setup], # Reset loaded model setup as well | |
outputs = [gallery_state, output_image, status, iter_gallery, loaded_model_setup] # Ensure loaded_model_setup is reset | |
).then( | |
fn = setup_model, | |
inputs = [prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate], | |
outputs = [model_status, loaded_model_setup] # Load the new setup into the state | |
).then( | |
fn = generate_image, | |
inputs = [loaded_model_setup, n_iter], | |
outputs = [output_image, status, gallery_state] | |
).then( | |
fn = show_gallery_output, | |
inputs = [gallery_state], | |
outputs = iter_gallery | |
) | |
# Launch the app | |
demo.queue().launch(show_error=True, show_api=False) |