nevmenandr
commited on
Commit
•
67a151a
1
Parent(s):
27496b8
Upload 3 files
Browse files
model.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.python.ops import rnn_cell
|
3 |
+
from tensorflow.python.ops import seq2seq
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class Model():
|
8 |
+
def __init__(self, args, infer=False):
|
9 |
+
self.args = args
|
10 |
+
if infer:
|
11 |
+
args.batch_size = 1
|
12 |
+
args.seq_length = 1
|
13 |
+
|
14 |
+
if args.model == 'rnn':
|
15 |
+
cell_fn = rnn_cell.BasicRNNCell
|
16 |
+
elif args.model == 'gru':
|
17 |
+
cell_fn = rnn_cell.GRUCell
|
18 |
+
elif args.model == 'lstm':
|
19 |
+
cell_fn = rnn_cell.BasicLSTMCell
|
20 |
+
else:
|
21 |
+
raise Exception("model type not supported: {}".format(args.model))
|
22 |
+
|
23 |
+
cell = cell_fn(args.rnn_size)
|
24 |
+
|
25 |
+
self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
|
26 |
+
|
27 |
+
self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
|
28 |
+
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
|
29 |
+
self.initial_state = cell.zero_state(args.batch_size, tf.float32)
|
30 |
+
|
31 |
+
with tf.variable_scope('rnnlm'):
|
32 |
+
softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
|
33 |
+
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
|
34 |
+
with tf.device("/cpu:0"):
|
35 |
+
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
|
36 |
+
inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
|
37 |
+
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
|
38 |
+
|
39 |
+
def loop(prev, _):
|
40 |
+
prev = tf.matmul(prev, softmax_w) + softmax_b
|
41 |
+
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
|
42 |
+
return tf.nn.embedding_lookup(embedding, prev_symbol)
|
43 |
+
|
44 |
+
outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
|
45 |
+
output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
|
46 |
+
self.logits = tf.matmul(output, softmax_w) + softmax_b
|
47 |
+
self.probs = tf.nn.softmax(self.logits)
|
48 |
+
loss = seq2seq.sequence_loss_by_example([self.logits],
|
49 |
+
[tf.reshape(self.targets, [-1])],
|
50 |
+
[tf.ones([args.batch_size * args.seq_length])],
|
51 |
+
args.vocab_size)
|
52 |
+
self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
|
53 |
+
self.final_state = last_state
|
54 |
+
self.lr = tf.Variable(0.0, trainable=False)
|
55 |
+
tvars = tf.trainable_variables()
|
56 |
+
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
|
57 |
+
args.grad_clip)
|
58 |
+
optimizer = tf.train.AdamOptimizer(self.lr)
|
59 |
+
self.train_op = optimizer.apply_gradients(zip(grads, tvars))
|
60 |
+
|
61 |
+
def sample(self, sess, words, vocab, num=200, prime='first all', sampling_type=1):
|
62 |
+
state = sess.run(self.cell.zero_state(1, tf.float32))
|
63 |
+
if not len(prime) or prime == " ":
|
64 |
+
prime = random.choice(list(vocab.keys()))
|
65 |
+
print (prime)
|
66 |
+
for word in prime.split()[:-1]:
|
67 |
+
print (word)
|
68 |
+
x = np.zeros((1, 1))
|
69 |
+
x[0, 0] = vocab.get(word,0)
|
70 |
+
feed = {self.input_data: x, self.initial_state:state}
|
71 |
+
[state] = sess.run([self.final_state], feed)
|
72 |
+
|
73 |
+
def weighted_pick(weights):
|
74 |
+
t = np.cumsum(weights)
|
75 |
+
s = np.sum(weights)
|
76 |
+
return(int(np.searchsorted(t, np.random.rand(1)*s)))
|
77 |
+
|
78 |
+
ret = prime
|
79 |
+
word = prime.split()[-1]
|
80 |
+
for n in range(num):
|
81 |
+
x = np.zeros((1, 1))
|
82 |
+
x[0, 0] = vocab.get(word,0)
|
83 |
+
feed = {self.input_data: x, self.initial_state:state}
|
84 |
+
[probs, state] = sess.run([self.probs, self.final_state], feed)
|
85 |
+
p = probs[0]
|
86 |
+
|
87 |
+
if sampling_type == 0:
|
88 |
+
sample = np.argmax(p)
|
89 |
+
elif sampling_type == 2:
|
90 |
+
if word == '\n':
|
91 |
+
sample = weighted_pick(p)
|
92 |
+
else:
|
93 |
+
sample = np.argmax(p)
|
94 |
+
else: # sampling_type == 1 default:
|
95 |
+
sample = weighted_pick(p)
|
96 |
+
|
97 |
+
pred = words[sample]
|
98 |
+
ret += ' ' + pred
|
99 |
+
word = pred
|
100 |
+
return ret
|
sample.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from __future__ import print_function
|
5 |
+
import numpy as np
|
6 |
+
import tensorflow as tf
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import time
|
10 |
+
import os
|
11 |
+
from six.moves import cPickle
|
12 |
+
|
13 |
+
from utils import TextLoader
|
14 |
+
from model import Model
|
15 |
+
|
16 |
+
from six import text_type
|
17 |
+
|
18 |
+
import re
|
19 |
+
|
20 |
+
def main():
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument('--save_dir', type=str, default='./save',
|
23 |
+
help='model directory to store checkpointed models')
|
24 |
+
parser.add_argument('-n', type=int, default=800,
|
25 |
+
help='number of characters to sample')
|
26 |
+
parser.add_argument('--prime', type=text_type, default=u'Промхимия ',
|
27 |
+
help='prime text')
|
28 |
+
parser.add_argument('--sample', type=int, default=1,
|
29 |
+
help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces')
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
sample(args)
|
33 |
+
|
34 |
+
def sample(args):
|
35 |
+
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
|
36 |
+
saved_args = cPickle.load(f)
|
37 |
+
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
|
38 |
+
chars, vocab = cPickle.load(f)
|
39 |
+
model = Model(saved_args, True)
|
40 |
+
with tf.Session() as sess:
|
41 |
+
tf.initialize_all_variables().run()
|
42 |
+
saver = tf.train.Saver(tf.all_variables())
|
43 |
+
ckpt = tf.train.get_checkpoint_state(args.save_dir)
|
44 |
+
if ckpt and ckpt.model_checkpoint_path:
|
45 |
+
saver.restore(sess, ckpt.model_checkpoint_path)
|
46 |
+
#print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample))
|
47 |
+
sample_string = model.sample(sess, chars, vocab, args.n, args.prime, args.sample)
|
48 |
+
sample_string = re.sub(u' ([^ ])', u'\\1', sample_string)
|
49 |
+
sample_string = re.sub(u'[ ]+', u' ', sample_string)
|
50 |
+
print(sample_string)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import codecs
|
2 |
+
import os
|
3 |
+
import collections
|
4 |
+
from six.moves import cPickle
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class TextLoader():
|
8 |
+
def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
|
9 |
+
self.data_dir = data_dir
|
10 |
+
self.batch_size = batch_size
|
11 |
+
self.seq_length = seq_length
|
12 |
+
self.encoding = encoding
|
13 |
+
|
14 |
+
input_file = os.path.join(data_dir, "input.txt")
|
15 |
+
vocab_file = os.path.join(data_dir, "vocab.pkl")
|
16 |
+
tensor_file = os.path.join(data_dir, "data.npy")
|
17 |
+
|
18 |
+
if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):
|
19 |
+
print("reading text file")
|
20 |
+
self.preprocess(input_file, vocab_file, tensor_file)
|
21 |
+
else:
|
22 |
+
print("loading preprocessed files")
|
23 |
+
self.load_preprocessed(vocab_file, tensor_file)
|
24 |
+
self.create_batches()
|
25 |
+
self.reset_batch_pointer()
|
26 |
+
|
27 |
+
def preprocess(self, input_file, vocab_file, tensor_file):
|
28 |
+
with codecs.open(input_file, "r", encoding=self.encoding) as f:
|
29 |
+
data = f.read()
|
30 |
+
counter = collections.Counter(data)
|
31 |
+
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
|
32 |
+
self.chars, _ = zip(*count_pairs)
|
33 |
+
self.vocab_size = len(self.chars)
|
34 |
+
self.vocab = dict(zip(self.chars, range(len(self.chars))))
|
35 |
+
with open(vocab_file, 'wb') as f:
|
36 |
+
cPickle.dump(self.chars, f)
|
37 |
+
self.tensor = np.array(list(map(self.vocab.get, data)))
|
38 |
+
np.save(tensor_file, self.tensor)
|
39 |
+
|
40 |
+
def load_preprocessed(self, vocab_file, tensor_file):
|
41 |
+
with open(vocab_file, 'rb') as f:
|
42 |
+
self.chars = cPickle.load(f)
|
43 |
+
self.vocab_size = len(self.chars)
|
44 |
+
self.vocab = dict(zip(self.chars, range(len(self.chars))))
|
45 |
+
self.tensor = np.load(tensor_file)
|
46 |
+
self.num_batches = int(self.tensor.size / (self.batch_size *
|
47 |
+
self.seq_length))
|
48 |
+
|
49 |
+
def create_batches(self):
|
50 |
+
self.num_batches = int(self.tensor.size / (self.batch_size *
|
51 |
+
self.seq_length))
|
52 |
+
|
53 |
+
# When the data (tesor) is too small, let's give them a better error message
|
54 |
+
if self.num_batches==0:
|
55 |
+
assert False, "Not enough data. Make seq_length and batch_size small."
|
56 |
+
|
57 |
+
self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]
|
58 |
+
xdata = self.tensor
|
59 |
+
ydata = np.copy(self.tensor)
|
60 |
+
ydata[:-1] = xdata[1:]
|
61 |
+
ydata[-1] = xdata[0]
|
62 |
+
self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1)
|
63 |
+
self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1)
|
64 |
+
|
65 |
+
|
66 |
+
def next_batch(self):
|
67 |
+
x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
|
68 |
+
self.pointer += 1
|
69 |
+
return x, y
|
70 |
+
|
71 |
+
def reset_batch_pointer(self):
|
72 |
+
self.pointer = 0
|