TCN_UL_acitivity / testing_interface.py
liesdillen's picture
Upload 4 files
da7cd93 verified
raw
history blame
No virus
2.78 kB
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from data_loader_interface import load_imu
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report, average_precision_score
from torch.utils.data import Dataset, DataLoader, TensorDataset, ConcatDataset
from einops import rearrange
from torch.optim.lr_scheduler import StepLR
def test(model, test_loader, output_file):
running_loss = 0.0
predlist = torch.zeros(0, dtype=torch.long, device='cpu')
lbllist = torch.zeros(0, dtype=torch.long, device='cpu')
number_of_nan_loss = 0
with torch.no_grad(), open(output_file, 'w') as f:
for batch_idx, (data) in enumerate(test_loader):
data = data.float().to('cpu')
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
predlist = torch.cat([predlist, predicted.view(-1).cpu()])
# lbllist = torch.cat([lbllist, labels.view(-1).cpu()])
for i, prediction in enumerate(predicted):
f.write(f"{prediction.item()}\n")
return running_loss / len(test_loader), predlist
class MyDataset_labeled(Dataset):
def __init__(self, x):
self.x_data = torch.from_numpy(np.array(x)).to(torch.float)
self.len = x.shape[0]
def __getitem__(self, idx):
return self.x_data[idx]
def __len__(self):
return self.len
def model_defining(test_data_array, name_model, output_file):
test_losses = []
accuracies = []
f1_scores = []
cm_list = []
cr_list = []
model = torch.load(str(name_model) + ".pth", map_location=torch.device('cpu'))
model.eval()
test_data = load_imu(test_data_array)
test_data = rearrange(test_data, 'n t c -> n c t')
test_dataset = MyDataset_labeled(test_data)
test_dataloaders = DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False)
print("Data loaded ...")
test_loss, predlist = test(model, test_dataloaders, output_file)
test_losses.append(test_loss)
def main():
name_model = "./models_chadwick/training_3S_model_validation/S3_101_102_103_validation_epoch_10"
train_subjects = []
test_subjects = ['104']
output_file = "S3_101_102_103_validation_epoch_10_tested_104_predicted.txt"
label_file = "S3_101_102_103_validation_epoch_10_tested_104_label.txt"
classification_file = "S3_101_102_103_validation_epoch_10_tested_104_classification.txt"
model_defining(train_subjects, test_subjects, name_model, output_file, label_file, classification_file)
if __name__ == "__main__":
main()