File size: 2,262 Bytes
0efb5b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from models.resnet import resnet18
from models.openmax import OpenMax
from models.metamax import MetaMax
from train import GameDataset
from utils.data_stats import load_dataset_stats
from utils.eval_utils import evaluate_known_classes, evaluate_openmax, evaluate_metamax
import os
from pprint import pprint

def test_models():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 加载数据集统计信息
    mean, std = load_dataset_stats()
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    
    # 加载验证集
    test_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=400, shuffle=False, num_workers=4, pin_memory=True)
    
    # 加载基础模型
    model = resnet18(num_classes=20)
    checkpoint = torch.load('models/best_model_99.92_02.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    # 加载OpenMax和MetaMax模型
    try:
        openmax = torch.load('models/best_openmax_94.71_01.pth')
        # metamax = torch.load('models/best_metamax.pth')
        print("Successfully loaded OpenMax and MetaMax models")
    except Exception as e:
        print(f"Error loading models: {e}")
        return
    
    # 测试基础ResNet
    print("\n=== Testing ResNet (Known Classes Only) ===")
    _, accuracy, errors = evaluate_known_classes(model, test_loader, torch.nn.CrossEntropyLoss(), device)
    print(f"Known Classes Accuracy: {accuracy:.2f}%")
    if errors:
        print("\nErrors in known classes:")
        pprint(errors)
    
    # 测试ResNet + OpenMax
    print("\n=== Testing ResNet + OpenMax ===")
    evaluate_openmax(openmax, model, test_loader, device, multiplier=0.5, fraction=0.2, verbose=True)
    
    # 测试ResNet + MetaMax
    # print("\n=== Testing ResNet + MetaMax ===")
    # evaluate_metamax(metamax, model, test_loader, device, threshold=0.5, verbose=True)

if __name__ == '__main__':
    test_models()