|
import pickle |
|
|
|
import torch |
|
import numpy as np |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn.functional import one_hot,softmax |
|
import matplotlib.pyplot as plt |
|
import random |
|
import torch.utils.data as data |
|
teacher_forcing_ratio = 0.5 |
|
|
|
|
|
def get_data_set(mode="train"): |
|
NT,AT = None,None |
|
with open(r'dataset/'+"NT.pkl","rb") as f1: |
|
NT = pickle.load(f1) |
|
|
|
with open(r'dataset/'+"AT.pkl","rb") as f2: |
|
AT = pickle.load(f2) |
|
|
|
|
|
def onehot_encode(char, vocab): |
|
|
|
encoded = [0 for _ in range(len(vocab))] |
|
encoded[vocab.index(char)] = 1 |
|
return encoded |
|
|
|
from read_graph import get_graph |
|
import networkx as nx |
|
G_new = get_graph() |
|
voc = list(G_new) |
|
with open('nodes.pkl', 'wb') as f: |
|
pickle.dump(voc, f, pickle.HIGHEST_PROTOCOL) |
|
|
|
voc = None |
|
with open('nodes.pkl', 'rb') as f: |
|
voc = pickle.load(f) |
|
|
|
voc.append(0) |
|
voc.append('s') |
|
voc.append('e') |
|
|
|
total_word_count = len(voc) |
|
|
|
|
|
|
|
samples = [] |
|
labels = [] |
|
if mode=="train": |
|
for tr in NT: |
|
samples.append(tr) |
|
labels.append(1) |
|
else: |
|
for tr in NT: |
|
samples.append(tr) |
|
labels.append(1) |
|
for tr in AT: |
|
samples.append(tr) |
|
labels.append(0) |
|
|
|
def padding(x,max_length): |
|
if len(x) > max_length: |
|
text = x[:max_length] |
|
else: |
|
text = x + [[0,0]] * (max_length - len(x)) |
|
return text |
|
|
|
|
|
|
|
max_len = 10 |
|
for tr in samples: |
|
max_len = max(max_len,len(tr)) |
|
samples_padded = [] |
|
|
|
|
|
for tr in samples: |
|
tr = padding(tr,max_len) |
|
samples_padded.append(tr) |
|
|
|
|
|
def onehot_encode(char, vocab): |
|
|
|
encoded = [0 for _ in range(len(vocab))] |
|
if char != 0: |
|
encoded[vocab.index(char)] = 1 |
|
return encoded |
|
|
|
samples_one_hot = [] |
|
samples_index = [] |
|
for tr in samples_padded: |
|
tr_rep = [] |
|
tr_rep_index = [] |
|
for pt in tr: |
|
spatial = onehot_encode(pt[0], voc) |
|
temporal = int(pt[1]) |
|
tr_rep.append(spatial) |
|
tr_rep_index.append(voc.index(pt[0])) |
|
samples_one_hot.append(tr_rep) |
|
samples_index.append(tr_rep_index) |
|
|
|
sampletensor = torch.Tensor(samples_one_hot) |
|
sampletensor_index = torch.Tensor(samples_index) |
|
labeltensor = torch.Tensor(labels) |
|
|
|
|
|
return sampletensor,sampletensor_index,labeltensor,max_len |
|
|
|
global device |
|
|
|
if torch.cuda.is_available(): |
|
torch.backends.cudnn.enabled = False |
|
device = torch.device("cuda:0") |
|
torch.cuda.set_device(0) |
|
import os |
|
os.environ['CUDA_VISIBLE_DEVICES']='0' |
|
print("Working on GPU") |
|
torch.cuda.empty_cache() |
|
else: |
|
device = torch.device("cpu") |
|
|
|
import torch.nn as nn |
|
|
|
|
|
if __name__ == '__main__': |
|
sampletensor,sampletensor_index,labeltensor,max_len = get_data_set("train") |
|
|
|
batch_size = 2 |
|
train_set = data.TensorDataset(sampletensor, sampletensor_index,labeltensor) |
|
train_iter = data.DataLoader(train_set, batch_size, shuffle=False, drop_last=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|