Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from torch import nn | |
from data_loader import DataLoader | |
from helper import ValTest | |
from modality_lstm import ModalityLSTM | |
batch_size = 32 | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
train_on_gpu = True | |
output_size = 5 | |
hidden_dim = 128 | |
trip_dim = 7 | |
n_layers = 2 | |
drop_prob = 0.2 | |
net = ModalityLSTM(trip_dim, output_size, batch_size, hidden_dim, n_layers, train_on_gpu, drop_prob, lstm_drop_prob=0.2) | |
lr=0.001 | |
loss_function = nn.CrossEntropyLoss(ignore_index=-1) | |
optimizer = torch.optim.Adam(net.parameters(), lr=lr) | |
epochs = 6 | |
print_every = 5 | |
log_every = 1 | |
evaluate_every = 100 | |
clip = 0.2 # gradient clipping | |
if train_on_gpu: | |
net.cuda() | |
net.train() | |
dl = DataLoader(batchsize=batch_size, read_from_pickle=True) | |
dl.prepare_data() | |
def pad_trajs(trajs, lengths): | |
for w, elem in enumerate(trajs): | |
while len(elem) < lengths[0]: | |
elem.append([-1] * trip_dim) | |
return trajs | |
losses, avg_losses = [], [] | |
validator = ValTest(dl.val_batches, net, trip_dim, batch_size, device, loss_function, output_size, dl.get_val_size()) | |
test = ValTest(dl.test_batches, net, trip_dim, batch_size, device, loss_function, output_size, dl.get_test_size()) | |
for e in range(1,epochs+1): | |
print("epoch ",e) | |
hidden = net.init_hidden() | |
counter = 0 | |
torch.cuda.empty_cache() | |
for train_sorted, labels_sorted in dl.batches(): | |
counter += 1 | |
lengths = [len(x) for x in train_sorted] | |
print("Lengths are ", lengths) | |
print("SUm of lengths",sum(lengths)) | |
train_sorted = pad_trajs(train_sorted, lengths) | |
X = np.asarray(train_sorted, dtype=np.float) | |
input_tensor = torch.from_numpy(X) | |
print("Input tensor is ",input_tensor.shape) | |
input_tensor = input_tensor.to(device) | |
net.zero_grad() | |
output, max_padding_for_this_batch = net(input_tensor, lengths) | |
print("Output is",output.shape) | |
for labelz in labels_sorted: | |
while len(labelz) < max_padding_for_this_batch: | |
labelz.append(-1) | |
labels_for_loss = torch.tensor(labels_sorted).view(max_padding_for_this_batch * batch_size, -1).squeeze( | |
1).long().to(device) | |
print("Labels for loss is",len(labels_for_loss)) | |
loss = loss_function(output.view( | |
max_padding_for_this_batch*batch_size, -1), | |
labels_for_loss) | |
loss.backward() | |
nn.utils.clip_grad_norm_(net.parameters(), clip) | |
optimizer.step() | |
if counter % log_every == 0: | |
losses.append(loss.item()) | |
if counter % print_every == 0: | |
avg_losses.append(sum(losses[-50:]) / 50) | |
print( | |
f'Epoch: {e:2d}. {counter:d} of {int(dl.get_train_size() / batch_size):d} {avg_losses[len(avg_losses) - 1]:f} Loss: {loss.item():.4f}') | |
if counter % evaluate_every == 0: | |
validator.run() | |
torch.save(net.state_dict(),"Model_Wieghts") | |
print("Testing") | |
test.run() | |
torch.save(net.state_dict(),"Model_Wieghts") |