Spaces:
Runtime error
Runtime error
# coding: UTF-8 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from sklearn import metrics | |
import time | |
from utils import get_time_dif | |
from tensorboardX import SummaryWriter | |
# 权重初始化,默认xavier | |
def init_network(model, method='xavier', exclude='embedding', seed=123): | |
for name, w in model.named_parameters(): | |
if exclude not in name: | |
if 'weight' in name: | |
if method == 'xavier': | |
nn.init.xavier_normal_(w) | |
elif method == 'kaiming': | |
nn.init.kaiming_normal_(w) | |
else: | |
nn.init.normal_(w) | |
elif 'bias' in name: | |
nn.init.constant_(w, 0) | |
else: | |
pass | |
def train(config, model, train_iter, dev_iter, test_iter): | |
start_time = time.time() | |
model.train() | |
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) | |
# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率 | |
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) | |
total_batch = 0 # 记录进行到多少batch | |
dev_best_loss = float('inf') | |
last_improve = 0 # 记录上次验证集loss下降的batch数 | |
flag = False # 记录是否很久没有效果提升 | |
writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime())) | |
for epoch in range(config.num_epochs): | |
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) | |
# scheduler.step() # 学习率衰减 | |
for i, (trains, labels) in enumerate(train_iter): | |
outputs = model(trains) | |
model.zero_grad() | |
loss = F.cross_entropy(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
if total_batch % 100 == 0: | |
# 每多少轮输出在训练集和验证集上的效果 | |
true = labels.data.cpu() | |
predic = torch.max(outputs.data, 1)[1].cpu() | |
train_acc = metrics.accuracy_score(true, predic) | |
dev_acc, dev_loss = evaluate(config, model, dev_iter) | |
if dev_loss < dev_best_loss: | |
dev_best_loss = dev_loss | |
torch.save(model.state_dict(), config.save_path) | |
improve = '*' | |
last_improve = total_batch | |
else: | |
improve = '' | |
time_dif = get_time_dif(start_time) | |
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' | |
print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) | |
writer.add_scalar("loss/train", loss.item(), total_batch) | |
writer.add_scalar("loss/dev", dev_loss, total_batch) | |
writer.add_scalar("acc/train", train_acc, total_batch) | |
writer.add_scalar("acc/dev", dev_acc, total_batch) | |
model.train() | |
total_batch += 1 | |
if total_batch - last_improve > config.require_improvement: | |
# 验证集loss超过1000batch没下降,结束训练 | |
print("No optimization for a long time, auto-stopping...") | |
flag = True | |
break | |
if flag: | |
break | |
writer.close() | |
test(config, model, test_iter) | |
def test(config, model, test_iter): | |
# test | |
model.load_state_dict(torch.load(config.save_path)) | |
model.eval() | |
start_time = time.time() | |
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True) | |
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}' | |
print(msg.format(test_loss, test_acc)) | |
print("Precision, Recall and F1-Score...") | |
print(test_report) | |
print("Confusion Matrix...") | |
print(test_confusion) | |
time_dif = get_time_dif(start_time) | |
print("Time usage:", time_dif) | |
def evaluate(config, model, data_iter, test=False): | |
model.eval() | |
loss_total = 0 | |
predict_all = np.array([], dtype=int) | |
labels_all = np.array([], dtype=int) | |
with torch.no_grad(): | |
for texts, labels in data_iter: | |
outputs = model(texts) | |
loss = F.cross_entropy(outputs, labels) | |
loss_total += loss | |
labels = labels.data.cpu().numpy() | |
predic = torch.max(outputs.data, 1)[1].cpu().numpy() | |
labels_all = np.append(labels_all, labels) | |
predict_all = np.append(predict_all, predic) | |
acc = metrics.accuracy_score(labels_all, predict_all) | |
if test: | |
report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4) | |
confusion = metrics.confusion_matrix(labels_all, predict_all) | |
return acc, loss_total / len(data_iter), report, confusion | |
return acc, loss_total / len(data_iter) |