from img2art_search.data.dataset import ImageRetrievalDataset from torch.utils.data import DataLoader from img2art_search.losses.contrastiveloss import ContrastiveLoss from img2art_search.data.data import get_data_from_local, split_train_val_test from img2art_search.data.transforms import transform from img2art_search.models.model import ViTImageSearchModel import torch from torch.optim import Adam from torch.utils.tensorboard import SummaryWriter import numpy as np def fine_tune_vit(epochs, batch_size): data = get_data_from_local() train_data, val_data, test_data = split_train_val_test(data, 0.2, 0.1) np.save("results/test_data", test_data) train_dataset = ImageRetrievalDataset(train_data, transform=transform) val_dataset = ImageRetrievalDataset(val_data, transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) model = ViTImageSearchModel() # logs log_dir = "./logs/" writer = SummaryWriter(log_dir=log_dir) # params criterion = ContrastiveLoss() optimizer = Adam(model.parameters(), lr=1e-4) epochs = epochs for epoch in range(epochs): model.train() total_loss = 0 for batch_idx, batch in enumerate(train_loader): inputs, labels = batch optimizer.zero_grad() input_embeddings = model(inputs) label_embeddings = model(labels) loss = criterion(input_embeddings, label_embeddings) loss.backward() optimizer.step() total_loss += loss.item() writer.add_scalar( "Train Loss", loss.item(), epoch * len(train_loader) + batch_idx ) avg_train_loss = total_loss / len(train_loader) writer.add_scalar("Average Train Loss", avg_train_loss, epoch) print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader)}") model.eval() with torch.no_grad(): val_loss = 0 for batch_idx, batch in enumerate(val_loader): inputs, labels = batch input_embeddings = model(inputs) label_embeddings = model(labels) loss = criterion(input_embeddings, label_embeddings) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) writer.add_scalar("Validation Loss", avg_val_loss, epoch) print(f"Validation Loss: {val_loss/len(val_loader)}") torch.save(model.state_dict(), "results/model.pth")