Multilingual-VQA / utils.py
gchhablani's picture
Init app
2bbf92c
raw history blame
No virus
2.28 kB
from torchvision.io import read_image, ImageReadMode
import torch
import numpy as np
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from transformers import BertTokenizerFast
import plotly.express as px
import json
from PIL import Image
class Transform(torch.nn.Module):
def __init__(self, image_size):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
return x
transform = Transform(224)
def get_transformed_image(image):
if image.shape[-1] == 3 and isinstance(image, np.ndarray):
image = image.transpose(2,0,1)
image = torch.tensor(image)
return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
def get_text_attributes(text):
return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
def get_top_5_predictions(logits, answer_reverse_mapping):
indices = np.argsort(logits)[-5:]
values = np.round(logits[indices], decimals=2)
labels = [answer_reverse_mapping[str(i)] for i in indices]
return labels, values
with open('translation_dict.json') as f:
translate_dict = json.load(f)
def translate_labels(labels, lang_id):
translated_labels = []
for label in labels:
if label=="<unk>":
translated_labels.append("<unk>")
elif lang_id == "en":
translated_labels.append(label)
else:
translated_labels.append(translate_dict[label][lang_id])
return translated_labels
def plotly_express_horizontal_bar_plot(values, labels):
fig = px.bar(x=values, y=labels, text = values, title="Top-5 Predictions", labels={"x": "Scores", "y":"Answers"}, orientation="h")
return fig