Spaces:
Runtime error
Runtime error
# coding: UTF-8 | |
import time | |
import torch | |
import numpy as np | |
from train_eval import train, init_network | |
from importlib import import_module | |
import argparse | |
parser = argparse.ArgumentParser(description='Chinese Text Classification') | |
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer') | |
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained') | |
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char') | |
args = parser.parse_args() | |
if __name__ == '__main__': | |
dataset = 'THUCNews' # 数据集 | |
# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random | |
embedding = 'embedding_SougouNews.npz' | |
if args.embedding == 'random': | |
embedding = 'random' | |
model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer | |
if model_name == 'FastText': | |
from utils_fasttext import build_dataset, build_iterator, get_time_dif | |
embedding = 'random' | |
else: | |
from utils import build_dataset, build_iterator, get_time_dif | |
x = import_module('models.' + model_name) | |
config = x.Config(dataset, embedding) | |
np.random.seed(1) | |
torch.manual_seed(1) | |
torch.cuda.manual_seed_all(1) | |
torch.backends.cudnn.deterministic = True # 保证每次结果一样 | |
start_time = time.time() | |
print("Loading data...") | |
vocab, train_data, dev_data, test_data = build_dataset(config, args.word) | |
train_iter = build_iterator(train_data, config) | |
dev_iter = build_iterator(dev_data, config) | |
test_iter = build_iterator(test_data, config) | |
time_dif = get_time_dif(start_time) | |
print("Time usage:", time_dif) | |
# train | |
config.n_vocab = len(vocab) | |
model = x.Model(config).to(config.device) | |
if model_name != 'Transformer': | |
init_network(model) | |
print(model.parameters) | |
train(config, model, train_iter, dev_iter, test_iter) | |