|
import torch |
|
import torch.nn as nn |
|
|
|
from textCNN_data import textCNN_data, textCNN_param, dataLoader_param |
|
from torch.utils.data import DataLoader |
|
from multihead_attention import my_model |
|
import os |
|
from torch.nn import functional as F |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
|
|
def validation(model, val_dataLoader, device): |
|
model.eval() |
|
total = 0 |
|
correct = 0 |
|
with torch.no_grad(): |
|
for i, (clas, sentences) in enumerate(val_dataLoader): |
|
try: |
|
|
|
|
|
out = model( |
|
sentences.to( |
|
device)) |
|
|
|
|
|
|
|
|
|
pred = torch.argmax(out, dim=1) |
|
|
|
correct += (pred == clas.to(device)).sum() |
|
total += clas.size()[0] |
|
except IndexError as e: |
|
print(i) |
|
print('clas', clas) |
|
print('clas size', clas.size()) |
|
print('sentence', sentences) |
|
print('sentences size', sentences.size()) |
|
print(e) |
|
print(e.__traceback__) |
|
exit() |
|
|
|
acc = correct / total |
|
return acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(torch.cuda.get_device_name()) |
|
if torch.cuda.is_available(): |
|
device = 'cuda:0' |
|
else: |
|
device = 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
print('init dataset...') |
|
trainDataFile = 'traindata_vec.txt' |
|
valDataFile = 'devdata_vec.txt' |
|
train_dataset = textCNN_data(trainDataFile) |
|
train_dataLoader = DataLoader(train_dataset, |
|
batch_size=dataLoader_param['batch_size'], |
|
shuffle=True) |
|
|
|
val_dataset = textCNN_data(valDataFile) |
|
val_dataLoader = DataLoader(val_dataset, |
|
batch_size=dataLoader_param['batch_size'], |
|
|
|
shuffle=False) |
|
|
|
if __name__ == "__main__": |
|
|
|
seed = 3407 |
|
|
|
print('init net...') |
|
model = my_model() |
|
model.to(device) |
|
print(model) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
print("training...") |
|
|
|
best_dev_acc = 0 |
|
|
|
for epoch in range(100): |
|
model.train() |
|
for i, (clas, sentences) in enumerate(train_dataLoader): |
|
|
|
|
|
|
|
out = model(sentences.to( |
|
device)) |
|
try: |
|
loss = criterion(out, clas.to(device)) |
|
except: |
|
print(out.size(), out) |
|
print(clas.size(), clas) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
if (i + 1) % 10 == 0: |
|
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item()) |
|
model.eval() |
|
dev_acc = validation(model=model, val_dataLoader=val_dataLoader, |
|
device=device) |
|
|
|
if best_dev_acc < dev_acc: |
|
best_dev_acc = dev_acc |
|
print("save model...") |
|
torch.save(model.state_dict(), "model.bin") |
|
print("epoch:", epoch + 1, "step:", i + 1, "loss:", loss.item()) |
|
print("best dev acc %.4f dev acc %.4f" % (best_dev_acc, dev_acc)) |
|
|