File size: 4,361 Bytes
0a1104e
 
 
0666c69
0a1104e
 
0666c69
0a1104e
 
 
 
 
 
 
 
 
 
0666c69
0a1104e
 
 
 
0666c69
0a1104e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0666c69
0a1104e
 
 
0666c69
0a1104e
0666c69
0a1104e
 
0666c69
0a1104e
 
0666c69
0a1104e
 
 
0666c69
0a1104e
 
 
 
 
0666c69
0a1104e
0666c69
0a1104e
 
 
 
 
 
 
0666c69
 
 
0a1104e
 
 
 
 
 
 
0666c69
0a1104e
 
 
 
 
0666c69
0a1104e
0666c69
0a1104e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import argparse
import numpy as np
from src.models.LSTM.model import Poetry_Model_lstm
from src.datasets.dataloader import train_vec
from src.utils.utils import make_cuda


def parse_arguments():
    # argument parsing
    parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
    parser.add_argument('--model', type=str, default='lstm',
                        help="lstm/GRU/Seq2Seq/Transformer/GPT-2")
    parser.add_argument('--Word2Vec', default=True)
    parser.add_argument('--strict_dataset', default=False, help="strict dataset")
    parser.add_argument('--n_hidden', type=int, default=128)

    parser.add_argument('--save_path', type=str, default='save_models/lstm_50.pth')

    return parser.parse_args()


def generate_poetry(model, head_string, w1, word_2_index, index_2_word):
    print("藏头诗生成中...., {}".format(head_string))
    poem = ""
    # 以句子的每一个字为开头生成诗句
    for head in head_string:
        if head not in word_2_index:
            print("抱歉,不能生成以{}开头的诗".format(head))
            return

        sentence = head
        max_sent_len = 20

        h_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32))
        c_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32))

        input_eval = word_2_index[head]
        for i in range(max_sent_len):
            if args.Word2Vec:
                word_embedding = torch.tensor(w1[input_eval][None][None])
            else:
                word_embedding = torch.tensor([input_eval]).unsqueeze(dim=0)
            pre, (h_0, c_0) = model(word_embedding, h_0, c_0)
            char_generated = index_2_word[int(torch.argmax(pre))]

            if char_generated == '。':
                break
            # 以新生成的字为输入继续向下生成
            input_eval = word_2_index[char_generated]
            sentence += char_generated

        poem += '\n' + sentence

    return poem

def infer(model,string):
    args = parse_arguments()
    all_data, (w1, word_2_index, index_2_word) = train_vec()
    args.word_size, args.embedding_num = w1.shape
    # string = input("诗头:")
    # string = '自然语言'
    args.model=model
    if args.model == 'lstm':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
        args.save_path = 'save_models/lstm_50.pth'
    elif args.model == 'GRU':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
        args.save_path = 'save_models/GRU_50.pth'
    elif args.model == 'Seq2Seq':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
    elif args.model == 'Transformer':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
    elif args.model == 'GPT-2':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
    else:
        print("Please choose a model!\n")

    model.load_state_dict(torch.load(args.save_path))
    model = make_cuda(model)
    poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
    return poem


if __name__ == '__main__':
    args = parse_arguments()
    all_data, (w1, word_2_index, index_2_word) = train_vec()
    args.word_size, args.embedding_num = w1.shape
    # string = input("诗头:")
    string = '自然语言'

    if args.model == 'lstm':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
    elif args.model == 'GRU':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
    elif args.model == 'Seq2Seq':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
    elif args.model == 'Transformer':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
    elif args.model == 'GPT-2':
        model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
    else:
        print("Please choose a model!\n")

    model.load_state_dict(torch.load(args.save_path))
    model = make_cuda(model)
    poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
    print(poem)