|
|
"""
|
|
|
Module định nghĩa các mô hình cho spam review detection
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import AutoModel, AutoConfig, AutoModelForSequenceClassification
|
|
|
from .custom_models import TextCNN, BiLSTM, RoBERTaGRU, SPhoBERT
|
|
|
|
|
|
class TransformerForSpamDetection(nn.Module):
|
|
|
"""
|
|
|
Base transformer model cho spam review detection
|
|
|
"""
|
|
|
def __init__(self, model_name: str, num_labels: int):
|
|
|
super().__init__()
|
|
|
config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
|
|
|
self.encoder = AutoModel.from_pretrained(model_name, config=config)
|
|
|
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
|
|
self.dropout = nn.Dropout(0.1)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
|
|
|
|
|
|
filtered_kwargs = {k: v for k, v in kwargs.items()
|
|
|
if k not in ['num_items_in_batch', 'position_ids']}
|
|
|
|
|
|
|
|
|
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs)
|
|
|
pooled = out.last_hidden_state[:, 0]
|
|
|
pooled = self.dropout(pooled)
|
|
|
logits = self.classifier(pooled)
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
loss = loss_fn(logits, labels)
|
|
|
return {"loss": loss, "logits": logits}
|
|
|
|
|
|
class ViT5ForSpamDetection(nn.Module):
|
|
|
"""
|
|
|
ViT5 model cho spam review detection - sử dụng encoder-only approach
|
|
|
"""
|
|
|
def __init__(self, model_name: str, num_labels: int):
|
|
|
super().__init__()
|
|
|
from transformers import T5EncoderModel, T5Config
|
|
|
|
|
|
|
|
|
config = T5Config.from_pretrained(model_name)
|
|
|
self.t5_encoder = T5EncoderModel.from_pretrained(model_name, config=config)
|
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(config.d_model, num_labels)
|
|
|
self.dropout = nn.Dropout(0.1)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
|
|
|
|
|
|
filtered_kwargs = {k: v for k, v in kwargs.items()
|
|
|
if k not in ['num_items_in_batch', 'position_ids']}
|
|
|
|
|
|
|
|
|
encoder_outputs = self.t5_encoder(input_ids=input_ids, attention_mask=attention_mask, **filtered_kwargs)
|
|
|
|
|
|
|
|
|
pooled = encoder_outputs.last_hidden_state[:, 0]
|
|
|
pooled = self.dropout(pooled)
|
|
|
logits = self.classifier(pooled)
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
loss = loss_fn(logits, labels)
|
|
|
|
|
|
return {"loss": loss, "logits": logits}
|
|
|
|
|
|
def get_model(model_name: str, num_labels: int, vocab_size: int = None):
|
|
|
"""
|
|
|
Factory function để tạo model dựa trên tên model
|
|
|
|
|
|
Args:
|
|
|
model_name: Tên model (phobert-v2, textcnn, bilstm, etc.)
|
|
|
num_labels: Số lượng classes
|
|
|
vocab_size: Kích thước vocabulary (chỉ cần cho BiLSTM-CRF)
|
|
|
|
|
|
Returns:
|
|
|
Model instance
|
|
|
"""
|
|
|
|
|
|
model_mapping = {
|
|
|
"phobert-v1": "vinai/phobert-base",
|
|
|
"phobert-v2": "vinai/phobert-base-v2",
|
|
|
"bartpho": "vinai/bartpho-syllable",
|
|
|
"visobert": "uitnlp/visobert",
|
|
|
"xlm-r": "xlm-roberta-large",
|
|
|
"mbert": "bert-base-multilingual-cased",
|
|
|
"vit5": "VietAI/vit5-base"
|
|
|
}
|
|
|
|
|
|
if model_name == "vit5":
|
|
|
|
|
|
base_model_name = model_mapping[model_name]
|
|
|
return ViT5ForSpamDetection(base_model_name, num_labels)
|
|
|
elif model_name in model_mapping:
|
|
|
|
|
|
base_model_name = model_mapping[model_name]
|
|
|
return TransformerForSpamDetection(base_model_name, num_labels)
|
|
|
|
|
|
elif model_name == "textcnn":
|
|
|
|
|
|
base_model_name = "vinai/phobert-base-v2"
|
|
|
return TextCNN(base_model_name, num_labels)
|
|
|
|
|
|
elif model_name == "bilstm":
|
|
|
|
|
|
base_model_name = "vinai/phobert-base-v2"
|
|
|
return BiLSTM(base_model_name, num_labels)
|
|
|
|
|
|
elif model_name == "roberta-gru":
|
|
|
|
|
|
base_model_name = "vinai/phobert-base-v2"
|
|
|
return RoBERTaGRU(base_model_name, num_labels)
|
|
|
|
|
|
elif model_name == "sphobert":
|
|
|
|
|
|
base_model_name = "vinai/phobert-base-v2"
|
|
|
return SPhoBERT(base_model_name, num_labels)
|
|
|
|
|
|
elif model_name == "bilstm-crf":
|
|
|
|
|
|
|
|
|
base_model_name = "vinai/phobert-base-v2"
|
|
|
return BiLSTM(base_model_name, num_labels)
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Unknown model name: {model_name}. Available models: {list(model_mapping.keys()) + ['textcnn', 'bilstm', 'roberta-gru', 'sphobert', 'bilstm-crf']}")
|
|
|
|
|
|
def get_model_config(model_name: str):
|
|
|
"""
|
|
|
Lấy cấu hình cho model
|
|
|
|
|
|
Args:
|
|
|
model_name: Tên model
|
|
|
|
|
|
Returns:
|
|
|
Dict chứa cấu hình model
|
|
|
"""
|
|
|
configs = {
|
|
|
"phobert-v1": {
|
|
|
"model_name": "vinai/phobert-base",
|
|
|
"description": "PhoBERT v1 - Pre-trained BERT for Vietnamese",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 5e-5
|
|
|
},
|
|
|
"phobert-v2": {
|
|
|
"model_name": "vinai/phobert-base-v2",
|
|
|
"description": "PhoBERT v2 - Improved PhoBERT for Vietnamese",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 5e-5
|
|
|
},
|
|
|
"bartpho": {
|
|
|
"model_name": "vinai/bartpho-syllable",
|
|
|
"description": "BART Pho - Vietnamese BART model",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 5e-5
|
|
|
},
|
|
|
"visobert": {
|
|
|
"model_name": "uitnlp/visobert",
|
|
|
"description": "ViSoBERT - Vietnamese Social BERT",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 5e-5
|
|
|
},
|
|
|
"xlm-r": {
|
|
|
"model_name": "xlm-roberta-large",
|
|
|
"description": "XLM-RoBERTa Large - Multilingual model",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 3e-5
|
|
|
},
|
|
|
"mbert": {
|
|
|
"model_name": "bert-base-multilingual-cased",
|
|
|
"description": "mBERT - Multilingual BERT model",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 5e-5
|
|
|
},
|
|
|
"vit5": {
|
|
|
"model_name": "VietAI/vit5-base",
|
|
|
"description": "ViT5 - Vietnamese T5",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 5e-5
|
|
|
},
|
|
|
"textcnn": {
|
|
|
"model_name": "vinai/phobert-base-v2",
|
|
|
"description": "TextCNN - Convolutional Neural Network for text",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 1e-3,
|
|
|
"custom_model": True
|
|
|
},
|
|
|
"bilstm": {
|
|
|
"model_name": "vinai/phobert-base-v2",
|
|
|
"description": "BiLSTM - Bidirectional LSTM for text classification",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 1e-3,
|
|
|
"custom_model": True
|
|
|
},
|
|
|
"roberta-gru": {
|
|
|
"model_name": "vinai/phobert-base-v2",
|
|
|
"description": "RoBERTa-GRU - Hybrid RoBERTa + GRU model",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 5e-5,
|
|
|
"custom_model": True
|
|
|
},
|
|
|
"sphobert": {
|
|
|
"model_name": "vinai/phobert-base-v2",
|
|
|
"description": "SPhoBERT - PhoBERT + SentenceBERT embedding fusion",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 5e-5,
|
|
|
"custom_model": True
|
|
|
},
|
|
|
"bilstm-crf": {
|
|
|
"model_name": "vinai/phobert-base-v2",
|
|
|
"description": "BiLSTM-CRF - Bidirectional LSTM with CRF",
|
|
|
"max_length": 256,
|
|
|
"learning_rate": 1e-3,
|
|
|
"custom_model": True
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if model_name not in configs:
|
|
|
raise ValueError(f"Model {model_name} not found. Available models: {list(configs.keys())}")
|
|
|
|
|
|
return configs[model_name] |