nevmenandr commited on
Commit
67a151a
1 Parent(s): 27496b8

Upload 3 files

Browse files
Files changed (3) hide show
  1. model.py +100 -0
  2. sample.py +54 -0
  3. utils.py +72 -0
model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.python.ops import rnn_cell
3
+ from tensorflow.python.ops import seq2seq
4
+ import random
5
+ import numpy as np
6
+
7
+ class Model():
8
+ def __init__(self, args, infer=False):
9
+ self.args = args
10
+ if infer:
11
+ args.batch_size = 1
12
+ args.seq_length = 1
13
+
14
+ if args.model == 'rnn':
15
+ cell_fn = rnn_cell.BasicRNNCell
16
+ elif args.model == 'gru':
17
+ cell_fn = rnn_cell.GRUCell
18
+ elif args.model == 'lstm':
19
+ cell_fn = rnn_cell.BasicLSTMCell
20
+ else:
21
+ raise Exception("model type not supported: {}".format(args.model))
22
+
23
+ cell = cell_fn(args.rnn_size)
24
+
25
+ self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
26
+
27
+ self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
28
+ self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
29
+ self.initial_state = cell.zero_state(args.batch_size, tf.float32)
30
+
31
+ with tf.variable_scope('rnnlm'):
32
+ softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
33
+ softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
34
+ with tf.device("/cpu:0"):
35
+ embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
36
+ inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
37
+ inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
38
+
39
+ def loop(prev, _):
40
+ prev = tf.matmul(prev, softmax_w) + softmax_b
41
+ prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
42
+ return tf.nn.embedding_lookup(embedding, prev_symbol)
43
+
44
+ outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
45
+ output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
46
+ self.logits = tf.matmul(output, softmax_w) + softmax_b
47
+ self.probs = tf.nn.softmax(self.logits)
48
+ loss = seq2seq.sequence_loss_by_example([self.logits],
49
+ [tf.reshape(self.targets, [-1])],
50
+ [tf.ones([args.batch_size * args.seq_length])],
51
+ args.vocab_size)
52
+ self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
53
+ self.final_state = last_state
54
+ self.lr = tf.Variable(0.0, trainable=False)
55
+ tvars = tf.trainable_variables()
56
+ grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
57
+ args.grad_clip)
58
+ optimizer = tf.train.AdamOptimizer(self.lr)
59
+ self.train_op = optimizer.apply_gradients(zip(grads, tvars))
60
+
61
+ def sample(self, sess, words, vocab, num=200, prime='first all', sampling_type=1):
62
+ state = sess.run(self.cell.zero_state(1, tf.float32))
63
+ if not len(prime) or prime == " ":
64
+ prime = random.choice(list(vocab.keys()))
65
+ print (prime)
66
+ for word in prime.split()[:-1]:
67
+ print (word)
68
+ x = np.zeros((1, 1))
69
+ x[0, 0] = vocab.get(word,0)
70
+ feed = {self.input_data: x, self.initial_state:state}
71
+ [state] = sess.run([self.final_state], feed)
72
+
73
+ def weighted_pick(weights):
74
+ t = np.cumsum(weights)
75
+ s = np.sum(weights)
76
+ return(int(np.searchsorted(t, np.random.rand(1)*s)))
77
+
78
+ ret = prime
79
+ word = prime.split()[-1]
80
+ for n in range(num):
81
+ x = np.zeros((1, 1))
82
+ x[0, 0] = vocab.get(word,0)
83
+ feed = {self.input_data: x, self.initial_state:state}
84
+ [probs, state] = sess.run([self.probs, self.final_state], feed)
85
+ p = probs[0]
86
+
87
+ if sampling_type == 0:
88
+ sample = np.argmax(p)
89
+ elif sampling_type == 2:
90
+ if word == '\n':
91
+ sample = weighted_pick(p)
92
+ else:
93
+ sample = np.argmax(p)
94
+ else: # sampling_type == 1 default:
95
+ sample = weighted_pick(p)
96
+
97
+ pred = words[sample]
98
+ ret += ' ' + pred
99
+ word = pred
100
+ return ret
sample.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import print_function
5
+ import numpy as np
6
+ import tensorflow as tf
7
+
8
+ import argparse
9
+ import time
10
+ import os
11
+ from six.moves import cPickle
12
+
13
+ from utils import TextLoader
14
+ from model import Model
15
+
16
+ from six import text_type
17
+
18
+ import re
19
+
20
+ def main():
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--save_dir', type=str, default='./save',
23
+ help='model directory to store checkpointed models')
24
+ parser.add_argument('-n', type=int, default=800,
25
+ help='number of characters to sample')
26
+ parser.add_argument('--prime', type=text_type, default=u'Промхимия ',
27
+ help='prime text')
28
+ parser.add_argument('--sample', type=int, default=1,
29
+ help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces')
30
+
31
+ args = parser.parse_args()
32
+ sample(args)
33
+
34
+ def sample(args):
35
+ with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
36
+ saved_args = cPickle.load(f)
37
+ with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
38
+ chars, vocab = cPickle.load(f)
39
+ model = Model(saved_args, True)
40
+ with tf.Session() as sess:
41
+ tf.initialize_all_variables().run()
42
+ saver = tf.train.Saver(tf.all_variables())
43
+ ckpt = tf.train.get_checkpoint_state(args.save_dir)
44
+ if ckpt and ckpt.model_checkpoint_path:
45
+ saver.restore(sess, ckpt.model_checkpoint_path)
46
+ #print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample))
47
+ sample_string = model.sample(sess, chars, vocab, args.n, args.prime, args.sample)
48
+ sample_string = re.sub(u' ([^ ])', u'\\1', sample_string)
49
+ sample_string = re.sub(u'[ ]+', u' ', sample_string)
50
+ print(sample_string)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ main()
utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
+ import os
3
+ import collections
4
+ from six.moves import cPickle
5
+ import numpy as np
6
+
7
+ class TextLoader():
8
+ def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
9
+ self.data_dir = data_dir
10
+ self.batch_size = batch_size
11
+ self.seq_length = seq_length
12
+ self.encoding = encoding
13
+
14
+ input_file = os.path.join(data_dir, "input.txt")
15
+ vocab_file = os.path.join(data_dir, "vocab.pkl")
16
+ tensor_file = os.path.join(data_dir, "data.npy")
17
+
18
+ if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):
19
+ print("reading text file")
20
+ self.preprocess(input_file, vocab_file, tensor_file)
21
+ else:
22
+ print("loading preprocessed files")
23
+ self.load_preprocessed(vocab_file, tensor_file)
24
+ self.create_batches()
25
+ self.reset_batch_pointer()
26
+
27
+ def preprocess(self, input_file, vocab_file, tensor_file):
28
+ with codecs.open(input_file, "r", encoding=self.encoding) as f:
29
+ data = f.read()
30
+ counter = collections.Counter(data)
31
+ count_pairs = sorted(counter.items(), key=lambda x: -x[1])
32
+ self.chars, _ = zip(*count_pairs)
33
+ self.vocab_size = len(self.chars)
34
+ self.vocab = dict(zip(self.chars, range(len(self.chars))))
35
+ with open(vocab_file, 'wb') as f:
36
+ cPickle.dump(self.chars, f)
37
+ self.tensor = np.array(list(map(self.vocab.get, data)))
38
+ np.save(tensor_file, self.tensor)
39
+
40
+ def load_preprocessed(self, vocab_file, tensor_file):
41
+ with open(vocab_file, 'rb') as f:
42
+ self.chars = cPickle.load(f)
43
+ self.vocab_size = len(self.chars)
44
+ self.vocab = dict(zip(self.chars, range(len(self.chars))))
45
+ self.tensor = np.load(tensor_file)
46
+ self.num_batches = int(self.tensor.size / (self.batch_size *
47
+ self.seq_length))
48
+
49
+ def create_batches(self):
50
+ self.num_batches = int(self.tensor.size / (self.batch_size *
51
+ self.seq_length))
52
+
53
+ # When the data (tesor) is too small, let's give them a better error message
54
+ if self.num_batches==0:
55
+ assert False, "Not enough data. Make seq_length and batch_size small."
56
+
57
+ self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]
58
+ xdata = self.tensor
59
+ ydata = np.copy(self.tensor)
60
+ ydata[:-1] = xdata[1:]
61
+ ydata[-1] = xdata[0]
62
+ self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1)
63
+ self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1)
64
+
65
+
66
+ def next_batch(self):
67
+ x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
68
+ self.pointer += 1
69
+ return x, y
70
+
71
+ def reset_batch_pointer(self):
72
+ self.pointer = 0