reedmayhew's picture
Update app.py
7129683 verified
raw
history blame
4.51 kB
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gradio as gr
import spaces
import os
def resize_image(image, max_size=2048):
width, height = image.size
if width > max_size or height > max_size:
aspect_ratio = width / height
if width > height:
new_width = max_size
new_height = int(new_width / aspect_ratio)
else:
new_height = max_size
new_width = int(new_height * aspect_ratio)
image = image.resize((new_width, new_height), Image.LANCZOS)
return image
def split_image(image, chunk_size=512):
width, height = image.size
chunks = []
for y in range(0, height, chunk_size):
for x in range(0, width, chunk_size):
chunk = image.crop((x, y, min(x + chunk_size, width), min(y + chunk_size, height)))
chunks.append((chunk, x, y))
return chunks
def stitch_image(chunks, original_size):
result = Image.new('RGB', original_size)
for img, x, y in chunks:
result.paste(img, (x, y))
return result
def upscale_chunk(chunk, model, processor, device):
inputs = processor(chunk, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
output = outputs.reconstruction.data.squeeze().cpu().float().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output_image = (output * 255.0).round().astype(np.uint8)
return Image.fromarray(output_image)
@spaces.GPU
def main(image, model_choice, save_as_jpg=True, use_tiling=True):
# Resize the input image
image = resize_image(image)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_paths = {
"Pixel Perfect": "caidas/swin2SR-classical-sr-x4-64",
"PSNR Match (Recommended)": "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
}
processor = AutoImageProcessor.from_pretrained(model_paths[model_choice])
model = Swin2SRForImageSuperResolution.from_pretrained(model_paths[model_choice]).to(device)
if use_tiling:
# Split the image into chunks
chunks = split_image(image)
# Process each chunk
upscaled_chunks = []
for chunk, x, y in chunks:
upscaled_chunk = upscale_chunk(chunk, model, processor, device)
# Remove 32 pixels from bottom and right edges
upscaled_chunk = upscaled_chunk.crop((0, 0, upscaled_chunk.width - 32, upscaled_chunk.height - 32))
upscaled_chunks.append((upscaled_chunk, x * 4, y * 4)) # Multiply coordinates by 4 due to 4x upscaling
# Stitch the chunks back together
final_size = (image.width * 4 - 32, image.height * 4 - 32) # Adjust for removed pixels
upscaled_image = stitch_image(upscaled_chunks, final_size)
else:
# Process the entire image at once
upscaled_image = upscale_chunk(image, model, processor, device)
# Generate output filename
original_filename = os.path.splitext(image.filename)[0] if image.filename else "image"
output_filename = f"{original_filename}_upscaled"
if save_as_jpg:
output_filename += ".jpg"
upscaled_image.save(output_filename, quality=95)
else:
output_filename += ".png"
upscaled_image.save(output_filename)
return output_filename
def gradio_interface(image, model_choice, save_as_jpg, use_tiling):
try:
result = main(image, model_choice, save_as_jpg, use_tiling)
return result, None
except Exception as e:
return None, str(e)
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(
choices=["PSNR Match (Recommended)", "Pixel Perfect"],
label="Select Model",
value="PSNR Match (Recommended)"
),
gr.Checkbox(value=True, label="Save as JPEG"),
gr.Checkbox(value=True, label="Use Tiling"),
],
outputs=[
gr.File(label="Download Upscaled Image"),
gr.Textbox(label="Error Message", visible=True)
],
title="Image Upscaler",
description="Upload an image, select a model, and upscale it. Images larger than 2048x2048 will be resized while maintaining aspect ratio. Use tiling for efficient processing of large images.",
)
interface.launch()