zhangj726 commited on
Commit
0666c69
·
1 Parent(s): e46ebb5

Upload 53 files

Browse files
Files changed (42) hide show
  1. .idea/.gitignore +3 -0
  2. .idea/.name +1 -0
  3. .idea/ea_lstm.iml +8 -0
  4. .idea/inspectionProfiles/Project_Default.xml +20 -0
  5. .idea/misc.xml +1 -1
  6. .idea/modules.xml +1 -1
  7. .idea/workspace.xml +18 -85
  8. README.md +22 -11
  9. __pycache__/inference.cpython-38.pyc +0 -0
  10. app.py +1 -1
  11. data/poetry_7.txt +0 -0
  12. data/word_vec.pkl +3 -0
  13. inference.py +18 -22
  14. save_models/GRU_25.pth +3 -0
  15. save_models/GRU_50.pth +3 -0
  16. save_models/lstm_25.pth +1 -1
  17. save_models/lstm_50.pth +3 -0
  18. src/__pycache__/__init__.cpython-38.pyc +0 -0
  19. src/__pycache__/__init__.cpython-39.pyc +0 -0
  20. src/apis/__pycache__/__init__.cpython-39.pyc +0 -0
  21. src/apis/__pycache__/inference.cpython-39.pyc +0 -0
  22. src/apis/__pycache__/train.cpython-39.pyc +0 -0
  23. src/apis/evaluate.py +23 -0
  24. src/apis/train.py +2 -2
  25. src/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  26. src/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  27. src/datasets/__pycache__/dataloader.cpython-38.pyc +0 -0
  28. src/datasets/__pycache__/dataloader.cpython-39.pyc +0 -0
  29. src/datasets/dataloader.py +13 -13
  30. src/models/LSTM/__pycache__/__init__.cpython-38.pyc +0 -0
  31. src/models/LSTM/__pycache__/__init__.cpython-39.pyc +0 -0
  32. src/models/LSTM/__pycache__/algorithm.cpython-39.pyc +0 -0
  33. src/models/LSTM/__pycache__/model.cpython-38.pyc +0 -0
  34. src/models/LSTM/__pycache__/model.cpython-39.pyc +0 -0
  35. src/models/LSTM/model.py +1 -1
  36. src/models/__pycache__/__init__.cpython-38.pyc +0 -0
  37. src/models/__pycache__/__init__.cpython-39.pyc +0 -0
  38. src/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  39. src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  40. src/utils/__pycache__/utils.cpython-38.pyc +0 -0
  41. src/utils/__pycache__/utils.cpython-39.pyc +0 -0
  42. train.py +32 -20
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/.name ADDED
@@ -0,0 +1 @@
 
 
1
+ inference.py
.idea/ea_lstm.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="7">
8
+ <item index="0" class="java.lang.String" itemvalue="easydict" />
9
+ <item index="1" class="java.lang.String" itemvalue="pandas" />
10
+ <item index="2" class="java.lang.String" itemvalue="matplotlib" />
11
+ <item index="3" class="java.lang.String" itemvalue="pillow" />
12
+ <item index="4" class="java.lang.String" itemvalue="mindspore" />
13
+ <item index="5" class="java.lang.String" itemvalue="setuptools" />
14
+ <item index="6" class="java.lang.String" itemvalue="numpy" />
15
+ </list>
16
+ </value>
17
+ </option>
18
+ </inspection_tool>
19
+ </profile>
20
+ </component>
.idea/misc.xml CHANGED
@@ -1,4 +1,4 @@
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
- <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
4
  </project>
 
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (pytorch)" project-jdk-type="Python SDK" />
4
  </project>
.idea/modules.xml CHANGED
@@ -2,7 +2,7 @@
2
  <project version="4">
3
  <component name="ProjectModuleManager">
4
  <modules>
5
- <module fileurl="file://$PROJECT_DIR$/.idea/nlp.iml" filepath="$PROJECT_DIR$/.idea/nlp.iml" />
6
  </modules>
7
  </component>
8
  </project>
 
2
  <project version="4">
3
  <component name="ProjectModuleManager">
4
  <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/ea_lstm.iml" filepath="$PROJECT_DIR$/.idea/ea_lstm.iml" />
6
  </modules>
7
  </component>
8
  </project>
.idea/workspace.xml CHANGED
@@ -1,111 +1,44 @@
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
  <component name="ChangeListManager">
