MiaPredict / inference_agent_emotion_classifier.py
RustyLinux's picture
Update name of model
7d730d3 verified
# -*- coding: utf-8 -*-
"""
Inferencia para el AgentEmotionPredictClassifier (MIA · segunda red)
- Busca 'best_model.pt' y 'config_agent.json' en local; si no están y hay
huggingface_hub instalado, los descarga del repo indicado.
- La config DEBE incluir, como mínimo:
{
"base_model_id": "dccuchile/bert-base-spanish-wwm-cased",
"max_length": 128,
"hidden1": 256,
"hidden2": 64,
"num_classes": 2,
"dropout": 0.4,
"label_feature_dropout": 0.5,
"pretrained_encoder": "beto",
"present_classes": [1, 2], # ids originales (0..5) presentes en train
"class_names": ["alegría","amor"] # nombres en el mismo orden del mapeo 0..K-1
}
- Uso:
from inference_agent_emotion import predict
y = predict("No me siento bien", user_label=0) # 0..5 (tristeza..sorpresa)
"""
from __future__ import annotations
import os
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional, Union
import torch
# Opcional: descarga desde HF si no hay archivos locales
try:
from huggingface_hub import hf_hub_download
except Exception:
hf_hub_download = None
from agent_emotion_predict_classifier import AgentEmotionPredictClassifier
# ---------------- Config ----------------
REPO_ID = "RustyLinux/MiaPredict" # cambia por tu repo si usas el Hub
LOCAL_CKPT = Path("best_model_agent.pt") # checkpoint de la segunda red
LOCAL_CFG = Path("config_agent.json") # config de la segunda red
# Mapa global de emociones (usuario y también nombres canónicos)
EMOTION_ID2NAME = {
0: "tristeza",
1: "alegría",
2: "amor",
3: "ira",
4: "miedo",
5: "sorpresa",
}
EMOTION_NAME2ID = {v: k for k, v in EMOTION_ID2NAME.items()}
_device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model: Optional[AgentEmotionPredictClassifier] = None
_cfg: Optional[Dict[str, Any]] = None
_label_map_fwd: Optional[Dict[int, int]] = None # original_id -> idx(0..K-1) usado en entrenamiento
_label_map_inv: Optional[Dict[int, int]] = None # idx(0..K-1) -> original_id (para devolver nombre global)
# ---------------- Utilidades internas ----------------
def _resolve_paths() -> Tuple[str, str]:
"""
Retorna (ckpt_path, cfg_path). Prefiere local; si no, intenta descarga HF.
"""
if LOCAL_CKPT.exists() and LOCAL_CFG.exists():
print("✅ Cargando archivos desde local.")
return str(LOCAL_CKPT.resolve()), str(LOCAL_CFG.resolve())
if hf_hub_download is None:
raise RuntimeError(
"No se encontraron 'best_model_agent.pt' y 'config_agent.json' en local, "
"y 'huggingface_hub' no está instalado para descargarlos."
)
print("⬇️ Descargando archivos desde Hugging Face Hub...")
ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="best_model_agent.pt")
cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config_agent.json")
return ckpt_path, cfg_path
def _prepare_label_maps(cfg: Dict[str, Any]) -> Tuple[Dict[int, int], Dict[int, int]]:
"""
Construye los mapeos entre ids originales (0..5) y los índices 0..K-1 usados por la head.
"""
present = cfg.get("present_classes", None)
if not present:
# Por compatibilidad: si no viene, asumimos [0..num_classes-1], pero se recomienda guardarlo.
k = int(cfg.get("num_classes", 2))
present = list(range(k))
present = list(sorted(int(x) for x in present))
fwd = {orig: i for i, orig in enumerate(present)}
inv = {i: orig for orig, i in fwd.items()}
return fwd, inv
def _load_config(cfg_path: str) -> Dict[str, Any]:
global _cfg, _label_map_fwd, _label_map_inv
if _cfg is not None:
return _cfg
with open(cfg_path, "r", encoding="utf-8") as f:
_cfg = json.load(f)
_label_map_fwd, _label_map_inv = _prepare_label_maps(_cfg)
return _cfg
def _build_model(cfg: Dict[str, Any]) -> AgentEmotionPredictClassifier:
model = AgentEmotionPredictClassifier(
model_name=cfg.get("base_model_id", "dccuchile/bert-base-spanish-wwm-cased"),
pretrained_encoder=cfg.get("pretrained_encoder", "beto"),
emb_dim=cfg.get("emb_dim", 300),
max_length=cfg.get("max_length", 128),
hidden1=cfg.get("hidden1", 256),
hidden2=cfg.get("hidden2", 64),
num_classes=cfg.get("num_classes", 2),
dropout=cfg.get("dropout", 0.4),
label_feature_dropout=cfg.get("label_feature_dropout", 0.0), # en inferencia no se usa
device=_device,
)
# aseguramos eval()
model.eval()
return model
def _load_model() -> AgentEmotionPredictClassifier:
global _model
if _model is not None:
return _model
ckpt_path, cfg_path = _resolve_paths()
cfg = _load_config(cfg_path)
model = _build_model(cfg)
state = torch.load(ckpt_path, map_location=_device)
if isinstance(state, dict) and "model_state_dict" in state:
model.load_state_dict(state["model_state_dict"])
else:
model.load_state_dict(state)
model.eval()
_model = model
print(f"✅ Modelo cargado en {_device} | num_classes={cfg.get('num_classes')} | "
f"present_classes={cfg.get('present_classes')}")
return _model
def _coerce_user_label(label: Union[int, str]) -> int:
"""
Convierte un label de usuario a id 0..5.
- Si llega string ("alegría"), lo mapea.
- Valida rango si llega int.
"""
if isinstance(label, str):
label = label.strip().lower()
if label not in EMOTION_NAME2ID:
raise ValueError(f"Label de usuario desconocido: {label}. Esperado uno de {list(EMOTION_NAME2ID.keys())}")
return EMOTION_NAME2ID[label]
if isinstance(label, int):
if label < 0 or label > 5:
raise ValueError("El user_label debe estar en 0..5.")
return label
raise TypeError("user_label debe ser int (0..5) o str (nombre de emoción).")
def _map_agent_idx_to_original(idx: int) -> int:
"""
Convierte el índice 0..K-1 (head) al id original 0..5 para reportar el nombre global.
"""
if _label_map_inv is None:
raise RuntimeError("Mapeos de etiquetas no inicializados.")
return _label_map_inv[int(idx)]
def _agent_class_names() -> List[str]:
"""
Nombres de clases del agente en el mismo orden que la head (0..K-1).
"""
if _cfg is None:
raise RuntimeError("Config no cargada.")
names = _cfg.get("class_names", None)
if names:
return list(names)
# fallback: usar nombres globales segun present_classes
present = sorted(_cfg.get("present_classes", []))
return [EMOTION_ID2NAME[p] for p in present]
# ---------------- API de inferencia ----------------
@torch.inference_mode()
def predict(text: str, user_label: Union[int, str], return_probs: bool = False) -> Any:
"""
Predice la emoción CON LA QUE DEBE RESPONDER EL AGENTE.
Args:
text: str
user_label: int(0..5) o nombre ("tristeza", "alegría", "amor", "ira", "miedo", "sorpresa")
return_probs: si True devuelve (pred_name, probs_dict)
Returns:
- Si return_probs=False: str con el nombre de la emoción objetivo del agente (en nombres globales 0..5).
- Si return_probs=True: (pred_name:str, probs:Dict[str,float]) usando los nombres en orden de la head.
"""
model = _load_model()
cfg = _cfg # ya cargada
assert cfg is not None
# 1) preparar entrada
u = _coerce_user_label(user_label)
user_tensor = torch.tensor([u], dtype=torch.long, device=_device)
texts = [text]
# 2) forward
logits = model(texts, user_tensor) # [1, K]
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
pred_idx = int(probs.argmax())
# 3) mapear idx(0..K-1) -> id original (0..5) y nombre canónico
orig_id = _map_agent_idx_to_original(pred_idx)
pred_name = EMOTION_ID2NAME[orig_id]
if not return_probs:
return pred_name
# nombres amistosos en el orden de la head
names_head = _agent_class_names()
probs_dict = {names_head[i]: float(probs[i]) for i in range(len(names_head))}
return pred_name, probs_dict
@torch.inference_mode()
def predict_batch(texts: List[str], user_labels: List[Union[int, str]], return_probs: bool = False):
"""
Batch de inferencia.
- user_labels: lista paralela a texts con ids (0..5) o nombres de emoción.
"""
if len(texts) != len(user_labels):
raise ValueError("texts y user_labels deben tener la misma longitud.")
model = _load_model()
# preparar
u_ids = [ _coerce_user_label(u) for u in user_labels ]
user_tensor = torch.tensor(u_ids, dtype=torch.long, device=_device)
logits = model(texts, user_tensor) # [B, K]
probs = torch.softmax(logits, dim=-1).cpu().numpy()
pred_idxs = probs.argmax(axis=1)
results = []
names_head = _agent_class_names()
for i, idx in enumerate(pred_idxs):
orig_id = _map_agent_idx_to_original(int(idx))
pred_name = EMOTION_ID2NAME[orig_id]
if return_probs:
pvec = probs[i]
probs_dict = {names_head[j]: float(pvec[j]) for j in range(len(names_head))}
results.append((pred_name, probs_dict))
else:
results.append(pred_name)
return results
# ---------------- CLI rápido ----------------
if __name__ == "__main__":
# Ejemplos rápidos
txts = [
"Tuve ese tipo de sentimiento pero lo ignoré",
"Estoy muy feliz con la noticia",
"Me molesta lo que pasó",
]
# user_label puede ser int o str
for t, ulab in zip(txts, [0, "alegría", "ira"]):
out = predict(t, user_label=ulab, return_probs=True)
print(f"\nTexto: {t}\nUser label: {ulab}\nPredicción agente: {out[0]}\nProbs: {out[1]}")