gpt-local / models /text_generator.py
DRDELATV's picture
Upload folder using huggingface_hub
22ca508 verified
"""
Generador de texto usando modelos GPT locales
"""
import torch
from typing import List, Dict
import logging
logger = logging.getLogger(__name__)
class TextGenerator:
def __init__(self, model_loader):
self.model_loader = model_loader
self.chat_history_ids = None
def generate_response(self, user_input: str, **kwargs) -> str:
"""
Genera una respuesta basada en la entrada del usuario
Args:
user_input (str): Mensaje del usuario
**kwargs: Par谩metros de generaci贸n (max_length, temperature, etc.)
Returns:
str: Respuesta generada
"""
if not self.model_loader.is_loaded():
return "Error: No hay modelo cargado"
try:
# Par谩metros por defecto
max_length = kwargs.get('max_length', 512)
temperature = kwargs.get('temperature', 0.7)
top_p = kwargs.get('top_p', 0.9)
do_sample = kwargs.get('do_sample', True)
# Codificar la entrada del usuario
new_user_input_ids = self.model_loader.tokenizer.encode(
user_input + self.model_loader.tokenizer.eos_token,
return_tensors='pt'
).to(self.model_loader.device)
# Concatenar con el historial de chat
if self.chat_history_ids is not None:
bot_input_ids = torch.cat([self.chat_history_ids, new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
# Generar respuesta
with torch.no_grad():
chat_history_ids = self.model_loader.model.generate(
bot_input_ids,
max_length=max_length,
num_beams=1,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
pad_token_id=self.model_loader.tokenizer.eos_token_id,
attention_mask=torch.ones(bot_input_ids.shape, device=self.model_loader.device)
)
# Actualizar historial
self.chat_history_ids = chat_history_ids
# Decodificar solo la nueva respuesta
response = self.model_loader.tokenizer.decode(
chat_history_ids[:, bot_input_ids.shape[-1]:][0],
skip_special_tokens=True
)
return str(response).strip()
except Exception as e:
logger.error(f"Error en la generaci贸n: {str(e)}")
return f"Error al generar respuesta: {str(e)}"
def generate_text(self, prompt: str, **kwargs) -> str:
"""
Genera texto continuando un prompt (sin historial de chat)
Args:
prompt (str): Texto inicial
**kwargs: Par谩metros de generaci贸n
Returns:
str: Texto generado
"""
if not self.model_loader.is_loaded():
return "Error: No hay modelo cargado"
try:
# Par谩metros por defecto
max_length = kwargs.get('max_length', 100)
temperature = kwargs.get('temperature', 0.8)
top_p = kwargs.get('top_p', 0.9)
do_sample = kwargs.get('do_sample', True)
# Codificar el prompt
input_ids = self.model_loader.tokenizer.encode(
prompt,
return_tensors='pt'
).to(self.model_loader.device)
# Generar texto
with torch.no_grad():
output = self.model_loader.model.generate(
input_ids,
max_length=input_ids.shape[1] + max_length,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
pad_token_id=self.model_loader.tokenizer.eos_token_id,
attention_mask=torch.ones(input_ids.shape, device=self.model_loader.device)
)
# Decodificar solo el texto generado
generated_text = self.model_loader.tokenizer.decode(
output[0][input_ids.shape[1]:],
skip_special_tokens=True
)
return str(generated_text.strip())
except Exception as e:
logger.error(f"Error en la generaci贸n: {str(e)}")
return f"Error al generar texto: {str(e)}"
def reset_chat_history(self):
"""Reinicia el historial de chat"""
self.chat_history_ids = None
logger.info("Historial de chat reiniciado")
def get_generation_stats(self) -> Dict:
"""Retorna estad铆sticas de generaci贸n"""
if self.chat_history_ids is not None:
return {
"history_length": self.chat_history_ids.shape[1],
"device": str(self.chat_history_ids.device)
}
return {"history_length": 0, "device": "N/A"}