Spaces:
Build error
Build error
# ------------------------------------------------------------------------------------ | |
# Minimal DALL-E | |
# Copyright (c) 2021 KakaoBrain. All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
# ------------------------------------------------------------------------------------ | |
import os | |
import random | |
import urllib | |
import hashlib | |
import tarfile | |
import torch | |
import clip | |
import numpy as np | |
from PIL import Image | |
from torch.nn import functional as F | |
from tqdm import tqdm | |
import torchvision.utils as vutils | |
import matplotlib.pyplot as plt | |
def set_seed(seed: int): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def clip_score(prompt: str, | |
images: np.ndarray, | |
model_clip: torch.nn.Module, | |
preprocess_clip, | |
device: str) -> np.ndarray: | |
images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images] | |
images = torch.stack(images, dim=0).to(device=device) | |
texts = clip.tokenize(prompt).to(device=device) | |
texts = torch.repeat_interleave(texts, images.shape[0], dim=0) | |
image_features = model_clip.encode_image(images) | |
text_features = model_clip.encode_text(texts) | |
scores = F.cosine_similarity(image_features, text_features).squeeze() | |
rank = torch.argsort(scores, descending=True).cpu().numpy() | |
return rank | |
def download(url: str, root: str) -> str: | |
os.makedirs(root, exist_ok=True) | |
filename = os.path.basename(url) | |
pathname = filename[:-len('.tar.gz')] | |
expected_md5 = url.split("/")[-2] | |
download_target = os.path.join(root, filename) | |
result_path = os.path.join(root, pathname) | |
if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)): | |
return result_path | |
with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output: | |
with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True, | |
unit_divisor=1024) as loop: | |
while True: | |
buffer = source.read(8192) | |
if not buffer: | |
break | |
output.write(buffer) | |
loop.update(len(buffer)) | |
if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5: | |
raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match') | |
with tarfile.open(download_target, 'r:gz') as f: | |
pbar = tqdm(f.getmembers(), total=len(f.getmembers())) | |
for member in pbar: | |
pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)') | |
f.extract(member=member, path=root) | |
return result_path | |
def realpath_url_or_path(url_or_path: str, root: str = None) -> str: | |
if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'): | |
return download(url_or_path, root) | |
return url_or_path | |
def images_to_numpy(tensor): | |
generated = tensor.data.cpu().numpy().transpose(1,2,0) | |
generated[generated < -1] = -1 | |
generated[generated > 1] = 1 | |
generated = (generated + 1) / 2 * 255 | |
return generated.astype('uint8') | |
def save_image(ground_truth, images, out_dir, batch_idx): | |
for i, im in enumerate(images): | |
if len(im.shape) == 3: | |
plt.imsave(os.path.join(out_dir, 'test_%s_%s.png' % (batch_idx, i)), im) | |
else: | |
bs = im.shape[0] | |
# plt.imsave() | |
for j in range(bs): | |
plt.imsave(os.path.join(out_dir, 'test_%s_%s_%s.png' % (batch_idx, i, j)), im[j]) | |
# print("Ground truth Images shape: ", ground_truth.shape, len(images)) | |
# images = vutils.make_grid(images, nrow=ground_truth.shape[0]) | |
# images = images_to_numpy(images) | |
# | |
# if ground_truth is not None: | |
# ground_truth = vutils.make_grid(ground_truth, 5) | |
# ground_truth = images_to_numpy(ground_truth) | |
# print("Ground Truth shape, Generated Images shape: ", ground_truth.shape, images.shape) | |
# images = np.concatenate([ground_truth, images], axis=0) | |
# | |
# output = Image.fromarray(images) | |
# output.save('%s/fake_samples_epoch_%03d.png' % (out_dir, batch_idx)) | |
# if texts is not None: | |
# fid = open('%s/fake_samples_epoch_%03d_%03d.txt' % (image_dir, epoch, idx), 'w') | |
# for idx in range(images.shape[0]): | |
# fid.write(str(idx) + '--------------------------------------------------------\n') | |
# for i in range(len(texts)): | |
# fid.write(texts[i][idx] + '\n') | |
# fid.write('\n\n') | |
# fid.close() | |
return |