imthanhlv commited on
Commit
93007bd
1 Parent(s): 1ca3b6f

first commit

Browse files
Files changed (3) hide show
  1. app.py +277 -0
  2. drug.jpg +0 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://huggingface.co/spaces/akhaliq/CLIP_prefix_captioning/blob/main/app.py
2
+
3
+ import os
4
+ os.system("gdown https://drive.google.com/uc?id=1_8v2ZUUaf9hhXP35jESXJ_Hgzl_rmufh")
5
+ import clip
6
+ import os
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as nnf
11
+ import sys
12
+ from typing import Tuple, List, Union, Optional
13
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
14
+ from tqdm import tqdm, trange
15
+ import skimage.io as io
16
+ import PIL.Image
17
+ import gradio as gr
18
+ N = type(None)
19
+ V = np.array
20
+ ARRAY = np.ndarray
21
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
22
+ VS = Union[Tuple[V, ...], List[V]]
23
+ VN = Union[V, N]
24
+ VNS = Union[VS, N]
25
+ T = torch.Tensor
26
+ TS = Union[Tuple[T, ...], List[T]]
27
+ TN = Optional[T]
28
+ TNS = Union[Tuple[TN, ...], List[TN]]
29
+ TSN = Optional[TS]
30
+ TA = Union[T, ARRAY]
31
+ D = torch.device
32
+ CPU = torch.device('cpu')
33
+
34
+ def get_device(device_id: int) -> D:
35
+ if not torch.cuda.is_available():
36
+ return CPU
37
+ device_id = min(torch.cuda.device_count() - 1, device_id)
38
+ return torch.device(f'cuda:{device_id}')
39
+
40
+ CUDA = get_device
41
+
42
+
43
+ class MLP(nn.Module):
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ return self.model(x)
46
+
47
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
48
+ """Project clip output to embedding of first prefix_length tokens"""
49
+ super(MLP, self).__init__()
50
+ layers = []
51
+ for i in range(len(sizes) - 1):
52
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
53
+ if i < len(sizes) - 2:
54
+ layers.append(act())
55
+ # added some dropout here
56
+ layers.append(nn.Dropout(p=0.2))
57
+ self.model = nn.Sequential(*layers)
58
+
59
+
60
+ class ClipCaptionModel(nn.Module):
61
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
62
+ """Generate prefix tokens, shape Bxprefix_length"""
63
+ return torch.zeros(
64
+ batch_size, self.prefix_length, dtype=torch.int64, device=device
65
+ )
66
+
67
+ def forward(
68
+ self,
69
+ tokens: torch.Tensor,
70
+ prefix: torch.Tensor,
71
+ mask: Optional[torch.Tensor] = None,
72
+ labels: Optional[torch.Tensor] = None,
73
+ ):
74
+ embedding_text = self.gpt.transformer.wte(tokens)
75
+ prefix_projections = self.clip_project(prefix).view(
76
+ -1, self.prefix_length, self.gpt_embedding_size
77
+ )
78
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
79
+ if labels is not None:
80
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
81
+ labels = torch.cat((dummy_token, tokens), dim=1)
82
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
83
+ return out
84
+
85
+ def __init__(self, prefix_length: int = 10, prefix_size: int = 512):
86
+ super(ClipCaptionModel, self).__init__()
87
+ self.prefix_length = prefix_length
88
+ self.gpt = GPT2LMHeadModel.from_pretrained("imthanhlv/gpt2news")
89
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
90
+ self.clip_project = MLP(
91
+ (
92
+ prefix_size,
93
+ (self.gpt_embedding_size * prefix_length) // 2,
94
+ self.gpt_embedding_size * prefix_length,
95
+ )
96
+ )
97
+
98
+
99
+ class ClipCaptionPrefix(ClipCaptionModel):
100
+ def parameters(self, recurse: bool = True):
101
+ return self.clip_project.parameters()
102
+
103
+ def train(self, mode: bool = True):
104
+ super(ClipCaptionPrefix, self).train(mode)
105
+ self.gpt.eval()
106
+ return self
107
+
108
+
109
+ #@title Caption prediction
110
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
111
+ entry_length=67, temperature=1., stop_token: str = '.'):
112
+ model.eval()
113
+ stop_token_index = tokenizer.encode(stop_token)[0]
114
+ tokens = None
115
+ scores = None
116
+ device = next(model.parameters()).device
117
+ seq_lengths = torch.ones(beam_size, device=device)
118
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
119
+ with torch.no_grad():
120
+ if embed is not None:
121
+ generated = embed
122
+ else:
123
+ if tokens is None:
124
+ tokens = torch.tensor(tokenizer.encode(prompt))
125
+ tokens = tokens.unsqueeze(0).to(device)
126
+ generated = model.gpt.transformer.wte(tokens)
127
+ for i in range(entry_length):
128
+ outputs = model.gpt(inputs_embeds=generated)
129
+ logits = outputs.logits
130
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
131
+ logits = logits.softmax(-1).log()
132
+ if scores is None:
133
+ scores, next_tokens = logits.topk(beam_size, -1)
134
+ generated = generated.expand(beam_size, *generated.shape[1:])
135
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
136
+ if tokens is None:
137
+ tokens = next_tokens
138
+ else:
139
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
140
+ tokens = torch.cat((tokens, next_tokens), dim=1)
141
+ else:
142
+ logits[is_stopped] = -float(np.inf)
143
+ logits[is_stopped, 0] = 0
144
+ scores_sum = scores[:, None] + logits
145
+ seq_lengths[~is_stopped] += 1
146
+ scores_sum_average = scores_sum / seq_lengths[:, None]
147
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
148
+ next_tokens_source = next_tokens // scores_sum.shape[1]
149
+ seq_lengths = seq_lengths[next_tokens_source]
150
+ next_tokens = next_tokens % scores_sum.shape[1]
151
+ next_tokens = next_tokens.unsqueeze(1)
152
+ tokens = tokens[next_tokens_source]
153
+ tokens = torch.cat((tokens, next_tokens), dim=1)
154
+ generated = generated[next_tokens_source]
155
+ scores = scores_sum_average * seq_lengths
156
+ is_stopped = is_stopped[next_tokens_source]
157
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
158
+ generated = torch.cat((generated, next_token_embed), dim=1)
159
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
160
+ if is_stopped.all():
161
+ break
162
+ scores = scores / seq_lengths
163
+ output_list = tokens.cpu().numpy()
164
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
165
+ order = scores.argsort(descending=True)
166
+ output_texts = [output_texts[i] for i in order]
167
+ return output_texts
168
+
169
+
170
+ def generate2(
171
+ model,
172
+ tokenizer,
173
+ tokens=None,
174
+ prompt=None,
175
+ embed=None,
176
+ entry_count=1,
177
+ entry_length=67, # maximum number of words
178
+ top_p=0.8,
179
+ temperature=1.,
180
+ stop_token: str = '.',
181
+ ):
182
+ model.eval()
183
+ generated_num = 0
184
+ generated_list = []
185
+ stop_token_index = tokenizer.encode(stop_token)[0]
186
+ filter_value = -float("Inf")
187
+ device = next(model.parameters()).device
188
+ with torch.no_grad():
189
+ for entry_idx in trange(entry_count):
190
+ if embed is not None:
191
+ generated = embed
192
+ else:
193
+ if tokens is None:
194
+ tokens = torch.tensor(tokenizer.encode(prompt))
195
+ tokens = tokens.unsqueeze(0).to(device)
196
+ generated = model.gpt.transformer.wte(tokens)
197
+ for i in range(entry_length):
198
+ outputs = model.gpt(inputs_embeds=generated)
199
+ logits = outputs.logits
200
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
201
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
202
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
203
+ sorted_indices_to_remove = cumulative_probs > top_p
204
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
205
+ ..., :-1
206
+ ].clone()
207
+ sorted_indices_to_remove[..., 0] = 0
208
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
209
+ logits[:, indices_to_remove] = filter_value
210
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
211
+ next_token_embed = model.gpt.transformer.wte(next_token)
212
+ if tokens is None:
213
+ tokens = next_token
214
+ else:
215
+ tokens = torch.cat((tokens, next_token), dim=1)
216
+ generated = torch.cat((generated, next_token_embed), dim=1)
217
+ if stop_token_index == next_token.item():
218
+ break
219
+ output_list = list(tokens.squeeze().cpu().numpy())
220
+ output_text = tokenizer.decode(output_list)
221
+ generated_list.append(output_text)
222
+ return generated_list[0]
223
+
224
+ is_gpu = False
225
+ device = CUDA(0) if is_gpu else "cpu"
226
+ clip_model, preprocess = clip.load("ViT-B/16", device=device, jit=False)
227
+ tokenizer = GPT2Tokenizer.from_pretrained("imthanhlv/gpt2news")
228
+
229
+ def inference(img, text, is_translate):
230
+ prefix_length = 10
231
+ model = ClipCaptionModel(prefix_length)
232
+ model_path = 'sat_019.pt'
233
+ model.load_state_dict(torch.load(model_path, map_location=CPU))
234
+ model = model.eval()
235
+ device = CUDA(0) if is_gpu else "cpu"
236
+ model = model.to(device)
237
+ use_beam_search = True
238
+ if is_translate:
239
+ # encode text
240
+ text = clip_model.tokenize([text]).to(device)
241
+ with torch.no_grad():
242
+ prefix = clip_model.encode_text(text).to(device, dtype=torch.float32)
243
+
244
+ else:
245
+ image = io.imread(img.name)
246
+ pil_image = PIL.Image.fromarray(image)
247
+ image = preprocess(pil_image).unsqueeze(0).to(device)
248
+
249
+ with torch.no_grad():
250
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
251
+
252
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
253
+ if use_beam_search:
254
+ generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
255
+ else:
256
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
257
+ return generated_text_prefix
258
+
259
+ title = "CLIP Dual encoder"
260
+ description = "You can translate English sentence to Vietnamese sentence or generate Vietnamese caption from image"
261
+ examples=[["drug.jpg","", False], ["", "What is your name?", True]]
262
+
263
+ inputs = [
264
+ gr.inputs.Image(type="file", label="Image to generate Vietnamese caption"),
265
+ gr.inputs.Textbox(lines=2, placeholder="English sentence for translation"),
266
+ gr.inputs.Checkbox()
267
+ ]
268
+
269
+ gr.Interface(
270
+ inference,
271
+ inputs,
272
+ gr.outputs.Textbox(label="Vietnamese sentence"),
273
+ title=title,
274
+ description=description,
275
+ enable_queue=True,
276
+ examples=examples
277
+ ).launch(debug=True)
drug.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ gdown
3
+ torch
4
+ numpy
5
+ tqdm
6
+ Pillow
7
+ scikit-image
8
+ git+https://github.com/openai/CLIP.git