4
- <list default="true" id="975e88fb-d387-4d2c-9625-dc69f610d124" name="Changes" comment="">
5
- <change afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" afterDir="false" />
6
- <change afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
7
- <change afterPath="$PROJECT_DIR$/.idea/modules.xml" afterDir="false" />
8
- <change afterPath="$PROJECT_DIR$/.idea/nlp.iml" afterDir="false" />
9
- <change afterPath="$PROJECT_DIR$/.idea/vcs.xml" afterDir="false" />
10
- <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
11
- <change beforePath="$PROJECT_DIR$/src/apis/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/apis/train.py" afterDir="false" />
12
- <change beforePath="$PROJECT_DIR$/src/datasets/dataloader.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/datasets/dataloader.py" afterDir="false" />
13
- <change beforePath="$PROJECT_DIR$/src/models/LSTM/model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/models/LSTM/model.py" afterDir="false" />
14
- </list>
15
  <option name="SHOW_DIALOG" value="false" />
16
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
17
  <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
18
  <option name="LAST_RESOLUTION" value="IGNORE" />
19
  </component>
20
- <component name="Git.Settings">
21
- <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
22
- </component>
23
  <component name="MarkdownSettingsMigration">
24
  <option name="stateVersion" value="1" />
25
  </component>
26
- <component name="ProjectId" id="2Q8D9XoYiTKL5jiaHLTd3rsHf4Y" />
27
  <component name="ProjectViewState">
28
  <option name="hideEmptyMiddlePackages" value="true" />
29
  <option name="showLibraryContents" value="true" />
30
  </component>
31
- <component name="PropertiesComponent">{
32
- &quot;keyToString&quot;: {
33
- &quot;RunOnceActivity.OpenProjectViewOnStart&quot;: &quot;true&quot;,
34
- &quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
35
- &quot;last_opened_file_path&quot;: &quot;D:/YOU/dasanxia/NLP/new0522/nlp&quot;,
36
- &quot;settings.editor.selected.configurable&quot;: &quot;com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable&quot;
37
  }
38
- }</component>
39
- <component name="RunManager" selected="Python.run_gradio">
40
- <configuration name="inference" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
41
- <module name="nlp" />
42
- <option name="INTERPRETER_OPTIONS" value="" />
43
- <option name="PARENT_ENVS" value="true" />
44
- <envs>
45
- <env name="PYTHONUNBUFFERED" value="1" />
46
- </envs>
47
- <option name="SDK_HOME" value="" />
48
- <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
49
- <option name="IS_MODULE_SDK" value="true" />
50
- <option name="ADD_CONTENT_ROOTS" value="true" />
51
- <option name="ADD_SOURCE_ROOTS" value="true" />
52
- <option name="SCRIPT_NAME" value="$PROJECT_DIR$/inference.py" />
53
- <option name="PARAMETERS" value="" />
54
- <option name="SHOW_COMMAND_LINE" value="false" />
55
- <option name="EMULATE_TERMINAL" value="false" />
56
- <option name="MODULE_MODE" value="false" />
57
- <option name="REDIRECT_INPUT" value="false" />
58
- <option name="INPUT_FILE" value="" />
59
- <method v="2" />
60
- </configuration>
61
- <configuration name="run_gradio" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
62
- <module name="nlp" />
63
- <option name="INTERPRETER_OPTIONS" value="" />
64
- <option name="PARENT_ENVS" value="true" />
65
- <envs>
66
- <env name="PYTHONUNBUFFERED" value="1" />
67
- </envs>
68
- <option name="SDK_HOME" value="" />
69
- <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
70
- <option name="IS_MODULE_SDK" value="true" />
71
- <option name="ADD_CONTENT_ROOTS" value="true" />
72
- <option name="ADD_SOURCE_ROOTS" value="true" />
73
- <option name="SCRIPT_NAME" value="$PROJECT_DIR$/run_gradio.py" />
74
- <option name="PARAMETERS" value="" />
75
- <option name="SHOW_COMMAND_LINE" value="false" />
76
- <option name="EMULATE_TERMINAL" value="false" />
77
- <option name="MODULE_MODE" value="false" />
78
- <option name="REDIRECT_INPUT" value="false" />
79
- <option name="INPUT_FILE" value="" />
80
- <method v="2" />
81
- </configuration>
82
- <recent_temporary>
83
- <list>
84
- <item itemvalue="Python.run_gradio" />
85
- <item itemvalue="Python.inference" />
86
- </list>
87
- </recent_temporary>
88
- </component>
89
  <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
