File size: 4,635 Bytes
2940390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse, os, torch, time
import torch.optim

from utils.utils import load_embedder_ckpt_with_optim, adjust_learning_rate, freeze_text_embedder, AverageMeter
from utils.utils_data import init_embedding_data



def train_embedding(cur_epoch, model, optimizer, trainloader, testloader, device, cfg_em):
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True
    
    acc_train_meter = AverageMeter()
    acc_test_meter = AverageMeter()
    loss_train_meter = AverageMeter()
    loss_test_meter = AverageMeter()
    time_train_meter = AverageMeter()
    time_test_meter = AverageMeter()

    freeze_text_embedder(model)
    for k,v in model.named_parameters():
        print('{}: {}'.format(k, v.requires_grad))
    for epoch in range(cur_epoch, cfg_em.epoch+1):

        optimizer = adjust_learning_rate(optimizer, epoch-1, cfg_em.lr_decay)
        lr = optimizer.param_groups[-1]['lr']

        model.train()
        for idx, batch in enumerate(trainloader):
            for i in range(len(batch)):
                batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu")
            time_start = time.time()
            out = model(batch, 'train')
            loss = out['loss_total']
            acc  = out['acc_type']
            time_train_meter.update(time.time() - time_start)

            acc_train_meter.update(acc)
            loss_train_meter.update(loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f'Epoch:{epoch}|Iter:{idx+1}/{len(trainloader)}|lr:{lr},'
                f'Loss: {loss_train_meter.avg:.3f},' 
                f'Acc: {acc_train_meter.avg:.3f},'
                f'Time: {time_train_meter.avg:.3f},', flush=True)
        
        model.eval()
        for idx, batch in enumerate(testloader):
            for i in range(len(batch)):
                batch[i] = batch[i].to("cuda" if torch.cuda.is_available() else "cpu")

            time_start = time.time()
            out = model(batch, 'train')
            loss = out['loss_total']
            acc  = out['acc_type']
            time_test_meter.update(time.time() - time_start)

            acc_test_meter.update(acc)
            loss_test_meter.update(loss)
            print(f'Epoch:{epoch}|Iter:{idx+1}/{len(testloader)}|lr:{lr},'
                f'Loss: {loss_test_meter.avg:.3f},'
                f'Acc: {acc_test_meter.avg:.3f},'
                f'Time: {time_test_meter.avg:.3f},', flush=True)
        
        torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict()},
                         f'{cfg_em.check_dir}/embedder_model_epoch{epoch}_{acc_train_meter.avg:.3f}_{loss_train_meter.avg:.3f}_{acc_test_meter.avg:.3f}_{loss_test_meter.avg:.3f}.tar')
        acc_train_meter.reset()
        acc_test_meter.reset()
        loss_train_meter.reset()
        loss_test_meter.reset()
        time_train_meter.reset()
        time_test_meter.reset()
    print('Done!')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
     # load model
    parser.add_argument("--seed", type=int, default = 124)
    parser.add_argument("--pre_weight", type=str, default = '')
    parser.add_argument("--lr", type=float, default = 0.0001)
    parser.add_argument("--type_name", type=list, default = ['clear', 'low', 'haze', 'rain',\
                                        'snow', 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
                                        'haze_snow', 'low_haze_rain', 'low_haze_snow'])
    parser.add_argument("--train-dir", type=str, default = './data/CDD-11_train/')
    parser.add_argument("--test-dir", type=str, default = './data/CDD-11_test/')
    parser.add_argument("--batch", type=int, default = 128)
    parser.add_argument("--num-workers", type=int, default = 0)
    parser.add_argument("--epoch", type=int, default = 200)
    parser.add_argument("--lr-decay", type=int, default = 50)
    parser.add_argument("--check-dir", type=str, default = "./ckpts")
    
    args = parser.parse_args()

    os.makedirs(args.check_dir,exist_ok=True)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    embedder, optimizer, cur_epoch, device = load_embedder_ckpt_with_optim(device, args)
    trainloader, testloader = init_embedding_data(args, 'train')
    train_embedding(cur_epoch, embedder, optimizer, trainloader, testloader, device, args)