Spaces:
Running
Running
# Copyright 2015 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Neural GPU -- data generation and batching utilities.""" | |
import math | |
import os | |
import random | |
import sys | |
import time | |
import numpy as np | |
from six.moves import xrange | |
import tensorflow as tf | |
import program_utils | |
FLAGS = tf.app.flags.FLAGS | |
bins = [2 + bin_idx_i for bin_idx_i in xrange(256)] | |
all_tasks = ["sort", "kvsort", "id", "rev", "rev2", "incr", "add", "left", | |
"right", "left-shift", "right-shift", "bmul", "mul", "dup", | |
"badd", "qadd", "search", "progeval", "progsynth"] | |
log_filename = "" | |
vocab, rev_vocab = None, None | |
def pad(l): | |
for b in bins: | |
if b >= l: return b | |
return bins[-1] | |
def bin_for(l): | |
for i, b in enumerate(bins): | |
if b >= l: return i | |
return len(bins) - 1 | |
train_set = {} | |
test_set = {} | |
for some_task in all_tasks: | |
train_set[some_task] = [] | |
test_set[some_task] = [] | |
for all_max_len in xrange(10000): | |
train_set[some_task].append([]) | |
test_set[some_task].append([]) | |
def read_tmp_file(name): | |
"""Read from a file with the given name in our log directory or above.""" | |
dirname = os.path.dirname(log_filename) | |
fname = os.path.join(dirname, name + ".txt") | |
if not tf.gfile.Exists(fname): | |
print_out("== not found file: " + fname) | |
fname = os.path.join(dirname, "../" + name + ".txt") | |
if not tf.gfile.Exists(fname): | |
print_out("== not found file: " + fname) | |
fname = os.path.join(dirname, "../../" + name + ".txt") | |
if not tf.gfile.Exists(fname): | |
print_out("== not found file: " + fname) | |
return None | |
print_out("== found file: " + fname) | |
res = [] | |
with tf.gfile.GFile(fname, mode="r") as f: | |
for line in f: | |
res.append(line.strip()) | |
return res | |
def write_tmp_file(name, lines): | |
dirname = os.path.dirname(log_filename) | |
fname = os.path.join(dirname, name + ".txt") | |
with tf.gfile.GFile(fname, mode="w") as f: | |
for line in lines: | |
f.write(line + "\n") | |
def add(n1, n2, base=10): | |
"""Add two numbers represented as lower-endian digit lists.""" | |
k = max(len(n1), len(n2)) + 1 | |
d1 = n1 + [0 for _ in xrange(k - len(n1))] | |
d2 = n2 + [0 for _ in xrange(k - len(n2))] | |
res = [] | |
carry = 0 | |
for i in xrange(k): | |
if d1[i] + d2[i] + carry < base: | |
res.append(d1[i] + d2[i] + carry) | |
carry = 0 | |
else: | |
res.append(d1[i] + d2[i] + carry - base) | |
carry = 1 | |
while res and res[-1] == 0: | |
res = res[:-1] | |
if res: return res | |
return [0] | |
def init_data(task, length, nbr_cases, nclass): | |
"""Data initialization.""" | |
def rand_pair(l, task): | |
"""Random data pair for a task. Total length should be <= l.""" | |
k = int((l-1)/2) | |
base = 10 | |
if task[0] == "b": base = 2 | |
if task[0] == "q": base = 4 | |
d1 = [np.random.randint(base) for _ in xrange(k)] | |
d2 = [np.random.randint(base) for _ in xrange(k)] | |
if task in ["add", "badd", "qadd"]: | |
res = add(d1, d2, base) | |
elif task in ["mul", "bmul"]: | |
d1n = sum([d * (base ** i) for i, d in enumerate(d1)]) | |
d2n = sum([d * (base ** i) for i, d in enumerate(d2)]) | |
if task == "bmul": | |
res = [int(x) for x in list(reversed(str(bin(d1n * d2n))))[:-2]] | |
else: | |
res = [int(x) for x in list(reversed(str(d1n * d2n)))] | |
else: | |
sys.exit() | |
sep = [12] | |
if task in ["add", "badd", "qadd"]: sep = [11] | |
inp = [d + 1 for d in d1] + sep + [d + 1 for d in d2] | |
return inp, [r + 1 for r in res] | |
def rand_dup_pair(l): | |
"""Random data pair for duplication task. Total length should be <= l.""" | |
k = int(l/2) | |
x = [np.random.randint(nclass - 1) + 1 for _ in xrange(k)] | |
inp = x + [0 for _ in xrange(l - k)] | |
res = x + x + [0 for _ in xrange(l - 2*k)] | |
return inp, res | |
def rand_rev2_pair(l): | |
"""Random data pair for reverse2 task. Total length should be <= l.""" | |
inp = [(np.random.randint(nclass - 1) + 1, | |
np.random.randint(nclass - 1) + 1) for _ in xrange(l/2)] | |
res = [i for i in reversed(inp)] | |
return [x for p in inp for x in p], [x for p in res for x in p] | |
def rand_search_pair(l): | |
"""Random data pair for search task. Total length should be <= l.""" | |
inp = [(np.random.randint(nclass - 1) + 1, | |
np.random.randint(nclass - 1) + 1) for _ in xrange(l-1/2)] | |
q = np.random.randint(nclass - 1) + 1 | |
res = 0 | |
for (k, v) in reversed(inp): | |
if k == q: | |
res = v | |
return [x for p in inp for x in p] + [q], [res] | |
def rand_kvsort_pair(l): | |
"""Random data pair for key-value sort. Total length should be <= l.""" | |
keys = [(np.random.randint(nclass - 1) + 1, i) for i in xrange(l/2)] | |
vals = [np.random.randint(nclass - 1) + 1 for _ in xrange(l/2)] | |
kv = [(k, vals[i]) for (k, i) in keys] | |
sorted_kv = [(k, vals[i]) for (k, i) in sorted(keys)] | |
return [x for p in kv for x in p], [x for p in sorted_kv for x in p] | |
def prog_io_pair(prog, max_len, counter=0): | |
try: | |
ilen = np.random.randint(max_len - 3) + 1 | |
bound = max(15 - (counter / 20), 1) | |
inp = [random.choice(range(-bound, bound)) for _ in range(ilen)] | |
inp_toks = [program_utils.prog_rev_vocab[t] | |
for t in program_utils.tokenize(str(inp)) if t != ","] | |
out = program_utils.evaluate(prog, {"a": inp}) | |
out_toks = [program_utils.prog_rev_vocab[t] | |
for t in program_utils.tokenize(str(out)) if t != ","] | |
if counter > 400: | |
out_toks = [] | |
if (out_toks and out_toks[0] == program_utils.prog_rev_vocab["["] and | |
len(out_toks) != len([o for o in out if o == ","]) + 3): | |
raise ValueError("generated list with too long ints") | |
if (out_toks and out_toks[0] != program_utils.prog_rev_vocab["["] and | |
len(out_toks) > 1): | |
raise ValueError("generated one int but tokenized it to many") | |
if len(out_toks) > max_len: | |
raise ValueError("output too long") | |
return (inp_toks, out_toks) | |
except ValueError: | |
return prog_io_pair(prog, max_len, counter+1) | |
def spec(inp): | |
"""Return the target given the input for some tasks.""" | |
if task == "sort": | |
return sorted(inp) | |
elif task == "id": | |
return inp | |
elif task == "rev": | |
return [i for i in reversed(inp)] | |
elif task == "incr": | |
carry = 1 | |
res = [] | |
for i in xrange(len(inp)): | |
if inp[i] + carry < nclass: | |
res.append(inp[i] + carry) | |
carry = 0 | |
else: | |
res.append(1) | |
carry = 1 | |
return res | |
elif task == "left": | |
return [inp[0]] | |
elif task == "right": | |
return [inp[-1]] | |
elif task == "left-shift": | |
return [inp[l-1] for l in xrange(len(inp))] | |
elif task == "right-shift": | |
return [inp[l+1] for l in xrange(len(inp))] | |
else: | |
print_out("Unknown spec for task " + str(task)) | |
sys.exit() | |
l = length | |
cur_time = time.time() | |
total_time = 0.0 | |
is_prog = task in ["progeval", "progsynth"] | |
if is_prog: | |
inputs_per_prog = 5 | |
program_utils.make_vocab() | |
progs = read_tmp_file("programs_len%d" % (l / 10)) | |
if not progs: | |
progs = program_utils.gen(l / 10, 1.2 * nbr_cases / inputs_per_prog) | |
write_tmp_file("programs_len%d" % (l / 10), progs) | |
prog_ios = read_tmp_file("programs_len%d_io" % (l / 10)) | |
nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2 | |
if not prog_ios: | |
# Generate program io data. | |
prog_ios = [] | |
for pidx, prog in enumerate(progs): | |
if pidx % 500 == 0: | |
print_out("== generating io pairs for program %d" % pidx) | |
if pidx * inputs_per_prog > nbr_cases * 1.2: | |
break | |
ptoks = [program_utils.prog_rev_vocab[t] | |
for t in program_utils.tokenize(prog)] | |
ptoks.append(program_utils.prog_rev_vocab["_EOS"]) | |
plen = len(ptoks) | |
for _ in xrange(inputs_per_prog): | |
if task == "progeval": | |
inp, out = prog_io_pair(prog, plen) | |
prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog) | |
elif task == "progsynth": | |
plen = max(len(ptoks), 8) | |
for _ in xrange(3): | |
inp, out = prog_io_pair(prog, plen / 2) | |
prog_ios.append(str(inp) + "\t" + str(out) + "\t" + prog) | |
write_tmp_file("programs_len%d_io" % (l / 10), prog_ios) | |
prog_ios_dict = {} | |
for s in prog_ios: | |
i, o, p = s.split("\t") | |
i_clean = "".join([c for c in i if c.isdigit() or c == " "]) | |
o_clean = "".join([c for c in o if c.isdigit() or c == " "]) | |
inp = [int(x) for x in i_clean.split()] | |
out = [int(x) for x in o_clean.split()] | |
if inp and out: | |
if p in prog_ios_dict: | |
prog_ios_dict[p].append([inp, out]) | |
else: | |
prog_ios_dict[p] = [[inp, out]] | |
# Use prog_ios_dict to create data. | |
progs = [] | |
for prog in prog_ios_dict: | |
if len([c for c in prog if c == ";"]) <= (l / 10): | |
progs.append(prog) | |
nbr_cases = min(nbr_cases, len(progs) * inputs_per_prog) / 1.2 | |
print_out("== %d training cases on %d progs" % (nbr_cases, len(progs))) | |
for pidx, prog in enumerate(progs): | |
if pidx * inputs_per_prog > nbr_cases * 1.2: | |
break | |
ptoks = [program_utils.prog_rev_vocab[t] | |
for t in program_utils.tokenize(prog)] | |
ptoks.append(program_utils.prog_rev_vocab["_EOS"]) | |
plen = len(ptoks) | |
dset = train_set if pidx < nbr_cases / inputs_per_prog else test_set | |
for _ in xrange(inputs_per_prog): | |
if task == "progeval": | |
inp, out = prog_ios_dict[prog].pop() | |
dset[task][bin_for(plen)].append([[ptoks, inp, [], []], [out]]) | |
elif task == "progsynth": | |
plen, ilist = max(len(ptoks), 8), [[]] | |
for _ in xrange(3): | |
inp, out = prog_ios_dict[prog].pop() | |
ilist.append(inp + out) | |
dset[task][bin_for(plen)].append([ilist, [ptoks]]) | |
for case in xrange(0 if is_prog else nbr_cases): | |
total_time += time.time() - cur_time | |
cur_time = time.time() | |
if l > 10000 and case % 100 == 1: | |
print_out(" avg gen time %.4f s" % (total_time / float(case))) | |
if task in ["add", "badd", "qadd", "bmul", "mul"]: | |
i, t = rand_pair(l, task) | |
train_set[task][bin_for(len(i))].append([[[], i, [], []], [t]]) | |
i, t = rand_pair(l, task) | |
test_set[task][bin_for(len(i))].append([[[], i, [], []], [t]]) | |
elif task == "dup": | |
i, t = rand_dup_pair(l) | |
train_set[task][bin_for(len(i))].append([[i], [t]]) | |
i, t = rand_dup_pair(l) | |
test_set[task][bin_for(len(i))].append([[i], [t]]) | |
elif task == "rev2": | |
i, t = rand_rev2_pair(l) | |
train_set[task][bin_for(len(i))].append([[i], [t]]) | |
i, t = rand_rev2_pair(l) | |
test_set[task][bin_for(len(i))].append([[i], [t]]) | |
elif task == "search": | |
i, t = rand_search_pair(l) | |
train_set[task][bin_for(len(i))].append([[i], [t]]) | |
i, t = rand_search_pair(l) | |
test_set[task][bin_for(len(i))].append([[i], [t]]) | |
elif task == "kvsort": | |
i, t = rand_kvsort_pair(l) | |
train_set[task][bin_for(len(i))].append([[i], [t]]) | |
i, t = rand_kvsort_pair(l) | |
test_set[task][bin_for(len(i))].append([[i], [t]]) | |
elif task not in ["progeval", "progsynth"]: | |
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)] | |
target = spec(inp) | |
train_set[task][bin_for(l)].append([[inp], [target]]) | |
inp = [np.random.randint(nclass - 1) + 1 for i in xrange(l)] | |
target = spec(inp) | |
test_set[task][bin_for(l)].append([[inp], [target]]) | |
def to_symbol(i): | |
"""Covert ids to text.""" | |
if i == 0: return "" | |
if i == 11: return "+" | |
if i == 12: return "*" | |
return str(i-1) | |
def to_id(s): | |
"""Covert text to ids.""" | |
if s == "+": return 11 | |
if s == "*": return 12 | |
return int(s) + 1 | |
def get_batch(bin_id, batch_size, data_set, height, offset=None, preset=None): | |
"""Get a batch of data, training or testing.""" | |
inputs, targets = [], [] | |
pad_length = bins[bin_id] | |
for b in xrange(batch_size): | |
if preset is None: | |
elem = random.choice(data_set[bin_id]) | |
if offset is not None and offset + b < len(data_set[bin_id]): | |
elem = data_set[bin_id][offset + b] | |
else: | |
elem = preset | |
inpt, targett, inpl, targetl = elem[0], elem[1], [], [] | |
for inp in inpt: | |
inpl.append(inp + [0 for _ in xrange(pad_length - len(inp))]) | |
if len(inpl) == 1: | |
for _ in xrange(height - 1): | |
inpl.append([0 for _ in xrange(pad_length)]) | |
for target in targett: | |
targetl.append(target + [0 for _ in xrange(pad_length - len(target))]) | |
inputs.append(inpl) | |
targets.append(targetl) | |
res_input = np.array(inputs, dtype=np.int32) | |
res_target = np.array(targets, dtype=np.int32) | |
assert list(res_input.shape) == [batch_size, height, pad_length] | |
assert list(res_target.shape) == [batch_size, 1, pad_length] | |
return res_input, res_target | |
def print_out(s, newline=True): | |
"""Print a message out and log it to file.""" | |
if log_filename: | |
try: | |
with tf.gfile.GFile(log_filename, mode="a") as f: | |
f.write(s + ("\n" if newline else "")) | |
# pylint: disable=bare-except | |
except: | |
sys.stderr.write("Error appending to %s\n" % log_filename) | |
sys.stdout.write(s + ("\n" if newline else "")) | |
sys.stdout.flush() | |
def decode(output): | |
return [np.argmax(o, axis=1) for o in output] | |
def accuracy(inpt_t, output, target_t, batch_size, nprint, | |
beam_out=None, beam_scores=None): | |
"""Calculate output accuracy given target.""" | |
assert nprint < batch_size + 1 | |
inpt = [] | |
for h in xrange(inpt_t.shape[1]): | |
inpt.extend([inpt_t[:, h, l] for l in xrange(inpt_t.shape[2])]) | |
target = [target_t[:, 0, l] for l in xrange(target_t.shape[2])] | |
def tok(i): | |
if rev_vocab and i < len(rev_vocab): | |
return rev_vocab[i] | |
return str(i - 1) | |
def task_print(inp, output, target): | |
stop_bound = 0 | |
print_len = 0 | |
while print_len < len(target) and target[print_len] > stop_bound: | |
print_len += 1 | |
print_out(" i: " + " ".join([tok(i) for i in inp if i > 0])) | |
print_out(" o: " + | |
" ".join([tok(output[l]) for l in xrange(print_len)])) | |
print_out(" t: " + | |
" ".join([tok(target[l]) for l in xrange(print_len)])) | |
decoded_target = target | |
decoded_output = decode(output) | |
# Use beam output if given and score is high enough. | |
if beam_out is not None: | |
for b in xrange(batch_size): | |
if beam_scores[b] >= 10.0: | |
for l in xrange(min(len(decoded_output), beam_out.shape[2])): | |
decoded_output[l][b] = int(beam_out[b, 0, l]) | |
total = 0 | |
errors = 0 | |
seq = [0 for b in xrange(batch_size)] | |
for l in xrange(len(decoded_output)): | |
for b in xrange(batch_size): | |
if decoded_target[l][b] > 0: | |
total += 1 | |
if decoded_output[l][b] != decoded_target[l][b]: | |
seq[b] = 1 | |
errors += 1 | |
e = 0 # Previous error index | |
for _ in xrange(min(nprint, sum(seq))): | |
while seq[e] == 0: | |
e += 1 | |
task_print([inpt[l][e] for l in xrange(len(inpt))], | |
[decoded_output[l][e] for l in xrange(len(decoded_target))], | |
[decoded_target[l][e] for l in xrange(len(decoded_target))]) | |
e += 1 | |
for b in xrange(nprint - errors): | |
task_print([inpt[l][b] for l in xrange(len(inpt))], | |
[decoded_output[l][b] for l in xrange(len(decoded_target))], | |
[decoded_target[l][b] for l in xrange(len(decoded_target))]) | |
return errors, total, sum(seq) | |
def safe_exp(x): | |
perp = 10000 | |
x = float(x) | |
if x < 100: perp = math.exp(x) | |
if perp > 10000: return 10000 | |
return perp | |