Multilingual-VQA / utils.py
gchhablani's picture
Fix style
e289356
raw history blame
No virus
2.36 kB
import json
import numpy as np
import plotly.express as px
import torch
from PIL import Image
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from transformers import BertTokenizerFast
class Transform(torch.nn.Module):
def __init__(self, image_size):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
return x
transform = Transform(224)
def get_transformed_image(image):
if image.shape[-1] == 3 and isinstance(image, np.ndarray):
image = image.transpose(2, 0, 1)
image = torch.tensor(image)
return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
def get_text_attributes(text):
return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
def get_top_5_predictions(logits, answer_reverse_mapping):
indices = np.argsort(logits)[-5:]
values = logits[indices]
labels = [answer_reverse_mapping[str(i)] for i in indices]
return labels, values
with open("translation_dict.json") as f:
translate_dict = json.load(f)
def translate_labels(labels, lang_id):
translated_labels = []
for label in labels:
if label == "<unk>":
translated_labels.append("<unk>")
elif lang_id == "en":
translated_labels.append(label)
else:
translated_labels.append(translate_dict[label][lang_id])
return translated_labels
def plotly_express_horizontal_bar_plot(values, labels):
fig = px.bar(
x=values,
y=labels,
text=[format(value, ".3%") for value in values],
title="Top-5 Predictions",
labels={"x": "Scores", "y": "Answers"},
orientation="h",
)
return fig