File size: 2,293 Bytes
2bbf92c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4963f2
2bbf92c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4963f2
2bbf92c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from torchvision.io import read_image, ImageReadMode
import torch
import numpy as np
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from transformers import BertTokenizerFast
import plotly.express as px
import json
from PIL import Image
class Transform(torch.nn.Module):
    def __init__(self, image_size):
        super().__init__()
        self.transforms = torch.nn.Sequential(
            Resize([image_size], interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            ConvertImageDtype(torch.float),
            Normalize(
                (0.48145466, 0.4578275, 0.40821073),
                (0.26862954, 0.26130258, 0.27577711),
            ),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
        return x


transform = Transform(224)


def get_transformed_image(image):
    if image.shape[-1] == 3 and isinstance(image, np.ndarray):
        image = image.transpose(2,0,1)
        image = torch.tensor(image)
    return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()


bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")


def get_text_attributes(text):
    return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")


def get_top_5_predictions(logits, answer_reverse_mapping):
    indices = np.argsort(logits)[-5:]
    values = logits[indices]
    labels = [answer_reverse_mapping[str(i)] for i in indices]
    return labels, values

with open('translation_dict.json') as f:
    translate_dict = json.load(f)

def translate_labels(labels, lang_id):
    translated_labels = []
    for label in labels:
        if label=="<unk>":
            translated_labels.append("<unk>")
        elif lang_id == "en":
            translated_labels.append(label)
        else:
            translated_labels.append(translate_dict[label][lang_id])
    return translated_labels


def plotly_express_horizontal_bar_plot(values, labels):
    fig = px.bar(x=values, y=labels, text = [format(value, ".3%") for value in values], title="Top-5 Predictions", labels={"x": "Scores", "y":"Answers"}, orientation="h")
    return fig