|
import torch |
|
from transformers import DistilBertTokenizer, DistilBertModel |
|
from torch import nn |
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
model_path = './best_model_state.bin' |
|
bert_model = DistilBertModel.from_pretrained('distilbert-base-uncased') |
|
|
|
class SentimentClassifier(nn.Module): |
|
def __init__(self, n_classes): |
|
super(SentimentClassifier, self).__init__() |
|
self.bert = bert_model |
|
self.pre_classifier = nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size) |
|
self.dropout = nn.Dropout(p=0.3) |
|
self.out = nn.Linear(self.bert.config.hidden_size, n_classes) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
last_hidden_state = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
)[0] |
|
|
|
|
|
pooled_output = self.pre_classifier(last_hidden_state[:, 0]) |
|
pooled_output = nn.ReLU()(pooled_output) |
|
pooled_output = self.dropout(pooled_output) |
|
|
|
return self.out(pooled_output) |
|
|
|
|
|
model = SentimentClassifier(n_classes=2) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
def preprocess(text): |
|
max_len = 256 |
|
encoding = tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=max_len, |
|
return_token_type_ids=False, |
|
padding='max_length', |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
return encoding['input_ids'], encoding['attention_mask'] |
|
|
|
|
|
def predict_sentiment(text): |
|
input_ids, attention_mask = preprocess(text) |
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
probs = torch.nn.functional.softmax(outputs, dim=1) |
|
|
|
prediction = torch.argmax(probs, dim=1) |
|
return prediction.item() |
|
|
|
|
|
|
|
while True: |
|
|
|
input_text = input("Enter a sentence: ") |
|
prediction = predict_sentiment(input_text) |
|
class_names = ["positive", "negative"] |
|
|
|
print(f'Text: {input_text}') |
|
print(f'Sentiment: {class_names[prediction]}') |
|
|