# ------------------------------------------------------------------------------------ # Minimal DALL-E # Copyright (c) 2021 KakaoBrain. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------ import os import random import urllib import hashlib import tarfile import torch import clip import numpy as np from PIL import Image from torch.nn import functional as F from tqdm import tqdm def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @torch.no_grad() def clip_score(prompt: str, images: np.ndarray, model_clip: torch.nn.Module, preprocess_clip, device: str) -> np.ndarray: images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images] images = torch.stack(images, dim=0).to(device=device) texts = clip.tokenize(prompt).to(device=device) texts = torch.repeat_interleave(texts, images.shape[0], dim=0) image_features = model_clip.encode_image(images) text_features = model_clip.encode_text(texts) scores = F.cosine_similarity(image_features, text_features).squeeze() rank = torch.argsort(scores, descending=True).cpu().numpy() return rank def download(url: str, root: str) -> str: os.makedirs(root, exist_ok=True) filename = os.path.basename(url) pathname = filename[:-len('.tar.gz')] expected_md5 = url.split("/")[-2] download_target = os.path.join(root, filename) result_path = os.path.join(root, pathname) if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)): return result_path with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output: with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5: raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match') with tarfile.open(download_target, 'r:gz') as f: pbar = tqdm(f.getmembers(), total=len(f.getmembers())) for member in pbar: pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)') f.extract(member=member, path=root) return result_path def realpath_url_or_path(url_or_path: str, root: str = None) -> str: if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'): return download(url_or_path, root) return url_or_path