|
""" |
|
Cargador de modelos GPT para uso local |
|
""" |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ModelLoader: |
|
def __init__(self): |
|
self.model = None |
|
self.tokenizer = None |
|
|
|
if torch.backends.mps.is_available(): |
|
self.device = "mps" |
|
elif torch.cuda.is_available(): |
|
self.device = "cuda" |
|
else: |
|
self.device = "cpu" |
|
|
|
def load_model(self, model_name="microsoft/DialoGPT-medium"): |
|
""" |
|
Carga un modelo GPT desde Hugging Face |
|
|
|
Args: |
|
model_name (str): Nombre del modelo en Hugging Face Hub |
|
""" |
|
try: |
|
logger.info(f"Cargando modelo: {model_name}") |
|
logger.info(f"Usando dispositivo: {self.device}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
device_map="auto" if self.device == "cuda" else None |
|
) |
|
|
|
if self.device == "cpu": |
|
self.model = self.model.to(self.device) |
|
|
|
logger.info("Modelo cargado exitosamente") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error al cargar el modelo: {str(e)}") |
|
return False |
|
|
|
def get_model_info(self): |
|
"""Retorna informaci贸n del modelo cargado""" |
|
if self.model is None: |
|
return {"status": "No hay modelo cargado"} |
|
|
|
return { |
|
"status": "Modelo cargado", |
|
"device": self.device, |
|
"model_type": type(self.model).__name__, |
|
"vocab_size": self.tokenizer.vocab_size if self.tokenizer else "N/A" |
|
} |
|
|
|
def is_loaded(self): |
|
"""Verifica si hay un modelo cargado""" |
|
return self.model is not None and self.tokenizer is not None |
|
|