|
|
|
|
|
|
|
from __future__ import print_function |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
import argparse |
|
import time |
|
import os |
|
from six.moves import cPickle |
|
|
|
from utils import TextLoader |
|
from model import Model |
|
|
|
from six import text_type |
|
|
|
import re |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--save_dir', type=str, default='./save', |
|
help='model directory to store checkpointed models') |
|
parser.add_argument('-n', type=int, default=200, |
|
help='number of characters to sample') |
|
parser.add_argument('--prime', type=text_type, default=u' ', |
|
help='prime text') |
|
parser.add_argument('--sample', type=int, default=1, |
|
help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces') |
|
|
|
args = parser.parse_args() |
|
sample(args) |
|
|
|
def sample(args): |
|
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: |
|
saved_args = cPickle.load(f) |
|
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: |
|
chars, vocab = cPickle.load(f) |
|
model = Model(saved_args, True) |
|
with tf.Session() as sess: |
|
tf.initialize_all_variables().run() |
|
saver = tf.train.Saver(tf.all_variables()) |
|
ckpt = tf.train.get_checkpoint_state(args.save_dir) |
|
if ckpt and ckpt.model_checkpoint_path: |
|
saver.restore(sess, ckpt.model_checkpoint_path) |
|
|
|
sample_string = model.sample(sess, chars, vocab, args.n, args.prime, args.sample) |
|
sample_string = re.sub(u' ([^ ])', u'\\1', sample_string) |
|
sample_string = re.sub(u'[ ]+', u' ', sample_string) |
|
print(sample_string) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|