File size: 3,441 Bytes
25ad706
29356cb
 
 
f1ee166
92c37e9
29356cb
25ad706
0782bc0
25ad706
 
 
 
 
 
0782bc0
25ad706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
 
4a66938
25ad706
0782bc0
4a66938
 
 
 
 
 
25ad706
29356cb
25ad706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29356cb
4a66938
25ad706
 
 
 
 
29356cb
 
 
 
e2d6adc
4a66938
 
 
 
 
f1ee166
29356cb
13a4c81
 
 
 
29356cb
25ad706
29356cb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces

def split_image(image, chunk_size=512):
    width, height = image.size
    chunks = []
    for y in range(0, height, chunk_size):
        for x in range(0, width, chunk_size):
            chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
            chunks.append((chunk, x, y))
    return chunks

def stitch_image(chunks, original_size):
    result = Image.new('RGB', original_size)
    for img, x, y in chunks:
        result.paste(img, (x, y))
    return result

def upscale_chunk(chunk, model, processor, device):
    inputs = processor(chunk, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
    output = np.moveaxis(output, source=0, destination=-1)
    output_image = (output * 255.0).round().astype(np.uint8)
    return Image.fromarray(output_image)

@spaces.GPU
def main(image, model_choice, save_as_jpg=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_paths = {
        "Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
        "PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
    }
    
    processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
    model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
    
    # Split the image into chunks
    chunks = split_image(image)
    
    # Process each chunk
    upscaled_chunks = []
    for chunk, x, y in chunks:
        upscaled_chunk = upscale_chunk(chunk, model, processor, device)
        # Remove 32 pixels from bottom and right edges
        upscaled_chunk = upscaled_chunk.crop((0, 0, upscaled_chunk.width - 32, upscaled_chunk.height - 32))
        upscaled_chunks.append((upscaled_chunk, x * 4, y * 4))  # Multiply coordinates by 4 due to 4x upscaling
    
    # Stitch the chunks back together
    final_size = (image.width * 4 - 32, image.height * 4 - 32)  # Adjust for removed pixels
    upscaled_image = stitch_image(upscaled_chunks, final_size)
    
    if save_as_jpg:
        upscaled_image.save("upscaled_image.jpg", quality=95)
        return "upscaled_image.jpg"
    else:
        upscaled_image.save("upscaled_image.png")
        return "upscaled_image.png"

def gradio_interface(image, model_choice, save_as_jpg):
    try:
        result = main(image, model_choice, save_as_jpg)
        return result, None
    except Exception as e:
        return None, str(e)

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(
            choices=["PSNR Match (Recommended)", "Pixel Perfect"],
            label="Select Model",
            value="PSNR Match (Recommended)"
        ),
        gr.Checkbox(value=True, label="Save as JPEG"),
    ],
    outputs=[
        gr.File(label="Download Upscaled Image"),
        gr.Textbox(label="Error Message", visible=True)
    ],
    title="Image Upscaler",
    description="Upload an image, select a model, and upscale it. The image will be processed in 512x512 pixel chunks to handle large images efficiently.",
)

interface.launch()