File size: 2,262 Bytes
0efb5b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from models.resnet import resnet18
from models.openmax import OpenMax
from models.metamax import MetaMax
from train import GameDataset
from utils.data_stats import load_dataset_stats
from utils.eval_utils import evaluate_known_classes, evaluate_openmax, evaluate_metamax
import os
from pprint import pprint
def test_models():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载数据集统计信息
mean, std = load_dataset_stats()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# 加载验证集
test_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=400, shuffle=False, num_workers=4, pin_memory=True)
# 加载基础模型
model = resnet18(num_classes=20)
checkpoint = torch.load('models/best_model_99.92_02.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
# 加载OpenMax和MetaMax模型
try:
openmax = torch.load('models/best_openmax_94.71_01.pth')
# metamax = torch.load('models/best_metamax.pth')
print("Successfully loaded OpenMax and MetaMax models")
except Exception as e:
print(f"Error loading models: {e}")
return
# 测试基础ResNet
print("\n=== Testing ResNet (Known Classes Only) ===")
_, accuracy, errors = evaluate_known_classes(model, test_loader, torch.nn.CrossEntropyLoss(), device)
print(f"Known Classes Accuracy: {accuracy:.2f}%")
if errors:
print("\nErrors in known classes:")
pprint(errors)
# 测试ResNet + OpenMax
print("\n=== Testing ResNet + OpenMax ===")
evaluate_openmax(openmax, model, test_loader, device, multiplier=0.5, fraction=0.2, verbose=True)
# 测试ResNet + MetaMax
# print("\n=== Testing ResNet + MetaMax ===")
# evaluate_metamax(metamax, model, test_loader, device, threshold=0.5, verbose=True)
if __name__ == '__main__':
test_models()
|