import torch import gc import gradio as gr from main import setup, execute_task from arguments import parse_args import os import shutil import glob import time import threading import argparse def list_iter_images(save_dir): # Specify the image extensions you want to search for image_extensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp'] # Add more if needed # Create a list to store the image file paths image_paths = [] # Iterate through the specified image extensions and get the file paths for ext in image_extensions: # Use glob to find all image files with the given extension image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}'))) # Now image_paths contains the list of all image file paths #print(image_paths) return image_paths def clean_dir(save_dir): # Check if the directory exists if os.path.exists(save_dir): # Check if the directory contains any files if len(os.listdir(save_dir)) > 0: # If it contains files, delete all files in the directory for filename in os.listdir(save_dir): file_path = os.path.join(save_dir, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) # Remove file or symbolic link elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove directory and its contents except Exception as e: print(f"Failed to delete {file_path}. Reason: {e}") print(f"All files in {save_dir} have been deleted.") else: print(f"{save_dir} exists but is empty.") else: print(f"{save_dir} does not exist.") def start_over(gallery_state, loaded_model_setup): torch.cuda.empty_cache() # Free up cached memory gc.collect() if gallery_state is not None: gallery_state = None if loaded_model_setup is not None: loaded_model_setup = None # Reset loaded model setup to prevent re-triggering old state return gallery_state, None, None, gr.update(visible=False), loaded_model_setup def setup_model(prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)): if prompt is None or prompt == "": raise gr.Error("You forgot to provide a prompt !") """Clear CUDA memory before starting the training.""" torch.cuda.empty_cache() # Free up cached memory gc.collect() # Set up arguments args = parse_args() args.task = "single" args.prompt = prompt args.model = model args.seed = seed args.n_iters = num_iterations args.lr = learning_rate args.cache_dir = "./HF_model_cache" args.save_dir = "./outputs" args.save_all_images = True if enable_hps is True: args.disable_hps = False args.hps_weighting = hps_w if enable_imagereward is True: args.disable_imagereward = False args.imagereward_weighting = imgrw_w if enable_pickscore is True: args.disable_pickscore = False args.pickscore_weighting = pcks_w if enable_clip is True: args.disable_clip = False args.clip_weighting = clip_w if model == "flux": args.cpu_offloading = True args.enable_multi_apply= True args.multi_step_model = "flux" try: args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args) loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings] return f"{model} model loaded succesfully !", loaded_setup except Exception as e: print(f"Unexpected Error: {e}") return f"Something went wrong with {model} loading: {e}", None def generate_image(setup_args, num_iterations): torch.cuda.empty_cache() # Free up cached memory gc.collect() args = setup_args[0] trainer = setup_args[1] device = setup_args[2] dtype = setup_args[3] shape = setup_args[4] enable_grad = setup_args[5] multi_apply_fn = setup_args[6] settings = setup_args[7] print(f"SETTINGS: {settings}") save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}" clean_dir(save_dir) try: torch.cuda.empty_cache() # Free up cached memory gc.collect() steps_completed = [] result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None} error_status = {"error_occurred": False} # Shared dictionary to track error status thread_status = {"running": False} # Track whether a thread is already running def progress_callback(step): # Limit redundant prints by checking the step number if not steps_completed or step > steps_completed[-1]: steps_completed.append(step) print(f"Progress: Step {step} completed.") def run_main(): thread_status["running"] = True # Mark thread as running try: execute_task( args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback ) except torch.cuda.OutOfMemoryError as e: print(f"CUDA Out of Memory Error: {e}") error_status["error_occurred"] = True except RuntimeError as e: if 'out of memory' in str(e): print(f"Runtime Error: {e}") error_status["error_occurred"] = True else: raise finally: thread_status["running"] = False # Mark thread as completed if not thread_status["running"]: # Ensure no other thread is running main_thread = threading.Thread(target=run_main) main_thread.start() last_step_yielded = 0 while main_thread.is_alive() and not error_status["error_occurred"]: # Check if new steps have been completed if steps_completed and steps_completed[-1] > last_step_yielded: last_step_yielded = steps_completed[-1] png_number = last_step_yielded - 1 # Get the image for this step image_path = os.path.join(save_dir, f"{png_number}.png") if os.path.exists(image_path): yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None) else: yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None) else: time.sleep(0.1) # Sleep to prevent busy waiting if error_status["error_occurred"]: torch.cuda.empty_cache() # Free up cached memory gc.collect() yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None) else: main_thread.join() # Ensure thread completion final_image_path = os.path.join(save_dir, "best_image.png") if os.path.exists(final_image_path): iter_images = list_iter_images(save_dir) torch.cuda.empty_cache() # Free up cached memory gc.collect() time.sleep(0.5) yield (final_image_path, f"Final image saved at {final_image_path}", iter_images) else: torch.cuda.empty_cache() # Free up cached memory gc.collect() yield (None, "Image generation completed, but no final image was found.", None) torch.cuda.empty_cache() # Free up cached memory gc.collect() except torch.cuda.OutOfMemoryError as e: print(f"Global CUDA Out of Memory Error: {e}") yield (None, "CUDA out of memory.", None) except RuntimeError as e: if 'out of memory' in str(e): print(f"Runtime Error: {e}") yield (None, "CUDA out of memory.", None) else: yield (None, f"An error occurred: {str(e)}", None) except Exception as e: print(f"Unexpected Error: {e}") yield (None, f"An unexpected error occurred: {str(e)}", None) def show_gallery_output(gallery_state): if gallery_state is not None: return gr.update(value=gallery_state, visible=True) else: return gr.update(value=None, visible=False) # Create Gradio interface title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization" description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed." css=""" #model-status-id{ height: 126px; } #model-status-id .progress-text{ font-size: 10px!important; } #model-status-id .progress-level-inner{ font-size: 8px!important; } """ with gr.Blocks(css=css, analytics_enabled=False) as demo: loaded_model_setup = gr.State() gallery_state = gr.State() with gr.Column(): gr.Markdown(title) gr.Markdown(description) gr.HTML("""
""") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt") with gr.Row(): chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo") seed = gr.Number(label="seed", value=0) model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id") with gr.Row(): n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations") learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate") with gr.Accordion("Advanced Settings", open=False): with gr.Column(): with gr.Row(): enable_hps = gr.Checkbox(label="HPS ON", value=False, scale=1) hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3) with gr.Row(): enable_imagereward = gr.Checkbox(label="ImageReward ON", value=False, scale=1) imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3) with gr.Row(): enable_pickscore = gr.Checkbox(label="PickScore ON", value=False, scale=1) pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05, interactive=False, scale=3) with gr.Row(): enable_clip = gr.Checkbox(label="CLIP ON", value=False, scale=1) clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3) submit_btn = gr.Button("Submit") gr.Examples( examples = [ "A red dog and a green cat", "A pink elephant and a grey cow", "A toaster riding a bike", "Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski", "A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions", "An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains" ], inputs = [prompt] ) with gr.Column(): output_image = gr.Image(type="filepath", label="Best Generated Image") status = gr.Textbox(label="Status") iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False) def allow_weighting(weight_type): if weight_type is True: return gr.update(interactive=True) else: return gr.update(interactive=False) enable_hps.change( fn = allow_weighting, inputs = [enable_hps], outputs = [hps_w], queue = False ) enable_imagereward.change( fn = allow_weighting, inputs = [enable_imagereward], outputs = [imgrw_w], queue = False ) enable_pickscore.change( fn = allow_weighting, inputs = [enable_pickscore], outputs = [pcks_w], queue = False ) enable_clip.change( fn = allow_weighting, inputs = [enable_clip], outputs = [clip_w], queue = False ) submit_btn.click( fn = start_over, inputs =[gallery_state, loaded_model_setup], # Reset loaded model setup as well outputs = [gallery_state, output_image, status, iter_gallery, loaded_model_setup] # Ensure loaded_model_setup is reset ).then( fn = setup_model, inputs = [prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate], outputs = [model_status, loaded_model_setup] # Load the new setup into the state ).then( fn = generate_image, inputs = [loaded_model_setup, n_iter], outputs = [output_image, status, gallery_state] ).then( fn = show_gallery_output, inputs = [gallery_state], outputs = iter_gallery ) # Launch the app demo.queue().launch(show_error=True, show_api=False)