90
  <component name="TaskManager">
91
  <task active="true" id="Default" summary="Default task">
92
- <changelist id="975e88fb-d387-4d2c-9625-dc69f610d124" name="Changes" comment="" />
93
- <created>1684726163448</created>
94
  <option name="number" value="Default" />
95
  <option name="presentableId" value="Default" />
96
- <updated>1684726163448</updated>
97
  </task>
98
  <servers />
99
  </component>
100
- <component name="Vcs.Log.Tabs.Properties">
101
- <option name="TAB_STATES">
102
- <map>
103
- <entry key="MAIN">
104
- <value>
105
- <State />
106
- </value>
107
- </entry>
108
- </map>
109
- </option>
110
  </component>
111
  </project>
 
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
  <component name="ChangeListManager">
4
+ <list default="true" id="276a53df-3cdd-4e96-95d3-c1e69d4e9b9f" name="Changes" comment="" />
 
 
 
 
 
 
 
 
 
 
5
  <option name="SHOW_DIALOG" value="false" />
6
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
7
  <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
8
  <option name="LAST_RESOLUTION" value="IGNORE" />
9
  </component>
 
 
 
10
  <component name="MarkdownSettingsMigration">
11
  <option name="stateVersion" value="1" />
12
  </component>
13
+ <component name="ProjectId" id="2OyFWrJQpFYHFKgf87OgmRH5Jtu" />
14
  <component name="ProjectViewState">
15
  <option name="hideEmptyMiddlePackages" value="true" />
16
  <option name="showLibraryContents" value="true" />
17
  </component>
18
+ <component name="PropertiesComponent"><![CDATA[{
19
+ "keyToString": {
20
+ "RunOnceActivity.OpenProjectViewOnStart": "true",
21
+ "RunOnceActivity.ShowReadmeOnStart": "true",
22
+ "last_opened_file_path": "C:/Users/LENOVO/PycharmProjects/lstm"
 
23
  }
24
+ }]]></component>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
26
  <component name="TaskManager">
27
  <task active="true" id="Default" summary="Default task">
28
+ <changelist id="276a53df-3cdd-4e96-95d3-c1e69d4e9b9f" name="Changes" comment="" />
29
+ <created>1682524950142</created>
30
  <option name="number" value="Default" />
31
  <option name="presentableId" value="Default" />
32
+ <updated>1682524950142</updated>
33
  </task>
34
  <servers />
35
  </component>
36
+ <component name="XDebuggerManager">
37
+ <watches-manager>
38
+ <configuration name="PythonConfigurationType">
39
+ <watch expression="input_eval" />
40
+ <watch expression="word_2_index" />
41
+ </configuration>
42
+ </watches-manager>
 
 
 
43
  </component>
44
  </project>
README.md CHANGED
@@ -1,12 +1,23 @@
1
- ---
2
- title: Poem Generation
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.32.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NLP Final Project
 
 
 
 
 
 
 
 
 
2
 
3
+ ```shell
4
+ ├── configs
5
+ ├── data
6
+ │   └── poetry.txt
7
+ ├── inference.py
8
+ ├── src
9
+ │   ├── apis
10
+ │   │   ├── evaluate.py
11
+ │   │   ├── inference.py
12
+ │   │   └── train.py
13
+ │   ├── datasets
14
+ │   │   └── dataloader.py
15
+ │   ├── models
16
+ │   │   └── EA-LSTM
17
+ │   │   ├── algorithm.py
18
+ │   │   └── model.py
19
+ │   └── utils
20
+ │   └── utils.py
21
+ ├── test.py
22
+ └── train.py
23
+ ```
__pycache__/inference.cpython-38.pyc ADDED
Binary file (2.88 kB). View file
 
app.py CHANGED
@@ -7,7 +7,7 @@ from inference import infer
7
 
8
 
9
 
