bczhou commited on
Commit
859131c
1 Parent(s): 8afcbeb

commit demo to space

Browse files
Files changed (9) hide show
  1. cat_with_food.png +0 -0
  2. config.py +24 -0
  3. demo.py +56 -0
  4. dog_with_frisbee.png +0 -0
  5. linear_mapping.py +278 -0
  6. main.py +116 -0
  7. pytorch_model.bin +3 -0
  8. stop_sign.png +0 -0
  9. two_bear.png +0 -0
cat_with_food.png ADDED
config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ PREFIX_MAP = {
4
+ "openai/clip-vit-base-patch32": 50,
5
+ "openai/clip-vit-large-patch14": 257
6
+ }
7
+
8
+
9
+ @dataclass
10
+ class LinearMappingConfig:
11
+ image_model: str = "openai/clip-vit-base-patch32"
12
+ freeze_image_model: bool = True
13
+ text_model: str = "gpt2-large"
14
+ freeze_text_model: bool = True
15
+ image_hidden_size: int = 768
16
+ text_hidden_size: int = 1280
17
+ linear_mapping_type: int = "linear"
18
+ max_seq_length: int = 2048
19
+ image_resize: int = 224
20
+ add_image_token: bool = True
21
+ freeze_ln: bool = False
22
+
23
+ def __post_init__(self):
24
+ self.prefix_length = PREFIX_MAP[self.image_model]
demo.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from linear_mapping import LinearMapping, LinearMappingConfig, LinearMappingProcessor
3
+ import os
4
+ import torch
5
+
6
+ os.environ['CURL_CA_BUNDLE'] = ''
7
+
8
+ config = LinearMappingConfig()
9
+ model = LinearMapping(config)
10
+ model.load_state_dict(torch.load("pytorch_model.bin"))
11
+ processor = LinearMappingProcessor(config)
12
+ processor.tokenizer.padding_side = 'left'
13
+ processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
14
+
15
+ title = "Generate Image Captions With CLIP And GPT2"
16
+
17
+
18
+ def generate_image_captions(image, text):
19
+ inputs = processor(images=image, texts=text, return_tensors="pt")
20
+ input_ids = inputs.get("input_ids", None)
21
+ pixel_values = inputs.get("pixel_values", None)
22
+ attention_mask = inputs.get("attention_mask", None)
23
+ prediction = model.generate(
24
+ pixel_values=pixel_values,
25
+ input_ids=input_ids,
26
+ attention_mask=attention_mask,
27
+ max_new_tokens=50
28
+ )
29
+
30
+ prediction_text = processor.decode(prediction[0], num_beams=5, skip_special_tokens=True)
31
+ return prediction_text
32
+
33
+
34
+ article = "This demo is originated from this paper: [original paper](https://arxiv.org/abs/2209.15162)"
35
+ description = """
36
+ ### Expand GPT2's language capabilities to vision with CLIP!
37
+ """
38
+ demo = gr.Interface(
39
+ fn=generate_image_captions,
40
+ inputs=[
41
+ gr.Image(),
42
+ gr.Textbox(placeholder="A picture of", lines=3)
43
+ ],
44
+ outputs="text",
45
+ examples=[
46
+ [os.path.join(os.getcwd(), 'two_bear.png'), ""],
47
+ [os.path.join(os.getcwd(), 'cat_with_food.png'), "Describe the picture:"],
48
+ [os.path.join(os.getcwd(), 'dog_with_frisbee.png'), "What is the color of the frisbee in the photo? Answer:"],
49
+ [os.path.join(os.getcwd(), 'stop_sign.png'), "What does the sign in the picture say? Answer:"]
50
+ ],
51
+ article=article,
52
+ title=title,
53
+ description=description
54
+ )
55
+
56
+ demo.launch(share=True)
dog_with_frisbee.png ADDED
linear_mapping.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import LinearMappingConfig
2
+ from transformers import (
3
+ GPT2TokenizerFast, GPT2LMHeadModel, AutoModel,
4
+ CLIPVisionModel, AutoProcessor, BatchEncoding,
5
+ )
6
+ from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModelOutput
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import List, Optional, Union, Tuple, Dict
10
+ from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
11
+ from torchvision.transforms.functional import InterpolationMode
12
+
13
+
14
+ class Transform(torch.nn.Module):
15
+ def __init__(self, image_size, mean, std):
16
+ super().__init__()
17
+ self.transforms = torch.nn.Sequential(
18
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC, antialias=True),
19
+ CenterCrop(image_size),
20
+ ConvertImageDtype(torch.float32),
21
+ Normalize(mean, std),
22
+ )
23
+
24
+ def forward(self, x) -> torch.Tensor:
25
+ """`x` should be an instance of `PIL.Image.Image`"""
26
+ with torch.no_grad():
27
+ x = self.transforms(x)
28
+ return x
29
+
30
+
31
+ class LinearMappingProcessor:
32
+ """
33
+ A combination of ImageProcessor and GPT2TokenizerFast
34
+ """
35
+
36
+ def __init__(self, config: LinearMappingConfig):
37
+ self.image_processor = AutoProcessor.from_pretrained(config.image_model)
38
+ self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
39
+ self.add_image_token = config.add_image_token
40
+ if config.add_image_token:
41
+ self.tokenizer.add_special_tokens({"cls_token": "|<image>|"})
42
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
43
+ self.tokenizer.padding_side = "right"
44
+ self.prefix_length = config.prefix_length
45
+
46
+ def __call__(self, texts=None, images=None, return_tensors="pt", **kwargs):
47
+ """
48
+ The processor assumes that images and texts are of the same number
49
+ """
50
+
51
+ if len(texts) == 0: # empty strings should be None
52
+ texts = None
53
+
54
+ if images is not None:
55
+ image_features = self.image_processor(images=images, return_tensors=return_tensors, **kwargs)
56
+ image_features["attention_mask"] = torch.ones(image_features.pixel_values.size(0),
57
+ self.prefix_length).to(dtype=torch.int64)
58
+ if texts is None and self.add_image_token:
59
+ texts = [self.tokenizer.cls_token for _ in range(image_features.pixel_values.size(0))]
60
+ elif texts is not None and self.add_image_token:
61
+ if isinstance(texts, str):
62
+ texts = [texts]
63
+ texts = [self.tokenizer.cls_token + text for text in texts]
64
+
65
+ elif texts is None:
66
+ texts = self.tokenizer.bos_token
67
+
68
+ if texts is not None:
69
+ encoding = self.tokenizer(texts, return_tensors=return_tensors, **kwargs)
70
+
71
+ if texts is not None and images is not None:
72
+ encoding["pixel_values"] = image_features.pixel_values
73
+
74
+ encoding["attention_mask"] = torch.cat([
75
+ image_features["attention_mask"],
76
+ encoding["attention_mask"]
77
+ ], dim=1).to(dtype=torch.long) # create attention mask for images
78
+ return encoding
79
+
80
+ elif texts is not None:
81
+ return encoding
82
+ else:
83
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
84
+
85
+ def batch_decode(self, *args, **kwargs):
86
+ """
87
+ This method forwards all its arguments to GPT2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
88
+ refer to the docstring of this method for more information.
89
+ """
90
+ return self.tokenizer.batch_decode(*args, **kwargs)
91
+
92
+ def decode(self, *args, **kwargs):
93
+ """
94
+ This method forwards all its arguments to GPT2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
95
+ the docstring of this method for more information.
96
+ """
97
+ return self.tokenizer.decode(*args, **kwargs)
98
+
99
+
100
+ class ImagePrefix(nn.Module):
101
+ """
102
+ Converts pixel values to prefix image prompts that are later fed to a LLM
103
+ """
104
+
105
+ def __init__(self, config: LinearMappingConfig):
106
+ super().__init__()
107
+ self.encoder = AutoModel.from_pretrained(config.image_model)
108
+ if "clip" in config.image_model:
109
+ self.encoder = CLIPVisionModel.from_pretrained(config.image_model)
110
+
111
+ if config.freeze_image_model:
112
+ for param in self.encoder.parameters():
113
+ param.requires_grad = False
114
+
115
+ self.linear = nn.Linear(config.image_hidden_size, config.text_hidden_size)
116
+ self.ln = nn.LayerNorm(config.text_hidden_size)
117
+
118
+ def forward(
119
+ self, pixel_values: torch.Tensor # B x C x H x W
120
+ ) -> torch.Tensor:
121
+ prefixes = self.encoder(pixel_values).last_hidden_state # B x N x D
122
+ prefix_prompts = self.linear(prefixes)
123
+ return self.ln(prefix_prompts)
124
+
125
+
126
+ class LinearMapping(nn.Module):
127
+
128
+ def __init__(self, config: LinearMappingConfig):
129
+ super().__init__()
130
+ self.image_prefix = ImagePrefix(config)
131
+ self.language_model = GPT2LMHeadModel.from_pretrained(config.text_model)
132
+ self.processor = LinearMappingProcessor(config)
133
+ self.tokenizer = self.processor.tokenizer
134
+ self.image_processor = self.processor.image_processor
135
+ self.add_image_token = config.add_image_token
136
+ if config.add_image_token:
137
+ self.language_model.resize_token_embeddings(len(self.tokenizer))
138
+
139
+ if config.freeze_text_model:
140
+ for module in self.language_model.modules():
141
+ if not isinstance(module, nn.LayerNorm) or config.freeze_ln:
142
+ for param in module.parameters():
143
+ param.requires_grad = False
144
+ if config.add_image_token:
145
+ # create a gradient mask for the lm_head weight and bias and hook it
146
+ self.language_model.lm_head.weight.requires_grad = True
147
+ self.weight_gradient_mask = nn.Parameter(torch.zeros_like(self.language_model.lm_head.weight),
148
+ requires_grad=False)
149
+ self.weight_gradient_mask[-1, :] = 1.0
150
+ self.language_model.lm_head.weight.register_hook(lambda grad: grad.mul_(self.weight_gradient_mask))
151
+
152
+ def prepare_text_inputs(self, input_ids: torch.Tensor) -> torch.Tensor:
153
+ return self.language_model.transformer.wte(input_ids.to(dtype=torch.int64))
154
+
155
+ def prepare_inputs(
156
+ self,
157
+ input_ids: Optional[torch.Tensor],
158
+ pixel_values: Optional[torch.Tensor]
159
+ ) -> Dict:
160
+ """
161
+ Prepare captions and pixel values for training.
162
+ It takes the captions' input ids and turn them into input embeddings
163
+ and turns pixel values into prefix prompts.
164
+ Then it concatenates them into one whole prompt batch.
165
+ """
166
+ if input_ids is not None and pixel_values is not None:
167
+
168
+ text_embeddings = self.prepare_text_inputs(input_ids) # B x T x D
169
+ prefix_prompts = self.image_prefix(pixel_values) # B x V x D
170
+ inputs_embeddings = torch.cat([prefix_prompts, text_embeddings], dim=1)
171
+
172
+ prefix_labels = torch.zeros(prefix_prompts.shape[:2], device=prefix_prompts.device) - 100
173
+ labels = torch.cat([prefix_labels, input_ids], dim=1) # B x (V + T)
174
+
175
+ for label in labels:
176
+ for k, token in enumerate(label):
177
+ if token == self.tokenizer.eos_token_id:
178
+ label[k + 1:] = -100
179
+ break
180
+ return {"hidden_states": inputs_embeddings, "labels": labels.to(dtype=torch.int64)}
181
+
182
+ elif pixel_values is not None:
183
+ prefix_prompts = self.image_prefix(pixel_values) # B x V x D
184
+ prefix_labels = torch.zeros(prefix_prompts.shape[:2], device=prefix_prompts.device) - 100
185
+ return {"hidden_states": prefix_prompts, "labels": prefix_labels.to(dtype=torch.int64)}
186
+
187
+ elif input_ids is not None:
188
+ text_embeddings = self.prepare_text_inputs(input_ids)
189
+ labels = input_ids.clone()
190
+ for label in labels:
191
+ for k, token in enumerate(label):
192
+ if token == self.tokenizer.eos_token_id:
193
+ label[k + 1:] = -100
194
+ break
195
+ return {"hidden_states": text_embeddings, "labels": labels.to(dtype=torch.int64)}
196
+ else:
197
+ return {"hidden_states": None, "labels": None}
198
+
199
+ @torch.no_grad()
200
+ def generate(
201
+ self,
202
+ input_ids: Optional[torch.Tensor] = None,
203
+ pixel_values: Optional[torch.Tensor] = None,
204
+ **kwargs
205
+ ):
206
+ if pixel_values is None:
207
+ return self.language_model.generate(
208
+ input_ids=input_ids,
209
+ **kwargs
210
+ )
211
+ batch_size = pixel_values.size(0)
212
+ past_input_ids = None
213
+ if input_ids is None:
214
+ if self.add_image_token:
215
+ input_ids = torch.tensor([self.tokenizer.cls_token_id for _ in range(batch_size)]).view(batch_size, -1)
216
+ else:
217
+ input_ids = torch.tensor([self.tokenizer.bos_token_id for _ in range(batch_size)]).view(batch_size, -1)
218
+ if input_ids.size(-1) <= 1:
219
+ first_forward_outputs = self.forward(
220
+ pixel_values=pixel_values
221
+ )
222
+ else:
223
+ first_forward_outputs = self.forward(
224
+ pixel_values=pixel_values,
225
+ input_ids=input_ids[:, :-1]
226
+ )
227
+ past_input_ids = input_ids[:, :-1]
228
+ input_ids = input_ids[:, -1].view(batch_size, -1)
229
+
230
+ past_key_values = first_forward_outputs.past_key_values
231
+
232
+ if kwargs.get("attention_mask", None) is None:
233
+ attention_mask_size = (past_key_values[0][0].size(0), past_key_values[0][0].size(-2))
234
+
235
+ attention_mask = torch.ones(attention_mask_size, dtype=torch.int64)
236
+ else:
237
+ attention_mask = kwargs.pop("attention_mask")
238
+
239
+ generated_token_ids = self.language_model.generate(
240
+ past_key_values=past_key_values,
241
+ input_ids=input_ids,
242
+ attention_mask=attention_mask,
243
+ **kwargs
244
+ )
245
+ if past_input_ids is not None:
246
+ generated_token_ids = torch.cat([past_input_ids, generated_token_ids], dim=-1)
247
+ return generated_token_ids
248
+
249
+ def forward(
250
+ self,
251
+ input_ids: Optional[torch.Tensor] = None,
252
+ pixel_values: Optional[torch.Tensor] = None,
253
+ labels: Optional[torch.Tensor] = None,
254
+ inputs_embeds: Optional[torch.Tensor] = None,
255
+ output_hidden_states: bool = True,
256
+ output_attentions: bool = True,
257
+ attention_mask: Optional[torch.Tensor] = None,
258
+ return_dict: Optional[bool] = True,
259
+ **kwargs
260
+ ) -> Union[GPT2DoubleHeadsModelOutput, Tuple]:
261
+ if (pixel_values is None and input_ids is None) and inputs_embeds is None:
262
+ raise ValueError("You have to specify inputs")
263
+ if inputs_embeds is not None and (pixel_values is not None or input_ids is not None):
264
+ raise ValueError("Either inputs_embeds or (pixel_values and input_ids) should be specified, not both")
265
+
266
+ inputs = self.prepare_inputs(input_ids, pixel_values)
267
+ hidden_states = inputs.get('hidden_states', None) if inputs_embeds is None else inputs_embeds
268
+ labels = inputs.get('labels', None) if labels is None else labels
269
+
270
+ return self.language_model(
271
+ inputs_embeds=hidden_states,
272
+ labels=labels,
273
+ output_hidden_states=output_hidden_states,
274
+ output_attentions=output_attentions,
275
+ attention_mask=attention_mask,
276
+ return_dict=return_dict,
277
+ **kwargs
278
+ )
main.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from linear_mapping import LinearMapping, LinearMappingProcessor, LinearMappingConfig, Transform
3
+ import torch
4
+ from torchvision.io import ImageReadMode, read_image
5
+ from transformers import Trainer, TrainingArguments
6
+ import os
7
+ from PIL import Image
8
+ os.environ["WANDB_DISABLED"] = "true"
9
+
10
+ DATA_DIR = os.path.join(os.getcwd(), "coco")
11
+ CAPTION_COLUMN = "caption"
12
+ IMAGE_COLUMN = "image_path"
13
+
14
+
15
+ def main():
16
+ ds = load_dataset("ydshieh/coco_dataset_script", "2017", DATA_DIR)
17
+ config = LinearMappingConfig()
18
+ processor = LinearMappingProcessor(config)
19
+
20
+ def collate_fn(batch):
21
+ return {
22
+ 'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
23
+ 'input_ids': torch.tensor([x['input_ids'] for x in batch], dtype=torch.long),
24
+ 'attention_mask': torch.stack([x["attention_mask"] for x in batch]),
25
+ }
26
+
27
+ def tokenize_fn(examples):
28
+ texts = list(examples[CAPTION_COLUMN])
29
+ if config.add_image_token:
30
+ texts = list(processor.tokenizer.cls_token + text for text in texts)
31
+ inputs = processor.tokenizer(
32
+ texts, padding="max_length", max_length=77,
33
+ return_tensors="pt", truncation=True
34
+ )
35
+ examples["input_ids"] = inputs.input_ids
36
+ examples["attention_mask"] = inputs.attention_mask
37
+ return examples
38
+
39
+ image_transformations = Transform(
40
+ config.image_resize,
41
+ [0.48145466, 0.4578275, 0.40821073],
42
+ [0.26862954, 0.26130258, 0.27577711]
43
+ )
44
+ image_transformations = torch.jit.script(image_transformations)
45
+
46
+ def transform_images(examples):
47
+ images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]]
48
+ examples["pixel_values"] = [image_transformations(image) for image in images]
49
+
50
+ examples["attention_mask"] = torch.cat([
51
+ torch.ones(len(images), config.prefix_length),
52
+ torch.tensor(examples["attention_mask"])
53
+ ], dim=1).to(dtype=torch.long)
54
+ return examples
55
+
56
+ def preprocess_fn(examples):
57
+
58
+ texts = list(examples[CAPTION_COLUMN])
59
+
60
+ images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]]
61
+ inputs = processor(
62
+ texts=texts, images=images, padding="max_length", truncation=True, max_length=77, return_tensors="pt"
63
+ )
64
+ return inputs
65
+
66
+ def filter_corrupt_images(examples):
67
+ """remove problematic images"""
68
+ valid_images = []
69
+ for image_file in examples[IMAGE_COLUMN]:
70
+ try:
71
+ Image.open(image_file)
72
+ valid_images.append(True)
73
+ except Exception:
74
+ valid_images.append(False)
75
+ return valid_images
76
+
77
+ train_dataset = ds["train"]
78
+
79
+ train_dataset = train_dataset.filter(
80
+ function=filter_corrupt_images,
81
+ batched=True
82
+ )
83
+ train_dataset = train_dataset.map(
84
+ function=tokenize_fn,
85
+ batched=True,
86
+ remove_columns=[col for col in train_dataset.column_names if col != IMAGE_COLUMN and col != CAPTION_COLUMN],
87
+ load_from_cache_file=True
88
+ )
89
+ train_dataset.set_transform(transform_images)
90
+
91
+ training_args = TrainingArguments(
92
+ learning_rate=5e-4,
93
+ lr_scheduler_type='cosine',
94
+ output_dir='clip-gpt2-image-captioner',
95
+ do_train=True,
96
+ logging_steps=50,
97
+ num_train_epochs=5,
98
+ logging_dir='runs',
99
+ remove_unused_columns=False,
100
+ max_grad_norm=1.0,
101
+ per_device_train_batch_size=16,
102
+ save_total_limit=3,
103
+ warmup_steps=500
104
+ )
105
+ model = LinearMapping(config)
106
+ trainer = Trainer(
107
+ model=model,
108
+ args=training_args,
109
+ train_dataset=train_dataset,
110
+ data_collator=collate_fn
111
+ )
112
+ trainer.train()
113
+
114
+
115
+ if __name__ == '__main__':
116
+ main()
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f817ef4696fa1ccb00cf19e71ed36660d9c52212fd1e953dbf52f923a7553ca0
3
+ size 3707484877
stop_sign.png ADDED
two_bear.png ADDED