ando55's picture
Update solver.py
68a4358
import torch.optim as optim
import numpy as np
import torch
from torch.autograd import Variable
import random
from torch.nn.utils import clip_grad_norm
import copy
import os
import pickle
def get_decoder_index_XY(batchY):
'''
:param batchY: like [0 0 1 0 0 0 0 1]
:return:
'''
returnX =[]
returnY =[]
for i in range(len(batchY)):
curY = batchY[i]
index_1 = np.where(curY==1)
decoderY = index_1[0]
if len(index_1[0]) ==1:
decoderX = np.array([0])
else:
decoderX = np.append([0],decoderY[0:-1]+1)
returnX.append(decoderX)
returnY.append(decoderY)
returnX = np.array(returnX)
returnY = np.array(returnY)
return returnX,returnY
def align_variable_numpy(X,maxL,paddingNumber):
aligned = []
for cur in X:
ext_cur = []
ext_cur.extend(cur)
ext_cur.extend([paddingNumber] * (maxL - len(cur)))
aligned.append(ext_cur)
aligned = np.array(aligned)
return aligned
def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
select_index = np.array(range(len(numpyY)))
select_index = np.array(range(len(numpyX)))
batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]
index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
all_lens = np.array([len(x) for x in batch_y])
maxL = np.max(all_lens)
idx = np.argsort(all_lens)
idx = np.sort(idx)
batch_x = [batch_x[i] for i in idx]
batch_y = [batch_y[i] for i in idx]
all_lens = all_lens[idx]
index_decoder_X = np.array([index_decoder_X[i] for i in idx])
index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])
numpy_batch_x = batch_x
batch_x = align_variable_numpy(batch_x,maxL,2000001)
batch_y = align_variable_numpy(batch_y,maxL,2)
batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))
return numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL
class TrainSolver(object):
def __init__(self, model,train_x,train_y,dev_x,dev_y,save_path,batch_size,eval_size,epoch, lr,lr_decay_epoch,weight_decay,use_cuda):
self.lr = lr
self.model = model
self.epoch = epoch
self.train_x = train_x
self.train_y = train_y
self.use_cuda = use_cuda
self.batch_size = batch_size
self.lr_decay_epoch = lr_decay_epoch
self.eval_size = eval_size
self.dev_x, self.dev_y = dev_x, dev_y
self.model = model
self.save_path = save_path
self.weight_decay =weight_decay
def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):
tokendic = {}
for n,i in enumerate(index2word):
tokendic[n] = i
sents = []
for i,cur_seq_y in enumerate(ground_b):
fuku = fukugen[i]
index_of_1 = np.where(cur_seq_y==1)[0]
index_pre = pre_b[i]
inp = x[i]
index_pre = np.array(index_pre)
END_B = index_of_1[-1]
index_pre = index_pre[index_pre != END_B]
index_of_1 = index_of_1[index_of_1 != END_B]
index_of_1 = list(index_of_1)
index_pre = list(index_pre)
FP = []
sent = []
ex = ""
sent = [tokendic[int(j.to('cpu').detach().numpy().copy())] for j in inp]
for k in index_pre:
if k not in index_of_1:
FP.append(k)
#FP = [int(j.to('cpu').detach().numpy().copy()) for j in FP]
for n,k in enumerate(zip(sent, fuku)):
f = k[1]
i = k[0]
if k == "<pad>":
continue
if n in FP:
ex += f
sents.append(ex)
ex = ""
else:
ex += f
sents.append(ex)
return sents
def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
for nloop in range(1):
dataY = data2Y[nloop]
dataX = data2X[nloop]
fukugen = fukugen2[nloop]
need_loop = int(np.ceil(len(dataY) / self.batch_size))
for lp in range(need_loop):
startN = lp*self.batch_size
endN = (lp+1)*self.batch_size
if endN > len(dataY):
endN = len(dataY)
fukuge = fukugen[startN:endN]
numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)
batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,index_decoder_Y,all_lens)
output_texts = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)
return output_texts