Ahsen Khaliq commited on
Commit
b2cb3d5
1 Parent(s): 7135121

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip install transformers gdown torch numpy tqdm Pillow scikit-image")
3
+ os.system("pip install git+https://github.com/openai/CLIP.git")
4
+ os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
5
+
6
+ import clip
7
+ import os
8
+ from torch import nn
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as nnf
12
+ import sys
13
+ from typing import Tuple, List, Union, Optional
14
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
15
+ from tqdm import tqdm, trange
16
+ import skimage.io as io
17
+ import PIL.Image
18
+
19
+ class MLP(nn.Module):
20
+
21
+ def forward(self, x: T) -> T:
22
+ return self.model(x)
23
+
24
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
25
+ super(MLP, self).__init__()
26
+ layers = []
27
+ for i in range(len(sizes) -1):
28
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
29
+ if i < len(sizes) - 2:
30
+ layers.append(act())
31
+ self.model = nn.Sequential(*layers)
32
+
33
+
34
+ class ClipCaptionModel(nn.Module):
35
+
36
+ #@functools.lru_cache #FIXME
37
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
38
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
39
+
40
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
41
+ embedding_text = self.gpt.transformer.wte(tokens)
42
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
43
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
44
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
45
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
46
+ if labels is not None:
47
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
48
+ labels = torch.cat((dummy_token, tokens), dim=1)
49
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
50
+ return out
51
+
52
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
53
+ super(ClipCaptionModel, self).__init__()
54
+ self.prefix_length = prefix_length
55
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
56
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
57
+ if prefix_length > 10: # not enough memory
58
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
59
+ else:
60
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
61
+
62
+
63
+ class ClipCaptionPrefix(ClipCaptionModel):
64
+
65
+ def parameters(self, recurse: bool = True):
66
+ return self.clip_project.parameters()
67
+
68
+ def train(self, mode: bool = True):
69
+ super(ClipCaptionPrefix, self).train(mode)
70
+ self.gpt.eval()
71
+ return self
72
+
73
+
74
+ #@title Caption prediction
75
+
76
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
77
+ entry_length=67, temperature=1., stop_token: str = '.'):
78
+
79
+ model.eval()
80
+ stop_token_index = tokenizer.encode(stop_token)[0]
81
+ tokens = None
82
+ scores = None
83
+ device = next(model.parameters()).device
84
+ seq_lengths = torch.ones(beam_size, device=device)
85
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
86
+ with torch.no_grad():
87
+ if embed is not None:
88
+ generated = embed
89
+ else:
90
+ if tokens is None:
91
+ tokens = torch.tensor(tokenizer.encode(prompt))
92
+ tokens = tokens.unsqueeze(0).to(device)
93
+ generated = model.gpt.transformer.wte(tokens)
94
+ for i in range(entry_length):
95
+ outputs = model.gpt(inputs_embeds=generated)
96
+ logits = outputs.logits
97
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
98
+ logits = logits.softmax(-1).log()
99
+ if scores is None:
100
+ scores, next_tokens = logits.topk(beam_size, -1)
101
+ generated = generated.expand(beam_size, *generated.shape[1:])
102
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
103
+ if tokens is None:
104
+ tokens = next_tokens
105
+ else:
106
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
107
+ tokens = torch.cat((tokens, next_tokens), dim=1)
108
+ else:
109
+ logits[is_stopped] = -float(np.inf)
110
+ logits[is_stopped, 0] = 0
111
+ scores_sum = scores[:, None] + logits
112
+ seq_lengths[~is_stopped] += 1
113
+ scores_sum_average = scores_sum / seq_lengths[:, None]
114
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
115
+ next_tokens_source = next_tokens // scores_sum.shape[1]
116
+ seq_lengths = seq_lengths[next_tokens_source]
117
+ next_tokens = next_tokens % scores_sum.shape[1]
118
+ next_tokens = next_tokens.unsqueeze(1)
119
+ tokens = tokens[next_tokens_source]
120
+ tokens = torch.cat((tokens, next_tokens), dim=1)
121
+ generated = generated[next_tokens_source]
122
+ scores = scores_sum_average * seq_lengths
123
+ is_stopped = is_stopped[next_tokens_source]
124
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
125
+ generated = torch.cat((generated, next_token_embed), dim=1)
126
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
127
+ if is_stopped.all():
128
+ break
129
+ scores = scores / seq_lengths
130
+ output_list = tokens.cpu().numpy()
131
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
132
+ order = scores.argsort(descending=True)
133
+ output_texts = [output_texts[i] for i in order]
134
+ return output_texts
135
+
136
+
137
+ def generate2(
138
+ model,
139
+ tokenizer,
140
+ tokens=None,
141
+ prompt=None,
142
+ embed=None,
143
+ entry_count=1,
144
+ entry_length=67, # maximum number of words
145
+ top_p=0.8,
146
+ temperature=1.,
147
+ stop_token: str = '.',
148
+ ):
149
+ model.eval()
150
+ generated_num = 0
151
+ generated_list = []
152
+ stop_token_index = tokenizer.encode(stop_token)[0]
153
+ filter_value = -float("Inf")
154
+ device = next(model.parameters()).device
155
+
156
+ with torch.no_grad():
157
+
158
+ for entry_idx in trange(entry_count):
159
+ if embed is not None:
160
+ generated = embed
161
+ else:
162
+ if tokens is None:
163
+ tokens = torch.tensor(tokenizer.encode(prompt))
164
+ tokens = tokens.unsqueeze(0).to(device)
165
+
166
+ generated = model.gpt.transformer.wte(tokens)
167
+
168
+ for i in range(entry_length):
169
+
170
+ outputs = model.gpt(inputs_embeds=generated)
171
+ logits = outputs.logits
172
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
173
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
174
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
175
+ sorted_indices_to_remove = cumulative_probs > top_p
176
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
177
+ ..., :-1
178
+ ].clone()
179
+ sorted_indices_to_remove[..., 0] = 0
180
+
181
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
182
+ logits[:, indices_to_remove] = filter_value
183
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
184
+ next_token_embed = model.gpt.transformer.wte(next_token)
185
+ if tokens is None:
186
+ tokens = next_token
187
+ else:
188
+ tokens = torch.cat((tokens, next_token), dim=1)
189
+ generated = torch.cat((generated, next_token_embed), dim=1)
190
+ if stop_token_index == next_token.item():
191
+ break
192
+
193
+ output_list = list(tokens.squeeze().cpu().numpy())
194
+ output_text = tokenizer.decode(output_list)
195
+ generated_list.append(output_text)
196
+
197
+ return generated_list[0]
198
+
199
+ is_gpu = False
200
+
201
+ device = CUDA(0) if is_gpu else "cpu"
202
+ clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
203
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
204
+
205
+
206
+ prefix_length = 10
207
+
208
+ model = ClipCaptionModel(prefix_length)
209
+
210
+ model.load_state_dict(torch.load(model_path, map_location=CPU))
211
+
212
+ model = model.eval()
213
+ device = CUDA(0) if is_gpu else "cpu"
214
+ model = model.to(device)
215
+
216
+
217
+ def inference(img):
218
+ use_beam_search = False
219
+ image = io.imread(img.name)
220
+ pil_image = PIL.Image.fromarray(image)
221
+ image = preprocess(pil_image).unsqueeze(0).to(device)
222
+ with torch.no_grad():
223
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
224
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
225
+ if use_beam_search:
226
+ generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
227
+ else:
228
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
229
+ return generated_text_prefix
230
+
231
+ title = "Anime2Sketch"
232
+ description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
233
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
234
+
235
+ gr.Interface(
236
+ inference,
237
+ gr.inputs.Image(type="file", label="Input"),
238
+ gr.outputs.Textbox(label="Output"),
239
+ title=title,
240
+ description=description,
241
+ article=article,
242
+ enable_queue=True
243
+ ).launch(debug=True)