10
- INTERFACE = gradio.Interface(fn=infer, inputs=[gradio.Radio(["lstm","GRU","Seq2Seq","Transformer","GPT-2"]),"text"], outputs=["text"], title="Poetry Generation",
11
  description="Choose a model and input the poetic head to generate a acrostic",
12
  thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true")
13
 
 
7
 
8
 
9
 
10
+ INTERFACE = gradio.Interface(fn=infer, inputs=[gradio.Radio(["lstm","GRU"]),"text"], outputs=["text"], title="Poetry Generation",
11
  description="Choose a model and input the poetic head to generate a acrostic",
12
  thumbnail="https://github.com/gradio-app/gpt-2/raw/master/screenshots/interface.png?raw=true")
13
 
data/poetry_7.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/word_vec.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1164cfc2e28ef6ecbb1a04734e7268238b4841667f13d6cb4c42e27717dd4575
3
+ size 6339344
inference.py CHANGED
@@ -1,10 +1,10 @@
1
  import torch
2
  import argparse
3
  import numpy as np
 
4
  from src.datasets.dataloader import train_vec
5
  from src.utils.utils import make_cuda
6
- from src.models.Transformer.model import Poetry_Model_Transformer
7
- from src.models.LSTM.model import Poetry_Model_lstm
8
 
9
  def parse_arguments():
10
  # argument parsing
@@ -15,12 +15,12 @@ def parse_arguments():
15
  parser.add_argument('--strict_dataset', default=False, help="strict dataset")
16
  parser.add_argument('--n_hidden', type=int, default=128)
17
 
18
- parser.add_argument('--save_path', type=str, default='save_models/model_params.pth')
19
 
20
  return parser.parse_args()
21
 
22
 
23
- def generate_poetry(model, head_string, w1, word_2_index, index_2_word,args):
24
  print("藏头诗生成中...., {}".format(head_string))
25
  poem = ""
26
  # 以句子的每一个字为开头生成诗句
@@ -54,33 +54,31 @@ def generate_poetry(model, head_string, w1, word_2_index, index_2_word,args):
54
 
55
  return poem
56
 
57
- def infer(model,poem_head):
58
  args = parse_arguments()
59
- args.model=model
60
-
61
  all_data, (w1, word_2_index, index_2_word) = train_vec()
62
  args.word_size, args.embedding_num = w1.shape
63
- string = poem_head
64
  # string = '自然语言'
65
-
66
  if args.model == 'lstm':
67
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
68
- args.save_path='save_models/lstm_25.pth'
69
  elif args.model == 'GRU':
70
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
 
71
  elif args.model == 'Seq2Seq':
72
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
73
  elif args.model == 'Transformer':
74
- model = Poetry_Model_Transformer(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
75
- args.save_path='save_models/transformer_100.pth'
76
  elif args.model == 'GPT-2':
77
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
78
  else:
79
  print("Please choose a model!\n")
80
 
81
- model.load_state_dict(torch.load(args.save_path,map_location=torch.device('cpu')))
82
  model = make_cuda(model)
83
- poem = generate_poetry(model, string, w1, word_2_index, index_2_word,args)
84
  return poem
85
 
86
 
@@ -88,25 +86,23 @@ if __name__ == '__main__':
88
  args = parse_arguments()
89
  all_data, (w1, word_2_index, index_2_word) = train_vec()
90
  args.word_size, args.embedding_num = w1.shape
91
- string = input("诗头:")
92
- # string = '自然语言'
93
-
94
  if args.model == 'lstm':
95
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
96
- args.save_path='save_models/lstm_25.pth'
97
  elif args.model == 'GRU':
98
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
99
  elif args.model == 'Seq2Seq':
100
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
101
  elif args.model == 'Transformer':
102
- model = Poetry_Model_Transformer(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
103
- args.save_path='save_models/transformer_100.pth'
104
  elif args.model == 'GPT-2':
105
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
106
  else:
107
  print("Please choose a model!\n")
108
 
109
- model.load_state_dict(torch.load(args.save_path,map_location=torch.device('cpu')))
110
  model = make_cuda(model)
111
- poem = generate_poetry(model, string, w1, word_2_index, index_2_word,args)
112
  print(poem)
 
1
  import torch
2
  import argparse
3
  import numpy as np
4
+ from src.models.LSTM.model import Poetry_Model_lstm
5
  from src.datasets.dataloader import train_vec
6
  from src.utils.utils import make_cuda
7
+
 
8
 
9
  def parse_arguments():
10
  # argument parsing
 
15
  parser.add_argument('--strict_dataset', default=False, help="strict dataset")
16
  parser.add_argument('--n_hidden', type=int, default=128)
17
 
18
+ parser.add_argument('--save_path', type=str, default='save_models/lstm_50.pth')
19
 
20
  return parser.parse_args()
21
 
22
 
23
+ def generate_poetry(model, head_string, w1, word_2_index, index_2_word):
24
  print("藏头诗生成中...., {}".format(head_string))
25
  poem = ""
26
  # 以句子的每一个字为开头生成诗句
 
54
 
55
  return poem
56
 
57
+ def infer(model,string):
58
  args = parse_arguments()
 
 
59
  all_data, (w1, word_2_index, index_2_word) = train_vec()
60
  args.word_size, args.embedding_num = w1.shape
61
+ # string = input("诗头:")
62
  # string = '自然语言'
63
+ args.model=model
64
  if args.model == 'lstm':
65
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
66
+ args.save_path = 'save_models/lstm_50.pth'
67
  elif args.model == 'GRU':
68
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
69
+ args.save_path = 'save_models/GRU_50.pth'
70
  elif args.model == 'Seq2Seq':
71
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
72
  elif args.model == 'Transformer':
73
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
 
74
  elif args.model == 'GPT-2':
75
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
76
  else:
77
  print("Please choose a model!\n")
78
 
79
+ model.load_state_dict(torch.load(args.save_path))
80
  model = make_cuda(model)
81
+ poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
82
  return poem
83
 
84
 
 
86
  args = parse_arguments()
87
  all_data, (w1, word_2_index, index_2_word) = train_vec()
88
  args.word_size, args.embedding_num = w1.shape
89
+ # string = input("诗头:")
90
+ string = '自然语言'
91
+
92
  if args.model == 'lstm':
93
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
 
94
  elif args.model == 'GRU':
95
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
96
  elif args.model == 'Seq2Seq':
97
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
98
  elif args.model == 'Transformer':
99
+ model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
 
100
  elif args.model == 'GPT-2':
101
  model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
102
  else:
103
  print("Please choose a model!\n")
104
 
105
+ model.load_state_dict(torch.load(args.save_path))
106
  model = make_cuda(model)
107
+ poem = generate_poetry(model, string, w1, word_2_index, index_2_word)
108
  print(poem)
save_models/GRU_25.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bacf9a7ec329c6185098c1309ab28239b4c087b53832b3d18e5323831bfead23
3
+ size 10727391
save_models/GRU_50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a8e83a733c023b35c44020e014bb72e2c1d05698eb782669c0e4d5a76d4590d
3
+ size 10727391
save_models/lstm_25.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1501b5e3e6d9aa864857c8036f27d09c2489da832e616916b9633092b0ed3df5
3
  size 10727391
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b064666ce02c63541dee4b6146d31ee8f7e784ee9c2811c9b9266aba6cc4193
3
  size 10727391
save_models/lstm_50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa157d970149c32b53b024a23ef8428e7b7e1702ed72d44152b568b085b1bfaa
3
+ size 10727391
src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (166 Bytes). View file
 
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file
 
src/apis/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (151 Bytes). View file
 
src/apis/__pycache__/inference.cpython-39.pyc ADDED
Binary file (1.44 kB). View file
 
src/apis/__pycache__/train.cpython-39.pyc ADDED
Binary file (1.68 kB). View file
 
src/apis/evaluate.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from src.models.EA_LSTM.model import weightedLSTM
4
+ from src.datasets.dataloader import MyDataset, create_vocab
5
+
6
+
7
+ def test(args):
8
+ vocab, poetrys = create_vocab(args.data)
9
+ # 词汇表长度
10
+ args.vocab_size = len(vocab)
11
+ int2char = np.array(vocab)
12
+ valid_dataset = MyDataset(vocab, poetrys, args, train=False)
13
+
14
+ model = weightedLSTM(6110, 256, 128, 2, [1.0] * 80, False)
15
+ model.load_state_dict(torch.load(args.save_path))
16
+
17
+ input_example_batch, target_example_batch = valid_dataset[0]
18
+ example_batch_predictions = model(input_example_batch)
19
+ predicted_id = torch.distributions.Categorical(example_batch_predictions).sample()
20
+ predicted_id = torch.squeeze(predicted_id, -1).numpy()
21
+ print("Input: \n", repr("".join(int2char[input_example_batch])))
22
+ print()
23
+ print("Predictions: \n", repr("".join(int2char[predicted_id])))
src/apis/train.py CHANGED
@@ -8,11 +8,11 @@ from torch.nn import functional as F
8
  from sklearn.metrics import mean_squared_error, mean_absolute_error
9
 
10
 
11
- def train(args, model, data_loader):
12
  optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
13
 
14
  model.train()
15
- num_epochs = args.num_epochs
16
 
17
  for epoch in range(num_epochs):
18
  loss = 0
 
8
  from sklearn.metrics import mean_squared_error, mean_absolute_error
9
 
10
 
11
+ def train(args, model, data_loader, initial=False):
12
  optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
13
 
14
  model.train()
15
+ num_epochs = args.initial_epochs if initial else args.num_epochs
16
 
17
  for epoch in range(num_epochs):
18
  loss = 0
src/datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (175 Bytes). View file
 
src/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (155 Bytes). View file
 
src/datasets/__pycache__/dataloader.cpython-38.pyc ADDED
Binary file (4.09 kB). View file
 
src/datasets/__pycache__/dataloader.cpython-39.pyc ADDED
Binary file (4.12 kB). View file
 
src/datasets/dataloader.py CHANGED
@@ -25,27 +25,27 @@ def get_poetry(arg):
25
  if arg.Augmented_dataset:
26
  path = arg.Augmented_data
27
  else:
28
- path = arg.data # 数据集路径,data/poetry.txt
29
  with open(path, "r", encoding='UTF-8') as f:
30
  for line in f:
31
  try:
32
  # line = line.decode('UTF-8')
33
- line = line.strip(u'\n') # 去掉换行符
34
  if arg.Augmented_dataset:
35
  content = line.strip(u' ')
36
  else:
37
- title, content = line.strip(u' ').split(u':') # 标题和内容以冒号分隔
38
- content = content.replace(u' ', u'') # 去掉空格
39
- if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content: # 去掉特殊符号的古诗
40
  continue
41
- if arg.strict_dataset: # 严格模式
42
  if len(content) < 12 or len(content) > 79:
43
  continue
44
  else:
45
  if len(content) < 5 or len(content) > 79:
46
  continue
47
- content = u'[' + content + u']' # 开头加上开始符,结尾加上结束符
48
- poetrys.append(content) # 保存到poetrys列表中
49
  except Exception as e:
50
  pass
51
 
@@ -82,9 +82,9 @@ def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"
82
  if os.path.exists(param_file):
83
  return org_data, pickle.load(open(param_file, "rb"))
84
 
85
- models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1) # 训练词向量,输入参数分别是:分词后的文本,词向量维度,线程数,最小词频
86
- pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb")) # 保存词向量,key_to_index是词汇表,index_to_key是词向量,dump的作用是将数据序列化到文件中
87
- return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key) # syn1neg是词向量,key_to_index是词汇表,index_to_key是词向量
88
 
89
 
90
  class Poetry_Dataset(Dataset):
@@ -93,11 +93,11 @@ class Poetry_Dataset(Dataset):
93
  self.w1 = w1
94
  self.word_2_index = word_2_index
95
  word_size, embedding_num = w1.shape
96
- self.embedding = nn.Embedding(word_size, embedding_num) # 词嵌入层
97
  # 最长句子长度
98
  maxlen = max([len(seq) for seq in all_data])
99
  pad = ' '
100
- self.all_data = padding(all_data[:-1], maxlen, pad)
101
 
102
  def __getitem__(self, index):
103
  a_poetry = self.all_data[index]
 
25
  if arg.Augmented_dataset:
26
  path = arg.Augmented_data
27
  else:
28
+ path = arg.data
29
  with open(path, "r", encoding='UTF-8') as f:
30
  for line in f:
31
  try:
32
  # line = line.decode('UTF-8')
33
+ line = line.strip(u'\n')
34
  if arg.Augmented_dataset:
35
  content = line.strip(u' ')
36
  else:
37
+ title, content = line.strip(u' ').split(u':')
38
+ content = content.replace(u' ', u'')
39
+ if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
40
  continue
41
+ if arg.strict_dataset:
42
  if len(content) < 12 or len(content) > 79:
43
  continue
44
  else:
45
  if len(content) < 5 or len(content) > 79:
46
  continue
47
+ content = u'[' + content + u']'
48
+ poetrys.append(content)
49
  except Exception as e:
50
  pass
51
 
 
82
  if os.path.exists(param_file):
83
  return org_data, pickle.load(open(param_file, "rb"))
84
 
85
+ models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
86
+ pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
87
+ return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)
88
 
