| | import torch |
| | import torch.nn as nn |
| | from transformers import DebertaV2Tokenizer , DebertaV2Model |
| | from typing import Dict, Any |
| | import joblib |
| | import os |
| |
|
| | |
| | class EndpointHandler: |
| | def __init__(self, model_path =""): |
| | |
| | self.tokenizer = DebertaV2Tokenizer.from_pretrained(model_path) |
| | |
| | |
| | self.model = MultitaskDebertaModel(num_emotion_labels=8, num_polarity_labels=4, num_hate_speech_labels=2) |
| | self.model.load_state_dict(torch.load(os.path.join(model_path, 'pytorch_model.bin'))) |
| | |
| | |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | self.model.to(self.device) |
| | self.model.eval() |
| | |
| | |
| | self.emotion_encoder = joblib.load(os.path.join(model_path, 'emotion_encoder.pkl')) |
| | self.polarity_encoder = joblib.load(os.path.join(model_path, 'polarity_encoder.pkl')) |
| | self.hate_speech_encoder = joblib.load(os.path.join(model_path, 'hate_speech_encoder.pkl')) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | |
| | texts = data.get('inputs', []) |
| |
|
| | |
| | batch_size = 32 |
| | results = { |
| | "emotions": [], |
| | "polarities": [], |
| | "hate_speech": [] |
| | } |
| |
|
| | for i in range(0, len(texts), batch_size): |
| | batch_texts = texts[i:i+batch_size] |
| | |
| | |
| | inputs = self.tokenizer(batch_texts, return_tensors='pt', max_length=256, truncation=True, padding=True) |
| | if 'token_type_ids' in inputs: |
| | del inputs['token_type_ids'] |
| | inputs = {key: val.to(self.device) for key, val in inputs.items()} |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | emotion_logits = outputs.get('emotion') |
| | polarity_logits = outputs.get('polarity') |
| | hate_speech_logits = outputs.get('hate_speech') |
| | |
| | |
| | emotion_preds = torch.argmax(emotion_logits, dim=1).cpu().numpy().tolist() |
| | polarity_preds = torch.argmax(polarity_logits, dim=1).cpu().numpy().tolist() |
| | hate_speech_preds = torch.argmax(hate_speech_logits, dim=1).cpu().numpy().tolist() |
| | |
| | |
| | decoded_emotions = self.emotion_encoder.inverse_transform(emotion_preds).tolist() |
| | decoded_polarities = self.polarity_encoder.inverse_transform(polarity_preds).tolist() |
| | decoded_hate_speech = self.hate_speech_encoder.inverse_transform(hate_speech_preds).tolist() |
| |
|
| | results["emotions"].extend(decoded_emotions) |
| | results["polarities"].extend(decoded_polarities) |
| | results["hate_speech"].extend(decoded_hate_speech) |
| |
|
| | return results |
| |
|
| | def load_model(self, model_path): |
| | |
| | self.load_state_dict(torch.load(model_path)) |
| |
|
| | |
| | class MultitaskDebertaModel(nn.Module): |
| | def __init__(self, num_emotion_labels, num_polarity_labels, num_hate_speech_labels): |
| | super(MultitaskDebertaModel, self).__init__() |
| | self.deberta = DebertaV2Model.from_pretrained('microsoft/deberta-v3-base') |
| |
|
| | |
| | for param in self.deberta.encoder.layer[:5]: |
| | for p in param.parameters(): |
| | p.requires_grad = False |
| |
|
| | |
| | self.emotion_lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True) |
| | self.polarity_lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True) |
| | self.hate_speech_lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True) |
| |
|
| | |
| | self.emotion_attention = nn.MultiheadAttention(embed_dim=256, num_heads=8, batch_first=True) |
| | self.polarity_attention = nn.MultiheadAttention(embed_dim=256, num_heads=8, batch_first=True) |
| | self.hate_speech_attention = nn.MultiheadAttention(embed_dim=256, num_heads=8, batch_first=True) |
| |
|
| | |
| | self.emotion_dense = nn.Linear(256, 128) |
| | self.polarity_dense = nn.Linear(256, 128) |
| | self.hate_speech_dense = nn.Linear(256, 128) |
| |
|
| | |
| | self.fusion_dense = nn.Linear(128 + 128 + 128 + 768, 128) |
| |
|
| | |
| | self.emotion_classifier = nn.Linear(128, num_emotion_labels) |
| | self.polarity_classifier = nn.Linear(128, num_polarity_labels) |
| | self.hate_speech_classifier = nn.Linear(128, num_hate_speech_labels) |
| |
|
| | |
| | self.layer_norm = nn.LayerNorm(128) |
| | self.dropout = nn.Dropout(p=0.3) |
| | self.relu = nn.ReLU() |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | |
| | deberta_outputs = self.deberta(input_ids, attention_mask=attention_mask) |
| | sequence_output = deberta_outputs.last_hidden_state |
| | cls_output = sequence_output[:, 0, :] |
| |
|
| | |
| | emotion_lstm_output, _ = self.emotion_lstm(sequence_output) |
| | polarity_lstm_output, _ = self.polarity_lstm(sequence_output) |
| | hate_speech_lstm_output, _ = self.hate_speech_lstm(sequence_output) |
| |
|
| | |
| | emotion_attention_output, _ = self.emotion_attention(emotion_lstm_output, emotion_lstm_output, emotion_lstm_output) |
| | polarity_attention_output, _ = self.polarity_attention(polarity_lstm_output, polarity_lstm_output, polarity_lstm_output) |
| | hate_speech_attention_output, _ = self.hate_speech_attention(hate_speech_lstm_output, hate_speech_lstm_output, hate_speech_lstm_output) |
| |
|
| | |
| | emotion_features = torch.mean(emotion_attention_output, dim=1) |
| | polarity_features = torch.mean(polarity_attention_output, dim=1) |
| | hate_speech_features = torch.mean(hate_speech_attention_output, dim=1) |
| |
|
| | |
| | emotion_features = self.relu(self.emotion_dense(emotion_features)) |
| | polarity_features = self.relu(self.polarity_dense(polarity_features)) |
| | hate_speech_features = self.relu(self.hate_speech_dense(hate_speech_features)) |
| |
|
| | |
| | combined_features = torch.cat([emotion_features, polarity_features, hate_speech_features, cls_output], dim=-1) |
| | combined_features = self.relu(self.fusion_dense(combined_features)) |
| |
|
| | |
| | combined_features = self.layer_norm(combined_features) |
| | combined_features = self.dropout(combined_features) |
| |
|
| | |
| | emotion_logits = self.emotion_classifier(combined_features) |
| | polarity_logits = self.polarity_classifier(combined_features) |
| | hate_speech_logits = self.hate_speech_classifier(combined_features) |
| |
|
| | return { |
| | 'emotion': emotion_logits, |
| | 'polarity': polarity_logits, |
| | 'hate_speech': hate_speech_logits |
| | } |
| |
|