Transport_Mode_Detector / classifier.py
agueroooooooooo's picture
First Commit
3d75a04
raw
history blame
3.08 kB
import torch
import numpy as np
from torch import nn
from data_loader import DataLoader
from helper import ValTest
from modality_lstm import ModalityLSTM
batch_size = 32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_on_gpu = True
output_size = 5
hidden_dim = 128
trip_dim = 7
n_layers = 2
drop_prob = 0.2
net = ModalityLSTM(trip_dim, output_size, batch_size, hidden_dim, n_layers, train_on_gpu, drop_prob, lstm_drop_prob=0.2)
lr=0.001
loss_function = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
epochs = 6
print_every = 5
log_every = 1
evaluate_every = 100
clip = 0.2 # gradient clipping
if train_on_gpu:
net.cuda()
net.train()
dl = DataLoader(batchsize=batch_size, read_from_pickle=True)
dl.prepare_data()
def pad_trajs(trajs, lengths):
for w, elem in enumerate(trajs):
while len(elem) < lengths[0]:
elem.append([-1] * trip_dim)
return trajs
losses, avg_losses = [], []
validator = ValTest(dl.val_batches, net, trip_dim, batch_size, device, loss_function, output_size, dl.get_val_size())
test = ValTest(dl.test_batches, net, trip_dim, batch_size, device, loss_function, output_size, dl.get_test_size())
for e in range(1,epochs+1):
print("epoch ",e)
hidden = net.init_hidden()
counter = 0
torch.cuda.empty_cache()
for train_sorted, labels_sorted in dl.batches():
counter += 1
lengths = [len(x) for x in train_sorted]
print("Lengths are ", lengths)
print("SUm of lengths",sum(lengths))
train_sorted = pad_trajs(train_sorted, lengths)
X = np.asarray(train_sorted, dtype=np.float)
input_tensor = torch.from_numpy(X)
print("Input tensor is ",input_tensor.shape)
input_tensor = input_tensor.to(device)
net.zero_grad()
output, max_padding_for_this_batch = net(input_tensor, lengths)
print("Output is",output.shape)
for labelz in labels_sorted:
while len(labelz) < max_padding_for_this_batch:
labelz.append(-1)
labels_for_loss = torch.tensor(labels_sorted).view(max_padding_for_this_batch * batch_size, -1).squeeze(
1).long().to(device)
print("Labels for loss is",len(labels_for_loss))
loss = loss_function(output.view(
max_padding_for_this_batch*batch_size, -1),
labels_for_loss)
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), clip)
optimizer.step()
if counter % log_every == 0:
losses.append(loss.item())
if counter % print_every == 0:
avg_losses.append(sum(losses[-50:]) / 50)
print(
f'Epoch: {e:2d}. {counter:d} of {int(dl.get_train_size() / batch_size):d} {avg_losses[len(avg_losses) - 1]:f} Loss: {loss.item():.4f}')
if counter % evaluate_every == 0:
validator.run()
torch.save(net.state_dict(),"Model_Wieghts")
print("Testing")
test.run()
torch.save(net.state_dict(),"Model_Wieghts")