89
 
90
  class Poetry_Dataset(Dataset):
 
93
  self.w1 = w1
94
  self.word_2_index = word_2_index
95
  word_size, embedding_num = w1.shape
96
+ self.embedding = nn.Embedding(word_size, embedding_num)
97
  # 最长句子长度
98
  maxlen = max([len(seq) for seq in all_data])
99
  pad = ' '
100
+ self.all_data = padding(all_data[:-1], maxlen, pad)
101
 
102
  def __getitem__(self, index):
103
  a_poetry = self.all_data[index]
src/models/LSTM/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (178 Bytes). View file
 
src/models/LSTM/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (161 Bytes). View file
 
src/models/LSTM/__pycache__/algorithm.cpython-39.pyc ADDED
Binary file (4.99 kB). View file
 
src/models/LSTM/__pycache__/model.cpython-38.pyc ADDED
Binary file (1.58 kB). View file
 
src/models/LSTM/__pycache__/model.cpython-39.pyc ADDED
Binary file (1.55 kB). View file
 
src/models/LSTM/model.py CHANGED
@@ -20,7 +20,7 @@ class Poetry_Model_lstm(nn.Module):
20
  self.cross_entropy = nn.CrossEntropyLoss()
21
 
22
  def forward(self, xs_embedding, h_0=None, c_0=None):
