|
from dotenv import load_dotenv |
|
import os |
|
import json |
|
import requests |
|
import redis |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
AutoModelForCausalLM, |
|
) |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from torch.optim import AdamW |
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.responses import HTMLResponse |
|
import multiprocessing |
|
import time |
|
import uuid |
|
|
|
load_dotenv() |
|
|
|
REDIS_HOST = os.getenv('REDIS_HOST') |
|
REDIS_PORT = os.getenv('REDIS_PORT') |
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') |
|
|
|
app = FastAPI() |
|
|
|
default_language = "es" |
|
|
|
class ChatbotService: |
|
def __init__(self): |
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
self.model_name = "response_model" |
|
self.tokenizer_name = "response_tokenizer" |
|
self.model = self.load_model_from_redis() |
|
self.tokenizer = self.load_tokenizer_from_redis() |
|
|
|
def get_response(self, user_id, message, language=default_language): |
|
if self.model is None or self.tokenizer is None: |
|
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde." |
|
input_text = f"Usuario: {message} Asistente:" |
|
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu") |
|
with torch.no_grad(): |
|
output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True) |
|
response = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
response = response.replace(input_text, "").strip() |
|
return response |
|
|
|
def load_model_from_redis(self): |
|
model_data_bytes = self.redis_client.get(f"model:{self.model_name}") |
|
if model_data_bytes: |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
model.load_state_dict(torch.load(model_data_bytes)) |
|
return model |
|
return None |
|
|
|
def load_tokenizer_from_redis(self): |
|
tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}") |
|
if tokenizer_data_bytes: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(json.loads(tokenizer_data_bytes)) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
return tokenizer |
|
return None |
|
|
|
chatbot_service = ChatbotService() |
|
|
|
class UnifiedModel(nn.Module): |
|
def __init__(self, models): |
|
super(UnifiedModel, self).__init__() |
|
self.models = nn.ModuleList(models) |
|
hidden_size = self.models[0].config.hidden_size |
|
self.projection = nn.Linear(len(models) * 3, 768) |
|
self.classifier = nn.Linear(hidden_size, 3) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
hidden_states = [] |
|
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask): |
|
outputs = model(input_ids=input_id, attention_mask=attn_mask) |
|
hidden_states.append(outputs.logits) |
|
concatenated_hidden_states = torch.cat(hidden_states, dim=1) |
|
projected_features = self.projection(concatenated_hidden_states) |
|
logits = self.classifier(projected_features) |
|
return logits |
|
|
|
@staticmethod |
|
def load_model_from_redis(redis_client): |
|
model_name = "unified_model" |
|
model_data_bytes = redis_client.get(f"model:{model_name}") |
|
if model_data_bytes: |
|
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) |
|
model.load_state_dict(torch.load(model_data_bytes)) |
|
else: |
|
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) |
|
return UnifiedModel([model, model]) |
|
|
|
class SyntheticDataset(Dataset): |
|
def __init__(self, tokenizers, data): |
|
self.tokenizers = tokenizers |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
item = self.data[idx] |
|
text = item['text'] |
|
label = item['label'] |
|
tokenized = {} |
|
for name, tokenizer in self.tokenizers.items(): |
|
tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128) |
|
tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"]) |
|
tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"]) |
|
tokenized["labels"] = torch.tensor(label) |
|
return tokenized |
|
|
|
conversation_history = {} |
|
|
|
@app.post("/process") |
|
async def process(request: Request): |
|
data = await request.json() |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
|
|
tokenizers = {} |
|
models = {} |
|
|
|
model_name = "unified_model" |
|
tokenizer_name = "unified_tokenizer" |
|
|
|
model_data_bytes = redis_client.get(f"model:{model_name}") |
|
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}") |
|
|
|
if model_data_bytes: |
|
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) |
|
model.load_state_dict(torch.load(model_data_bytes)) |
|
else: |
|
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3) |
|
models[model_name] = model |
|
|
|
if tokenizer_data_bytes: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.add_tokens(json.loads(tokenizer_data_bytes)) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizers[tokenizer_name] = tokenizer |
|
|
|
unified_model = UnifiedModel(list(models.values())) |
|
unified_model.to(torch.device("cpu")) |
|
|
|
if data.get("train"): |
|
user_data = data.get("user_data", []) |
|
if not user_data: |
|
user_data = [ |
|
{"text": "Hola", "label": 1}, |
|
{"text": "Necesito ayuda", "label": 2}, |
|
{"text": "No entiendo", "label": 0} |
|
] |
|
redis_client.rpush("training_queue", json.dumps({ |
|
"tokenizers": {tokenizer_name: tokenizer.get_vocab()}, |
|
"data": user_data |
|
})) |
|
return {"message": "Training data received. Model will be updated asynchronously."} |
|
elif data.get("message"): |
|
user_id = data.get("user_id") |
|
text = data['message'] |
|
language = data.get("language", default_language) |
|
if user_id not in conversation_history: |
|
conversation_history[user_id] = [] |
|
conversation_history[user_id].append(text) |
|
contextualized_text = " ".join(conversation_history[user_id][-3:]) |
|
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()] |
|
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs] |
|
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs] |
|
with torch.no_grad(): |
|
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask) |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
response = chatbot_service.get_response(user_id, contextualized_text, language) |
|
redis_client.rpush("training_queue", json.dumps({ |
|
"tokenizers": {tokenizer_name: tokenizer.get_vocab()}, |
|
"data": [{"text": contextualized_text, "label": predicted_class}] |
|
})) |
|
return {"answer": response} |
|
else: |
|
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.") |
|
|
|
def get_chatbot_response(user_id, question, predicted_class, language): |
|
if user_id not in conversation_history: |
|
conversation_history[user_id] = [] |
|
conversation_history[user_id].append(question) |
|
return chatbot_service.get_response(user_id, question, language) |
|
|
|
@app.get("/") |
|
async def get_home(): |
|
user_id = str(uuid.uuid4()) |
|
html_code = f""" |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<title>Chatbot</title> |
|
<style> |
|
body {{ |
|
font-family: 'Arial', sans-serif; |
|
background-color: #f4f4f9; |
|
margin: 0; |
|
padding: 0; |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
min-height: 100vh; |
|
}} |
|
.container {{ |
|
background-color: #fff; |
|
border-radius: 10px; |
|
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); |
|
overflow: hidden; |
|
width: 400px; |
|
max-width: 90%; |
|
}} |
|
h1 {{ |
|
color: #333; |
|
text-align: center; |
|
padding: 20px; |
|
margin: 0; |
|
background-color: #f8f9fa; |
|
border-bottom: 1px solid #eee; |
|
}} |
|
#chatbox {{ |
|
height: 300px; |
|
overflow-y: auto; |
|
padding: 10px; |
|
border-bottom: 1px solid #eee; |
|
}} |
|
.message {{ |
|
margin-bottom: 10px; |
|
padding: 10px; |
|
border-radius: 5px; |
|
}} |
|
.message.user {{ |
|
background-color: #e1f5fe; |
|
text-align: right; |
|
}} |
|
.message.bot {{ |
|
background-color: #f1f1f1; |
|
}} |
|
#input {{ |
|
display: flex; |
|
padding: 10px; |
|
}} |
|
#input textarea {{ |
|
flex: 1; |
|
padding: 10px; |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
margin-right: 10px; |
|
}} |
|
#input button {{ |
|
padding: 10px 20px; |
|
border: none; |
|
border-radius: 4px; |
|
background-color: #007bff; |
|
color: #fff; |
|
cursor: pointer; |
|
}} |
|
#input button:hover {{ |
|
background-color: #0056b3; |
|
}} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>Chatbot</h1> |
|
<div id="chatbox"></div> |
|
<div id="input"> |
|
<textarea id="message" rows="3" placeholder="Escribe tu mensaje aquí..."></textarea> |
|
<button id="send">Enviar</button> |
|
</div> |
|
</div> |
|
<script> |
|
const chatbox = document.getElementById('chatbox'); |
|
const messageInput = document.getElementById('message'); |
|
const sendButton = document.getElementById('send'); |
|
|
|
function appendMessage(text, sender) {{ |
|
const messageDiv = document.createElement('div'); |
|
messageDiv.classList.add('message', sender); |
|
messageDiv.textContent = text; |
|
chatbox.appendChild(messageDiv); |
|
chatbox.scrollTop = chatbox.scrollHeight; |
|
}} |
|
|
|
async function sendMessage() {{ |
|
const message = messageInput.value; |
|
if (!message.trim()) return; |
|
|
|
appendMessage(message, 'user'); |
|
messageInput.value = ''; |
|
|
|
const response = await fetch('/process', {{ |
|
method: 'POST', |
|
headers: {{ |
|
'Content-Type': 'application/json' |
|
}}, |
|
body: JSON.stringify({{ |
|
message: message, |
|
user_id: '{user_id}' |
|
}}) |
|
}}); |
|
const data = await response.json(); |
|
appendMessage(data.answer, 'bot'); |
|
}} |
|
|
|
sendButton.addEventListener('click', sendMessage); |
|
messageInput.addEventListener('keypress', (e) => {{ |
|
if (e.key === 'Enter' && !e.shiftKey) {{ |
|
e.preventDefault(); |
|
sendMessage(); |
|
}} |
|
}}); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code) |
|
|
|
def train_unified_model(): |
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True) |
|
while True: |
|
training_queue = redis_client.lrange("training_queue", 0, -1) |
|
if training_queue: |
|
for item in training_queue: |
|
item_data = json.loads(item) |
|
tokenizers = {name: AutoTokenizer.from_pretrained("gpt2") for name in item_data["tokenizers"]} |
|
for tokenizer in tokenizers.values(): |
|
tokenizer.pad_token = tokenizer.eos_token |
|
data = item_data["data"] |
|
dataset = SyntheticDataset(tokenizers, data) |
|
dataloader = DataLoader(dataset, batch_size=8, shuffle=True) |
|
|
|
model = UnifiedModel([AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)]) |
|
optimizer = AdamW(model.parameters(), lr=1e-5) |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
for epoch in range(3): |
|
model.train() |
|
for batch in dataloader: |
|
input_ids = [batch[f"input_ids_{name}"] for name in tokenizers] |
|
attention_mask = [batch[f"attention_mask_{name}"] for name in tokenizers] |
|
labels = batch["labels"] |
|
|
|
optimizer.zero_grad() |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
model_data_bytes = torch.save(model.state_dict(), "model_data.pt") |
|
redis_client.set(f"model:unified_model", model_data_bytes) |
|
|
|
redis_client.delete("training_queue") |
|
time.sleep(60) |
|
|
|
if __name__ == "__main__": |
|
training_process = multiprocessing.Process(target=train_unified_model) |
|
training_process.start() |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|