Spaces:
Configuration error
Configuration error
from src.utils.utils import make_cuda | |
from src.apis.train import train, evaluate | |
from src.models.LSTM.model import Poetry_Model_lstm | |
import argparse | |
import torch | |
import os | |
from src.datasets.dataloader import Poetry_Dataset, train_vec, get_poetry, split_text | |
from torch.utils.data import DataLoader | |
def parse_arguments(): | |
# argument parsing | |
parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting") | |
parser.add_argument('--model', type=str, default='lstm', | |
help="lstm/GRU/Seq2Seq/Transformer/GPT-2") | |
parser.add_argument('--Word2Vec', default=True) | |
parser.add_argument('--Augmented_dataset', default=False, help="augmented dataset") | |
parser.add_argument('--strict_dataset', default=False, help="strict dataset") | |
parser.add_argument('--batch_size', type=int, default=64, | |
help="Specify batch size") | |
parser.add_argument('--num_epochs', type=int, default=50, | |
help="Specify the number of epochs for competitive search") | |
parser.add_argument('--log_step', type=int, default=100, | |
help="Specify log step size for training") | |
parser.add_argument('--learning_rate', type=float, default=1e-3, | |
help="Learning rate") | |
parser.add_argument('--data', type=str, default='data/poetry.txt', | |
help="Path to the dataset") | |
parser.add_argument('--Augmented_data', type=str, default='data/poetry_7.txt', | |
help="Path to the Augmented_dataset") | |
parser.add_argument('--n_hidden', type=int, default=128) | |
parser.add_argument('--max_grad_norm', type=float, default=1.0) | |
parser.add_argument('--save_path', type=str, default='save_models/') | |
return parser.parse_args() | |
def main(): | |
args = parse_arguments() | |
# if you want to change the data(org data or argument data), please delete file: 'split_poetry.txt' and 'org_poetry.txt' | |
if os.path.exists("data/split_poetry.txt") and os.path.exists("data/org_poetry.txt"): | |
print("pre_file exit!") | |
else: | |
split_text(get_poetry(args)) | |
all_data, (w1, word_2_index, index_2_word) = train_vec() | |
args.word_size, args.embedding_num = w1.shape | |
dataset = Poetry_Dataset(w1, word_2_index, all_data, args.Word2Vec) | |
train_size = int(len(dataset) * 0.8) | |
test_size = len(dataset) - train_size | |
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) | |
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) | |
valid_data_loader = DataLoader(test_dataset, batch_size=int(args.batch_size/4), shuffle=True) | |
if args.model == 'lstm': | |
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
elif args.model == 'GRU': | |
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
elif args.model == 'Seq2Seq': | |
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
elif args.model == 'Transformer': | |
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
elif args.model == 'GPT-2': | |
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) | |
else: | |
print("Please choose a model!\n") | |
best_model = make_cuda(best_model) | |
best_model = train(args, best_model, train_data_loader) | |
torch.save(best_model.state_dict(), args.save_path + args.model + '_' + str(args.num_epochs)+'.pth') | |
print('test evaluation:') | |
evaluate(args, best_model, valid_data_loader) | |
if __name__ == '__main__': | |
main() | |