Spaces:
Runtime error
Runtime error
Delete model_tools.py
Browse files- model_tools.py +0 -99
model_tools.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import torch
|
3 |
-
from torch import cuda
|
4 |
-
from nltk.tokenize.casual import TweetTokenizer
|
5 |
-
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
6 |
-
|
7 |
-
pretrained_name = "ml6team/xlm-roberta-base-nl-emoji-ner"
|
8 |
-
tokenizer = AutoTokenizer.from_pretrained(pretrained_name)
|
9 |
-
model = AutoModelForTokenClassification.from_pretrained(pretrained_name)
|
10 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
-
|
12 |
-
emoji_list = ['π¨', 'π₯', 'π', 'π ', 'π€―', 'π', 'πΎ', 'π', 'β', 'π°']
|
13 |
-
emoji_names = ['B-AFRAID', 'B-SAD', 'B-LOVE', 'B-ANGRY', 'B-SHOCKED', 'B-LAUGH', 'B-CHAMP', 'B-CAR', 'B-COFFEE', 'B-MONEY']
|
14 |
-
|
15 |
-
from collections import defaultdict
|
16 |
-
|
17 |
-
def default_value():
|
18 |
-
return 0
|
19 |
-
|
20 |
-
emoji2tag = {emo:tag for emo, tag in zip(emoji_list, emoji_names)}
|
21 |
-
emoji2tag['O'] = 'O'
|
22 |
-
tag2emoji = {val:key for key,val in emoji2tag.items()}
|
23 |
-
|
24 |
-
tag2idx = defaultdict(default_value)
|
25 |
-
|
26 |
-
# mapping of emojis and categories
|
27 |
-
for i, e in enumerate(emoji_names):
|
28 |
-
tag2idx[e] = i+1
|
29 |
-
|
30 |
-
tag2idx['O'] = 0
|
31 |
-
|
32 |
-
|
33 |
-
idx2tag = {val:key for key,val in tag2idx.items()}
|
34 |
-
|
35 |
-
|
36 |
-
def tag_text_sample(text):
|
37 |
-
|
38 |
-
t = TweetTokenizer()
|
39 |
-
|
40 |
-
words = t.tokenize(text)
|
41 |
-
# Get tokens with special characters
|
42 |
-
tokens = tokenizer(words, is_split_into_words=True).tokens()
|
43 |
-
|
44 |
-
# Encode the sequence into IDs
|
45 |
-
input_ids = tokenizer(
|
46 |
-
words,
|
47 |
-
add_special_tokens=True,
|
48 |
-
is_split_into_words=True,
|
49 |
-
return_tensors="pt").input_ids.to(device)
|
50 |
-
|
51 |
-
# Get predictions as distribution over 7 possible classes
|
52 |
-
model.eval()
|
53 |
-
outputs = model(input_ids)[0]
|
54 |
-
|
55 |
-
# Take argmax to get most likely class per token
|
56 |
-
predictions = torch.argmax(outputs, dim=2)
|
57 |
-
|
58 |
-
# Convert to DataFrame
|
59 |
-
preds = [idx2tag[p] for p in predictions[0].cpu().numpy()]
|
60 |
-
|
61 |
-
# Get word id's
|
62 |
-
word_ids = tokenizer(words, is_split_into_words=True).word_ids()
|
63 |
-
|
64 |
-
# Keep non-special tokens and labels
|
65 |
-
preds = preds[1:-1]
|
66 |
-
tokens = tokens[1:-1]
|
67 |
-
word_ids = word_ids[1:-1]
|
68 |
-
|
69 |
-
# Full sentence reconstruction
|
70 |
-
previous_word_idx = None
|
71 |
-
current_emoji = ''
|
72 |
-
|
73 |
-
word_emojis = {}
|
74 |
-
|
75 |
-
for i, (word_idx, pred) in enumerate(zip(word_ids, preds)):
|
76 |
-
|
77 |
-
if previous_word_idx is None or word_idx != previous_word_idx:
|
78 |
-
# Add emoji from previous word to sequence
|
79 |
-
word_emojis[word_idx] = current_emoji
|
80 |
-
|
81 |
-
# Check new emoji
|
82 |
-
current_emoji = tag2emoji[pred] if not pred == "O" else ""
|
83 |
-
|
84 |
-
previous_word_idx = word_idx
|
85 |
-
|
86 |
-
# If final word: add emoji
|
87 |
-
if i == len(word_ids) - 1:
|
88 |
-
word_emojis[word_idx] = current_emoji
|
89 |
-
|
90 |
-
# Reconstruct
|
91 |
-
full_text = []
|
92 |
-
for i, word in enumerate(words):
|
93 |
-
|
94 |
-
full_text.append(word)
|
95 |
-
|
96 |
-
if word_emojis[i]:
|
97 |
-
full_text.append(word_emojis[i])
|
98 |
-
|
99 |
-
return tokenizer.clean_up_tokenization(" ".join(full_text))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|