Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import torch | |
import streamlit as st | |
from transformers import AutoModelForTokenClassification, AutoTokenizer | |
from dante_tokenizer import DanteTokenizer | |
from dante_tokenizer.data.preprocessing import expand_contractions | |
from annotated_text import annotated_text | |
def get_pos_tag_model(model_name: str = "Emanuel/autonlp-pos-tag-bosque") -> Tuple[AutoModelForTokenClassification, AutoTokenizer]: | |
model = AutoModelForTokenClassification.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return model, tokenizer | |
def get_tag_color(tag: str) -> str: | |
""" | |
Return the color for a given part-of-speech tag from the Universal Dependencies tagset. | |
See: https://universaldependencies.org/u/pos/ | |
""" | |
pallete = { | |
"ADJ": "#2E4C6D", | |
"ADP": "#FBE7C6", | |
"ADV": "#DADDFC", | |
"AUX": "#FC997C", | |
"CCONJ": "#544179", | |
"DET": "#A0E7E5", | |
"INTJ": "#32C1CD", | |
"NOUN": "#17D7A0", | |
"PART": "#C85C5C", | |
"PRON": "#F9975D", | |
"PROPN": "#FBD148", | |
"PUNCT": "#B2EA70", | |
"SCONJ": "#AA14F0", | |
"SYM": "#34BE82", | |
"VERB": "#FFBF86", | |
"X": "#2F86A6", | |
"NUM": "#F39B6D", | |
} | |
return pallete[tag] | |
def main(): | |
text = st.text_area("Digite seu texto de entrada!") | |
dt = DanteTokenizer() | |
model, tokenizer = get_pos_tag_model() | |
if text: | |
tokens = dt.tokenize(text) | |
input_cleaned_text = expand_contractions(text) | |
inputs = tokenizer(text, return_tensors="pt") | |
outputs = model(**inputs) | |
labelids = outputs.logits.squeeze().argmax(axis=-1) | |
scores, _ = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().max(axis=-1) | |
scores = scores.tolist() | |
labels = [model.config.id2label[int(x)] for x in labelids] | |
labels = labels[1:-1] | |
answer = [] | |
for token, label, score in zip(tokens, labels, scores): | |
answer.append((token, label, get_tag_color(label))) | |
annotated_text(*answer) | |
if __name__ == "__main__": | |
main() |