keithhon commited on
Commit
0c9c6b9
1 Parent(s): 7931c5f

Upload dalle/utils/utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dalle/utils/utils.py +84 -0
dalle/utils/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import random
9
+ import urllib
10
+ import hashlib
11
+ import tarfile
12
+ import torch
13
+ import clip
14
+ import numpy as np
15
+ from PIL import Image
16
+ from torch.nn import functional as F
17
+ from tqdm import tqdm
18
+
19
+
20
+ def set_seed(seed: int):
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+
26
+
27
+ @torch.no_grad()
28
+ def clip_score(prompt: str,
29
+ images: np.ndarray,
30
+ model_clip: torch.nn.Module,
31
+ preprocess_clip,
32
+ device: str) -> np.ndarray:
33
+ images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
34
+ images = torch.stack(images, dim=0).to(device=device)
35
+ texts = clip.tokenize(prompt).to(device=device)
36
+ texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
37
+
38
+ image_features = model_clip.encode_image(images)
39
+ text_features = model_clip.encode_text(texts)
40
+
41
+ scores = F.cosine_similarity(image_features, text_features).squeeze()
42
+ rank = torch.argsort(scores, descending=True).cpu().numpy()
43
+ return rank
44
+
45
+
46
+ def download(url: str, root: str) -> str:
47
+ os.makedirs(root, exist_ok=True)
48
+ filename = os.path.basename(url)
49
+ pathname = filename[:-len('.tar.gz')]
50
+
51
+ expected_md5 = url.split("/")[-2]
52
+ download_target = os.path.join(root, filename)
53
+ result_path = os.path.join(root, pathname)
54
+
55
+ #if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
56
+ #return result_path
57
+
58
+ with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
59
+ with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
60
+ unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ # if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
70
+ # raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
71
+
72
+ with tarfile.open(download_target, 'r:gz') as f:
73
+ pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
74
+ for member in pbar:
75
+ pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
76
+ f.extract(member=member, path=root)
77
+
78
+ return result_path
79
+
80
+
81
+ def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
82
+ if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
83
+ return download(url_or_path, root)
84
+ return url_or_path