pharmapsychotic commited on
Commit
e900656
1 Parent(s): 6e346ad

CLIP Interrogator 2.0

Browse files
Files changed (8) hide show
  1. app.py +197 -11
  2. data/artists.txt +5271 -0
  3. data/flavors.txt +0 -0
  4. data/mediums.txt +95 -0
  5. data/movements.txt +200 -0
  6. example.jpg +0 -0
  7. example01.jpg +0 -0
  8. example02.jpg +0 -0
app.py CHANGED
@@ -1,14 +1,27 @@
1
- import gradio as gr
2
  import sys
3
- import torch
4
- import torchvision.transforms as T
5
- import torchvision.transforms.functional as TF
6
-
7
  sys.path.append('src/blip')
8
  sys.path.append('src/clip')
9
 
10
  import clip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from models.blip import blip_decoder
 
 
 
 
12
 
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
@@ -24,6 +37,68 @@ clip_model_name = 'ViT-L/14'
24
  clip_model, clip_preprocess = clip.load(clip_model_name, device=device)
25
  clip_model.to(device).eval()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def generate_caption(pil_image):
29
  gpu_image = T.Compose([
@@ -36,18 +111,129 @@ def generate_caption(pil_image):
36
  caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
37
  return caption[0]
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def inference(image):
40
- return generate_caption(image)
41
 
42
  inputs = [gr.inputs.Image(type='pil')]
43
  outputs = gr.outputs.Textbox(label="Output")
44
 
45
  title = "CLIP Interrogator"
46
- description = "First test of CLIP Interrogator on HuggingSpace"
47
  article = """
48
- <p style='text-align: center'>
49
- <a href="">Colab Notebook</a> /
50
- <a href="">Github repo</a>
 
 
 
 
 
 
 
 
51
  </p>
52
  """
53
 
@@ -57,5 +243,5 @@ gr.Interface(
57
  outputs,
58
  title=title, description=description,
59
  article=article,
60
- examples=[['example.jpg']]
61
  ).launch(enable_queue=True)
 
 
1
  import sys
 
 
 
 
2
  sys.path.append('src/blip')
3
  sys.path.append('src/clip')
4
 
5
  import clip
6
+ import gradio as gr
7
+ import hashlib
8
+ import io
9
+ import IPython
10
+ import ipywidgets as widgets
11
+ import math
12
+ import numpy as np
13
+ import os
14
+ import pickle
15
+ import requests
16
+ import torch
17
+ import torchvision.transforms as T
18
+ import torchvision.transforms.functional as TF
19
+
20
  from models.blip import blip_decoder
21
+ from PIL import Image
22
+ from torch import nn
23
+ from torch.nn import functional as F
24
+ from tqdm import tqdm
25
 
26
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
 
 
37
  clip_model, clip_preprocess = clip.load(clip_model_name, device=device)
38
  clip_model.to(device).eval()
39
 
40
+ chunk_size = 2048
41
+ flavor_intermediate_count = 2048
42
+
43
+
44
+ class LabelTable():
45
+ def __init__(self, labels, desc):
46
+ self.labels = labels
47
+ self.embeds = []
48
+
49
+ hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
50
+
51
+ os.makedirs('./cache', exist_ok=True)
52
+ cache_filepath = f"./cache/{desc}.pkl"
53
+ if desc is not None and os.path.exists(cache_filepath):
54
+ with open(cache_filepath, 'rb') as f:
55
+ data = pickle.load(f)
56
+ if data['hash'] == hash:
57
+ self.labels = data['labels']
58
+ self.embeds = data['embeds']
59
+
60
+ if len(self.labels) != len(self.embeds):
61
+ self.embeds = []
62
+ chunks = np.array_split(self.labels, max(1, len(self.labels)/chunk_size))
63
+ for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
64
+ text_tokens = clip.tokenize(chunk).cuda()
65
+ with torch.no_grad():
66
+ text_features = clip_model.encode_text(text_tokens).float()
67
+ text_features /= text_features.norm(dim=-1, keepdim=True)
68
+ text_features = text_features.half().cpu().numpy()
69
+ for i in range(text_features.shape[0]):
70
+ self.embeds.append(text_features[i])
71
+
72
+ with open(cache_filepath, 'wb') as f:
73
+ pickle.dump({"labels":self.labels, "embeds":self.embeds, "hash":hash}, f)
74
+
75
+ def _rank(self, image_features, text_embeds, top_count=1):
76
+ top_count = min(top_count, len(text_embeds))
77
+ similarity = torch.zeros((1, len(text_embeds))).to(device)
78
+ text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(device)
79
+ for i in range(image_features.shape[0]):
80
+ similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1)
81
+ _, top_labels = similarity.cpu().topk(top_count, dim=-1)
82
+ return [top_labels[0][i].numpy() for i in range(top_count)]
83
+
84
+ def rank(self, image_features, top_count=1):
85
+ if len(self.labels) <= chunk_size:
86
+ tops = self._rank(image_features, self.embeds, top_count=top_count)
87
+ return [self.labels[i] for i in tops]
88
+
89
+ num_chunks = int(math.ceil(len(self.labels)/chunk_size))
90
+ keep_per_chunk = int(chunk_size / num_chunks)
91
+
92
+ top_labels, top_embeds = [], []
93
+ for chunk_idx in tqdm(range(num_chunks)):
94
+ start = chunk_idx*chunk_size
95
+ stop = min(start+chunk_size, len(self.embeds))
96
+ tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
97
+ top_labels.extend([self.labels[start+i] for i in tops])
98
+ top_embeds.extend([self.embeds[start+i] for i in tops])
99
+
100
+ tops = self._rank(image_features, top_embeds, top_count=top_count)
101
+ return [top_labels[i] for i in tops]
102
 
103
  def generate_caption(pil_image):
104
  gpu_image = T.Compose([
 
111
  caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
112
  return caption[0]
113
 
114
+ def load_list(filename):
115
+ with open(filename, 'r', encoding='utf-8', errors='replace') as f:
116
+ items = [line.strip() for line in f.readlines()]
117
+ return items
118
+
119
+ def rank_top(image_features, text_array):
120
+ text_tokens = clip.tokenize([text for text in text_array]).cuda()
121
+ with torch.no_grad():
122
+ text_features = clip_model.encode_text(text_tokens).float()
123
+ text_features /= text_features.norm(dim=-1, keepdim=True)
124
+
125
+ similarity = torch.zeros((1, len(text_array)), device=device)
126
+ for i in range(image_features.shape[0]):
127
+ similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
128
+
129
+ _, top_labels = similarity.cpu().topk(1, dim=-1)
130
+ return text_array[top_labels[0][0].numpy()]
131
+
132
+ def similarity(image_features, text):
133
+ text_tokens = clip.tokenize([text]).cuda()
134
+ with torch.no_grad():
135
+ text_features = clip_model.encode_text(text_tokens).float()
136
+ text_features /= text_features.norm(dim=-1, keepdim=True)
137
+ similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
138
+ return similarity[0][0]
139
+
140
+ def interrogate(image):
141
+ caption = generate_caption(image)
142
+
143
+ images = clip_preprocess(image).unsqueeze(0).cuda()
144
+ with torch.no_grad():
145
+ image_features = clip_model.encode_image(images).float()
146
+ image_features /= image_features.norm(dim=-1, keepdim=True)
147
+
148
+ flaves = flavors.rank(image_features, flavor_intermediate_count)
149
+ best_medium = mediums.rank(image_features, 1)[0]
150
+ best_artist = artists.rank(image_features, 1)[0]
151
+ best_trending = trendings.rank(image_features, 1)[0]
152
+ best_movement = movements.rank(image_features, 1)[0]
153
+
154
+ best_prompt = caption
155
+ best_sim = similarity(image_features, best_prompt)
156
+
157
+ def check(addition):
158
+ nonlocal best_prompt, best_sim
159
+ prompt = best_prompt + ", " + addition
160
+ sim = similarity(image_features, prompt)
161
+ if sim > best_sim:
162
+ best_sim = sim
163
+ best_prompt = prompt
164
+ return True
165
+ return False
166
+
167
+ def check_multi_batch(opts):
168
+ nonlocal best_prompt, best_sim
169
+ prompts = []
170
+ for i in range(2**len(opts)):
171
+ prompt = best_prompt
172
+ for bit in range(len(opts)):
173
+ if i & (1 << bit):
174
+ prompt += ", " + opts[bit]
175
+ prompts.append(prompt)
176
+
177
+ prompt = rank_top(image_features, prompts)
178
+ sim = similarity(image_features, prompt)
179
+ if sim > best_sim:
180
+ best_sim = sim
181
+ best_prompt = prompt
182
+
183
+ check_multi_batch([best_medium, best_artist, best_trending, best_movement])
184
+
185
+ extended_flavors = set(flaves)
186
+ for _ in tqdm(range(25), desc="Flavor chain"):
187
+ try:
188
+ best = rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
189
+ flave = best[len(best_prompt)+2:]
190
+ if not check(flave):
191
+ break
192
+ extended_flavors.remove(flave)
193
+ except:
194
+ # exceeded max prompt length
195
+ break
196
+
197
+ return best_prompt
198
+
199
+
200
+ sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
201
+ trending_list = [site for site in sites]
202
+ trending_list.extend(["trending on "+site for site in sites])
203
+ trending_list.extend(["featured on "+site for site in sites])
204
+ trending_list.extend([site+" contest winner" for site in sites])
205
+
206
+ raw_artists = load_list('data/artists.txt')
207
+ artists = [f"by {a}" for a in raw_artists]
208
+ artists.extend([f"inspired by {a}" for a in raw_artists])
209
+
210
+ artists = LabelTable(artists, "artists")
211
+ flavors = LabelTable(load_list('data/flavors.txt'), "flavors")
212
+ mediums = LabelTable(load_list('data/mediums.txt'), "mediums")
213
+ movements = LabelTable(load_list('data/movements.txt'), "movements")
214
+ trendings = LabelTable(trending_list, "trendings")
215
+
216
+
217
  def inference(image):
218
+ return interrogate(image)
219
 
220
  inputs = [gr.inputs.Image(type='pil')]
221
  outputs = gr.outputs.Textbox(label="Output")
222
 
223
  title = "CLIP Interrogator"
224
+ description = "Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!"
225
  article = """
226
+ <p>
227
+ Example art by <a href="https://pixabay.com/illustrations/watercolour-painting-art-effect-4799014/">Layers</a>
228
+ and <a href="https://pixabay.com/illustrations/animal-painting-cat-feline-pet-7154059/">Lin Tong</a>
229
+ from pixabay.com
230
+ </p>
231
+
232
+ <p>
233
+ Has this been helpful to you? Follow me on twitter
234
+ <a href="https://twitter.com/pharmapsychotic">@pharmapsychotic</a>
235
+ and check out more tools at my
236
+ <a href="https://pharmapsychotic.com/tools.html">Ai generative art tools list</a>
237
  </p>
238
  """
239
 
 
243
  outputs,
244
  title=title, description=description,
245
  article=article,
246
+ examples=[['example01.jpg'], ['example02.jpg']]
247
  ).launch(enable_queue=True)
data/artists.txt ADDED
@@ -0,0 +1,5271 @@