jamino30's picture
Upload folder using huggingface_hub
ecf0440 verified
raw
history blame
1.4 kB
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from safetensors.torch import load_file
from data_loader import PASCALSDataset
from model import U2Net
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
def load_model(model, model_path):
state_dict = load_file(model_path, device=device.type)
model.load_state_dict(state_dict)
model.eval()
def eval(model, loader, criterion):
model.eval()
running_loss = 0.
with torch.no_grad():
for images, masks in tqdm(loader, desc='Testing'):
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = sum([criterion(output, masks) for output in outputs])
running_loss += loss.item()
return running_loss / len(loader)
if __name__ == '__main__':
batch_size = 1
model_type = input('Model type [b,f]: ')
model_name = 'best-u2net-duts-msra.safetensors' if model_type == 'b' else 'u2net-duts-msra.safetensors'
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
model = U2Net().to(device)
model = nn.DataParallel(model)
load_model(model, f'results/{model_name}')
loader = DataLoader(PASCALSDataset(split='eval'), batch_size=batch_size, shuffle=False)
loss = eval(model, loader, loss_fn)
print(f'Loss: {loss:.4f}')