mshukor
commited on
Commit
·
26fd00c
1
Parent(s):
0fee199
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Audio_Captioning.ipynb +0 -0
- Captioning.ipynb +0 -0
- Image_gen.ipynb +301 -0
- LICENSE +201 -0
- README.md +618 -12
- README_EncouragingLoss.md +34 -0
- VG.ipynb +0 -0
- VQA.ipynb +0 -0
- Video_Captioning.ipynb +0 -0
- __pycache__/trainer.cpython-37.pyc +0 -0
- __pycache__/trainer.cpython-38.pyc +0 -0
- __pycache__/trainer.cpython-39.pyc +0 -0
- app.py +297 -0
- checkpoints.md +36 -0
- checkpoints/unival_s2_hs/checkpoint1.pt +3 -0
- checkpoints_cn.md +82 -0
- colab.md +9 -0
- criterions/__init__.py +5 -0
- criterions/__pycache__/__init__.cpython-37.pyc +0 -0
- criterions/__pycache__/__init__.cpython-38.pyc +0 -0
- criterions/__pycache__/__init__.cpython-39.pyc +0 -0
- criterions/__pycache__/clip_scst_loss.cpython-37.pyc +0 -0
- criterions/__pycache__/clip_scst_loss.cpython-38.pyc +0 -0
- criterions/__pycache__/clip_scst_loss.cpython-39.pyc +0 -0
- criterions/__pycache__/label_smoothed_cross_entropy.cpython-37.pyc +0 -0
- criterions/__pycache__/label_smoothed_cross_entropy.cpython-38.pyc +0 -0
- criterions/__pycache__/label_smoothed_cross_entropy.cpython-39.pyc +0 -0
- criterions/__pycache__/label_smoothed_cross_entropy_scst.cpython-39.pyc +0 -0
- criterions/__pycache__/label_smoothed_encouraging_loss.cpython-37.pyc +0 -0
- criterions/__pycache__/label_smoothed_encouraging_loss.cpython-38.pyc +0 -0
- criterions/__pycache__/label_smoothed_encouraging_loss.cpython-39.pyc +0 -0
- criterions/__pycache__/refcoco_scst_loss.cpython-39.pyc +0 -0
- criterions/__pycache__/scst_loss.cpython-37.pyc +0 -0
- criterions/__pycache__/scst_loss.cpython-38.pyc +0 -0
- criterions/__pycache__/scst_loss.cpython-39.pyc +0 -0
- criterions/clip_scst_loss.py +277 -0
- criterions/label_smoothed_cross_entropy.py +346 -0
- criterions/label_smoothed_cross_entropy_scst.py +555 -0
- criterions/label_smoothed_encouraging_loss.py +395 -0
- criterions/refcoco_scst_loss.py +427 -0
- data/.ipynb_checkpoints/file_dataset-checkpoint.py +107 -0
- data/__init__.py +0 -0
- data/__pycache__/__init__.cpython-37.pyc +0 -0
- data/__pycache__/__init__.cpython-38.pyc +0 -0
- data/__pycache__/__init__.cpython-39.pyc +0 -0
- data/__pycache__/audio_utils.cpython-37.pyc +0 -0
- data/__pycache__/audio_utils.cpython-39.pyc +0 -0
- data/__pycache__/data_utils.cpython-37.pyc +0 -0
- data/__pycache__/data_utils.cpython-38.pyc +0 -0
- data/__pycache__/data_utils.cpython-39.pyc +0 -0
Audio_Captioning.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Captioning.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Image_gen.ipynb
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "399f2fcf-9241-4910-a30d-6ca19880d0ad",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"## Import"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 15,
|
14 |
+
"id": "97e68340-0096-475e-8ed8-22f5d627e3ad",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"import torch\n",
|
19 |
+
"import numpy as np\n",
|
20 |
+
"from fairseq import utils, tasks\n",
|
21 |
+
"from fairseq import checkpoint_utils\n",
|
22 |
+
"from utils.eval_utils import eval_step\n",
|
23 |
+
"from tasks.mm_tasks import ImageGenTask\n",
|
24 |
+
"from models.unival import UnIVALModel\n",
|
25 |
+
"from PIL import Image\n",
|
26 |
+
"from torchvision import transforms\n",
|
27 |
+
"import time\n",
|
28 |
+
"\n",
|
29 |
+
"\n",
|
30 |
+
"# turn on cuda if GPU is available\n",
|
31 |
+
"use_cuda = torch.cuda.is_available()\n",
|
32 |
+
"# use fp16 only when GPU is available\n",
|
33 |
+
"use_fp16 = True if use_cuda else False"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 16,
|
39 |
+
"id": "719cef65-c00c-4c9c-90b2-e660b386c3d5",
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [
|
42 |
+
{
|
43 |
+
"data": {
|
44 |
+
"text/plain": [
|
45 |
+
"<function fairseq.tasks.register_task.<locals>.register_task_cls(cls)>"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
"execution_count": 16,
|
49 |
+
"metadata": {},
|
50 |
+
"output_type": "execute_result"
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"source": [
|
54 |
+
"# Register caption task\n",
|
55 |
+
"tasks.register_task('image_gen', ImageGenTask)\n"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "markdown",
|
60 |
+
"id": "cc9c1d7b-898b-4ac4-adf3-832891d9e4be",
|
61 |
+
"metadata": {},
|
62 |
+
"source": [
|
63 |
+
"### Load model "
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 12,
|
69 |
+
"id": "568bb6ea-eef9-4024-98e6-35e74b5ffeec",
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [
|
72 |
+
{
|
73 |
+
"name": "stdout",
|
74 |
+
"output_type": "stream",
|
75 |
+
"text": [
|
76 |
+
"self.sample_patch_num 784\n",
|
77 |
+
"self.sample_audio_patch_num None\n",
|
78 |
+
"self.sample_video_patch_num None\n",
|
79 |
+
"self.with_cls False\n",
|
80 |
+
"Frozen image bn <class 'models.ofa.frozen_bn.FrozenBatchNorm2d'>\n",
|
81 |
+
"Loading: all_resnext101\n",
|
82 |
+
"use bn: <class 'torch.nn.modules.batchnorm.BatchNorm3d'>\n",
|
83 |
+
"load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n",
|
84 |
+
"_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n",
|
85 |
+
"load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth\n",
|
86 |
+
"<All keys matched successfully>\n",
|
87 |
+
"RAM memory % used: 10.5\n",
|
88 |
+
"RAM Used (GB): 19.574349824\n",
|
89 |
+
"encoder\n",
|
90 |
+
"RAM memory % used: 10.5\n",
|
91 |
+
"decoder\n",
|
92 |
+
"RAM memory % used: 10.5\n",
|
93 |
+
"ofa\n",
|
94 |
+
"Working with z of shape (1, 256, 32, 32) = 262144 dimensions.\n"
|
95 |
+
]
|
96 |
+
}
|
97 |
+
],
|
98 |
+
"source": [
|
99 |
+
"# Load pretrained ckpt & config\n",
|
100 |
+
"clip_model_path='/data/mshukor/data/ofa/clip/ViT-B-16.pt'\n",
|
101 |
+
"vqgan_model_path='/data/mshukor/data/ofa/vqgan/last.ckpt'\n",
|
102 |
+
"vqgan_config_path='/data/mshukor/data/ofa/vqgan/model.yaml'\n",
|
103 |
+
"\n",
|
104 |
+
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofa_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n",
|
105 |
+
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_long/checkpoint_best.pt'\n",
|
106 |
+
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_base_best.pt'\n",
|
107 |
+
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_large_best.pt'\n",
|
108 |
+
"\n",
|
109 |
+
"# checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_1_base_s2_hsep1_long/checkpoint_best.pt'\n",
|
110 |
+
"checkpoint_path = '/data/mshukor/logs/ofa/best_models/image_gen_ofaplus_stage_2_base_s2_hsep1_long/checkpoint_best.pt'\n",
|
111 |
+
"\n",
|
112 |
+
"\n",
|
113 |
+
"\n",
|
114 |
+
"video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n",
|
115 |
+
"resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'\n",
|
116 |
+
"\n",
|
117 |
+
"gen_images_path='results/image_gen/'\n",
|
118 |
+
"\n",
|
119 |
+
"overrides = {\"bpe_dir\": \"utils/BPE\",\n",
|
120 |
+
" \"eval_cider\": False,\n",
|
121 |
+
" \"beam\": 24,\n",
|
122 |
+
" \"max_len_b\": 1024,\n",
|
123 |
+
" \"max_len_a\": 0,\n",
|
124 |
+
" \"min_len\": 1024,\n",
|
125 |
+
" \"sampling_topk\": 256,\n",
|
126 |
+
" \"constraint_range\": \"50265,58457\",\n",
|
127 |
+
" \"clip_model_path\": clip_model_path,\n",
|
128 |
+
" \"vqgan_model_path\": vqgan_model_path,\n",
|
129 |
+
" \"vqgan_config_path\": vqgan_config_path,\n",
|
130 |
+
" \"seed\": 42,\n",
|
131 |
+
" \"video_model_path\": video_model_path, \n",
|
132 |
+
" \"resnet_model_path\": resnet_model_path,\n",
|
133 |
+
" \"gen_images_path\":gen_images_path,\n",
|
134 |
+
" \"patch_image_size\": 256,\n",
|
135 |
+
" \"temperature\": 1.5,\n",
|
136 |
+
" }\n",
|
137 |
+
"\n",
|
138 |
+
"models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n",
|
139 |
+
" utils.split_paths(checkpoint_path),\n",
|
140 |
+
" arg_overrides=overrides\n",
|
141 |
+
")\n",
|
142 |
+
"\n",
|
143 |
+
"task.cfg.sampling_times = 2\n",
|
144 |
+
"# Move models to GPU\n",
|
145 |
+
"for model in models:\n",
|
146 |
+
" model.eval()\n",
|
147 |
+
" if use_fp16:\n",
|
148 |
+
" model.half()\n",
|
149 |
+
" if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n",
|
150 |
+
" model.cuda()\n",
|
151 |
+
" model.prepare_for_inference_(cfg)\n",
|
152 |
+
"\n",
|
153 |
+
"# Initialize generator\n",
|
154 |
+
"generator = task.build_generator(models, cfg.generation)\n",
|
155 |
+
"\n",
|
156 |
+
"# Text preprocess\n",
|
157 |
+
"bos_item = torch.LongTensor([task.src_dict.bos()])\n",
|
158 |
+
"eos_item = torch.LongTensor([task.src_dict.eos()])\n",
|
159 |
+
"pad_idx = task.src_dict.pad()"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "markdown",
|
164 |
+
"id": "5e4a45ec-bce1-495b-8033-3b574367b360",
|
165 |
+
"metadata": {},
|
166 |
+
"source": [
|
167 |
+
"### Preprocess"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": 13,
|
173 |
+
"id": "9f2e7e32-c9a0-43b3-bf86-2419d9f7dfe0",
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"def encode_text(text, length=None, append_bos=False, append_eos=False):\n",
|
178 |
+
" s = task.tgt_dict.encode_line(\n",
|
179 |
+
" line=task.bpe.encode(text),\n",
|
180 |
+
" add_if_not_exist=False,\n",
|
181 |
+
" append_eos=False\n",
|
182 |
+
" ).long()\n",
|
183 |
+
" if length is not None:\n",
|
184 |
+
" s = s[:length]\n",
|
185 |
+
" if append_bos:\n",
|
186 |
+
" s = torch.cat([bos_item, s])\n",
|
187 |
+
" if append_eos:\n",
|
188 |
+
" s = torch.cat([s, eos_item])\n",
|
189 |
+
" return s\n",
|
190 |
+
"\n",
|
191 |
+
"\n",
|
192 |
+
"# Construct input for image generation task\n",
|
193 |
+
"def construct_sample(query: str):\n",
|
194 |
+
" code_mask = torch.tensor([True])\n",
|
195 |
+
" src_text = encode_text(\" what is the complete image? caption: {}\".format(query), append_bos=True,\n",
|
196 |
+
" append_eos=True).unsqueeze(0)\n",
|
197 |
+
" src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n",
|
198 |
+
" sample = {\n",
|
199 |
+
" \"id\": np.array(['42']),\n",
|
200 |
+
" \"net_input\": {\n",
|
201 |
+
" \"src_tokens\": src_text,\n",
|
202 |
+
" \"src_lengths\": src_length,\n",
|
203 |
+
" \"code_masks\": code_mask\n",
|
204 |
+
" }\n",
|
205 |
+
" }\n",
|
206 |
+
" return sample\n",
|
207 |
+
"\n",
|
208 |
+
"\n",
|
209 |
+
"# Function to turn FP32 to FP16\n",
|
210 |
+
"def apply_half(t):\n",
|
211 |
+
" if t.dtype is torch.float32:\n",
|
212 |
+
" return t.to(dtype=torch.half)\n",
|
213 |
+
" return t\n",
|
214 |
+
"\n",
|
215 |
+
"\n",
|
216 |
+
"# Function for image generation\n",
|
217 |
+
"def image_generation(caption):\n",
|
218 |
+
" sample = construct_sample(caption)\n",
|
219 |
+
" sample = utils.move_to_cuda(sample) if use_cuda else sample\n",
|
220 |
+
" sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample\n",
|
221 |
+
" print('|Start|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n",
|
222 |
+
" with torch.no_grad():\n",
|
223 |
+
" result, scores = eval_step(task, generator, models, sample)\n",
|
224 |
+
"\n",
|
225 |
+
" # return top-4 results (ranked by clip)\n",
|
226 |
+
" images = [result[i]['image'] for i in range(4)]\n",
|
227 |
+
" pic_size = 256\n",
|
228 |
+
" retImage = Image.new('RGB', (pic_size * 2, pic_size * 2))\n",
|
229 |
+
" print('|FINISHED|', time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()), caption)\n",
|
230 |
+
" for i in range(4):\n",
|
231 |
+
" loc = ((i % 2) * pic_size, int(i / 2) * pic_size)\n",
|
232 |
+
" retImage.paste(images[i], loc)\n",
|
233 |
+
" return retImage"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "markdown",
|
238 |
+
"id": "44dec799-c5c2-4d22-8b08-7a7ca2cdf3c9",
|
239 |
+
"metadata": {},
|
240 |
+
"source": [
|
241 |
+
"### Inference"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": 14,
|
247 |
+
"id": "02d5cd7a-8d63-4fa4-9da1-d4b79ec01445",
|
248 |
+
"metadata": {},
|
249 |
+
"outputs": [
|
250 |
+
{
|
251 |
+
"name": "stdout",
|
252 |
+
"output_type": "stream",
|
253 |
+
"text": [
|
254 |
+
"|Start| 2023-06-29 12:57:39 A brown horse in the street\n",
|
255 |
+
"|FINISHED| 2023-06-29 12:59:03 A brown horse in the street\n"
|
256 |
+
]
|
257 |
+
}
|
258 |
+
],
|
259 |
+
"source": [
|
260 |
+
"query = \"A brown horse in the street\"\n",
|
261 |
+
"# query = \"Cattle grazing on grass near a lake surrounded by mountain.\"\n",
|
262 |
+
"# query = 'A street scene with a double-decker bus on the road.'\n",
|
263 |
+
"# query = 'A path.'\n",
|
264 |
+
"\n",
|
265 |
+
"\n",
|
266 |
+
"retImage = image_generation(query)\n"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": null,
|
272 |
+
"id": "1a8a1654-1f17-41c7-b410-c7491a96dcee",
|
273 |
+
"metadata": {},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"retImage.save(f'{query}.png')"
|
277 |
+
]
|
278 |
+
}
|
279 |
+
],
|
280 |
+
"metadata": {
|
281 |
+
"kernelspec": {
|
282 |
+
"display_name": "ofa",
|
283 |
+
"language": "python",
|
284 |
+
"name": "ofa"
|
285 |
+
},
|
286 |
+
"language_info": {
|
287 |
+
"codemirror_mode": {
|
288 |
+
"name": "ipython",
|
289 |
+
"version": 3
|
290 |
+
},
|
291 |
+
"file_extension": ".py",
|
292 |
+
"mimetype": "text/x-python",
|
293 |
+
"name": "python",
|
294 |
+
"nbconvert_exporter": "python",
|
295 |
+
"pygments_lexer": "ipython3",
|
296 |
+
"version": "3.7.4"
|
297 |
+
}
|
298 |
+
},
|
299 |
+
"nbformat": 4,
|
300 |
+
"nbformat_minor": 5
|
301 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 1999-2022 Alibaba Group Holding Ltd.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,618 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!---
|
2 |
+
Copyright 2022 The OFA-Sys Team.
|
3 |
+
All rights reserved.
|
4 |
+
This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory.
|
5 |
+
-->
|
6 |
+
|
7 |
+
todo:
|
8 |
+
models
|
9 |
+
data
|
10 |
+
all readme
|
11 |
+
animation
|
12 |
+
|
13 |
+
readme of:
|
14 |
+
rewarded soups
|
15 |
+
and others
|
16 |
+
|
17 |
+
<p align="center">
|
18 |
+
<br>
|
19 |
+
<img src="examples/OFA_logo_tp_path.svg" width="150" />
|
20 |
+
<br>
|
21 |
+
<p>
|
22 |
+
<br>
|
23 |
+
|
24 |
+
<p align="center">
|
25 |
+
<a href="modelscope.md">ModelScope</a>  |  <a href="checkpoints.md">Checkpoints</a>  |  <a href="colab.md">Colab</a>  |  <a href="https://huggingface.co/ofa-sys">Demo</a>  |  <a href="http://arxiv.org/abs/2202.03052">Paper </a>  |  Blog
|
26 |
+
</p>
|
27 |
+
|
28 |
+
<p align="center">
|
29 |
+
<br>
|
30 |
+
<img src="examples/demo.gif" width="800" />
|
31 |
+
<br>
|
32 |
+
<p>
|
33 |
+
|
34 |
+
[colab]: <https://colab.research.google.com/assets/colab-badge.svg>
|
35 |
+
|
36 |
+
OFA is a unified sequence-to-sequence pretrained model (support **English** and **Chinese**) that unifies modalities (i.e., cross-modality, vision, language) and tasks (**finetuning** and **prompt tuning** are supported): image captioning (1st at the [MSCOCO Leaderboard](https://competitions.codalab.org/competitions/3221#results)), VQA ([link](https://eval.ai/web/challenges/challenge-page/830/leaderboard/2278)), visual grounding, text-to-image generation, text classification, text generation, image classification, etc. We provide **step-by-step** instructions for pretraining and finetuning and corresponding checkpoints (check official ckpt \[[EN](checkpoints.md)|[CN](checkpoints_cn.md)\] or [huggingface ckpt](https://huggingface.co/OFA-Sys)).
|
37 |
+
|
38 |
+
We sincerely welcome contributions to our project. Feel free to contact us or send us issues / PRs!
|
39 |
+
<br></br>
|
40 |
+
|
41 |
+
# Our installation
|
42 |
+
|
43 |
+
after installling pycocoevalcap, donwload needed models:
|
44 |
+
```
|
45 |
+
python -c "from pycocoevalcap.spice.spice import Spice; tmp = Spice()"
|
46 |
+
|
47 |
+
```
|
48 |
+
|
49 |
+
# Online Demos
|
50 |
+
We provide online demo via Hugging Face Spaces for you to interact with our pretrained and finetuned models. Below are the links to the demos:
|
51 |
+
* Image Captioning \[[ModelScope](https://modelscope.cn/#/models/damo/ofa_image-caption_coco_large_en/summary) | [Spaces](https://huggingface.co/spaces/OFA-Sys/OFA-Image_Caption)\]
|
52 |
+
* Visual Grounding \[[ModelScope](https://modelscope.cn/#/models/damo/ofa_visual-grounding_refcoco_large_en/summary) | [Spaces](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Grounding)\]
|
53 |
+
* Visual Question Answering \[[ModelScope](https://modelscope.cn/#/models/damo/ofa_visual-question-answering_pretrain_large_en/summary) | [Spaces](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Question_Answering)\]
|
54 |
+
* Text-to-Image Generation \[[ModelScope](https://modelscope.cn/#/models/damo/ofa_text-to-image-synthesis_coco_large_en/summary) | [Spaces](https://huggingface.co/spaces/OFA-Sys/OFA-Text2Image_Generation)\]
|
55 |
+
* Generic Interface \[[Spaces](https://huggingface.co/spaces/OFA-Sys/OFA-Generic_Interface)\]
|
56 |
+
|
57 |
+
Also we provide Colab notebooks for you to better perceive the procedures. Click [here](colab.md) to check them out!
|
58 |
+
<br></br>
|
59 |
+
|
60 |
+
# Use in Huggingface Transformers
|
61 |
+
We support the inference of OFA in Huggingface Transformers. Check the [README](transformers.md) and [Colab Notebook](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing) for more information. Codes are released in this branch https://github.com/OFA-Sys/OFA/tree/feature/add_transformers
|
62 |
+
<br><br>
|
63 |
+
|
64 |
+
|
65 |
+
# News
|
66 |
+
* 2022.8.22: Released checkpoints and demos of **OFA** and **Chinese CLIP** on [ModelScope](https://modelscope.cn/). Check the [README](modelscope.md) for more details!
|
67 |
+
* 2022.8.16: Released the **Chinese** version of OFA. **OFA-CN** needs only switching to `bpe_dir=../../utils/BERT_CN_dict` and `bpe=bert` and using our provided Chinese checkpoints in [checkpoints_cn.md](checkpoints_cn.md). Temporarily, we only provide base-size and large-size pretrained checkpoints and finetuned checkpoints on [MUGE Caption](https://tianchi.aliyun.com/muge) and the Chinese version of RefCOCO(-/+/g) (to release soon).
|
68 |
+
* 2022.8.5: Released support of **prompt tuning** for OFA. Check our paper [here](https://arxiv.org/abs/2208.02532)! Please see the [prompt_tuning.md](prompt_tuning.md) for further details.
|
69 |
+
* 2022.7.7: Updated support of OFA on **huggingface transformers** (fixed bugs in forward, add sequence generator from Fairseq to ensure performance, etc.). Refer to the doc [transformers.md](transformers.md) and the branch `feature/add_transformers`.
|
70 |
+
* 2022.6.17: Released the pretrained checkpoint of **OFA-Huge**. To use it, set `--arch=ofa_huge` in the script.
|
71 |
+
* 2022.5.15: OFA was accepted by **ICML 2022**
|
72 |
+
* 2022.4.28: Add support of inference on **huggingface transformers**. For how to use it, please refer to the doc [transformers.md](transformers.md) and our [huggingface models](https://huggingface.co/OFA-Sys).
|
73 |
+
* 2022.4.16: Released lightweight pretrained models **OFA-Medium** (~93M params) and **OFA-Tiny** (~33M params) in [checkpoints.md](checkpoints.md). To use them, you just need to load the corresponding checkpoint and set `--arch=ofa_medium` or `--arch=ofa_tiny` in the scripts.
|
74 |
+
|
75 |
+
<details>
|
76 |
+
<summary><b>More News</b></summary>
|
77 |
+
<p>
|
78 |
+
<ul>
|
79 |
+
<li>2022.3.23: Added [Encouraging Loss](https://arxiv.org/pdf/2110.06537.pdf) as a feature. See [README_EncouragingLoss.md](README_EncouragingLoss.md). Leveraging this feature, OFA-Large has achieved improved results in both VQA (**test-std acc: 80.67**) and Image Classification (**test acc: 85.6**) recently.</li>
|
80 |
+
<li>2022.3.21: Released codes for pretraining OFA.</li>
|
81 |
+
<li>2022.3.18: Released the finetuned <b>OFA-Base</b> (~180M parameters) checkpoints and running scripts for vision & language tasks, including: <b>Caption (146.4 CIDEr), VQA (78.07 on test-std), SNLI-VE (89.3 on dev), RefCOCO (90.67 on testA), RefCOCO+ (87.15 on testA) and RefCOCOg (82.31 on test-u)</b>.</li>
|
82 |
+
<li>2022.3.11: Released the finetuning & inference code/checkpoints for <b>Gigaword</b>.</li>
|
83 |
+
<li>2022.3.08: Released the pretrained checkpoint of <b>OFA-Base</b> in <a href="https://github.com/OFA-Sys/OFA/blob/main/checkpoints.md">checkpoints.md</a>. To use OFA-Base, you just need to load <code>ofa_base.pt</code> and change <code>--arch=ofa_large</code> to <code>--arch=ofa_base</code> in the training scripts.</li>
|
84 |
+
<li>2022.3.07: Released the finetuning & inference code/checkpoints for <b>Image Classification</b>, which achieves <b>85.0</b> accuracy on ImageNet-1K, slightly better than reported in OFA paper.</li>
|
85 |
+
<li>2022.3.04: Released the finetuning & inference code/checkpoints for <b>Text-to-Image Generation</b>.</li>
|
86 |
+
<li>2022.3.03: Released the finetuning & inference code/checkpoints for <b>SNLI-VE</b> and <b>GLUE</b>.</li>
|
87 |
+
<li>2022.2.22: Released the finetuning & inference code/checkpoints for <b>Visual Question Answering</b>, which can reproduce <b>the reported VQA accuracy in OFA paper (80.02 on test-std)</b>. Check our results on the <a href="https://eval.ai/web/challenges/challenge-page/830/leaderboard/2278">VQA Challenge</a>.</li>
|
88 |
+
<li>2022.2.15: Released finetuning & inference code/checkpoints for <b>Referring Expression Comprehension</b></li>
|
89 |
+
<li>2022.2.10: Released the inference code & finetuned checkpoint for <b>Image captioning</b>, which can reproduce <b>the results on COCO Karparthy test split (149.6 CIDEr)</b>. OFA also achieves No.1 on the COCO image captioning online leaderboard <a href='https://competitions.codalab.org/competitions/3221#results'>Link</a> (marked as M6-Team).</li>
|
90 |
+
</ul>
|
91 |
+
</p>
|
92 |
+
</details>
|
93 |
+
<br></br>
|
94 |
+
|
95 |
+
|
96 |
+
# Model Card
|
97 |
+
We list the parameters and pretrained checkpoints of OFAs below. For finetuned checkpoints, please refer to [checkpoints.md](checkpoints.md).
|
98 |
+
|
99 |
+
<table border="1" width="100%">
|
100 |
+
<tr align="center">
|
101 |
+
<th>Model</th><th>Ckpt</th><th>Params</th><th>Backbone</th><th>Hidden size</th><th>Intermediate size</th><th>Num. of heads</th><th>Enc layers</th><th>Dec layers</th>
|
102 |
+
</tr>
|
103 |
+
<tr align="center">
|
104 |
+
<td>OFA<sub>Tiny</sub></td><td><a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_tiny.pt">Download</a></td><td>33M</td><td>ResNet50</td><td>256</td><td>1024</td><td>4</td><td>4</td><td>4</td>
|
105 |
+
</tr>
|
106 |
+
<tr align="center">
|
107 |
+
<td>OFA<sub>Medium</sub></td><td><a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_medium.pt">Download</a></td><td>93M</td><td>ResNet101</td><td>512</td></td><td>2048</td><td>8</td><td>4</td><td>4</td>
|
108 |
+
</tr>
|
109 |
+
<tr align="center">
|
110 |
+
<td>OFA<sub>Base</sub></td><td><a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_base.pt">Download</a></td><td>180M</td><td>ResNet101</td><td>768</td></td><td>3072</td><td>12</td><td>6</td><td>6</td>
|
111 |
+
</tr>
|
112 |
+
<tr align="center">
|
113 |
+
<td>OFA<sub>Large</sub></td><td><a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt">Download</a></td><td>470M</td><td>ResNet152</td><td>1024</td></td><td>4096</td><td>16</td><td>12</td><td>12</td>
|
114 |
+
</tr>
|
115 |
+
<tr align="center">
|
116 |
+
<td>OFA<sub>Huge</sub></td><td><a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_huge.pt">Download</a></td><td>930M</td><td>ResNet152</td><td>1280</td></td><td>5120</td><td>16</td><td>24</td><td>12</td>
|
117 |
+
</tr>
|
118 |
+
</table>
|
119 |
+
<br></br>
|
120 |
+
|
121 |
+
# Results
|
122 |
+
Below we demonstrate the results of OFAs on cross-modal understanding and generation.
|
123 |
+
|
124 |
+
<table border="1" width="100%">
|
125 |
+
<tr align="center">
|
126 |
+
<th>Task</th><th>Image Captioning</th><th>VQA</th><th>Visual Entailment</th><th colspan="3">Referring Expression Comprehension</th>
|
127 |
+
</tr>
|
128 |
+
<tr align="center">
|
129 |
+
<td>Dataset</td><td>COCO</td><td>VQA v2</td><td>SNLI-VE</td><td>RefCOCO</td><td>RefCOCO+</td><td>RefCOCOg</td>
|
130 |
+
</tr>
|
131 |
+
<tr align="center">
|
132 |
+
<td>Split</td><td>Karpathy test (CE/CIDEr)</td><td>test-dev/test-std</td><td>val/test</td><td>val/test-a/test-b</td><td>val/test-a/test-b</td><td>val-u/test-u</td>
|
133 |
+
</tr>
|
134 |
+
<tr align="center">
|
135 |
+
<td>Metric</td><td>CIDEr</td><td>Acc.</td><td>Acc.</td><td colspan="3">Acc.</td>
|
136 |
+
</tr>
|
137 |
+
<tr align="center">
|
138 |
+
<td>OFA<sub>Tiny</sub></td><td>119.0 / 128.7</td><td>70.3 / 70.4</td><td>85.3 / 85.2</td><td>80.20 / 84.07 / 75.00</td><td>68.22 / 75.13 / 57.66</td><td>72.02 / 69.74</td>
|
139 |
+
</tr>
|
140 |
+
<tr align="center">
|
141 |
+
<td>OFA<sub>Medium</sub></td><td>130.4 / 140.3</td><td>75.4 / 75.5</td><td>86.6 / 87.0</td><td>85.34 / 87.68 / 77.92</td><td>76.09 / 83.04 / 66.25</td><td>78.76 / 78.58</td>
|
142 |
+
</tr>
|
143 |
+
<tr align="center">
|
144 |
+
<td>OFA<sub>Base</sub></td><td>138.2 / 146.7</td><td>78.0 / 78.1</td><td>89.3 / 89.2</td><td>88.48 / 90.67 / 83.30</td><td>81.39 / 87.15 / 74.29</td><td>82.29 / 82.31</td>
|
145 |
+
</tr>
|
146 |
+
<tr align="center">
|
147 |
+
<td>OFA<sub>Large</sub></td><td>142.2 / 150.7</td><td>80.4 / 80.7</td><td>90.3 / 90.2</td><td>90.05 / 92.93 / 85.26</td><td>85.80 / 89.87 / 79.22</td><td>85.89 / 86.55</td>
|
148 |
+
</tr>
|
149 |
+
<tr align="center">
|
150 |
+
<td>OFA<sub>Huge</sub></td><td>145.3 / 154.9</td><td>82.0 / 82.0</td><td>91.0 / 91.2</td><td>92.04 / 94.03 / 88.44</td><td>87.86 / 91.70 / 80.71</td><td>88.07 / 88.78</td>
|
151 |
+
</tr>
|
152 |
+
</table>
|
153 |
+
<br></br>
|
154 |
+
|
155 |
+
# Requirements
|
156 |
+
* python 3.7.4
|
157 |
+
* pytorch 1.8.1
|
158 |
+
* torchvision 0.9.1
|
159 |
+
* JAVA 1.8 (for COCO evaluation)
|
160 |
+
<br></br>
|
161 |
+
|
162 |
+
# Installation
|
163 |
+
```bash
|
164 |
+
git clone https://github.com/OFA-Sys/OFA
|
165 |
+
pip install -r requirements.txt
|
166 |
+
```
|
167 |
+
<br></br>
|
168 |
+
|
169 |
+
# Datasets and Checkpoints
|
170 |
+
See [datasets.md](datasets.md) and [checkpoints.md](checkpoints.md).
|
171 |
+
<br></br>
|
172 |
+
|
173 |
+
# Training & Inference
|
174 |
+
Below we provide methods for training and inference on different tasks. We provide both pretrained OFA-Large and OFA-Base in [checkpoints.md](checkpoints.md). The scripts mentioned in this section are prepared for OFA-Large. For reproducing the downstreaming results of OFA-Base, we have also provided the corresponding finetuning and inference scripts for OFA-Base in the `run_scripts/` folder.
|
175 |
+
|
176 |
+
We recommend that your workspace directory should be organized like this:
|
177 |
+
```
|
178 |
+
OFA/
|
179 |
+
├── checkpoints/
|
180 |
+
│ ├── ofa_base.pt
|
181 |
+
│ ├── ofa_large.pt
|
182 |
+
│ ├── caption_large_best_clean.pt
|
183 |
+
│ └── ...
|
184 |
+
├── criterions/
|
185 |
+
├── data/
|
186 |
+
├── dataset/
|
187 |
+
│ ├── caption_data/
|
188 |
+
│ ├── gigaword_data/
|
189 |
+
│ └── ...
|
190 |
+
├── fairseq/
|
191 |
+
├── models/
|
192 |
+
├── run_scripts/
|
193 |
+
├── tasks/
|
194 |
+
├── train.py
|
195 |
+
├── trainer.py
|
196 |
+
└── utils/
|
197 |
+
```
|
198 |
+
|
199 |
+
|
200 |
+
## Image Processing
|
201 |
+
To ensure the efficiency of processing data, we did not store images with small files, but instead we encode them to base64 strings.
|
202 |
+
Transforming image files to base64 strings is simple. Run the following code:
|
203 |
+
```python
|
204 |
+
from PIL import Image
|
205 |
+
from io import BytesIO
|
206 |
+
import base64
|
207 |
+
|
208 |
+
img = Image.open(file_name) # path to file
|
209 |
+
img_buffer = BytesIO()
|
210 |
+
img.save(img_buffer, format=img.format)
|
211 |
+
byte_data = img_buffer.getvalue()
|
212 |
+
base64_str = base64.b64encode(byte_data) # bytes
|
213 |
+
base64_str = base64_str.decode("utf-8") # str
|
214 |
+
```
|
215 |
+
|
216 |
+
## Pretraining
|
217 |
+
Below we provide methods for pretraining OFA.
|
218 |
+
|
219 |
+
<details>
|
220 |
+
<summary><b>1. Prepare the Dataset</b></summary>
|
221 |
+
<p>
|
222 |
+
To pretrain OFA, you should first download the dataset we provide (<a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/pretrain_data/pretrain_data_examples.zip">pretrain_data_examples.zip</a>, a small subset of the original pretraining data). For your customed pretraining datasets, please prepare your training samples into the same format. <code>pretrain_data_examples.zip</code> contains 4 TSV files: <code>vision_language_examples.tsv</code>, <code>text_examples.tsv</code>, <code>image_examples.tsv</code> and <code>detection_examples.tsv</code>. Details of these files are as follows:
|
223 |
+
<br />
|
224 |
+
<ul type="circle">
|
225 |
+
<li><b>vision_language_examples.tsv</b>:
|
226 |
+
Each line contains uniq-id, image (base64 string), caption, question, answer, ground-truth objects (objects appearing in the caption or question), dataset name (source of the data) and task type (caption, qa or visual gronunding). Prepared for the pretraining tasks of visual grounding, grounded captioning, image-text matching, image captioning and visual question answering. </li>
|
227 |
+
<li><b>text_examples.tsv</b>: Each line contains uniq-id and text. Prepared for the pretraining task of text infilling. </li>
|
228 |
+
<li><b>image_examples.tsv</b>: Each line contains uniq-id, image (base64 string, should be resized to 256*256 resolution) and image-code (generate the sparse codes for the central part of image through VQ-GAN). Prepared for the pretraining task of image infilling. </li>
|
229 |
+
<li><b>detection_examples.tsv</b>: Each line contains uniq-id, image (base64 string) and bounding box annotations (contains the top-left and bottom-right coordinates of the bounding box, object_id and object_name, seperated by commas). Prepared for the pretraining task of detection. </li>
|
230 |
+
</ul>
|
231 |
+
In addition, the folder negative_sample in pretrain_data_examples.zip contains three files <code>all_captions.txt</code>, <code>object.txt</code> and <code>type2ans.json</code>. The data in these files are used as negative samples for the image-text matching (ITM) task.
|
232 |
+
</p>
|
233 |
+
</details>
|
234 |
+
<details>
|
235 |
+
<summary><b>2. Pretraining</b></summary>
|
236 |
+
<p>
|
237 |
+
By default, the pretraining script will attempt to restore the released pretrained checkpoints of OFA-Base or OFA-Large and perform continuous pretraining. Continuous pretraining is more recommended, which achieves much better results compared with pretraining from scratch. For continuous pretraining, please download the pretrained weights in advance (see <a href='checkpoints.md'>checkpoints.md</a>) and put them in the correct directory <code>OFA/checkpoints/</code>. If not, the pretraining will begin from scratch.
|
238 |
+
</p>
|
239 |
+
<pre>
|
240 |
+
cd run_scripts/pretraining
|
241 |
+
bash pretrain_ofa_large.sh # Pretrain OFA-Large. For OFA-Base, use pretrain_ofa_base.sh
|
242 |
+
</pre>
|
243 |
+
<p>
|
244 |
+
If the pretrained OFA checkpoint is restored successfully, you will see the following information in the log:
|
245 |
+
</p>
|
246 |
+
<pre>
|
247 |
+
INFO: Loaded checkpoint ../../checkpoints/ofa_large.pt
|
248 |
+
</pre>
|
249 |
+
</details>
|
250 |
+
|
251 |
+
## Image Captioning
|
252 |
+
We provide procedures to reproduce our results of image captioning on our paper below.
|
253 |
+
<details>
|
254 |
+
<summary><b>1. Prepare the Dataset & Checkpoints</b></summary>
|
255 |
+
<p>
|
256 |
+
Download data (see <a href='datasets.md'>datasets.md</a>) and models (see <a href='checkpoints.md'>checkpoints.md</a>) and put them in the correct directory. The dataset zipfile <code>caption_data.zip</code> contains caption_stage1_train.tsv, caption_stage2_train.tsv, caption_val.tsv and caption_test.tsv. Each image corresponds to only 1 caption in <code>caption_stage1_train.tsv</code> and corresponds to multiple captions in other TSV files (about 5 captions per image). Each line of the dataset represents a caption sample with the following format. The information of uniq-id, image-id, caption, predicted object labels (taken from <a href='https://github.com/pzzhang/VinVL'>VinVL</a>, not used), image base64 string are separated by tabs.
|
257 |
+
</p>
|
258 |
+
<pre>
|
259 |
+
162365 12455 the sun sets over the trees beyond some docks. sky&&water&&dock&&pole /9j/4AAQSkZJ....UCP/2Q==
|
260 |
+
</pre>
|
261 |
+
</details>
|
262 |
+
<details>
|
263 |
+
<summary><b>2. Finetuning</b></summary>
|
264 |
+
<p>
|
265 |
+
Following previous standard practice, we divide the finetuning process of image captioning into two stages. In stage 1, we finetune OFA with cross-entropy loss on 4 NVIDIA-V100 GPUs with 32GB memory (expected to obtain ~139.5 CIDEr on the validation set at this stage). In stage 2, we select the best checkpoint of stage 1 and train with CIDEr optimization on 8 NVIDIA-V100 GPUs. <b>Note that CIDEr optimization is very unstable and requires careful hyperparameter tuning. If you encounter training errors in the stage2 finetuning, you can increase the batch size or reduce the learning rate. If neither of these works, you can directly set </b><code>--freeze-resnet</code><b> to freeze the inner states of batch normalization.</b>
|
266 |
+
</p>
|
267 |
+
<pre>
|
268 |
+
cd run_scripts/caption
|
269 |
+
nohup sh train_caption_stage1.sh > train_stage1.out & # stage 1, train with cross-entropy loss
|
270 |
+
nohup sh train_caption_stage2.sh > train_stage2.out & # stage 2, load the best ckpt of stage1 and train with CIDEr optimization
|
271 |
+
</pre>
|
272 |
+
</details>
|
273 |
+
<details>
|
274 |
+
<summary><b>3. Inference</b></summary>
|
275 |
+
<p>
|
276 |
+
Run the following commands to get your results and evaluate your model.
|
277 |
+
</p>
|
278 |
+
<pre>
|
279 |
+
cd run_scripts/caption ; sh evaluate_caption.sh # inference & evaluate
|
280 |
+
</pre>
|
281 |
+
</details>
|
282 |
+
|
283 |
+
## Text-to-Image Generation
|
284 |
+
This part provides procedures for the finetuning and inference of text-to-image generation. See below.
|
285 |
+
|
286 |
+
<details>
|
287 |
+
<summary><b>1. Prepare the Dataset & Checkpoints</b></summary>
|
288 |
+
<p>
|
289 |
+
Download data (see <a href="datasets.md">datasets.md</a>) and models (see <a href="checkpoints.md">checkpoints.md</a>) and put them in the correct directory. The dataset zipfile <code>coco_image_gen.zip</code> contains <code>coco_vqgan_train.tsv</code>, <code>coco_vqgan_dev.tsv</code> and <code>coco_vqgan_full_test.tsv</code>. Each line of the dataset represents a sample with the following format. The information of uniq-id, image-code (produced by <a href="https://github.com/CompVis/taming-transformers">vqgan</a>, a list of integers separated by single-whitespaces), lowercased caption are separated by tabs.
|
290 |
+
</p>
|
291 |
+
<pre>
|
292 |
+
1 6674 4336 4532 5334 3251 5461 3615 2469 ...4965 4190 1846 the people are posing for a group photo.
|
293 |
+
</pre>
|
294 |
+
<p>
|
295 |
+
The checkpoint zipfile <code>image_gen_large_best.zip</code> contains <code>image_gen_large_best.pt</code>, <code>vqgan/last.ckpt</code>, <code>vqgan/model.yaml</code> and <code>clip/Vit-B-16.pt</code>.
|
296 |
+
</p>
|
297 |
+
</details>
|
298 |
+
<details>
|
299 |
+
<summary><b>2. Shuffle the Training Data</b></summary>
|
300 |
+
<p>
|
301 |
+
(Optional, but achieves better result): If the disk storage is sufficient, we recommend to prepare the shuffled training data for each epoch in advance.
|
302 |
+
</p>
|
303 |
+
<pre>
|
304 |
+
cd dataset/image_gen
|
305 |
+
ln coco_vqgan_train.tsv coco_vqgan_train_1.tsv
|
306 |
+
for idx in `seq 1 9`;do shuf coco_vqgan_train_${idx}.tsv > coco_vqgan_train_$[${idx}+1].tsv;done # each file is used for an epoch
|
307 |
+
</pre>
|
308 |
+
</details>
|
309 |
+
<details>
|
310 |
+
<summary><b>3. Finetuning</b></summary>
|
311 |
+
<p>
|
312 |
+
Following previous practice, we divide the finetuning process of image generating into two stages. In stage 1, we finetune OFA with cross-entropy loss on 4 8-V100-32G-GPU servers (expected to obtain ~32.5+ CLIP Score on the validation set at this stage). In stage 2, we select the last checkpoint of stage 1 and train with CLIP Score optimization on 4 8-V100-32G-GPU servers (expected to obtain ~34.0+ CLIP Score on the validation set at this stage). During the validation, the generated image will be dumped into <code>_GEN_IMAGE_PATH_</code>.
|
313 |
+
</p>
|
314 |
+
<pre>
|
315 |
+
# run on each worker after the distributed and data configs have been correctly set following the guide in train_image_gen_stage1_distributed.sh
|
316 |
+
cd run_scripts/image_gen
|
317 |
+
nohup sh train_image_gen_stage1_distributed.sh # stage 1, train with cross-entropy loss
|
318 |
+
nohup sh train_image_gen_stage2_distributed.sh # stage 2, load the last ckpt of stage1 and train with CLIP Score optimization
|
319 |
+
</pre>
|
320 |
+
</details>
|
321 |
+
<details>
|
322 |
+
<summary><b>4. Inference</b></summary>
|
323 |
+
<p>
|
324 |
+
Run the command below to generate your images.
|
325 |
+
</p>
|
326 |
+
<pre>
|
327 |
+
cd run_scripts/image_gen ; sh evaluate_image_gen.sh # inference & evaluate (FID, IS and CLIP Score)
|
328 |
+
</pre>
|
329 |
+
</details>
|
330 |
+
|
331 |
+
## Visual Question Answering
|
332 |
+
Here we provide the finetuning and inference codes to reproduce the VQAv2 result reported in our paper (**test-std 80.02**). We believe much improvement on accuracy can still be achieved based on this codebase :)
|
333 |
+
<details>
|
334 |
+
<summary><b>1. Prepare the Dataset & Checkpoints</b></summary>
|
335 |
+
<p>
|
336 |
+
Download data (see <a href="datasets.md">datasets.md</a>) and models (see <a href="checkpoints.md">checkpoints.md</a>) and put them in the correct directory. The dataset zipfile <code>vqa_data.zip</code> is around 100G and the decompressed data costs around 135G disk storage, which contains the training, validation and testing samples together with other necessary data resources. (Since <code>vqa_data.zip</code> is large in size, we have also provided chunked parts of the dataset files for more convenient and stable downloading. Please refer to <a href="https://github.com/OFA-Sys/OFA/issues/68#issuecomment-1096837349">issue #68</a>.) Following common practice, VG-QA samples are also included in the training data. To adapt to the seq2seq paradigm of OFA, we transform original VQA training questions with multiple golden answers into multiple training samples. For the original VQA validation set, we keep around 10k samples for our validation and utilize the other samples for training. Each line of the dataset represents a VQA sample with the following format. The information of question-id, image-id, question, answer (with confidence), predicted object labels (taken from <a href="https://github.com/pzzhang/VinVL">VinVL</a>, slightly brings around +0.1 accuracy improvement), image base64 string are separated by tabs.
|
337 |
+
</p>
|
338 |
+
<pre>
|
339 |
+
79459 79459 is this person wearing shorts? 0.6|!+no house&&short&&...&&sky /9j/4AAQS...tigZ/9k=
|
340 |
+
</pre>
|
341 |
+
<p>
|
342 |
+
For fine-tuning on customed VQA-formulated tasks, please refer to issue <a href="https://github.com/OFA-Sys/OFA/issues/76">#76</a>, <a href="https://github.com/OFA-Sys/OFA/issues/105">#105</a> and <a href="https://github.com/OFA-Sys/OFA/issues/73">#73</a> for more information.
|
343 |
+
</p>
|
344 |
+
</details>
|
345 |
+
<details>
|
346 |
+
<summary><b>2. Shuffle the Training Data</b></summary>
|
347 |
+
<p>
|
348 |
+
(Optional, but achieves better finetuning accuracy): If the disk storage is sufficient, we recommend to prepare the shuffled training data for each epoch in advance. In our experiments, we use shuffling which brings around <b>+0.3</b> improvement on VQA accuracy.
|
349 |
+
</p>
|
350 |
+
<pre>
|
351 |
+
cd dataset/vqa_data
|
352 |
+
ln vqa_train.tsv vqa_train_1.tsv
|
353 |
+
for idx in `seq 1 9`;do shuf vqa_train_${idx}.tsv > vqa_train_$[${idx}+1].tsv;done # each file is used for an epoch
|
354 |
+
</pre>
|
355 |
+
</details>
|
356 |
+
<details>
|
357 |
+
<summary><b>3. Finetuning</b></summary>
|
358 |
+
<p>
|
359 |
+
In our experiments, the VQA finetuning is performed on 4 8-A100-GPU servers (<i>with RDMA</i>). Here provides the finetuning script <code>train_vqa_distributed.sh</code>, which supports multi-server distributed training (as well as single-server training). Please refer to the comments in the beginning of the script and set the configs correctly according to your distribution environment. If you have shuffled the training data in the previous step, please correctly specify the training data path following the guide in the script comments. <b>The command should be run on each worker.</b>
|
360 |
+
</p>
|
361 |
+
<pre>
|
362 |
+
# run on each worker after the distributed and data configs have been correctly set following the guide in train_vqa_distributed.sh
|
363 |
+
cd run_scripts/vqa
|
364 |
+
bash train_vqa_distributed.sh
|
365 |
+
</pre>
|
366 |
+
<p>
|
367 |
+
In our experiments, the finetuning costs around 36 hours (for 12 epochs). After each epoch, an evaluation on validation set is performed. The best validation accuracy during finetuning will be around 80.8. The log is saved in <code>${log_dir}</code>.
|
368 |
+
</p>
|
369 |
+
<p>
|
370 |
+
<i>(Update on validation time-cost)</i> As will be mentioned in the <i>4. Inference</i> section, we prepare 2 types of inference: beam-search and all-candidate inference. By default, all-candidate inference is used for validation during fine-tuning, which achieves better accuracy but costs much time. Now we have added a new option in the training scripts called <code>--val-inference-type</code> to switch the validation inference type during fine-tuning. If you feel the validation takes too long, you can refer to <a href="https://github.com/OFA-Sys/OFA/pull/79">PR #79</a> to activate beam-search validation, which significantly takes much less time, with around 0.5-0.6 validation score degradation compared with all-candidate validation.
|
371 |
+
</p>
|
372 |
+
</details>
|
373 |
+
<details>
|
374 |
+
<summary><b>4. Inference</b></summary>
|
375 |
+
<p>
|
376 |
+
We provide 2 types of inference, <b>beam-search</b> (much faster but gets sub-optimal accuracy) and <b>all-candidate evaluation</b> (slower but best accuracy). <br></br>
|
377 |
+
For beam-search inference, use the script <code>evaluate_vqa_beam.sh</code>. Refer to the command below. The inference on test set costs around 16 GPU hours. After inference on test set, the result JSON file will be dumped in the <code>${result_path}</code> defined in the shell script. You can submit the result <code>test_predict.json</code> to <a href="https://eval.ai/web/challenges/challenge-page/830/overview">EvalAI</a>. Using our released finetuned checkpoint, beam-search inference will get 80.15 validation accuracy, 79.36 test-dev accuracy and 79.48 test-std accuracy (around 0.6 lower than all-candidate evaluation).
|
378 |
+
</p>
|
379 |
+
<pre>
|
380 |
+
cd run_scripts/vqa
|
381 |
+
bash evaluate_vqa_beam.sh val # specify 'val' or 'test'
|
382 |
+
</pre>
|
383 |
+
<p>
|
384 |
+
For all-candidate evaluation, we recommend to use the distributed script <code>evaluate_vqa_allcand_distributed.sh</code>. Please refer to the guide in the script to set the distributed configs before running. The result JSON file will be dumped in the <code>${result_path}</code> defined in the shell script of rank-0 server. All-candidate evaluation computes scores on all the candidate answers in the VQA dataset, which achieves <b>80.82</b> validation accuracy, <b>79.87</b> test-dev accuracy and <b>80.02</b> test-std accuracy, reproducing our reported results in the paper. However, the inference on test set costs around 1k GPU hours, which is much slower.
|
385 |
+
</p>
|
386 |
+
<pre>
|
387 |
+
# run on each worker after the distributed configs have been correctly set following the guide in evaluate_vqa_allcand_distributed.sh
|
388 |
+
cd run_scripts/vqa
|
389 |
+
bash evaluate_vqa_allcand_distributed.sh val # specify 'val' or 'test'
|
390 |
+
</pre>
|
391 |
+
</details>
|
392 |
+
|
393 |
+
## Referring Expression Comprehension
|
394 |
+
Here provides procedures for you to prepare data, train, and evaluate your model on visual grounding.
|
395 |
+
<details>
|
396 |
+
<summary><b>1. Prepare the Dataset & Checkpoints</b></summary>
|
397 |
+
<p>
|
398 |
+
Download data (see <a href='datasets.md'>datasets.md</a>) and models (see <a href='checkpoints.md'>checkpoints.md</a>) and put them in the correct directory. We provide RefCOCO (split by UNC), RefCOCO+ (split by UNC) and RefCOCOg (split by UMD) datasets. See <a href='https://www.tensorflow.org/datasets/catalog/ref_coco'>RefCOCO</a> and <a href="https://github.com/lichengunc/refer">Refer</a> for more details. Note that in the original dataset, each region-coord (or bounding box) may corresponds to multiple descriptive texts. We split these texts into multiple samples so that the region-coord in each sample corresponds to only one text. Each line of the processed dataset represents a sample with the following format. The information of uniq-id, image-id, text, region-coord (separated by commas), image base64 string are separated by tabs.
|
399 |
+
</p>
|
400 |
+
<pre>
|
401 |
+
79_1 237367 A woman in a white blouse holding a glass of wine. 230.79,121.75,423.66,463.06 9j/4AAQ...1pAz/9k=
|
402 |
+
</pre>
|
403 |
+
</details>
|
404 |
+
<details>
|
405 |
+
<summary><b>2. Finetuning</b></summary>
|
406 |
+
<p>
|
407 |
+
Unlike the original paper, we finetune OFA with a drop-path rate of 0.2, and found that training with this hyper-parameter achieves better results. We will update the reported results of the paper later.
|
408 |
+
</p>
|
409 |
+
<pre>
|
410 |
+
cd run_scripts/refcoco
|
411 |
+
nohup sh train_refcoco.sh > train_refcoco.out & # finetune for refcoco
|
412 |
+
nohup sh train_refcocoplus.sh > train_refcocoplus.out & # finetune for refcoco+
|
413 |
+
nohup sh train_refcocog.sh > train_refcocog.out & # finetune for refcocog
|
414 |
+
</pre>
|
415 |
+
</details>
|
416 |
+
<details>
|
417 |
+
<summary><b>3. Inference</b></summary>
|
418 |
+
<p>
|
419 |
+
Run the following commands for the evaluation.
|
420 |
+
</p>
|
421 |
+
<pre>
|
422 |
+
cd run_scripts/refcoco ; sh evaluate_refcoco.sh # inference & evaluate for refcoco/refcoco+/refcocog
|
423 |
+
</pre>
|
424 |
+
</details>
|
425 |
+
|
426 |
+
## Visual Entailment
|
427 |
+
We provide steps for you to reproduce our results in visual entailment. See the details below.
|
428 |
+
|
429 |
+
<details>
|
430 |
+
<summary><b>1. Prepare the Dataset & Checkpoints</b></summary>
|
431 |
+
<p>
|
432 |
+
Download data (see <a href="datasets.md">datasets.md</a>) and models (see <a href="checkpoints.md">checkpoints.md</a>) and put them in the correct directory. Each line of the processed dataset represents a sample with the following format. The information of uniq-id, image-id, image base64 string, hypothesis, caption (or text premise), label are separated by tabs.
|
433 |
+
</p>
|
434 |
+
<pre>
|
435 |
+
252244149.jpg#1r1n 252244149 /9j/4AAQ...MD/2Q== a man in pink and gold is chewing on a wooden toothpick. a man in pink is chewing a toothpick on the subway. neutral
|
436 |
+
</pre>
|
437 |
+
</details>
|
438 |
+
<details>
|
439 |
+
<summary><b>2. Finetuning</b></summary>
|
440 |
+
<p>
|
441 |
+
In our experiments, the SNLI-VE finetuning is performed on 8 NVIDIA-V100 GPUs with 32GB memory. In this task, we experimented with only a few sets of hyperparameters. We believe that proper hyperparameter tuning can lead to further accuracy improvement.
|
442 |
+
</p>
|
443 |
+
<pre>
|
444 |
+
cd run_scripts/snli_ve
|
445 |
+
nohup sh train_snli_ve.sh > train_snli_ve.out & # finetune for snli_ve
|
446 |
+
</pre>
|
447 |
+
</details>
|
448 |
+
<details>
|
449 |
+
<summary><b>3. Inference</b></summary>
|
450 |
+
<p>
|
451 |
+
Run the following command to obtain the results.
|
452 |
+
</p>
|
453 |
+
<pre>
|
454 |
+
cd run_scripts/snli_ve ; sh evaluate_snli_ve.sh dev # specify 'dev' or 'test'
|
455 |
+
</pre>
|
456 |
+
</details>
|
457 |
+
|
458 |
+
## GLUE
|
459 |
+
Here we provide steps for you to finetune and evaluate our model on language understanding tasks. We demonstrate our practice for the GLUE benchmark.
|
460 |
+
|
461 |
+
<details>
|
462 |
+
<summary><b>1. Prepare the Dataset & Checkpoints</b></summary>
|
463 |
+
<p>
|
464 |
+
Download data (see <a href="datasets.md">datasets.md</a>) and models (see <a href="checkpoints.md">checkpoints.md</a>) and put them in the correct directory. we provide 7 language understanding datasets from GLUE benchmark, including COLA, MNLI, MRPC, QNLI, QQP, RTE and SST2. More details about these datasets can be found in this <a href="https://openreview.net/pdf?id=rJ4km2R5t7">link</a>.
|
465 |
+
</p>
|
466 |
+
</details>
|
467 |
+
<details>
|
468 |
+
<summary><b>2. Finetuning</b></summary>
|
469 |
+
<p>
|
470 |
+
For each task, we have tried multiple sets of hyperparameters (including learning rate, batch size, training epochs). The results under different sets of hyperparameters can be found in <code>${log_dir}</code>.
|
471 |
+
</p>
|
472 |
+
<pre>
|
473 |
+
cd run_scripts/glue
|
474 |
+
nohup sh train_cola.sh > train_cola.out & # finetune for cola
|
475 |
+
nohup sh train_mnli.sh > train_mnli.out & # finetune for mnli
|
476 |
+
nohup sh train_mrpc.sh > train_mrpc.out & # finetune for mrpc
|
477 |
+
nohup sh train_qnli.sh > train_qnli.out & # finetune for qnli
|
478 |
+
nohup sh train_qqp.sh > train_qqp.out & # finetune for qqp
|
479 |
+
nohup sh train_rte.sh > train_rte.out & # finetune for rte
|
480 |
+
nohup sh train_sst2.sh > train_sst2.out & # finetune for sst2
|
481 |
+
</pre>
|
482 |
+
</details>
|
483 |
+
|
484 |
+
## Image Classification on ImageNet-1K
|
485 |
+
We provide the finetuning and inference codes which reproduce **85.0 ImageNet-1K accuracy**, slightly better than reported in our paper.
|
486 |
+
|
487 |
+
<details>
|
488 |
+
<summary><b>1. Prepare the Dataset & Checkpoints</b></summary>
|
489 |
+
<p>
|
490 |
+
Download data (see <a href="datasets.md">datasets.md</a>) and models (see <a href="checkpoints.md">checkpoints.md</a>) and put them in the correct directory. Our provided data is derived from the original <a href="http://image-net.org/">ImageNet-1K</a> (ILSVRC2012 train & validation) dataset and shares the same data split with it. To formulate the classification task into seq2seq paradigm, we use the <a href="https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt">synset words</a> provided by Caffe as the generation target for each image class. Each line of the processed dataset represents a sample with the following format. The information of image base64 string, classification label (1-indexed, conform to the order in <code>synset_words.txt</code>), synset words of the label are separated by tabs.
|
491 |
+
</p>
|
492 |
+
<pre>
|
493 |
+
_9j_4AAQS...fzX__Z 769 rugby ball
|
494 |
+
</pre>
|
495 |
+
</details>
|
496 |
+
<details>
|
497 |
+
<summary><b>2. Shuffle the Training Data</b></summary>
|
498 |
+
<p>
|
499 |
+
(Optional, but achieves better finetuning accuracy): If the disk storage is sufficient, we recommend to prepare the shuffled training data for each epoch in advance. In our experiments, we use shuffling which brings around <b>+0.2</b> improvement on ImageNet-1K accuracy.
|
500 |
+
</p>
|
501 |
+
<pre>
|
502 |
+
cd dataset/imagenet_1k_data
|
503 |
+
ln imagenet_1k_train.tsv imagenet_1k_train_1.tsv
|
504 |
+
for idx in `seq 1 9`;do shuf imagenet_1k_train_${idx}.tsv > imagenet_1k_train_$[${idx}+1].tsv;done # each file is used for an epoch one by one
|
505 |
+
</pre>
|
506 |
+
</details>
|
507 |
+
<details>
|
508 |
+
<summary><b>3. Finetuning</b></summary>
|
509 |
+
<p>
|
510 |
+
In our experiments, the ImageNet-1K finetuning is performed on 2 8-A100-GPU servers (<i>with RDMA</i>). Here provides the finetuning script <code>train_imagenet_distributed.sh</code>, which supports multi-server distributed training (as well as single-server training). Please refer to the comments in the beginning of the script and set the configs correctly according to your distribution environment. If you have shuffled the training data in the previous step, please correctly specify the training data path following the guide in the script comments. <b>The command should be run on each worker.</b> For quick evaluation during finetuning, by default we sample 20% of the original validation split and report accuracy on this subset after each epoch. The accuracy on the validation subset is generally ±0.1 relative to accuracy on the whole validation split.
|
511 |
+
</p>
|
512 |
+
<pre>
|
513 |
+
# run on each worker after the distributed and data configs have been correctly set following the guide in train_imagenet_distributed.sh
|
514 |
+
cd run_scripts/image_classify
|
515 |
+
bash train_imagenet_distributed.sh
|
516 |
+
</pre>
|
517 |
+
<p>
|
518 |
+
In our experiments, the finetuning costs around 80 hours (for 32 epochs). The best accuracy on validation subset during finetuning will be around 85.0. The log is saved in <code>${log_dir}</code>.
|
519 |
+
</p>
|
520 |
+
</details>
|
521 |
+
<details>
|
522 |
+
<summary><b>4. Inference</b></summary>
|
523 |
+
<p>
|
524 |
+
To get the validation accuracy on the whole ImageNet-1K validation set, run the following command. The evaluation costs around 10 GPU hours. The accuracy will be reported in the stdout (expected to be around <b>85.0</b>).
|
525 |
+
</p>
|
526 |
+
<pre>
|
527 |
+
cd run_scripts/image_classify ; sh evaluate_imagenet.sh # inference & evaluate for imagenet-1k
|
528 |
+
</pre>
|
529 |
+
</details>
|
530 |
+
|
531 |
+
## Gigaword
|
532 |
+
We provide steps for you to reproduce our results in Gigaword. See the details below.
|
533 |
+
|
534 |
+
<details>
|
535 |
+
<summary><b>1. Prepare the Dataset & Checkpoints</b></summary>
|
536 |
+
<p>
|
537 |
+
Download data (see <a href="datasets.md">datasets.md</a>) and models (see <a href="checkpoints.md">checkpoints.md</a>) and put them in the correct directory. The original dataset is taken from <a href="https://github.com/microsoft/unilm/">UniLM</a> and we organized the data into the tsv format. Each line of the processed dataset represents a sample with the following format. The information of source and target texts are separated by tabs.
|
538 |
+
</p>
|
539 |
+
<pre>
|
540 |
+
factory orders for manufactured goods rose #.# percent in september... us september factory orders up #.# percent
|
541 |
+
</pre>
|
542 |
+
</details>
|
543 |
+
<details>
|
544 |
+
<summary><b>2. Finetuning</b></summary>
|
545 |
+
<p>
|
546 |
+
Run the following command to train the model.
|
547 |
+
</p>
|
548 |
+
<pre>
|
549 |
+
cd run_scripts/gigaword
|
550 |
+
nohup sh train_gigaword.sh > train_gigaword.out & # finetune for gigaword
|
551 |
+
</pre>
|
552 |
+
</details>
|
553 |
+
<details>
|
554 |
+
<summary><b>3. Inference</b></summary>
|
555 |
+
<p>
|
556 |
+
Run the following command to obtain the results (~36.43 rougeL).
|
557 |
+
</p>
|
558 |
+
<pre>
|
559 |
+
cd run_scripts/gigaword ; sh evaluate_gigaword.sh # inference & evaluate for gigaword
|
560 |
+
</pre>
|
561 |
+
</details>
|
562 |
+
|
563 |
+
<br></br>
|
564 |
+
|
565 |
+
# Gallery
|
566 |
+
Below we provide examples of OFA in text-to-image generation and open-ended VQA. Also, we demonstrate its performance in unseen task (Grounded QA) as well as unseen domain (Visual Grounding on images from unseen domains).
|
567 |
+
|
568 |
+
## Text-to-Image Generation
|
569 |
+
|
570 |
+
![case1](examples/case1.png)
|
571 |
+
|
572 |
+
|
573 |
+
## Open-Ended VQA
|
574 |
+
![open_vqa](examples/open_vqa.png)
|
575 |
+
|
576 |
+
## Grounded QA (unseen task)
|
577 |
+
![grounded_qa](examples/grounded_qa.png)
|
578 |
+
|
579 |
+
## Visual Grounding (unseen domain)
|
580 |
+
![vg](examples/viusal_grounding.png)
|
581 |
+
<br></br>
|
582 |
+
|
583 |
+
# Related Codebase
|
584 |
+
* [Fairseq](https://github.com/pytorch/fairseq)
|
585 |
+
* [taming-transformers](https://github.com/CompVis/taming-transformers)
|
586 |
+
<br></br>
|
587 |
+
|
588 |
+
|
589 |
+
# Getting Involved
|
590 |
+
Feel free to submit Github issues or pull requests. Welcome to contribute to our project!
|
591 |
+
|
592 |
+
To contact us, never hestitate to send an email to `zheluo.wp@alibaba-inc.com` or `junyang.ljy@alibaba-inc.com`!
|
593 |
+
<br></br>
|
594 |
+
|
595 |
+
|
596 |
+
# Citation
|
597 |
+
Please cite our paper if you find it helpful :)
|
598 |
+
|
599 |
+
```
|
600 |
+
@article{wang2022ofa,
|
601 |
+
author = {Peng Wang and
|
602 |
+
An Yang and
|
603 |
+
Rui Men and
|
604 |
+
Junyang Lin and
|
605 |
+
Shuai Bai and
|
606 |
+
Zhikang Li and
|
607 |
+
Jianxin Ma and
|
608 |
+
Chang Zhou and
|
609 |
+
Jingren Zhou and
|
610 |
+
Hongxia Yang},
|
611 |
+
title = {OFA: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence
|
612 |
+
Learning Framework},
|
613 |
+
journal = {CoRR},
|
614 |
+
volume = {abs/2202.03052},
|
615 |
+
year = {2022}
|
616 |
+
}
|
617 |
+
```
|
618 |
+
<br></br>
|
README_EncouragingLoss.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Finetuning with Encouraging Loss (EL)
|
2 |
+
Below we provide methods for finetuning with label smoothed encouraging loss proposed in [_Well-classified Examples are Underestimated in Classification with Deep Neural Networks_](https://arxiv.org/pdf/2110.06537.pdf) on different downstream tasks.
|
3 |
+
The implementation is in [label_smoothed_encouraging_loss.py](criterions/label_smoothed_encouraging_loss.py).
|
4 |
+
You can set the `--criterion` to `adjust_label_smoothed_encouraging_loss` to use it. This criterion has a hyper-parameter `--log-end`.
|
5 |
+
`--log-end < 1` results in a approximated and conservative version of the full encouraging loss.
|
6 |
+
A high log_end will more strongly weaken the gradient vanishing, enhance the modeling of the data, and increase the growth rate of the margin, but it will also bring a larger gradient norm, which will bring challenges to the existing optimization system.
|
7 |
+
We recommend higher log_end for cases with higher performance, and 0.75 or 0.5 as your first try.
|
8 |
+
## Image Captioning
|
9 |
+
We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
|
10 |
+
|
11 |
+
<details>
|
12 |
+
<summary><b>Finetuning</b></summary>
|
13 |
+
<p>
|
14 |
+
We propose two scripts for stage1. </b>
|
15 |
+
</p>
|
16 |
+
<pre>
|
17 |
+
cd run_scripts/caption
|
18 |
+
nohup sh train_caption_stage1_el.sh > train_stage1_el.out & # stage 1, train with encouraging loss, expected cider 1.403
|
19 |
+
nohup sh train_caption_stage1_el_db.sh > train_stage1_el.out & # stage 1, train with encouraging loss, and drop best examples, expected cider 1.404
|
20 |
+
</pre>
|
21 |
+
</details>
|
22 |
+
|
23 |
+
## Referring Expression Comprehension
|
24 |
+
We provide procedures for image captioning with EL below. The preprocessing is identical to default setting.
|
25 |
+
<details>
|
26 |
+
<summary><b>Finetuning</b></summary>
|
27 |
+
<pre>
|
28 |
+
cd run_scripts/refcoco
|
29 |
+
nohup sh train_refcoco_el.sh > train_refcoco_el.out & # finetune for refcoco
|
30 |
+
nohup sh train_refcocoplus_el.sh > train_refcocoplus_el.out & # finetune for refcoco+
|
31 |
+
nohup sh train_refcocog_el.sh > train_refcocog_el.out & # finetune for refcocog
|
32 |
+
</pre>
|
33 |
+
</details>
|
34 |
+
Evaluation is also the same as the default setting.
|
VG.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
VQA.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Video_Captioning.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
__pycache__/trainer.cpython-37.pyc
ADDED
Binary file (35.9 kB). View file
|
|
__pycache__/trainer.cpython-38.pyc
ADDED
Binary file (36.4 kB). View file
|
|
__pycache__/trainer.cpython-39.pyc
ADDED
Binary file (36.9 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.system('cd fairseq;'
|
4 |
+
'pip install ./; cd ..')
|
5 |
+
os.system('ls -l')
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import gradio as gr
|
10 |
+
import cv2
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision import transforms
|
13 |
+
|
14 |
+
from fairseq import utils, tasks, options
|
15 |
+
from fairseq import checkpoint_utils
|
16 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
17 |
+
|
18 |
+
from tasks.mm_tasks.caption import CaptionTask
|
19 |
+
from tasks.mm_tasks.refcoco import RefcocoTask
|
20 |
+
from tasks.mm_tasks.vqa_gen import VqaGenTask
|
21 |
+
|
22 |
+
|
23 |
+
def move2gpu(models, cfg):
|
24 |
+
for model in models:
|
25 |
+
model.eval()
|
26 |
+
if use_fp16:
|
27 |
+
model.half()
|
28 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
29 |
+
model.cuda()
|
30 |
+
model.prepare_for_inference_(cfg)
|
31 |
+
|
32 |
+
|
33 |
+
def construct_transform(patch_image_size):
|
34 |
+
mean = [0.5, 0.5, 0.5]
|
35 |
+
std = [0.5, 0.5, 0.5]
|
36 |
+
|
37 |
+
patch_resize_transform = transforms.Compose([
|
38 |
+
lambda image: image.convert("RGB"),
|
39 |
+
transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
|
40 |
+
transforms.ToTensor(),
|
41 |
+
transforms.Normalize(mean=mean, std=std),
|
42 |
+
])
|
43 |
+
|
44 |
+
return patch_resize_transform
|
45 |
+
|
46 |
+
|
47 |
+
# Register tasks
|
48 |
+
tasks.register_task('caption', CaptionTask)
|
49 |
+
tasks.register_task('refcoco', RefcocoTask)
|
50 |
+
tasks.register_task('vqa_gen', VqaGenTask)
|
51 |
+
# turn on cuda if GPU is available
|
52 |
+
use_cuda = torch.cuda.is_available()
|
53 |
+
# use fp16 only when GPU is available
|
54 |
+
use_fp16 = False
|
55 |
+
|
56 |
+
# # download checkpoints
|
57 |
+
# os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/caption_demo.pt; '
|
58 |
+
# 'mkdir -p checkpoints; mv caption_demo.pt checkpoints/caption_demo.pt')
|
59 |
+
# os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/refcoco_demo.pt; '
|
60 |
+
# 'mkdir -p checkpoints; mv refcoco_demo.pt checkpoints/refcoco_demo.pt')
|
61 |
+
# os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/general_demo.pt; '
|
62 |
+
# 'mkdir -p checkpoints; mv general_demo.pt checkpoints/general_demo.pt')
|
63 |
+
|
64 |
+
|
65 |
+
checkpoint_path = 'checkpoints/unival_s2_hs/checkpoint1.pt'
|
66 |
+
|
67 |
+
# Load ckpt & config for Image Captioning
|
68 |
+
caption_overrides={"eval_cider":False, "beam":5, "max_len_b":22, "no_repeat_ngram_size":3, "seed":7, "unnormalized": False,
|
69 |
+
"bpe_dir":"utils/BPE", "video_model_path": None,}
|
70 |
+
|
71 |
+
caption_models, caption_cfg, caption_task = checkpoint_utils.load_model_ensemble_and_task(
|
72 |
+
utils.split_paths(checkpoint_path),
|
73 |
+
arg_overrides=caption_overrides
|
74 |
+
)
|
75 |
+
|
76 |
+
# Load ckpt & config for Refcoco
|
77 |
+
refcoco_overrides = {"bpe_dir":"utils/BPE", "video_model_path": None}
|
78 |
+
|
79 |
+
refcoco_models, refcoco_cfg, refcoco_task = checkpoint_utils.load_model_ensemble_and_task(
|
80 |
+
utils.split_paths(checkpoint_path),
|
81 |
+
arg_overrides=refcoco_overrides
|
82 |
+
)
|
83 |
+
refcoco_cfg.common.seed = 7
|
84 |
+
refcoco_cfg.generation.beam = 5
|
85 |
+
refcoco_cfg.generation.min_len = 4
|
86 |
+
refcoco_cfg.generation.max_len_a = 0
|
87 |
+
refcoco_cfg.generation.max_len_b = 4
|
88 |
+
refcoco_cfg.generation.no_repeat_ngram_size = 3
|
89 |
+
|
90 |
+
# Load pretrained ckpt & config for VQA
|
91 |
+
parser = options.get_generation_parser()
|
92 |
+
input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE"]
|
93 |
+
args = options.parse_args_and_arch(parser, input_args)
|
94 |
+
vqa_cfg = convert_namespace_to_omegaconf(args)
|
95 |
+
vqa_task = tasks.setup_task(vqa_cfg.task)
|
96 |
+
vqa_models, vqa_cfg = checkpoint_utils.load_model_ensemble(
|
97 |
+
utils.split_paths(vqa_cfg.common_eval.path),
|
98 |
+
task=vqa_task
|
99 |
+
)
|
100 |
+
|
101 |
+
# Load pretrained ckpt & config for Generic Interface
|
102 |
+
parser = options.get_generation_parser()
|
103 |
+
input_args = ["", "--task=refcoco", "--beam=10", f"--path={checkpoint_path}", "--bpe-dir=utils/BPE", "--no-repeat-ngram-size=3", "--patch-image-size=384"]
|
104 |
+
args = options.parse_args_and_arch(parser, input_args)
|
105 |
+
general_cfg = convert_namespace_to_omegaconf(args)
|
106 |
+
general_task = tasks.setup_task(general_cfg.task)
|
107 |
+
general_models, general_cfg = checkpoint_utils.load_model_ensemble(
|
108 |
+
utils.split_paths(general_cfg.common_eval.path),
|
109 |
+
task=general_task
|
110 |
+
)
|
111 |
+
|
112 |
+
# move models to gpu
|
113 |
+
move2gpu(caption_models, caption_cfg)
|
114 |
+
move2gpu(refcoco_models, refcoco_cfg)
|
115 |
+
move2gpu(vqa_models, vqa_cfg)
|
116 |
+
move2gpu(general_models, general_cfg)
|
117 |
+
|
118 |
+
# Initialize generator
|
119 |
+
caption_generator = caption_task.build_generator(caption_models, caption_cfg.generation)
|
120 |
+
refcoco_generator = refcoco_task.build_generator(refcoco_models, refcoco_cfg.generation)
|
121 |
+
vqa_generator = vqa_task.build_generator(vqa_models, vqa_cfg.generation)
|
122 |
+
vqa_generator.zero_shot = True
|
123 |
+
vqa_generator.constraint_trie = None
|
124 |
+
general_generator = general_task.build_generator(general_models, general_cfg.generation)
|
125 |
+
|
126 |
+
# Construct image transforms
|
127 |
+
caption_transform = construct_transform(caption_cfg.task.patch_image_size)
|
128 |
+
refcoco_transform = construct_transform(refcoco_cfg.task.patch_image_size)
|
129 |
+
vqa_transform = construct_transform(vqa_cfg.task.patch_image_size)
|
130 |
+
general_transform = construct_transform(general_cfg.task.patch_image_size)
|
131 |
+
|
132 |
+
# Text preprocess
|
133 |
+
bos_item = torch.LongTensor([caption_task.src_dict.bos()])
|
134 |
+
eos_item = torch.LongTensor([caption_task.src_dict.eos()])
|
135 |
+
pad_idx = caption_task.src_dict.pad()
|
136 |
+
|
137 |
+
|
138 |
+
def get_symbols_to_strip_from_output(generator):
|
139 |
+
if hasattr(generator, "symbols_to_strip_from_output"):
|
140 |
+
return generator.symbols_to_strip_from_output
|
141 |
+
else:
|
142 |
+
return {generator.bos, generator.eos}
|
143 |
+
|
144 |
+
|
145 |
+
def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
|
146 |
+
x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
|
147 |
+
token_result = []
|
148 |
+
bin_result = []
|
149 |
+
img_result = []
|
150 |
+
for token in x.strip().split():
|
151 |
+
if token.startswith('<bin_'):
|
152 |
+
bin_result.append(token)
|
153 |
+
elif token.startswith('<code_'):
|
154 |
+
img_result.append(token)
|
155 |
+
else:
|
156 |
+
if bpe is not None:
|
157 |
+
token = bpe.decode('{}'.format(token))
|
158 |
+
if tokenizer is not None:
|
159 |
+
token = tokenizer.decode(token)
|
160 |
+
if token.startswith(' ') or len(token_result) == 0:
|
161 |
+
token_result.append(token.strip())
|
162 |
+
else:
|
163 |
+
token_result[-1] += token
|
164 |
+
|
165 |
+
return ' '.join(token_result), ' '.join(bin_result), ' '.join(img_result)
|
166 |
+
|
167 |
+
|
168 |
+
def bin2coord(bins, w_resize_ratio, h_resize_ratio, cfg):
|
169 |
+
bin_list = [int(bin[5:-1]) for bin in bins.strip().split()]
|
170 |
+
coord_list = []
|
171 |
+
coord_list += [bin_list[0] / (cfg.task.num_bins - 1) * cfg.task.max_image_size / w_resize_ratio]
|
172 |
+
coord_list += [bin_list[1] / (cfg.task.num_bins - 1) * cfg.task.max_image_size / h_resize_ratio]
|
173 |
+
coord_list += [bin_list[2] / (cfg.task.num_bins - 1) * cfg.task.max_image_size / w_resize_ratio]
|
174 |
+
coord_list += [bin_list[3] / (cfg.task.num_bins - 1) * cfg.task.max_image_size / h_resize_ratio]
|
175 |
+
return coord_list
|
176 |
+
|
177 |
+
|
178 |
+
def encode_text(text, length=None, append_bos=False, append_eos=False):
|
179 |
+
line = [
|
180 |
+
caption_task.bpe.encode(' {}'.format(word.strip()))
|
181 |
+
if not word.startswith('<code_') and not word.startswith('<bin_') else word
|
182 |
+
for word in text.strip().split()
|
183 |
+
]
|
184 |
+
line = ' '.join(line)
|
185 |
+
s = caption_task.tgt_dict.encode_line(
|
186 |
+
line=line,
|
187 |
+
add_if_not_exist=False,
|
188 |
+
append_eos=False
|
189 |
+
).long()
|
190 |
+
if length is not None:
|
191 |
+
s = s[:length]
|
192 |
+
if append_bos:
|
193 |
+
s = torch.cat([bos_item, s])
|
194 |
+
if append_eos:
|
195 |
+
s = torch.cat([s, eos_item])
|
196 |
+
return s
|
197 |
+
|
198 |
+
|
199 |
+
def construct_sample(image: Image, instruction: str, transform):
|
200 |
+
patch_image = transform(image).unsqueeze(0)
|
201 |
+
patch_mask = torch.tensor([True])
|
202 |
+
|
203 |
+
instruction = encode_text(' {}'.format(instruction.lower().strip()), append_bos=True, append_eos=True).unsqueeze(0)
|
204 |
+
instruction_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in instruction])
|
205 |
+
sample = {
|
206 |
+
"id": np.array(['42']),
|
207 |
+
"net_input": {
|
208 |
+
"src_tokens": instruction,
|
209 |
+
"src_lengths": instruction_length,
|
210 |
+
"patch_images": patch_image,
|
211 |
+
"patch_masks": patch_mask,
|
212 |
+
}
|
213 |
+
}
|
214 |
+
return sample
|
215 |
+
|
216 |
+
|
217 |
+
# Function to turn FP32 to FP16
|
218 |
+
def apply_half(t):
|
219 |
+
if t.dtype is torch.float32:
|
220 |
+
return t.to(dtype=torch.half)
|
221 |
+
return t
|
222 |
+
|
223 |
+
|
224 |
+
def inference(image, task_type, instruction):
|
225 |
+
if task_type == 'Image Captioning':
|
226 |
+
task = caption_task
|
227 |
+
models = caption_models
|
228 |
+
generator = caption_generator
|
229 |
+
instruction = 'what does the image describe?'
|
230 |
+
transform = caption_transform
|
231 |
+
cfg = caption_cfg
|
232 |
+
elif task_type == 'Visual Question Answering':
|
233 |
+
task = vqa_task
|
234 |
+
models = vqa_models
|
235 |
+
generator = vqa_generator
|
236 |
+
transform = vqa_transform
|
237 |
+
cfg = vqa_cfg
|
238 |
+
elif task_type == 'Visual Grounding':
|
239 |
+
task = refcoco_task
|
240 |
+
models = refcoco_models
|
241 |
+
generator = refcoco_generator
|
242 |
+
instruction = 'which region does the text " {} " describe?'.format(instruction)
|
243 |
+
transform = refcoco_transform
|
244 |
+
cfg = refcoco_cfg
|
245 |
+
elif task_type == 'General':
|
246 |
+
task = general_task
|
247 |
+
models = general_models
|
248 |
+
generator = general_generator
|
249 |
+
transform = general_transform
|
250 |
+
cfg = general_cfg
|
251 |
+
else:
|
252 |
+
raise NotImplementedError
|
253 |
+
|
254 |
+
# Construct input sample & preprocess for GPU if cuda available
|
255 |
+
sample = construct_sample(image, instruction, transform)
|
256 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
257 |
+
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
|
258 |
+
|
259 |
+
# Generate result
|
260 |
+
with torch.no_grad():
|
261 |
+
hypos = task.inference_step(generator, models, sample)
|
262 |
+
tokens, bins, imgs = decode_fn(hypos[0][0]["tokens"], task.tgt_dict, task.bpe, generator)
|
263 |
+
|
264 |
+
if bins.strip() != '':
|
265 |
+
w, h = image.size
|
266 |
+
w_resize_ratio = task.cfg.patch_image_size / w
|
267 |
+
h_resize_ratio = task.cfg.patch_image_size / h
|
268 |
+
img = np.asarray(image)
|
269 |
+
coord_list = bin2coord(bins, w_resize_ratio, h_resize_ratio, cfg)
|
270 |
+
cv2.rectangle(
|
271 |
+
img,
|
272 |
+
(int(coord_list[0]), int(coord_list[1])),
|
273 |
+
(int(coord_list[2]), int(coord_list[3])),
|
274 |
+
(0, 255, 0),
|
275 |
+
3
|
276 |
+
)
|
277 |
+
return img, None
|
278 |
+
else:
|
279 |
+
return None, tokens
|
280 |
+
|
281 |
+
inputs = [gr.inputs.Image(type='pil'), gr.inputs.Radio(choices=['Image Captioning',"Visual Question Answering", "Visual Grounding", "General"], type="value", default="Image Captioning", label="Task"), gr.inputs.Textbox(lines=1, label="Instruction")]
|
282 |
+
outputs = [gr.outputs.Image(type='pil'), 'text']
|
283 |
+
examples = [
|
284 |
+
['examples/pokemons.jpeg', 'Image Captioning', None],
|
285 |
+
['examples/cats.jpeg', 'Visual Question Answering', 'where are the cats?'],
|
286 |
+
['examples/one_piece.jpeg', 'Visual Grounding', 'a man in a straw hat and a red dress'],
|
287 |
+
['examples/three_houses.jpeg', 'General', 'which region does the text " a grey car " describe?'],
|
288 |
+
['examples/three_houses.jpeg', 'General', 'what color is the left car?']
|
289 |
+
]
|
290 |
+
|
291 |
+
title = "OFA"
|
292 |
+
description = "Gradio Demo for OFA: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework"
|
293 |
+
article = "<p style='text-align: center'><a href='http://arxiv.org/abs/2202.03052' target='_blank'>Paper</a> | <a href='https://github.com/OFA-Sys/OFA' target='_blank'>Github Repo</a></p>"
|
294 |
+
|
295 |
+
io = gr.Interface(fn=inference, inputs=inputs, outputs=outputs,
|
296 |
+
title=title, description=description, article=article, examples=examples, cache_examples=False)
|
297 |
+
io.launch()
|
checkpoints.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Checkpoints
|
2 |
+
|
3 |
+
We provide links for you to download our checkpoints, including pretrained and finetuned models on different tasks. If you would like to use OFA with Transformers, please download checkpoints at [https://huggingface.co/OFA-Sys](https://huggingface.co/OFA-Sys), and check the code in the branch `feature/add_transformers`.
|
4 |
+
|
5 |
+
## Pretraining
|
6 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_huge.pt"> Pre-trained checkpoint (OFA-Huge) </a> (~930M parameters)
|
7 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a> (~470M parameters)
|
8 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_base.pt"> Pre-trained checkpoint (OFA-Base) </a> (~180M parameters)
|
9 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_medium.pt"> Pre-trained checkpoint (OFA-Medium) </a> (~93M parameters)
|
10 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_tiny.pt"> Pre-trained checkpoint (OFA-Tiny) </a> (~33M parameters)
|
11 |
+
|
12 |
+
## Finetuning (OFA-Huge)
|
13 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_huge_best.pt"> Finetuned checkpoint for Caption on COCO </a>
|
14 |
+
|
15 |
+
## Finetuning (OFA-Large)
|
16 |
+
|
17 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
|
18 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_stage1_best.pt"> Finetuned checkpoint for Caption on COCO During Stage1 Finetuning </a>
|
19 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_large_best.pt"> Finetuned checkpoint for RefCOCO </a>
|
20 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_large_best.pt"> Finetuned checkpoint for RefCOCO+ </a>
|
21 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_large_best.pt"> Finetuned checkpoint for RefCOCOg </a>
|
22 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_large_best.pt"> Finetuned checkpoint for VQAv2 </a>
|
23 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_large_best.pt"> Finetuned checkpoint for SNLI-VE </a>
|
24 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_large_best.zip"> Finetuned checkpoint for Text-to-Image Generation on COCO && CLIP checkpoint && VQGAN checkpoint </a>
|
25 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/imagenet_1k_large_best.pt"> Finetuned checkpoint for ImageNet-1K </a>
|
26 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/gigaword_large_best.pt"> Finetuned checkpoint for Gigaword </a>
|
27 |
+
|
28 |
+
|
29 |
+
## Finetuning (OFA-Base)
|
30 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_base_best.pt"> Finetuned base checkpoint for Caption on COCO </a>
|
31 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_base_best.pt"> Finetuned base checkpoint for RefCOCO </a>
|
32 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_base_best.pt"> Finetuned base checkpoint for RefCOCO+ </a>
|
33 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_base_best.pt"> Finetuned base checkpoint for RefCOCOg </a>
|
34 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/vqa_base_best.pt"> Finetuned base checkpoint for VQAv2 </a>
|
35 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/snli_ve_base_best.pt"> Finetuned base checkpoint for SNLI-VE </a>
|
36 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/image_gen_base_best.pt"> Finetuned base checkpoint for Text-to-Image Generation on COCO </a>
|
checkpoints/unival_s2_hs/checkpoint1.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b062bb0fa7c45266ee36326391e355724cccaee3119a9d3ee55d93488838a33
|
3 |
+
size 2570641445
|
checkpoints_cn.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Checkpoints (OFA-CN)
|
2 |
+
|
3 |
+
We provide checkpoints of OFA-CN, which is the Chinese version of OFA. We provide Base-size and Large-size models, including pretrained and finetuned models on image captioning and referring expression comprehension. Note that we translated the texts in the RefCOCO(-/+/g) datasets and finetuned OFA-CN on them. We plan to release the related new datasets in the near future.
|
4 |
+
<br>
|
5 |
+
|
6 |
+
## Checkpoints
|
7 |
+
Below we provide the links for downloading the Chinese OFA checkpoints.
|
8 |
+
|
9 |
+
### Pretraining
|
10 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_large.pt"> Pretrained checkpoint (OFA-CN-Large) </a> (~443M parameters)
|
11 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_base.pt "> Pretrained checkpoint (OFA-CN-Base) </a> (~160M parameters)
|
12 |
+
|
13 |
+
### Finetuning (OFA-Large)
|
14 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_large.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
|
15 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_large.pt"> Finetuned checkpoint for RefCOCO-CN </a>
|
16 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_large.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
|
17 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_large.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
|
18 |
+
|
19 |
+
### Finetuning (OFA-Base)
|
20 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_base.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
|
21 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_base.pt"> Finetuned checkpoint for RefCOCO-CN </a>
|
22 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_base.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
|
23 |
+
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_base.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
|
24 |
+
<br>
|
25 |
+
|
26 |
+
## Model Card
|
27 |
+
Below we provide the basic information of the base-size and large-size OFA-CN.
|
28 |
+
|
29 |
+
<table border="1" width="100%">
|
30 |
+
<tr align="center">
|
31 |
+
<th>Model</th><th>#Params</th><th>Backbone</th><th>Hidden Size</th><th>Intermediate Size</th><th>#Heads</th><th>#Enc. Layers</th><th>#Dec. Layers</th>
|
32 |
+
</tr>
|
33 |
+
<tr align="center">
|
34 |
+
<td>OFA<sub>Base</sub><td>160M</td><td>ResNet101</td><td>768</td></td><td>3072</td><td>12</td><td>6</td><td>6</td>
|
35 |
+
</tr>
|
36 |
+
<tr align="center">
|
37 |
+
<td>OFA<sub>Large</sub></td><td>443M</td><td>ResNet152</td><td>1024</td></td><td>4096</td><td>16</td><td>12</td><td>12</td>
|
38 |
+
</tr>
|
39 |
+
</tr>
|
40 |
+
</table>
|
41 |
+
<br>
|
42 |
+
|
43 |
+
## Results
|
44 |
+
Below we provide the results of OFA-CN and the baselines for comparison.
|
45 |
+
|
46 |
+
### [MUGE Caption]("https://tianchi.aliyun.com/muge")
|
47 |
+
<table border="1" width="100%">
|
48 |
+
<tr align="center">
|
49 |
+
<td>Model</td><td>BLEU@4</td><td>ROUGE-L</td><td>CIDEr-D</td>
|
50 |
+
</tr>
|
51 |
+
<tr align="center">
|
52 |
+
<td>Trm </td><td>7.33</td><td>51.51</td><td>11.00</td>
|
53 |
+
</tr>
|
54 |
+
<tr align="center">
|
55 |
+
<td>M6</td><td>16.19</td><td>55.06</td><td>30.75</td>
|
56 |
+
</tr>
|
57 |
+
<tr align="center">
|
58 |
+
<td>OFA<sub>Base</sub></td><td>26.23</td><td>58.95</td><td>50.70</td>
|
59 |
+
</tr>
|
60 |
+
<tr align="center">
|
61 |
+
<td>OFA<sub>Large</sub></td><td><b>27.32</b></td><td><b>59.20</b></td><td><b>53.51</b></td>
|
62 |
+
</tr>
|
63 |
+
</table>
|
64 |
+
|
65 |
+
### RefCOCO-CN Series
|
66 |
+
<table border="1" width="100%">
|
67 |
+
<tr align="center">
|
68 |
+
<td>Model</td><td>RefCOCO(val/testA/testB)</td><td>RefCOCO+(val/testA/testB)</td><td>RefCOCOg(val/test-u)</td>
|
69 |
+
</tr>
|
70 |
+
<tr align="center">
|
71 |
+
<td>OFA<sub>Base</sub>(random-init)</td><td>30.13/35.07/25.03</td><td>17.89/20.90/15.83</td><td>20.30/20.45</td>
|
72 |
+
</tr>
|
73 |
+
<tr align="center">
|
74 |
+
<td>OFA<sub>Base</sub></td><td>82.18/86.07/<b>76.68</b></td><td>69.38/77.26/60.14</td><td><b>73.57/72.53</b></td>
|
75 |
+
</tr>
|
76 |
+
<tr align="center">
|
77 |
+
<td>OFA<sub>Large</sub></td><td><b>82.84/86.54</b>/76.50</td><td><b>71.30/78.56/61.85</b></td><td>71.96/71.30</td>
|
78 |
+
</tr>
|
79 |
+
</table>
|
80 |
+
<br>
|
81 |
+
|
82 |
+
|
colab.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Colab Notebooks
|
2 |
+
|
3 |
+
We provide Colab notebooks of different downstream tasks for you guys to enjoy OFA. See below.
|
4 |
+
|
5 |
+
* [Image Captioning in Huggingface Transformers](https://colab.research.google.com/drive/1Ho81RBV8jysZ7e0FhsSCk_v938QeDuy3?usp=sharing)
|
6 |
+
* [Generic Interface](https://colab.research.google.com/drive/1jogyZ-2rdHU3XxZOf3TBfhex1XHqX-1m?usp=sharing#scrollTo=s9Vni6YUZOpC) (using different instructions to perform various tasks with just one model.)
|
7 |
+
* [Image Captioning](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
|
8 |
+
* [Referring Expression Comprehension](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing)
|
9 |
+
* [Open-Domain Visual Question Answering](https://colab.research.google.com/drive/14v6OQe_MxV_HMnsiKfnEeMR1UMqhzZNb?usp=sharing)
|
criterions/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .label_smoothed_cross_entropy import AdjustLabelSmoothedCrossEntropyCriterion
|
2 |
+
from .clip_scst_loss import ClipScstRewardCriterion
|
3 |
+
from .label_smoothed_encouraging_loss import AdjustLabelSmoothedEncouragingLossCriterion
|
4 |
+
from .label_smoothed_cross_entropy_scst import AdjustLabelSmoothedCrossEntropySCSTCriterion
|
5 |
+
from .refcoco_scst_loss import RefCOCOScstRewardCriterion
|
criterions/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (457 Bytes). View file
|
|
criterions/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (438 Bytes). View file
|
|
criterions/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (634 Bytes). View file
|
|
criterions/__pycache__/clip_scst_loss.cpython-37.pyc
ADDED
Binary file (9.59 kB). View file
|
|
criterions/__pycache__/clip_scst_loss.cpython-38.pyc
ADDED
Binary file (9.74 kB). View file
|
|
criterions/__pycache__/clip_scst_loss.cpython-39.pyc
ADDED
Binary file (9.73 kB). View file
|
|
criterions/__pycache__/label_smoothed_cross_entropy.cpython-37.pyc
ADDED
Binary file (10.7 kB). View file
|
|
criterions/__pycache__/label_smoothed_cross_entropy.cpython-38.pyc
ADDED
Binary file (10.8 kB). View file
|
|
criterions/__pycache__/label_smoothed_cross_entropy.cpython-39.pyc
ADDED
Binary file (10.7 kB). View file
|
|
criterions/__pycache__/label_smoothed_cross_entropy_scst.cpython-39.pyc
ADDED
Binary file (15.3 kB). View file
|
|
criterions/__pycache__/label_smoothed_encouraging_loss.cpython-37.pyc
ADDED
Binary file (11.7 kB). View file
|
|
criterions/__pycache__/label_smoothed_encouraging_loss.cpython-38.pyc
ADDED
Binary file (11.8 kB). View file
|
|
criterions/__pycache__/label_smoothed_encouraging_loss.cpython-39.pyc
ADDED
Binary file (11.8 kB). View file
|
|
criterions/__pycache__/refcoco_scst_loss.cpython-39.pyc
ADDED
Binary file (13.7 kB). View file
|
|
criterions/__pycache__/scst_loss.cpython-37.pyc
ADDED
Binary file (9.93 kB). View file
|
|
criterions/__pycache__/scst_loss.cpython-38.pyc
ADDED
Binary file (10.1 kB). View file
|
|
criterions/__pycache__/scst_loss.cpython-39.pyc
ADDED
Binary file (11 kB). View file
|
|
criterions/clip_scst_loss.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The OFA-Sys Team.
|
2 |
+
# All rights reserved.
|
3 |
+
# This source code is licensed under the Apache 2.0 license
|
4 |
+
# found in the LICENSE file in the root directory.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
from fairseq import metrics
|
15 |
+
from fairseq.data import data_utils
|
16 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
17 |
+
from fairseq.dataclass import FairseqDataclass
|
18 |
+
from fairseq import utils
|
19 |
+
from omegaconf import II
|
20 |
+
|
21 |
+
from models import clip
|
22 |
+
|
23 |
+
|
24 |
+
def custom_to_pil(x):
|
25 |
+
x = x.detach().cpu()
|
26 |
+
x = torch.clamp(x, -1., 1.)
|
27 |
+
x = (x + 1.) / 2.
|
28 |
+
x = x.permute(1, 2, 0).numpy()
|
29 |
+
x = (255 * x).astype(np.uint8)
|
30 |
+
x = Image.fromarray(x)
|
31 |
+
if not x.mode == "RGB":
|
32 |
+
x = x.convert("RGB")
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
|
37 |
+
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
|
38 |
+
if ignore_index is not None:
|
39 |
+
pad_mask = target.eq(ignore_index)
|
40 |
+
loss.masked_fill_(pad_mask, 0.0)
|
41 |
+
ntokens = (~pad_mask).sum()
|
42 |
+
else:
|
43 |
+
loss = loss.squeeze(-1)
|
44 |
+
ntokens = target.numel()
|
45 |
+
if reduce:
|
46 |
+
loss = loss.sum()
|
47 |
+
return loss, ntokens
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class ClipScstRewardCriterionConfig(FairseqDataclass):
|
52 |
+
ignore_prefix_size: int = field(
|
53 |
+
default=0,
|
54 |
+
metadata={"help": "Ignore first N tokens"},
|
55 |
+
)
|
56 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
57 |
+
constraint_range: Optional[str] = field(
|
58 |
+
default=None,
|
59 |
+
metadata={"help": "constraint range"}
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
@register_criterion(
|
64 |
+
"clip_scst_reward_criterion", dataclass=ClipScstRewardCriterionConfig
|
65 |
+
)
|
66 |
+
class ClipScstRewardCriterion(FairseqCriterion):
|
67 |
+
CLIP_REWARD_WEIGHT = 2.5
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
task,
|
72 |
+
sentence_avg,
|
73 |
+
ignore_prefix_size=0,
|
74 |
+
constraint_range=None
|
75 |
+
):
|
76 |
+
super().__init__(task)
|
77 |
+
self.sentence_avg = sentence_avg
|
78 |
+
self.ignore_prefix_size = ignore_prefix_size
|
79 |
+
|
80 |
+
self.constraint_start = None
|
81 |
+
self.constraint_end = None
|
82 |
+
if constraint_range is not None:
|
83 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
84 |
+
self.constraint_start = int(constraint_start)
|
85 |
+
self.constraint_end = int(constraint_end)
|
86 |
+
|
87 |
+
def forward(self, model, sample, update_num=0, reduce=True):
|
88 |
+
"""Compute the loss for the given sample.
|
89 |
+
|
90 |
+
Returns a tuple with three elements:
|
91 |
+
1) the loss
|
92 |
+
2) the sample size, which is used as the denominator for the gradient
|
93 |
+
3) logging outputs to display while training
|
94 |
+
"""
|
95 |
+
loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
|
96 |
+
|
97 |
+
sample_size = (
|
98 |
+
nsentences if self.sentence_avg else ntokens
|
99 |
+
)
|
100 |
+
logging_output = {
|
101 |
+
"loss": loss.data,
|
102 |
+
"score": score,
|
103 |
+
"ntokens": ntokens,
|
104 |
+
"nsentences": nsentences,
|
105 |
+
"sample_size": sample_size,
|
106 |
+
}
|
107 |
+
return loss, sample_size, logging_output
|
108 |
+
|
109 |
+
def _calculate_clip_scores(self, gen_res, gt_text, device):
|
110 |
+
'''
|
111 |
+
gen_res: generated images, list of Image
|
112 |
+
gt_text: input captions.
|
113 |
+
device: device for clip model
|
114 |
+
'''
|
115 |
+
batch_size = len(gt_text)
|
116 |
+
gen_res_size = len(gen_res)
|
117 |
+
img_per_seq = gen_res_size // batch_size
|
118 |
+
|
119 |
+
hyp_images = torch.stack(
|
120 |
+
[self.task.clip_preprocess(gen_image) for gen_image in gen_res], dim=0
|
121 |
+
).to(device)
|
122 |
+
|
123 |
+
clip_input = clip.tokenize([text for text in gt_text]).to(device)
|
124 |
+
with torch.no_grad():
|
125 |
+
image_features = self.task.clip_model.encode_image(hyp_images)
|
126 |
+
text_features = self.task.clip_model.encode_text(clip_input)
|
127 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
128 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
129 |
+
image_features = image_features.view(batch_size, img_per_seq, -1)
|
130 |
+
text_features = text_features.view(batch_size, 1, -1)
|
131 |
+
ti_similarity = image_features @ text_features.transpose(1, 2)
|
132 |
+
ti_similarity = ti_similarity.view(-1)
|
133 |
+
|
134 |
+
scores = self.CLIP_REWARD_WEIGHT * ti_similarity
|
135 |
+
return scores
|
136 |
+
|
137 |
+
def get_generator_out(self, model, sample):
|
138 |
+
model.eval()
|
139 |
+
with torch.no_grad():
|
140 |
+
self.task.scst_generator.model.eval()
|
141 |
+
gen_out = self.task.scst_generator.generate([model], sample)
|
142 |
+
|
143 |
+
gen_target = []
|
144 |
+
gen_res = []
|
145 |
+
gt_text = []
|
146 |
+
for i in range(len(gen_out)):
|
147 |
+
with torch.no_grad():
|
148 |
+
tokens = torch.stack([item['tokens'][:-1] for item in gen_out[i]], dim=0)
|
149 |
+
tokens += -len(self.task.src_dict) + self.task.cfg.code_dict_size + self.task.cfg.num_bins
|
150 |
+
images = self.task.image_tokenizer.decode_code(
|
151 |
+
tokens.view(-1, self.task.cfg.code_image_size // 8, self.task.cfg.code_image_size // 8)
|
152 |
+
)
|
153 |
+
images = [custom_to_pil(image) for image in images]
|
154 |
+
|
155 |
+
gen_target += [item['tokens'] for item in gen_out[i]]
|
156 |
+
gen_res += images
|
157 |
+
gt_text.append(
|
158 |
+
self.task.bpe.decode(
|
159 |
+
self.task.tgt_dict.string(
|
160 |
+
utils.strip_pad(sample['net_input']['src_tokens'][i], self.padding_idx).cpu().int()
|
161 |
+
)
|
162 |
+
)[38:] # remove task instruction.
|
163 |
+
)
|
164 |
+
|
165 |
+
return gen_target, gen_res, gt_text
|
166 |
+
|
167 |
+
def get_reward_and_scores(self, gen_res, gt_text, device):
|
168 |
+
batch_size = len(gt_text)
|
169 |
+
gen_res_size = len(gen_res)
|
170 |
+
img_per_sample = gen_res_size // batch_size
|
171 |
+
|
172 |
+
scores = self._calculate_clip_scores(gen_res, gt_text, device)
|
173 |
+
sc_ = scores.reshape(batch_size, img_per_sample)
|
174 |
+
baseline = (sc_.sum(1, keepdim=True) - sc_) / (sc_.shape[1] - 1)
|
175 |
+
# sample - baseline
|
176 |
+
reward = scores.reshape(batch_size, img_per_sample)
|
177 |
+
reward = reward - baseline
|
178 |
+
reward = reward.view(-1)
|
179 |
+
|
180 |
+
return reward, scores
|
181 |
+
|
182 |
+
def get_net_output(self, model, sample, gen_target):
|
183 |
+
def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
|
184 |
+
return data_utils.collate_tokens(
|
185 |
+
sample_list,
|
186 |
+
pad_idx=self.padding_idx,
|
187 |
+
eos_idx=eos,
|
188 |
+
left_pad=False,
|
189 |
+
move_eos_to_beginning=move_eos_to_beginning,
|
190 |
+
)
|
191 |
+
|
192 |
+
batch_size = len(sample["target"])
|
193 |
+
gen_target_size = len(gen_target)
|
194 |
+
img_per_sample = gen_target_size // batch_size
|
195 |
+
|
196 |
+
model.train()
|
197 |
+
sample_src_tokens = torch.repeat_interleave(
|
198 |
+
sample['net_input']['src_tokens'], img_per_sample, dim=0
|
199 |
+
)
|
200 |
+
sample_src_lengths = torch.repeat_interleave(
|
201 |
+
sample['net_input']['src_lengths'], img_per_sample, dim=0
|
202 |
+
)
|
203 |
+
sample_code_masks = torch.repeat_interleave(
|
204 |
+
sample['net_input']['code_masks'], img_per_sample, dim=0
|
205 |
+
)
|
206 |
+
gen_prev_output_tokens = torch.as_tensor(
|
207 |
+
merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
|
208 |
+
device=sample["target"].device, dtype=torch.int64
|
209 |
+
)
|
210 |
+
gen_target_tokens = torch.as_tensor(
|
211 |
+
merge(gen_target), device=sample["target"].device, dtype=torch.int64
|
212 |
+
)
|
213 |
+
net_output = model(
|
214 |
+
src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
|
215 |
+
code_masks=sample_code_masks, prev_output_tokens=gen_prev_output_tokens
|
216 |
+
)
|
217 |
+
|
218 |
+
return net_output, gen_target_tokens
|
219 |
+
|
220 |
+
def get_lprobs_and_target(self, model, net_output, gen_target):
|
221 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
222 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
223 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
224 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
225 |
+
if self.ignore_prefix_size > 0:
|
226 |
+
if getattr(lprobs, "batch_first", False):
|
227 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
228 |
+
gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
|
229 |
+
else:
|
230 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
231 |
+
gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
|
232 |
+
return lprobs, gen_target
|
233 |
+
|
234 |
+
def compute_loss(self, model, sample, reduce=True):
|
235 |
+
gen_target, gen_res, gt_text = self.get_generator_out(model, sample)
|
236 |
+
reward, scores = self.get_reward_and_scores(gen_res, gt_text, device=sample["target"].device)
|
237 |
+
net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
|
238 |
+
gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
|
239 |
+
loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
|
240 |
+
nsentences = gen_target_tokens.size(0)
|
241 |
+
|
242 |
+
return loss, scores.sum(), ntokens, nsentences
|
243 |
+
|
244 |
+
@classmethod
|
245 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
246 |
+
"""Aggregate logging outputs from data parallel training."""
|
247 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
248 |
+
score_sum = sum(log.get("score", 0) for log in logging_outputs)
|
249 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
250 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
251 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
252 |
+
|
253 |
+
metrics.log_scalar(
|
254 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
255 |
+
)
|
256 |
+
metrics.log_scalar(
|
257 |
+
"score", score_sum / nsentences, nsentences, round=3
|
258 |
+
)
|
259 |
+
|
260 |
+
metrics.log_scalar(
|
261 |
+
"ntokens", ntokens, 1, round=3
|
262 |
+
)
|
263 |
+
metrics.log_scalar(
|
264 |
+
"nsentences", nsentences, 1, round=3
|
265 |
+
)
|
266 |
+
metrics.log_scalar(
|
267 |
+
"sample_size", sample_size, 1, round=3
|
268 |
+
)
|
269 |
+
|
270 |
+
@staticmethod
|
271 |
+
def logging_outputs_can_be_summed() -> bool:
|
272 |
+
"""
|
273 |
+
Whether the logging outputs returned by `forward` can be summed
|
274 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
275 |
+
to True will improves distributed training speed.
|
276 |
+
"""
|
277 |
+
return True
|
criterions/label_smoothed_cross_entropy.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The OFA-Sys Team.
|
2 |
+
# All rights reserved.
|
3 |
+
# This source code is licensed under the Apache 2.0 license
|
4 |
+
# found in the LICENSE file in the root directory.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import numpy as np
|
13 |
+
from fairseq import metrics, utils
|
14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
+
from fairseq.dataclass import FairseqDataclass
|
16 |
+
from omegaconf import II
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class AdjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
|
21 |
+
label_smoothing: float = field(
|
22 |
+
default=0.0,
|
23 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
24 |
+
)
|
25 |
+
report_accuracy: bool = field(
|
26 |
+
default=False,
|
27 |
+
metadata={"help": "report accuracy metric"},
|
28 |
+
)
|
29 |
+
ignore_prefix_size: int = field(
|
30 |
+
default=0,
|
31 |
+
metadata={"help": "Ignore first N tokens"},
|
32 |
+
)
|
33 |
+
ignore_eos: bool = field(
|
34 |
+
default=False,
|
35 |
+
metadata={"help": "Ignore eos token"},
|
36 |
+
)
|
37 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
38 |
+
drop_worst_ratio: float = field(
|
39 |
+
default=0.0,
|
40 |
+
metadata={"help": "ratio for discarding bad samples"},
|
41 |
+
)
|
42 |
+
drop_worst_after: int = field(
|
43 |
+
default=0,
|
44 |
+
metadata={"help": "steps for discarding bad samples"},
|
45 |
+
)
|
46 |
+
use_rdrop: bool = field(
|
47 |
+
default=False, metadata={"help": "use R-Drop"}
|
48 |
+
)
|
49 |
+
reg_alpha: float = field(
|
50 |
+
default=1.0, metadata={"help": "weight for R-Drop"}
|
51 |
+
)
|
52 |
+
sample_patch_num: int = field(
|
53 |
+
default=196, metadata={"help": "sample patches for v1"}
|
54 |
+
)
|
55 |
+
constraint_range: Optional[str] = field(
|
56 |
+
default=None,
|
57 |
+
metadata={"help": "constraint range"}
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def construct_rdrop_sample(x):
|
62 |
+
if isinstance(x, dict):
|
63 |
+
for key in x:
|
64 |
+
x[key] = construct_rdrop_sample(x[key])
|
65 |
+
return x
|
66 |
+
elif isinstance(x, torch.Tensor):
|
67 |
+
return x.repeat(2, *([1] * (x.dim()-1)))
|
68 |
+
elif isinstance(x, int):
|
69 |
+
return x * 2
|
70 |
+
elif isinstance(x, np.ndarray):
|
71 |
+
return x.repeat(2)
|
72 |
+
else:
|
73 |
+
raise NotImplementedError
|
74 |
+
|
75 |
+
|
76 |
+
def kl_loss(p, q):
|
77 |
+
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
|
78 |
+
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
|
79 |
+
loss = (p_loss + q_loss) / 2
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def label_smoothed_nll_loss(
|
84 |
+
lprobs, target, epsilon, update_num, reduce=True,
|
85 |
+
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
|
86 |
+
constraint_masks=None, constraint_start=None, constraint_end=None
|
87 |
+
):
|
88 |
+
if target.dim() == lprobs.dim() - 1:
|
89 |
+
target = target.unsqueeze(-1)
|
90 |
+
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
|
91 |
+
if constraint_masks is not None:
|
92 |
+
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
|
93 |
+
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
|
94 |
+
elif constraint_start is not None and constraint_end is not None:
|
95 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
96 |
+
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
|
97 |
+
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
|
98 |
+
else:
|
99 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
|
100 |
+
eps_i = epsilon / (lprobs.size(-1) - 1)
|
101 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
102 |
+
if drop_worst_ratio > 0 and update_num > drop_worst_after:
|
103 |
+
if use_rdrop:
|
104 |
+
true_batch_size = loss.size(0) // 2
|
105 |
+
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
|
106 |
+
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
|
107 |
+
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
|
108 |
+
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
|
109 |
+
else:
|
110 |
+
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
|
111 |
+
nll_loss = nll_loss[indices]
|
112 |
+
lprobs = lprobs[indices]
|
113 |
+
|
114 |
+
ntokens = loss.numel()
|
115 |
+
nll_loss = nll_loss.sum()
|
116 |
+
loss = loss.sum()
|
117 |
+
if use_rdrop:
|
118 |
+
true_batch_size = lprobs.size(0) // 2
|
119 |
+
p = lprobs[:true_batch_size]
|
120 |
+
q = lprobs[true_batch_size:]
|
121 |
+
if constraint_start is not None and constraint_end is not None:
|
122 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
123 |
+
p = p[:, constraint_range]
|
124 |
+
q = q[:, constraint_range]
|
125 |
+
loss += kl_loss(p, q) * reg_alpha
|
126 |
+
|
127 |
+
return loss, nll_loss, ntokens
|
128 |
+
|
129 |
+
|
130 |
+
@register_criterion(
|
131 |
+
"adjust_label_smoothed_cross_entropy", dataclass=AdjustLabelSmoothedCrossEntropyCriterionConfig
|
132 |
+
)
|
133 |
+
class AdjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
task,
|
137 |
+
sentence_avg,
|
138 |
+
label_smoothing,
|
139 |
+
ignore_prefix_size=0,
|
140 |
+
ignore_eos=False,
|
141 |
+
report_accuracy=False,
|
142 |
+
drop_worst_ratio=0,
|
143 |
+
drop_worst_after=0,
|
144 |
+
use_rdrop=False,
|
145 |
+
reg_alpha=1.0,
|
146 |
+
sample_patch_num=196,
|
147 |
+
constraint_range=None
|
148 |
+
):
|
149 |
+
super().__init__(task)
|
150 |
+
self.sentence_avg = sentence_avg
|
151 |
+
self.eps = label_smoothing
|
152 |
+
self.ignore_prefix_size = ignore_prefix_size
|
153 |
+
self.ignore_eos = ignore_eos
|
154 |
+
self.report_accuracy = report_accuracy
|
155 |
+
self.drop_worst_ratio = drop_worst_ratio
|
156 |
+
self.drop_worst_after = drop_worst_after
|
157 |
+
self.use_rdrop = use_rdrop
|
158 |
+
self.reg_alpha = reg_alpha
|
159 |
+
self.sample_patch_num = sample_patch_num
|
160 |
+
|
161 |
+
self.constraint_start = None
|
162 |
+
self.constraint_end = None
|
163 |
+
if constraint_range is not None:
|
164 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
165 |
+
self.constraint_start = int(constraint_start)
|
166 |
+
self.constraint_end = int(constraint_end)
|
167 |
+
|
168 |
+
def forward(self, model, sample, update_num=0, reduce=True):
|
169 |
+
"""Compute the loss for the given sample.
|
170 |
+
|
171 |
+
Returns a tuple with three elements:
|
172 |
+
1) the loss
|
173 |
+
2) the sample size, which is used as the denominator for the gradient
|
174 |
+
3) logging outputs to display while training
|
175 |
+
"""
|
176 |
+
if isinstance(sample, list):
|
177 |
+
if self.sample_patch_num > 0:
|
178 |
+
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
|
179 |
+
# change to support len(samples) > 2
|
180 |
+
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
|
181 |
+
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
|
182 |
+
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
|
183 |
+
sample_size = 1
|
184 |
+
logging_output = {
|
185 |
+
"loss": loss.data,
|
186 |
+
"loss_v1": loss_v1.data,
|
187 |
+
"loss_v2": loss_v2.data,
|
188 |
+
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
|
189 |
+
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
|
190 |
+
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
|
191 |
+
"sample_size": 1,
|
192 |
+
"sample_size_v1": sample_size_v1,
|
193 |
+
"sample_size_v2": sample_size_v2,
|
194 |
+
}
|
195 |
+
return loss, sample_size, logging_output
|
196 |
+
|
197 |
+
if self.use_rdrop:
|
198 |
+
construct_rdrop_sample(sample)
|
199 |
+
|
200 |
+
net_output = model(**sample["net_input"])
|
201 |
+
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
|
202 |
+
sample_size = (
|
203 |
+
sample["target"].size(0) if self.sentence_avg else ntokens
|
204 |
+
)
|
205 |
+
logging_output = {
|
206 |
+
"loss": loss.data,
|
207 |
+
"nll_loss": nll_loss.data,
|
208 |
+
"ntokens": sample["ntokens"],
|
209 |
+
"nsentences": sample["nsentences"],
|
210 |
+
"sample_size": sample_size,
|
211 |
+
}
|
212 |
+
if self.report_accuracy:
|
213 |
+
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
214 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
215 |
+
logging_output["total"] = utils.item(total.data)
|
216 |
+
|
217 |
+
return loss, sample_size, logging_output
|
218 |
+
|
219 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
220 |
+
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
|
221 |
+
constraint_masks = None
|
222 |
+
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
|
223 |
+
constraint_masks = sample["constraint_masks"]
|
224 |
+
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
225 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
226 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
227 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
228 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
|
229 |
+
target = model.get_targets(sample, net_output)
|
230 |
+
if self.ignore_prefix_size > 0:
|
231 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
232 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
233 |
+
if constraint_masks is not None:
|
234 |
+
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
|
235 |
+
if self.ignore_eos:
|
236 |
+
bsz, seq_len, embed_dim = lprobs.size()
|
237 |
+
eos_indices = target.eq(self.task.tgt_dict.eos())
|
238 |
+
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
239 |
+
target = target[~eos_indices].reshape(bsz, seq_len-1)
|
240 |
+
if constraint_masks is not None:
|
241 |
+
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
242 |
+
if constraint_masks is not None:
|
243 |
+
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
|
244 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
|
245 |
+
|
246 |
+
def compute_loss(self, model, net_output, sample, update_num, reduce=True):
|
247 |
+
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
|
248 |
+
if constraint_masks is not None:
|
249 |
+
constraint_masks = constraint_masks[target != self.padding_idx]
|
250 |
+
# print(target.shape, self.padding_idx, lprobs.shape, target, lprobs)
|
251 |
+
lprobs = lprobs[target != self.padding_idx]
|
252 |
+
target = target[target != self.padding_idx]
|
253 |
+
loss, nll_loss, ntokens = label_smoothed_nll_loss(
|
254 |
+
lprobs,
|
255 |
+
target,
|
256 |
+
self.eps,
|
257 |
+
update_num,
|
258 |
+
reduce=reduce,
|
259 |
+
drop_worst_ratio=self.drop_worst_ratio,
|
260 |
+
drop_worst_after=self.drop_worst_after,
|
261 |
+
use_rdrop=self.use_rdrop,
|
262 |
+
reg_alpha=self.reg_alpha,
|
263 |
+
constraint_masks=constraint_masks,
|
264 |
+
constraint_start=self.constraint_start,
|
265 |
+
constraint_end=self.constraint_end
|
266 |
+
)
|
267 |
+
return loss, nll_loss, ntokens
|
268 |
+
|
269 |
+
def compute_accuracy(self, model, net_output, sample):
|
270 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
271 |
+
mask = target.ne(self.padding_idx)
|
272 |
+
n_correct = torch.sum(
|
273 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
274 |
+
)
|
275 |
+
total = torch.sum(mask)
|
276 |
+
return n_correct, total
|
277 |
+
|
278 |
+
@classmethod
|
279 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
280 |
+
"""Aggregate logging outputs from data parallel training."""
|
281 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
282 |
+
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
|
283 |
+
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
|
284 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
285 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
286 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
287 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
288 |
+
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
|
289 |
+
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
|
290 |
+
|
291 |
+
metrics.log_scalar(
|
292 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
293 |
+
)
|
294 |
+
metrics.log_scalar(
|
295 |
+
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
|
296 |
+
)
|
297 |
+
metrics.log_scalar(
|
298 |
+
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
|
299 |
+
)
|
300 |
+
metrics.log_scalar(
|
301 |
+
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
|
302 |
+
)
|
303 |
+
metrics.log_derived(
|
304 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
305 |
+
)
|
306 |
+
|
307 |
+
metrics.log_scalar(
|
308 |
+
"ntokens", ntokens, 1, round=3
|
309 |
+
)
|
310 |
+
metrics.log_scalar(
|
311 |
+
"nsentences", nsentences, 1, round=3
|
312 |
+
)
|
313 |
+
metrics.log_scalar(
|
314 |
+
"sample_size", sample_size, 1, round=3
|
315 |
+
)
|
316 |
+
metrics.log_scalar(
|
317 |
+
"sample_size_v1", sample_size_v1, 1, round=3
|
318 |
+
)
|
319 |
+
metrics.log_scalar(
|
320 |
+
"sample_size_v2", sample_size_v2, 1, round=3
|
321 |
+
)
|
322 |
+
|
323 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
324 |
+
if total > 0:
|
325 |
+
metrics.log_scalar("total", total)
|
326 |
+
n_correct = utils.item(
|
327 |
+
sum(log.get("n_correct", 0) for log in logging_outputs)
|
328 |
+
)
|
329 |
+
metrics.log_scalar("n_correct", n_correct)
|
330 |
+
metrics.log_derived(
|
331 |
+
"accuracy",
|
332 |
+
lambda meters: round(
|
333 |
+
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
334 |
+
)
|
335 |
+
if meters["total"].sum > 0
|
336 |
+
else float("nan"),
|
337 |
+
)
|
338 |
+
|
339 |
+
@staticmethod
|
340 |
+
def logging_outputs_can_be_summed() -> bool:
|
341 |
+
"""
|
342 |
+
Whether the logging outputs returned by `forward` can be summed
|
343 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
344 |
+
to True will improves distributed training speed.
|
345 |
+
"""
|
346 |
+
return True
|
criterions/label_smoothed_cross_entropy_scst.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The OFA-Sys Team.
|
2 |
+
# All rights reserved.
|
3 |
+
# This source code is licensed under the Apache 2.0 license
|
4 |
+
# found in the LICENSE file in the root directory.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import numpy as np
|
13 |
+
from fairseq import metrics, utils
|
14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
+
from fairseq.dataclass import FairseqDataclass
|
16 |
+
from omegaconf import II
|
17 |
+
|
18 |
+
|
19 |
+
from mapcalc import calculate_map, calculate_map_range
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class AdjustLabelSmoothedCrossEntropySCSTCriterionConfig(FairseqDataclass):
|
23 |
+
label_smoothing: float = field(
|
24 |
+
default=0.0,
|
25 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
26 |
+
)
|
27 |
+
report_accuracy: bool = field(
|
28 |
+
default=False,
|
29 |
+
metadata={"help": "report accuracy metric"},
|
30 |
+
)
|
31 |
+
ignore_prefix_size: int = field(
|
32 |
+
default=0,
|
33 |
+
metadata={"help": "Ignore first N tokens"},
|
34 |
+
)
|
35 |
+
ignore_eos: bool = field(
|
36 |
+
default=False,
|
37 |
+
metadata={"help": "Ignore eos token"},
|
38 |
+
)
|
39 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
40 |
+
drop_worst_ratio: float = field(
|
41 |
+
default=0.0,
|
42 |
+
metadata={"help": "ratio for discarding bad samples"},
|
43 |
+
)
|
44 |
+
drop_worst_after: int = field(
|
45 |
+
default=0,
|
46 |
+
metadata={"help": "steps for discarding bad samples"},
|
47 |
+
)
|
48 |
+
use_rdrop: bool = field(
|
49 |
+
default=False, metadata={"help": "use R-Drop"}
|
50 |
+
)
|
51 |
+
reg_alpha: float = field(
|
52 |
+
default=1.0, metadata={"help": "weight for R-Drop"}
|
53 |
+
)
|
54 |
+
sample_patch_num: int = field(
|
55 |
+
default=196, metadata={"help": "sample patches for v1"}
|
56 |
+
)
|
57 |
+
constraint_range: Optional[str] = field(
|
58 |
+
default=None,
|
59 |
+
metadata={"help": "constraint range"}
|
60 |
+
)
|
61 |
+
acc_thresh: Optional[float] = field(
|
62 |
+
default=None, metadata={"help": "acc thresh for refcoco"}
|
63 |
+
)
|
64 |
+
metric: Optional[str] = field(
|
65 |
+
default='acc',
|
66 |
+
metadata={"help": "metric"}
|
67 |
+
)
|
68 |
+
|
69 |
+
max_area_size: Optional[float] = field(
|
70 |
+
default=None, metadata={"help": "max_area_size"}
|
71 |
+
)
|
72 |
+
|
73 |
+
min_area_size: Optional[float] = field(
|
74 |
+
default=None, metadata={"help": "min_area_size"}
|
75 |
+
)
|
76 |
+
logprob: Optional[bool] = field(
|
77 |
+
default=False, metadata={"help": "maximise log prob"}
|
78 |
+
)
|
79 |
+
|
80 |
+
pos_reward: Optional[float] = field(
|
81 |
+
default=None, metadata={"help": "pos_reward"}
|
82 |
+
)
|
83 |
+
|
84 |
+
neg_reward: Optional[float] = field(
|
85 |
+
default=None, metadata={"help": "neg_reward"}
|
86 |
+
)
|
87 |
+
|
88 |
+
reinforce: Optional[bool] = field(
|
89 |
+
default=False, metadata={"help": "reinforce"}
|
90 |
+
)
|
91 |
+
|
92 |
+
lambda_reinforce: Optional[float] = field(
|
93 |
+
default=0, metadata={"help": "lambda_reinforce"}
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
def construct_rdrop_sample(x):
|
99 |
+
if isinstance(x, dict):
|
100 |
+
for key in x:
|
101 |
+
x[key] = construct_rdrop_sample(x[key])
|
102 |
+
return x
|
103 |
+
elif isinstance(x, torch.Tensor):
|
104 |
+
return x.repeat(2, *([1] * (x.dim()-1)))
|
105 |
+
elif isinstance(x, int):
|
106 |
+
return x * 2
|
107 |
+
elif isinstance(x, np.ndarray):
|
108 |
+
return x.repeat(2)
|
109 |
+
else:
|
110 |
+
raise NotImplementedError
|
111 |
+
|
112 |
+
|
113 |
+
def kl_loss(p, q):
|
114 |
+
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
|
115 |
+
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
|
116 |
+
loss = (p_loss + q_loss) / 2
|
117 |
+
return loss
|
118 |
+
|
119 |
+
|
120 |
+
def label_smoothed_nll_loss(
|
121 |
+
lprobs, target, epsilon, update_num, reduce=True,
|
122 |
+
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
|
123 |
+
constraint_masks=None, constraint_start=None, constraint_end=None
|
124 |
+
):
|
125 |
+
if target.dim() == lprobs.dim() - 1:
|
126 |
+
target = target.unsqueeze(-1)
|
127 |
+
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
|
128 |
+
if constraint_masks is not None:
|
129 |
+
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
|
130 |
+
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
|
131 |
+
elif constraint_start is not None and constraint_end is not None:
|
132 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
133 |
+
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
|
134 |
+
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
|
135 |
+
else:
|
136 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
|
137 |
+
eps_i = epsilon / (lprobs.size(-1) - 1)
|
138 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
139 |
+
if drop_worst_ratio > 0 and update_num > drop_worst_after:
|
140 |
+
if use_rdrop:
|
141 |
+
true_batch_size = loss.size(0) // 2
|
142 |
+
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
|
143 |
+
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
|
144 |
+
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
|
145 |
+
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
|
146 |
+
else:
|
147 |
+
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
|
148 |
+
nll_loss = nll_loss[indices]
|
149 |
+
lprobs = lprobs[indices]
|
150 |
+
|
151 |
+
ntokens = loss.numel()
|
152 |
+
nll_loss = nll_loss.sum()
|
153 |
+
# loss = loss.sum()
|
154 |
+
if use_rdrop:
|
155 |
+
true_batch_size = lprobs.size(0) // 2
|
156 |
+
p = lprobs[:true_batch_size]
|
157 |
+
q = lprobs[true_batch_size:]
|
158 |
+
if constraint_start is not None and constraint_end is not None:
|
159 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
160 |
+
p = p[:, constraint_range]
|
161 |
+
q = q[:, constraint_range]
|
162 |
+
loss = loss + ((kl_loss(p, q) * reg_alpha)/loss.shape[0])
|
163 |
+
|
164 |
+
return loss, nll_loss, ntokens
|
165 |
+
|
166 |
+
|
167 |
+
@register_criterion(
|
168 |
+
"adjust_label_smoothed_cross_entropy_scst", dataclass=AdjustLabelSmoothedCrossEntropySCSTCriterionConfig
|
169 |
+
)
|
170 |
+
class AdjustLabelSmoothedCrossEntropySCSTCriterion(FairseqCriterion):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
task,
|
174 |
+
sentence_avg,
|
175 |
+
label_smoothing,
|
176 |
+
ignore_prefix_size=0,
|
177 |
+
ignore_eos=False,
|
178 |
+
report_accuracy=False,
|
179 |
+
drop_worst_ratio=0,
|
180 |
+
drop_worst_after=0,
|
181 |
+
use_rdrop=False,
|
182 |
+
reg_alpha=1.0,
|
183 |
+
sample_patch_num=196,
|
184 |
+
constraint_range=None,
|
185 |
+
acc_thresh=None,
|
186 |
+
metric='acc',
|
187 |
+
max_area_size=None,
|
188 |
+
min_area_size=None,
|
189 |
+
logprob=False,
|
190 |
+
pos_reward=None,
|
191 |
+
neg_reward=None,
|
192 |
+
reinforce=False,
|
193 |
+
lambda_reinforce=0,
|
194 |
+
):
|
195 |
+
super().__init__(task)
|
196 |
+
self.sentence_avg = sentence_avg
|
197 |
+
self.eps = label_smoothing
|
198 |
+
self.ignore_prefix_size = ignore_prefix_size
|
199 |
+
self.ignore_eos = ignore_eos
|
200 |
+
self.report_accuracy = report_accuracy
|
201 |
+
self.drop_worst_ratio = drop_worst_ratio
|
202 |
+
self.drop_worst_after = drop_worst_after
|
203 |
+
self.use_rdrop = use_rdrop
|
204 |
+
self.reg_alpha = reg_alpha
|
205 |
+
self.sample_patch_num = sample_patch_num
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
self.constraint_start = None
|
210 |
+
self.constraint_end = None
|
211 |
+
if constraint_range is not None:
|
212 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
213 |
+
self.constraint_start = int(constraint_start)
|
214 |
+
self.constraint_end = int(constraint_end)
|
215 |
+
|
216 |
+
self.acc_thresh = acc_thresh
|
217 |
+
self.metric = metric
|
218 |
+
self.min_area_size = min_area_size
|
219 |
+
self.max_area_size = max_area_size
|
220 |
+
self.logprob = logprob
|
221 |
+
|
222 |
+
self.pos_reward = pos_reward
|
223 |
+
self.neg_reward = neg_reward
|
224 |
+
|
225 |
+
self.reinforce = reinforce
|
226 |
+
self.lambda_reinforce = lambda_reinforce
|
227 |
+
|
228 |
+
def get_generator_out(self, model, sample):
|
229 |
+
|
230 |
+
model.eval()
|
231 |
+
with torch.no_grad():
|
232 |
+
self.task.scst_generator.model.eval()
|
233 |
+
gen_out = self.task.scst_generator.generate([model], sample)
|
234 |
+
|
235 |
+
hyps, refs = [], []
|
236 |
+
for i in range(len(gen_out)):
|
237 |
+
hyps.append(gen_out[i][0]["tokens"][:-1] - len(self.task.src_dict) + self.task.cfg.num_bins)
|
238 |
+
refs.append(sample["target"][i][:-1] - len(self.task.src_dict) + self.task.cfg.num_bins)
|
239 |
+
|
240 |
+
return torch.stack(hyps, dim=0), torch.stack(refs, dim=0)
|
241 |
+
|
242 |
+
def _calculate_map_score(self, hyps, refs, thresh=0.5):
|
243 |
+
|
244 |
+
|
245 |
+
ground_truth = {
|
246 |
+
'boxes': refs.cpu().numpy().tolist(),
|
247 |
+
|
248 |
+
'labels': [1 for i in range(refs.shape[0])]
|
249 |
+
}
|
250 |
+
|
251 |
+
result_dict = {
|
252 |
+
'boxes': hyps.cpu().numpy().tolist(),
|
253 |
+
|
254 |
+
'labels': [1 for i in range(hyps.shape[0])],
|
255 |
+
}
|
256 |
+
|
257 |
+
score = calculate_map(ground_truth, result_dict, thresh)
|
258 |
+
|
259 |
+
score = torch.tensor(score).unsqueeze(0).repeat(refs.shape[0]).to(hyps.device)
|
260 |
+
return score
|
261 |
+
|
262 |
+
def _calculate_ap_score(self, hyps, refs, thresh=0.5, min_area_size=None, max_area_size=None):
|
263 |
+
interacts = torch.cat(
|
264 |
+
[torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]),
|
265 |
+
torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])],
|
266 |
+
dim=1
|
267 |
+
)
|
268 |
+
area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1]) ## x1, y1, x2, y2, x1 < x2
|
269 |
+
area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
|
270 |
+
interacts_w = interacts[:, 2] - interacts[:, 0]
|
271 |
+
interacts_h = interacts[:, 3] - interacts[:, 1]
|
272 |
+
area_interacts = interacts_w * interacts_h
|
273 |
+
ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)
|
274 |
+
|
275 |
+
|
276 |
+
if max_area_size is not None and min_area_size is not None:
|
277 |
+
ious = ious * (torch.logical_or(area_targets < max_area_size, area_targets > min_area_size).float())
|
278 |
+
|
279 |
+
elif min_area_size is not None:
|
280 |
+
ious = ious * (area_targets > min_area_size).float()
|
281 |
+
|
282 |
+
elif max_area_size is not None:
|
283 |
+
ious = ious * (area_targets < max_area_size).float()
|
284 |
+
|
285 |
+
if thresh is None:
|
286 |
+
return ious
|
287 |
+
else:
|
288 |
+
return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float()
|
289 |
+
|
290 |
+
def reward_step(self, sample, model):
|
291 |
+
|
292 |
+
hyps, refs = self.get_generator_out(model, sample)
|
293 |
+
hyps = hyps / (self.task.cfg.num_bins - 1) * self.task.cfg.max_image_size
|
294 |
+
refs = refs / (self.task.cfg.num_bins - 1) * self.task.cfg.max_image_size
|
295 |
+
hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
|
296 |
+
hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
|
297 |
+
refs[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
|
298 |
+
refs[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
|
299 |
+
|
300 |
+
# scores = self._calculate_ap_score(hyps, refs)
|
301 |
+
if self.metric == 'acc':
|
302 |
+
scores = self._calculate_ap_score(hyps, sample['region_coords'].float(), thresh=self.acc_thresh,
|
303 |
+
min_area_size=self.min_area_size, max_area_size=self.max_area_size)
|
304 |
+
elif self.metric == 'map':
|
305 |
+
scores = self._calculate_map_score(hyps, sample['region_coords'].float(), thresh=self.acc_thresh)
|
306 |
+
else:
|
307 |
+
raise NotImplemented
|
308 |
+
|
309 |
+
# logging_output["_score_sum"] = scores.sum().item()
|
310 |
+
# logging_output["_score_cnt"] = scores.size(0)
|
311 |
+
|
312 |
+
if self.pos_reward:
|
313 |
+
scores = torch.where(scores > 0, self.pos_reward, scores)
|
314 |
+
if self.neg_reward:
|
315 |
+
scores = torch.where(scores == 0, self.neg_reward, scores)
|
316 |
+
|
317 |
+
|
318 |
+
return scores
|
319 |
+
|
320 |
+
def forward(self, model, sample, update_num=0, reduce=True):
|
321 |
+
"""Compute the loss for the given sample.
|
322 |
+
|
323 |
+
Returns a tuple with three elements:
|
324 |
+
1) the loss
|
325 |
+
2) the sample size, which is used as the denominator for the gradient
|
326 |
+
3) logging outputs to display while training
|
327 |
+
"""
|
328 |
+
if isinstance(sample, list):
|
329 |
+
if self.sample_patch_num > 0:
|
330 |
+
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
|
331 |
+
# change to support len(samples) > 2
|
332 |
+
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
|
333 |
+
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
|
334 |
+
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
|
335 |
+
sample_size = 1
|
336 |
+
logging_output = {
|
337 |
+
"loss": loss.data,
|
338 |
+
"loss_v1": loss_v1.data,
|
339 |
+
"loss_v2": loss_v2.data,
|
340 |
+
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
|
341 |
+
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
|
342 |
+
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
|
343 |
+
"sample_size": 1,
|
344 |
+
"sample_size_v1": sample_size_v1,
|
345 |
+
"sample_size_v2": sample_size_v2,
|
346 |
+
"reward": (logging_output_v1["reward"] + logging_output_v2["reward"])/2,
|
347 |
+
}
|
348 |
+
return loss, sample_size, logging_output
|
349 |
+
|
350 |
+
if self.use_rdrop:
|
351 |
+
construct_rdrop_sample(sample)
|
352 |
+
|
353 |
+
### scst
|
354 |
+
reward = self.reward_step(sample, model) # shape = bs
|
355 |
+
model.train()
|
356 |
+
net_output = model(**sample["net_input"])
|
357 |
+
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce, reward=reward)
|
358 |
+
|
359 |
+
|
360 |
+
|
361 |
+
|
362 |
+
# loss = loss*reward
|
363 |
+
|
364 |
+
loss = loss.sum()
|
365 |
+
sample_size = (
|
366 |
+
sample["target"].size(0) if self.sentence_avg else ntokens
|
367 |
+
)
|
368 |
+
logging_output = {
|
369 |
+
"loss": loss.data,
|
370 |
+
"nll_loss": nll_loss.data,
|
371 |
+
"ntokens": sample["ntokens"],
|
372 |
+
"nsentences": sample["nsentences"],
|
373 |
+
"sample_size": sample_size,
|
374 |
+
"reward": reward.mean(),
|
375 |
+
}
|
376 |
+
if self.report_accuracy:
|
377 |
+
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
378 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
379 |
+
logging_output["total"] = utils.item(total.data)
|
380 |
+
|
381 |
+
return loss, sample_size, logging_output
|
382 |
+
|
383 |
+
def get_lprobs_and_target(self, model, net_output, sample, reward=None):
|
384 |
+
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
|
385 |
+
constraint_masks = None
|
386 |
+
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
|
387 |
+
constraint_masks = sample["constraint_masks"]
|
388 |
+
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
389 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
390 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
391 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
392 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
|
393 |
+
target = model.get_targets(sample, net_output)
|
394 |
+
if self.ignore_prefix_size > 0:
|
395 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
396 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
397 |
+
if constraint_masks is not None:
|
398 |
+
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
|
399 |
+
if self.ignore_eos:
|
400 |
+
bsz, seq_len, embed_dim = lprobs.size()
|
401 |
+
eos_indices = target.eq(self.task.tgt_dict.eos())
|
402 |
+
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
403 |
+
target = target[~eos_indices].reshape(bsz, seq_len-1)
|
404 |
+
if constraint_masks is not None:
|
405 |
+
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
406 |
+
if constraint_masks is not None:
|
407 |
+
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
|
408 |
+
|
409 |
+
if reward is not None:
|
410 |
+
reward = reward.unsqueeze(1).unsqueeze(1)
|
411 |
+
lprobs = lprobs*reward
|
412 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
|
413 |
+
|
414 |
+
def compute_loss(self, model, net_output, sample, update_num, reduce=True, reward=None):
|
415 |
+
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample, reward=reward)
|
416 |
+
|
417 |
+
if constraint_masks is not None:
|
418 |
+
constraint_masks = constraint_masks[target != self.padding_idx]
|
419 |
+
# print(target.shape, self.padding_idx, lprobs.shape, target, lprobs)
|
420 |
+
lprobs = lprobs[target != self.padding_idx]
|
421 |
+
target = target[target != self.padding_idx]
|
422 |
+
|
423 |
+
|
424 |
+
loss, nll_loss, ntokens = label_smoothed_nll_loss(
|
425 |
+
lprobs,
|
426 |
+
target,
|
427 |
+
self.eps,
|
428 |
+
update_num,
|
429 |
+
reduce=reduce,
|
430 |
+
drop_worst_ratio=self.drop_worst_ratio,
|
431 |
+
drop_worst_after=self.drop_worst_after,
|
432 |
+
use_rdrop=self.use_rdrop,
|
433 |
+
reg_alpha=self.reg_alpha,
|
434 |
+
constraint_masks=constraint_masks,
|
435 |
+
constraint_start=self.constraint_start,
|
436 |
+
constraint_end=self.constraint_end
|
437 |
+
)
|
438 |
+
|
439 |
+
if self.logprob and self.reinforce:
|
440 |
+
# print(-lprobs.max(dim=-1)[0].squeeze(-1).sum(), loss)
|
441 |
+
if self.lambda_reinforce > 0:
|
442 |
+
lprobs_, target_, constraint_masks_ = self.get_lprobs_and_target(model, net_output, sample, reward=None)
|
443 |
+
|
444 |
+
loss_, _, ntokens = label_smoothed_nll_loss(
|
445 |
+
lprobs_,
|
446 |
+
target_,
|
447 |
+
self.eps,
|
448 |
+
update_num,
|
449 |
+
reduce=reduce,
|
450 |
+
drop_worst_ratio=self.drop_worst_ratio,
|
451 |
+
drop_worst_after=self.drop_worst_after,
|
452 |
+
use_rdrop=self.use_rdrop,
|
453 |
+
reg_alpha=self.reg_alpha,
|
454 |
+
constraint_masks=constraint_masks_,
|
455 |
+
constraint_start=self.constraint_start,
|
456 |
+
constraint_end=self.constraint_end
|
457 |
+
)
|
458 |
+
# print(-lprobs.max(dim=-1)[0].squeeze(-1).sum(), loss_)
|
459 |
+
# loss = -lprobs.max(dim=-1)[0].squeeze(-1).sum()*self.lambda_reinforce + loss_
|
460 |
+
|
461 |
+
loss = loss*self.lambda_reinforce + loss_ # only supervised with more weights via reward
|
462 |
+
|
463 |
+
else:
|
464 |
+
loss = -lprobs.max(dim=-1)[0].squeeze(-1).sum()
|
465 |
+
|
466 |
+
elif self.logprob:
|
467 |
+
loss = nll_loss
|
468 |
+
|
469 |
+
return loss, nll_loss, ntokens
|
470 |
+
|
471 |
+
def compute_accuracy(self, model, net_output, sample):
|
472 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
473 |
+
mask = target.ne(self.padding_idx)
|
474 |
+
n_correct = torch.sum(
|
475 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
476 |
+
)
|
477 |
+
total = torch.sum(mask)
|
478 |
+
return n_correct, total
|
479 |
+
|
480 |
+
@classmethod
|
481 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
482 |
+
"""Aggregate logging outputs from data parallel training."""
|
483 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
484 |
+
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
|
485 |
+
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
|
486 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
487 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
488 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
489 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
490 |
+
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
|
491 |
+
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
|
492 |
+
|
493 |
+
|
494 |
+
reward = sum(log.get("reward", 0) for log in logging_outputs)
|
495 |
+
|
496 |
+
metrics.log_scalar(
|
497 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
498 |
+
)
|
499 |
+
metrics.log_scalar(
|
500 |
+
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
|
501 |
+
)
|
502 |
+
metrics.log_scalar(
|
503 |
+
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
|
504 |
+
)
|
505 |
+
metrics.log_scalar(
|
506 |
+
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
|
507 |
+
)
|
508 |
+
metrics.log_derived(
|
509 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
510 |
+
)
|
511 |
+
|
512 |
+
metrics.log_scalar(
|
513 |
+
"ntokens", ntokens, 1, round=3
|
514 |
+
)
|
515 |
+
metrics.log_scalar(
|
516 |
+
"nsentences", nsentences, 1, round=3
|
517 |
+
)
|
518 |
+
metrics.log_scalar(
|
519 |
+
"sample_size", sample_size, 1, round=3
|
520 |
+
)
|
521 |
+
metrics.log_scalar(
|
522 |
+
"sample_size_v1", sample_size_v1, 1, round=3
|
523 |
+
)
|
524 |
+
metrics.log_scalar(
|
525 |
+
"sample_size_v2", sample_size_v2, 1, round=3
|
526 |
+
)
|
527 |
+
|
528 |
+
metrics.log_scalar(
|
529 |
+
"reward", reward / sample_size, sample_size, round=3
|
530 |
+
)
|
531 |
+
|
532 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
533 |
+
if total > 0:
|
534 |
+
metrics.log_scalar("total", total)
|
535 |
+
n_correct = utils.item(
|
536 |
+
sum(log.get("n_correct", 0) for log in logging_outputs)
|
537 |
+
)
|
538 |
+
metrics.log_scalar("n_correct", n_correct)
|
539 |
+
metrics.log_derived(
|
540 |
+
"accuracy",
|
541 |
+
lambda meters: round(
|
542 |
+
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
543 |
+
)
|
544 |
+
if meters["total"].sum > 0
|
545 |
+
else float("nan"),
|
546 |
+
)
|
547 |
+
|
548 |
+
@staticmethod
|
549 |
+
def logging_outputs_can_be_summed() -> bool:
|
550 |
+
"""
|
551 |
+
Whether the logging outputs returned by `forward` can be summed
|
552 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
553 |
+
to True will improves distributed training speed.
|
554 |
+
"""
|
555 |
+
return True
|
criterions/label_smoothed_encouraging_loss.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import numpy as np
|
13 |
+
from fairseq import metrics, utils
|
14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
15 |
+
from fairseq.dataclass import FairseqDataclass
|
16 |
+
from omegaconf import II
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class AdjustLabelSmoothedEncouragingLossConfig(FairseqDataclass):
|
21 |
+
label_smoothing: float = field(
|
22 |
+
default=0.0,
|
23 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
24 |
+
)
|
25 |
+
report_accuracy: bool = field(
|
26 |
+
default=False,
|
27 |
+
metadata={"help": "report accuracy metric"},
|
28 |
+
)
|
29 |
+
ignore_prefix_size: int = field(
|
30 |
+
default=0,
|
31 |
+
metadata={"help": "Ignore first N tokens"},
|
32 |
+
)
|
33 |
+
ignore_eos: bool = field(
|
34 |
+
default=False,
|
35 |
+
metadata={"help": "Ignore eos token"},
|
36 |
+
)
|
37 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
38 |
+
drop_worst_ratio: float = field(
|
39 |
+
default=0.0,
|
40 |
+
metadata={"help": "ratio for discarding bad samples"},
|
41 |
+
)
|
42 |
+
drop_worst_after: int = field(
|
43 |
+
default=0,
|
44 |
+
metadata={"help": "steps for discarding bad samples"},
|
45 |
+
)
|
46 |
+
use_rdrop: bool = field(
|
47 |
+
default=False, metadata={"help": "use R-Drop"}
|
48 |
+
)
|
49 |
+
reg_alpha: float = field(
|
50 |
+
default=1.0, metadata={"help": "weight for R-Drop"}
|
51 |
+
)
|
52 |
+
sample_patch_num: int = field(
|
53 |
+
default=196, metadata={"help": "sample patchs for v1"}
|
54 |
+
)
|
55 |
+
constraint_range: Optional[str] = field(
|
56 |
+
default=None,
|
57 |
+
metadata={"help": "constraint range"}
|
58 |
+
)
|
59 |
+
log_end: float = field(
|
60 |
+
default=0.75,
|
61 |
+
metadata={"help": "higher log_end is for cases with higher performance,"
|
62 |
+
" we recommend 0.75 or 0.5 as your first try."}
|
63 |
+
)
|
64 |
+
drop_best_ratio: float = field(
|
65 |
+
default=0.0,
|
66 |
+
metadata={"help": "ratio for discarding best samples"},
|
67 |
+
)
|
68 |
+
drop_best_after: int = field(
|
69 |
+
default=0,
|
70 |
+
metadata={"help": "steps for discarding best samples"},
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def construct_rdrop_sample(x):
|
76 |
+
if isinstance(x, dict):
|
77 |
+
for key in x:
|
78 |
+
x[key] = construct_rdrop_sample(x[key])
|
79 |
+
return x
|
80 |
+
elif isinstance(x, torch.Tensor):
|
81 |
+
return x.repeat(2, *([1] * (x.dim()-1)))
|
82 |
+
elif isinstance(x, int):
|
83 |
+
return x * 2
|
84 |
+
elif isinstance(x, np.ndarray):
|
85 |
+
return x.repeat(2)
|
86 |
+
else:
|
87 |
+
raise NotImplementedError
|
88 |
+
|
89 |
+
|
90 |
+
def kl_loss(p, q):
|
91 |
+
p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
|
92 |
+
q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
|
93 |
+
loss = (p_loss + q_loss) / 2
|
94 |
+
return loss
|
95 |
+
|
96 |
+
|
97 |
+
def label_smoothed_nll_loss(
|
98 |
+
lprobs, target, epsilon, update_num, reduce=True,
|
99 |
+
drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
|
100 |
+
constraint_masks=None, constraint_start=None, constraint_end=None, drop_best_ratio=0.0,
|
101 |
+
drop_best_after=0,
|
102 |
+
):
|
103 |
+
if target.dim() == lprobs.dim() - 1:
|
104 |
+
target = target.unsqueeze(-1)
|
105 |
+
nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
|
106 |
+
if constraint_masks is not None:
|
107 |
+
smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
|
108 |
+
eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
|
109 |
+
elif constraint_start is not None and constraint_end is not None:
|
110 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
111 |
+
smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
|
112 |
+
eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
|
113 |
+
else:
|
114 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
|
115 |
+
eps_i = epsilon / (lprobs.size(-1) - 1)
|
116 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
117 |
+
if drop_worst_ratio > 0 and update_num > drop_worst_after:
|
118 |
+
if use_rdrop:
|
119 |
+
true_batch_size = loss.size(0) // 2
|
120 |
+
_, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
|
121 |
+
loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
|
122 |
+
nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
|
123 |
+
lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
|
124 |
+
else:
|
125 |
+
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
|
126 |
+
nll_loss = nll_loss[indices]
|
127 |
+
lprobs = lprobs[indices]
|
128 |
+
target = target[indices]
|
129 |
+
if update_num > drop_best_after:
|
130 |
+
loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_best_ratio)), largest=True)
|
131 |
+
nll_loss = nll_loss[indices]
|
132 |
+
lprobs = lprobs[indices]
|
133 |
+
target = target[indices]
|
134 |
+
|
135 |
+
ntokens = loss.numel()
|
136 |
+
nll_loss = nll_loss.sum()
|
137 |
+
loss = loss.sum()
|
138 |
+
if use_rdrop:
|
139 |
+
true_batch_size = lprobs.size(0) // 2
|
140 |
+
p = lprobs[:true_batch_size]
|
141 |
+
q = lprobs[true_batch_size:]
|
142 |
+
if constraint_start is not None and constraint_end is not None:
|
143 |
+
constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
|
144 |
+
p = p[:, constraint_range]
|
145 |
+
q = q[:, constraint_range]
|
146 |
+
loss += kl_loss(p, q) * reg_alpha
|
147 |
+
|
148 |
+
return loss, nll_loss, ntokens,lprobs,target
|
149 |
+
|
150 |
+
|
151 |
+
@register_criterion(
|
152 |
+
"adjust_label_smoothed_encouraging_loss", dataclass=AdjustLabelSmoothedEncouragingLossConfig
|
153 |
+
)
|
154 |
+
class AdjustLabelSmoothedEncouragingLossCriterion(FairseqCriterion):
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
task,
|
158 |
+
sentence_avg,
|
159 |
+
label_smoothing,
|
160 |
+
ignore_prefix_size=0,
|
161 |
+
ignore_eos=False,
|
162 |
+
report_accuracy=False,
|
163 |
+
drop_worst_ratio=0,
|
164 |
+
drop_worst_after=0,
|
165 |
+
use_rdrop=False,
|
166 |
+
reg_alpha=1.0,
|
167 |
+
sample_patch_num=196,
|
168 |
+
constraint_range=None,
|
169 |
+
log_end=0.75,
|
170 |
+
drop_best_ratio=0.0,
|
171 |
+
drop_best_after=0,
|
172 |
+
):
|
173 |
+
super().__init__(task)
|
174 |
+
self.sentence_avg = sentence_avg
|
175 |
+
self.eps = label_smoothing
|
176 |
+
self.ignore_prefix_size = ignore_prefix_size
|
177 |
+
self.ignore_eos = ignore_eos
|
178 |
+
self.report_accuracy = report_accuracy
|
179 |
+
self.drop_worst_ratio = drop_worst_ratio
|
180 |
+
self.drop_worst_after = drop_worst_after
|
181 |
+
self.use_rdrop = use_rdrop
|
182 |
+
self.reg_alpha = reg_alpha
|
183 |
+
self.sample_patch_num = sample_patch_num
|
184 |
+
|
185 |
+
self.constraint_start = None
|
186 |
+
self.constraint_end = None
|
187 |
+
if constraint_range is not None:
|
188 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
189 |
+
self.constraint_start = int(constraint_start)
|
190 |
+
self.constraint_end = int(constraint_end)
|
191 |
+
self.log_end = log_end
|
192 |
+
self.drop_best_ratio = drop_best_ratio
|
193 |
+
self.drop_best_after = drop_best_after
|
194 |
+
print('el, self.log_end=', self.log_end)
|
195 |
+
# @staticmethod
|
196 |
+
# def add_args(parser):
|
197 |
+
# """Add criterion-specific arguments to the parser."""
|
198 |
+
# # fmt: off
|
199 |
+
# parser.add_argument('--log_end', type=float, default=1.0)
|
200 |
+
|
201 |
+
def forward(self, model, sample, update_num=0, reduce=True):
|
202 |
+
"""Compute the loss for the given sample.
|
203 |
+
|
204 |
+
Returns a tuple with three elements:
|
205 |
+
1) the loss
|
206 |
+
2) the sample size, which is used as the denominator for the gradient
|
207 |
+
3) logging outputs to display while training
|
208 |
+
"""
|
209 |
+
if isinstance(sample, list):
|
210 |
+
if self.sample_patch_num > 0:
|
211 |
+
sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
|
212 |
+
loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
|
213 |
+
loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
|
214 |
+
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
|
215 |
+
sample_size = 1
|
216 |
+
logging_output = {
|
217 |
+
"loss": loss.data,
|
218 |
+
"loss_v1": loss_v1.data,
|
219 |
+
"loss_v2": loss_v2.data,
|
220 |
+
"nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
|
221 |
+
"ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
|
222 |
+
"nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
|
223 |
+
"sample_size": 1,
|
224 |
+
"sample_size_v1": sample_size_v1,
|
225 |
+
"sample_size_v2": sample_size_v2,
|
226 |
+
}
|
227 |
+
return loss, sample_size, logging_output
|
228 |
+
|
229 |
+
if self.use_rdrop:
|
230 |
+
construct_rdrop_sample(sample)
|
231 |
+
|
232 |
+
net_output = model(**sample["net_input"])
|
233 |
+
loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
|
234 |
+
sample_size = (
|
235 |
+
sample["target"].size(0) if self.sentence_avg else ntokens
|
236 |
+
)
|
237 |
+
logging_output = {
|
238 |
+
"loss": loss.data,
|
239 |
+
"nll_loss": nll_loss.data,
|
240 |
+
"ntokens": sample["ntokens"],
|
241 |
+
"nsentences": sample["nsentences"],
|
242 |
+
"sample_size": sample_size,
|
243 |
+
}
|
244 |
+
if self.report_accuracy:
|
245 |
+
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
246 |
+
logging_output["n_correct"] = utils.item(n_correct.data)
|
247 |
+
logging_output["total"] = utils.item(total.data)
|
248 |
+
return loss, sample_size, logging_output
|
249 |
+
|
250 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
251 |
+
conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
|
252 |
+
constraint_masks = None
|
253 |
+
if "constraint_masks" in sample and sample["constraint_masks"] is not None:
|
254 |
+
constraint_masks = sample["constraint_masks"]
|
255 |
+
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
256 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
257 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
258 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
259 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
|
260 |
+
target = model.get_targets(sample, net_output)
|
261 |
+
if self.ignore_prefix_size > 0:
|
262 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
263 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
264 |
+
if constraint_masks is not None:
|
265 |
+
constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
|
266 |
+
if self.ignore_eos:
|
267 |
+
bsz, seq_len, embed_dim = lprobs.size()
|
268 |
+
eos_indices = target.eq(self.task.tgt_dict.eos())
|
269 |
+
lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
270 |
+
target = target[~eos_indices].reshape(bsz, seq_len-1)
|
271 |
+
if constraint_masks is not None:
|
272 |
+
constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
|
273 |
+
if constraint_masks is not None:
|
274 |
+
constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
|
275 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
|
276 |
+
|
277 |
+
def compute_loss(self, model, net_output, sample, update_num, reduce=True):
|
278 |
+
lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
|
279 |
+
if constraint_masks is not None:
|
280 |
+
constraint_masks = constraint_masks[target != self.padding_idx]
|
281 |
+
lprobs = lprobs[target != self.padding_idx]
|
282 |
+
target = target[target != self.padding_idx]
|
283 |
+
loss, nll_loss, ntokens, lprobs, target = label_smoothed_nll_loss(
|
284 |
+
lprobs,
|
285 |
+
target,
|
286 |
+
self.eps,
|
287 |
+
update_num,
|
288 |
+
reduce=reduce,
|
289 |
+
drop_worst_ratio=self.drop_worst_ratio,
|
290 |
+
drop_worst_after=self.drop_worst_after,
|
291 |
+
use_rdrop=self.use_rdrop,
|
292 |
+
reg_alpha=self.reg_alpha,
|
293 |
+
constraint_masks=constraint_masks,
|
294 |
+
constraint_start=self.constraint_start,
|
295 |
+
constraint_end=self.constraint_end
|
296 |
+
)
|
297 |
+
# for encouraging loss
|
298 |
+
probs = torch.exp(lprobs)
|
299 |
+
bonus = torch.log(torch.clamp((torch.ones_like(probs) - probs), min=1e-5)) # likelihood bonus
|
300 |
+
log_end = self.log_end
|
301 |
+
if log_end != 1.0: # e.g. 0.9
|
302 |
+
y_log_end = torch.log(torch.ones_like(probs) - log_end)
|
303 |
+
bonus_after_log_end = 1 / (log_end - torch.ones_like(probs)) * (probs - log_end) + y_log_end
|
304 |
+
# x:log_end, y torch.log(torch.clamp((torch.ones_like(probs) - probs), min=self.cl_eps))
|
305 |
+
bonus = torch.where(probs > log_end, bonus_after_log_end, bonus)
|
306 |
+
c_loss = F.nll_loss(
|
307 |
+
-bonus,
|
308 |
+
target.view(-1),
|
309 |
+
reduction='sum',
|
310 |
+
)
|
311 |
+
smoothing_c_loss = bonus.sum(dim=-1)
|
312 |
+
smoothing_c_loss = smoothing_c_loss.sum()
|
313 |
+
c_loss = c_loss * (1 - self.eps) + (self.eps / lprobs.size(-1)) * smoothing_c_loss
|
314 |
+
loss = loss + c_loss
|
315 |
+
# end for encouraging loss
|
316 |
+
return loss, nll_loss, ntokens
|
317 |
+
|
318 |
+
def compute_accuracy(self, model, net_output, sample):
|
319 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
320 |
+
mask = target.ne(self.padding_idx)
|
321 |
+
n_correct = torch.sum(
|
322 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
323 |
+
)
|
324 |
+
total = torch.sum(mask)
|
325 |
+
return n_correct, total
|
326 |
+
|
327 |
+
@classmethod
|
328 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
329 |
+
"""Aggregate logging outputs from data parallel training."""
|
330 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
331 |
+
loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
|
332 |
+
loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
|
333 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
334 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
335 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
336 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
337 |
+
sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
|
338 |
+
sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
|
339 |
+
|
340 |
+
metrics.log_scalar(
|
341 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
342 |
+
)
|
343 |
+
metrics.log_scalar(
|
344 |
+
"loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
|
345 |
+
)
|
346 |
+
metrics.log_scalar(
|
347 |
+
"loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
|
348 |
+
)
|
349 |
+
metrics.log_scalar(
|
350 |
+
"nll_loss", nll_loss_sum / sample_size, ntokens, round=3
|
351 |
+
)
|
352 |
+
metrics.log_derived(
|
353 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
354 |
+
)
|
355 |
+
|
356 |
+
metrics.log_scalar(
|
357 |
+
"ntokens", ntokens, 1, round=3
|
358 |
+
)
|
359 |
+
metrics.log_scalar(
|
360 |
+
"nsentences", nsentences, 1, round=3
|
361 |
+
)
|
362 |
+
metrics.log_scalar(
|
363 |
+
"sample_size", sample_size, 1, round=3
|
364 |
+
)
|
365 |
+
metrics.log_scalar(
|
366 |
+
"sample_size_v1", sample_size_v1, 1, round=3
|
367 |
+
)
|
368 |
+
metrics.log_scalar(
|
369 |
+
"sample_size_v2", sample_size_v2, 1, round=3
|
370 |
+
)
|
371 |
+
|
372 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
373 |
+
if total > 0:
|
374 |
+
metrics.log_scalar("total", total)
|
375 |
+
n_correct = utils.item(
|
376 |
+
sum(log.get("n_correct", 0) for log in logging_outputs)
|
377 |
+
)
|
378 |
+
metrics.log_scalar("n_correct", n_correct)
|
379 |
+
metrics.log_derived(
|
380 |
+
"accuracy",
|
381 |
+
lambda meters: round(
|
382 |
+
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
383 |
+
)
|
384 |
+
if meters["total"].sum > 0
|
385 |
+
else float("nan"),
|
386 |
+
)
|
387 |
+
|
388 |
+
@staticmethod
|
389 |
+
def logging_outputs_can_be_summed() -> bool:
|
390 |
+
"""
|
391 |
+
Whether the logging outputs returned by `forward` can be summed
|
392 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
393 |
+
to True will improves distributed training speed.
|
394 |
+
"""
|
395 |
+
return True
|
criterions/refcoco_scst_loss.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OFA code.
|
2 |
+
# Copyright 2022 The OFA-Sys Team.
|
3 |
+
# All rights reserved.
|
4 |
+
# This source code is licensed under the Apache 2.0 license
|
5 |
+
# found in the LICENSE file in the root directory.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import string
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from collections import OrderedDict
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from fairseq import metrics, utils
|
15 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
16 |
+
from fairseq.dataclass import FairseqDataclass
|
17 |
+
from omegaconf import II
|
18 |
+
|
19 |
+
from data import data_utils
|
20 |
+
from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True, ce=False):
|
25 |
+
|
26 |
+
if ce:
|
27 |
+
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
|
28 |
+
elif isinstance(reward, float):
|
29 |
+
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward
|
30 |
+
else:
|
31 |
+
loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
|
32 |
+
|
33 |
+
if ignore_index is not None:
|
34 |
+
pad_mask = target.eq(ignore_index)
|
35 |
+
loss.masked_fill_(pad_mask, 0.0)
|
36 |
+
ntokens = (~pad_mask).sum()
|
37 |
+
else:
|
38 |
+
loss = loss.squeeze(-1)
|
39 |
+
ntokens = target.numel()
|
40 |
+
if reduce:
|
41 |
+
loss = loss.sum()
|
42 |
+
return loss, ntokens
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class RefCOCOScstRewardCriterionConfig(FairseqDataclass):
|
47 |
+
scst_cider_cached_tokens: Optional[str] = field(
|
48 |
+
default="coco-train-words.p",
|
49 |
+
metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
|
50 |
+
)
|
51 |
+
ignore_prefix_size: int = field(
|
52 |
+
default=0,
|
53 |
+
metadata={"help": "Ignore first N tokens"},
|
54 |
+
)
|
55 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
56 |
+
constraint_range: Optional[str] = field(
|
57 |
+
default=None,
|
58 |
+
metadata={"help": "constraint range"}
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
acc_thresh: Optional[float] = field(
|
63 |
+
default=None, metadata={"help": "acc thresh for refcoco"}
|
64 |
+
)
|
65 |
+
metric: Optional[str] = field(
|
66 |
+
default='acc',
|
67 |
+
metadata={"help": "metric"}
|
68 |
+
)
|
69 |
+
|
70 |
+
max_area_size: Optional[float] = field(
|
71 |
+
default=None, metadata={"help": "max_area_size"}
|
72 |
+
)
|
73 |
+
|
74 |
+
min_area_size: Optional[float] = field(
|
75 |
+
default=None, metadata={"help": "min_area_size"}
|
76 |
+
)
|
77 |
+
logprob: Optional[bool] = field(
|
78 |
+
default=False, metadata={"help": "maximise log prob"}
|
79 |
+
)
|
80 |
+
|
81 |
+
pos_reward: Optional[float] = field(
|
82 |
+
default=None, metadata={"help": "pos_reward"}
|
83 |
+
)
|
84 |
+
|
85 |
+
neg_reward: Optional[float] = field(
|
86 |
+
default=None, metadata={"help": "neg_reward"}
|
87 |
+
)
|
88 |
+
|
89 |
+
reinforce: Optional[bool] = field(
|
90 |
+
default=False, metadata={"help": "reinforce"}
|
91 |
+
)
|
92 |
+
|
93 |
+
lambda_reinforce: Optional[float] = field(
|
94 |
+
default=0, metadata={"help": "lambda_reinforce"}
|
95 |
+
)
|
96 |
+
|
97 |
+
medium_area: Optional[bool] = field(
|
98 |
+
default=False, metadata={"help": "reinforce"}
|
99 |
+
)
|
100 |
+
|
101 |
+
@register_criterion(
|
102 |
+
"refcoco_scst_reward_criterion", dataclass=RefCOCOScstRewardCriterionConfig
|
103 |
+
)
|
104 |
+
class RefCOCOScstRewardCriterion(FairseqCriterion):
|
105 |
+
CIDER_REWARD_WEIGHT = 1
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
task,
|
110 |
+
scst_cider_cached_tokens,
|
111 |
+
sentence_avg,
|
112 |
+
ignore_prefix_size=0,
|
113 |
+
constraint_range=None,
|
114 |
+
acc_thresh=None,
|
115 |
+
metric='acc',
|
116 |
+
max_area_size=None,
|
117 |
+
min_area_size=None,
|
118 |
+
logprob=False,
|
119 |
+
pos_reward=None,
|
120 |
+
neg_reward=None,
|
121 |
+
reinforce=False,
|
122 |
+
lambda_reinforce=0,
|
123 |
+
medium_area=False,
|
124 |
+
):
|
125 |
+
super().__init__(task)
|
126 |
+
self.sentence_avg = sentence_avg
|
127 |
+
self.ignore_prefix_size = ignore_prefix_size
|
128 |
+
self.transtab = str.maketrans({key: None for key in string.punctuation})
|
129 |
+
|
130 |
+
self.constraint_start = None
|
131 |
+
self.constraint_end = None
|
132 |
+
if constraint_range is not None:
|
133 |
+
constraint_start, constraint_end = constraint_range.split(',')
|
134 |
+
self.constraint_start = int(constraint_start)
|
135 |
+
self.constraint_end = int(constraint_end)
|
136 |
+
|
137 |
+
self.metric = metric
|
138 |
+
print("metric", metric)
|
139 |
+
|
140 |
+
self.acc_thresh = acc_thresh
|
141 |
+
self.metric = metric
|
142 |
+
self.min_area_size = min_area_size
|
143 |
+
self.max_area_size = max_area_size
|
144 |
+
self.logprob = logprob
|
145 |
+
|
146 |
+
self.pos_reward = pos_reward
|
147 |
+
self.neg_reward = neg_reward
|
148 |
+
|
149 |
+
self.reinforce = reinforce
|
150 |
+
self.lambda_reinforce = lambda_reinforce
|
151 |
+
|
152 |
+
self.medium_area = medium_area
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
def forward(self, model, sample, update_num=0, reduce=True):
|
158 |
+
"""Compute the loss for the given sample.
|
159 |
+
|
160 |
+
Returns a tuple with three elements:
|
161 |
+
1) the loss
|
162 |
+
2) the sample size, which is used as the denominator for the gradient
|
163 |
+
3) logging outputs to display while training
|
164 |
+
"""
|
165 |
+
loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
|
166 |
+
|
167 |
+
sample_size = (
|
168 |
+
nsentences if self.sentence_avg else ntokens
|
169 |
+
)
|
170 |
+
logging_output = {
|
171 |
+
"loss": loss.data,
|
172 |
+
"score": score,
|
173 |
+
"ntokens": ntokens,
|
174 |
+
"nsentences": nsentences,
|
175 |
+
"sample_size": sample_size,
|
176 |
+
}
|
177 |
+
return loss, sample_size, logging_output
|
178 |
+
|
179 |
+
def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
|
180 |
+
'''
|
181 |
+
gen_res: generated captions, list of str
|
182 |
+
gt_idx: list of int, of the same length as gen_res
|
183 |
+
gt_res: ground truth captions, list of list of str.
|
184 |
+
gen_res[i] corresponds to gt_res[gt_idx[i]]
|
185 |
+
Each image can have multiple ground truth captions
|
186 |
+
'''
|
187 |
+
|
188 |
+
gen_res_size = len(gen_res)
|
189 |
+
|
190 |
+
res = OrderedDict()
|
191 |
+
for i in range(gen_res_size):
|
192 |
+
res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
|
193 |
+
|
194 |
+
gts = OrderedDict()
|
195 |
+
gt_res_ = [
|
196 |
+
[self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
|
197 |
+
for i in range(len(gt_res))
|
198 |
+
]
|
199 |
+
for i in range(gen_res_size):
|
200 |
+
gts[i] = gt_res_[gt_idx[i]]
|
201 |
+
|
202 |
+
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
203 |
+
|
204 |
+
# replace with other metrics
|
205 |
+
if self.metric != 'cider':
|
206 |
+
predicts = [res[i][0] if isinstance(res[i], list) else res[i] for i in range(len(res))]
|
207 |
+
|
208 |
+
answers = [gts[i] for i in range(gen_res_size)]
|
209 |
+
|
210 |
+
results = self.evaluator.run_evaluation(predicts, answers)
|
211 |
+
batch_cider_scores = results[self.metric]
|
212 |
+
|
213 |
+
batch_cider_scores = torch.tensor(batch_cider_scores).repeat(gen_res_size)
|
214 |
+
else:
|
215 |
+
_, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
|
216 |
+
|
217 |
+
scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
|
218 |
+
return scores
|
219 |
+
|
220 |
+
@classmethod
|
221 |
+
def _wrap_sentence(self, s):
|
222 |
+
# ensure the sentence ends with <eos> token
|
223 |
+
# in order to keep consisitent with cider_cached_tokens
|
224 |
+
r = s.strip()
|
225 |
+
if r.endswith('.'):
|
226 |
+
r = r[:-1]
|
227 |
+
r += ' <eos>'
|
228 |
+
return r
|
229 |
+
|
230 |
+
|
231 |
+
def get_generator_out(self, model, sample):
|
232 |
+
|
233 |
+
|
234 |
+
model.eval()
|
235 |
+
with torch.no_grad():
|
236 |
+
self.task.scst_generator.model.eval()
|
237 |
+
gen_out = self.task.scst_generator.generate([model], sample)
|
238 |
+
|
239 |
+
gen_target = []
|
240 |
+
gen_res = []
|
241 |
+
gt_res = []
|
242 |
+
for i in range(len(gen_out)):
|
243 |
+
gen_res.append(gen_out[i][0]["tokens"][:-1] - len(self.task.src_dict) + self.task.cfg.num_bins)
|
244 |
+
gt_res.append(sample["target"][i][:-1] - len(self.task.src_dict) + self.task.cfg.num_bins)
|
245 |
+
gen_target.append(gen_out[i][0]["tokens"][:-1].int().cpu())
|
246 |
+
|
247 |
+
return gen_target, gen_res, gt_res
|
248 |
+
|
249 |
+
def _calculate_ap_score(self, hyps, refs, thresh=0.5, min_area_size=None, max_area_size=None, medium_area=False):
|
250 |
+
interacts = torch.cat(
|
251 |
+
[torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]),
|
252 |
+
torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])],
|
253 |
+
dim=1
|
254 |
+
)
|
255 |
+
area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1]) ## x1, y1, x2, y2, x1 < x2
|
256 |
+
area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
|
257 |
+
interacts_w = interacts[:, 2] - interacts[:, 0]
|
258 |
+
interacts_h = interacts[:, 3] - interacts[:, 1]
|
259 |
+
area_interacts = interacts_w * interacts_h
|
260 |
+
ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)
|
261 |
+
|
262 |
+
|
263 |
+
if max_area_size is not None and min_area_size is not None:
|
264 |
+
if medium_area:
|
265 |
+
ious = ious * (torch.logical_and(area_targets > max_area_size, area_targets < min_area_size).float())
|
266 |
+
|
267 |
+
else:
|
268 |
+
ious = ious * (torch.logical_or(area_targets < max_area_size, area_targets > min_area_size).float())
|
269 |
+
|
270 |
+
elif min_area_size is not None:
|
271 |
+
if medium_area:
|
272 |
+
ious = ious * (area_targets < min_area_size).float() # as max areas
|
273 |
+
else:
|
274 |
+
ious = ious * (area_targets > min_area_size).float()
|
275 |
+
|
276 |
+
elif max_area_size is not None:
|
277 |
+
if medium_area:
|
278 |
+
ious = ious * (area_targets > max_area_size).float()
|
279 |
+
else:
|
280 |
+
ious = ious * (area_targets < max_area_size).float()
|
281 |
+
|
282 |
+
if thresh is None:
|
283 |
+
return ious
|
284 |
+
else:
|
285 |
+
return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float()
|
286 |
+
|
287 |
+
|
288 |
+
def get_reward_and_scores(self, gen_res, gt_res, device, sample):
|
289 |
+
|
290 |
+
|
291 |
+
hyps_, refs_ = torch.stack(gen_res, dim=0), torch.stack(gt_res, dim=0)
|
292 |
+
|
293 |
+
hyps = hyps_ / (self.task.cfg.num_bins - 1) * self.task.cfg.max_image_size
|
294 |
+
refs = refs_ / (self.task.cfg.num_bins - 1) * self.task.cfg.max_image_size
|
295 |
+
|
296 |
+
hyps[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
|
297 |
+
hyps[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
|
298 |
+
refs[:, ::2] /= sample['w_resize_ratios'].unsqueeze(1)
|
299 |
+
refs[:, 1::2] /= sample['h_resize_ratios'].unsqueeze(1)
|
300 |
+
|
301 |
+
if self.metric == 'acc':
|
302 |
+
scores = self._calculate_ap_score(hyps, sample['region_coords'].float(), thresh=self.acc_thresh,
|
303 |
+
min_area_size=self.min_area_size, max_area_size=self.max_area_size, medium_area=self.medium_area)
|
304 |
+
else:
|
305 |
+
raise NotImplemented
|
306 |
+
|
307 |
+
|
308 |
+
if self.pos_reward:
|
309 |
+
scores = torch.where(scores > 0, self.pos_reward, scores)
|
310 |
+
if self.neg_reward:
|
311 |
+
scores = torch.where(scores == 0, self.neg_reward, scores)
|
312 |
+
|
313 |
+
return scores, scores
|
314 |
+
|
315 |
+
|
316 |
+
def get_net_output(self, model, sample, gen_target):
|
317 |
+
def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
|
318 |
+
return data_utils.collate_tokens(
|
319 |
+
sample_list,
|
320 |
+
pad_idx=self.padding_idx,
|
321 |
+
eos_idx=eos,
|
322 |
+
left_pad=False,
|
323 |
+
move_eos_to_beginning=move_eos_to_beginning,
|
324 |
+
)
|
325 |
+
|
326 |
+
batch_size = len(sample["target"])
|
327 |
+
gen_target_size = len(gen_target)
|
328 |
+
seq_per_img = gen_target_size // batch_size
|
329 |
+
|
330 |
+
model.train()
|
331 |
+
sample_src_tokens = torch.repeat_interleave(
|
332 |
+
sample['net_input']['src_tokens'], seq_per_img, dim=0
|
333 |
+
)
|
334 |
+
sample_src_lengths = torch.repeat_interleave(
|
335 |
+
sample['net_input']['src_lengths'], seq_per_img, dim=0
|
336 |
+
)
|
337 |
+
sample_patch_images = torch.repeat_interleave(
|
338 |
+
sample['net_input']['patch_images'], seq_per_img, dim=0
|
339 |
+
)
|
340 |
+
sample_patch_masks = torch.repeat_interleave(
|
341 |
+
sample['net_input']['patch_masks'], seq_per_img, dim=0
|
342 |
+
)
|
343 |
+
gen_prev_output_tokens = torch.as_tensor(
|
344 |
+
merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
|
345 |
+
device=sample["target"].device, dtype=torch.int64
|
346 |
+
)
|
347 |
+
gen_target_tokens = torch.as_tensor(
|
348 |
+
merge(gen_target), device=sample["target"].device, dtype=torch.int64
|
349 |
+
)
|
350 |
+
|
351 |
+
net_output = model(
|
352 |
+
src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
|
353 |
+
patch_images=sample_patch_images, patch_masks=sample_patch_masks,
|
354 |
+
prev_output_tokens=gen_prev_output_tokens
|
355 |
+
)
|
356 |
+
|
357 |
+
return net_output, gen_target_tokens
|
358 |
+
|
359 |
+
def get_lprobs_and_target(self, model, net_output, gen_target):
|
360 |
+
if self.constraint_start is not None and self.constraint_end is not None:
|
361 |
+
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
362 |
+
net_output[0][:, :, self.constraint_end:] = -math.inf
|
363 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
364 |
+
if self.ignore_prefix_size > 0:
|
365 |
+
if getattr(lprobs, "batch_first", False):
|
366 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
367 |
+
gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
|
368 |
+
else:
|
369 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
370 |
+
gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
|
371 |
+
return lprobs, gen_target
|
372 |
+
|
373 |
+
def compute_loss(self, model, sample, reduce=True):
|
374 |
+
gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
|
375 |
+
reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device, sample=sample)
|
376 |
+
|
377 |
+
net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
|
378 |
+
|
379 |
+
gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
|
380 |
+
loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
|
381 |
+
nsentences = gen_target_tokens.size(0)
|
382 |
+
|
383 |
+
if self.lambda_reinforce > 0:
|
384 |
+
target = model.get_targets(sample, net_output)[:, :-1] # ignore eos token
|
385 |
+
if self.ignore_prefix_size > 0:
|
386 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
387 |
+
|
388 |
+
loss_ce, ntokens_ = scst_loss(gen_lprobs, target, reward=1, ignore_index=self.padding_idx, reduce=reduce, ce=True)
|
389 |
+
|
390 |
+
loss = loss_ce + self.lambda_reinforce*loss
|
391 |
+
|
392 |
+
return loss, scores.sum(), ntokens, nsentences
|
393 |
+
|
394 |
+
@classmethod
|
395 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
396 |
+
"""Aggregate logging outputs from data parallel training."""
|
397 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
398 |
+
score_sum = sum(log.get("score", 0) for log in logging_outputs)
|
399 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
400 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
401 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
402 |
+
|
403 |
+
metrics.log_scalar(
|
404 |
+
"loss", loss_sum / sample_size, sample_size, round=3
|
405 |
+
)
|
406 |
+
metrics.log_scalar(
|
407 |
+
"score", score_sum / nsentences, nsentences, round=3
|
408 |
+
)
|
409 |
+
|
410 |
+
metrics.log_scalar(
|
411 |
+
"ntokens", ntokens, 1, round=3
|
412 |
+
)
|
413 |
+
metrics.log_scalar(
|
414 |
+
"nsentences", nsentences, 1, round=3
|
415 |
+
)
|
416 |
+
metrics.log_scalar(
|
417 |
+
"sample_size", sample_size, 1, round=3
|
418 |
+
)
|
419 |
+
|
420 |
+
@staticmethod
|
421 |
+
def logging_outputs_can_be_summed() -> bool:
|
422 |
+
"""
|
423 |
+
Whether the logging outputs returned by `forward` can be summed
|
424 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
425 |
+
to True will improves distributed training speed.
|
426 |
+
"""
|
427 |
+
return True
|
data/.ipynb_checkpoints/file_dataset-checkpoint.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The OFA-Sys Team.
|
2 |
+
# All rights reserved.
|
3 |
+
# This source code is licensed under the Apache 2.0 license
|
4 |
+
# found in the LICENSE file in the root directory.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import pickle
|
9 |
+
|
10 |
+
|
11 |
+
class FileDataset:
|
12 |
+
def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
|
13 |
+
self.file_path = file_path
|
14 |
+
assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
|
15 |
+
|
16 |
+
self.separator = separator
|
17 |
+
if selected_col_ids is None:
|
18 |
+
# default to all fields
|
19 |
+
self.selected_col_ids = list(
|
20 |
+
range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
|
21 |
+
else:
|
22 |
+
self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
|
23 |
+
if dtypes is None:
|
24 |
+
# default to str
|
25 |
+
self.dtypes = [str for col_id in self.selected_col_ids]
|
26 |
+
else:
|
27 |
+
self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
|
28 |
+
assert len(self.dtypes) == len(self.selected_col_ids)
|
29 |
+
|
30 |
+
self.data_cnt = 0
|
31 |
+
try:
|
32 |
+
self.slice_id = torch.distributed.get_rank()
|
33 |
+
self.slice_count = torch.distributed.get_world_size()
|
34 |
+
except Exception:
|
35 |
+
self.slice_id = 0
|
36 |
+
self.slice_count = 1
|
37 |
+
self.cached_index = cached_index
|
38 |
+
self._init_seek_index()
|
39 |
+
self._reader = self._get_reader()
|
40 |
+
print("file {} slice_id {} row count {} total row count {}".format(
|
41 |
+
self.file_path, self.slice_id, self.row_count, self.total_row_count)
|
42 |
+
)
|
43 |
+
|
44 |
+
def _init_seek_index(self):
|
45 |
+
if self.cached_index:
|
46 |
+
cache_path = "{}.index".format(self.file_path)
|
47 |
+
assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
|
48 |
+
self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
|
49 |
+
print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
|
50 |
+
self.file_path, self.slice_id))
|
51 |
+
else:
|
52 |
+
# make an iteration over the file to get row_count and line_idx-to-offset mapping
|
53 |
+
fp = open(self.file_path, "r")
|
54 |
+
print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
|
55 |
+
self.file_path, self.slice_id))
|
56 |
+
self.total_row_count = 0
|
57 |
+
offset = 0
|
58 |
+
self.lineid_to_offset = []
|
59 |
+
for line in fp:
|
60 |
+
self.lineid_to_offset.append(offset)
|
61 |
+
self.total_row_count += 1
|
62 |
+
offset += len(line.encode('utf-8'))
|
63 |
+
self._compute_start_pos_and_row_count()
|
64 |
+
print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
|
65 |
+
self.file_path, self.slice_id))
|
66 |
+
|
67 |
+
def _compute_start_pos_and_row_count(self):
|
68 |
+
self.row_count = self.total_row_count // self.slice_count
|
69 |
+
if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
|
70 |
+
self.row_count += 1
|
71 |
+
self.start_pos = self.row_count * self.slice_id
|
72 |
+
else:
|
73 |
+
self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
|
74 |
+
|
75 |
+
def _get_reader(self):
|
76 |
+
fp = open(self.file_path, "r")
|
77 |
+
fp.seek(self.lineid_to_offset[self.start_pos])
|
78 |
+
return fp
|
79 |
+
|
80 |
+
def _seek(self, offset=0):
|
81 |
+
try:
|
82 |
+
print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
|
83 |
+
self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
|
84 |
+
self.data_cnt = offset
|
85 |
+
except Exception:
|
86 |
+
print("slice_id {} seek offset {}".format(self.slice_id, offset))
|
87 |
+
self._reader.seek(self.lineid_to_offset[offset])
|
88 |
+
self.data_cnt = offset
|
89 |
+
|
90 |
+
def __del__(self):
|
91 |
+
self._reader.close()
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return self.row_count
|
95 |
+
|
96 |
+
def get_total_row_count(self):
|
97 |
+
return self.total_row_count
|
98 |
+
|
99 |
+
def __getitem__(self, index):
|
100 |
+
if self.data_cnt == self.row_count:
|
101 |
+
print("reach the end of datafile, start a new reader")
|
102 |
+
self.data_cnt = 0
|
103 |
+
self._reader = self._get_reader()
|
104 |
+
column_l = self._reader.readline().rstrip("\n").split(self.separator)
|
105 |
+
self.data_cnt += 1
|
106 |
+
column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
|
107 |
+
return column_l
|
data/__init__.py
ADDED
File without changes
|
data/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (124 Bytes). View file
|
|
data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (125 Bytes). View file
|
|
data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (151 Bytes). View file
|
|
data/__pycache__/audio_utils.cpython-37.pyc
ADDED
Binary file (4.95 kB). View file
|
|
data/__pycache__/audio_utils.cpython-39.pyc
ADDED
Binary file (5.04 kB). View file
|
|
data/__pycache__/data_utils.cpython-37.pyc
ADDED
Binary file (18.3 kB). View file
|
|
data/__pycache__/data_utils.cpython-38.pyc
ADDED
Binary file (18.5 kB). View file
|
|
data/__pycache__/data_utils.cpython-39.pyc
ADDED
Binary file (18.5 kB). View file
|
|