ZJ commited on
Commit
0a1104e
·
1 Parent(s): fb200e8

first version

Browse files
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch
2
+ /.torch
3
+
4
+ # Data files
5
+ *.csv
6
+ *.json
7
+ *.tsv
8
+
9
+ # Model files
10
+ *.ckpt
11
+ *.pth
12
+ *.pkl
13
+
14
+ # Logs and checkpoints
15
+ logs/
16
+ checkpoints/
17
+
18
+ # Secondary files
19
+ *.pyc
20
+ __pycache__/
21
+ .DS_Store
22
+
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/nlp.iml" filepath="$PROJECT_DIR$/.idea/nlp.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/nlp.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ChangeListManager">
4
+ <list default="true" id="975e88fb-d387-4d2c-9625-dc69f610d124" name="Changes" comment="">
5
+ <change afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" afterDir="false" />
6
+ <change afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
7
+ <change afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
8
+ <change afterPath="$PROJECT_DIR$/.idea/nlp.iml" afterDir="false" />
9
+ <change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
10
+ <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
11
+ <change beforePath="$PROJECT_DIR$/src/apis/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/apis/train.py" afterDir="false" />
12
+ <change beforePath="$PROJECT_DIR$/src/datasets/dataloader.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/datasets/dataloader.py" afterDir="false" />
13
+ <change beforePath="$PROJECT_DIR$/src/models/LSTM/model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/models/LSTM/model.py" afterDir="false" />
14
+ </list>
15
+ <option name="SHOW_DIALOG" value="false" />
16
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
17
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
18
+ <option name="LAST_RESOLUTION" value="IGNORE" />
19
+ </component>
20
+ <component name="Git.Settings">
21
+ <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
22
+ </component>
23
+ <component name="MarkdownSettingsMigration">
24
+ <option name="stateVersion" value="1" />
25
+ </component>
26
+ <component name="ProjectId" id="2Q8D9XoYiTKL5jiaHLTd3rsHf4Y" />
27
+ <component name="ProjectViewState">
28
+ <option name="hideEmptyMiddlePackages" value="true" />
29
+ <option name="showLibraryContents" value="true" />
30
+ </component>
31
+ <component name="PropertiesComponent">{
32
+ &quot;keyToString&quot;: {
33
+ &quot;RunOnceActivity.OpenProjectViewOnStart&quot;: &quot;true&quot;,
34
+ &quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
35
+ &quot;last_opened_file_path&quot;: &quot;D:/YOU/dasanxia/NLP/new0522/nlp&quot;,
36
+ &quot;settings.editor.selected.configurable&quot;: &quot;com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable&quot;
37
+ }
38
+ }</component>
39
+ <component name="RunManager" selected="Python.run_gradio">
40
+ <configuration name="inference" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
41
+ <module name="nlp" />
42
+ <option name="INTERPRETER_OPTIONS" value="" />
43
+ <option name="PARENT_ENVS" value="true" />
44
+ <envs>
45
+ <env name="PYTHONUNBUFFERED" value="1" />
46
+ </envs>
47
+ <option name="SDK_HOME" value="" />
48
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
49
+ <option name="IS_MODULE_SDK" value="true" />
50
+ <option name="ADD_CONTENT_ROOTS" value="true" />
51
+ <option name="ADD_SOURCE_ROOTS" value="true" />
52
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/inference.py" />
53
+ <option name="PARAMETERS" value="" />
54
+ <option name="SHOW_COMMAND_LINE" value="false" />
55
+ <option name="EMULATE_TERMINAL" value="false" />
56
+ <option name="MODULE_MODE" value="false" />
57
+ <option name="REDIRECT_INPUT" value="false" />
58
+ <option name="INPUT_FILE" value="" />
59
+ <method v="2" />
60
+ </configuration>
61
+ <configuration name="run_gradio" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
62
+ <module name="nlp" />
63
+ <option name="INTERPRETER_OPTIONS" value="" />
64
+ <option name="PARENT_ENVS" value="true" />
65
+ <envs>
66
+ <env name="PYTHONUNBUFFERED" value="1" />
67
+ </envs>
68
+ <option name="SDK_HOME" value="" />
69
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
70
+ <option name="IS_MODULE_SDK" value="true" />
71
+ <option name="ADD_CONTENT_ROOTS" value="true" />
72
+ <option name="ADD_SOURCE_ROOTS" value="true" />
73
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/run_gradio.py" />
74
+ <option name="PARAMETERS" value="" />
75
+ <option name="SHOW_COMMAND_LINE" value="false" />
76
+ <option name="EMULATE_TERMINAL" value="false" />
77
+ <option name="MODULE_MODE" value="false" />
78
+ <option name="REDIRECT_INPUT" value="false" />
79
+ <option name="INPUT_FILE" value="" />
80
+ <method v="2" />
81
+ </configuration>
82
+ <recent_temporary>
83
+ <list>
84
+ <item itemvalue="Python.run_gradio" />
85
+ <item itemvalue="Python.inference" />
86
+ </list>
87
+ </recent_temporary>
88
+ </component>
89
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
90
+ <component name="TaskManager">
91
+ <task active="true" id="Default" summary="Default task">
92
+ <changelist id="975e88fb-d387-4d2c-9625-dc69f610d124" name="Changes" comment="" />
93
+ <created>1684726163448</created>
94
+ <option name="number" value="Default" />
95
+ <option name="presentableId" value="Default" />
96
+ <updated>1684726163448</updated>
97
+ </task>
98
+ <servers />
99
+ </component>
100
+ <component name="Vcs.Log.Tabs.Properties">
101
+ <option name="TAB_STATES">
102
+ <map>
103
+ <entry key="MAIN">
104
+ <value>
105
+ <State />
106
+ </value>
107
+ </entry>
108
+ </map>
109
+ </option>
110
+ </component>
111
+ </project>
app.py CHANGED
@@ -1,7 +1,14 @@
1
- import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
1
+ # !/user/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
 
