File size: 1,396 Bytes
5464cad
 
 
 
 
ecf0440
5464cad
 
 
 
 
 
 
 
ecf0440
5464cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecf0440
5464cad
ecf0440
 
5464cad
 
 
ecf0440
5464cad
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from safetensors.torch import load_file

from data_loader import PASCALSDataset
from model import U2Net

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

def load_model(model, model_path):
    state_dict = load_file(model_path, device=device.type)
    model.load_state_dict(state_dict)
    model.eval()

def eval(model, loader, criterion):
    model.eval()
    running_loss = 0.
    with torch.no_grad():
        for images, masks in tqdm(loader, desc='Testing'):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = sum([criterion(output, masks) for output in outputs])
            running_loss += loss.item()
    return running_loss / len(loader)


if __name__ == '__main__':
    batch_size = 1

    model_type = input('Model type [b,f]: ')
    model_name = 'best-u2net-duts-msra.safetensors' if model_type == 'b' else 'u2net-duts-msra.safetensors'
    loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
    model = U2Net().to(device)
    model = nn.DataParallel(model)
    load_model(model, f'results/{model_name}')
    
    loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False)

    loss = eval(model, loader, loss_fn)
    print(f'Loss: {loss:.4f}')