23
- # xs_embedding: [batch_size, max_seq_len, n_feature] n_feature=256
24
  if h_0 == None or c_0 == None:
25
  h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
26
  c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
 
20
  self.cross_entropy = nn.CrossEntropyLoss()
21
 
22
  def forward(self, xs_embedding, h_0=None, c_0=None):
23
+ # xs_embedding: [batch_size, max_seq_len, n_feature] n_feature=128
24
  if h_0 == None or c_0 == None:
25
  h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
26
  c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
src/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (173 Bytes). View file
 
src/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (153 Bytes). View file
 
src/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (172 Bytes). View file
 
src/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (152 Bytes). View file
 
src/utils/__pycache__/utils.cpython-38.pyc ADDED
Binary file (575 Bytes). View file
 
src/utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (555 Bytes). View file
 
train.py CHANGED
@@ -6,19 +6,20 @@ import torch
6
  import os
7
  from src.datasets.dataloader import Poetry_Dataset, train_vec, get_poetry, split_text
8
  from torch.utils.data import DataLoader
9
- from src.models.Transformer.model import Poetry_Model_Transformer
10
 
11
 
12
  def parse_arguments():
13
  # argument parsing
14
  parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
15
 
 
 
 
 
 
 
16
  parser.add_argument('--batch_size', type=int, default=64,
17
  help="Specify batch size")
