sentiment / word2tensor.py
zhanyil2's picture
Upload 12 files
e5a4e3d
raw
history blame contribute delete
No virus
5.18 kB
import pickle
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import one_hot,softmax
import matplotlib.pyplot as plt
import random
import torch.utils.data as data
teacher_forcing_ratio = 0.5
def get_data_set(mode="train"):
NT,AT = None,None
with open(r'dataset/'+"NT.pkl","rb") as f1:
NT = pickle.load(f1)
with open(r'dataset/'+"AT.pkl","rb") as f2:
AT = pickle.load(f2)
def onehot_encode(char, vocab):
# one hot encode a given text
encoded = [0 for _ in range(len(vocab))] #[0,0,1,0,000]
encoded[vocab.index(char)] = 1
return encoded
from read_graph import get_graph
import networkx as nx
G_new = get_graph()
voc = list(G_new)
with open('nodes.pkl', 'wb') as f:
pickle.dump(voc, f, pickle.HIGHEST_PROTOCOL)
voc = None
with open('nodes.pkl', 'rb') as f:
voc = pickle.load(f)
voc.append(0) # 补全符号
voc.append('s') # START
voc.append('e') # EOF
total_word_count = len(voc)
# 轨迹的标签
samples = []
labels = []
if mode=="train":
for tr in NT:
samples.append(tr)
labels.append(1) # 正常
else:
for tr in NT:
samples.append(tr)
labels.append(1) # 正常
for tr in AT:
samples.append(tr)
labels.append(0) # 异常
def padding(x,max_length):
if len(x) > max_length:
text = x[:max_length]
else:
text = x + [[0,0]] * (max_length - len(x))
return text
# 计算最长轨迹
max_len = 10
for tr in samples:
max_len = max(max_len,len(tr))
samples_padded = []
# 补全为长轨迹
for tr in samples:
tr = padding(tr,max_len)
samples_padded.append(tr)
# One hot
def onehot_encode(char, vocab):
# one hot encode a given text
encoded = [0 for _ in range(len(vocab))]
if char != 0:
encoded[vocab.index(char)] = 1
return encoded
samples_one_hot = []
samples_index = []
for tr in samples_padded:
tr_rep = []
tr_rep_index = []
for pt in tr:
spatial = onehot_encode(pt[0], voc)
temporal = int(pt[1])
tr_rep.append(spatial)
tr_rep_index.append(voc.index(pt[0]))
samples_one_hot.append(tr_rep)
samples_index.append(tr_rep_index)
sampletensor = torch.Tensor(samples_one_hot)
sampletensor_index = torch.Tensor(samples_index)
labeltensor = torch.Tensor(labels)
# print("sampletensor.shape",sampletensor.shape)
# print("labeltensor.shape",labeltensor.shape)
return sampletensor,sampletensor_index,labeltensor,max_len
global device
if torch.cuda.is_available():
torch.backends.cudnn.enabled = False
device = torch.device("cuda:0")
torch.cuda.set_device(0)
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
print("Working on GPU")
torch.cuda.empty_cache()
else:
device = torch.device("cpu")
import torch.nn as nn
# from VAE import AE,RNN
if __name__ == '__main__':
sampletensor,sampletensor_index,labeltensor,max_len = get_data_set("train")
batch_size = 2
train_set = data.TensorDataset(sampletensor, sampletensor_index,labeltensor)
train_iter = data.DataLoader(train_set, batch_size, shuffle=False, drop_last=False)
# rnn = RNN(input_size=2694,hidden_size=64,batch_size=2,maxlen=max_len)
# loss = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adamax(rnn.parameters(),lr=1e-2)
#
# net = rnn.to(device)
# num_epochs = 120
#
# h_hat_avg = None
#
# from tqdm import tqdm
# for epoch in tqdm(range(num_epochs)):
# epoch_total_loss = 0
# for x, x_label,y in train_iter:
# # RNN
# xhat,kld,h_hat = net(x,x,"train",None)
# # print(xhat.shape)
# # print(x_label.shape)
# len_all = (x_label.shape[0])*(x_label.shape[1])
# xhat = xhat.reshape(len_all,-1)
# x_label = x_label.reshape(len_all).long().to(device)
# # print(x_label)
# # print("xhat",xhat.shape)
# # print("x_label",x_label.shape)
# l = loss(xhat,x_label)
# # print("reconstruction loss:",l,"kld loss:",kld)
# total_loss = l + kld
# epoch_total_loss += total_loss
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()
# if epoch == num_epochs - 1:
# if h_hat_avg is None:
# h_hat_avg = h_hat/ torch.full(h_hat.shape,len(sampletensor)).to(device)
# else:
# h_hat_avg += h_hat / torch.full(h_hat.shape, len(sampletensor)).to(device)
# print(">>> h_hat_avg",h_hat_avg.shape)
# print(" epoch_total_loss = ",epoch_total_loss)
#
# print("training ends")
# torch.save(net,"LSTM-VAE.pth")
# torch.save(h_hat_avg, 'h_hat_avg.pt')
#
#
#
#
#
#
#
#
#
#
#
#