|
import gradio as gr |
|
|
|
import os |
|
import torch |
|
import numpy as np |
|
from transformers import AutoModelForSequenceClassification |
|
from transformers import AutoTokenizer |
|
from huggingface_hub import HfApi |
|
|
|
from label_dicts import MANIFESTO_LABEL_NAMES |
|
|
|
HF_TOKEN = os.environ["hf_read"] |
|
|
|
languages = [ |
|
"Armenian", "Bulgarian", "Croatian", "Czech", "Danish", "Dutch", "English", |
|
"Estonian", "Finnish", "French", "Georgian", "German", "Greek", "Hebrew", |
|
"Hungarian", "Icelandic", "Italian", "Japanese", "Korean", "Latvian", |
|
"Lithuanian", "Norwegian", "Polish", "Portuguese", "Romanian", "Russian", |
|
"Serbian", "Slovak", "Slovenian", "Spanish", "Swedish", "Turkish" |
|
] |
|
|
|
def build_huggingface_path(language: str): |
|
return "poltextlab/xlm-roberta-large-manifesto" |
|
|
|
def predict(text, model_id, tokenizer_id): |
|
device = torch.device("cpu") |
|
model = AutoModelForSequenceClassification.from_pretrained(model_id, token=HF_TOKEN) |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) |
|
model.to(device) |
|
|
|
inputs = tokenizer(text, |
|
max_length=512, |
|
truncation=True, |
|
padding="do_not_pad", |
|
return_tensors="pt").to(device) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten() |
|
output_pred = {f"[{model.config.id2label[i]}] {MANIFESTO_LABEL_NAMES[int(model.config.id2label[i])]}": probs[i] for i in np.argsort(probs)[::-1]} |
|
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>' |
|
return output_pred, output_info |
|
|
|
def predict_cap(text, language): |
|
model_id = build_huggingface_path(language) |
|
tokenizer_id = "xlm-roberta-large" |
|
return predict(text, model_id, tokenizer_id) |
|
|
|
demo = gr.Interface( |
|
fn=predict_cap, |
|
inputs=[gr.Textbox(lines=6, label="Input"), |
|
gr.Dropdown(languages, label="Language")], |
|
outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()]) |