|
from model import DocBERT |
|
from models.lstm_model import DocumentBiLSTM |
|
from dataset import DataLoader, DocumentDataset |
|
from utils.word_segmentation_vi import word_segmentation_vi |
|
import numpy as np |
|
from transformers import AutoTokenizer |
|
import torch.nn.functional as F |
|
import torch |
|
|
|
args = { |
|
"bert_model": "vinai/phobert-base-v2", |
|
"model_path": "./vietnamese_hate_speech_detection_phobert/vinai_phobert-base-v2_finetuned.pth", |
|
"lstm_model_path": "./vietnamese_hate_speech_detection_phobert/distilled_lstm_model.pth", |
|
"max_seq_length": 250, |
|
"num_classes": 4, |
|
"num_categories": 5, |
|
} |
|
|
|
class_names = ["NORMAL", "CLEAN", "OFFENSIVE", "HATE"] |
|
|
|
def load_model_bert(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = DocBERT(bert_model_name=args["bert_model"], num_classes=args["num_classes"], num_categories=args["num_categories"]) |
|
model.load_state_dict(torch.load(args["model_path"], map_location=device)) |
|
model = model.to(device) |
|
return model, device |
|
|
|
def load_model_lstm(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args["bert_model"]) |
|
vocab_size = tokenizer.vocab_size |
|
model = DocumentBiLSTM(vocab_size=vocab_size, |
|
embedding_dim=300, |
|
hidden_dim=256, |
|
n_layers=2, |
|
output_dim=args["num_classes"] * args["num_categories"]) |
|
model.load_state_dict(torch.load(args["lstm_model_path"], map_location=device)["model_state_dict"]) |
|
model = model.to(device) |
|
return model, device |
|
|
|
def inference(model, device, comments: str | list, threshold: float = 0.5): |
|
if isinstance(comments, str): |
|
comments = [comments] |
|
elif not isinstance(comments, list): |
|
raise ValueError("comment must be a string or a list of strings") |
|
|
|
comments = np.array([word_segmentation_vi(comment) for comment in comments]) |
|
data = DocumentDataset(texts=comments, labels=None, tokenizer_name=args["bert_model"], max_length=args["max_seq_length"]) |
|
inference_loader = DataLoader(data, batch_size=comments.shape[0], shuffle=False) |
|
|
|
batch = next(iter(inference_loader)) |
|
input_ids = batch['input_ids'] |
|
attention_mask = batch['attention_mask'] |
|
token_type_ids = batch['token_type_ids'] |
|
|
|
input_ids = input_ids.to(device) |
|
attention_mask = attention_mask.to(device) |
|
token_type_ids = token_type_ids.to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
if args["num_categories"] > 1: |
|
batch_size, total_classes = outputs.shape |
|
if total_classes % args["num_categories"] != 0: |
|
raise ValueError("Error: Number of total classes in the batch must of divisible by the number of categories.") |
|
|
|
classes_per_group = total_classes // args["num_categories"] |
|
|
|
reshaped = outputs.view(outputs.size(0), -1, classes_per_group) |
|
probs = F.softmax(reshaped, dim=1) |
|
|
|
|
|
probs = torch.where(probs > threshold, probs, 0.0) |
|
print("Probabilities: ", probs) |
|
|
|
predictions = probs.argmax(dim=-1) |
|
else: |
|
predictions = torch.argmax(outputs, dim=-1) |
|
|
|
preds_array = predictions.cpu().numpy() |
|
result = [] |
|
for i in range(preds_array.shape[0]): |
|
result.append( |
|
{ |
|
"Bình luận": comments[i], |
|
"Cá nhân": class_names[ preds_array[i, 0] ], |
|
"Nhóm/tổ chức": class_names[ preds_array[i, 1] ], |
|
"Tôn giáo/tín ngưỡng": class_names[ preds_array[i, 2] ], |
|
"Chủng tộc/sắc tộc": class_names[ preds_array[i, 3] ], |
|
"Chính trị": class_names[ preds_array[i, 4] ], |
|
}) |
|
return result |
|
|
|
if __name__ == "__main__": |
|
|
|
model, device = load_model_bert() |
|
'''comments = [ |
|
"Em ăn hoành thánh sáng bị khó chịu mắc ói quá bỏ ăn trưa luôn. Các thím thường hay uống gì cho đỡ vậy? Em tính làm gói gừng pha uống", |
|
"Quan trọng là năm nay có tham gia những lễ hội có tính chất, quy mô và bối cảnh y hệt vậy không? Chứ tôi nói thật, dù ở bất cứ đâu mà tập trung đông đến mức không tiến không lùi như này được thì đều nguy hiểm. Khoan nói về giẫm đạp, chỉ riêng việc có sự cố đột xuất xảy ra thì chuyện cấp cứu nó sẽ vô cùng khó khăn và mất rất nhiều thời gian. Bởi vậy, tôi từ chối tham gia tất cả lễ hội nơi mà số người vượt tải đến mức không thể nhúc nhích như thế này.", |
|
"Còn phải tốn hơn nữa mới được", |
|
"Mình k có ý kích dục fen nhé :v Có sao kể vậy thôi.", |
|
"Này là lúc trước khi gặp P hả bác? Em thắc mắc là bác có thể thẳng thừng chặn C - người bác yêu như vậy à?", |
|
"Thì mượt hơn là đúng thôi. Mới phát triển thì không có nhiều tính năng, không có nhiều app thì chả mượt", |
|
]''' |
|
comments = [ |
|
"đúng là vozer, nhiều thằng sống ngu và ích kỷ vcl, nếu như người yêu nó cần 1 trái thận, lúc đó bản thân suy nghĩ tính toán thì ok, này chạy xe có 40km mà tính toán chi ly, mua cái váy mà mặc đi", |
|
"Khác mẹ gì tàu khựa, bơm tiền cho đám NGO woke đi biểu tình phá lại bọn tây lông thôi. Chó chê mèo lắm lông. À mà acc Emma Roberts bị ban rồi à mày", |
|
"đùa, cái shop thế mà cũng bảo chính hãng, vả vỡ alo nó đi. ra trung tâm thương mại, hay cửa hàng chính hãng mà mua.", |
|
"qua thớt này của nó thì 90% là xiaolol rùi", |
|
"thằng này chuyên đăng bài để hả hê, khóa mõm nó đi mod", |
|
"Đm nhẫm vào đuổi con bò đỏ này nó giãy nảy cắn người kinh thật @@ Tao có hay ko liên quan lol gì mà mày có vẻ cay cú vkl nhỉ, chắc gato với tao hả ))", |
|
"Sao thế óc chó, bị chửi cho ngu người rồi à =]] thứ ngu học chả biết mẹ gì vào sủa như đúng rồi =]]", |
|
] |
|
predictions = inference(model, device, comments) |
|
print("BERT Predictions:") |
|
print(predictions) |
|
|
|
lstm_model, device = load_model_lstm() |
|
lstm_predictions = inference(lstm_model, device, comments) |
|
print("LSTM Predictions:") |
|
print(lstm_predictions) |