File size: 3,010 Bytes
5464cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pickle
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torch.amp import autocast, GradScaler

from data_loader import DUTSDataset, MSRADataset
from model import U2Net

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scaler = GradScaler()

def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.
    for images, masks in tqdm(loader, desc='Training', leave=False):
        images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        with autocast(device_type='cuda'):
            outputs = model(images)
            loss = sum([criterion(output, masks) for output in outputs])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
    return running_loss / len(loader)

def validate(model, loader, criterion):
    model.eval()
    running_loss = 0.
    with torch.no_grad():
        for images, masks in tqdm(loader, desc='Validating', leave=False):
            images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
            outputs = model(images)
            loss = sum([criterion(output, masks) for output in outputs])
            running_loss += loss.item()
    avg_loss = running_loss / len(loader)
    return avg_loss


if __name__ == '__main__':
    batch_size = 40
    valid_batch_size = 80
    epochs = 100
    
    lr = 1e-4
    loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
    
    model_name = 'u2net-duts'
    model = U2Net()
    model = torch.nn.DataParallel(model.to(device))
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    train_loader = DataLoader(
        ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]), 
        batch_size=batch_size, shuffle=True, pin_memory=True, 
        num_workers=16, persistent_workers=True
    )
    valid_loader = DataLoader(
        ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]), 
        batch_size=valid_batch_size, shuffle=False, pin_memory=True, 
        num_workers=16, persistent_workers=True
    )
    
    losses = {'train': [], 'val': []}
    for epoch in tqdm(range(epochs), desc='Epochs'):
        torch.cuda.empty_cache()
        train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
        val_loss = validate(model, valid_loader, loss_fn)
        losses['train'].append(train_loss)
        losses['val'].append(val_loss)

        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f'results/inter-{model_name}.pt')
            
        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

    torch.save(model.state_dict(), f'results/{model_name}.pt')
    with open('results/loss.txt', 'wb') as f:
        pickle.dump(losses, f)