Param20h's picture
Upload folder using huggingface_hub
d31183e verified
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from dataset import get_loader
from model import CNNtoRNN
import sys
def train():
# Setup device
device = torch.device(
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using device: {device}")
# Hyperparameters
embed_size = 256
hidden_size = 256
num_layers = 1
learning_rate = 3e-4
num_epochs = 5
batch_size = 32
# Transforms for image processing
transform = transforms.Compose([
transforms.Resize((356, 356)),
transforms.RandomCrop((299, 299)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
print("Loading dataset and building vocabulary...")
# Load dataset
try:
# Note: loading dataset might take some time and network bandwidth
train_loader, dataset = get_loader(
transform=transform,
batch_size=batch_size,
split="train"
)
except Exception as e:
print(f"Failed to load dataset: {e}")
print("Please ensure you have internet access and the Huggingface datasets library is installed.")
sys.exit(1)
vocab_size = len(dataset.vocab)
print(f"Vocabulary size: {vocab_size}")
# Initialize model
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
# Freeze CNN layers
model.encoderCNN.fine_tune(False)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
print("Starting training...")
for epoch in range(num_epochs):
model.train()
total_loss = 0
for idx, (imgs, captions) in enumerate(train_loader):
imgs = imgs.to(device)
captions = captions.to(device)
outputs = model(imgs, captions) # Let model.py handle the slicing array
# Loss requires outputs shape (batch*seq_len, vocab_size) and targets (batch*seq_len)
# targets should be captions starting from the second token
loss = criterion(
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if idx % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Step [{idx}/{len(train_loader)}] Loss: {loss.item():.4f}")
print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {total_loss/len(train_loader):.4f}")
print("Training Complete. Saving model...")
torch.save(model.state_dict(), "caption_model.pth")
# Also save the vocab so we can use it in inference
import pickle
with open("vocab.pkl", "wb") as f:
pickle.dump(dataset.vocab, f)
print("Model and vocabulary saved locally.")
if __name__ == "__main__":
train()