AIXD / app.py
PolloXDDDDDD's picture
AXD
8f11083 verified
# Asegúrate de ejecutar esto en un entorno de Google Colab
!pip install transformers torch
import torch
from torch import nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
class LiquidLayer(nn.Module):
def __init__(self, input_size, hidden_size):
super(LiquidLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.wx = nn.Linear(input_size, hidden_size)
self.wh = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
def forward(self, x, prev_state):
new_state = self.activation(self.wx(x) + self.wh(prev_state))
return new_state
class LiquidGPT2(nn.Module):
def __init__(self, gpt2_model, liquid_size):
super(LiquidGPT2, self).__init__()
self.gpt2 = gpt2_model
self.liquid_layer = LiquidLayer(self.gpt2.config.n_embd, liquid_size)
self.memory = torch.zeros(1, liquid_size)
def forward(self, input_ids, attention_mask=None):
# Set output_hidden_states to True to get the last hidden state
gpt2_output = self.gpt2(input_ids, attention_mask=attention_mask, output_hidden_states=True)
last_hidden_state = gpt2_output.hidden_states[-1][:, -1, :] # Access last hidden state correctly
liquid_output = self.liquid_layer(last_hidden_state, self.memory)
self.memory = liquid_output.detach()
return self.gpt2.lm_head(liquid_output)
# Cargar el modelo GPT-2 y el tokenizador
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2')
# Crear el modelo LiquidGPT2
liquid_size = 768 # Mismo tamaño que las embeddings de GPT-2
model = LiquidGPT2(gpt2_model, liquid_size)
# Función para generar respuestas
def generate_response(prompt, model, tokenizer, max_length=50):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
output = model(input_ids, attention_mask=attention_mask)
# Generar tokens
generated = tokenizer.decode(torch.argmax(output, dim=-1).squeeze())
return generated
# Ejemplo de uso
prompt = "Hola, ¿cómo estás?"
response = generate_response(prompt, model, tokenizer)
print(f"Prompt: {prompt}")
print(f"Respuesta: {response}")
# Función de chat interactivo
def chat():
print("¡Hola! Soy un chatbot basado en GPT-2 con una capa líquida. Escribe 'salir' para terminar.")
while True:
user_input = input("Tú: ")
if user_input.lower() == 'salir':
print("¡Hasta luego!")
break
response = generate_response(user_input, model, tokenizer)
print(f"ChatBot: {response}")
# Iniciar chat
chat()