Spaces:
Runtime error
Runtime error
import torch | |
import matplotlib.pyplot as plt | |
from tqdm.autonotebook import tqdm | |
import pywt | |
import os | |
def display_eval(epoch, epochs, tlength, global_step, tcorrect, tsamples, t_valid_samples, average_train_loss, average_valid_loss, total_acc_val): | |
tqdm.write( | |
f'Epoch: [{epoch + 1}/{epochs}], Step [{global_step}/{epochs*tlength}] | Train Loss: {average_train_loss: .3f} \ | |
| Train Accuracy: {tcorrect / tsamples: .3f} \ | |
| Val Loss: {average_valid_loss: .3f} \ | |
| Val Accuracy: {total_acc_val / t_valid_samples: .3f}') | |
def save_model(model, optimizer, valid_loss, epoch, path='model.pt'): | |
torch.save({'valid_loss': valid_loss, | |
'model_state_dict': model.state_dict(), | |
'epoch': epoch + 1, | |
'optimizer': optimizer.state_dict() | |
}, path) | |
tqdm.write(f'Model saved to ==> {path}') | |
def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'): | |
torch.save({'train_loss_list': train_loss_list, | |
'valid_loss_list': valid_loss_list, | |
'global_steps_list': global_steps_list, | |
}, path) | |
def plot_losses(metrics_save_name='metrics', save_dir='./'): | |
path = f'{save_dir}metrics_{metrics_save_name}.pt' | |
state = torch.load(path) | |
train_loss_list = state['train_loss_list'] | |
valid_loss_list = state['valid_loss_list'] | |
global_steps_list = state['global_steps_list'] | |
plt.plot(global_steps_list, train_loss_list, label='Train') | |
plt.plot(global_steps_list, valid_loss_list, label='Valid') | |
plt.xlabel('Global Steps') | |
plt.ylabel('Loss') | |
plt.legend() | |
plt.show() | |
def train_RNN(epochs, train_loader, valid_loader, model, loss_fn, optimizer, eval_every=0.25, best_valid_loss=float("Inf"), device='cuda', model_save_name='', save_dir='./'): | |
model.train() | |
running_loss = 0.0 | |
valid_running_loss = 0.0 | |
global_step = 0 | |
train_loss_list = [] | |
valid_loss_list = [] | |
global_steps_list = [] | |
wavelet = 'db4' | |
level = 3 | |
for epoch in tqdm(range(epochs)): | |
running_loss = 0.0 | |
t_correct = 0 | |
t_samples = 0 | |
for images, labels, notes in train_loader: | |
optimizer.zero_grad() | |
coeffs = pywt.wavedec(images, wavelet, level=level, axis=1) | |
threshold = 0.1 * \ | |
torch.median(torch.abs(torch.from_numpy(coeffs[-1]))) | |
denoised_coeffs = [pywt.threshold( | |
data=c, mode='hard', value=threshold) for c in coeffs] | |
images = pywt.waverec(denoised_coeffs, wavelet, axis=1) | |
images = torch.tensor(images).float().to(device) | |
labels = labels.to(device) | |
notes = notes.to(device) | |
output = model(images, notes) | |
loss = loss_fn(output, labels.float()) | |
running_loss += loss.item()*len(labels) | |
loss.backward() | |
global_step += 1*len(images) | |
optimizer.step() | |
values, indices = torch.max(output, dim=1) | |
t_correct += sum(1 for s, i in enumerate(indices) | |
if labels[s][i] == 1) | |
t_samples += len(indices) | |
if (global_step % (int(eval_every*len(train_loader.dataset)))) < train_loader.batch_size: | |
model.eval() | |
valid_running_loss = 0.0 | |
total_acc_val = 0 | |
with torch.no_grad(): | |
for images, labels, notes in valid_loader: | |
coeffs = pywt.wavedec( | |
images, wavelet, level=level, axis=1) | |
threshold = 0.1 * \ | |
torch.median( | |
torch.abs(torch.from_numpy(coeffs[-1]))) | |
denoised_coeffs = [pywt.threshold( | |
data=c, mode='hard', value=threshold) for c in coeffs] | |
images = pywt.waverec(denoised_coeffs, wavelet, axis=1) | |
images = torch.tensor(images).float().to(device) | |
labels = labels.to(device) | |
notes = notes.to(device) | |
output = model(images, notes) | |
loss = loss_fn(output, labels.float()).item() | |
valid_running_loss += loss*len(images) | |
values, indices = torch.max(output, dim=1) | |
total_acc_val += sum(1 for s, | |
i in enumerate(indices) if labels[s][i] == 1) | |
# evaluation | |
average_train_loss = running_loss / t_samples | |
average_valid_loss = valid_running_loss / \ | |
len(valid_loader.dataset) | |
train_loss_list.append(average_train_loss) | |
valid_loss_list.append(average_valid_loss) | |
global_steps_list.append(global_step) | |
display_eval(epoch, epochs, len(train_loader.dataset), global_step, t_correct, t_samples, len( | |
valid_loader.dataset), average_train_loss, average_valid_loss, total_acc_val) | |
# resetting running values | |
model.train() | |
if best_valid_loss > average_valid_loss: | |
best_valid_loss = average_valid_loss | |
save_model(model, optimizer, best_valid_loss, epoch, | |
path=f'{save_dir}model_{model_save_name}.pt') | |
save_metrics(train_loss_list, valid_loss_list, | |
global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt') | |
save_metrics(train_loss_list, valid_loss_list, global_steps_list, | |
path=f'{save_dir}metrics_{model_save_name}.pt') | |
print("Training complete.") | |
return model | |
def evaluate_RNN(model, test_loader, device="cuda"): | |
model.eval() | |
y_pred = [] | |
y_true = [] | |
wavelet = 'db4' | |
level = 3 | |
total_acc_test = 0 | |
with torch.no_grad(): | |
for images, labels, notes in test_loader: | |
coeffs = pywt.wavedec(images, wavelet, level=level, axis=1) | |
threshold = 0.1 * \ | |
torch.median(torch.abs(torch.from_numpy(coeffs[-1]))) | |
denoised_coeffs = [pywt.threshold( | |
data=c, mode='hard', value=threshold) for c in coeffs] | |
images = pywt.waverec(denoised_coeffs, wavelet, axis=1) | |
images = torch.tensor(images).float().to(device) | |
labels = labels.to(device) | |
notes = notes.to(device) | |
output = model(images, notes) | |
values, indices = torch.max(output, dim=1) | |
y_pred.extend(indices.tolist()) | |
y_true.extend(labels.tolist()) | |
total_acc_test += sum(1 for s, | |
i in enumerate(indices) if labels[s][i] == 1) | |
test_accuracy = total_acc_test / len(test_loader.dataset) | |
print(f'Test Accuracy: {test_accuracy: .3f}') | |
return test_accuracy | |
def rename_with_acc(save_name, save_dir, acc): | |
acc = round(acc*100) | |
# Rename model | |
new_model_name = f'{save_dir}model_{save_name}_acc_{acc}.pt' | |
new_metrics_name = f'{save_dir}metrics_{save_name}_acc_{acc}.pt' | |
if os.path.isfile(new_model_name): | |
os.remove(new_model_name) | |
if os.path.isfile(new_metrics_name): | |
os.remove(new_metrics_name) | |
os.rename(f'{save_dir}model_{save_name}.pt', | |
f'{save_dir}model_{save_name}_acc_{acc}.pt') | |
# Rename metrics | |
os.rename(f'{save_dir}metrics_{save_name}.pt', | |
f'{save_dir}metrics_{save_name}_acc_{acc}.pt') | |