Spaces:
Configuration error
Configuration error
Upload 53 files
Browse files- .idea/.gitignore +3 -0
- .idea/.name +1 -0
- .idea/ea_lstm.iml +8 -0
- .idea/inspectionProfiles/Project_Default.xml +20 -0
- .idea/misc.xml +1 -1
- .idea/modules.xml +1 -1
- .idea/workspace.xml +18 -85
- README.md +22 -11
- __pycache__/inference.cpython-38.pyc +0 -0
- app.py +1 -1
- data/poetry_7.txt +0 -0
- data/word_vec.pkl +3 -0
- inference.py +18 -22
- save_models/GRU_25.pth +3 -0
- save_models/GRU_50.pth +3 -0
- save_models/lstm_25.pth +1 -1
- save_models/lstm_50.pth +3 -0
- src/__pycache__/__init__.cpython-38.pyc +0 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/apis/__pycache__/__init__.cpython-39.pyc +0 -0
- src/apis/__pycache__/inference.cpython-39.pyc +0 -0
- src/apis/__pycache__/train.cpython-39.pyc +0 -0
- src/apis/evaluate.py +23 -0
- src/apis/train.py +2 -2
- src/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- src/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- src/datasets/__pycache__/dataloader.cpython-38.pyc +0 -0
- src/datasets/__pycache__/dataloader.cpython-39.pyc +0 -0
- src/datasets/dataloader.py +13 -13
- src/models/LSTM/__pycache__/__init__.cpython-38.pyc +0 -0
- src/models/LSTM/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/LSTM/__pycache__/algorithm.cpython-39.pyc +0 -0
- src/models/LSTM/__pycache__/model.cpython-38.pyc +0 -0
- src/models/LSTM/__pycache__/model.cpython-39.pyc +0 -0
- src/models/LSTM/model.py +1 -1
- src/models/__pycache__/__init__.cpython-38.pyc +0 -0
- src/models/__pycache__/__init__.cpython-39.pyc +0 -0
- src/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- src/utils/__pycache__/utils.cpython-38.pyc +0 -0
- src/utils/__pycache__/utils.cpython-39.pyc +0 -0
- train.py +32 -20
.idea/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
.idea/.name
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
inference.py
|
.idea/ea_lstm.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="inheritedJdk" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredPackages">
|
6 |
+
<value>
|
7 |
+
<list size="7">
|
8 |
+
<item index="0" class="java.lang.String" itemvalue="easydict" />
|
9 |
+
<item index="1" class="java.lang.String" itemvalue="pandas" />
|
10 |
+
<item index="2" class="java.lang.String" itemvalue="matplotlib" />
|
11 |
+
<item index="3" class="java.lang.String" itemvalue="pillow" />
|
12 |
+
<item index="4" class="java.lang.String" itemvalue="mindspore" />
|
13 |
+
<item index="5" class="java.lang.String" itemvalue="setuptools" />
|
14 |
+
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
15 |
+
</list>
|
16 |
+
</value>
|
17 |
+
</option>
|
18 |
+
</inspection_tool>
|
19 |
+
</profile>
|
20 |
+
</component>
|
.idea/misc.xml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
<project version="4">
|
3 |
-
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
|
4 |
</project>
|
|
|
1 |
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (pytorch)" project-jdk-type="Python SDK" />
|
4 |
</project>
|
.idea/modules.xml
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
<project version="4">
|
3 |
<component name="ProjectModuleManager">
|
4 |
<modules>
|
5 |
-
<module fileurl="file://$PROJECT_DIR$/.idea/
|
6 |
</modules>
|
7 |
</component>
|
8 |
</project>
|
|
|
2 |
<project version="4">
|
3 |
<component name="ProjectModuleManager">
|
4 |
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/ea_lstm.iml" filepath="$PROJECT_DIR$/.idea/ea_lstm.iml" />
|
6 |
</modules>
|
7 |
</component>
|
8 |
</project>
|
.idea/workspace.xml
CHANGED
@@ -1,111 +1,44 @@
|
|
1 |
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
<project version="4">
|
3 |
<component name="ChangeListManager">
|
4 |
-
<list default="true" id="
|
5 |
-
<change afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" afterDir="false" />
|
6 |
-
<change afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
|
7 |
-
<change afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
|
8 |
-
<change afterPath="$PROJECT_DIR$/.idea/nlp.iml" afterDir="false" />
|
9 |
-
<change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
|
10 |
-
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
11 |
-
<change beforePath="$PROJECT_DIR$/src/apis/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/apis/train.py" afterDir="false" />
|
12 |
-
<change beforePath="$PROJECT_DIR$/src/datasets/dataloader.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/datasets/dataloader.py" afterDir="false" />
|
13 |
-
<change beforePath="$PROJECT_DIR$/src/models/LSTM/model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/models/LSTM/model.py" afterDir="false" />
|
14 |
-
</list>
|
15 |
<option name="SHOW_DIALOG" value="false" />
|
16 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
17 |
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
18 |
<option name="LAST_RESOLUTION" value="IGNORE" />
|
19 |
</component>
|
20 |
-
<component name="Git.Settings">
|
21 |
-
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
22 |
-
</component>
|
23 |
<component name="MarkdownSettingsMigration">
|
24 |
<option name="stateVersion" value="1" />
|
25 |
</component>
|
26 |
-
<component name="ProjectId" id="
|
27 |
<component name="ProjectViewState">
|
28 |
<option name="hideEmptyMiddlePackages" value="true" />
|
29 |
<option name="showLibraryContents" value="true" />
|
30 |
</component>
|
31 |
-
<component name="PropertiesComponent"
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
"settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable"
|
37 |
}
|
38 |
-
}
|
39 |
-
<component name="RunManager" selected="Python.run_gradio">
|
40 |
-
<configuration name="inference" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
41 |
-
<module name="nlp" />
|
42 |
-
<option name="INTERPRETER_OPTIONS" value="" />
|
43 |
-
<option name="PARENT_ENVS" value="true" />
|
44 |
-
<envs>
|
45 |
-
<env name="PYTHONUNBUFFERED" value="1" />
|
46 |
-
</envs>
|
47 |
-
<option name="SDK_HOME" value="" />
|
48 |
-
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
49 |
-
<option name="IS_MODULE_SDK" value="true" />
|
50 |
-
<option name="ADD_CONTENT_ROOTS" value="true" />
|
51 |
-
<option name="ADD_SOURCE_ROOTS" value="true" />
|
52 |
-
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/inference.py" />
|
53 |
-
<option name="PARAMETERS" value="" />
|
54 |
-
<option name="SHOW_COMMAND_LINE" value="false" />
|
55 |
-
<option name="EMULATE_TERMINAL" value="false" />
|
56 |
-
<option name="MODULE_MODE" value="false" />
|
57 |
-
<option name="REDIRECT_INPUT" value="false" />
|
58 |
-
<option name="INPUT_FILE" value="" />
|
59 |
-
<method v="2" />
|
60 |
-
</configuration>
|
61 |
-
<configuration name="run_gradio" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
62 |
-
<module name="nlp" />
|
63 |
-
<option name="INTERPRETER_OPTIONS" value="" />
|
64 |
-
<option name="PARENT_ENVS" value="true" />
|
65 |
-
<envs>
|
66 |
-
<env name="PYTHONUNBUFFERED" value="1" />
|
67 |
-
</envs>
|
68 |
-
<option name="SDK_HOME" value="" />
|
69 |
-
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
70 |
-
<option name="IS_MODULE_SDK" value="true" />
|
71 |
-
<option name="ADD_CONTENT_ROOTS" value="true" />
|
72 |
-
<option name="ADD_SOURCE_ROOTS" value="true" />
|
73 |
-
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/run_gradio.py" />
|
74 |
-
<option name="PARAMETERS" value="" />
|
75 |
-
<option name="SHOW_COMMAND_LINE" value="false" />
|
76 |
-
<option name="EMULATE_TERMINAL" value="false" />
|
77 |
-
<option name="MODULE_MODE" value="false" />
|
78 |
-
<option name="REDIRECT_INPUT" value="false" />
|
79 |
-
<option name="INPUT_FILE" value="" />
|
80 |
-
<method v="2" />
|
81 |
-
</configuration>
|
82 |
-
<recent_temporary>
|
83 |
-
<list>
|
84 |
-
<item itemvalue="Python.run_gradio" />
|
85 |
-
<item itemvalue="Python.inference" />
|
86 |
-
</list>
|
87 |
-
</recent_temporary>
|
88 |
-
</component>
|
89 |
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
90 |
<component name="TaskManager">
|
91 |
<task active="true" id="Default" summary="Default task">
|
92 |
-
<changelist id="
|
93 |
-
<created>
|
94 |
<option name="number" value="Default" />
|
95 |
<option name="presentableId" value="Default" />
|
96 |
-
<updated>
|
97 |
</task>
|
98 |
<servers />
|
99 |
</component>
|
100 |
-
<component name="
|
101 |
-
<
|
102 |
-
<
|
103 |
-
<
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
</entry>
|
108 |
-
</map>
|
109 |
-
</option>
|
110 |
</component>
|
111 |
</project>
|
|
|
1 |
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
<project version="4">
|
3 |
<component name="ChangeListManager">
|
4 |
+
<list default="true" id="276a53df-3cdd-4e96-95d3-c1e69d4e9b9f" name="Changes" comment="" />
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
<option name="SHOW_DIALOG" value="false" />
|
6 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
7 |
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
8 |
<option name="LAST_RESOLUTION" value="IGNORE" />
|
9 |
</component>
|
|
|
|
|
|
|
10 |
<component name="MarkdownSettingsMigration">
|
11 |
<option name="stateVersion" value="1" />
|
12 |
</component>
|
13 |
+
<component name="ProjectId" id="2OyFWrJQpFYHFKgf87OgmRH5Jtu" />
|
14 |
<component name="ProjectViewState">
|
15 |
<option name="hideEmptyMiddlePackages" value="true" />
|
16 |
<option name="showLibraryContents" value="true" />
|
17 |
</component>
|
18 |
+
<component name="PropertiesComponent"><![CDATA[{
|
19 |
+
"keyToString": {
|
20 |
+
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
21 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
22 |
+
"last_opened_file_path": "C:/Users/LENOVO/PycharmProjects/lstm"
|
|
|
23 |
}
|
24 |
+
}]]></component>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
26 |
<component name="TaskManager">
|
27 |
<task active="true" id="Default" summary="Default task">
|
28 |
+
<changelist id="276a53df-3cdd-4e96-95d3-c1e69d4e9b9f" name="Changes" comment="" />
|
29 |
+
<created>1682524950142</created>
|
30 |
<option name="number" value="Default" />
|
31 |
<option name="presentableId" value="Default" />
|
32 |
+
<updated>1682524950142</updated>
|
33 |
</task>
|
34 |
<servers />
|
35 |
</component>
|
36 |
+
<component name="XDebuggerManager">
|
37 |
+
<watches-manager>
|
38 |
+
<configuration name="PythonConfigurationType">
|
39 |
+
<watch expression="input_eval" />
|
40 |
+
<watch expression="word_2_index" />
|
41 |
+
</configuration>
|
42 |
+
</watches-manager>
|
|
|
|
|
|
|
43 |
</component>
|
44 |
</project>
|
README.md
CHANGED
@@ -1,12 +1,23 @@
|
|
1 |
-
|
2 |
-
title: Poem Generation
|
3 |
-
emoji: 👁
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: red
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.32.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NLP Final Project
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
```shell
|
4 |
+
├── configs
|
5 |
+
├── data
|
6 |
+
│ └── poetry.txt
|
7 |
+
├── inference.py
|
8 |
+
├── src
|
9 |
+
│ ├── apis
|
10 |
+
│ │ ├── evaluate.py
|
11 |
+
│ │ ├── inference.py
|
12 |
+
│ │ └── train.py
|
13 |
+
│ ├── datasets
|
14 |
+
│ │ └── dataloader.py
|
15 |
+
│ ├── models
|
16 |
+
│ │ └── EA-LSTM
|
17 |
+
│ │ ├── algorithm.py
|
18 |
+
│ │ └── model.py
|
19 |
+
│ └── utils
|
20 |
+
│ └── utils.py
|
21 |
+
├── test.py
|
22 |
+
└── train.py
|
23 |
+
```
|
__pycache__/inference.cpython-38.pyc
ADDED
Binary file (2.88 kB). View file
|
|
app.py
CHANGED
@@ -7,7 +7,7 @@ from inference import infer
|
|
7 |
|
8 |
|
9 |
|
10 |
-
INTERFACE = gradio.Interface(fn=infer, inputs=[gradio.Radio(["lstm","GRU"
|
11 |
description="Choose a model and input the poetic head to generate a acrostic",
|
12 |
thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true")
|
13 |
|
|
|
7 |
|
8 |
|
9 |
|
10 |
+
INTERFACE = gradio.Interface(fn=infer, inputs=[gradio.Radio(["lstm","GRU"]),"text"], outputs=["text"], title="Poetry Generation",
|
11 |
description="Choose a model and input the poetic head to generate a acrostic",
|
12 |
thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true")
|
13 |
|
data/poetry_7.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/word_vec.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1164cfc2e28ef6ecbb1a04734e7268238b4841667f13d6cb4c42e27717dd4575
|
3 |
+
size 6339344
|
inference.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import torch
|
2 |
import argparse
|
3 |
import numpy as np
|
|
|
4 |
from src.datasets.dataloader import train_vec
|
5 |
from src.utils.utils import make_cuda
|
6 |
-
|
7 |
-
from src.models.LSTM.model import Poetry_Model_lstm
|
8 |
|
9 |
def parse_arguments():
|
10 |
# argument parsing
|
@@ -15,12 +15,12 @@ def parse_arguments():
|
|
15 |
parser.add_argument('--strict_dataset', default=False, help="strict dataset")
|
16 |
parser.add_argument('--n_hidden', type=int, default=128)
|
17 |
|
18 |
-
parser.add_argument('--save_path', type=str, default='save_models/
|
19 |
|
20 |
return parser.parse_args()
|
21 |
|
22 |
|
23 |
-
def generate_poetry(model, head_string, w1, word_2_index, index_2_word
|
24 |
print("藏头诗生成中...., {}".format(head_string))
|
25 |
poem = ""
|
26 |
# 以句子的每一个字为开头生成诗句
|
@@ -54,33 +54,31 @@ def generate_poetry(model, head_string, w1, word_2_index, index_2_word,args):
|
|
54 |
|
55 |
return poem
|
56 |
|
57 |
-
def infer(model,
|
58 |
args = parse_arguments()
|
59 |
-
args.model=model
|
60 |
-
|
61 |
all_data, (w1, word_2_index, index_2_word) = train_vec()
|
62 |
args.word_size, args.embedding_num = w1.shape
|
63 |
-
string =
|
64 |
# string = '自然语言'
|
65 |
-
|
66 |
if args.model == 'lstm':
|
67 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
68 |
-
args.save_path='save_models/
|
69 |
elif args.model == 'GRU':
|
70 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
|
|
71 |
elif args.model == 'Seq2Seq':
|
72 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
73 |
elif args.model == 'Transformer':
|
74 |
-
model =
|
75 |
-
args.save_path='save_models/transformer_100.pth'
|
76 |
elif args.model == 'GPT-2':
|
77 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
78 |
else:
|
79 |
print("Please choose a model!\n")
|
80 |
|
81 |
-
model.load_state_dict(torch.load(args.save_path
|
82 |
model = make_cuda(model)
|
83 |
-
poem = generate_poetry(model, string, w1, word_2_index, index_2_word
|
84 |
return poem
|
85 |
|
86 |
|
@@ -88,25 +86,23 @@ if __name__ == '__main__':
|
|
88 |
args = parse_arguments()
|
89 |
all_data, (w1, word_2_index, index_2_word) = train_vec()
|
90 |
args.word_size, args.embedding_num = w1.shape
|
91 |
-
string = input("诗头:")
|
92 |
-
|
93 |
-
|
94 |
if args.model == 'lstm':
|
95 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
96 |
-
args.save_path='save_models/lstm_25.pth'
|
97 |
elif args.model == 'GRU':
|
98 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
99 |
elif args.model == 'Seq2Seq':
|
100 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
101 |
elif args.model == 'Transformer':
|
102 |
-
model =
|
103 |
-
args.save_path='save_models/transformer_100.pth'
|
104 |
elif args.model == 'GPT-2':
|
105 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
106 |
else:
|
107 |
print("Please choose a model!\n")
|
108 |
|
109 |
-
model.load_state_dict(torch.load(args.save_path
|
110 |
model = make_cuda(model)
|
111 |
-
poem = generate_poetry(model, string, w1, word_2_index, index_2_word
|
112 |
print(poem)
|
|
|
1 |
import torch
|
2 |
import argparse
|
3 |
import numpy as np
|
4 |
+
from src.models.LSTM.model import Poetry_Model_lstm
|
5 |
from src.datasets.dataloader import train_vec
|
6 |
from src.utils.utils import make_cuda
|
7 |
+
|
|
|
8 |
|
9 |
def parse_arguments():
|
10 |
# argument parsing
|
|
|
15 |
parser.add_argument('--strict_dataset', default=False, help="strict dataset")
|
16 |
parser.add_argument('--n_hidden', type=int, default=128)
|
17 |
|
18 |
+
parser.add_argument('--save_path', type=str, default='save_models/lstm_50.pth')
|
19 |
|
20 |
return parser.parse_args()
|
21 |
|
22 |
|
23 |
+
def generate_poetry(model, head_string, w1, word_2_index, index_2_word):
|
24 |
print("藏头诗生成中...., {}".format(head_string))
|
25 |
poem = ""
|
26 |
# 以句子的每一个字为开头生成诗句
|
|
|
54 |
|
55 |
return poem
|
56 |
|
57 |
+
def infer(model,string):
|
58 |
args = parse_arguments()
|
|
|
|
|
59 |
all_data, (w1, word_2_index, index_2_word) = train_vec()
|
60 |
args.word_size, args.embedding_num = w1.shape
|
61 |
+
# string = input("诗头:")
|
62 |
# string = '自然语言'
|
63 |
+
args.model=model
|
64 |
if args.model == 'lstm':
|
65 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
66 |
+
args.save_path = 'save_models/lstm_50.pth'
|
67 |
elif args.model == 'GRU':
|
68 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
69 |
+
args.save_path = 'save_models/GRU_50.pth'
|
70 |
elif args.model == 'Seq2Seq':
|
71 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
72 |
elif args.model == 'Transformer':
|
73 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
|
|
74 |
elif args.model == 'GPT-2':
|
75 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
76 |
else:
|
77 |
print("Please choose a model!\n")
|
78 |
|
79 |
+
model.load_state_dict(torch.load(args.save_path))
|
80 |
model = make_cuda(model)
|
81 |
+
poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
|
82 |
return poem
|
83 |
|
84 |
|
|
|
86 |
args = parse_arguments()
|
87 |
all_data, (w1, word_2_index, index_2_word) = train_vec()
|
88 |
args.word_size, args.embedding_num = w1.shape
|
89 |
+
# string = input("诗头:")
|
90 |
+
string = '自然语言'
|
91 |
+
|
92 |
if args.model == 'lstm':
|
93 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
|
|
94 |
elif args.model == 'GRU':
|
95 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
96 |
elif args.model == 'Seq2Seq':
|
97 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
98 |
elif args.model == 'Transformer':
|
99 |
+
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
|
|
100 |
elif args.model == 'GPT-2':
|
101 |
model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
102 |
else:
|
103 |
print("Please choose a model!\n")
|
104 |
|
105 |
+
model.load_state_dict(torch.load(args.save_path))
|
106 |
model = make_cuda(model)
|
107 |
+
poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
|
108 |
print(poem)
|
save_models/GRU_25.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bacf9a7ec329c6185098c1309ab28239b4c087b53832b3d18e5323831bfead23
|
3 |
+
size 10727391
|
save_models/GRU_50.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a8e83a733c023b35c44020e014bb72e2c1d05698eb782669c0e4d5a76d4590d
|
3 |
+
size 10727391
|
save_models/lstm_25.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 10727391
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b064666ce02c63541dee4b6146d31ee8f7e784ee9c2811c9b9266aba6cc4193
|
3 |
size 10727391
|
save_models/lstm_50.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa157d970149c32b53b024a23ef8428e7b7e1702ed72d44152b568b085b1bfaa
|
3 |
+
size 10727391
|
src/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (166 Bytes). View file
|
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (146 Bytes). View file
|
|
src/apis/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (151 Bytes). View file
|
|
src/apis/__pycache__/inference.cpython-39.pyc
ADDED
Binary file (1.44 kB). View file
|
|
src/apis/__pycache__/train.cpython-39.pyc
ADDED
Binary file (1.68 kB). View file
|
|
src/apis/evaluate.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from src.models.EA_LSTM.model import weightedLSTM
|
4 |
+
from src.datasets.dataloader import MyDataset, create_vocab
|
5 |
+
|
6 |
+
|
7 |
+
def test(args):
|
8 |
+
vocab, poetrys = create_vocab(args.data)
|
9 |
+
# 词汇表长度
|
10 |
+
args.vocab_size = len(vocab)
|
11 |
+
int2char = np.array(vocab)
|
12 |
+
valid_dataset = MyDataset(vocab, poetrys, args, train=False)
|
13 |
+
|
14 |
+
model = weightedLSTM(6110, 256, 128, 2, [1.0] * 80, False)
|
15 |
+
model.load_state_dict(torch.load(args.save_path))
|
16 |
+
|
17 |
+
input_example_batch, target_example_batch = valid_dataset[0]
|
18 |
+
example_batch_predictions = model(input_example_batch)
|
19 |
+
predicted_id = torch.distributions.Categorical(example_batch_predictions).sample()
|
20 |
+
predicted_id = torch.squeeze(predicted_id, -1).numpy()
|
21 |
+
print("Input: \n", repr("".join(int2char[input_example_batch])))
|
22 |
+
print()
|
23 |
+
print("Predictions: \n", repr("".join(int2char[predicted_id])))
|
src/apis/train.py
CHANGED
@@ -8,11 +8,11 @@ from torch.nn import functional as F
|
|
8 |
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
9 |
|
10 |
|
11 |
-
def train(args, model, data_loader):
|
12 |
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
|
13 |
|
14 |
model.train()
|
15 |
-
num_epochs = args.num_epochs
|
16 |
|
17 |
for epoch in range(num_epochs):
|
18 |
loss = 0
|
|
|
8 |
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
9 |
|
10 |
|
11 |
+
def train(args, model, data_loader, initial=False):
|
12 |
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
|
13 |
|
14 |
model.train()
|
15 |
+
num_epochs = args.initial_epochs if initial else args.num_epochs
|
16 |
|
17 |
for epoch in range(num_epochs):
|
18 |
loss = 0
|
src/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (175 Bytes). View file
|
|
src/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (155 Bytes). View file
|
|
src/datasets/__pycache__/dataloader.cpython-38.pyc
ADDED
Binary file (4.09 kB). View file
|
|
src/datasets/__pycache__/dataloader.cpython-39.pyc
ADDED
Binary file (4.12 kB). View file
|
|
src/datasets/dataloader.py
CHANGED
@@ -25,27 +25,27 @@ def get_poetry(arg):
|
|
25 |
if arg.Augmented_dataset:
|
26 |
path = arg.Augmented_data
|
27 |
else:
|
28 |
-
path = arg.data
|
29 |
with open(path, "r", encoding='UTF-8') as f:
|
30 |
for line in f:
|
31 |
try:
|
32 |
# line = line.decode('UTF-8')
|
33 |
-
line = line.strip(u'\n')
|
34 |
if arg.Augmented_dataset:
|
35 |
content = line.strip(u' ')
|
36 |
else:
|
37 |
-
title, content = line.strip(u' ').split(u':')
|
38 |
-
content = content.replace(u' ', u'')
|
39 |
-
if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
|
40 |
continue
|
41 |
-
if arg.strict_dataset:
|
42 |
if len(content) < 12 or len(content) > 79:
|
43 |
continue
|
44 |
else:
|
45 |
if len(content) < 5 or len(content) > 79:
|
46 |
continue
|
47 |
-
content = u'[' + content + u']'
|
48 |
-
poetrys.append(content)
|
49 |
except Exception as e:
|
50 |
pass
|
51 |
|
@@ -82,9 +82,9 @@ def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"
|
|
82 |
if os.path.exists(param_file):
|
83 |
return org_data, pickle.load(open(param_file, "rb"))
|
84 |
|
85 |
-
models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
|
86 |
-
pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
|
87 |
-
return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)
|
88 |
|
89 |
|
90 |
class Poetry_Dataset(Dataset):
|
@@ -93,11 +93,11 @@ class Poetry_Dataset(Dataset):
|
|
93 |
self.w1 = w1
|
94 |
self.word_2_index = word_2_index
|
95 |
word_size, embedding_num = w1.shape
|
96 |
-
self.embedding = nn.Embedding(word_size, embedding_num)
|
97 |
# 最长句子长度
|
98 |
maxlen = max([len(seq) for seq in all_data])
|
99 |
pad = ' '
|
100 |
-
self.all_data = padding(all_data[:-1], maxlen, pad)
|
101 |
|
102 |
def __getitem__(self, index):
|
103 |
a_poetry = self.all_data[index]
|
|
|
25 |
if arg.Augmented_dataset:
|
26 |
path = arg.Augmented_data
|
27 |
else:
|
28 |
+
path = arg.data
|
29 |
with open(path, "r", encoding='UTF-8') as f:
|
30 |
for line in f:
|
31 |
try:
|
32 |
# line = line.decode('UTF-8')
|
33 |
+
line = line.strip(u'\n')
|
34 |
if arg.Augmented_dataset:
|
35 |
content = line.strip(u' ')
|
36 |
else:
|
37 |
+
title, content = line.strip(u' ').split(u':')
|
38 |
+
content = content.replace(u' ', u'')
|
39 |
+
if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
|
40 |
continue
|
41 |
+
if arg.strict_dataset:
|
42 |
if len(content) < 12 or len(content) > 79:
|
43 |
continue
|
44 |
else:
|
45 |
if len(content) < 5 or len(content) > 79:
|
46 |
continue
|
47 |
+
content = u'[' + content + u']'
|
48 |
+
poetrys.append(content)
|
49 |
except Exception as e:
|
50 |
pass
|
51 |
|
|
|
82 |
if os.path.exists(param_file):
|
83 |
return org_data, pickle.load(open(param_file, "rb"))
|
84 |
|
85 |
+
models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
|
86 |
+
pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
|
87 |
+
return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)
|
88 |
|
89 |
|
90 |
class Poetry_Dataset(Dataset):
|
|
|
93 |
self.w1 = w1
|
94 |
self.word_2_index = word_2_index
|
95 |
word_size, embedding_num = w1.shape
|
96 |
+
self.embedding = nn.Embedding(word_size, embedding_num)
|
97 |
# 最长句子长度
|
98 |
maxlen = max([len(seq) for seq in all_data])
|
99 |
pad = ' '
|
100 |
+
self.all_data = padding(all_data[:-1], maxlen, pad)
|
101 |
|
102 |
def __getitem__(self, index):
|
103 |
a_poetry = self.all_data[index]
|
src/models/LSTM/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (178 Bytes). View file
|
|
src/models/LSTM/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (161 Bytes). View file
|
|
src/models/LSTM/__pycache__/algorithm.cpython-39.pyc
ADDED
Binary file (4.99 kB). View file
|
|
src/models/LSTM/__pycache__/model.cpython-38.pyc
ADDED
Binary file (1.58 kB). View file
|
|
src/models/LSTM/__pycache__/model.cpython-39.pyc
ADDED
Binary file (1.55 kB). View file
|
|
src/models/LSTM/model.py
CHANGED
@@ -20,7 +20,7 @@ class Poetry_Model_lstm(nn.Module):
|
|
20 |
self.cross_entropy = nn.CrossEntropyLoss()
|
21 |
|
22 |
def forward(self, xs_embedding, h_0=None, c_0=None):
|
23 |
-
# xs_embedding: [batch_size, max_seq_len, n_feature] n_feature=
|
24 |
if h_0 == None or c_0 == None:
|
25 |
h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
|
26 |
c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
|
|
|
20 |
self.cross_entropy = nn.CrossEntropyLoss()
|
21 |
|
22 |
def forward(self, xs_embedding, h_0=None, c_0=None):
|
23 |
+
# xs_embedding: [batch_size, max_seq_len, n_feature] n_feature=128
|
24 |
if h_0 == None or c_0 == None:
|
25 |
h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
|
26 |
c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
|
src/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (173 Bytes). View file
|
|
src/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (153 Bytes). View file
|
|
src/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (172 Bytes). View file
|
|
src/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
src/utils/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (575 Bytes). View file
|
|
src/utils/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (555 Bytes). View file
|
|
train.py
CHANGED
@@ -6,19 +6,20 @@ import torch
|
|
6 |
import os
|
7 |
from src.datasets.dataloader import Poetry_Dataset, train_vec, get_poetry, split_text
|
8 |
from torch.utils.data import DataLoader
|
9 |
-
from src.models.Transformer.model import Poetry_Model_Transformer
|
10 |
|
11 |
|
12 |
def parse_arguments():
|
13 |
# argument parsing
|
14 |
parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
parser.add_argument('--batch_size', type=int, default=64,
|
17 |
help="Specify batch size")
|
18 |
-
|
19 |
-
parser.add_argument('--initial_epochs', type=int, default=25,
|
20 |
-
help="Specify the number of epochs for initial training")
|
21 |
-
|
22 |
parser.add_argument('--num_epochs', type=int, default=50,
|
23 |
help="Specify the number of epochs for competitive search")
|
24 |
parser.add_argument('--log_step', type=int, default=100,
|
@@ -27,40 +28,51 @@ def parse_arguments():
|
|
27 |
help="Learning rate")
|
28 |
parser.add_argument('--data', type=str, default='data/poetry.txt',
|
29 |
help="Path to the dataset")
|
|
|
|
|
30 |
parser.add_argument('--n_hidden', type=int, default=128)
|
31 |
parser.add_argument('--max_grad_norm', type=float, default=1.0)
|
32 |
|
33 |
-
parser.add_argument('--save_path', type=str, default='save_models/
|
34 |
-
|
35 |
-
parser.add_argument('--Word2Vec',type=bool, default=True)
|
36 |
-
parser.add_argument("--Augmented_dataset", type=bool, default=False)
|
37 |
return parser.parse_args()
|
38 |
|
39 |
|
40 |
def main():
|
41 |
args = parse_arguments()
|
|
|
42 |
if os.path.exists("data/split_poetry.txt") and os.path.exists("data/org_poetry.txt"):
|
43 |
print("pre_file exit!")
|
44 |
else:
|
45 |
-
split_text(get_poetry(args))
|
|
|
46 |
all_data, (w1, word_2_index, index_2_word) = train_vec()
|
47 |
-
args.word_size, args.embedding_num = w1.shape
|
48 |
|
49 |
-
dataset = Poetry_Dataset(w1, word_2_index, all_data,
|
50 |
-
train_size = int(len(dataset) * 0.
|
51 |
test_size = len(dataset) - train_size
|
52 |
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
|
53 |
-
|
54 |
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
55 |
valid_data_loader = DataLoader(test_dataset, batch_size=int(args.batch_size/4), shuffle=True)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
best_model = train(args, best_model, train_data_loader)
|
62 |
|
63 |
-
torch.save(best_model.state_dict(), args.save_path)
|
64 |
|
65 |
print('test evaluation:')
|
66 |
evaluate(args, best_model, valid_data_loader)
|
|
|
6 |
import os
|
7 |
from src.datasets.dataloader import Poetry_Dataset, train_vec, get_poetry, split_text
|
8 |
from torch.utils.data import DataLoader
|
|
|
9 |
|
10 |
|
11 |
def parse_arguments():
|
12 |
# argument parsing
|
13 |
parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
|
14 |
|
15 |
+
parser.add_argument('--model', type=str, default='lstm',
|
16 |
+
help="lstm/GRU/Seq2Seq/Transformer/GPT-2")
|
17 |
+
parser.add_argument('--Word2Vec', default=True)
|
18 |
+
parser.add_argument('--Augmented_dataset', default=False, help="augmented dataset")
|
19 |
+
parser.add_argument('--strict_dataset', default=False, help="strict dataset")
|
20 |
+
|
21 |
parser.add_argument('--batch_size', type=int, default=64,
|
22 |
help="Specify batch size")
|
|
|
|
|
|
|
|
|
23 |
parser.add_argument('--num_epochs', type=int, default=50,
|
24 |
help="Specify the number of epochs for competitive search")
|
25 |
parser.add_argument('--log_step', type=int, default=100,
|
|
|
28 |
help="Learning rate")
|
29 |
parser.add_argument('--data', type=str, default='data/poetry.txt',
|
30 |
help="Path to the dataset")
|
31 |
+
parser.add_argument('--Augmented_data', type=str, default='data/poetry_7.txt',
|
32 |
+
help="Path to the Augmented_dataset")
|
33 |
parser.add_argument('--n_hidden', type=int, default=128)
|
34 |
parser.add_argument('--max_grad_norm', type=float, default=1.0)
|
35 |
|
36 |
+
parser.add_argument('--save_path', type=str, default='save_models/')
|
37 |
+
|
|
|
|
|
38 |
return parser.parse_args()
|
39 |
|
40 |
|
41 |
def main():
|
42 |
args = parse_arguments()
|
43 |
+
# if you want to change the data(org data or argument data), please delete file: 'split_poetry.txt' and 'org_poetry.txt'
|
44 |
if os.path.exists("data/split_poetry.txt") and os.path.exists("data/org_poetry.txt"):
|
45 |
print("pre_file exit!")
|
46 |
else:
|
47 |
+
split_text(get_poetry(args))
|
48 |
+
|
49 |
all_data, (w1, word_2_index, index_2_word) = train_vec()
|
50 |
+
args.word_size, args.embedding_num = w1.shape
|
51 |
|
52 |
+
dataset = Poetry_Dataset(w1, word_2_index, all_data, args.Word2Vec)
|
53 |
+
train_size = int(len(dataset) * 0.8)
|
54 |
test_size = len(dataset) - train_size
|
55 |
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
|
|
|
56 |
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
57 |
valid_data_loader = DataLoader(test_dataset, batch_size=int(args.batch_size/4), shuffle=True)
|
58 |
+
|
59 |
+
if args.model == 'lstm':
|
60 |
+
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
61 |
+
elif args.model == 'GRU':
|
62 |
+
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
63 |
+
elif args.model == 'Seq2Seq':
|
64 |
+
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
65 |
+
elif args.model == 'Transformer':
|
66 |
+
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
67 |
+
elif args.model == 'GPT-2':
|
68 |
+
best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
|
69 |
+
else:
|
70 |
+
print("Please choose a model!\n")
|
71 |
+
|
72 |
+
best_model = make_cuda(best_model)
|
73 |
best_model = train(args, best_model, train_data_loader)
|
74 |
|
75 |
+
torch.save(best_model.state_dict(), args.save_path + args.model + '_' + str(args.num_epochs)+'.pth')
|
76 |
|
77 |
print('test evaluation:')
|
78 |
evaluate(args, best_model, valid_data_loader)
|