zhangzhe45 commited on
Commit
50d1ff1
1 Parent(s): b1893ac
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from urllib.parse import urlparse
3
+
4
+ from PIL import Image
5
+ import requests
6
+ import torch
7
+ from timm.models.hub import download_cached_file
8
+ from torchvision import transforms
9
+ from torchvision.transforms.functional import InterpolationMode
10
+ import gradio as gr
11
+ from mm_commerce import BLIP_Decoder
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+
16
+ def is_url(url_or_filename):
17
+ parsed = urlparse(url_or_filename)
18
+ return parsed.scheme in ("http", "https")
19
+
20
+
21
+ def load_checkpoint(url_or_filename):
22
+ if is_url(url_or_filename):
23
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
24
+ checkpoint = torch.load(cached_file, map_location='cpu')
25
+ elif os.path.isfile(url_or_filename):
26
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
27
+ else:
28
+ raise RuntimeError('checkpoint url or path is invalid')
29
+ return checkpoint
30
+
31
+
32
+ image_size = 224
33
+ transform = transforms.Compose([
34
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
37
+ ])
38
+
39
+ model = BLIP_Decoder(med_config='configs/med_large_config.json', vit='large_v2', prompt='[DEC]')
40
+
41
+ ckpt = 'https://huggingface.co/zhezh/mm_commerce_zhcn/resolve/main/model.pth'
42
+ sd = load_checkpoint(ckpt)
43
+ model.load_state_dict(sd, strict=True)
44
+
45
+ model.eval()
46
+ model = model.to('cuda')
47
+
48
+
49
+ def inference(raw_image, strategy):
50
+ image = transform(raw_image).unsqueeze(0).to(device)
51
+ with torch.no_grad():
52
+ if strategy == "Beam search":
53
+ caption = model.generate(image, sample=False, num_beams=10, max_length=100, min_length=10)
54
+ else:
55
+ caption = model.generate(image, sample=True, top_p=0.9, max_length=100, min_length=10)
56
+ return '商品描述: ' + '"' + ''.join(caption[0][6:-5].split()) + '"'
57
+
58
+
59
+ inputs = [
60
+ gr.inputs.Image(type='pil'),
61
+ gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Beam search", label="文本生成策略")
62
+ ]
63
+ outputs = gr.outputs.Textbox(label="生成的标题(Output)")
64
+
65
+ title = "MM Commerce ZhCN (中文商品描述生成)"
66
+
67
+ description = "中文商品描述生成 -- By Zhe Zhang"
68
+
69
+ demo = gr.Interface(
70
+ inference, inputs, outputs, title=title, description=description,
71
+ # article=article,
72
+ examples=[
73
+ ['starrynight.jpeg', "Nucleus sampling"],
74
+ ['resources/examples/zhuobu.jpg', "Beam search"],
75
+ ['resources/examples/jiandao.jpg', "Beam search"],
76
+ ['resources/examples/lego-yellow.jpg', "Beam search"],
77
+ ['resources/examples/charger.jpg', "Beam search"],
78
+ ['resources/examples/charger-ugreen.jpg', "Beam search"],
79
+ ['resources/examples/charger-hw.jpg', "Beam search"],
80
+ ],
81
+ )
82
+ # demo.launch(enable_queue=True, share=True, server_name='0.0.0.0', server_port=8080,)
83
+ demo.launch(enable_queue=True)
configs/med_large_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 4096,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 16,
15
+ "num_hidden_layers": 24,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 21130,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
mm_commerce.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from torchvision.transforms import InterpolationMode
7
+
8
+ warnings.filterwarnings("ignore")
9
+
10
+ from models.vit import VisionTransformer, interpolate_pos_embed
11
+ from models.med import BertConfig, BertModel, BertLMHeadModel
12
+ from transformers import BertTokenizer, CLIPConfig
13
+ from models.modeling_clip import CLIPModel, CLIPVisionModel, CLIPVisionConfig
14
+
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class BLIP_Decoder(nn.Module):
21
+ def __init__(self,
22
+ med_config='configs/med_config.json',
23
+ image_size=384,
24
+ vit='base',
25
+ vit_grad_ckpt=False,
26
+ vit_ckpt_layer=0,
27
+ prompt='[DEC]',
28
+ ):
29
+ super().__init__()
30
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
31
+ self.tokenizer = init_tokenizer()
32
+ med_config = BertConfig.from_json_file(med_config)
33
+ med_config.encoder_width = vision_width
34
+ self.text_decoder = BertLMHeadModel(config=med_config)
35
+
36
+ self.prompt = prompt
37
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
38
+
39
+ def forward(self, image, caption):
40
+
41
+ image_embeds = self.visual_encoder(image)
42
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
43
+
44
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
45
+
46
+ text.input_ids[:, 0] = self.tokenizer.bos_token_id
47
+
48
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
49
+ decoder_targets[:, :self.prompt_length] = -100
50
+
51
+ decoder_output = self.text_decoder(text.input_ids,
52
+ attention_mask=text.attention_mask,
53
+ encoder_hidden_states=image_embeds,
54
+ encoder_attention_mask=image_atts,
55
+ labels=decoder_targets,
56
+ return_dict=True,
57
+ )
58
+ loss_lm = decoder_output.loss
59
+
60
+ return loss_lm
61
+
62
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
63
+ image_embeds = self.visual_encoder(image)
64
+
65
+ if not sample:
66
+ image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
67
+
68
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
69
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask": image_atts}
70
+
71
+ prompt = [self.prompt] * image.size(0)
72
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
73
+ input_ids[:, 0] = self.tokenizer.bos_token_id
74
+ input_ids = input_ids[:, :-1]
75
+
76
+ if sample:
77
+ # nucleus sampling
78
+ outputs = self.text_decoder.generate(input_ids=input_ids,
79
+ max_length=max_length,
80
+ min_length=min_length,
81
+ do_sample=True,
82
+ top_p=top_p,
83
+ num_return_sequences=1,
84
+ eos_token_id=self.tokenizer.sep_token_id,
85
+ pad_token_id=self.tokenizer.pad_token_id,
86
+ repetition_penalty=1.1,
87
+ **model_kwargs)
88
+ else:
89
+ # beam search
90
+ outputs = self.text_decoder.generate(input_ids=input_ids,
91
+ max_length=max_length,
92
+ min_length=min_length,
93
+ num_beams=num_beams,
94
+ eos_token_id=self.tokenizer.sep_token_id,
95
+ pad_token_id=self.tokenizer.pad_token_id,
96
+ repetition_penalty=repetition_penalty,
97
+ **model_kwargs)
98
+
99
+ captions = []
100
+ for output in outputs:
101
+ caption = self.tokenizer.decode(output, skip_special_tokens=False)
102
+ captions.append(caption[len(self.prompt):])
103
+ return captions
104
+
105
+
106
+ def init_tokenizer():
107
+ tokenizer = BertTokenizer.from_pretrained('resources/bert-large-chinese', do_lower_case=True)
108
+ tokenizer.add_special_tokens({'bos_token': '[DEC]'})
109
+ tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
110
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
111
+ return tokenizer
112
+
113
+
114
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
115
+ assert vit in ['base', 'large', 'large_v2'], "vit parameter must be base or large"
116
+ if vit == 'base':
117
+ vision_width = 768
118
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
119
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
120
+ drop_path_rate=0 or drop_path_rate
121
+ )
122
+ elif vit == 'large':
123
+ vision_width = 1024
124
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
125
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
126
+ drop_path_rate=0.1 or drop_path_rate
127
+ )
128
+ elif vit == 'large_v2':
129
+ vision_width = 1024
130
+ clip_config = CLIPConfig.from_pretrained('resources/clip_vit_large_patch14')
131
+ visual_encoder = CLIPVisionModel(clip_config)
132
+ return visual_encoder, vision_width
133
+
134
+
135
+ def load_image(image, image_size, device):
136
+ raw_image = Image.open(str(image)).convert('RGB')
137
+
138
+ w, h = raw_image.size
139
+
140
+ transform = transforms.Compose([
141
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
142
+ transforms.ToTensor(),
143
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
144
+ ])
145
+ image = transform(raw_image).unsqueeze(0).to(device)
146
+ return image
models/__init__.py ADDED
File without changes
models/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
models/modeling_clip.py ADDED
@@ -0,0 +1,1054 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch CLIP model."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.file_utils import (
27
+ ModelOutput,
28
+ add_start_docstrings,
29
+ add_start_docstrings_to_model_forward,
30
+ replace_return_docstrings,
31
+ )
32
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import logging
35
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
41
+
42
+ CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
+ "openai/clip-vit-base-patch32",
44
+ # See all CLIP models at https://huggingface.co/models?filter=clip
45
+ ]
46
+
47
+
48
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
49
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
50
+ """
51
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
52
+ """
53
+ bsz, src_len = mask.size()
54
+ tgt_len = tgt_len if tgt_len is not None else src_len
55
+
56
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
57
+
58
+ inverted_mask = 1.0 - expanded_mask
59
+
60
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
61
+
62
+
63
+ # contrastive loss function, adapted from
64
+ # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
65
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
66
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
67
+
68
+
69
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
70
+ caption_loss = contrastive_loss(similarity)
71
+ image_loss = contrastive_loss(similarity.T)
72
+ return (caption_loss + image_loss) / 2.0
73
+
74
+
75
+ @dataclass
76
+ class CLIPOutput(ModelOutput):
77
+ """
78
+ Args:
79
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
80
+ Contrastive loss for image-text similarity.
81
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
82
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
83
+ similarity scores.
84
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
85
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
86
+ similarity scores.
87
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
88
+ The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
89
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
90
+ The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
91
+ text_model_output(`BaseModelOutputWithPooling`):
92
+ The output of the [`CLIPTextModel`].
93
+ vision_model_output(`BaseModelOutputWithPooling`):
94
+ The output of the [`CLIPVisionModel`].
95
+ """
96
+
97
+ loss: Optional[torch.FloatTensor] = None
98
+ logits_per_image: torch.FloatTensor = None
99
+ logits_per_text: torch.FloatTensor = None
100
+ text_embeds: torch.FloatTensor = None
101
+ image_embeds: torch.FloatTensor = None
102
+ text_model_output: BaseModelOutputWithPooling = None
103
+ vision_model_output: BaseModelOutputWithPooling = None
104
+
105
+ def to_tuple(self) -> Tuple[Any]:
106
+ return tuple(
107
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
108
+ for k in self.keys()
109
+ )
110
+
111
+
112
+ class CLIPVisionEmbeddings(nn.Module):
113
+ def __init__(self, config: CLIPVisionConfig):
114
+ super().__init__()
115
+ self.config = config
116
+ self.embed_dim = config.hidden_size
117
+ self.image_size = config.image_size
118
+ self.patch_size = config.patch_size
119
+
120
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
121
+
122
+ self.patch_embedding = nn.Conv2d(
123
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
124
+ )
125
+
126
+ self.num_patches = (self.image_size // self.patch_size) ** 2
127
+ self.num_positions = self.num_patches + 1
128
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
129
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
130
+
131
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
132
+ batch_size = pixel_values.shape[0]
133
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
134
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
135
+
136
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
137
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
138
+ embeddings = embeddings + self.position_embedding(self.position_ids)
139
+ return embeddings
140
+
141
+
142
+ class CLIPTextEmbeddings(nn.Module):
143
+ def __init__(self, config: CLIPTextConfig):
144
+ super().__init__()
145
+ embed_dim = config.hidden_size
146
+
147
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
148
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
149
+
150
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
151
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
152
+
153
+ def forward(
154
+ self,
155
+ input_ids: Optional[torch.LongTensor] = None,
156
+ position_ids: Optional[torch.LongTensor] = None,
157
+ inputs_embeds: Optional[torch.FloatTensor] = None,
158
+ ) -> torch.Tensor:
159
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
160
+
161
+ if position_ids is None:
162
+ position_ids = self.position_ids[:, :seq_length]
163
+
164
+ if inputs_embeds is None:
165
+ inputs_embeds = self.token_embedding(input_ids)
166
+
167
+ position_embeddings = self.position_embedding(position_ids)
168
+ embeddings = inputs_embeds + position_embeddings
169
+
170
+ return embeddings
171
+
172
+
173
+ class CLIPAttention(nn.Module):
174
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
175
+
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ self.config = config
179
+ self.embed_dim = config.hidden_size
180
+ self.num_heads = config.num_attention_heads
181
+ self.head_dim = self.embed_dim // self.num_heads
182
+ assert (
183
+ self.head_dim * self.num_heads == self.embed_dim
184
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
185
+ self.scale = self.head_dim**-0.5
186
+ self.dropout = config.attention_dropout
187
+
188
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
189
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
190
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
191
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
192
+
193
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
194
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
195
+
196
+ def forward(
197
+ self,
198
+ hidden_states: torch.Tensor,
199
+ attention_mask: Optional[torch.Tensor] = None,
200
+ causal_attention_mask: Optional[torch.Tensor] = None,
201
+ output_attentions: Optional[bool] = False,
202
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
203
+ """Input shape: Batch x Time x Channel"""
204
+
205
+ bsz, tgt_len, embed_dim = hidden_states.size()
206
+
207
+ # get query proj
208
+ query_states = self.q_proj(hidden_states) * self.scale
209
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
210
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
211
+
212
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
213
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
214
+ key_states = key_states.view(*proj_shape)
215
+ value_states = value_states.view(*proj_shape)
216
+
217
+ src_len = key_states.size(1)
218
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
219
+
220
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
221
+ raise ValueError(
222
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
223
+ )
224
+
225
+ # apply the causal_attention_mask first
226
+ if causal_attention_mask is not None:
227
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
228
+ raise ValueError(
229
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {causal_attention_mask.size()}"
230
+ )
231
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
232
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
233
+
234
+ if attention_mask is not None:
235
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
236
+ raise ValueError(
237
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
238
+ )
239
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
240
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
241
+
242
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
243
+
244
+ if output_attentions:
245
+ # this operation is a bit akward, but it's required to
246
+ # make sure that attn_weights keeps its gradient.
247
+ # In order to do so, attn_weights have to reshaped
248
+ # twice and have to be reused in the following
249
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
250
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
251
+ else:
252
+ attn_weights_reshaped = None
253
+
254
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
255
+
256
+ attn_output = torch.bmm(attn_probs, value_states)
257
+
258
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
259
+ raise ValueError(
260
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
261
+ )
262
+
263
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
264
+ attn_output = attn_output.transpose(1, 2)
265
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
266
+
267
+ attn_output = self.out_proj(attn_output)
268
+
269
+ return attn_output, attn_weights_reshaped
270
+
271
+
272
+ class CLIPMLP(nn.Module):
273
+ def __init__(self, config):
274
+ super().__init__()
275
+ self.config = config
276
+ self.activation_fn = ACT2FN[config.hidden_act]
277
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
278
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
279
+
280
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
281
+ hidden_states = self.fc1(hidden_states)
282
+ hidden_states = self.activation_fn(hidden_states)
283
+ hidden_states = self.fc2(hidden_states)
284
+ return hidden_states
285
+
286
+
287
+ class CLIPEncoderLayer(nn.Module):
288
+ def __init__(self, config: CLIPConfig):
289
+ super().__init__()
290
+ self.embed_dim = config.hidden_size
291
+ self.self_attn = CLIPAttention(config)
292
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim)
293
+ self.mlp = CLIPMLP(config)
294
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim)
295
+
296
+ def forward(
297
+ self,
298
+ hidden_states: torch.Tensor,
299
+ attention_mask: torch.Tensor,
300
+ causal_attention_mask: torch.Tensor,
301
+ output_attentions: Optional[bool] = False,
302
+ ) -> Tuple[torch.FloatTensor]:
303
+ """
304
+ Args:
305
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
306
+ attention_mask (`torch.FloatTensor`): attention mask of size
307
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
308
+ `(config.encoder_attention_heads,)`.
309
+ output_attentions (`bool`, *optional*):
310
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
311
+ returned tensors for more detail.
312
+ """
313
+ residual = hidden_states
314
+
315
+ hidden_states = self.layer_norm1(hidden_states)
316
+ hidden_states, attn_weights = self.self_attn(
317
+ hidden_states=hidden_states,
318
+ attention_mask=attention_mask,
319
+ causal_attention_mask=causal_attention_mask,
320
+ output_attentions=output_attentions,
321
+ )
322
+ hidden_states = residual + hidden_states
323
+
324
+ residual = hidden_states
325
+ hidden_states = self.layer_norm2(hidden_states)
326
+ hidden_states = self.mlp(hidden_states)
327
+ hidden_states = residual + hidden_states
328
+
329
+ outputs = (hidden_states,)
330
+
331
+ if output_attentions:
332
+ outputs += (attn_weights,)
333
+
334
+ return outputs
335
+
336
+
337
+ class CLIPPreTrainedModel(PreTrainedModel):
338
+ """
339
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
340
+ models.
341
+ """
342
+
343
+ config_class = CLIPConfig
344
+ base_model_prefix = "clip"
345
+ supports_gradient_checkpointing = True
346
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
347
+
348
+ def _init_weights(self, module):
349
+ """Initialize the weights"""
350
+ factor = self.config.initializer_factor
351
+ if isinstance(module, CLIPTextEmbeddings):
352
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
353
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
354
+ elif isinstance(module, CLIPVisionEmbeddings):
355
+ factor = self.config.initializer_factor
356
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
357
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
358
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
359
+ elif isinstance(module, CLIPAttention):
360
+ factor = self.config.initializer_factor
361
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
362
+ out_proj_std = (module.embed_dim**-0.5) * factor
363
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
364
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
365
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
366
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
367
+ elif isinstance(module, CLIPMLP):
368
+ factor = self.config.initializer_factor
369
+ in_proj_std = (
370
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
371
+ )
372
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
373
+ nn.init.normal_(module.fc1.weight, std=fc_std)
374
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
375
+ elif isinstance(module, CLIPModel):
376
+ nn.init.normal_(
377
+ module.text_projection.weight,
378
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
379
+ )
380
+ nn.init.normal_(
381
+ module.visual_projection.weight,
382
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
383
+ )
384
+
385
+ if isinstance(module, nn.LayerNorm):
386
+ module.bias.data.zero_()
387
+ module.weight.data.fill_(1.0)
388
+ if isinstance(module, nn.Linear) and module.bias is not None:
389
+ module.bias.data.zero_()
390
+
391
+ def _set_gradient_checkpointing(self, module, value=False):
392
+ if isinstance(module, CLIPEncoder):
393
+ module.gradient_checkpointing = value
394
+
395
+
396
+ CLIP_START_DOCSTRING = r"""
397
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
398
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
399
+ behavior.
400
+
401
+ Parameters:
402
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
403
+ Initializing with a config file does not load the weights associated with the model, only the
404
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
405
+ """
406
+
407
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
408
+ Args:
409
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
410
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
411
+ it.
412
+
413
+ Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
414
+ [`PreTrainedTokenizer.__call__`] for details.
415
+
416
+ [What are input IDs?](../glossary#input-ids)
417
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
418
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
419
+
420
+ - 1 for tokens that are **not masked**,
421
+ - 0 for tokens that are **masked**.
422
+
423
+ [What are attention masks?](../glossary#attention-mask)
424
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
425
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
426
+ config.max_position_embeddings - 1]`.
427
+
428
+ [What are position IDs?](../glossary#position-ids)
429
+ output_attentions (`bool`, *optional*):
430
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
431
+ tensors for more detail.
432
+ output_hidden_states (`bool`, *optional*):
433
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
434
+ more detail.
435
+ return_dict (`bool`, *optional*):
436
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
437
+ """
438
+
439
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
440
+ Args:
441
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
442
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
443
+ [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
444
+ output_attentions (`bool`, *optional*):
445
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
446
+ tensors for more detail.
447
+ output_hidden_states (`bool`, *optional*):
448
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
449
+ more detail.
450
+ return_dict (`bool`, *optional*):
451
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
452
+ """
453
+
454
+ CLIP_INPUTS_DOCSTRING = r"""
455
+ Args:
456
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
457
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
458
+ it.
459
+
460
+ Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
461
+ [`PreTrainedTokenizer.__call__`] for details.
462
+
463
+ [What are input IDs?](../glossary#input-ids)
464
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
465
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
466
+
467
+ - 1 for tokens that are **not masked**,
468
+ - 0 for tokens that are **masked**.
469
+
470
+ [What are attention masks?](../glossary#attention-mask)
471
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
472
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
473
+ config.max_position_embeddings - 1]`.
474
+
475
+ [What are position IDs?](../glossary#position-ids)
476
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
477
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
478
+ [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
479
+ return_loss (`bool`, *optional*):
480
+ Whether or not to return the contrastive loss.
481
+ output_attentions (`bool`, *optional*):
482
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
483
+ tensors for more detail.
484
+ output_hidden_states (`bool`, *optional*):
485
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
486
+ more detail.
487
+ return_dict (`bool`, *optional*):
488
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
489
+ """
490
+
491
+
492
+ class CLIPEncoder(nn.Module):
493
+ """
494
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
495
+ [`CLIPEncoderLayer`].
496
+
497
+ Args:
498
+ config: CLIPConfig
499
+ """
500
+
501
+ def __init__(self, config: CLIPConfig):
502
+ super().__init__()
503
+ self.config = config
504
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
505
+ self.gradient_checkpointing = False
506
+
507
+ def forward(
508
+ self,
509
+ inputs_embeds,
510
+ attention_mask: Optional[torch.Tensor] = None,
511
+ causal_attention_mask: Optional[torch.Tensor] = None,
512
+ output_attentions: Optional[bool] = None,
513
+ output_hidden_states: Optional[bool] = None,
514
+ return_dict: Optional[bool] = None,
515
+ ) -> Union[Tuple, BaseModelOutput]:
516
+ r"""
517
+ Args:
518
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
519
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
520
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
521
+ than the model's internal embedding lookup matrix.
522
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
523
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
524
+
525
+ - 1 for tokens that are **not masked**,
526
+ - 0 for tokens that are **masked**.
527
+
528
+ [What are attention masks?](../glossary#attention-mask)
529
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
530
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
531
+
532
+ - 1 for tokens that are **not masked**,
533
+ - 0 for tokens that are **masked**.
534
+
535
+ [What are attention masks?](../glossary#attention-mask)
536
+ output_attentions (`bool`, *optional*):
537
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
538
+ returned tensors for more detail.
539
+ output_hidden_states (`bool`, *optional*):
540
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
541
+ for more detail.
542
+ return_dict (`bool`, *optional*):
543
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
544
+ """
545
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
546
+ output_hidden_states = (
547
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
548
+ )
549
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
550
+
551
+ encoder_states = () if output_hidden_states else None
552
+ all_attentions = () if output_attentions else None
553
+
554
+ hidden_states = inputs_embeds
555
+ for idx, encoder_layer in enumerate(self.layers):
556
+ if output_hidden_states:
557
+ encoder_states = encoder_states + (hidden_states,)
558
+ if self.gradient_checkpointing and self.training:
559
+
560
+ def create_custom_forward(module):
561
+ def custom_forward(*inputs):
562
+ return module(*inputs, output_attentions)
563
+
564
+ return custom_forward
565
+
566
+ layer_outputs = torch.utils.checkpoint.checkpoint(
567
+ create_custom_forward(encoder_layer),
568
+ hidden_states,
569
+ attention_mask,
570
+ causal_attention_mask,
571
+ )
572
+ else:
573
+ layer_outputs = encoder_layer(
574
+ hidden_states,
575
+ attention_mask,
576
+ causal_attention_mask,
577
+ output_attentions=output_attentions,
578
+ )
579
+
580
+ hidden_states = layer_outputs[0]
581
+
582
+ if output_attentions:
583
+ all_attentions = all_attentions + (layer_outputs[1],)
584
+
585
+ if output_hidden_states:
586
+ encoder_states = encoder_states + (hidden_states,)
587
+
588
+ if not return_dict:
589
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
590
+ return BaseModelOutput(
591
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
592
+ )
593
+
594
+
595
+ class CLIPTextTransformer(nn.Module):
596
+ def __init__(self, config: CLIPTextConfig):
597
+ super().__init__()
598
+ self.config = config
599
+ embed_dim = config.hidden_size
600
+ self.embeddings = CLIPTextEmbeddings(config)
601
+ self.encoder = CLIPEncoder(config)
602
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
603
+
604
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
605
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
606
+ def forward(
607
+ self,
608
+ input_ids: Optional[torch.Tensor] = None,
609
+ attention_mask: Optional[torch.Tensor] = None,
610
+ position_ids: Optional[torch.Tensor] = None,
611
+ output_attentions: Optional[bool] = None,
612
+ output_hidden_states: Optional[bool] = None,
613
+ return_dict: Optional[bool] = None,
614
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
615
+ r"""
616
+ Returns:
617
+
618
+ """
619
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
620
+ output_hidden_states = (
621
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
622
+ )
623
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
624
+
625
+ if input_ids is None:
626
+ raise ValueError("You have to specify either input_ids")
627
+
628
+ input_shape = input_ids.size()
629
+ input_ids = input_ids.view(-1, input_shape[-1])
630
+
631
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
632
+
633
+ bsz, seq_len = input_shape
634
+ # CLIP's text model uses causal mask, prepare it here.
635
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
636
+ causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
637
+ # expand attention_mask
638
+ if attention_mask is not None:
639
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
640
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
641
+
642
+ encoder_outputs = self.encoder(
643
+ inputs_embeds=hidden_states,
644
+ attention_mask=attention_mask,
645
+ causal_attention_mask=causal_attention_mask,
646
+ output_attentions=output_attentions,
647
+ output_hidden_states=output_hidden_states,
648
+ return_dict=return_dict,
649
+ )
650
+
651
+ last_hidden_state = encoder_outputs[0]
652
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
653
+
654
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
655
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
656
+ pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
657
+
658
+ if not return_dict:
659
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
660
+
661
+ return BaseModelOutputWithPooling(
662
+ last_hidden_state=last_hidden_state,
663
+ pooler_output=pooled_output,
664
+ hidden_states=encoder_outputs.hidden_states,
665
+ attentions=encoder_outputs.attentions,
666
+ )
667
+
668
+ def _build_causal_attention_mask(self, bsz, seq_len):
669
+ # lazily create causal attention mask, with full attention between the vision tokens
670
+ # pytorch uses additive attention mask; fill with -inf
671
+ mask = torch.empty(bsz, seq_len, seq_len)
672
+ mask.fill_(float("-inf"))
673
+ mask.triu_(1) # zero out the lower diagonal
674
+ mask = mask.unsqueeze(1) # expand mask
675
+ return mask
676
+
677
+
678
+ class CLIPTextModel(CLIPPreTrainedModel):
679
+ config_class = CLIPTextConfig
680
+
681
+ def __init__(self, config: CLIPTextConfig):
682
+ super().__init__(config)
683
+ self.text_model = CLIPTextTransformer(config)
684
+ # Initialize weights and apply final processing
685
+ self.post_init()
686
+
687
+ def get_input_embeddings(self) -> nn.Module:
688
+ return self.text_model.embeddings.token_embedding
689
+
690
+ def set_input_embeddings(self, value):
691
+ self.text_model.embeddings.token_embedding = value
692
+
693
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
694
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
695
+ def forward(
696
+ self,
697
+ input_ids: Optional[torch.Tensor] = None,
698
+ attention_mask: Optional[torch.Tensor] = None,
699
+ position_ids: Optional[torch.Tensor] = None,
700
+ output_attentions: Optional[bool] = None,
701
+ output_hidden_states: Optional[bool] = None,
702
+ return_dict: Optional[bool] = None,
703
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
704
+ r"""
705
+ Returns:
706
+
707
+ Examples:
708
+
709
+ ```python
710
+ >>> from transformers import CLIPTokenizer, CLIPTextModel
711
+
712
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
713
+ >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
714
+
715
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
716
+
717
+ >>> outputs = model(**inputs)
718
+ >>> last_hidden_state = outputs.last_hidden_state
719
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
720
+ ```"""
721
+ return self.text_model(
722
+ input_ids=input_ids,
723
+ attention_mask=attention_mask,
724
+ position_ids=position_ids,
725
+ output_attentions=output_attentions,
726
+ output_hidden_states=output_hidden_states,
727
+ return_dict=return_dict,
728
+ )
729
+
730
+
731
+ class CLIPVisionTransformer(nn.Module):
732
+ def __init__(self, config: CLIPVisionConfig):
733
+ super().__init__()
734
+ self.config = config
735
+ embed_dim = config.hidden_size
736
+
737
+ self.embeddings = CLIPVisionEmbeddings(config)
738
+ self.pre_layrnorm = nn.LayerNorm(embed_dim)
739
+ self.encoder = CLIPEncoder(config)
740
+ self.post_layernorm = nn.LayerNorm(embed_dim)
741
+
742
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
743
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
744
+ def forward(
745
+ self,
746
+ pixel_values: Optional[torch.FloatTensor] = None,
747
+ output_attentions: Optional[bool] = None,
748
+ output_hidden_states: Optional[bool] = None,
749
+ return_dict: Optional[bool] = None,
750
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
751
+ r"""
752
+ Returns:
753
+
754
+ """
755
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
756
+ output_hidden_states = (
757
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
758
+ )
759
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
760
+
761
+ if pixel_values is None:
762
+ raise ValueError("You have to specify pixel_values")
763
+
764
+ hidden_states = self.embeddings(pixel_values)
765
+ hidden_states = self.pre_layrnorm(hidden_states)
766
+
767
+ encoder_outputs = self.encoder(
768
+ inputs_embeds=hidden_states,
769
+ output_attentions=output_attentions,
770
+ output_hidden_states=output_hidden_states,
771
+ return_dict=return_dict,
772
+ )
773
+
774
+ last_hidden_state = encoder_outputs[0]
775
+ pooled_output = last_hidden_state[:, 0, :]
776
+ pooled_output = self.post_layernorm(pooled_output)
777
+
778
+ if not return_dict:
779
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
780
+
781
+ return BaseModelOutputWithPooling(
782
+ last_hidden_state=last_hidden_state,
783
+ pooler_output=pooled_output,
784
+ hidden_states=encoder_outputs.hidden_states,
785
+ attentions=encoder_outputs.attentions,
786
+ )
787
+
788
+
789
+ class CLIPVisionModel(CLIPPreTrainedModel):
790
+ config_class = CLIPVisionConfig
791
+ main_input_name = "pixel_values"
792
+
793
+ def __init__(self, config: CLIPVisionConfig):
794
+ super().__init__(config)
795
+ self.vision_model = CLIPVisionTransformer(config)
796
+ # Initialize weights and apply final processing
797
+ self.post_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
798
+ self.post_init()
799
+
800
+ def get_input_embeddings(self) -> nn.Module:
801
+ return self.vision_model.embeddings.patch_embedding
802
+
803
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
804
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
805
+ def forward(
806
+ self,
807
+ pixel_values: Optional[torch.FloatTensor] = None,
808
+ output_attentions: Optional[bool] = None,
809
+ output_hidden_states: Optional[bool] = None,
810
+ return_dict: Optional[bool] = None,
811
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
812
+ r"""
813
+ Returns:
814
+
815
+ Examples:
816
+
817
+ ```python
818
+ >>> from PIL import Image
819
+ >>> import requests
820
+ >>> from transformers import CLIPProcessor, CLIPVisionModel
821
+
822
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
823
+ >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
824
+
825
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
826
+ >>> image = Image.open(requests.get(url, stream=True).raw)
827
+
828
+ >>> inputs = processor(images=image, return_tensors="pt")
829
+
830
+ >>> outputs = model(**inputs)
831
+ >>> last_hidden_state = outputs.last_hidden_state
832
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
833
+ ```"""
834
+ # return self.vision_model(
835
+ # pixel_values=pixel_values,
836
+ # output_attentions=output_attentions,
837
+ # output_hidden_states=output_hidden_states,
838
+ # return_dict=return_dict,
839
+ # )
840
+
841
+ result = self.vision_model(
842
+ pixel_values=pixel_values,
843
+ output_attentions=output_attentions,
844
+ output_hidden_states=output_hidden_states,
845
+ return_dict=return_dict,
846
+ )
847
+ result = result.last_hidden_state
848
+ # todo post norm
849
+ result = self.post_norm(result)
850
+ return result
851
+
852
+
853
+ @add_start_docstrings(CLIP_START_DOCSTRING)
854
+ class CLIPModel(CLIPPreTrainedModel):
855
+ config_class = CLIPConfig
856
+
857
+ def __init__(self, config: CLIPConfig):
858
+ super().__init__(config)
859
+
860
+ if not isinstance(config.text_config, CLIPTextConfig):
861
+ raise ValueError(
862
+ f"config.text_config is expected to be of type CLIPTextConfig but is of type {type(config.text_config)}."
863
+ )
864
+
865
+ if not isinstance(config.vision_config, CLIPVisionConfig):
866
+ raise ValueError(
867
+ f"config.vision_config is expected to be of type CLIPVisionConfig but is of type {type(config.vision_config)}."
868
+ )
869
+
870
+ text_config = config.text_config
871
+ vision_config = config.vision_config
872
+
873
+ self.projection_dim = config.projection_dim
874
+ self.text_embed_dim = text_config.hidden_size
875
+ self.vision_embed_dim = vision_config.hidden_size
876
+
877
+ self.text_model = CLIPTextTransformer(text_config)
878
+ self.vision_model = CLIPVisionTransformer(vision_config)
879
+
880
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
881
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
882
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
883
+
884
+ # Initialize weights and apply final processing
885
+ self.post_init()
886
+
887
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
888
+ def get_text_features(
889
+ self,
890
+ input_ids: Optional[torch.Tensor] = None,
891
+ attention_mask: Optional[torch.Tensor] = None,
892
+ position_ids: Optional[torch.Tensor] = None,
893
+ output_attentions: Optional[bool] = None,
894
+ output_hidden_states: Optional[bool] = None,
895
+ return_dict: Optional[bool] = None,
896
+ ) -> torch.FloatTensor:
897
+ r"""
898
+ Returns:
899
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
900
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
901
+
902
+ Examples:
903
+
904
+ ```python
905
+ >>> from transformers import CLIPTokenizer, CLIPModel
906
+
907
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
908
+ >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
909
+
910
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
911
+ >>> text_features = model.get_text_features(**inputs)
912
+ ```"""
913
+ text_outputs = self.text_model(
914
+ input_ids=input_ids,
915
+ attention_mask=attention_mask,
916
+ position_ids=position_ids,
917
+ output_attentions=output_attentions,
918
+ output_hidden_states=output_hidden_states,
919
+ return_dict=return_dict,
920
+ )
921
+
922
+ pooled_output = text_outputs[1]
923
+ text_features = self.text_projection(pooled_output)
924
+
925
+ return text_features
926
+
927
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
928
+ def get_image_features(
929
+ self,
930
+ pixel_values: Optional[torch.FloatTensor] = None,
931
+ output_attentions: Optional[bool] = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ ) -> torch.FloatTensor:
935
+ r"""
936
+ Returns:
937
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
938
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
939
+
940
+ Examples:
941
+
942
+ ```python
943
+ >>> from PIL import Image
944
+ >>> import requests
945
+ >>> from transformers import CLIPProcessor, CLIPModel
946
+
947
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
948
+ >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
949
+
950
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
951
+ >>> image = Image.open(requests.get(url, stream=True).raw)
952
+
953
+ >>> inputs = processor(images=image, return_tensors="pt")
954
+
955
+ >>> image_features = model.get_image_features(**inputs)
956
+ ```"""
957
+ vision_outputs = self.vision_model(
958
+ pixel_values=pixel_values,
959
+ output_attentions=output_attentions,
960
+ output_hidden_states=output_hidden_states,
961
+ return_dict=return_dict,
962
+ )
963
+
964
+ pooled_output = vision_outputs[1] # pooled_output
965
+ image_features = self.visual_projection(pooled_output)
966
+
967
+ return image_features
968
+
969
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
970
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
971
+ def forward(
972
+ self,
973
+ input_ids: Optional[torch.LongTensor] = None,
974
+ pixel_values: Optional[torch.FloatTensor] = None,
975
+ attention_mask: Optional[torch.Tensor] = None,
976
+ position_ids: Optional[torch.LongTensor] = None,
977
+ return_loss: Optional[bool] = None,
978
+ output_attentions: Optional[bool] = None,
979
+ output_hidden_states: Optional[bool] = None,
980
+ return_dict: Optional[bool] = None,
981
+ ) -> Union[Tuple, CLIPOutput]:
982
+ r"""
983
+ Returns:
984
+
985
+ Examples:
986
+
987
+ ```python
988
+ >>> from PIL import Image
989
+ >>> import requests
990
+ >>> from transformers import CLIPProcessor, CLIPModel
991
+
992
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
993
+ >>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
994
+
995
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
996
+ >>> image = Image.open(requests.get(url, stream=True).raw)
997
+
998
+ >>> inputs = processor(
999
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1000
+ ... )
1001
+
1002
+ >>> outputs = model(**inputs)
1003
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1004
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1005
+ ```"""
1006
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1007
+ vision_outputs = self.vision_model(
1008
+ pixel_values=pixel_values,
1009
+ output_attentions=output_attentions,
1010
+ output_hidden_states=output_hidden_states,
1011
+ return_dict=return_dict,
1012
+ )
1013
+
1014
+ text_outputs = self.text_model(
1015
+ input_ids=input_ids,
1016
+ attention_mask=attention_mask,
1017
+ position_ids=position_ids,
1018
+ output_attentions=output_attentions,
1019
+ output_hidden_states=output_hidden_states,
1020
+ return_dict=return_dict,
1021
+ )
1022
+
1023
+ image_embeds = vision_outputs[1]
1024
+ image_embeds = self.visual_projection(image_embeds)
1025
+
1026
+ text_embeds = text_outputs[1]
1027
+ text_embeds = self.text_projection(text_embeds)
1028
+
1029
+ # normalized features
1030
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
1031
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
1032
+
1033
+ # cosine similarity as logits
1034
+ logit_scale = self.logit_scale.exp()
1035
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1036
+ logits_per_image = logits_per_text.T
1037
+
1038
+ loss = None
1039
+ if return_loss:
1040
+ loss = clip_loss(logits_per_text)
1041
+
1042
+ if not return_dict:
1043
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1044
+ return ((loss,) + output) if loss is not None else output
1045
+
1046
+ return CLIPOutput(
1047
+ loss=loss,
1048
+ logits_per_image=logits_per_image,
1049
+ logits_per_text=logits_per_text,
1050
+ text_embeds=text_embeds,
1051
+ image_embeds=image_embeds,
1052
+ text_model_output=text_outputs,
1053
+ vision_model_output=vision_outputs,
1054
+ )
models/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
5
+ torch
6
+ torchvision
7
+ Pillow
resources/bert-large-chinese/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "directionality": "bidi",
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-12,
15
+ "max_position_embeddings": 512,
16
+ "model_type": "bert",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 24,
19
+ "output_past": true,
20
+ "pad_token_id": 0,
21
+ "pooler_fc_size": 768,
22
+ "pooler_num_attention_heads": 12,
23
+ "pooler_num_fc_layers": 3,
24
+ "pooler_size_per_head": 128,
25
+ "pooler_type": "first_token_transform",
26
+ "type_vocab_size": 2,
27
+ "vocab_size": 21128
28
+ }
resources/bert-large-chinese/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"init_inputs": []}
resources/bert-large-chinese/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
resources/clip_vit_large_patch14/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-dpsr/zhangzhe45/huggingface/openai/clip-vit-large-patch14/",
3
+ "attention_dropout": 0.0,
4
+ "dropout": 0.0,
5
+ "hidden_act": "quick_gelu",
6
+ "hidden_size": 1024,
7
+ "image_size": 224,
8
+ "initializer_factor": 1.0,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 4096,
11
+ "layer_norm_eps": 1e-05,
12
+ "model_type": "clip_vision_model",
13
+ "num_attention_heads": 16,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "transformers_version": "4.18.0"
17
+ }
resources/examples/charger-hw.jpg ADDED

Git LFS Details

  • SHA256: ed7377f93aa019b857d539b69d550926735526b6abea417380f2feda170880a6
  • Pointer size: 130 Bytes
  • Size of remote file: 56 kB
resources/examples/charger-ugreen.jpg ADDED

Git LFS Details

  • SHA256: 7b21f015aa7ff9ddaba4d15a219a7d0763fccf34b4071a0332e168d55c32c0ef
  • Pointer size: 130 Bytes
  • Size of remote file: 65.2 kB
resources/examples/charger.jpg ADDED

Git LFS Details

  • SHA256: d499625b41a9ef26f7ac2ced6449992e3bb936940c39ab78ac05562a1722273c
  • Pointer size: 131 Bytes
  • Size of remote file: 235 kB
resources/examples/jiandao.jpg ADDED

Git LFS Details

  • SHA256: faa7d559bc1a18e06902647deefc3baa9904637d36610d8dfd102f5aba79e4f3
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
resources/examples/lego-yellow.jpg ADDED

Git LFS Details

  • SHA256: fbb391de4cd341b7d2473d58a08e33ade8b8cc38212914a79727c3c8ab529b84
  • Pointer size: 131 Bytes
  • Size of remote file: 174 kB
starrynight.jpeg ADDED

Git LFS Details

  • SHA256: 5321c78bb92a5746827fb37b7361ddc6f9bfa5ec6c96ce98917f6f696231cc49
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
transform/randaugment.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ ## aug functions
6
+ def identity_func(img):
7
+ return img
8
+
9
+
10
+ def autocontrast_func(img, cutoff=0):
11
+ '''
12
+ same output as PIL.ImageOps.autocontrast
13
+ '''
14
+ n_bins = 256
15
+
16
+ def tune_channel(ch):
17
+ n = ch.size
18
+ cut = cutoff * n // 100
19
+ if cut == 0:
20
+ high, low = ch.max(), ch.min()
21
+ else:
22
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23
+ low = np.argwhere(np.cumsum(hist) > cut)
24
+ low = 0 if low.shape[0] == 0 else low[0]
25
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27
+ if high <= low:
28
+ table = np.arange(n_bins)
29
+ else:
30
+ scale = (n_bins - 1) / (high - low)
31
+ offset = -low * scale
32
+ table = np.arange(n_bins) * scale + offset
33
+ table[table < 0] = 0
34
+ table[table > n_bins - 1] = n_bins - 1
35
+ table = table.clip(0, 255).astype(np.uint8)
36
+ return table[ch]
37
+
38
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
39
+ out = cv2.merge(channels)
40
+ return out
41
+
42
+
43
+ def equalize_func(img):
44
+ '''
45
+ same output as PIL.ImageOps.equalize
46
+ PIL's implementation is different from cv2.equalize
47
+ '''
48
+ n_bins = 256
49
+
50
+ def tune_channel(ch):
51
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52
+ non_zero_hist = hist[hist != 0].reshape(-1)
53
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54
+ if step == 0: return ch
55
+ n = np.empty_like(hist)
56
+ n[0] = step // 2
57
+ n[1:] = hist[:-1]
58
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
59
+ return table[ch]
60
+
61
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
62
+ out = cv2.merge(channels)
63
+ return out
64
+
65
+
66
+ def rotate_func(img, degree, fill=(0, 0, 0)):
67
+ '''
68
+ like PIL, rotate by degree, not radians
69
+ '''
70
+ H, W = img.shape[0], img.shape[1]
71
+ center = W / 2, H / 2
72
+ M = cv2.getRotationMatrix2D(center, degree, 1)
73
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
74
+ return out
75
+
76
+
77
+ def solarize_func(img, thresh=128):
78
+ '''
79
+ same output as PIL.ImageOps.posterize
80
+ '''
81
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
82
+ table = table.clip(0, 255).astype(np.uint8)
83
+ out = table[img]
84
+ return out
85
+
86
+
87
+ def color_func(img, factor):
88
+ '''
89
+ same output as PIL.ImageEnhance.Color
90
+ '''
91
+ ## implementation according to PIL definition, quite slow
92
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
93
+ # out = blend(degenerate, img, factor)
94
+ # M = (
95
+ # np.eye(3) * factor
96
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
97
+ # )[np.newaxis, np.newaxis, :]
98
+ M = (
99
+ np.float32([
100
+ [0.886, -0.114, -0.114],
101
+ [-0.587, 0.413, -0.587],
102
+ [-0.299, -0.299, 0.701]]) * factor
103
+ + np.float32([[0.114], [0.587], [0.299]])
104
+ )
105
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106
+ return out
107
+
108
+
109
+ def contrast_func(img, factor):
110
+ """
111
+ same output as PIL.ImageEnhance.Contrast
112
+ """
113
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114
+ table = np.array([(
115
+ el - mean) * factor + mean
116
+ for el in range(256)
117
+ ]).clip(0, 255).astype(np.uint8)
118
+ out = table[img]
119
+ return out
120
+
121
+
122
+ def brightness_func(img, factor):
123
+ '''
124
+ same output as PIL.ImageEnhance.Contrast
125
+ '''
126
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
127
+ out = table[img]
128
+ return out
129
+
130
+
131
+ def sharpness_func(img, factor):
132
+ '''
133
+ The differences the this result and PIL are all on the 4 boundaries, the center
134
+ areas are same
135
+ '''
136
+ kernel = np.ones((3, 3), dtype=np.float32)
137
+ kernel[1][1] = 5
138
+ kernel /= 13
139
+ degenerate = cv2.filter2D(img, -1, kernel)
140
+ if factor == 0.0:
141
+ out = degenerate
142
+ elif factor == 1.0:
143
+ out = img
144
+ else:
145
+ out = img.astype(np.float32)
146
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
147
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
148
+ out = out.astype(np.uint8)
149
+ return out
150
+
151
+
152
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
153
+ H, W = img.shape[0], img.shape[1]
154
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
155
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
160
+ '''
161
+ same output as PIL.Image.transform
162
+ '''
163
+ H, W = img.shape[0], img.shape[1]
164
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
165
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
166
+ return out
167
+
168
+
169
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
170
+ '''
171
+ same output as PIL.Image.transform
172
+ '''
173
+ H, W = img.shape[0], img.shape[1]
174
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
175
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
176
+ return out
177
+
178
+
179
+ def posterize_func(img, bits):
180
+ '''
181
+ same output as PIL.ImageOps.posterize
182
+ '''
183
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
184
+ return out
185
+
186
+
187
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
188
+ H, W = img.shape[0], img.shape[1]
189
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
190
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
191
+ return out
192
+
193
+
194
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
195
+ replace = np.array(replace, dtype=np.uint8)
196
+ H, W = img.shape[0], img.shape[1]
197
+ rh, rw = np.random.random(2)
198
+ pad_size = pad_size // 2
199
+ ch, cw = int(rh * H), int(rw * W)
200
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
201
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
202
+ out = img.copy()
203
+ out[x1:x2, y1:y2, :] = replace
204
+ return out
205
+
206
+
207
+ ### level to args
208
+ def enhance_level_to_args(MAX_LEVEL):
209
+ def level_to_args(level):
210
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
211
+ return level_to_args
212
+
213
+
214
+ def shear_level_to_args(MAX_LEVEL, replace_value):
215
+ def level_to_args(level):
216
+ level = (level / MAX_LEVEL) * 0.3
217
+ if np.random.random() > 0.5: level = -level
218
+ return (level, replace_value)
219
+
220
+ return level_to_args
221
+
222
+
223
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
224
+ def level_to_args(level):
225
+ level = (level / MAX_LEVEL) * float(translate_const)
226
+ if np.random.random() > 0.5: level = -level
227
+ return (level, replace_value)
228
+
229
+ return level_to_args
230
+
231
+
232
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
233
+ def level_to_args(level):
234
+ level = int((level / MAX_LEVEL) * cutout_const)
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def solarize_level_to_args(MAX_LEVEL):
241
+ def level_to_args(level):
242
+ level = int((level / MAX_LEVEL) * 256)
243
+ return (level, )
244
+ return level_to_args
245
+
246
+
247
+ def none_level_to_args(level):
248
+ return ()
249
+
250
+
251
+ def posterize_level_to_args(MAX_LEVEL):
252
+ def level_to_args(level):
253
+ level = int((level / MAX_LEVEL) * 4)
254
+ return (level, )
255
+ return level_to_args
256
+
257
+
258
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
259
+ def level_to_args(level):
260
+ level = (level / MAX_LEVEL) * 30
261
+ if np.random.random() < 0.5:
262
+ level = -level
263
+ return (level, replace_value)
264
+
265
+ return level_to_args
266
+
267
+
268
+ func_dict = {
269
+ 'Identity': identity_func,
270
+ 'AutoContrast': autocontrast_func,
271
+ 'Equalize': equalize_func,
272
+ 'Rotate': rotate_func,
273
+ 'Solarize': solarize_func,
274
+ 'Color': color_func,
275
+ 'Contrast': contrast_func,
276
+ 'Brightness': brightness_func,
277
+ 'Sharpness': sharpness_func,
278
+ 'ShearX': shear_x_func,
279
+ 'TranslateX': translate_x_func,
280
+ 'TranslateY': translate_y_func,
281
+ 'Posterize': posterize_func,
282
+ 'ShearY': shear_y_func,
283
+ }
284
+
285
+ translate_const = 10
286
+ MAX_LEVEL = 10
287
+ replace_value = (128, 128, 128)
288
+ arg_dict = {
289
+ 'Identity': none_level_to_args,
290
+ 'AutoContrast': none_level_to_args,
291
+ 'Equalize': none_level_to_args,
292
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
293
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
294
+ 'Color': enhance_level_to_args(MAX_LEVEL),
295
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
296
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
297
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
298
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
299
+ 'TranslateX': translate_level_to_args(
300
+ translate_const, MAX_LEVEL, replace_value
301
+ ),
302
+ 'TranslateY': translate_level_to_args(
303
+ translate_const, MAX_LEVEL, replace_value
304
+ ),
305
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
306
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
307
+ }
308
+
309
+
310
+ class RandomAugment(object):
311
+
312
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
313
+ self.N = N
314
+ self.M = M
315
+ self.isPIL = isPIL
316
+ if augs:
317
+ self.augs = augs
318
+ else:
319
+ self.augs = list(arg_dict.keys())
320
+
321
+ def get_random_ops(self):
322
+ sampled_ops = np.random.choice(self.augs, self.N)
323
+ return [(op, 0.5, self.M) for op in sampled_ops]
324
+
325
+ def __call__(self, img):
326
+ if self.isPIL:
327
+ img = np.array(img)
328
+ ops = self.get_random_ops()
329
+ for name, prob, level in ops:
330
+ if np.random.random() > prob:
331
+ continue
332
+ args = arg_dict[name](level)
333
+ img = func_dict[name](img, *args)
334
+ return img
335
+
336
+
337
+ if __name__ == '__main__':
338
+ a = RandomAugment()
339
+ img = np.random.randn(32, 32, 3)
340
+ a(img)
utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
3
+ """Decay the learning rate"""
4
+ lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
5
+ for param_group in optimizer.param_groups:
6
+ param_group['lr'] = lr
7
+
8
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
9
+ """Warmup the learning rate"""
10
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
11
+ for param_group in optimizer.param_groups:
12
+ param_group['lr'] = lr
13
+
14
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
15
+ """Decay the learning rate"""
16
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
17
+ for param_group in optimizer.param_groups:
18
+ param_group['lr'] = lr
19
+
20
+ import numpy as np
21
+ import io
22
+ import os
23
+ import time
24
+ from collections import defaultdict, deque
25
+ import datetime
26
+
27
+ import torch
28
+ import torch.distributed as dist
29
+
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ return self.total / self.count
74
+
75
+ @property
76
+ def max(self):
77
+ return max(self.deque)
78
+
79
+ @property
80
+ def value(self):
81
+ return self.deque[-1]
82
+
83
+ def __str__(self):
84
+ return self.fmt.format(
85
+ median=self.median,
86
+ avg=self.avg,
87
+ global_avg=self.global_avg,
88
+ max=self.max,
89
+ value=self.value)
90
+
91
+
92
+ class MetricLogger(object):
93
+ def __init__(self, delimiter="\t"):
94
+ self.meters = defaultdict(SmoothedValue)
95
+ self.delimiter = delimiter
96
+
97
+ def update(self, **kwargs):
98
+ for k, v in kwargs.items():
99
+ if isinstance(v, torch.Tensor):
100
+ v = v.item()
101
+ assert isinstance(v, (float, int))
102
+ self.meters[k].update(v)
103
+
104
+ def __getattr__(self, attr):
105
+ if attr in self.meters:
106
+ return self.meters[attr]
107
+ if attr in self.__dict__:
108
+ return self.__dict__[attr]
109
+ raise AttributeError("'{}' object has no attribute '{}'".format(
110
+ type(self).__name__, attr))
111
+
112
+ def __str__(self):
113
+ loss_str = []
114
+ for name, meter in self.meters.items():
115
+ loss_str.append(
116
+ "{}: {}".format(name, str(meter))
117
+ )
118
+ return self.delimiter.join(loss_str)
119
+
120
+ def global_avg(self):
121
+ loss_str = []
122
+ for name, meter in self.meters.items():
123
+ loss_str.append(
124
+ "{}: {:.4f}".format(name, meter.global_avg)
125
+ )
126
+ return self.delimiter.join(loss_str)
127
+
128
+ def synchronize_between_processes(self):
129
+ for meter in self.meters.values():
130
+ meter.synchronize_between_processes()
131
+
132
+ def add_meter(self, name, meter):
133
+ self.meters[name] = meter
134
+
135
+ def log_every(self, iterable, print_freq, header=None):
136
+ i = 0
137
+ if not header:
138
+ header = ''
139
+ start_time = time.time()
140
+ end = time.time()
141
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
142
+ data_time = SmoothedValue(fmt='{avg:.4f}')
143
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
144
+ log_msg = [
145
+ header,
146
+ '[{0' + space_fmt + '}/{1}]',
147
+ 'eta: {eta}',
148
+ '{meters}',
149
+ 'time: {time}',
150
+ 'data: {data}'
151
+ ]
152
+ if torch.cuda.is_available():
153
+ log_msg.append('max mem: {memory:.0f}')
154
+ log_msg = self.delimiter.join(log_msg)
155
+ MB = 1024.0 * 1024.0
156
+ for obj in iterable:
157
+ data_time.update(time.time() - end)
158
+ yield obj
159
+ iter_time.update(time.time() - end)
160
+ if i % print_freq == 0 or i == len(iterable) - 1:
161
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
162
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
163
+ if torch.cuda.is_available():
164
+ print(log_msg.format(
165
+ i, len(iterable), eta=eta_string,
166
+ meters=str(self),
167
+ time=str(iter_time), data=str(data_time),
168
+ memory=torch.cuda.max_memory_allocated() / MB))
169
+ else:
170
+ print(log_msg.format(
171
+ i, len(iterable), eta=eta_string,
172
+ meters=str(self),
173
+ time=str(iter_time), data=str(data_time)))
174
+ i += 1
175
+ end = time.time()
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ print('{} Total time: {} ({:.4f} s / it)'.format(
179
+ header, total_time_str, total_time / len(iterable)))
180
+
181
+
182
+ class AttrDict(dict):
183
+ def __init__(self, *args, **kwargs):
184
+ super(AttrDict, self).__init__(*args, **kwargs)
185
+ self.__dict__ = self
186
+
187
+
188
+ def compute_acc(logits, label, reduction='mean'):
189
+ ret = (torch.argmax(logits, dim=1) == label).float()
190
+ if reduction == 'none':
191
+ return ret.detach()
192
+ elif reduction == 'mean':
193
+ return ret.mean().item()
194
+
195
+ def compute_n_params(model, return_str=True):
196
+ tot = 0
197
+ for p in model.parameters():
198
+ w = 1
199
+ for x in p.shape:
200
+ w *= x
201
+ tot += w
202
+ if return_str:
203
+ if tot >= 1e6:
204
+ return '{:.1f}M'.format(tot / 1e6)
205
+ else:
206
+ return '{:.1f}K'.format(tot / 1e3)
207
+ else:
208
+ return tot
209
+
210
+ def setup_for_distributed(is_master):
211
+ """
212
+ This function disables printing when not in master process
213
+ """
214
+ import builtins as __builtin__
215
+ builtin_print = __builtin__.print
216
+
217
+ def print(*args, **kwargs):
218
+ force = kwargs.pop('force', False)
219
+ if is_master or force:
220
+ builtin_print(*args, **kwargs)
221
+
222
+ __builtin__.print = print
223
+
224
+
225
+ def is_dist_avail_and_initialized():
226
+ if not dist.is_available():
227
+ return False
228
+ if not dist.is_initialized():
229
+ return False
230
+ return True
231
+
232
+
233
+ def get_world_size():
234
+ if not is_dist_avail_and_initialized():
235
+ return 1
236
+ return dist.get_world_size()
237
+
238
+
239
+ def get_rank():
240
+ if not is_dist_avail_and_initialized():
241
+ return 0
242
+ return dist.get_rank()
243
+
244
+
245
+ def is_main_process():
246
+ return get_rank() == 0
247
+
248
+
249
+ def save_on_master(*args, **kwargs):
250
+ if is_main_process():
251
+ torch.save(*args, **kwargs)
252
+
253
+
254
+ def init_distributed_mode(args):
255
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
256
+ args.rank = int(os.environ["RANK"])
257
+ args.world_size = int(os.environ['WORLD_SIZE'])
258
+ args.gpu = int(os.environ['LOCAL_RANK'])
259
+ elif 'SLURM_PROCID' in os.environ:
260
+ args.rank = int(os.environ['SLURM_PROCID'])
261
+ args.gpu = args.rank % torch.cuda.device_count()
262
+ else:
263
+ print('Not using distributed mode')
264
+ args.distributed = False
265
+ return
266
+
267
+ args.distributed = True
268
+
269
+ torch.cuda.set_device(args.gpu)
270
+ args.dist_backend = 'nccl'
271
+ print('| distributed init (rank {}, word {}): {}'.format(
272
+ args.rank, args.world_size, args.dist_url), flush=True)
273
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
274
+ world_size=args.world_size, rank=args.rank)
275
+ torch.distributed.barrier()
276
+ setup_for_distributed(args.rank == 0)
277
+
278
+