Train / main.py
Yjhhh's picture
Update main.py
bde7af1 verified
raw
history blame
14.2 kB
from dotenv import load_dotenv
import os
import json
import requests
import redis
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
import multiprocessing
import time
import uuid
load_dotenv()
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = os.getenv('REDIS_PORT')
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
app = FastAPI()
default_language = "es"
class ChatbotService:
def __init__(self):
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
self.model_name = "response_model"
self.tokenizer_name = "response_tokenizer"
self.model = self.load_model_from_redis()
self.tokenizer = self.load_tokenizer_from_redis()
def get_response(self, user_id, message, language=default_language):
if self.model is None or self.tokenizer is None:
return "El modelo aún no está listo. Por favor, inténtelo de nuevo más tarde."
input_text = f"Usuario: {message} Asistente:"
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to("cpu")
with torch.no_grad():
output = self.model.generate(input_ids=input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
response = response.replace(input_text, "").strip()
return response
def load_model_from_redis(self):
model_data_bytes = self.redis_client.get(f"model:{self.model_name}")
if model_data_bytes:
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.load_state_dict(torch.load(model_data_bytes))
return model
return None
def load_tokenizer_from_redis(self):
tokenizer_data_bytes = self.redis_client.get(f"tokenizer:{self.tokenizer_name}")
if tokenizer_data_bytes:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
return None
chatbot_service = ChatbotService()
class UnifiedModel(nn.Module):
def __init__(self, models):
super(UnifiedModel, self).__init__()
self.models = nn.ModuleList(models)
hidden_size = self.models[0].config.hidden_size
self.projection = nn.Linear(len(models) * 3, 768)
self.classifier = nn.Linear(hidden_size, 3)
def forward(self, input_ids, attention_mask):
hidden_states = []
for model, input_id, attn_mask in zip(self.models, input_ids, attention_mask):
outputs = model(input_ids=input_id, attention_mask=attn_mask)
hidden_states.append(outputs.logits)
concatenated_hidden_states = torch.cat(hidden_states, dim=1)
projected_features = self.projection(concatenated_hidden_states)
logits = self.classifier(projected_features)
return logits
@staticmethod
def load_model_from_redis(redis_client):
model_name = "unified_model"
model_data_bytes = redis_client.get(f"model:{model_name}")
if model_data_bytes:
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
model.load_state_dict(torch.load(model_data_bytes))
else:
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
return UnifiedModel([model, model])
class SyntheticDataset(Dataset):
def __init__(self, tokenizers, data):
self.tokenizers = tokenizers
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
text = item['text']
label = item['label']
tokenized = {}
for name, tokenizer in self.tokenizers.items():
tokens = tokenizer(text, padding="max_length", truncation=True, max_length=128)
tokenized[f"input_ids_{name}"] = torch.tensor(tokens["input_ids"])
tokenized[f"attention_mask_{name}"] = torch.tensor(tokens["attention_mask"])
tokenized["labels"] = torch.tensor(label)
return tokenized
conversation_history = {}
@app.post("/process")
async def process(request: Request):
data = await request.json()
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
tokenizers = {}
models = {}
model_name = "unified_model"
tokenizer_name = "unified_tokenizer"
model_data_bytes = redis_client.get(f"model:{model_name}")
tokenizer_data_bytes = redis_client.get(f"tokenizer:{tokenizer_name}")
if model_data_bytes:
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
model.load_state_dict(torch.load(model_data_bytes))
else:
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)
models[model_name] = model
if tokenizer_data_bytes:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_tokens(json.loads(tokenizer_data_bytes))
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizers[tokenizer_name] = tokenizer
unified_model = UnifiedModel(list(models.values()))
unified_model.to(torch.device("cpu"))
if data.get("train"):
user_data = data.get("user_data", [])
if not user_data:
user_data = [
{"text": "Hola", "label": 1},
{"text": "Necesito ayuda", "label": 2},
{"text": "No entiendo", "label": 0}
]
redis_client.rpush("training_queue", json.dumps({
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
"data": user_data
}))
return {"message": "Training data received. Model will be updated asynchronously."}
elif data.get("message"):
user_id = data.get("user_id")
text = data['message']
language = data.get("language", default_language)
if user_id not in conversation_history:
conversation_history[user_id] = []
conversation_history[user_id].append(text)
contextualized_text = " ".join(conversation_history[user_id][-3:])
tokenized_inputs = [tokenizers[name](contextualized_text, return_tensors="pt") for name in tokenizers.keys()]
input_ids = [tokens['input_ids'] for tokens in tokenized_inputs]
attention_mask = [tokens['attention_mask'] for tokens in tokenized_inputs]
with torch.no_grad():
logits = unified_model(input_ids=input_ids, attention_mask=attention_mask)
predicted_class = torch.argmax(logits, dim=-1).item()
response = chatbot_service.get_response(user_id, contextualized_text, language)
redis_client.rpush("training_queue", json.dumps({
"tokenizers": {tokenizer_name: tokenizer.get_vocab()},
"data": [{"text": contextualized_text, "label": predicted_class}]
}))
return {"answer": response}
else:
raise HTTPException(status_code=400, detail="Request must contain 'train' or 'message'.")
def get_chatbot_response(user_id, question, predicted_class, language):
if user_id not in conversation_history:
conversation_history[user_id] = []
conversation_history[user_id].append(question)
return chatbot_service.get_response(user_id, question, language)
@app.get("/")
async def get_home():
user_id = str(uuid.uuid4())
html_code = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Chatbot</title>
<style>
body {{
font-family: 'Arial', sans-serif;
background-color: #f4f4f9;
margin: 0;
padding: 0;
display: flex;
align-items: center;
justify-content: center;
min-height: 100vh;
}}
.container {{
background-color: #fff;
border-radius: 10px;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
overflow: hidden;
width: 400px;
max-width: 90%;
}}
h1 {{
color: #333;
text-align: center;
padding: 20px;
margin: 0;
background-color: #f8f9fa;
border-bottom: 1px solid #eee;
}}
#chatbox {{
height: 300px;
overflow-y: auto;
padding: 10px;
border-bottom: 1px solid #eee;
}}
.message {{
margin-bottom: 10px;
padding: 10px;
border-radius: 5px;
}}
.message.user {{
background-color: #e1f5fe;
text-align: right;
}}
.message.bot {{
background-color: #f1f1f1;
}}
#input {{
display: flex;
padding: 10px;
}}
#input textarea {{
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
margin-right: 10px;
}}
#input button {{
padding: 10px 20px;
border: none;
border-radius: 4px;
background-color: #007bff;
color: #fff;
cursor: pointer;
}}
#input button:hover {{
background-color: #0056b3;
}}
</style>
</head>
<body>
<div class="container">
<h1>Chatbot</h1>
<div id="chatbox"></div>
<div id="input">
<textarea id="message" rows="3" placeholder="Escribe tu mensaje aquí..."></textarea>
<button id="send">Enviar</button>
</div>
</div>
<script>
const chatbox = document.getElementById('chatbox');
const messageInput = document.getElementById('message');
const sendButton = document.getElementById('send');
function appendMessage(text, sender) {{
const messageDiv = document.createElement('div');
messageDiv.classList.add('message', sender);
messageDiv.textContent = text;
chatbox.appendChild(messageDiv);
chatbox.scrollTop = chatbox.scrollHeight;
}}
async function sendMessage() {{
const message = messageInput.value;
if (!message.trim()) return;
appendMessage(message, 'user');
messageInput.value = '';
const response = await fetch('/process', {{
method: 'POST',
headers: {{
'Content-Type': 'application/json'
}},
body: JSON.stringify({{
message: message,
user_id: '{user_id}'
}})
}});
const data = await response.json();
appendMessage(data.answer, 'bot');
}}
sendButton.addEventListener('click', sendMessage);
messageInput.addEventListener('keypress', (e) => {{
if (e.key === 'Enter' && !e.shiftKey) {{
e.preventDefault();
sendMessage();
}}
}});
</script>
</body>
</html>
"""
return HTMLResponse(content=html_code)
def train_unified_model():
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)
while True:
training_queue = redis_client.lrange("training_queue", 0, -1)
if training_queue:
for item in training_queue:
item_data = json.loads(item)
tokenizers = {name: AutoTokenizer.from_pretrained("gpt2") for name in item_data["tokenizers"]}
for tokenizer in tokenizers.values():
tokenizer.pad_token = tokenizer.eos_token
data = item_data["data"]
dataset = SyntheticDataset(tokenizers, data)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
model = UnifiedModel([AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=3)])
optimizer = AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()
for epoch in range(3):
model.train()
for batch in dataloader:
input_ids = [batch[f"input_ids_{name}"] for name in tokenizers]
attention_mask = [batch[f"attention_mask_{name}"] for name in tokenizers]
labels = batch["labels"]
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
model_data_bytes = torch.save(model.state_dict(), "model_data.pt")
redis_client.set(f"model:unified_model", model_data_bytes)
redis_client.delete("training_queue")
time.sleep(60)
if __name__ == "__main__":
training_process = multiprocessing.Process(target=train_unified_model)
training_process.start()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)