Spaces:
Configuration error
Configuration error
import io | |
import os | |
import sys | |
import numpy as np | |
import PIL | |
import requests | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
import torchvision.transforms.functional as TF | |
from PIL import Image, ImageDraw, ImageFont | |
def download_image(url): | |
resp = requests.get(url) | |
resp.raise_for_status() | |
return PIL.Image.open(io.BytesIO(resp.content)) | |
def preprocess(img, target_image_size=256, map_dalle=False): | |
s = min(img.size) | |
if s < target_image_size: | |
raise ValueError(f'min dim for image {s} < {target_image_size}') | |
r = target_image_size / s | |
s = (round(r * img.size[1]), round(r * img.size[0])) | |
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) | |
img = TF.center_crop(img, output_size=2 * [target_image_size]) | |
img = torch.unsqueeze(T.ToTensor()(img), 0) | |
return img | |
def preprocess_vqgan(x): | |
x = 2.*x - 1. | |
return x | |
def custom_to_pil(x, process=True, mode="RGB"): | |
x = x.detach().cpu() | |
if process: | |
x = torch.clamp(x, -1., 1.) | |
x = (x + 1.)/2. | |
x = x.permute(1,2,0).numpy() | |
if process: | |
x = (255*x).astype(np.uint8) | |
x = Image.fromarray(x) | |
if not x.mode == mode: | |
x = x.convert(mode) | |
return x | |
def get_pil(x): | |
x = torch.clamp(x, -1., 1.) | |
x = (x + 1.)/2. | |
x = x.permute(1,2,0) | |
return x | |
def loop_post_process(x): | |
x = get_pil(x.squeeze()) | |
return x.permute(2, 0, 1).unsqueeze(0) | |
def stack_reconstructions(input, x0, x1, x2, x3, titles=[]): | |
assert input.size == x1.size == x2.size == x3.size | |
w, h = input.size[0], input.size[1] | |
img = Image.new("RGB", (5*w, h)) | |
img.paste(input, (0,0)) | |
img.paste(x0, (1*w,0)) | |
img.paste(x1, (2*w,0)) | |
img.paste(x2, (3*w,0)) | |
img.paste(x3, (4*w,0)) | |
for i, title in enumerate(titles): | |
ImageDraw.Draw(img).text((i*w, 0), f'{title}', (255, 255, 255), font=font) # coordinates, text, color, font | |
return img | |