4
+ import gradio
5
+ from inference import infer
6
 
7
+
8
+
9
+
10
+ INTERFACE = gradio.Interface(fn=infer, inputs=["text","text"], outputs=["text"], title="Poem Generation",
11
+ description="model: lstm/GRU/Seq2Seq/Transformer/GPT-2",
12
+ thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true")
13
+
14
+ INTERFACE.launch(inbrowser=True)
data/org_poetry.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/poetry.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/split_poetry.txt ADDED
The diff for this file is too large to render. See raw diff
 
example.jpg ADDED
inference.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import numpy as np
4
+ from src.models.LSTM.model import Poetry_Model_lstm
5
+ from src.datasets.dataloader import train_vec
6
+ from src.utils.utils import make_cuda
7
+ from src.models.Transformer.model import Poetry_Model_Transformer
8
+
9
+ def parse_arguments():
10
+ # argument parsing
11
+ parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
12
+ parser.add_argument('--model', type=str, default='lstm',
13
+ help="lstm/GRU/Seq2Seq/Transformer/GPT-2")
14
+ parser.add_argument('--Word2Vec', default=True)
15
+ parser.add_argument('--strict_dataset', default=False, help="strict dataset")
16
+ parser.add_argument('--n_hidden', type=int, default=128)
17
+
18
+ parser.add_argument('--save_path', type=str, default='save_models/model_params.pth')
19
+
20
+ return parser.parse_args()
21
+
22
+
23
+ def generate_poetry(model, head_string, w1, word_2_index, index_2_word,args):
24
+ print("藏头诗生成中...., {}".format(head_string))
25
+ poem = ""
26
+ # 以句子的每一个字为开头生成诗句
27
+ for head in head_string:
28
+ if head not in word_2_index:
29
+ print("抱歉,不能生成以{}开头的诗".format(head))
30
+ return
31
+
32
+ sentence = head
33
+ max_sent_len = 20
34
+
35
+ h_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32))
36
+ c_0 = torch.tensor(np.zeros((2, 1, args.n_hidden), dtype=np.float32))
37
+
38
+ input_eval = word_2_index[head]
39
+ for i in range(max_sent_len):
40
+ if args.Word2Vec:
41
+ word_embedding = torch.tensor(w1[input_eval][None][None])
42
+ else:
43
+ word_embedding = torch.tensor([input_eval]).unsqueeze(dim=0)
44
+ pre, (h_0, c_0) = model(word_embedding, h_0, c_0)
45
+ char_generated = index_2_word[int(torch.argmax(pre))]
46
+
47
+ if char_generated == '。':
48
+ break
49
+ # 以新生成的字为输入继续向下生成
50
+ input_eval = word_2_index[char_generated]
51
+ sentence += char_generated
52
+
53
+ poem += '\n' + sentence
54
+
55
+ return poem
56
+
57
+ def infer(model,poem_head):
58
+ args = parse_arguments()
59
+ args.model=model
60
+
61
+ all_data, (w1, word_2_index, index_2_word) = train_vec()
62
+ args.word_size, args.embedding_num = w1.shape
63
+ string = poem_head
64
+ # string = '自然语言'
65
+
66
+ if args.model == 'lstm':
67
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
68
+ args.save_path='save_models/lstm_25.pth'
69
+ elif args.model == 'GRU':
70
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
71
+ elif args.model == 'Seq2Seq':
72
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
73
+ elif args.model == 'Transformer':
74
+ model = Poetry_Model_Transformer(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
75
+ args.save_path='save_models/transformer.pth'
76
+ elif args.model == 'GPT-2':
77
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
78
+ else:
79
+ print("Please choose a model!\n")
80
+
81
+ model.load_state_dict(torch.load(args.save_path))
82
+ model = make_cuda(model)
83
+ poem = generate_poetry(model, string, w1, word_2_index, index_2_word,args)
84
+ return poem
85
+
86
+
87
+ if __name__ == '__main__':
88
+ args = parse_arguments()
89
+ all_data, (w1, word_2_index, index_2_word) = train_vec()
90
+ args.word_size, args.embedding_num = w1.shape
91
+ string = input("诗头:")
92
+ # string = '自然语言'
93
+
94
+ if args.model == 'lstm':
95
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
96
+ args.save_path='save_models/lstm_25.pth'
97
+ elif args.model == 'GRU':
98
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
99
+ elif args.model == 'Seq2Seq':
100
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
101
+ elif args.model == 'Transformer':
102
+ model = Poetry_Model_Transformer(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
103
+ args.save_path='save_models/transformer.pth'
104
+ elif args.model == 'GPT-2':
105
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
106
+ else:
107
+ print("Please choose a model!\n")
108
+
109
+ model.load_state_dict(torch.load(args.save_path))
110
+ model = make_cuda(model)
111
+ poem = generate_poetry(model, string, w1, word_2_index, index_2_word,args)
112
+ print(poem)
scripts/lstm_infer.sh ADDED
File without changes
scripts/lstm_train.sh ADDED
File without changes
src/__init__.py ADDED
File without changes
src/apis/__init__.py ADDED
File without changes
src/apis/train.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from src.utils.utils import make_cuda
7
+ from torch.nn import functional as F
8
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
9
+
10
+
11
+ def train(args, model, data_loader):
12
+ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
13
+
14
+ model.train()
15
+ num_epochs = args.num_epochs
16
+
17
+ for epoch in range(num_epochs):
18
+ loss = 0
19
+ for step, (features, targets) in enumerate(data_loader):
20
+ features = make_cuda(features)
21
+ targets = make_cuda(targets)
22
+
23
+ optimizer.zero_grad()
24
+
25
+ pre, _ = model(features)
26
+ crs_loss = model.cross_entropy(pre, targets.reshape(-1))
27
+ loss += crs_loss.item()
28
+ crs_loss.backward()
29
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
30
+ optimizer.step()
31
+
32
+ # print step info
33
+ if (step + 1) % args.log_step == 0:
34
+ print("Epoch [%.3d/%.3d] Step [%.3d/%.3d]: CROSS_loss=%.4f, RCROSS_loss=%.4f"
35
+ % (epoch + 1,
36
+ num_epochs,
37
+ step + 1,
38
+ len(data_loader),
39
+ loss / args.log_step,
40
+ math.sqrt(loss / args.log_step)))
41
+ loss = 0
42
+
43
+ # Loss = []
44
+ # for step, (features, targets) in enumerate(valid_data_loader):
45
+ # features = make_cuda(features)
46
+ # targets = make_cuda(targets)
47
+ # model.eval()
48
+ # preds = model(features)
49
+ # valid_loss = CrossLoss(preds, targets)
50
+ # Loss.append(valid_loss)
51
+ # print("Valid loss: %.3d\n" % (np.mean(Loss)))
52
+
53
+ return model
54
+
55
+
56
+ def evaluate(args, model, data_loader):
57
+ model.eval()
58
+ loss = []
59
+ for step, (features, targets) in enumerate(data_loader):
60
+ features = make_cuda(features)
61
+ targets = make_cuda(targets)
62
+
63
+ pre, _ = model(features)
64
+ crs_loss = model.cross_entropy(pre, targets.reshape(-1))
65
+ loss.append(crs_loss.item())
66
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
67
+
68
+ print("loss=%.4f" % (np.mean(loss)))
src/datasets/__init__.py ADDED
File without changes
src/datasets/dataloader.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from gensim.models.word2vec import Word2Vec
7
+ from torch.utils.data import Dataset
8
+
9
+
10
+ def padding(poetries, maxlen, pad):
11
+ batch_seq = [poetry + pad * (maxlen - len(poetry)) for poetry in poetries]
12
+ return batch_seq
13
+
14
+
15
+ # 输入向后滑一字符为target,即预测下一个字
16
+ def split_input_target(seq):
17
+ inputs = seq[:-1]
18
+ targets = seq[1:]
19
+ return inputs, targets
20
+
21
+
22
+ # 创建词汇表
23
+ def get_poetry(arg):
24
+ poetrys = []
25
+ if arg.Augmented_dataset:
26
+ path = arg.Augmented_data
27
+ else:
28
+ path = arg.data # 数据集路径,data/poetry.txt
29
+ with open(path, "r", encoding='UTF-8') as f:
30
+ for line in f:
31
+ try:
32
+ # line = line.decode('UTF-8')
33
+ line = line.strip(u'\n') # 去掉换行符
34
+ if arg.Augmented_dataset:
35
+ content = line.strip(u' ')
36
+ else:
37
+ title, content = line.strip(u' ').split(u':') # 标题和内容以冒号分隔
38
+ content = content.replace(u' ', u'') # 去掉空格
39
+ if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content: # 去掉特殊符号的古诗
40
+ continue
41
+ if arg.strict_dataset: # 严格模式
42
+ if len(content) < 12 or len(content) > 79:
43
+ continue
44
+ else:
45
+ if len(content) < 5 or len(content) > 79:
46
+ continue
47
+ content = u'[' + content + u']' # 开头加上开始符,结尾加上结束符
48
+ poetrys.append(content) # 保存到poetrys列表中
49
+ except Exception as e:
50
+ pass
51
+
52
+ # 按诗的字数排序
53
+ poetrys = sorted(poetrys, key=lambda line: len(line))
54
+
55
+ with open("data/org_poetry.txt", "w", encoding="utf-8") as f:
56
+ for poetry in poetrys:
57
+ poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n'
58
+ f.write(poetry)
59
+
60
+ return poetrys
61
+
62
+
63
+ # 切分文档
64
+ def split_text(poetrys):
65
+ with open("data/split_poetry.txt", "w", encoding="utf-8") as f:
66
+ for poetry in poetrys:
67
+ poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n '
68
+ split_data = " ".join(poetry)
69
+ f.write(split_data)
70
+ return open("data/split_poetry.txt", "r", encoding='UTF-8').read()
71
+
72
+
73
+ # 训练词向量
74
+ def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"):
75
+ param_file = "data/word_vec.pkl"
76
+ org_data = open(org_file, "r", encoding="utf-8").read().split("\n")
77
+ if os.path.exists(split_file):
78
+ all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")
79
+ else:
80
+ all_data_split = split_text().split("\n")
81
+
82
+ if os.path.exists(param_file):
83
+ return org_data, pickle.load(open(param_file, "rb"))
84
+
85
+ models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1) # 训练词向量,输入参数分别是:分词后的文本,词向量维度,线程数,最小词频
86
+ pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb")) # 保存词向量,key_to_index是词汇表,index_to_key是词向量,dump的作用是将数据序列化到文件中
87
+ return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key) # syn1neg是词向量,key_to_index是词汇表,index_to_key是词向量
88
+
89
+
90
+ class Poetry_Dataset(Dataset):
91
+ def __init__(self, w1, word_2_index, all_data, Word2Vec):
92
+ self.Word2Vec = Word2Vec
93
+ self.w1 = w1
94
+ self.word_2_index = word_2_index
95
+ word_size, embedding_num = w1.shape
96
+ self.embedding = nn.Embedding(word_size, embedding_num) # 词嵌入层
97
+ # 最长句子长度
98
+ maxlen = max([len(seq) for seq in all_data])
99
+ pad = ' '
100
+ self.all_data = padding(all_data[:-1], maxlen, pad)
101
+
102
+ def __getitem__(self, index):
103
+ a_poetry = self.all_data[index]
104
+
105
+ a_poetry_index = [self.word_2_index[i] for i in a_poetry]
106
+ xs, ys = split_input_target(a_poetry_index)
107
+ if self.Word2Vec:
108
+ xs_embedding = self.w1[xs]
109
+ else:
110
+ xs_embedding = np.array(xs)
111
+
112
+ return xs_embedding, np.array(ys).astype(np.int64)
113
+
114
+ def __len__(self):
115
+ return len(self.all_data)
src/models/LSTM/__init__.py ADDED
File without changes
src/models/LSTM/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+
5
+
6
+ class Poetry_Model_lstm(nn.Module):
7
+ def __init__(self, hidden_num, word_size, embedding_num, Word2Vec):
8
+ super().__init__()
9
+
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ self.hidden_num = hidden_num
12
+ self.Word2Vec = Word2Vec
13
+
14
+ self.embedding = nn.Embedding(word_size, embedding_num)
15
+ self.lstm = nn.LSTM(input_size=embedding_num, hidden_size=hidden_num, batch_first=True, num_layers=2,
16
+ bidirectional=False)
17
+ self.dropout = nn.Dropout(0.3)
18
+ self.flatten = nn.Flatten(0, 1)
19
+ self.linear = nn.Linear(hidden_num, word_size)
20
+ self.cross_entropy = nn.CrossEntropyLoss()
21
+
22
+ def forward(self, xs_embedding, h_0=None, c_0=None):
23
+ # xs_embedding: [batch_size, max_seq_len, n_feature] n_feature=256
24
+ if h_0 == None or c_0 == None:
25
+ h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
26
+ c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
27
+ h_0 = h_0.to(self.device)
28
+ c_0 = c_0.to(self.device)
29
+ xs_embedding = xs_embedding.to(self.device)
30
+ if not self.Word2Vec:
31
+ xs_embedding = self.embedding(xs_embedding)
32
+ hidden, (h_0, c_0) = self.lstm(xs_embedding, (h_0, c_0))
33
+ hidden_drop = self.dropout(hidden)
34
+ hidden_flatten = self.flatten(hidden_drop)
35
+ pre = self.linear(hidden_flatten)
36
+ # pre:[batch_size*max_seq_len, vocab_size]
37
+ return pre, (h_0, c_0)
src/models/Transformer/__init__.py ADDED
File without changes
src/models/Transformer/model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import math
5
+
6
+ class Poetry_Model_Transformer(nn.Module):
7
+ def __init__(self, hidden_num, word_size, embedding_num, Word2Vec):
8
+ super().__init__()
9
+
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ self.hidden_num = hidden_num
12
+ self.Word2Vec = Word2Vec
13
+
14
+ # 位置编码
15
+ self.pos_encoder= PositionalEncoding(d_model=embedding_num,dropout=0.5)
16
+ self.embedding = nn.Embedding(word_size, embedding_num)
17
+ self.transformer = nn.Transformer(d_model=embedding_num, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
18
+ dim_feedforward=2048, dropout=0.5, activation='relu') # 输入的参数分别是:词嵌入的维度,多头注意力的头数,编码器层数,解码器层数,前馈网络的隐藏层维度,dropout概率,激活函数
19
+
20
+ # 编码器
21
+ self.encoder_layer=nn.TransformerEncoderLayer(d_model=embedding_num,nhead=8,dim_feedforward=2048,dropout=0.5)
22
+ self.encoder_norm=nn.LayerNorm(normalized_shape=embedding_num)
23
+ self.encoder=nn.TransformerEncoder(self.encoder_layer,num_layers=6,norm=self.encoder_norm)
24
+
25
+ # 解码器
26
+ # self.decoder_layer=nn.TransformerDecoderLayer(d_model=embedding_num,nhead=8,dim_feedforward=2048,dropout=0.5)
27
+ # self.decoder_norm=nn.LayerNorm(normalized_shape=embedding_num)
28
+ # self.decoder=nn.TransformerDecoder(self.decoder_layer,num_layers=6,norm=self.decoder_norm)
29
+ self.flatten = nn.Flatten(0, 1)
30
+ self.linear1 = nn.Linear(embedding_num, hidden_num)
31
+ self.linear2 = nn.Linear(hidden_num, word_size)
32
+ self.cross_entropy = nn.CrossEntropyLoss()
33
+
34
+ def forward(self, xs_embedding, h_0=None, c_0=None):
35
+ if h_0 == None or c_0 == None:
36
+ h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
37
+ c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
38
+ h_0 = h_0.to(self.device)
39
+ c_0 = c_0.to(self.device)
40
+ xs_embedding = xs_embedding.to(self.device)
41
+
42
+ if not self.Word2Vec:
43
+ xs_embedding = self.embedding(xs_embedding)
44
+
45
+ encoder_input = self.pos_encoder(xs_embedding)
46
+ pre_encoded=self.encoder(encoder_input)
47
+ pre=self.linear2(self.linear1(self.flatten(pre_encoded)))
48
+
49
+ # pre:[batch_size*max_seq_len, vocab_size]
50
+ return pre, (h_0, c_0)
51
+
52
+ class PositionalEncoding(nn.Module):
53
+
54
+ def __init__(self, d_model, dropout = 0.1, max_len = 5000):
55
+ super(PositionalEncoding, self).__init__()
56
+
57
+ self.dropout = nn.Dropout(p = dropout)
58
+
59
+ pe = torch.zeros(max_len, d_model)
60
+
61
+ position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
62
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
63
+ pe[:, 0::2] = torch.sin(position * div_term)
64
+ pe[:, 1::2] = torch.cos(position * div_term)
65
+ pe = pe.unsqueeze(0).transpose(0, 1)
66
+ self.register_buffer('pe', pe)
67
+
68
+ def forward(self, x):
69
+ x = x + self.pe[:x.size(0), :]
70
+ return self.dropout(x)
src/models/__init__.py ADDED
File without changes
src/utils/__init__.py ADDED
File without changes
src/utils/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def make_cuda(tensor):
5
+ """Use CUDA if it's available."""
6
+ if torch.cuda.is_available():
7
+ tensor = tensor.cuda()
8
+ return tensor
9
+
10
+
11
+ def is_minimum(value, indiv_to_rmse):
12
+ if len(indiv_to_rmse) == 0:
13
+ return True
14
+ temp = list(indiv_to_rmse.values())
15
+ return True if value < min(temp) else False
train.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.utils import make_cuda
2
+ from src.apis.train import train, evaluate
3
+ from src.models.LSTM.model import Poetry_Model_lstm
4
+ import argparse
5
+ import torch
6
+ import os
7
+ from src.datasets.dataloader import Poetry_Dataset, train_vec, get_poetry, split_text
8
+ from torch.utils.data import DataLoader
9
+ from src.models.Transformer.model import Poetry_Model_Transformer
10
+
11
+
12
+ def parse_arguments():
13
+ # argument parsing
14
+ parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
15
+
16
+ parser.add_argument('--batch_size', type=int, default=64,
17
+ help="Specify batch size")
18
+
19
+ parser.add_argument('--initial_epochs', type=int, default=25,
20
+ help="Specify the number of epochs for initial training")
21
+
22
+ parser.add_argument('--num_epochs', type=int, default=50,
23
+ help="Specify the number of epochs for competitive search")
24
+ parser.add_argument('--log_step', type=int, default=100,
25
+ help="Specify log step size for training")
26
+ parser.add_argument('--learning_rate', type=float, default=1e-3,
27
+ help="Learning rate")
28
+ parser.add_argument('--data', type=str, default='data/poetry.txt',
29
+ help="Path to the dataset")
30
+ parser.add_argument('--n_hidden', type=int, default=128)
31
+ parser.add_argument('--max_grad_norm', type=float, default=1.0)
32
+
33
+ parser.add_argument('--save_path', type=str, default='save_models/transformer.pth')
34
+ parser.add_argument('--strict_dataset', default=False, help="strict dataset")
35
+ parser.add_argument('--Word2Vec',type=bool, default=True)
36
+ parser.add_argument("--Augmented_dataset", type=bool, default=False)
37
+ return parser.parse_args()
38
+
39
+
40
+ def main():
41
+ args = parse_arguments()
42
+ if os.path.exists("data/split_poetry.txt") and os.path.exists("data/org_poetry.txt"):
43
+ print("pre_file exit!")
44
+ else:
45
+ split_text(get_poetry(args)) # split poetry
46
+ all_data, (w1, word_2_index, index_2_word) = train_vec()
47
+ args.word_size, args.embedding_num = w1.shape # 词向量的维度
48
+
49
+ dataset = Poetry_Dataset(w1, word_2_index, all_data, Word2Vec=args.Word2Vec)
50
+ train_size = int(len(dataset) * 0.7)
51
+ test_size = len(dataset) - train_size
52
+ train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
53
+
54
+ train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
55
+ valid_data_loader = DataLoader(test_dataset, batch_size=int(args.batch_size/4), shuffle=True)
56
+
57
+ # best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num,args.Word2Vec)
58
+ best_model = Poetry_Model_Transformer(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
59
+ best_model = make_cuda(best_model) # use gpu
60
+ print("Initial training before competitive random search")
61
+ best_model = train(args, best_model, train_data_loader)
62
+
63
+ torch.save(best_model.state_dict(), args.save_path)
64
+
65
+ print('test evaluation:')
66
+ evaluate(args, best_model, valid_data_loader)
67
+
68
+
69
+ if __name__ == '__main__':
70
+ main()