| import sys |
| import types |
| import datetime |
| import re |
| from pathlib import Path |
|
|
| import huggingface_hub |
|
|
| |
| |
| |
| if not hasattr(huggingface_hub, "cached_download"): |
| def cached_download(*args, **kwargs): |
| return huggingface_hub.hf_hub_download(*args, **kwargs) |
| huggingface_hub.cached_download = cached_download |
|
|
| import torch |
| import numpy as np |
| import einops |
| import spaces |
| import gradio as gr |
|
|
| from PIL import Image |
| from torchvision import transforms |
| import torch.nn.functional as F |
| from torchvision.models import resnet50, ResNet50_Weights |
|
|
| from pytorch_lightning import seed_everything |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
| from diffusers import ( |
| AutoencoderKL, |
| DDIMScheduler, |
| PNDMScheduler, |
| DPMSolverMultistepScheduler, |
| UniPCMultistepScheduler, |
| ) |
|
|
| |
| |
| |
| torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6) |
| torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace( |
| name="NVIDIA A10G", |
| major=8, |
| minor=6, |
| total_memory=23836033024, |
| multi_processor_count=80, |
| ) |
|
|
| |
| |
| |
| huggingface_hub.snapshot_download( |
| repo_id="camenduru/PASD", |
| allow_patterns=[ |
| "pasd/**", |
| "pasd_light/**", |
| "pasd_light_rrdb/**", |
| "pasd_rrdb/**", |
| ], |
| local_dir="PASD/runs", |
| ) |
|
|
| huggingface_hub.hf_hub_download( |
| repo_id="camenduru/PASD", |
| filename="majicmixRealistic_v6.safetensors", |
| local_dir="PASD/checkpoints/personalized_models", |
| ) |
|
|
| huggingface_hub.hf_hub_download( |
| repo_id="akhaliq/RetinaFace-R50", |
| filename="RetinaFace-R50.pth", |
| local_dir="PASD/annotator/ckpts", |
| ) |
|
|
| |
| |
| |
| sys.path.append("./PASD") |
|
|
|
|
| |
| |
| |
| def patch_file(path_str: str, replacements: list[tuple[str, str]]) -> None: |
| path = Path(path_str) |
| if not path.exists(): |
| print(f"[patch] file not found: {path}") |
| return |
|
|
| try: |
| text = path.read_text(encoding="utf-8") |
| except Exception as e: |
| print(f"[patch] failed reading {path}: {e}") |
| return |
|
|
| original = text |
| for old, new in replacements: |
| text = text.replace(old, new) |
|
|
| if text != original: |
| try: |
| path.write_text(text, encoding="utf-8") |
| print(f"[patch] updated: {path}") |
| except Exception as e: |
| print(f"[patch] failed writing {path}: {e}") |
| else: |
| print(f"[patch] no changes: {path}") |
|
|
|
|
| def patch_controlnet_loader_import(path_str: str) -> None: |
| path = Path(path_str) |
| if not path.exists(): |
| print(f"[patch] file not found: {path}") |
| return |
|
|
| try: |
| text = path.read_text(encoding="utf-8") |
| except Exception as e: |
| print(f"[patch] failed reading {path}: {e}") |
| return |
|
|
| safe_block = """try: |
| from diffusers.loaders import FromOriginalControlNetMixin as FromOriginalControlnetMixin |
| except Exception: |
| try: |
| from diffusers.loaders import FromOriginalControlnetMixin |
| except Exception: |
| class FromOriginalControlnetMixin: |
| pass |
| |
| """ |
|
|
| original = text |
|
|
| |
| text = re.sub( |
| r"(?m)^from diffusers\.loaders[^\n]*FromOriginalControl\w*Mixin[^\n]*\n", |
| "", |
| text, |
| ) |
| text = re.sub( |
| r"(?m)^from diffusers\.loaders\.single_file_model[^\n]*FromOriginal\w+[^\n]*\n", |
| "", |
| text, |
| ) |
|
|
| |
| text = re.sub( |
| r"(?ms)^try:\n(?:(?: |\t).*\n)+?except Exception:\n(?:(?: |\t).*\n)+?(?=^(?:class|def|@|from |import |\Z))", |
| lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0), |
| text, |
| ) |
|
|
| |
| text = text.replace("FromOriginalControlNetMixin", "FromOriginalControlnetMixin") |
|
|
| marker = "class ControlNetConditioningEmbedding" |
| if safe_block not in text: |
| idx = text.find(marker) |
| if idx != -1: |
| text = text[:idx] + safe_block + text[idx:] |
| else: |
| text = safe_block + text |
|
|
| if text != original: |
| try: |
| path.write_text(text, encoding="utf-8") |
| print(f"[patch] normalized: {path}") |
| except Exception as e: |
| print(f"[patch] failed writing {path}: {e}") |
| else: |
| print(f"[patch] no changes: {path}") |
|
|
|
|
| def patch_pasd_for_diffusers() -> None: |
| |
| patch_file( |
| "./PASD/pipelines/pipeline_pasd.py", |
| [ |
| ( |
| "from diffusers.pipeline_utils import DiffusionPipeline", |
| "from diffusers import DiffusionPipeline", |
| ), |
| ], |
| ) |
|
|
| |
| patch_file( |
| "./PASD/models/pasd/unet_2d_condition.py", |
| [ |
| (" PositionNet,\n", ""), |
| ( |
| " GLIGENTextBoundingboxProjection,\n", |
| " GLIGENTextBoundingboxProjection as PositionNet,\n", |
| ), |
| ], |
| ) |
|
|
| |
| patch_file( |
| "./PASD/models/pasd/unet_2d_blocks.py", |
| [ |
| ( |
| "from diffusers.models.attention import AdaGroupNorm", |
| "from diffusers.models.normalization import AdaGroupNorm", |
| ), |
| ( |
| "from diffusers.models.dual_transformer_2d import DualTransformer2DModel", |
| "from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel", |
| ), |
| ( |
| "from diffusers.models.transformer_2d import Transformer2DModel", |
| "from diffusers.models.transformers.transformer_2d import Transformer2DModel", |
| ), |
| ], |
| ) |
|
|
| |
| patch_controlnet_loader_import("./PASD/models/pasd/controlnet.py") |
|
|
|
|
| patch_pasd_for_diffusers() |
|
|
| |
| |
| |
| from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline |
| from myutils.misc import load_dreambooth_lora |
| from myutils.wavelet_color_fix import wavelet_color_fix |
| from annotator.retinaface import RetinaFaceDetection |
|
|
| use_pasd_light = False |
| face_detector = RetinaFaceDetection() |
|
|
| if use_pasd_light: |
| from models.pasd_light.unet_2d_condition import UNet2DConditionModel |
| from models.pasd_light.controlnet import ControlNetModel |
| else: |
| from models.pasd.unet_2d_condition import UNet2DConditionModel |
| from models.pasd.controlnet import ControlNetModel |
|
|
| |
| |
| |
| pretrained_model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5" |
| ckpt_path = "PASD/runs/pasd/checkpoint-100000" |
| dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors" |
|
|
| weight_dtype = torch.float16 |
| device = "cuda" |
|
|
| scheduler = UniPCMultistepScheduler.from_pretrained( |
| pretrained_model_path, subfolder="scheduler" |
| ) |
| text_encoder = CLIPTextModel.from_pretrained( |
| pretrained_model_path, subfolder="text_encoder" |
| ) |
| tokenizer = CLIPTokenizer.from_pretrained( |
| pretrained_model_path, subfolder="tokenizer" |
| ) |
| vae = AutoencoderKL.from_pretrained( |
| pretrained_model_path, subfolder="vae" |
| ) |
| feature_extractor = CLIPImageProcessor.from_pretrained( |
| pretrained_model_path, subfolder="feature_extractor" |
| ) |
| unet = UNet2DConditionModel.from_pretrained( |
| ckpt_path, subfolder="unet" |
| ) |
| controlnet = ControlNetModel.from_pretrained( |
| ckpt_path, subfolder="controlnet" |
| ) |
|
|
| vae.requires_grad_(False) |
| text_encoder.requires_grad_(False) |
| unet.requires_grad_(False) |
| controlnet.requires_grad_(False) |
|
|
| unet, vae, text_encoder = load_dreambooth_lora( |
| unet, vae, text_encoder, dreambooth_lora_path |
| ) |
|
|
| text_encoder.to(device, dtype=weight_dtype) |
| vae.to(device, dtype=weight_dtype) |
| unet.to(device, dtype=weight_dtype) |
| controlnet.to(device, dtype=weight_dtype) |
|
|
| validation_pipeline = StableDiffusionControlNetPipeline( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| feature_extractor=feature_extractor, |
| unet=unet, |
| controlnet=controlnet, |
| scheduler=scheduler, |
| safety_checker=None, |
| requires_safety_checker=False, |
| ) |
|
|
| validation_pipeline._init_tiled_vae(decoder_tile_size=224) |
|
|
| |
| |
| |
| weights = ResNet50_Weights.DEFAULT |
| preprocess = weights.transforms() |
| resnet = resnet50(weights=weights) |
| resnet.eval() |
|
|
|
|
| def resize_image(image_path: str, target_height: int) -> Image.Image: |
| with Image.open(image_path) as img: |
| ratio = target_height / float(img.size[1]) |
| new_width = int(float(img.size[0]) * ratio) |
| return img.resize((new_width, target_height), Image.LANCZOS) |
|
|
|
|
| @spaces.GPU(enable_queue=True) |
| def inference( |
| input_image, |
| prompt, |
| a_prompt, |
| n_prompt, |
| denoise_steps, |
| upscale, |
| alpha, |
| cfg, |
| seed, |
| progress=gr.Progress(track_tqdm=True) |
| ): |
| if seed == -1: |
| seed = 0 |
|
|
| input_image = resize_image(input_image, 512) |
| timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") |
|
|
| with torch.no_grad(): |
| seed_everything(seed) |
| generator = torch.Generator(device=device) |
| generator.manual_seed(seed) |
|
|
| input_image = input_image.convert("RGB") |
|
|
| batch = preprocess(input_image).unsqueeze(0) |
| prediction = resnet(batch).squeeze(0).softmax(0) |
| class_id = prediction.argmax().item() |
| score = prediction[class_id].item() |
| category_name = weights.meta["categories"][class_id] |
|
|
| if score >= 0.1: |
| prompt += f"{category_name}" if prompt == "" else f", {category_name}" |
|
|
| prompt = a_prompt if prompt == "" else f"{prompt}, {a_prompt}" |
|
|
| ori_width, ori_height = input_image.size |
|
|
| rscale = upscale |
| input_image = input_image.resize( |
| (input_image.size[0] * rscale, input_image.size[1] * rscale) |
| ) |
| input_image = input_image.resize( |
| (input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8) |
| ) |
| width, height = input_image.size |
|
|
| try: |
| image = validation_pipeline( |
| None, |
| prompt, |
| input_image, |
| num_inference_steps=denoise_steps, |
| generator=generator, |
| height=height, |
| width=width, |
| guidance_scale=cfg, |
| negative_prompt=n_prompt, |
| conditioning_scale=alpha, |
| eta=0.0, |
| ).images[0] |
|
|
| image = wavelet_color_fix(image, input_image) |
| image = image.resize((ori_width * rscale, ori_height * rscale)) |
| except Exception as e: |
| print(f"[inference] error: {e}") |
| image = Image.new(mode="RGB", size=(512, 512)) |
|
|
| result_path = f"result_{timestamp}.jpg" |
| input_path = f"input_{timestamp}.jpg" |
|
|
| image.save(result_path, "JPEG") |
| input_image.save(input_path, "JPEG") |
|
|
| return input_path, result_path, result_path |
|
|
|
|
| css = """ |
| #col-container{ |
| margin: 0 auto; |
| max-width: 720px; |
| } |
| #project-links{ |
| margin: 0 0 12px !important; |
| column-gap: 8px; |
| display: flex; |
| justify-content: center; |
| flex-wrap: nowrap; |
| flex-direction: row; |
| align-items: center; |
| } |
| """ |
|
|
| with gr.Blocks() as demo: |
| with gr.Column(elem_id="col-container"): |
| gr.HTML(""" |
| <h2 style="text-align: center;"> |
| PASD Magnify |
| </h2> |
| <p style="text-align: center;"> |
| Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization |
| </p> |
| <p id="project-links" align="center"> |
| <a href="https://github.com/yangxy/PASD"><img src="https://img.shields.io/badge/Project-Page-Green"></a> |
| <a href="https://huggingface.co/papers/2308.14469"><img src="https://img.shields.io/badge/Paper-Arxiv-red"></a> |
| </p> |
| <p style="margin:12px auto;display: flex;justify-content: center;"> |
| <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true"> |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"> |
| </a> |
| </p> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image( |
| type="filepath", |
| sources=["upload"], |
| value="PASD/samples/frog.png", |
| label="Input image", |
| ) |
| prompt_in = gr.Textbox(label="Prompt", value="Frog") |
|
|
| with gr.Accordion(label="Advanced settings", open=False): |
| added_prompt = gr.Textbox( |
| label="Added Prompt", |
| value="clean, high-resolution, 8k, best quality, masterpiece", |
| ) |
| neg_prompt = gr.Textbox( |
| label="Negative Prompt", |
| value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", |
| ) |
| denoise_steps = gr.Slider( |
| label="Denoise Steps", |
| minimum=10, |
| maximum=50, |
| value=20, |
| step=1, |
| ) |
| upsample_scale = gr.Slider( |
| label="Upsample Scale", |
| minimum=1, |
| maximum=4, |
| value=2, |
| step=1, |
| ) |
| condition_scale = gr.Slider( |
| label="Conditioning Scale", |
| minimum=0.5, |
| maximum=1.5, |
| value=1.1, |
| step=0.1, |
| ) |
| classifier_free_guidance = gr.Slider( |
| label="Classifier-free Guidance", |
| minimum=0.1, |
| maximum=10.0, |
| value=7.5, |
| step=0.1, |
| ) |
| seed = gr.Slider( |
| label="Seed", |
| minimum=-1, |
| maximum=2147483647, |
| step=1, |
| randomize=True, |
| ) |
|
|
| submit_btn = gr.Button("Submit") |
|
|
| with gr.Column(): |
| before_img = gr.Image(label="Input") |
| after_img = gr.Image(label="Result") |
| file_output = gr.File(label="Downloadable image result") |
|
|
| submit_btn.click( |
| fn=inference, |
| inputs=[ |
| input_image, |
| prompt_in, |
| added_prompt, |
| neg_prompt, |
| denoise_steps, |
| upsample_scale, |
| condition_scale, |
| classifier_free_guidance, |
| seed, |
| ], |
| outputs=[ |
| before_img, |
| after_img, |
| file_output, |
| ], |
| api_visibility="private", |
| ) |
|
|
| demo.queue(max_size=10).launch( |
| ssr_mode=False, |
| mcp_server=False, |
| css=css, |
| ) |