import gradio as gr import torch from diffusers import StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForText2Image from diffusers.utils import load_image from PIL import Image import time import random import os import gc # Garbage collector import logging # --- Configuration --- # Setup basic logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Ensure CPU is used DEVICE = "cpu" TORCH_DTYPE = torch.float32 # float16/bfloat16 not practical on CPU # Model definitions # We need to know the base model for LoRAs and compatible IP-Adapters MODEL_CONFIG = { "BlaireSilver13/youtube-thumbnail": { "repo_id": "BlaireSilver13/youtube-thumbnail", "is_lora": True, "lora_filename": "FLUX-youtube-thumbnails.safetensors", "base_model": "black-forest-labs/FLUX.1-dev", "pipeline_class": AutoPipelineForText2Image, "ip_adapter_repo": "h94/IP-Adapter", "ip_adapter_weights": "ip-adapter_sd15.bin", "ip_adapter_image_encoder": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" }, "itzzdeep/youtube-thumbnails-sdxl-lora": { "repo_id": "itzzdeep/youtube-thumbnails-sdxl-lora", "is_lora": True, "lora_filename": "youtube-thumbnails-sdxl-lora.safetensors", "base_model": "stabilityai/stable-diffusion-xl-base-1.0", "pipeline_class": AutoPipelineForText2Image, # Handles SDXL loading better "ip_adapter_repo": "h94/IP-Adapter", # SDXL IP-Adapter repo "ip_adapter_weights": "ip-adapter-plus_sdxl_vit-h.bin", # SDXL weights "ip_adapter_image_encoder": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" # Usually the same encoder repo }, "justmalhar/flux-thumbnails-v3": { "repo_id": "justmalhar/flux-thumbnails-v3", "is_lora": False, # Assuming this is a full SD 1.5 fine-tune based on common practice "base_model": None, "pipeline_class": StableDiffusionPipeline, "ip_adapter_repo": "h94/IP-Adapter", "ip_adapter_weights": "ip-adapter_sd15.bin", "ip_adapter_image_encoder": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" }, "saq1b/mrbeast-thumbnail-style": { "repo_id": "saq1b/mrbeast-thumbnail-style", "is_lora": True, # This is typically a LoRA "lora_filename": None, # Auto-detect or specify e.g., "pytorch_lora_weights.safetensors" "base_model": "runwayml/stable-diffusion-v1-5", # Common base for SD 1.5 LoRAs "pipeline_class": StableDiffusionPipeline, "ip_adapter_repo": "h94/IP-Adapter", "ip_adapter_weights": "ip-adapter_sd15.bin", "ip_adapter_image_encoder": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" } } AVAILABLE_MODELS = list(MODEL_CONFIG.keys()) # Global variable to potentially hold the pipeline to avoid reloading *if memory allows* # NOTE: On restricted CPU environments, it's SAFER to load inside the function. # Set to None initially. Let's load dynamically inside the function for safety. # current_pipeline = None # current_model_key = None # --- Helper Functions --- def cleanup_memory(): """Attempts to free GPU memory (less relevant for CPU but good practice).""" logger.info("Attempting to clean up memory...") try: # If a pipeline exists globally (if we change strategy), unload it # global current_pipeline, current_model_key # if current_pipeline is not None: # logger.info(f"Unloading model {current_model_key} from memory.") # del current_pipeline # current_pipeline = None # current_model_key = None gc.collect() if torch.cuda.is_available(): # Only run cuda cache empty if cuda is present torch.cuda.empty_cache() logger.info("Memory cleanup potentially done.") except Exception as e: logger.error(f"Error during memory cleanup: {e}") # --- Main Generation Function --- def generate_thumbnail( model_key: str, prompt: str, negative_prompt: str, reference_image_pil: Image.Image | None, # Gradio provides PIL image num_inference_steps: int, guidance_scale: float, seed: int, ip_adapter_scale: float, progress=gr.Progress() ): """Generates an image using the selected model, IP-Adapter, and settings.""" start_time = time.time() debug_log = f"--- Generation Log ({time.strftime('%Y-%m-%d %H:%M:%S')}) ---\n" debug_log += f"Selected Model Key: {model_key}\n" debug_log += f"Prompt: {prompt}\n" debug_log += f"Negative Prompt: {negative_prompt}\n" debug_log += f"Steps: {num_inference_steps}, CFG Scale: {guidance_scale}\n" debug_log += f"Seed: {seed}\n" debug_log += f"Reference Image Provided: {'Yes' if reference_image_pil else 'No'}\n" debug_log += f"IP Adapter Scale: {ip_adapter_scale}\n" debug_log += f"Device: {DEVICE}, Dtype: {TORCH_DTYPE}\n\n" pipeline = None # Ensure pipeline is defined in this scope try: if not model_key: raise ValueError("No model selected.") config = MODEL_CONFIG[model_key] repo_id = config["repo_id"] is_lora = config["is_lora"] base_model = config["base_model"] pipeline_class = config["pipeline_class"] ip_adapter_repo = config["ip_adapter_repo"] ip_adapter_weights = config["ip_adapter_weights"] # ip_adapter_image_encoder = config["ip_adapter_image_encoder"] # Encoder loaded via IP-Adapter itself usually # --- Model Loading --- load_start_time = time.time() debug_log += f"[{time.time() - start_time:.2f}s] Cleaning up memory before loading...\n" progress(0.1, desc="Cleaning up memory...") cleanup_memory() # Attempt cleanup before loading new model debug_log += f"[{time.time() - start_time:.2f}s] Loading model: {'LoRA ' + repo_id if is_lora else repo_id}...\n" progress(0.2, desc=f"Loading {'LoRA ' + repo_id if is_lora else repo_id}...") model_load_id = base_model if is_lora else repo_id debug_log += f"[{time.time() - start_time:.2f}s] Base/Model ID for pipeline: {model_load_id}\n" pipeline = pipeline_class.from_pretrained( model_load_id, torch_dtype=TORCH_DTYPE, # Add any specific args needed for the pipeline class if necessary # safety_checker=None, # Disable safety checker if needed/causes issues on CPU # requires_safety_checker=False, ) pipeline.to(DEVICE) debug_log += f"[{time.time() - start_time:.2f}s] Base pipeline loaded onto {DEVICE}.\n" if is_lora: lora_load_start = time.time() debug_log += f"[{time.time() - start_time:.2f}s] Loading LoRA weights from {repo_id}...\n" progress(0.4, desc=f"Loading LoRA {repo_id}...") try: lora_filename = config.get("lora_filename") # Get specific filename if provided if lora_filename: debug_log += f"[{time.time() - start_time:.2f}s] Using specified LoRA filename: {lora_filename}\n" pipeline.load_lora_weights(repo_id, weight_name=lora_filename, torch_dtype=TORCH_DTYPE) else: # Let diffusers try to auto-detect standard names like .safetensors or .bin debug_log += f"[{time.time() - start_time:.2f}s] Attempting auto-detection of LoRA filename.\n" pipeline.load_lora_weights(repo_id, torch_dtype=TORCH_DTYPE) # When using LoRA with diffusers >= 0.22, explicitly fuse *or* set adapters # pipeline.fuse_lora() # Fuse creates a new pipeline state (might use more memory) pipeline.set_adapters(pipeline.get_active_adapters(), adapter_weights=1.0) # Recommended for flexibility debug_log += f"[{time.time() - start_time:.2f}s] LoRA weights loaded and adapters set in {time.time() - lora_load_start:.2f}s.\n" except Exception as e: debug_log += f"[{time.time() - start_time:.2f}s] ERROR loading LoRA: {e}. Check LoRA repo structure/filename.\n" # Decide whether to continue without LoRA or raise error raise ValueError(f"Failed to load LoRA weights for {repo_id}: {e}") # --- IP Adapter Loading --- if reference_image_pil and ip_adapter_scale > 0: ip_load_start = time.time() debug_log += f"[{time.time() - start_time:.2f}s] Loading IP-Adapter: {ip_adapter_repo} ({ip_adapter_weights})...\n" progress(0.6, desc="Loading IP-Adapter...") try: # Ensure the pipeline has the load_ip_adapter method if not hasattr(pipeline, "load_ip_adapter"): raise AttributeError("The current pipeline class does not support load_ip_adapter. Check diffusers version or pipeline type.") pipeline.load_ip_adapter( ip_adapter_repo, subfolder="models", # Common subfolder, adjust if needed weight_name=ip_adapter_weights, # image_encoder_folder=ip_adapter_image_encoder # Let diffusers handle encoder loading usually ) pipeline.set_ip_adapter_scale(ip_adapter_scale) debug_log += f"[{time.time() - start_time:.2f}s] IP-Adapter loaded and scale set ({ip_adapter_scale}) in {time.time() - ip_load_start:.2f}s.\n" # Prepare the image for IP-Adapter (often just needs to be a PIL image) ip_image = reference_image_pil.convert("RGB") debug_log += f"[{time.time() - start_time:.2f}s] Reference image prepared for IP-Adapter.\n" except Exception as e: debug_log += f"[{time.time() - start_time:.2f}s] WARNING: Failed to load IP-Adapter: {e}. Proceeding without image guidance.\n" ip_image = None ip_adapter_scale = 0 # Effectively disable it if loading failed pipeline.set_ip_adapter_scale(0) # Ensure scale is 0 else: ip_image = None if hasattr(pipeline, "set_ip_adapter_scale"): pipeline.set_ip_adapter_scale(0) # Ensure scale is 0 if no image/scale=0 debug_log += f"[{time.time() - start_time:.2f}s] No reference image provided or IP Adapter scale is 0. Skipping IP-Adapter loading.\n" debug_log += f"[{time.time() - start_time:.2f}s] Total Model & IP-Adapter Loading time: {time.time() - load_start_time:.2f}s\n" # --- Generation --- gen_start_time = time.time() debug_log += f"[{time.time() - start_time:.2f}s] Starting generation...\n" progress(0.7, desc="Generating image...") # Handle seed if seed == -1: seed = random.randint(0, 2**32 - 1) debug_log += f"[{time.time() - start_time:.2f}s] Using random seed: {seed}\n" generator = torch.Generator(device=DEVICE).manual_seed(seed) # Prepare arguments for pipeline call pipeline_args = { "prompt": prompt, "negative_prompt": negative_prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "generator": generator, } # Add IP-Adapter image if it's loaded and ready if ip_image is not None and hasattr(pipeline, "set_ip_adapter_scale") and ip_adapter_scale > 0: pipeline_args["ip_adapter_image"] = ip_image # Scale was set earlier with set_ip_adapter_scale debug_log += f"[{time.time() - start_time:.2f}s] Passing reference image to pipeline with IP scale {ip_adapter_scale}.\n" else: debug_log += f"[{time.time() - start_time:.2f}s] Not passing reference image to pipeline.\n" # Run inference with torch.inference_mode(): # More modern than no_grad for inference output_image = pipeline(**pipeline_args).images[0] gen_end_time = time.time() debug_log += f"[{time.time() - start_time:.2f}s] Generation finished in {gen_end_time - gen_start_time:.2f}s.\n" # --- Cleanup --- debug_log += f"[{time.time() - start_time:.2f}s] Unloading model from memory (CPU strategy)...\n" progress(0.95, desc="Cleaning up...") del pipeline # Explicitly delete pipeline cleanup_memory() # Call garbage collection total_time = time.time() - start_time debug_log += f"\n--- Total time: {total_time:.2f} seconds ---\n" return output_image, debug_log except Exception as e: logger.exception(f"Error during generation for model {model_key}") # Log full traceback error_time = time.time() - start_time debug_log += f"\n\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" debug_log += f"ERROR occurred after {error_time:.2f}s:\n{e}\n" debug_log += f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" # Try cleanup even on error if 'pipeline' in locals() and pipeline is not None: del pipeline cleanup_memory() # Return None for image, and the log containing the error return None, debug_log # --- Gradio Interface --- css = """ #warning { background-color: #FFCCCB; /* Light red */ padding: 10px; border-radius: 5px; text-align: center; font-weight: bold; } #debug_log_area textarea { font-family: monospace; font-size: 10px; /* Smaller font for logs */ white-space: pre-wrap; /* Wrap long lines */ word-wrap: break-word; /* Break words if necessary */ } """ with gr.Blocks(css=css) as demo: gr.Markdown("# YouTube Thumbnail Generator with IP-Adapter") gr.Markdown( "Select a thumbnail model, provide a text prompt, and optionally upload a reference image " "to guide the generation using IP-Adapter." ) gr.HTML("