|
import torch |
|
import torch.nn.functional as F |
|
import pickle |
|
import re |
|
|
|
|
|
model = torch.load("models/model", map_location='cpu') |
|
tokenizer = torch.load("models/tokenizer") |
|
|
|
with open("models/label_dict", 'rb') as file: |
|
label_dict = pickle.load(file) |
|
|
|
def preprocess_string(tweet: str) -> str: |
|
tweet = tweet.lower().strip() |
|
tweet = re.sub(r'[^\w\s]', '', tweet) |
|
return tweet |
|
|
|
def predict_single(tweet: str) -> str: |
|
clean_tweet = preprocess_string(tweet) |
|
input = tokenizer(clean_tweet, return_tensors='pt', truncation=True) |
|
output = model(**input) |
|
pred = torch.max(F.softmax(output.logits, dim=-1), dim=-1)[1] |
|
pred = pred.data.item() |
|
return label_dict[pred] |
|
|
|
def predict_batch(tweets): |
|
clean_tweets = [preprocess_string(tweet) for tweet in tweets] |
|
inputs = tokenizer(clean_tweets, return_tensors='pt', padding=True, truncation=True) |
|
outputs = model(**inputs) |
|
preds = torch.max(F.softmax(outputs.logits, dim=-1), dim=-1)[1] |
|
preds = preds.tolist() |
|
return [label_dict[pred] for pred in preds] |
|
|