V2L-Alpha1-Example1 / decoder.py
openagi-agi's picture
Upload 8 files
7cd7caf verified
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from transformers import T5Tokenizer
from sentence_transformers import SentenceTransformer
# ===== CONFIG =====
INPUT_FILE = "chat_1turn.csv"
EMB_FILE = "chat_embeddings.pt"
MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0"
EPOCHS = 80
BATCH_SIZE = 16
HIDDEN_DIM = 512
MAX_LEN = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# ===== Load CSV =====
df = pd.read_csv(INPUT_FILE)
sources = df["source"].fillna("").tolist()
targets = df["target"].fillna("").tolist()
# ===== Tokenizer =====
tokenizer = T5Tokenizer.from_pretrained("t5-small")
target_enc = tokenizer(targets, padding=True, truncation=True,
return_tensors="pt", max_length=MAX_LEN)
input_ids = target_enc["input_ids"].to(device)
attention_mask = target_enc["attention_mask"].to(device)
# ===== Load embeddings =====
emb_data = torch.load(EMB_FILE)
x_embeddings = emb_data["source"].to(device) # not used directly in this training
y_embeddings = emb_data["target"].to(device) # used to condition decoder
# ===== Decoder =====
class EmbeddingDecoder(nn.Module):
def __init__(self, input_dim, hidden_dim, vocab_size):
super().__init__()
self.bridge = nn.Linear(input_dim, hidden_dim)
self.embed = nn.Embedding(vocab_size, hidden_dim)
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, emb_vec, target_ids=None, teacher_forcing_ratio=0.5, max_len=MAX_LEN):
hidden = self.bridge(emb_vec).unsqueeze(0) # [1,B,H]
B = emb_vec.size(0)
outputs = []
# start with pad_token (T5 has pad=0, eos=1)
inp = torch.full((B,1), tokenizer.pad_token_id, device=emb_vec.device)
for t in range(max_len):
inp_emb = self.embed(inp) # [B,1,H]
out, hidden = self.gru(inp_emb, hidden) # [B,1,H]
logits = self.fc(out.squeeze(1)) # [B,V]
outputs.append(logits.unsqueeze(1))
if target_ids is not None and t < target_ids.size(1) and torch.rand(1).item() < teacher_forcing_ratio:
inp = target_ids[:, t].unsqueeze(1)
else:
inp = torch.argmax(logits, dim=-1, keepdim=True)
return torch.cat(outputs, dim=1) # [B, max_len, V]
# ===== Train =====
decoder = EmbeddingDecoder(y_embeddings.shape[1], HIDDEN_DIM, tokenizer.vocab_size).to(device)
optimizer = optim.Adam(decoder.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
print("Training decoder...")
for epoch in range(EPOCHS):
decoder.train()
total_loss = 0.0
for i in range(0, len(y_embeddings), BATCH_SIZE):
xb = y_embeddings[i:i+BATCH_SIZE]
yb = input_ids[i:i+BATCH_SIZE]
optimizer.zero_grad()
logits = decoder(xb, target_ids=yb, teacher_forcing_ratio=0.7, max_len=yb.size(1))
loss = criterion(logits.reshape(-1, logits.size(-1)), yb.reshape(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss:.4f}")
# ===== Inference =====
embedder = SentenceTransformer(MODEL_NAME, device=device)
def generate(text, max_len=30, use_mapper=False, mapper=None):
with torch.no_grad():
# embed new text
emb = embedder.encode([text], convert_to_tensor=True, device=device)
if use_mapper and mapper is not None:
emb = mapper(emb)
logits = decoder(emb, target_ids=None, teacher_forcing_ratio=0.0, max_len=max_len)
ids = torch.argmax(logits, dim=-1).squeeze(0).tolist()
return tokenizer.decode(ids, skip_special_tokens=True)
# ===== Test =====
print("Hi ->", generate("Hi"))