from torchvision.io import read_image, ImageReadMode import torch import numpy as np from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize from torchvision.transforms.functional import InterpolationMode from transformers import BertTokenizerFast import plotly.express as px import json from PIL import Image class Transform(torch.nn.Module): def __init__(self, image_size): super().__init__() self.transforms = torch.nn.Sequential( Resize([image_size], interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ConvertImageDtype(torch.float), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ) def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): x = self.transforms(x) return x transform = Transform(224) def get_transformed_image(image): if image.shape[-1] == 3 and isinstance(image, np.ndarray): image = image.transpose(2,0,1) image = torch.tensor(image) return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy() bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased") def get_text_attributes(text): return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np") def get_top_5_predictions(logits, answer_reverse_mapping): indices = np.argsort(logits)[-5:] values = logits[indices] labels = [answer_reverse_mapping[str(i)] for i in indices] return labels, values with open('translation_dict.json') as f: translate_dict = json.load(f) def translate_labels(labels, lang_id): translated_labels = [] for label in labels: if label=="": translated_labels.append("") elif lang_id == "en": translated_labels.append(label) else: translated_labels.append(translate_dict[label][lang_id]) return translated_labels def plotly_express_horizontal_bar_plot(values, labels): fig = px.bar(x=values, y=labels, text = [format(value, ".3%") for value in values], title="Top-5 Predictions", labels={"x": "Scores", "y":"Answers"}, orientation="h") return fig