Spaces:
Configuration error
Configuration error
Upload 31 files
Browse files- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-38.pyc +0 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/apis/__init__.py +0 -0
- src/apis/__pycache__/__init__.cpython-39.pyc +0 -0
- src/apis/__pycache__/inference.cpython-39.pyc +0 -0
- src/apis/__pycache__/train.cpython-39.pyc +0 -0
- src/apis/evaluate.py +23 -0
- src/apis/train.py +68 -0
- src/datasets/__init__.py +0 -0
- src/datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- src/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- src/datasets/__pycache__/dataloader.cpython-38.pyc +0 -0
- src/datasets/__pycache__/dataloader.cpython-39.pyc +0 -0
- src/datasets/dataloader.py +115 -0
- src/models/LSTM/__init__.py +0 -0
- src/models/LSTM/__pycache__/__init__.cpython-38.pyc +0 -0
- src/models/LSTM/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/LSTM/__pycache__/algorithm.cpython-39.pyc +0 -0
- src/models/LSTM/__pycache__/model.cpython-38.pyc +0 -0
- src/models/LSTM/__pycache__/model.cpython-39.pyc +0 -0
- src/models/LSTM/model.py +37 -0
- src/models/__init__.py +0 -0
- src/models/__pycache__/__init__.cpython-38.pyc +0 -0
- src/models/__pycache__/__init__.cpython-39.pyc +0 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- src/utils/__pycache__/utils.cpython-38.pyc +0 -0
- src/utils/__pycache__/utils.cpython-39.pyc +0 -0
- src/utils/utils.py +15 -0
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (166 Bytes). View file
|
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (146 Bytes). View file
|
|
src/apis/__init__.py
ADDED
File without changes
|
src/apis/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (151 Bytes). View file
|
|
src/apis/__pycache__/inference.cpython-39.pyc
ADDED
Binary file (1.44 kB). View file
|
|
src/apis/__pycache__/train.cpython-39.pyc
ADDED
Binary file (1.68 kB). View file
|
|
src/apis/evaluate.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from src.models.EA_LSTM.model import weightedLSTM
|
4 |
+
from src.datasets.dataloader import MyDataset, create_vocab
|
5 |
+
|
6 |
+
|
7 |
+
def test(args):
|
8 |
+
vocab, poetrys = create_vocab(args.data)
|
9 |
+
# 词汇表长度
|
10 |
+
args.vocab_size = len(vocab)
|
11 |
+
int2char = np.array(vocab)
|
12 |
+
valid_dataset = MyDataset(vocab, poetrys, args, train=False)
|
13 |
+
|
14 |
+
model = weightedLSTM(6110, 256, 128, 2, [1.0] * 80, False)
|
15 |
+
model.load_state_dict(torch.load(args.save_path))
|
16 |
+
|
17 |
+
input_example_batch, target_example_batch = valid_dataset[0]
|
18 |
+
example_batch_predictions = model(input_example_batch)
|
19 |
+
predicted_id = torch.distributions.Categorical(example_batch_predictions).sample()
|
20 |
+
predicted_id = torch.squeeze(predicted_id, -1).numpy()
|
21 |
+
print("Input: \n", repr("".join(int2char[input_example_batch])))
|
22 |
+
print()
|
23 |
+
print("Predictions: \n", repr("".join(int2char[predicted_id])))
|
src/apis/train.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
from src.utils.utils import make_cuda
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
9 |
+
|
10 |
+
|
11 |
+
def train(args, model, data_loader, initial=False):
|
12 |
+
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
|
13 |
+
|
14 |
+
model.train()
|
15 |
+
num_epochs = args.initial_epochs if initial else args.num_epochs
|
16 |
+
|
17 |
+
for epoch in range(num_epochs):
|
18 |
+
loss = 0
|
19 |
+
for step, (features, targets) in enumerate(data_loader):
|
20 |
+
features = make_cuda(features)
|
21 |
+
targets = make_cuda(targets)
|
22 |
+
|
23 |
+
optimizer.zero_grad()
|
24 |
+
|
25 |
+
pre, _ = model(features)
|
26 |
+
crs_loss = model.cross_entropy(pre, targets.reshape(-1))
|
27 |
+
loss += crs_loss.item()
|
28 |
+
crs_loss.backward()
|
29 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
30 |
+
optimizer.step()
|
31 |
+
|
32 |
+
# print step info
|
33 |
+
if (step + 1) % args.log_step == 0:
|
34 |
+
print("Epoch [%.3d/%.3d] Step [%.3d/%.3d]: CROSS_loss=%.4f, RCROSS_loss=%.4f"
|
35 |
+
% (epoch + 1,
|
36 |
+
num_epochs,
|
37 |
+
step + 1,
|
38 |
+
len(data_loader),
|
39 |
+
loss / args.log_step,
|
40 |
+
math.sqrt(loss / args.log_step)))
|
41 |
+
loss = 0
|
42 |
+
|
43 |
+
# Loss = []
|
44 |
+
# for step, (features, targets) in enumerate(valid_data_loader):
|
45 |
+
# features = make_cuda(features)
|
46 |
+
# targets = make_cuda(targets)
|
47 |
+
# model.eval()
|
48 |
+
# preds = model(features)
|
49 |
+
# valid_loss = CrossLoss(preds, targets)
|
50 |
+
# Loss.append(valid_loss)
|
51 |
+
# print("Valid loss: %.3d\n" % (np.mean(Loss)))
|
52 |
+
|
53 |
+
return model
|
54 |
+
|
55 |
+
|
56 |
+
def evaluate(args, model, data_loader):
|
57 |
+
model.eval()
|
58 |
+
loss = []
|
59 |
+
for step, (features, targets) in enumerate(data_loader):
|
60 |
+
features = make_cuda(features)
|
61 |
+
targets = make_cuda(targets)
|
62 |
+
|
63 |
+
pre, _ = model(features)
|
64 |
+
crs_loss = model.cross_entropy(pre, targets.reshape(-1))
|
65 |
+
loss.append(crs_loss.item())
|
66 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
67 |
+
|
68 |
+
print("loss=%.4f" % (np.mean(loss)))
|
src/datasets/__init__.py
ADDED
File without changes
|
src/datasets/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (175 Bytes). View file
|
|
src/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (155 Bytes). View file
|
|
src/datasets/__pycache__/dataloader.cpython-38.pyc
ADDED
Binary file (4.09 kB). View file
|
|
src/datasets/__pycache__/dataloader.cpython-39.pyc
ADDED
Binary file (4.12 kB). View file
|
|
src/datasets/dataloader.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from gensim.models.word2vec import Word2Vec
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
|
9 |
+
|
10 |
+
def padding(poetries, maxlen, pad):
|
11 |
+
batch_seq = [poetry + pad * (maxlen - len(poetry)) for poetry in poetries]
|
12 |
+
return batch_seq
|
13 |
+
|
14 |
+
|
15 |
+
# 输入向后滑一字符为target,即预测下一个字
|
16 |
+
def split_input_target(seq):
|
17 |
+
inputs = seq[:-1]
|
18 |
+
targets = seq[1:]
|
19 |
+
return inputs, targets
|
20 |
+
|
21 |
+
|
22 |
+
# 创建词汇表
|
23 |
+
def get_poetry(arg):
|
24 |
+
poetrys = []
|
25 |
+
if arg.Augmented_dataset:
|
26 |
+
path = arg.Augmented_data
|
27 |
+
else:
|
28 |
+
path = arg.data
|
29 |
+
with open(path, "r", encoding='UTF-8') as f:
|
30 |
+
for line in f:
|
31 |
+
try:
|
32 |
+
# line = line.decode('UTF-8')
|
33 |
+
line = line.strip(u'\n')
|
34 |
+
if arg.Augmented_dataset:
|
35 |
+
content = line.strip(u' ')
|
36 |
+
else:
|
37 |
+
title, content = line.strip(u' ').split(u':')
|
38 |
+
content = content.replace(u' ', u'')
|
39 |
+
if u'_' in content or u'(' in content or u'(' in content or u'《' in content or u'[' in content:
|
40 |
+
continue
|
41 |
+
if arg.strict_dataset:
|
42 |
+
if len(content) < 12 or len(content) > 79:
|
43 |
+
continue
|
44 |
+
else:
|
45 |
+
if len(content) < 5 or len(content) > 79:
|
46 |
+
continue
|
47 |
+
content = u'[' + content + u']'
|
48 |
+
poetrys.append(content)
|
49 |
+
except Exception as e:
|
50 |
+
pass
|
51 |
+
|
52 |
+
# 按诗的字数排序
|
53 |
+
poetrys = sorted(poetrys, key=lambda line: len(line))
|
54 |
+
|
55 |
+
with open("data/org_poetry.txt", "w", encoding="utf-8") as f:
|
56 |
+
for poetry in poetrys:
|
57 |
+
poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n'
|
58 |
+
f.write(poetry)
|
59 |
+
|
60 |
+
return poetrys
|
61 |
+
|
62 |
+
|
63 |
+
# 切分文档
|
64 |
+
def split_text(poetrys):
|
65 |
+
with open("data/split_poetry.txt", "w", encoding="utf-8") as f:
|
66 |
+
for poetry in poetrys:
|
67 |
+
poetry = str(poetry).strip('[').strip(']').replace(',', '').replace('\'', '') + '\n '
|
68 |
+
split_data = " ".join(poetry)
|
69 |
+
f.write(split_data)
|
70 |
+
return open("data/split_poetry.txt", "r", encoding='UTF-8').read()
|
71 |
+
|
72 |
+
|
73 |
+
# 训练词向量
|
74 |
+
def train_vec(split_file="data/split_poetry.txt", org_file="data/org_poetry.txt"):
|
75 |
+
param_file = "data/word_vec.pkl"
|
76 |
+
org_data = open(org_file, "r", encoding="utf-8").read().split("\n")
|
77 |
+
if os.path.exists(split_file):
|
78 |
+
all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")
|
79 |
+
else:
|
80 |
+
all_data_split = split_text().split("\n")
|
81 |
+
|
82 |
+
if os.path.exists(param_file):
|
83 |
+
return org_data, pickle.load(open(param_file, "rb"))
|
84 |
+
|
85 |
+
models = Word2Vec(all_data_split, vector_size=256, workers=7, min_count=1)
|
86 |
+
pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
|
87 |
+
return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)
|
88 |
+
|
89 |
+
|
90 |
+
class Poetry_Dataset(Dataset):
|
91 |
+
def __init__(self, w1, word_2_index, all_data, Word2Vec):
|
92 |
+
self.Word2Vec = Word2Vec
|
93 |
+
self.w1 = w1
|
94 |
+
self.word_2_index = word_2_index
|
95 |
+
word_size, embedding_num = w1.shape
|
96 |
+
self.embedding = nn.Embedding(word_size, embedding_num)
|
97 |
+
# 最长句子长度
|
98 |
+
maxlen = max([len(seq) for seq in all_data])
|
99 |
+
pad = ' '
|
100 |
+
self.all_data = padding(all_data[:-1], maxlen, pad)
|
101 |
+
|
102 |
+
def __getitem__(self, index):
|
103 |
+
a_poetry = self.all_data[index]
|
104 |
+
|
105 |
+
a_poetry_index = [self.word_2_index[i] for i in a_poetry]
|
106 |
+
xs, ys = split_input_target(a_poetry_index)
|
107 |
+
if self.Word2Vec:
|
108 |
+
xs_embedding = self.w1[xs]
|
109 |
+
else:
|
110 |
+
xs_embedding = np.array(xs)
|
111 |
+
|
112 |
+
return xs_embedding, np.array(ys).astype(np.int64)
|
113 |
+
|
114 |
+
def __len__(self):
|
115 |
+
return len(self.all_data)
|
src/models/LSTM/__init__.py
ADDED
File without changes
|
src/models/LSTM/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (178 Bytes). View file
|
|
src/models/LSTM/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (161 Bytes). View file
|
|
src/models/LSTM/__pycache__/algorithm.cpython-39.pyc
ADDED
Binary file (4.99 kB). View file
|
|
src/models/LSTM/__pycache__/model.cpython-38.pyc
ADDED
Binary file (1.58 kB). View file
|
|
src/models/LSTM/__pycache__/model.cpython-39.pyc
ADDED
Binary file (1.55 kB). View file
|
|
src/models/LSTM/model.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Poetry_Model_lstm(nn.Module):
|
7 |
+
def __init__(self, hidden_num, word_size, embedding_num, Word2Vec):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
self.hidden_num = hidden_num
|
12 |
+
self.Word2Vec = Word2Vec
|
13 |
+
|
14 |
+
self.embedding = nn.Embedding(word_size, embedding_num)
|
15 |
+
self.lstm = nn.LSTM(input_size=embedding_num, hidden_size=hidden_num, batch_first=True, num_layers=2,
|
16 |
+
bidirectional=False)
|
17 |
+
self.dropout = nn.Dropout(0.3)
|
18 |
+
self.flatten = nn.Flatten(0, 1)
|
19 |
+
self.linear = nn.Linear(hidden_num, word_size)
|
20 |
+
self.cross_entropy = nn.CrossEntropyLoss()
|
21 |
+
|
22 |
+
def forward(self, xs_embedding, h_0=None, c_0=None):
|
23 |
+
# xs_embedding: [batch_size, max_seq_len, n_feature] n_feature=128
|
24 |
+
if h_0 == None or c_0 == None:
|
25 |
+
h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
|
26 |
+
c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
|
27 |
+
h_0 = h_0.to(self.device)
|
28 |
+
c_0 = c_0.to(self.device)
|
29 |
+
xs_embedding = xs_embedding.to(self.device)
|
30 |
+
if not self.Word2Vec:
|
31 |
+
xs_embedding = self.embedding(xs_embedding)
|
32 |
+
hidden, (h_0, c_0) = self.lstm(xs_embedding, (h_0, c_0))
|
33 |
+
hidden_drop = self.dropout(hidden)
|
34 |
+
hidden_flatten = self.flatten(hidden_drop)
|
35 |
+
pre = self.linear(hidden_flatten)
|
36 |
+
# pre:[batch_size*max_seq_len, vocab_size]
|
37 |
+
return pre, (h_0, c_0)
|
src/models/__init__.py
ADDED
File without changes
|
src/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (173 Bytes). View file
|
|
src/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (153 Bytes). View file
|
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (172 Bytes). View file
|
|
src/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
src/utils/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (575 Bytes). View file
|
|
src/utils/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (555 Bytes). View file
|
|
src/utils/utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def make_cuda(tensor):
|
5 |
+
"""Use CUDA if it's available."""
|
6 |
+
if torch.cuda.is_available():
|
7 |
+
tensor = tensor.cuda()
|
8 |
+
return tensor
|
9 |
+
|
10 |
+
|
11 |
+
def is_minimum(value, indiv_to_rmse):
|
12 |
+
if len(indiv_to_rmse) == 0:
|
13 |
+
return True
|
14 |
+
temp = list(indiv_to_rmse.values())
|
15 |
+
return True if value < min(temp) else False
|