Spaces:
Running
Running
import os | |
import sys | |
import torch | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from omegaconf import OmegaConf | |
import subprocess | |
from tqdm import tqdm | |
import requests | |
import einops | |
import math | |
import random | |
import pytorch_lightning as pl | |
import spaces | |
def download_file(url, filename): | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
block_size = 1024 | |
with open(filename, 'wb') as file, tqdm( | |
desc=filename, | |
total=total_size, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as progress_bar: | |
for data in response.iter_content(block_size): | |
size = file.write(data) | |
progress_bar.update(size) | |
def setup_environment(): | |
if not os.path.exists("CCSR"): | |
print("Cloning CCSR repository...") | |
subprocess.run(["git", "clone", "-b", "dev", "https://github.com/camenduru/CCSR.git"]) | |
os.chdir("CCSR") | |
sys.path.append(os.getcwd()) | |
os.makedirs("weights", exist_ok=True) | |
if not os.path.exists("weights/real-world_ccsr.ckpt"): | |
print("Downloading model checkpoint...") | |
download_file( | |
"https://huggingface.co/camenduru/CCSR/resolve/main/real-world_ccsr.ckpt", | |
"weights/real-world_ccsr.ckpt" | |
) | |
else: | |
print("Model checkpoint already exists. Skipping download.") | |
setup_environment() | |
from ldm.xformers_state import disable_xformers | |
from model.q_sampler import SpacedSampler | |
from model.ccsr_stage1 import ControlLDM | |
from utils.common import instantiate_from_config, load_state_dict | |
from utils.image import auto_resize | |
config = OmegaConf.load("configs/model/ccsr_stage2.yaml") | |
model = instantiate_from_config(config) | |
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu") | |
load_state_dict(model, ckpt, strict=True) | |
model.freeze() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def process( | |
control_img: Image.Image, | |
num_samples: int, | |
sr_scale: float, | |
strength: float, | |
positive_prompt: str, | |
negative_prompt: str, | |
cfg_scale: float, | |
steps: int, | |
use_color_fix: bool, | |
seed: int, | |
tile_diffusion: bool, | |
tile_diffusion_size: int, | |
tile_diffusion_stride: int | |
): | |
print(f"control image shape={control_img.size}\n" | |
f"num_samples={num_samples}, sr_scale={sr_scale}, strength={strength}\n" | |
f"positive_prompt='{positive_prompt}', negative_prompt='{negative_prompt}'\n" | |
f"cfg scale={cfg_scale}, steps={steps}, use_color_fix={use_color_fix}\n" | |
f"seed={seed}\n" | |
f"tile_diffusion={tile_diffusion}, tile_diffusion_size={tile_diffusion_size}, tile_diffusion_stride={tile_diffusion_stride}") | |
pl.seed_everything(seed) | |
# Resize input image | |
if sr_scale != 1: | |
control_img = control_img.resize( | |
tuple(math.ceil(x * sr_scale) for x in control_img.size), | |
Image.BICUBIC | |
) | |
input_size = control_img.size | |
# Resize the image | |
if not tile_diffusion: | |
control_img = auto_resize(control_img, 512) | |
else: | |
control_img = auto_resize(control_img, tile_diffusion_size) | |
# Resize image to be multiples of 64 | |
control_img = control_img.resize( | |
tuple((s // 64 + 1) * 64 for s in control_img.size), Image.LANCZOS | |
) | |
control_img = np.array(control_img) | |
# Convert to tensor (NCHW, [0,1]) | |
control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=device).clamp_(0, 1) | |
control = einops.rearrange(control, "n h w c -> n c h w").contiguous() | |
height, width = control.size(-2), control.size(-1) | |
model.control_scales = [strength] * 13 | |
sampler = SpacedSampler(model, var_type="fixed_small") | |
preds = [] | |
for _ in tqdm(range(num_samples)): | |
shape = (1, 4, height // 8, width // 8) | |
x_T = torch.randn(shape, device=device, dtype=torch.float32) | |
# Create unconditional embeddings for classifier-free guidance | |
c = model.get_learned_conditioning([positive_prompt]) | |
uc = model.get_learned_conditioning([negative_prompt]) | |
if not tile_diffusion: | |
samples = sampler.sample_ccsr( | |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control, | |
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, | |
cfg_scale=cfg_scale, | |
color_fix_type="adain" if use_color_fix else "none", | |
# Pass unconditional embeddings to the sampler | |
unconditional_conditioning=uc, | |
) | |
else: | |
samples = sampler.sample_with_tile_ccsr( | |
tile_size=tile_diffusion_size, tile_stride=tile_diffusion_stride, | |
steps=steps, t_max=0.6667, t_min=0.3333, shape=shape, cond_img=control, | |
positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, | |
cfg_scale=cfg_scale, | |
color_fix_type="adain" if use_color_fix else "none", | |
# Pass unconditional embeddings to the sampler | |
unconditional_conditioning=uc, | |
) | |
x_samples = samples.clamp(0, 1) | |
x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) | |
img = Image.fromarray(x_samples[0, ...]).resize(input_size, Image.LANCZOS) | |
preds.append(np.array(img)) | |
return preds | |
def update_output_resolution(image, scale_choice, custom_scale): | |
if image is not None: | |
width, height = image.size | |
if scale_choice == "Custom": | |
scale = custom_scale | |
elif "%" in scale_choice: | |
scale = float(scale_choice.split()[-1].strip("()%")) / 100 | |
else: | |
scale = float(scale_choice.split()[-1].strip("()x")) | |
return f"Current resolution: {width}x{height}. Output resolution: {int(width*scale)}x{int(height*scale)}" | |
return "Upload an image to see the output resolution" | |
def update_scale_choices(image): | |
if image is not None: | |
width, height = image.size | |
aspect_ratio = width / height | |
common_resolutions = [ | |
(1280, 720), (1920, 1080), (2560, 1440), (3840, 2160), # 16:9 | |
(1440, 1440), (2048, 2048), (2560, 2560), (3840, 3840) # 1:1 | |
] | |
choices = [] | |
for w, h in common_resolutions: | |
if abs(w/h - aspect_ratio) < 0.1: # Allow some tolerance for aspect ratio | |
scale = max(w/width, h/height) | |
if scale > 1: | |
choices.append(f"{w}x{h} ({scale:.2f}x)") | |
if not choices: # If no common resolutions fit, use percentage-based options | |
choices = [ | |
f"{width*2}x{height*2} (200%)", | |
f"{width*4}x{height*4} (400%)", | |
f"{width*8}x{height*8} (800%)" | |
] | |
choices.append("Custom") | |
return gr.update(choices=choices, value=choices[0]) | |
return gr.update(choices=["Custom"], value="Custom") | |
# Improved UI design | |
css = """ | |
.container {max-width: 1200px; margin: auto; padding: 20px;} | |
.input-image {width: 100%; max-height: 500px; object-fit: contain;} | |
.output-gallery {display: flex; flex-wrap: wrap; justify-content: center;} | |
.output-image {margin: 10px; max-width: 45%; height: auto;} | |
.gr-form {border: 1px solid #e0e0e0; border-radius: 8px; padding: 16px; margin-bottom: 16px;} | |
""" | |
with gr.Blocks(css=css) as block: | |
gr.HTML("<h1 style='text-align: center;'>CCSR Upscaler</h1>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="pil", label="Input Image", elem_classes="input-image") | |
sr_scale = gr.Dropdown( | |
label="Output Resolution", | |
choices=["Custom"], | |
value="Custom", | |
interactive=True | |
) | |
custom_scale = gr.Slider( | |
label="Custom Scale", | |
minimum=1, | |
maximum=8, | |
value=4, | |
step=0.1, | |
visible=True | |
) | |
output_resolution = gr.Markdown("Upload an image to see the output resolution") | |
run_button = gr.Button(value="Run", variant="primary") | |
with gr.Column(scale=1): | |
with gr.Accordion("Advanced Options", open=False): | |
num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1) | |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | |
positive_prompt = gr.Textbox(label="Positive Prompt", value="") | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" | |
) | |
cfg_scale = gr.Slider(label="Classifier Free Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=45, step=1) | |
use_color_fix = gr.Checkbox(label="Use Color Correction", value=True) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231) | |
tile_diffusion = gr.Checkbox(label="Tile diffusion", value=False) | |
tile_diffusion_size = gr.Slider(label="Tile diffusion size", minimum=512, maximum=1024, value=512, step=256) | |
tile_diffusion_stride = gr.Slider(label="Tile diffusion stride", minimum=256, maximum=512, value=256, step=128) | |
with gr.Row(): | |
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery", elem_classes="output-gallery") | |
def update_custom_scale(choice): | |
return gr.update(visible=choice == "Custom") | |
sr_scale.change(update_custom_scale, inputs=[sr_scale], outputs=[custom_scale]) | |
def get_scale_value(choice, custom): | |
if choice == "Custom": | |
return custom | |
if "%" in choice: | |
return float(choice.split()[-1].strip("()%")) / 100 | |
return float(choice.split()[-1].strip("()x")) | |
inputs = [ | |
input_image, num_samples, sr_scale, strength, positive_prompt, negative_prompt, | |
cfg_scale, steps, use_color_fix, seed, tile_diffusion, tile_diffusion_size, | |
tile_diffusion_stride | |
] | |
run_button.click( | |
fn=lambda *args: process(args[0], args[1], get_scale_value(args[2], args[-1]), *args[3:-1]), | |
inputs=inputs + [custom_scale], | |
outputs=[result_gallery] | |
) | |
input_image.change( | |
update_scale_choices, | |
inputs=[input_image], | |
outputs=[sr_scale] | |
) | |
input_image.change( | |
update_output_resolution, | |
inputs=[input_image, sr_scale, custom_scale], | |
outputs=[output_resolution] | |
) | |
sr_scale.change( | |
update_output_resolution, | |
inputs=[input_image, sr_scale, custom_scale], | |
outputs=[output_resolution] | |
) | |
custom_scale.change( | |
update_output_resolution, | |
inputs=[input_image, sr_scale, custom_scale], | |
outputs=[output_resolution] | |
) | |
input_image.change( | |
lambda x: gr.update(interactive=x is not None), | |
inputs=[input_image], | |
outputs=[sr_scale] | |
) | |
block.launch(share=True) |