Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Copyright (c) HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
from collections import Counter | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from torch import nn | |
from torch.utils.data import Dataset | |
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)} | |
class ImageEncoder(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
model = torchvision.models.resnet152(pretrained=True) | |
modules = list(model.children())[:-2] | |
self.model = nn.Sequential(*modules) | |
self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[args.num_image_embeds]) | |
def forward(self, x): | |
# Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048 | |
out = self.pool(self.model(x)) | |
out = torch.flatten(out, start_dim=2) | |
out = out.transpose(1, 2).contiguous() | |
return out # BxNx2048 | |
class JsonlDataset(Dataset): | |
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length): | |
self.data = [json.loads(l) for l in open(data_path)] | |
self.data_dir = os.path.dirname(data_path) | |
self.tokenizer = tokenizer | |
self.labels = labels | |
self.n_classes = len(labels) | |
self.max_seq_length = max_seq_length | |
self.transforms = transforms | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True)) | |
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1] | |
sentence = sentence[: self.max_seq_length] | |
label = torch.zeros(self.n_classes) | |
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1 | |
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB") | |
image = self.transforms(image) | |
return { | |
"image_start_token": start_token, | |
"image_end_token": end_token, | |
"sentence": sentence, | |
"image": image, | |
"label": label, | |
} | |
def get_label_frequencies(self): | |
label_freqs = Counter() | |
for row in self.data: | |
label_freqs.update(row["label"]) | |
return label_freqs | |
def collate_fn(batch): | |
lens = [len(row["sentence"]) for row in batch] | |
bsz, max_seq_len = len(batch), max(lens) | |
mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long) | |
text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long) | |
for i_batch, (input_row, length) in enumerate(zip(batch, lens)): | |
text_tensor[i_batch, :length] = input_row["sentence"] | |
mask_tensor[i_batch, :length] = 1 | |
img_tensor = torch.stack([row["image"] for row in batch]) | |
tgt_tensor = torch.stack([row["label"] for row in batch]) | |
img_start_token = torch.stack([row["image_start_token"] for row in batch]) | |
img_end_token = torch.stack([row["image_end_token"] for row in batch]) | |
return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor | |
def get_mmimdb_labels(): | |
return [ | |
"Crime", | |
"Drama", | |
"Thriller", | |
"Action", | |
"Comedy", | |
"Romance", | |
"Documentary", | |
"Short", | |
"Mystery", | |
"History", | |
"Family", | |
"Adventure", | |
"Fantasy", | |
"Sci-Fi", | |
"Western", | |
"Horror", | |
"Sport", | |
"War", | |
"Music", | |
"Musical", | |
"Animation", | |
"Biography", | |
"Film-Noir", | |
] | |
def get_image_transforms(): | |
return transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.46777044, 0.44531429, 0.40661017], | |
std=[0.12221994, 0.12145835, 0.14380469], | |
), | |
] | |
) | |