File size: 3,529 Bytes
e289356
fb3c77c
2bbf92c
393b8fc
e289356
 
405f2d4
2bbf92c
 
 
690384a
393b8fc
 
 
 
 
 
 
 
 
 
 
 
 
 
919efff
 
 
393b8fc
 
 
 
 
 
 
 
 
c684eba
393b8fc
 
 
690384a
2bbf92c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690384a
2bbf92c
 
 
 
 
 
2c8f495
2bbf92c
 
 
 
405f2d4
2bbf92c
f4963f2
405f2d4
 
 
 
2bbf92c
 
690384a
 
2bbf92c
 
690384a
2bbf92c
 
 
690384a
2bbf92c
 
 
 
 
 
 
 
 
690384a
 
 
 
 
 
 
 
 
fb3c77c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import os
import numpy as np
import streamlit as st
import plotly.express as px
import torch
from torchvision.io import read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from transformers import BertTokenizerFast

class Toc:

    def __init__(self):
        self._items = []
        self._placeholder = None
    
    def title(self, text):
        self._markdown(text, "h1")

    def header(self, text):
        self._markdown(text, "h2", " " * 2)

    def subheader(self, text):
        self._markdown(text, "h3", " " * 4)
    
    def subsubheader(self, text):
        self._markdown(text, "h4", " " * 8)

    def placeholder(self, sidebar=False):
        self._placeholder = st.sidebar.empty() if sidebar else st.empty()

    def generate(self):
        if self._placeholder:
            self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
    
    def _markdown(self, text, level, space=""):
        key = "".join(filter(str.isalnum, text)).lower().replace('+','')

        st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
        self._items.append(f"{space}* <a href='#{key}'>{text}</a>")

class Transform(torch.nn.Module):
    def __init__(self, image_size):
        super().__init__()
        self.transforms = torch.nn.Sequential(
            Resize([image_size], interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            ConvertImageDtype(torch.float),
            Normalize(
                (0.48145466, 0.4578275, 0.40821073),
                (0.26862954, 0.26130258, 0.27577711),
            ),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
        return x


transform = Transform(224)


def get_transformed_image(image):
    if image.shape[-1] == 3 and isinstance(image, np.ndarray):
        image = image.transpose(2, 0, 1)
        image = torch.tensor(image)
    return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()


bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")


def get_text_attributes(text):
    return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")


def get_top_5_predictions(logits, answer_reverse_mapping=None):
    indices = np.argsort(logits)[-5:]
    values = logits[indices]
    if answer_reverse_mapping is not None:
        labels = [answer_reverse_mapping[str(i)] for i in indices]
    else:
        labels = bert_tokenizer.convert_ids_to_tokens(indices)
    return labels, values


with open("translation_dict.json") as f:
    translate_dict = json.load(f)


def translate_labels(labels, lang_id):
    translated_labels = []
    for label in labels:
        if label == "<unk>":
            translated_labels.append("<unk>")
        elif lang_id == "en":
            translated_labels.append(label)
        else:
            translated_labels.append(translate_dict[label][lang_id])
    return translated_labels


def plotly_express_horizontal_bar_plot(values, labels):
    fig = px.bar(
        x=values,
        y=labels,
        text=[format(value, ".3%") for value in values],
        title="Top-5 Predictions",
        labels={"x": "Scores", "y": "Answers"},
        orientation="h",
    )
    return fig


def read_markdown(path, parent="./sections/"):
    with open(os.path.join(parent, path)) as f:
        return f.read()