File size: 4,173 Bytes
70884da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f4666
70884da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f4666
70884da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import argparse
from torch import optim
from torch.utils.data import DataLoader
from scripts.genomic_plip_model import GenomicPLIPModel
from scripts.tile_file_dataloader import FlatTileDataset
from transformers import CLIPVisionModel

def train_model(data_dir, model_save_path, pretrained_model_path, lr, num_epochs, train_batch_size, validation_batch_size, num_workers):

    # Load datasets   
    train_dataset = FlatTileDataset(data_dir=f'{data_dir}/train')
    train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)

    validation_dataset = FlatTileDataset(data_dir=f'{data_dir}/validate')
    validation_data_loader = DataLoader(validation_dataset, batch_size=validation_batch_size, shuffle=False, num_workers=num_workers)

    # Initialize the model
    base_model = CLIPVisionModel.from_pretrained(pretrained_model_path)
    custom_model = GenomicPLIPModel(base_model)

    criterion = torch.nn.CosineSimilarity(dim=1)
    optimizer = optim.Adam(custom_model.parameters(), lr=lr)


    for epoch in range(num_epochs):
        # Training loop
        custom_model.train()
        train_loss = 0.0

        for batch_images, batch_scores in train_data_loader:
            optimizer.zero_grad()

            batch_loss = 0
            for img, score in zip(batch_images, batch_scores):
                vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))
                cos_sim = criterion(score_features, vision_features)
                loss = 1-cos_sim.mean()

                batch_loss += loss.item()
                loss.backward()

            optimizer.step()
            train_loss += batch_loss
            print(f"Batch Cosine Similarity {batch_loss:.4f}")

        avg_train_loss = train_loss / len(train_data_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Cosine Similarity: {avg_train_loss:.4f}")

        # Validation loop
        custom_model.eval()
        validation_loss = 0.0

        with torch.no_grad():
            for batch_images, batch_scores in validation_data_loader:
                batch_loss = 0
                for img, score in zip(batch_images, batch_scores):
                    vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))
                    cos_sim = criterion(score_features, vision_features)
                    loss = 1-cos_sim.mean()

                    batch_loss += loss.item()

                validation_loss += batch_loss
                print(f"Validation Batch Cosine Similarity {batch_loss:.4f}")

            avg_validation_loss = validation_loss / len(validation_data_loader)
            print(f"Epoch [{epoch+1}/{num_epochs}], Validation Cosine Similarity: {avg_validation_loss:.4f}")

    # Save the trained model
    torch.save(custom_model.state_dict(), model_save_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train the Genomic PLIP Model')
    parser.add_argument('--data_dir', type=str, default='Datasets/train_03', help='Directory containing the train, validate, and test datasets.')
    parser.add_argument('--model_save_path', type=str, default='genomic_plip.pth', help='Path to save the trained model.')
    parser.add_argument('--pretrained_model_path', type=str, default='./plip', help='Path to the pretrained CLIP model.')
    
    parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate for the optimizer.')
    parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs to train for.')
    parser.add_argument('--train_batch_size', type=int, default=128, help='Batch size for the training data loader.')
    parser.add_argument('--validation_batch_size', type=int, default=128, help='Batch size for the validation data loader.')
    parser.add_argument('--num_workers', type=int, default=32, help='Number of worker threads for data loading.')


    args = parser.parse_args()

    train_model(args.data_dir, args.model_save_path, args.pretrained_model_path, args.lr, args.num_epochs, args.train_batch_size, args.validation_batch_size, args.num_workers)