import spaces import os import datetime import einops import gradio as gr from gradio_imageslider import ImageSlider import numpy as np import torch import random from PIL import Image from pathlib import Path from torchvision import transforms import torch.nn.functional as F from torchvision.models import resnet50, ResNet50_Weights from pytorch_lightning import seed_everything from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline from myutils.misc import load_dreambooth_lora, rand_name from myutils.wavelet_color_fix import wavelet_color_fix from annotator.retinaface import RetinaFaceDetection use_pasd_light = False face_detector = RetinaFaceDetection() if use_pasd_light: from models.pasd_light.unet_2d_condition import UNet2DConditionModel from models.pasd_light.controlnet import ControlNetModel else: from models.pasd.unet_2d_condition import UNet2DConditionModel from models.pasd.controlnet import ControlNetModel pretrained_model_path = "checkpoints/stable-diffusion-v1-5" ckpt_path = "runs/pasd/checkpoint-100000" #dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors" dreambooth_lora_path = "checkpoints/personalized_models/majicmixRealistic_v6.safetensors" #dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors" weight_dtype = torch.float16 device = "cuda" scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor") unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet") controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet") vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) controlnet.requires_grad_(False) unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path) text_encoder.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) unet.to(device, dtype=weight_dtype) controlnet.to(device, dtype=weight_dtype) validation_pipeline = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, ) #validation_pipeline.enable_vae_tiling() validation_pipeline._init_tiled_vae(decoder_tile_size=224) weights = ResNet50_Weights.DEFAULT preprocess = weights.transforms() resnet = resnet50(weights=weights) resnet.eval() def resize_image(image_path, target_height): # Open the image file with Image.open(image_path) as img: # Calculate the ratio to resize the image to the target height ratio = target_height / float(img.size[1]) # Calculate the new width based on the aspect ratio new_width = int(float(img.size[0]) * ratio) # Resize the image resized_img = img.resize((new_width, target_height), Image.LANCZOS) # Save the resized image #resized_img.save(output_path) return resized_img @spaces.GPU(enable_queue=True) def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed): input_image = resize_image(input_image, 512) process_size = 768 resize_preproc = transforms.Compose([ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), ]) # Get the current timestamp timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") with torch.no_grad(): seed_everything(seed) generator = torch.Generator(device=device) input_image = input_image.convert('RGB') batch = preprocess(input_image).unsqueeze(0) prediction = resnet(batch).squeeze(0).softmax(0) class_id = prediction.argmax().item() score = prediction[class_id].item() category_name = weights.meta["categories"][class_id] if score >= 0.1: prompt += f"{category_name}" if prompt=='' else f", {category_name}" prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}" ori_width, ori_height = input_image.size resize_flag = False rscale = upscale input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale)) #if min(validation_image.size) < process_size: # validation_image = resize_preproc(validation_image) input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8)) width, height = input_image.size resize_flag = True # try: image = validation_pipeline( None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg, negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0, ).images[0] if True: #alpha<1.0: image = wavelet_color_fix(image, input_image) if resize_flag: image = image.resize((ori_width*rscale, ori_height*rscale)) except Exception as e: print(e) image = Image.new(mode="RGB", size=(512, 512)) # Convert and save the image as JPEG image.save(f'result_{timestamp}.jpg', 'JPEG') # Convert and save the image as JPEG input_image.save(f'input_{timestamp}.jpg', 'JPEG') return (f"input_{timestamp}.jpg", f"result_{timestamp}.jpg"), f"result_{timestamp}.jpg" title = "Pixel-Aware Stable Diffusion for Real-ISR" description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them." article = "Github Repo Pytorch" #examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']] css = """ #col-container{ margin: 0 auto; max-width: 720px; } #project-links{ margin: 0 0 12px !important; column-gap: 8px; display: flex; justify-content: center; flex-wrap: nowrap; flex-direction: row; align-items: center; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.HTML(f"""
Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
""") with gr.Row(): with gr.Column(): input_image = gr.Image(type="filepath", sources=["upload"], value="samples/frog.png") prompt_in = gr.Textbox(label="Prompt", value="Frog") with gr.Accordion(label="Advanced settings", open=False): added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece') neg_prompt = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1) upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1) condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1) classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1) seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) submit_btn = gr.Button("Submit") with gr.Column(): b_a_slider = ImageSlider(label="B/A result", position=0.5) file_output = gr.File(label="Downloadable image result") submit_btn.click( fn = inference, inputs = [ input_image, prompt_in, added_prompt, neg_prompt, denoise_steps, upsample_scale, condition_scale, classifier_free_guidance, seed ], outputs = [ b_a_slider, file_output ] ) demo.queue().launch()