Spaces:
Runtime error
Runtime error
import re | |
import torch | |
from torch import cuda | |
from nltk.tokenize.casual import TweetTokenizer | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
pretrained_name = "ml6team/xlm-roberta-base-nl-emoji-ner" | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_name) | |
model = AutoModelForTokenClassification.from_pretrained(pretrained_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
emoji_list = ['π¨', 'π₯', 'π', 'π ', 'π€―', 'π', 'πΎ', 'π', 'β', 'π°'] | |
emoji_names = ['B-AFRAID', 'B-SAD', 'B-LOVE', 'B-ANGRY', 'B-SHOCKED', 'B-LAUGH', 'B-CHAMP', 'B-CAR', 'B-COFFEE', 'B-MONEY'] | |
from collections import defaultdict | |
def default_value(): | |
return 0 | |
emoji2tag = {emo:tag for emo, tag in zip(emoji_list, emoji_names)} | |
emoji2tag['O'] = 'O' | |
tag2emoji = {val:key for key,val in emoji2tag.items()} | |
tag2idx = defaultdict(default_value) | |
# mapping of emojis and categories | |
for i, e in enumerate(emoji_names): | |
tag2idx[e] = i+1 | |
tag2idx['O'] = 0 | |
idx2tag = {val:key for key,val in tag2idx.items()} | |
def tag_text_sample(text): | |
t = TweetTokenizer() | |
words = t.tokenize(text) | |
# Get tokens with special characters | |
tokens = tokenizer(words, is_split_into_words=True).tokens() | |
# Encode the sequence into IDs | |
input_ids = tokenizer( | |
words, | |
add_special_tokens=True, | |
is_split_into_words=True, | |
return_tensors="pt").input_ids.to(device) | |
# Get predictions as distribution over 7 possible classes | |
model.eval() | |
outputs = model(input_ids)[0] | |
# Take argmax to get most likely class per token | |
predictions = torch.argmax(outputs, dim=2) | |
# Convert to DataFrame | |
preds = [idx2tag[p] for p in predictions[0].cpu().numpy()] | |
# Get word id's | |
word_ids = tokenizer(words, is_split_into_words=True).word_ids() | |
# Keep non-special tokens and labels | |
preds = preds[1:-1] | |
tokens = tokens[1:-1] | |
word_ids = word_ids[1:-1] | |
# Full sentence reconstruction | |
previous_word_idx = None | |
current_emoji = '' | |
word_emojis = {} | |
for i, (word_idx, pred) in enumerate(zip(word_ids, preds)): | |
if previous_word_idx is None or word_idx != previous_word_idx: | |
# Add emoji from previous word to sequence | |
word_emojis[word_idx] = current_emoji | |
# Check new emoji | |
current_emoji = tag2emoji[pred] if not pred == "O" else "" | |
previous_word_idx = word_idx | |
# If final word: add emoji | |
if i == len(word_ids) - 1: | |
word_emojis[word_idx] = current_emoji | |
# Reconstruct | |
full_text = [] | |
for i, word in enumerate(words): | |
full_text.append(word) | |
if word_emojis[i]: | |
full_text.append(word_emojis[i]) | |
return tokenizer.clean_up_tokenization(" ".join(full_text)) |