|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
import pandas as pd
|
|
|
from transformers import T5Tokenizer
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
INPUT_FILE = "chat_1turn.csv"
|
|
|
EMB_FILE = "chat_embeddings.pt"
|
|
|
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
|
|
|
EPOCHS = 80
|
|
|
BATCH_SIZE = 16
|
|
|
HIDDEN_DIM = 512
|
|
|
MAX_LEN = 64
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
df = pd.read_csv(INPUT_FILE)
|
|
|
sources = df["source"].fillna("").tolist()
|
|
|
targets = df["target"].fillna("").tolist()
|
|
|
|
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
|
target_enc = tokenizer(targets, padding=True, truncation=True,
|
|
|
return_tensors="pt", max_length=MAX_LEN)
|
|
|
input_ids = target_enc["input_ids"].to(device)
|
|
|
attention_mask = target_enc["attention_mask"].to(device)
|
|
|
|
|
|
|
|
|
emb_data = torch.load(EMB_FILE)
|
|
|
x_embeddings = emb_data["source"].to(device)
|
|
|
y_embeddings = emb_data["target"].to(device)
|
|
|
|
|
|
|
|
|
class EmbeddingDecoder(nn.Module):
|
|
|
def __init__(self, input_dim, hidden_dim, vocab_size):
|
|
|
super().__init__()
|
|
|
self.bridge = nn.Linear(input_dim, hidden_dim)
|
|
|
self.embed = nn.Embedding(vocab_size, hidden_dim)
|
|
|
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
|
|
|
self.fc = nn.Linear(hidden_dim, vocab_size)
|
|
|
|
|
|
def forward(self, emb_vec, target_ids=None, teacher_forcing_ratio=0.5, max_len=MAX_LEN):
|
|
|
hidden = self.bridge(emb_vec).unsqueeze(0)
|
|
|
B = emb_vec.size(0)
|
|
|
outputs = []
|
|
|
|
|
|
|
|
|
inp = torch.full((B,1), tokenizer.pad_token_id, device=emb_vec.device)
|
|
|
|
|
|
for t in range(max_len):
|
|
|
inp_emb = self.embed(inp)
|
|
|
out, hidden = self.gru(inp_emb, hidden)
|
|
|
logits = self.fc(out.squeeze(1))
|
|
|
outputs.append(logits.unsqueeze(1))
|
|
|
|
|
|
if target_ids is not None and t < target_ids.size(1) and torch.rand(1).item() < teacher_forcing_ratio:
|
|
|
inp = target_ids[:, t].unsqueeze(1)
|
|
|
else:
|
|
|
inp = torch.argmax(logits, dim=-1, keepdim=True)
|
|
|
|
|
|
return torch.cat(outputs, dim=1)
|
|
|
|
|
|
|
|
|
decoder = EmbeddingDecoder(y_embeddings.shape[1], HIDDEN_DIM, tokenizer.vocab_size).to(device)
|
|
|
optimizer = optim.Adam(decoder.parameters(), lr=1e-3)
|
|
|
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
|
|
|
|
|
|
print("Training decoder...")
|
|
|
for epoch in range(EPOCHS):
|
|
|
decoder.train()
|
|
|
total_loss = 0.0
|
|
|
for i in range(0, len(y_embeddings), BATCH_SIZE):
|
|
|
xb = y_embeddings[i:i+BATCH_SIZE]
|
|
|
yb = input_ids[i:i+BATCH_SIZE]
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
logits = decoder(xb, target_ids=yb, teacher_forcing_ratio=0.7, max_len=yb.size(1))
|
|
|
loss = criterion(logits.reshape(-1, logits.size(-1)), yb.reshape(-1))
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")
|
|
|
|
|
|
|
|
|
embedder = SentenceTransformer(MODEL_NAME, device=device)
|
|
|
|
|
|
def generate(text, max_len=30, use_mapper=False, mapper=None):
|
|
|
with torch.no_grad():
|
|
|
|
|
|
emb = embedder.encode([text], convert_to_tensor=True, device=device)
|
|
|
if use_mapper and mapper is not None:
|
|
|
emb = mapper(emb)
|
|
|
logits = decoder(emb, target_ids=None, teacher_forcing_ratio=0.0, max_len=max_len)
|
|
|
ids = torch.argmax(logits, dim=-1).squeeze(0).tolist()
|
|
|
return tokenizer.decode(ids, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
print("Hi ->", generate("Hi"))
|
|
|
|