import torch import argparse import numpy as np from src.models.LSTM.model import Poetry_Model_lstm from src.datasets.dataloader import train_vec from src.utils.utils import make_cuda def parse_arguments(): # argument parsing parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting") parser.add_argument('--model', type=str, default='lstm', help="lstm/GRU/Seq2Seq/Transformer/GPT-2") parser.add_argument('--Word2Vec', default=True) parser.add_argument('--strict_dataset', default=False, help="strict dataset") parser.add_argument('--n_hidden', type=int, default=128) parser.add_argument('--save_path', type=str, default='save_models/lstm_50.pth') return parser.parse_args() def generate_poetry(model, head_string, w1, word_2_index, index_2_word): print("藏头诗生成中...., {}".format(head_string)) poem = "" # 以句子的每一个字为开头生成诗句 for head in head_string: if head not in word_2_index: print("抱歉,不能生成以{}开头的诗".format(head)) return sentence = head max_sent_len = 20 h_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32)) c_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32)) input_eval = word_2_index[head] for i in range(max_sent_len): if args.Word2Vec: word_embedding = torch.tensor(w1[input_eval][None][None]) else: word_embedding = torch.tensor([input_eval]).unsqueeze(dim=0) pre, (h_0, c_0) = model(word_embedding, h_0, c_0) char_generated = index_2_word[int(torch.argmax(pre))] if char_generated == '。': break # 以新生成的字为输入继续向下生成 input_eval = word_2_index[char_generated] sentence += char_generated poem += '\n' + sentence return poem def infer(model,string): args = parse_arguments() all_data, (w1, word_2_index, index_2_word) = train_vec() args.word_size, args.embedding_num = w1.shape # string = input("诗头:") # string = '自然语言' args.model=model if args.model == 'lstm': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) args.save_path = 'save_models/lstm_50.pth' elif args.model == 'GRU': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) args.save_path = 'save_models/GRU_50.pth' elif args.model == 'Seq2Seq': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) elif args.model == 'Transformer': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) elif args.model == 'GPT-2': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) else: print("Please choose a model!\n") model.load_state_dict(torch.load(args.save_path)) model = make_cuda(model) poem = generate_poetry(model, string, w1, word_2_index, index_2_word) return poem if __name__ == '__main__': args = parse_arguments() all_data, (w1, word_2_index, index_2_word) = train_vec() args.word_size, args.embedding_num = w1.shape # string = input("诗头:") string = '自然语言' if args.model == 'lstm': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) elif args.model == 'GRU': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) elif args.model == 'Seq2Seq': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) elif args.model == 'Transformer': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) elif args.model == 'GPT-2': model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec) else: print("Please choose a model!\n") model.load_state_dict(torch.load(args.save_path)) model = make_cuda(model) poem = generate_poetry(model, string, w1, word_2_index, index_2_word) print(poem)