Spaces:
Runtime error
Runtime error
File size: 3,529 Bytes
e289356 fb3c77c 2bbf92c 393b8fc e289356 405f2d4 2bbf92c 690384a 393b8fc 919efff 393b8fc c684eba 393b8fc 690384a 2bbf92c 690384a 2bbf92c 2c8f495 2bbf92c 405f2d4 2bbf92c f4963f2 405f2d4 2bbf92c 690384a 2bbf92c 690384a 2bbf92c 690384a 2bbf92c 690384a fb3c77c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import json
import os
import numpy as np
import streamlit as st
import plotly.express as px
import torch
from torchvision.io import read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from transformers import BertTokenizerFast
class Toc:
def __init__(self):
self._items = []
self._placeholder = None
def title(self, text):
self._markdown(text, "h1")
def header(self, text):
self._markdown(text, "h2", " " * 2)
def subheader(self, text):
self._markdown(text, "h3", " " * 4)
def subsubheader(self, text):
self._markdown(text, "h4", " " * 8)
def placeholder(self, sidebar=False):
self._placeholder = st.sidebar.empty() if sidebar else st.empty()
def generate(self):
if self._placeholder:
self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
def _markdown(self, text, level, space=""):
key = "".join(filter(str.isalnum, text)).lower().replace('+','')
st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
self._items.append(f"{space}* <a href='#{key}'>{text}</a>")
class Transform(torch.nn.Module):
def __init__(self, image_size):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
return x
transform = Transform(224)
def get_transformed_image(image):
if image.shape[-1] == 3 and isinstance(image, np.ndarray):
image = image.transpose(2, 0, 1)
image = torch.tensor(image)
return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
def get_text_attributes(text):
return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
def get_top_5_predictions(logits, answer_reverse_mapping=None):
indices = np.argsort(logits)[-5:]
values = logits[indices]
if answer_reverse_mapping is not None:
labels = [answer_reverse_mapping[str(i)] for i in indices]
else:
labels = bert_tokenizer.convert_ids_to_tokens(indices)
return labels, values
with open("translation_dict.json") as f:
translate_dict = json.load(f)
def translate_labels(labels, lang_id):
translated_labels = []
for label in labels:
if label == "<unk>":
translated_labels.append("<unk>")
elif lang_id == "en":
translated_labels.append(label)
else:
translated_labels.append(translate_dict[label][lang_id])
return translated_labels
def plotly_express_horizontal_bar_plot(values, labels):
fig = px.bar(
x=values,
y=labels,
text=[format(value, ".3%") for value in values],
title="Top-5 Predictions",
labels={"x": "Scores", "y": "Answers"},
orientation="h",
)
return fig
def read_markdown(path, parent="./sections/"):
with open(os.path.join(parent, path)) as f:
return f.read() |