18
-
19
- parser.add_argument('--initial_epochs', type=int, default=25,
20
- help="Specify the number of epochs for initial training")
21
-
22
  parser.add_argument('--num_epochs', type=int, default=50,
23
  help="Specify the number of epochs for competitive search")
24
  parser.add_argument('--log_step', type=int, default=100,
@@ -27,40 +28,51 @@ def parse_arguments():
27
  help="Learning rate")
28
  parser.add_argument('--data', type=str, default='data/poetry.txt',
29
  help="Path to the dataset")
 
 
30
  parser.add_argument('--n_hidden', type=int, default=128)
31
  parser.add_argument('--max_grad_norm', type=float, default=1.0)
32
 
33
- parser.add_argument('--save_path', type=str, default='save_models/transformer.pth')
34
- parser.add_argument('--strict_dataset', default=False, help="strict dataset")
35
- parser.add_argument('--Word2Vec',type=bool, default=True)
36
- parser.add_argument("--Augmented_dataset", type=bool, default=False)
37
  return parser.parse_args()
38
 
39
 
40
  def main():
41
  args = parse_arguments()
 
42
  if os.path.exists("data/split_poetry.txt") and os.path.exists("data/org_poetry.txt"):
43
  print("pre_file exit!")
44
  else:
45
- split_text(get_poetry(args)) # split poetry
 
46
  all_data, (w1, word_2_index, index_2_word) = train_vec()
