Spaces:
Sleeping
Sleeping
File size: 3,010 Bytes
5464cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torch.amp import autocast, GradScaler
from data_loader import DUTSDataset, MSRADataset
from model import U2Net
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scaler = GradScaler()
def train_one_epoch(model, loader, criterion, optimizer):
model.train()
running_loss = 0.
for images, masks in tqdm(loader, desc='Training', leave=False):
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
optimizer.zero_grad()
with autocast(device_type='cuda'):
outputs = model(images)
loss = sum([criterion(output, masks) for output in outputs])
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
return running_loss / len(loader)
def validate(model, loader, criterion):
model.eval()
running_loss = 0.
with torch.no_grad():
for images, masks in tqdm(loader, desc='Validating', leave=False):
images, masks = images.to(device, non_blocking=True), masks.to(device, non_blocking=True)
outputs = model(images)
loss = sum([criterion(output, masks) for output in outputs])
running_loss += loss.item()
avg_loss = running_loss / len(loader)
return avg_loss
if __name__ == '__main__':
batch_size = 40
valid_batch_size = 80
epochs = 100
lr = 1e-4
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
model_name = 'u2net-duts'
model = U2Net()
model = torch.nn.DataParallel(model.to(device))
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
train_loader = DataLoader(
ConcatDataset([DUTSDataset(split='train'), MSRADataset(split='train')]),
batch_size=batch_size, shuffle=True, pin_memory=True,
num_workers=16, persistent_workers=True
)
valid_loader = DataLoader(
ConcatDataset([DUTSDataset(split='valid'), MSRADataset(split='valid')]),
batch_size=valid_batch_size, shuffle=False, pin_memory=True,
num_workers=16, persistent_workers=True
)
losses = {'train': [], 'val': []}
for epoch in tqdm(range(epochs), desc='Epochs'):
torch.cuda.empty_cache()
train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
val_loss = validate(model, valid_loader, loss_fn)
losses['train'].append(train_loss)
losses['val'].append(val_loss)
if (epoch + 1) % 10 == 0:
torch.save(model.state_dict(), f'results/inter-{model_name}.pt')
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
torch.save(model.state_dict(), f'results/{model_name}.pt')
with open('results/loss.txt', 'wb') as f:
pickle.dump(losses, f) |