| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torchvision.transforms as transforms |
| from dataset import get_loader |
| from model import CNNtoRNN |
| import sys |
|
|
| def train(): |
| |
| device = torch.device( |
| "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| ) |
| print(f"Using device: {device}") |
|
|
| |
| embed_size = 256 |
| hidden_size = 256 |
| num_layers = 1 |
| learning_rate = 3e-4 |
| num_epochs = 5 |
| batch_size = 32 |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((356, 356)), |
| transforms.RandomCrop((299, 299)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| ]) |
|
|
| print("Loading dataset and building vocabulary...") |
| |
| try: |
| |
| train_loader, dataset = get_loader( |
| transform=transform, |
| batch_size=batch_size, |
| split="train" |
| ) |
| except Exception as e: |
| print(f"Failed to load dataset: {e}") |
| print("Please ensure you have internet access and the Huggingface datasets library is installed.") |
| sys.exit(1) |
|
|
| vocab_size = len(dataset.vocab) |
| print(f"Vocabulary size: {vocab_size}") |
|
|
| |
| model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device) |
| |
| model.encoderCNN.fine_tune(False) |
|
|
| criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"]) |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
| print("Starting training...") |
| for epoch in range(num_epochs): |
| model.train() |
| total_loss = 0 |
| |
| for idx, (imgs, captions) in enumerate(train_loader): |
| imgs = imgs.to(device) |
| captions = captions.to(device) |
|
|
| outputs = model(imgs, captions) |
|
|
| |
| |
| loss = criterion( |
| outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1) |
| ) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| total_loss += loss.item() |
| |
| if idx % 10 == 0: |
| print(f"Epoch [{epoch+1}/{num_epochs}] Step [{idx}/{len(train_loader)}] Loss: {loss.item():.4f}") |
|
|
| print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {total_loss/len(train_loader):.4f}") |
| |
| print("Training Complete. Saving model...") |
| torch.save(model.state_dict(), "caption_model.pth") |
| |
| import pickle |
| with open("vocab.pkl", "wb") as f: |
| pickle.dump(dataset.vocab, f) |
| print("Model and vocabulary saved locally.") |
|
|
| if __name__ == "__main__": |
| train() |
|
|