gchhablani's picture
Fix ToC imports
393b8fc
raw
history blame
3.43 kB
import json
import os
import numpy as np
import streamlit as st
import plotly.express as px
import torch
from torchvision.io import read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from transformers import BertTokenizerFast
class Toc:
def __init__(self):
self._items = []
self._placeholder = None
def title(self, text):
self._markdown(text, "h1")
def header(self, text):
self._markdown(text, "h2", " " * 2)
def subheader(self, text):
self._markdown(text, "h3", " " * 4)
def placeholder(self, sidebar=False):
self._placeholder = st.sidebar.empty() if sidebar else st.empty()
def generate(self):
if self._placeholder:
self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
def _markdown(self, text, level, space=""):
key = "".join(filter(str.isalnum, text)).lower()
st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
self._items.append(f"{space}* <a href='#{key}'>{text}</a>")
class Transform(torch.nn.Module):
def __init__(self, image_size):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
return x
transform = Transform(224)
def get_transformed_image(image):
if image.shape[-1] == 3 and isinstance(image, np.ndarray):
image = image.transpose(2, 0, 1)
image = torch.tensor(image)
return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
def get_text_attributes(text):
return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
def get_top_5_predictions(logits, answer_reverse_mapping=None):
indices = np.argsort(logits)[-5:]
values = logits[indices]
if answer_reverse_mapping is not None:
labels = [answer_reverse_mapping[str(i)] for i in indices]
else:
labels = bert_tokenizer.convert_ids_to_tokens(indices)
return labels, values
with open("translation_dict.json") as f:
translate_dict = json.load(f)
def translate_labels(labels, lang_id):
translated_labels = []
for label in labels:
if label == "<unk>":
translated_labels.append("<unk>")
elif lang_id == "en":
translated_labels.append(label)
else:
translated_labels.append(translate_dict[label][lang_id])
return translated_labels
def plotly_express_horizontal_bar_plot(values, labels):
fig = px.bar(
x=values,
y=labels,
text=[format(value, ".3%") for value in values],
title="Top-5 Predictions",
labels={"x": "Scores", "y": "Answers"},
orientation="h",
)
return fig
def read_markdown(path, parent="./sections/"):
with open(os.path.join(parent, path)) as f:
return f.read()