| """ |
| === MIA · Agent Emotion Predict Classifier (Text/BETO Embedder + MLP) === |
| Objetivo: predecir la emoción del AGENTE (label_agent: 0..5) a partir de: |
| - el TEXTO del usuario |
| - la EMOCIÓN del texto (label del usuario: 0..5) |
| |
| Arquitectura: |
| Texto ──▶ Embedder (TextEmbedder ó BETOEmbedder) ─▶ h_text ∈ R^D |
| Label usuario (0..5) ─▶ one-hot(6) ─▶ (feature dropout opcional) |
| Concatenación [h_text ; onehot_label] ─▶ MLP ─▶ logits (6) |
| |
| Notas: |
| - Si usas BETOEmbedder, se recomienda congelarlo (freeze) para esta segunda red. |
| - El feature dropout en la one-hot del label obliga al modelo a mirar el TEXTO en los casos ambiguos. |
| """ |
|
|
| from typing import List, Optional |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from emotion_classifier_model import TextEmbedder, BETOEmbedder |
|
|
|
|
| class FeatureDropout(nn.Module): |
| """Apaga aleatoriamente (con prob p) TODA la rama de la one-hot del label en entrenamiento. |
| Si p=0.2, en el 20% de los batches el modelo debe decidir solo con el texto. |
| """ |
| def __init__(self, p: float = 0.0): |
| super().__init__() |
| assert 0.0 <= p < 1.0 |
| self.p = p |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if not self.training or self.p <= 0.0: |
| return x |
| |
| mask = (torch.rand(x.size(0), 1, device=x.device) > self.p).float() |
| return x * mask |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, input_dim: int, hidden1: int = 256, hidden2: int = 64, num_classes: int = 6, dropout: float = 0.2): |
| super().__init__() |
| self.fc1 = nn.Linear(input_dim, hidden1) |
| self.fc2 = nn.Linear(hidden1, hidden2) |
| self.out = nn.Linear(hidden2, num_classes) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = F.relu(self.fc1(x)) |
| x = self.drop(x) |
| x = F.relu(self.fc2(x)) |
| x = self.drop(x) |
| return self.out(x) |
|
|
|
|
| class AgentEmotionPredictClassifier(nn.Module): |
| """ |
| Segunda red: predice la emoción del AGENTE (0..5) a partir de (texto, label_usuario). |
| |
| Parámetros clave: |
| - pretrained_encoder: None → TextEmbedder (emb_dim) |
| "beto" → BETOEmbedder (768D) |
| - label_feature_dropout: apaga la one-hot a veces para forzar al modelo a usar el texto en casos ambiguos. |
| """ |
| def __init__( |
| self, |
| model_name: str = "dccuchile/bert-base-spanish-wwm-cased", |
| pretrained_encoder: Optional[str] = "beto", |
| emb_dim: int = 300, |
| max_length: int = 128, |
| hidden1: int = 256, |
| hidden2: int = 64, |
| num_classes: int = 6, |
| dropout: float = 0.2, |
| label_feature_dropout: float = 0.15, |
| device: Optional[torch.device] = None, |
| ): |
| super().__init__() |
| self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| if pretrained_encoder == "beto": |
| self.embedder = BETOEmbedder(model_name=model_name, max_length=max_length, device=self.device) |
| embed_dim = 768 |
| else: |
| self.embedder = TextEmbedder(model_name=model_name, emb_dim=emb_dim, max_length=max_length, device=self.device) |
| embed_dim = emb_dim |
|
|
| self.label_dim = 6 |
| self.feat_drop = FeatureDropout(p=label_feature_dropout) |
| self.classifier = MLP(input_dim=embed_dim + self.label_dim, |
| hidden1=hidden1, hidden2=hidden2, |
| num_classes=num_classes, dropout=dropout) |
| self.to(self.device) |
|
|
| |
| @staticmethod |
| def _one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor: |
| |
| return F.one_hot(labels.long(), num_classes=num_classes).float() |
|
|
| def freeze_encoder(self): |
| for p in self.embedder.parameters(): |
| p.requires_grad = False |
|
|
| def unfreeze_encoder(self): |
| for p in self.embedder.parameters(): |
| p.requires_grad = True |
|
|
| |
| def forward(self, texts: List[str], user_labels: torch.Tensor) -> torch.Tensor: |
| """texts: lista de strings (len=B) |
| user_labels: tensor [B] con labels del usuario (0..5) |
| """ |
| h_text = self.embedder.embed_batch(texts) |
| onehot = self._one_hot(user_labels.to(h_text.device), self.label_dim) |
| onehot = self.feat_drop(onehot) |
| x = torch.cat([h_text, onehot], dim=-1) |
| logits = self.classifier(x) |
| return logits |
|
|
| @torch.inference_mode() |
| def predict(self, texts: List[str], user_labels: torch.Tensor): |
| self.eval() |
| logits = self.forward(texts, user_labels) |
| probs = logits.softmax(dim=-1) |
| preds = probs.argmax(dim=-1) |
| return preds, probs |
|
|
|
|
|
|
|
|