File size: 1,911 Bytes
67a151a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function
import numpy as np
import tensorflow as tf

import argparse
import time
import os
from six.moves import cPickle

from utils import TextLoader
from model import Model

from six import text_type

import re

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_dir', type=str, default='./save',
                       help='model directory to store checkpointed models')
    parser.add_argument('-n', type=int, default=800,
                       help='number of characters to sample')
    parser.add_argument('--prime', type=text_type, default=u'Промхимия ',
                       help='prime text')
    parser.add_argument('--sample', type=int, default=1,
                       help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces')

    args = parser.parse_args()
    sample(args)

def sample(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    model = Model(saved_args, True)
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            #print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample))
            sample_string = model.sample(sess, chars, vocab, args.n, args.prime, args.sample)
            sample_string = re.sub(u' ([^ ])', u'\\1', sample_string)
            sample_string = re.sub(u'[ ]+', u' ', sample_string)
            print(sample_string)
            

if __name__ == '__main__':
    main()