File size: 4,021 Bytes
3cc4a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import time
from datetime import datetime
from tqdm import tqdm
from tensorboardX import SummaryWriter
import torch
import torchinfo
import numpy as np

import options
from validate import validate, calculate_acc
from datasetss import *
from utilss.logger import create_logger
from utilss.earlystop import EarlyStopping
from networks.trainer import Trainer



if __name__ == '__main__':
    train_opt = options.TrainOptions().parse()
    # val_opt = options.TestOptions().parse()
    
    # logger
    logger = create_logger(output_dir=train_opt.checkpoints_dir, name="FeatureTransformer")
    logger.info(f"working dir: {train_opt.checkpoints_dir}")

 
    model = Trainer(train_opt)
    # logger.info(opt.gpu_ids[0])
    logger.info(model.device)
    
    # extract_feature_model = model.extract_feature_model
    train_loader, val_loader = create_train_val_dataloader(train_opt, clip_model = None, transform = model.clip_model.preprocess, k_split=0.8)
    logger.info(f"train    {len(train_loader)}")
    logger.info(f"validate {len(val_loader)}")


    train_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "train"))
    val_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "val"))
        
    early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.001, verbose=True)
    
    start_time = time.time()
    logger.info(torchinfo.summary(model.model, input_size=(train_opt.batch_size, 16, 768), col_width=20,
                  col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0))
    
    
    logger.info("Length of train loader: %d" %(len(train_loader)))
    for epoch in range(train_opt.niter):
        y_true, y_pred = [], []
        pbar = tqdm(train_loader)
        for i, data in enumerate(pbar):
            pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

            model.total_steps += 1

            model.set_input(data)
            model.optimize_parameters()

            y_pred.extend(model.output.sigmoid().flatten().tolist())
            y_true.extend(data[1].flatten().tolist())

            if model.total_steps % train_opt.loss_freq == 0:
                logger.info("Train loss: {} at step: {}".format(model.loss, model.total_steps))
                train_writer.add_scalar('loss', model.loss, model.total_steps)
                logger.info("Iter time: {}".format((time.time()-start_time)/model.total_steps)  )

            if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters 
                model.save_networks('model_iters_%s.pth' % model.total_steps)
            # logger.info("trained one batch")
            pbar.set_postfix_str(f"loss: {model.loss}, ")
        r_acc0, f_acc0, acc0 = calculate_acc(np.array(y_true), np.array(y_pred), 0.5)
        logger.info(f"TrainSet r_acc: {r_acc0}, f_acc: {f_acc0}, acc: {acc0}")

        if epoch % train_opt.save_epoch_freq == 0:
            logger.info('saving the model at the end of epoch %d' % (epoch))
            model.save_networks( 'model_epoch_%s.pth' % epoch )

        # Validation
        model.eval()
        ap, r_acc, f_acc, acc = validate(model, val_loader, logger=logger)
        val_writer.add_scalar('accuracy', acc, model.total_steps)
        val_writer.add_scalar('ap', ap, model.total_steps)
        logger.info("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))

        early_stopping(acc, model.model)
        if early_stopping.early_stop:
            cont_train = model.adjust_learning_rate()
            if cont_train:
                logger.info("Learning rate dropped by 10, continue training...")
                early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.002, verbose=True)
            else:
                logger.info("Early stopping.")
                break
        
        model.train()