47
- args.word_size, args.embedding_num = w1.shape # 词向量的维度
48
 
49
- dataset = Poetry_Dataset(w1, word_2_index, all_data, Word2Vec=args.Word2Vec)
50
- train_size = int(len(dataset) * 0.7)
51
  test_size = len(dataset) - train_size
52
  train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
53
-
54
  train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
55
  valid_data_loader = DataLoader(test_dataset, batch_size=int(args.batch_size/4), shuffle=True)
56
-
57
- # best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num,args.Word2Vec)
58
- best_model = Poetry_Model_Transformer(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
59
- best_model = make_cuda(best_model) # use gpu
60
- print("Initial training before competitive random search")
 
 
 
 
 
 
 
 
 
 
61
  best_model = train(args, best_model, train_data_loader)
62
 
63
- torch.save(best_model.state_dict(), args.save_path)
64
 
65
  print('test evaluation:')
66
  evaluate(args, best_model, valid_data_loader)
 
6
  import os
7
  from src.datasets.dataloader import Poetry_Dataset, train_vec, get_poetry, split_text
8
  from torch.utils.data import DataLoader
 
9
 
10
 
11
  def parse_arguments():
12
  # argument parsing
13
  parser = argparse.ArgumentParser(description="Specify Params for Experimental Setting")
14
 
15
+ parser.add_argument('--model', type=str, default='lstm',
16
+ help="lstm/GRU/Seq2Seq/Transformer/GPT-2")
17
+ parser.add_argument('--Word2Vec', default=True)
18
+ parser.add_argument('--Augmented_dataset', default=False, help="augmented dataset")
19
+ parser.add_argument('--strict_dataset', default=False, help="strict dataset")
20
+
21
  parser.add_argument('--batch_size', type=int, default=64,
22
  help="Specify batch size")
 
 
 
 
23
  parser.add_argument('--num_epochs', type=int, default=50,
24
  help="Specify the number of epochs for competitive search")
25
  parser.add_argument('--log_step', type=int, default=100,
 
28
  help="Learning rate")
29
  parser.add_argument('--data', type=str, default='data/poetry.txt',
30
  help="Path to the dataset")
31
+ parser.add_argument('--Augmented_data', type=str, default='data/poetry_7.txt',
32
+ help="Path to the Augmented_dataset")
33
  parser.add_argument('--n_hidden', type=int, default=128)
34
  parser.add_argument('--max_grad_norm', type=float, default=1.0)
35
 
36
+ parser.add_argument('--save_path', type=str, default='save_models/')
37
+
 
 
38
  return parser.parse_args()
39
 
40
 
41
  def main():
42
  args = parse_arguments()
43
+ # if you want to change the data(org data or argument data), please delete file: 'split_poetry.txt' and 'org_poetry.txt'
44
  if os.path.exists("data/split_poetry.txt") and os.path.exists("data/org_poetry.txt"):
45
  print("pre_file exit!")
46
  else:
47
+ split_text(get_poetry(args))
48
+
49
  all_data, (w1, word_2_index, index_2_word) = train_vec()
50
+ args.word_size, args.embedding_num = w1.shape
51
 
52
+ dataset = Poetry_Dataset(w1, word_2_index, all_data, args.Word2Vec)
53
+ train_size = int(len(dataset) * 0.8)
54
  test_size = len(dataset) - train_size
55
  train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
 
56
  train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
57
  valid_data_loader = DataLoader(test_dataset, batch_size=int(args.batch_size/4), shuffle=True)
58
+
59
+ if args.model == 'lstm':
60
+ best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
61
+ elif args.model == 'GRU':
62
+ best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
63
+ elif args.model == 'Seq2Seq':
64
+ best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
65
+ elif args.model == 'Transformer':
66
+ best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
67
+ elif args.model == 'GPT-2':
68
+ best_model = Poetry_Model_lstm(args.n_hidden, args.word_size, args.embedding_num, args.Word2Vec)
69
+ else:
70
+ print("Please choose a model!\n")
71
+
72
+ best_model = make_cuda(best_model)
73
  best_model = train(args, best_model, train_data_loader)
74
 
75
+ torch.save(best_model.state_dict(), args.save_path + args.model + '_' + str(args.num_epochs)+'.pth')
76
 
77
  print('test evaluation:')
78
  evaluate(args, best_model, valid_data_loader)