3morrrrr commited on
Commit
569596a
·
verified ·
1 Parent(s): 67df482

Upload 14 files

Browse files
Files changed (14) hide show
  1. data_frame.py +104 -0
  2. demo.py +62 -0
  3. drawing.py +216 -0
  4. hand.py +150 -0
  5. handwriting_api.py +52 -0
  6. lyrics.py +190 -0
  7. prepare_data.py +134 -0
  8. readme.md +68 -0
  9. requirements.txt +15 -0
  10. rnn.py +236 -0
  11. rnn_cell.py +185 -0
  12. rnn_ops.py +249 -0
  13. tf_base_model.py +408 -0
  14. tf_utils.py +91 -0
data_frame.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sklearn.model_selection import train_test_split
6
+
7
+
8
+ class DataFrame(object):
9
+
10
+ """Minimal pd.DataFrame analog for handling n-dimensional numpy matrices with additional
11
+ support for shuffling, batching, and train/test splitting.
12
+
13
+ Args:
14
+ columns: List of names corresponding to the matrices in data.
15
+ data: List of n-dimensional data matrices ordered in correspondence with columns.
16
+ All matrices must have the same leading dimension. Data can also be fed a list of
17
+ instances of np.memmap, in which case RAM usage can be limited to the size of a
18
+ single batch.
19
+ """
20
+
21
+ def __init__(self, columns, data):
22
+ assert len(columns) == len(data), 'columns length does not match data length'
23
+
24
+ lengths = [mat.shape[0] for mat in data]
25
+ assert len(set(lengths)) == 1, 'all matrices in data must have same first dimension'
26
+
27
+ self.length = lengths[0]
28
+ self.columns = columns
29
+ self.data = data
30
+ self.dict = dict(zip(self.columns, self.data))
31
+ self.idx = np.arange(self.length)
32
+
33
+ def shapes(self):
34
+ return pd.Series(dict(zip(self.columns, [mat.shape for mat in self.data])))
35
+
36
+ def dtypes(self):
37
+ return pd.Series(dict(zip(self.columns, [mat.dtype for mat in self.data])))
38
+
39
+ def shuffle(self):
40
+ np.random.shuffle(self.idx)
41
+
42
+ def train_test_split(self, train_size, random_state=np.random.randint(1000), stratify=None):
43
+ train_idx, test_idx = train_test_split(
44
+ self.idx,
45
+ train_size=train_size,
46
+ random_state=random_state,
47
+ stratify=stratify
48
+ )
49
+ train_df = DataFrame(copy.copy(self.columns), [mat[train_idx] for mat in self.data])
50
+ test_df = DataFrame(copy.copy(self.columns), [mat[test_idx] for mat in self.data])
51
+ return train_df, test_df
52
+
53
+ def batch_generator(self, batch_size, shuffle=True, num_epochs=10000, allow_smaller_final_batch=False):
54
+ epoch_num = 0
55
+ while epoch_num < num_epochs:
56
+ if shuffle:
57
+ self.shuffle()
58
+
59
+ for i in range(0, self.length + 1, batch_size):
60
+ batch_idx = self.idx[i: i + batch_size]
61
+ if not allow_smaller_final_batch and len(batch_idx) != batch_size:
62
+ break
63
+ yield DataFrame(
64
+ columns=copy.copy(self.columns),
65
+ data=[mat[batch_idx].copy() for mat in self.data]
66
+ )
67
+
68
+ epoch_num += 1
69
+
70
+ def iterrows(self):
71
+ for i in self.idx:
72
+ yield self[i]
73
+
74
+ def mask(self, mask):
75
+ return DataFrame(copy.copy(self.columns), [mat[mask] for mat in self.data])
76
+
77
+ def concat(self, other_df):
78
+ mats = []
79
+ for column in self.columns:
80
+ mats.append(np.concatenate([self[column], other_df[column]], axis=0))
81
+ return DataFrame(copy.copy(self.columns), mats)
82
+
83
+ def items(self):
84
+ return self.dict.items()
85
+
86
+ def __iter__(self):
87
+ return self.dict.items().__iter__()
88
+
89
+ def __len__(self):
90
+ return self.length
91
+
92
+ def __getitem__(self, key):
93
+ if isinstance(key, str):
94
+ return self.dict[key]
95
+
96
+ elif isinstance(key, int):
97
+ return pd.Series(dict(zip(self.columns, [mat[self.idx[key]] for mat in self.data])))
98
+
99
+ def __setitem__(self, key, value):
100
+ assert value.shape[0] == len(self), 'matrix first dimension does not match'
101
+ if key not in self.columns:
102
+ self.columns.append(key)
103
+ self.data.append(value)
104
+ self.dict[key] = value
demo.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from hand import Hand
3
+
4
+ import lyrics
5
+
6
+
7
+ if __name__ == '__main__':
8
+ hand = Hand()
9
+
10
+ # usage demo
11
+ lines = [
12
+ "what's up how are you marcos ?"
13
+ ]
14
+ biases = [.75 for i in lines]
15
+ styles = [9 for i in lines]
16
+ stroke_colors = ['red', 'green', 'black', 'blue']
17
+ stroke_widths = [1, 2, 1, 2]
18
+
19
+ hand.write(
20
+ filename='img/usage_demo.svg',
21
+ lines=lines,
22
+ biases=biases,
23
+ styles=styles,
24
+ stroke_colors=stroke_colors,
25
+ stroke_widths=stroke_widths
26
+ )
27
+
28
+ # demo number 1 - fixed bias, fixed style
29
+ lines = lyrics.all_star.split("\n")
30
+ biases = [.75 for i in lines]
31
+ styles = [12 for i in lines]
32
+
33
+ hand.write(
34
+ filename='img/all_star.svg',
35
+ lines=lines,
36
+ biases=biases,
37
+ styles=styles,
38
+ )
39
+
40
+ # demo number 2 - fixed bias, varying style
41
+ lines = lyrics.downtown.split("\n")
42
+ biases = [.75 for i in lines]
43
+ styles = np.cumsum(np.array([len(i) for i in lines]) == 0).astype(int)
44
+
45
+ hand.write(
46
+ filename='img/downtown.svg',
47
+ lines=lines,
48
+ biases=biases,
49
+ styles=styles,
50
+ )
51
+
52
+ # demo number 3 - varying bias, fixed style
53
+ lines = lyrics.give_up.split("\n")
54
+ biases = .2*np.flip(np.cumsum([len(i) == 0 for i in lines]), 0)
55
+ styles = [7 for i in lines]
56
+
57
+ hand.write(
58
+ filename='img/give_up.svg',
59
+ lines=lines,
60
+ biases=biases,
61
+ styles=styles,
62
+ )
drawing.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ from collections import defaultdict
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from scipy.signal import savgol_filter
7
+ from scipy.interpolate import interp1d
8
+
9
+
10
+ alphabet = [
11
+ '\x00', ' ', '!', '"', '#', "'", '(', ')', ',', '-', '.',
12
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';',
13
+ '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
14
+ 'L', 'M', 'N', 'O', 'P', 'R', 'S', 'T', 'U', 'V', 'W', 'Y',
15
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
16
+ 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
17
+ 'y', 'z'
18
+ ]
19
+ alphabet_ord = list(map(ord, alphabet))
20
+ alpha_to_num = defaultdict(int, list(map(reversed, enumerate(alphabet))))
21
+ num_to_alpha = dict(enumerate(alphabet_ord))
22
+
23
+ MAX_STROKE_LEN = 1200
24
+ MAX_CHAR_LEN = 75
25
+
26
+
27
+ def align(coords):
28
+ """
29
+ corrects for global slant/offset in handwriting strokes
30
+ """
31
+ coords = np.copy(coords)
32
+ X, Y = coords[:, 0].reshape(-1, 1), coords[:, 1].reshape(-1, 1)
33
+ X = np.concatenate([np.ones([X.shape[0], 1]), X], axis=1)
34
+ offset, slope = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(Y).squeeze()
35
+ theta = np.arctan(slope)
36
+ rotation_matrix = np.array(
37
+ [[np.cos(theta), -np.sin(theta)],
38
+ [np.sin(theta), np.cos(theta)]]
39
+ )
40
+ coords[:, :2] = np.dot(coords[:, :2], rotation_matrix) - offset
41
+ return coords
42
+
43
+
44
+ def skew(coords, degrees):
45
+ """
46
+ skews strokes by given degrees
47
+ """
48
+ coords = np.copy(coords)
49
+ theta = degrees * np.pi/180
50
+ A = np.array([[np.cos(-theta), 0], [np.sin(-theta), 1]])
51
+ coords[:, :2] = np.dot(coords[:, :2], A)
52
+ return coords
53
+
54
+
55
+ def stretch(coords, x_factor, y_factor):
56
+ """
57
+ stretches strokes along x and y axis
58
+ """
59
+ coords = np.copy(coords)
60
+ coords[:, :2] *= np.array([x_factor, y_factor])
61
+ return coords
62
+
63
+
64
+ def add_noise(coords, scale):
65
+ """
66
+ adds gaussian noise to strokes
67
+ """
68
+ coords = np.copy(coords)
69
+ coords[1:, :2] += np.random.normal(loc=0.0, scale=scale, size=coords[1:, :2].shape)
70
+ return coords
71
+
72
+
73
+ def encode_ascii(ascii_string):
74
+ """
75
+ encodes ascii string to array of ints
76
+ """
77
+ return np.array(list(map(lambda x: alpha_to_num[x], ascii_string)) + [0])
78
+
79
+
80
+ def denoise(coords):
81
+ """
82
+ smoothing filter to mitigate some artifacts of the data collection
83
+ """
84
+ coords = np.split(coords, np.where(coords[:, 2] == 1)[0] + 1, axis=0)
85
+ new_coords = []
86
+ for stroke in coords:
87
+ if len(stroke) != 0:
88
+ x_new = savgol_filter(stroke[:, 0], 7, 3, mode='nearest')
89
+ y_new = savgol_filter(stroke[:, 1], 7, 3, mode='nearest')
90
+ xy_coords = np.hstack([x_new.reshape(-1, 1), y_new.reshape(-1, 1)])
91
+ stroke = np.concatenate([xy_coords, stroke[:, 2].reshape(-1, 1)], axis=1)
92
+ new_coords.append(stroke)
93
+
94
+ coords = np.vstack(new_coords)
95
+ return coords
96
+
97
+
98
+ def interpolate(coords, factor=2):
99
+ """
100
+ interpolates strokes using cubic spline
101
+ """
102
+ coords = np.split(coords, np.where(coords[:, 2] == 1)[0] + 1, axis=0)
103
+ new_coords = []
104
+ for stroke in coords:
105
+
106
+ if len(stroke) == 0:
107
+ continue
108
+
109
+ xy_coords = stroke[:, :2]
110
+
111
+ if len(stroke) > 3:
112
+ f_x = interp1d(np.arange(len(stroke)), stroke[:, 0], kind='cubic')
113
+ f_y = interp1d(np.arange(len(stroke)), stroke[:, 1], kind='cubic')
114
+
115
+ xx = np.linspace(0, len(stroke) - 1, factor*(len(stroke)))
116
+ yy = np.linspace(0, len(stroke) - 1, factor*(len(stroke)))
117
+
118
+ x_new = f_x(xx)
119
+ y_new = f_y(yy)
120
+
121
+ xy_coords = np.hstack([x_new.reshape(-1, 1), y_new.reshape(-1, 1)])
122
+
123
+ stroke_eos = np.zeros([len(xy_coords), 1])
124
+ stroke_eos[-1] = 1.0
125
+ stroke = np.concatenate([xy_coords, stroke_eos], axis=1)
126
+ new_coords.append(stroke)
127
+
128
+ coords = np.vstack(new_coords)
129
+ return coords
130
+
131
+
132
+ def normalize(offsets):
133
+ """
134
+ normalizes strokes to median unit norm
135
+ """
136
+ offsets = np.copy(offsets)
137
+ offsets[:, :2] /= np.median(np.linalg.norm(offsets[:, :2], axis=1))
138
+ return offsets
139
+
140
+
141
+ def coords_to_offsets(coords):
142
+ """
143
+ convert from coordinates to offsets
144
+ """
145
+ offsets = np.concatenate([coords[1:, :2] - coords[:-1, :2], coords[1:, 2:3]], axis=1)
146
+ offsets = np.concatenate([np.array([[0, 0, 1]]), offsets], axis=0)
147
+ return offsets
148
+
149
+
150
+ def offsets_to_coords(offsets):
151
+ """
152
+ convert from offsets to coordinates
153
+ """
154
+ return np.concatenate([np.cumsum(offsets[:, :2], axis=0), offsets[:, 2:3]], axis=1)
155
+
156
+
157
+ def draw(
158
+ offsets,
159
+ ascii_seq=None,
160
+ align_strokes=True,
161
+ denoise_strokes=True,
162
+ interpolation_factor=None,
163
+ save_file=None
164
+ ):
165
+ strokes = offsets_to_coords(offsets)
166
+
167
+ if denoise_strokes:
168
+ strokes = denoise(strokes)
169
+
170
+ if interpolation_factor is not None:
171
+ strokes = interpolate(strokes, factor=interpolation_factor)
172
+
173
+ if align_strokes:
174
+ strokes[:, :2] = align(strokes[:, :2])
175
+
176
+ fig, ax = plt.subplots(figsize=(12, 3))
177
+
178
+ stroke = []
179
+ for x, y, eos in strokes:
180
+ stroke.append((x, y))
181
+ if eos == 1:
182
+ coords = zip(*stroke)
183
+ ax.plot(coords[0], coords[1], 'k')
184
+ stroke = []
185
+ if stroke:
186
+ coords = zip(*stroke)
187
+ ax.plot(coords[0], coords[1], 'k')
188
+ stroke = []
189
+
190
+ ax.set_xlim(-50, 600)
191
+ ax.set_ylim(-40, 40)
192
+
193
+ ax.set_aspect('equal')
194
+ plt.tick_params(
195
+ axis='both',
196
+ left='off',
197
+ top='off',
198
+ right='off',
199
+ bottom='off',
200
+ labelleft='off',
201
+ labeltop='off',
202
+ labelright='off',
203
+ labelbottom='off'
204
+ )
205
+
206
+ if ascii_seq is not None:
207
+ if not isinstance(ascii_seq, str):
208
+ ascii_seq = ''.join(list(map(chr, ascii_seq)))
209
+ plt.title(ascii_seq)
210
+
211
+ if save_file is not None:
212
+ plt.savefig(save_file)
213
+ print('saved to {}'.format(save_file))
214
+ else:
215
+ plt.show()
216
+ plt.close('all')
hand.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import drawing
2
+ from rnn import rnn
3
+
4
+
5
+ import numpy as np
6
+ import svgwrite
7
+
8
+
9
+ import logging
10
+ import os
11
+
12
+
13
+ class Hand(object):
14
+
15
+ def __init__(self):
16
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
17
+ self.nn = rnn(
18
+ log_dir='logs',
19
+ checkpoint_dir='checkpoints',
20
+ prediction_dir='predictions',
21
+ learning_rates=[.0001, .00005, .00002],
22
+ batch_sizes=[32, 64, 64],
23
+ patiences=[1500, 1000, 500],
24
+ beta1_decays=[.9, .9, .9],
25
+ validation_batch_size=32,
26
+ optimizer='rms',
27
+ num_training_steps=100000,
28
+ warm_start_init_step=17900,
29
+ regularization_constant=0.0,
30
+ keep_prob=1.0,
31
+ enable_parameter_averaging=False,
32
+ min_steps_to_checkpoint=2000,
33
+ log_interval=20,
34
+ logging_level=logging.CRITICAL,
35
+ grad_clip=10,
36
+ lstm_size=400,
37
+ output_mixture_components=20,
38
+ attention_mixture_components=10
39
+ )
40
+ self.nn.restore()
41
+
42
+ def write(self, filename, lines, biases=None, styles=None, stroke_colors=None, stroke_widths=None):
43
+ valid_char_set = set(drawing.alphabet)
44
+ for line_num, line in enumerate(lines):
45
+ if len(line) > 75:
46
+ raise ValueError(
47
+ (
48
+ "Each line must be at most 75 characters. "
49
+ "Line {} contains {}"
50
+ ).format(line_num, len(line))
51
+ )
52
+
53
+ for char in line:
54
+ if char not in valid_char_set:
55
+ raise ValueError(
56
+ (
57
+ "Invalid character {} detected in line {}. "
58
+ "Valid character set is {}"
59
+ ).format(char, line_num, valid_char_set)
60
+ )
61
+
62
+ strokes = self._sample(lines, biases=biases, styles=styles)
63
+ self._draw(strokes, lines, filename, stroke_colors=stroke_colors, stroke_widths=stroke_widths)
64
+
65
+ def _sample(self, lines, biases=None, styles=None):
66
+ num_samples = len(lines)
67
+ max_tsteps = 40*max([len(i) for i in lines])
68
+ biases = biases if biases is not None else [0.5]*num_samples
69
+
70
+ x_prime = np.zeros([num_samples, 1200, 3])
71
+ x_prime_len = np.zeros([num_samples])
72
+ chars = np.zeros([num_samples, 120])
73
+ chars_len = np.zeros([num_samples])
74
+
75
+ if styles is not None:
76
+ for i, (cs, style) in enumerate(zip(lines, styles)):
77
+ x_p = np.load('styles/style-{}-strokes.npy'.format(style))
78
+ c_p = np.load('styles/style-{}-chars.npy'.format(style)).tostring().decode('utf-8')
79
+
80
+ c_p = str(c_p) + " " + cs
81
+ c_p = drawing.encode_ascii(c_p)
82
+ c_p = np.array(c_p)
83
+
84
+ x_prime[i, :len(x_p), :] = x_p
85
+ x_prime_len[i] = len(x_p)
86
+ chars[i, :len(c_p)] = c_p
87
+ chars_len[i] = len(c_p)
88
+
89
+ else:
90
+ for i in range(num_samples):
91
+ encoded = drawing.encode_ascii(lines[i])
92
+ chars[i, :len(encoded)] = encoded
93
+ chars_len[i] = len(encoded)
94
+
95
+ [samples] = self.nn.session.run(
96
+ [self.nn.sampled_sequence],
97
+ feed_dict={
98
+ self.nn.prime: styles is not None,
99
+ self.nn.x_prime: x_prime,
100
+ self.nn.x_prime_len: x_prime_len,
101
+ self.nn.num_samples: num_samples,
102
+ self.nn.sample_tsteps: max_tsteps,
103
+ self.nn.c: chars,
104
+ self.nn.c_len: chars_len,
105
+ self.nn.bias: biases
106
+ }
107
+ )
108
+ samples = [sample[~np.all(sample == 0.0, axis=1)] for sample in samples]
109
+ return samples
110
+
111
+ def _draw(self, strokes, lines, filename, stroke_colors=None, stroke_widths=None):
112
+ stroke_colors = stroke_colors or ['black']*len(lines)
113
+ stroke_widths = stroke_widths or [2]*len(lines)
114
+
115
+ line_height = 60
116
+ view_width = 1000
117
+ view_height = line_height*(len(strokes) + 1)
118
+
119
+ dwg = svgwrite.Drawing(filename=filename)
120
+ dwg.viewbox(width=view_width, height=view_height)
121
+ dwg.add(dwg.rect(insert=(0, 0), size=(view_width, view_height), fill='white'))
122
+
123
+ initial_coord = np.array([0, -(3*line_height / 4)])
124
+ for offsets, line, color, width in zip(strokes, lines, stroke_colors, stroke_widths):
125
+
126
+ if not line:
127
+ initial_coord[1] -= line_height
128
+ continue
129
+
130
+ offsets[:, :2] *= 1.5
131
+ strokes = drawing.offsets_to_coords(offsets)
132
+ strokes = drawing.denoise(strokes)
133
+ strokes[:, :2] = drawing.align(strokes[:, :2])
134
+
135
+ strokes[:, 1] *= -1
136
+ strokes[:, :2] -= strokes[:, :2].min() + initial_coord
137
+ strokes[:, 0] += (view_width - strokes[:, 0].max()) / 2
138
+
139
+ prev_eos = 1.0
140
+ p = "M{},{} ".format(0, 0)
141
+ for x, y, eos in zip(*strokes.T):
142
+ p += '{}{},{} '.format('M' if prev_eos == 1.0 else 'L', x, y)
143
+ prev_eos = eos
144
+ path = svgwrite.path.Path(p)
145
+ path = path.stroke(color=color, width=width, linecap='round').fill("none")
146
+ dwg.add(path)
147
+
148
+ initial_coord[1] -= line_height
149
+
150
+ dwg.save()
handwriting_api.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI as FastAPIApp, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ import uvicorn
5
+ import os
6
+ from hand import Hand
7
+ from typing import List, Optional
8
+
9
+ app = FastAPIApp()
10
+ hand = Hand()
11
+
12
+ class InputData(BaseModel):
13
+ text: str
14
+ bias: Optional[float] = 0.75
15
+ style: Optional[int] = 9
16
+ stroke_colors: Optional[List[str]] = ['black']
17
+ stroke_widths: Optional[List[float]] = [2]
18
+
19
+ def validate_input(data: InputData):
20
+ if len(data.text) > 75:
21
+ raise ValueError("Text must be 75 characters or less")
22
+ if not (0.5 <= data.bias <= 1.0):
23
+ raise ValueError("Bias must be between 0.5 and 1.0")
24
+ if not (0 <= data.style <= 12):
25
+ raise ValueError("Style must be between 0 and 12")
26
+ if len(data.stroke_colors) != len(data.text.split('\n')):
27
+ raise ValueError("Number of stroke colors must match number of lines")
28
+ if len(data.stroke_widths) != len(data.text.split('\n')):
29
+ raise ValueError("Number of stroke widths must match number of lines")
30
+
31
+ @app.post("/synthesize")
32
+ def synthesize(data: InputData):
33
+ try:
34
+ validate_input(data)
35
+ lines = data.text.split('\n')
36
+ biases = [data.bias] * len(lines)
37
+ styles = [data.style] * len(lines)
38
+
39
+ hand.write(
40
+ filename='img/output.svg',
41
+ lines=lines,
42
+ biases=biases,
43
+ styles=styles,
44
+ stroke_colors=data.stroke_colors,
45
+ stroke_widths=data.stroke_widths
46
+ )
47
+ return {"result": "Handwriting synthesized successfully", "output_file": "img/output.svg"}
48
+ except ValueError as e:
49
+ raise HTTPException(status_code=400, detail=str(e))
50
+
51
+ if __name__ == "__main__":
52
+ uvicorn.run(app, host="0.0.0.0", port=8000)
lyrics.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """lyrics taken from https://www.azlyrics.com/"""
2
+
3
+ all_star = """Somebody once told me the world is gonna roll me
4
+ I ain't the sharpest tool in the shed
5
+ She was looking kind of dumb with her finger and her thumb
6
+ In the shape of an "L" on her forehead
7
+
8
+ Well, the years start coming and they don't stop coming
9
+ Fed to the rules and I hit the ground running
10
+ Didn't make sense not to live for fun
11
+ Your brain gets smart but your head gets dumb
12
+
13
+ So much to do, so much to see
14
+ So what's wrong with taking the back streets?
15
+ You'll never know if you don't go
16
+ You'll never shine if you don't glow
17
+
18
+ Hey, now, you're an All Star, get your game on, go play
19
+ Hey, now, you're a Rock Star, get the show on, get paid
20
+ And all that glitters is gold
21
+ Only shooting stars break the mold
22
+
23
+ It's a cool place and they say it gets colder
24
+ You're bundled up now wait 'til you get older
25
+ But the meteor men beg to differ
26
+ Judging by the hole in the satellite picture
27
+
28
+ The ice we skate is getting pretty thin
29
+ The water's getting warm so you might as well swim
30
+ My world's on fire. How about yours?
31
+ That's the way I like it and I'll never get bored.
32
+
33
+ Somebody once asked could I spare some change for gas
34
+ I need to get myself away from this place
35
+ I said yep, what a concept
36
+ I could use a little fuel myself
37
+ And we could all use a little change
38
+
39
+ Well, the years start coming and they don't stop coming
40
+ Fed to the rules and I hit the ground running
41
+ Didn't make sense not to live for fun
42
+ Your brain gets smart but your head gets dumb
43
+
44
+ So much to do, so much to see
45
+ So what's wrong with taking the back streets?
46
+ You'll never know if you don't go
47
+ You'll never shine if you don't glow.
48
+
49
+ And all that glitters is gold
50
+ Only shooting stars break the mold"""
51
+
52
+ downtown = """Making my way downtown
53
+ Walking fast
54
+ Faces pass
55
+ And I'm home-bound
56
+
57
+ Staring blankly ahead
58
+ Just making my way
59
+ Making a way
60
+ Through the crowd
61
+
62
+ And I need you
63
+ And I miss you
64
+ And now I wonder
65
+
66
+ If I could fall into the sky
67
+ Do you think time would pass me by?
68
+ 'Cause you know I'd walk a thousand miles
69
+ If I could just see you tonight
70
+
71
+ It's always times like these
72
+ When I think of you
73
+ And I wonder if you ever think of me
74
+ 'Cause everything's so wrong
75
+ And I don't belong
76
+ Living in your precious memory
77
+
78
+ 'Cause I need you
79
+ And I miss you
80
+ And now I wonder
81
+
82
+ If I could fall into the sky
83
+ Do you think time would pass me by?
84
+ 'Cause you know I'd walk a thousand miles
85
+ If I could just see you tonight
86
+
87
+ And I, I don't wanna let you know
88
+ I, I drown in your memory
89
+ I, I don't wanna let this go
90
+ I, I don't
91
+
92
+ Making my way downtown
93
+ Walking fast
94
+ Faces pass
95
+ And I'm home-bound
96
+
97
+ Staring blankly ahead
98
+ Just making my way
99
+ Making a way
100
+ Through the crowd
101
+
102
+ And I still need you
103
+ And I still miss you
104
+ And now I wonder
105
+
106
+ If I could fall into the sky
107
+ Do you think time would pass us by?
108
+ 'Cause you know I'd walk a thousand miles
109
+ If I could just see you
110
+
111
+ If I could fall into the sky
112
+ Do you think time would pass me by?
113
+ 'Cause you know I'd walk a thousand miles
114
+ If I could just see you
115
+ If I could just hold you tonight"""
116
+
117
+ give_up = """We're no strangers to love
118
+ You know the rules and so do I
119
+ A full commitment's what I'm thinking of
120
+ You wouldn't get this from any other guy
121
+
122
+ I just wanna tell you how I'm feeling
123
+ Gotta make you understand
124
+
125
+ Never gonna give you up
126
+ Never gonna let you down
127
+ Never gonna run around and desert you
128
+ Never gonna make you cry
129
+ Never gonna say goodbye
130
+ Never gonna tell a lie and hurt you
131
+
132
+ We've known each other for so long
133
+ Your heart's been aching, but
134
+ You're too shy to say it
135
+ Inside, we both know what's been going on
136
+ We know the game and we're gonna play it
137
+
138
+ And if you ask me how I'm feeling
139
+ Don't tell me you're too blind to see
140
+
141
+ Never gonna give you up
142
+ Never gonna let you down
143
+ Never gonna run around and desert you
144
+ Never gonna make you cry
145
+ Never gonna say goodbye
146
+ Never gonna tell a lie and hurt you
147
+
148
+ Never gonna give you up
149
+ Never gonna let you down
150
+ Never gonna run around and desert you
151
+ Never gonna make you cry
152
+ Never gonna say goodbye
153
+ Never gonna tell a lie and hurt you
154
+
155
+ (Ooh, give you up)
156
+ (Ooh, give you up)
157
+ Never gonna give, never gonna give
158
+ (Give you up)
159
+ Never gonna give, never gonna give
160
+ (Give you up)
161
+
162
+ We've known each other for so long
163
+ Your heart's been aching, but
164
+ You're too shy to say it
165
+ Inside, we both know what's been going on
166
+ We know the game and we're gonna play it
167
+
168
+ I just wanna tell you how I'm feeling
169
+ Gotta make you understand
170
+
171
+ Never gonna give you up
172
+ Never gonna let you down
173
+ Never gonna run around and desert you
174
+ Never gonna make you cry
175
+ Never gonna say goodbye
176
+ Never gonna tell a lie and hurt you
177
+
178
+ Never gonna give you up
179
+ Never gonna let you down
180
+ Never gonna run around and desert you
181
+ Never gonna make you cry
182
+ Never gonna say goodbye
183
+ Never gonna tell a lie and hurt you
184
+
185
+ Never gonna give you up
186
+ Never gonna let you down
187
+ Never gonna run around and desert you
188
+ Never gonna make you cry
189
+ Never gonna say goodbye
190
+ Never gonna tell a lie and hurt you"""
prepare_data.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ from xml.etree import ElementTree
4
+
5
+ import numpy as np
6
+
7
+ import drawing
8
+
9
+
10
+ def get_stroke_sequence(filename):
11
+ tree = ElementTree.parse(filename).getroot()
12
+ strokes = [i for i in tree if i.tag == 'StrokeSet'][0]
13
+
14
+ coords = []
15
+ for stroke in strokes:
16
+ for i, point in enumerate(stroke):
17
+ coords.append([
18
+ int(point.attrib['x']),
19
+ -1*int(point.attrib['y']),
20
+ int(i == len(stroke) - 1)
21
+ ])
22
+ coords = np.array(coords)
23
+
24
+ coords = drawing.align(coords)
25
+ coords = drawing.denoise(coords)
26
+ offsets = drawing.coords_to_offsets(coords)
27
+ offsets = offsets[:drawing.MAX_STROKE_LEN]
28
+ offsets = drawing.normalize(offsets)
29
+ return offsets
30
+
31
+
32
+ def get_ascii_sequences(filename):
33
+ sequences = open(filename, 'r').read()
34
+ sequences = sequences.replace(r'%%%%%%%%%%%', '\n')
35
+ sequences = [i.strip() for i in sequences.split('\n')]
36
+ lines = sequences[sequences.index('CSR:') + 2:]
37
+ lines = [line.strip() for line in lines if line.strip()]
38
+ lines = [drawing.encode_ascii(line)[:drawing.MAX_CHAR_LEN] for line in lines]
39
+ return lines
40
+
41
+
42
+ def collect_data():
43
+ fnames = []
44
+ for dirpath, dirnames, filenames in os.walk('data/raw/ascii/'):
45
+ if dirnames:
46
+ continue
47
+ for filename in filenames:
48
+ if filename.startswith('.'):
49
+ continue
50
+ fnames.append(os.path.join(dirpath, filename))
51
+
52
+ # low quality samples (selected by collecting samples to
53
+ # which the trained model assigned very low likelihood)
54
+ blacklist = set(np.load('data/blacklist.npy'))
55
+
56
+ stroke_fnames, transcriptions, writer_ids = [], [], []
57
+ for i, fname in enumerate(fnames):
58
+ print(i, fname)
59
+ if fname == 'data/raw/ascii/z01/z01-000/z01-000z.txt':
60
+ continue
61
+
62
+ head, tail = os.path.split(fname)
63
+ last_letter = os.path.splitext(fname)[0][-1]
64
+ last_letter = last_letter if last_letter.isalpha() else ''
65
+
66
+ line_stroke_dir = head.replace('ascii', 'lineStrokes')
67
+ line_stroke_fname_prefix = os.path.split(head)[-1] + last_letter + '-'
68
+
69
+ if not os.path.isdir(line_stroke_dir):
70
+ continue
71
+ line_stroke_fnames = sorted([f for f in os.listdir(line_stroke_dir)
72
+ if f.startswith(line_stroke_fname_prefix)])
73
+ if not line_stroke_fnames:
74
+ continue
75
+
76
+ original_dir = head.replace('ascii', 'original')
77
+ original_xml = os.path.join(original_dir, 'strokes' + last_letter + '.xml')
78
+ tree = ElementTree.parse(original_xml)
79
+ root = tree.getroot()
80
+
81
+ general = root.find('General')
82
+ if general is not None:
83
+ writer_id = int(general[0].attrib.get('writerID', '0'))
84
+ else:
85
+ writer_id = int('0')
86
+
87
+ ascii_sequences = get_ascii_sequences(fname)
88
+ assert len(ascii_sequences) == len(line_stroke_fnames)
89
+
90
+ for ascii_seq, line_stroke_fname in zip(ascii_sequences, line_stroke_fnames):
91
+ if line_stroke_fname in blacklist:
92
+ continue
93
+
94
+ stroke_fnames.append(os.path.join(line_stroke_dir, line_stroke_fname))
95
+ transcriptions.append(ascii_seq)
96
+ writer_ids.append(writer_id)
97
+
98
+ return stroke_fnames, transcriptions, writer_ids
99
+
100
+
101
+ if __name__ == '__main__':
102
+ print('traversing data directory...')
103
+ stroke_fnames, transcriptions, writer_ids = collect_data()
104
+
105
+ print('dumping to numpy arrays...')
106
+ x = np.zeros([len(stroke_fnames), drawing.MAX_STROKE_LEN, 3], dtype=np.float32)
107
+ x_len = np.zeros([len(stroke_fnames)], dtype=np.int16)
108
+ c = np.zeros([len(stroke_fnames), drawing.MAX_CHAR_LEN], dtype=np.int8)
109
+ c_len = np.zeros([len(stroke_fnames)], dtype=np.int8)
110
+ w_id = np.zeros([len(stroke_fnames)], dtype=np.int16)
111
+ valid_mask = np.zeros([len(stroke_fnames)], dtype=np.bool)
112
+
113
+ for i, (stroke_fname, c_i, w_id_i) in enumerate(zip(stroke_fnames, transcriptions, writer_ids)):
114
+ if i % 200 == 0:
115
+ print(i, '\t', '/', len(stroke_fnames))
116
+ x_i = get_stroke_sequence(stroke_fname)
117
+ valid_mask[i] = ~np.any(np.linalg.norm(x_i[:, :2], axis=1) > 60)
118
+
119
+ x[i, :len(x_i), :] = x_i
120
+ x_len[i] = len(x_i)
121
+
122
+ c[i, :len(c_i)] = c_i
123
+ c_len[i] = len(c_i)
124
+
125
+ w_id[i] = w_id_i
126
+
127
+ if not os.path.isdir('data/processed'):
128
+ os.makedirs('data/processed')
129
+
130
+ np.save('data/processed/x.npy', x[valid_mask])
131
+ np.save('data/processed/x_len.npy', x_len[valid_mask])
132
+ np.save('data/processed/c.npy', c[valid_mask])
133
+ np.save('data/processed/c_len.npy', c_len[valid_mask])
134
+ np.save('data/processed/w_id.npy', w_id[valid_mask])
readme.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![](img/banner.svg)
2
+ # Handwriting Synthesis
3
+ Implementation of the handwriting synthesis experiments in the paper <a href="https://arxiv.org/abs/1308.0850">Generating Sequences with Recurrent Neural Networks</a> by Alex Graves. The implementation closely follows the original paper, with a few slight deviations, and the generated samples are of similar quality to those presented in the paper.
4
+
5
+ Web demo from original github (which I forked) is available <a href="https://seanvasquez.com/handwriting-generation/">here</a>.
6
+
7
+ ## Fork Notes
8
+ Fork has the following changes from the original
9
+ 1. Updated to work with TensorFlow 2.15.0 (current as of December 2023), but uses the v1 compat mechanism. Honestly I hacked and burned through many errors in an hour and a half (I forked at 7:05 PM and am writing this at 8:35 PM) and just verified by running the demo.py and looking through the output image files. It didn't create a banner.svg file (no idea if it was supposed to), and it gets a lot of deprecation warnings so use at your own risk. TensorFlow will likely abandon some of this V1 compat stuff in the relatively near future, but this should work as long as you can still get 2.15.0 for whatever python version you have.
10
+ 2. I'm going to split the Hand class into its own file (per the original author's suggestion).
11
+ 3. I left everything else below this text alone (except striking the "split Hand class" request).
12
+ 4. If you think my fork is super sloppy (you're right) and want to do it right- the major job is to convert the deprecated `tf.nn.rnn_cell.LSTMCell` and replace it with `tf.keras.layers.LSTMCell`. It is not a drop in replacement. Have at it.
13
+
14
+ ## Usage
15
+ ```python
16
+ lines = [
17
+ "Now this is a story all about how",
18
+ "My life got flipped turned upside down",
19
+ "And I'd like to take a minute, just sit right there",
20
+ "I'll tell you how I became the prince of a town called Bel-Air",
21
+ ]
22
+ biases = [.75 for i in lines]
23
+ styles = [9 for i in lines]
24
+ stroke_colors = ['red', 'green', 'black', 'blue']
25
+ stroke_widths = [1, 2, 1, 2]
26
+
27
+ hand = Hand()
28
+ hand.write(
29
+ filename='img/usage_demo.svg',
30
+ lines=lines,
31
+ biases=biases,
32
+ styles=styles,
33
+ stroke_colors=stroke_colors,
34
+ stroke_widths=stroke_widths
35
+ )
36
+ ```
37
+ ![](img/usage_demo.svg)
38
+
39
+ ~~Currently, the `Hand` class must be imported from `demo.py`. If someone would like to package this project to make it more usable, please [contribute](#contribute).~~
40
+
41
+ A pretrained model is included, but if you'd like to train your own, read <a href='https://github.com/sjvasquez/handwriting-synthesis/tree/master/data/raw'>these instructions</a>.
42
+
43
+ ## Demonstrations
44
+ Below are a few hundred samples from the model, including some samples demonstrating the effect of priming and biasing the model. Loosely speaking, biasing controls the neatness of the samples and priming controls the style of the samples. The code for these demonstrations can be found in `demo.py`.
45
+
46
+ ### Demo #1:
47
+ The following samples were generated with a fixed style and fixed bias.
48
+
49
+ **Smash Mouth – All Star (<a href="https://www.azlyrics.com/lyrics/smashmouth/allstar.html">lyrics</a>)**
50
+ ![](img/all_star.svg)
51
+
52
+ ### Demo #2
53
+ The following samples were generated with varying style and fixed bias. Each verse is generated in a different style.
54
+
55
+ **Vanessa Carlton – A Thousand Miles (<a href="https://www.azlyrics.com/lyrics/vanessacarlton/athousandmiles.html">lyrics</a>)**
56
+ ![](img/downtown.svg)
57
+
58
+ ### Demo #3
59
+ The following samples were generated with a fixed style and varying bias. Each verse has a lower bias than the previous, with the last verse being unbiased.
60
+
61
+ **Leonard Cohen – Hallelujah (<a href="https://www.youtube.com/watch?v=dQw4w9WgXcQ">lyrics</a>)**
62
+ ![](img/give_up.svg)
63
+
64
+ ## Contribute
65
+ This project was intended to serve as a reference implementation for a research paper, but since the results are of decent quality, it may be worthwile to make the project more broadly usable. I plan to continue focusing on the machine learning side of things. That said, I'd welcome contributors who can:
66
+
67
+ - Package this, and otherwise make it look more like a usable software project and less like research code.
68
+ - Add support for more sophisticated drawing, animations, or anything else in this direction. Currently, the project only creates some simple svg files.
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib==3.9.4
2
+ pandas==2.2.3
3
+ scikit-learn==1.6.1
4
+ scipy==1.13.0
5
+ svgwrite==1.1.12
6
+ tensorflow-probability==0.23.0
7
+ tensorflow==2.15.0
8
+ fastapi==0.109.0
9
+ uvicorn==0.27.0
10
+ pydantic==2.5.0
11
+ cairosvg==2.7.1
12
+ gradio==4.29.0
13
+ python-multipart==0.0.9
14
+ httpx==0.27.0
15
+ requests==2.31.0
rnn.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+
4
+ import numpy as np
5
+ import tensorflow.compat.v1 as tf
6
+ tf.disable_v2_behavior()
7
+
8
+ import drawing
9
+ from data_frame import DataFrame
10
+ from rnn_cell import LSTMAttentionCell
11
+ from rnn_ops import rnn_free_run
12
+ from tf_base_model import TFBaseModel
13
+ from tf_utils import time_distributed_dense_layer
14
+
15
+
16
+ class DataReader(object):
17
+
18
+ def __init__(self, data_dir):
19
+ data_cols = ['x', 'x_len', 'c', 'c_len']
20
+ data = [np.load(os.path.join(data_dir, '{}.npy'.format(i))) for i in data_cols]
21
+
22
+ self.test_df = DataFrame(columns=data_cols, data=data)
23
+ self.train_df, self.val_df = self.test_df.train_test_split(train_size=0.95, random_state=2018)
24
+
25
+ print('train size', len(self.train_df))
26
+ print('val size', len(self.val_df))
27
+ print('test size', len(self.test_df))
28
+
29
+ def train_batch_generator(self, batch_size):
30
+ return self.batch_generator(
31
+ batch_size=batch_size,
32
+ df=self.train_df,
33
+ shuffle=True,
34
+ num_epochs=10000,
35
+ mode='train'
36
+ )
37
+
38
+ def val_batch_generator(self, batch_size):
39
+ return self.batch_generator(
40
+ batch_size=batch_size,
41
+ df=self.val_df,
42
+ shuffle=True,
43
+ num_epochs=10000,
44
+ mode='val'
45
+ )
46
+
47
+ def test_batch_generator(self, batch_size):
48
+ return self.batch_generator(
49
+ batch_size=batch_size,
50
+ df=self.test_df,
51
+ shuffle=False,
52
+ num_epochs=1,
53
+ mode='test'
54
+ )
55
+
56
+ def batch_generator(self, batch_size, df, shuffle=True, num_epochs=10000, mode='train'):
57
+ gen = df.batch_generator(
58
+ batch_size=batch_size,
59
+ shuffle=shuffle,
60
+ num_epochs=num_epochs,
61
+ allow_smaller_final_batch=(mode == 'test')
62
+ )
63
+ for batch in gen:
64
+ batch['x_len'] = batch['x_len'] - 1
65
+ max_x_len = np.max(batch['x_len'])
66
+ max_c_len = np.max(batch['c_len'])
67
+ batch['y'] = batch['x'][:, 1:max_x_len + 1, :]
68
+ batch['x'] = batch['x'][:, :max_x_len, :]
69
+ batch['c'] = batch['c'][:, :max_c_len]
70
+ yield batch
71
+
72
+
73
+ class rnn(TFBaseModel):
74
+
75
+ def __init__(
76
+ self,
77
+ lstm_size,
78
+ output_mixture_components,
79
+ attention_mixture_components,
80
+ **kwargs
81
+ ):
82
+ self.lstm_size = lstm_size
83
+ self.output_mixture_components = output_mixture_components
84
+ self.output_units = self.output_mixture_components*6 + 1
85
+ self.attention_mixture_components = attention_mixture_components
86
+ super(rnn, self).__init__(**kwargs)
87
+
88
+ def parse_parameters(self, z, eps=1e-8, sigma_eps=1e-4):
89
+ pis, sigmas, rhos, mus, es = tf.split(
90
+ z,
91
+ [
92
+ 1*self.output_mixture_components,
93
+ 2*self.output_mixture_components,
94
+ 1*self.output_mixture_components,
95
+ 2*self.output_mixture_components,
96
+ 1
97
+ ],
98
+ axis=-1
99
+ )
100
+ pis = tf.nn.softmax(pis, axis=-1)
101
+ sigmas = tf.clip_by_value(tf.exp(sigmas), sigma_eps, np.inf)
102
+ rhos = tf.clip_by_value(tf.tanh(rhos), eps - 1.0, 1.0 - eps)
103
+ es = tf.clip_by_value(tf.nn.sigmoid(es), eps, 1.0 - eps)
104
+ return pis, mus, sigmas, rhos, es
105
+
106
+ def NLL(self, y, lengths, pis, mus, sigmas, rho, es, eps=1e-8):
107
+ sigma_1, sigma_2 = tf.split(sigmas, 2, axis=2)
108
+ y_1, y_2, y_3 = tf.split(y, 3, axis=2)
109
+ mu_1, mu_2 = tf.split(mus, 2, axis=2)
110
+
111
+ norm = 1.0 / (2*np.pi*sigma_1*sigma_2 * tf.sqrt(1 - tf.square(rho)))
112
+ Z = tf.square((y_1 - mu_1) / (sigma_1)) + \
113
+ tf.square((y_2 - mu_2) / (sigma_2)) - \
114
+ 2*rho*(y_1 - mu_1)*(y_2 - mu_2) / (sigma_1*sigma_2)
115
+
116
+ exp = -1.0*Z / (2*(1 - tf.square(rho)))
117
+ gaussian_likelihoods = tf.exp(exp) * norm
118
+ gmm_likelihood = tf.reduce_sum(pis * gaussian_likelihoods, 2)
119
+ gmm_likelihood = tf.clip_by_value(gmm_likelihood, eps, np.inf)
120
+
121
+ bernoulli_likelihood = tf.squeeze(tf.where(tf.equal(tf.ones_like(y_3), y_3), es, 1 - es))
122
+
123
+ nll = -(tf.log(gmm_likelihood) + tf.log(bernoulli_likelihood))
124
+ sequence_mask = tf.logical_and(
125
+ tf.sequence_mask(lengths, maxlen=tf.shape(y)[1]),
126
+ tf.logical_not(tf.is_nan(nll)),
127
+ )
128
+ nll = tf.where(sequence_mask, nll, tf.zeros_like(nll))
129
+ num_valid = tf.reduce_sum(tf.cast(sequence_mask, tf.float32), axis=1)
130
+
131
+ sequence_loss = tf.reduce_sum(nll, axis=1) / tf.maximum(num_valid, 1.0)
132
+ element_loss = tf.reduce_sum(nll) / tf.maximum(tf.reduce_sum(num_valid), 1.0)
133
+ return sequence_loss, element_loss
134
+
135
+ def sample(self, cell):
136
+ initial_state = cell.zero_state(self.num_samples, dtype=tf.float32)
137
+ initial_input = tf.concat([
138
+ tf.zeros([self.num_samples, 2]),
139
+ tf.ones([self.num_samples, 1]),
140
+ ], axis=1)
141
+ return rnn_free_run(
142
+ cell=cell,
143
+ sequence_length=self.sample_tsteps,
144
+ initial_state=initial_state,
145
+ initial_input=initial_input,
146
+ scope='rnn'
147
+ )[1]
148
+
149
+ def primed_sample(self, cell):
150
+ initial_state = cell.zero_state(self.num_samples, dtype=tf.float32)
151
+ primed_state = tf.nn.dynamic_rnn(
152
+ inputs=self.x_prime,
153
+ cell=cell,
154
+ sequence_length=self.x_prime_len,
155
+ dtype=tf.float32,
156
+ initial_state=initial_state,
157
+ scope='rnn'
158
+ )[1]
159
+ return rnn_free_run(
160
+ cell=cell,
161
+ sequence_length=self.sample_tsteps,
162
+ initial_state=primed_state,
163
+ scope='rnn'
164
+ )[1]
165
+
166
+ def calculate_loss(self):
167
+ self.x = tf.placeholder(tf.float32, [None, None, 3])
168
+ self.y = tf.placeholder(tf.float32, [None, None, 3])
169
+ self.x_len = tf.placeholder(tf.int32, [None])
170
+ self.c = tf.placeholder(tf.int32, [None, None])
171
+ self.c_len = tf.placeholder(tf.int32, [None])
172
+
173
+ self.sample_tsteps = tf.placeholder(tf.int32, [])
174
+ self.num_samples = tf.placeholder(tf.int32, [])
175
+ self.prime = tf.placeholder(tf.bool, [])
176
+ self.x_prime = tf.placeholder(tf.float32, [None, None, 3])
177
+ self.x_prime_len = tf.placeholder(tf.int32, [None])
178
+ self.bias = tf.placeholder_with_default(
179
+ tf.zeros([self.num_samples], dtype=tf.float32), [None])
180
+
181
+ cell = LSTMAttentionCell(
182
+ lstm_size=self.lstm_size,
183
+ num_attn_mixture_components=self.attention_mixture_components,
184
+ attention_values=tf.one_hot(self.c, len(drawing.alphabet)),
185
+ attention_values_lengths=self.c_len,
186
+ num_output_mixture_components=self.output_mixture_components,
187
+ bias=self.bias
188
+ )
189
+ self.initial_state = cell.zero_state(tf.shape(self.x)[0], dtype=tf.float32)
190
+ outputs, self.final_state = tf.nn.dynamic_rnn(
191
+ inputs=self.x,
192
+ cell=cell,
193
+ sequence_length=self.x_len,
194
+ dtype=tf.float32,
195
+ initial_state=self.initial_state,
196
+ scope='rnn'
197
+ )
198
+ params = time_distributed_dense_layer(outputs, self.output_units, scope='rnn/gmm')
199
+ pis, mus, sigmas, rhos, es = self.parse_parameters(params)
200
+ sequence_loss, self.loss = self.NLL(self.y, self.x_len, pis, mus, sigmas, rhos, es)
201
+
202
+ self.sampled_sequence = tf.cond(
203
+ self.prime,
204
+ lambda: self.primed_sample(cell),
205
+ lambda: self.sample(cell)
206
+ )
207
+ return self.loss
208
+
209
+
210
+ if __name__ == '__main__':
211
+ dr = DataReader(data_dir='data/processed/')
212
+
213
+ nn = rnn(
214
+ reader=dr,
215
+ log_dir='logs',
216
+ checkpoint_dir='checkpoints',
217
+ prediction_dir='predictions',
218
+ learning_rates=[.0001, .00005, .00002],
219
+ batch_sizes=[32, 64, 64],
220
+ patiences=[1500, 1000, 500],
221
+ beta1_decays=[.9, .9, .9],
222
+ validation_batch_size=32,
223
+ optimizer='rms',
224
+ num_training_steps=100000,
225
+ warm_start_init_step=0,
226
+ regularization_constant=0.0,
227
+ keep_prob=1.0,
228
+ enable_parameter_averaging=False,
229
+ min_steps_to_checkpoint=2000,
230
+ log_interval=20,
231
+ grad_clip=10,
232
+ lstm_size=400,
233
+ output_mixture_components=20,
234
+ attention_mixture_components=10
235
+ )
236
+ nn.fit()
rnn_cell.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import tensorflow.compat.v1 as tf
4
+ tf.disable_v2_behavior()
5
+ import tensorflow_probability as tfp
6
+ tfd = tfp.distributions
7
+ import numpy as np
8
+
9
+ from tf_utils import dense_layer, shape
10
+
11
+
12
+ LSTMAttentionCellState = namedtuple(
13
+ 'LSTMAttentionCellState',
14
+ ['h1', 'c1', 'h2', 'c2', 'h3', 'c3', 'alpha', 'beta', 'kappa', 'w', 'phi']
15
+ )
16
+
17
+
18
+ class LSTMAttentionCell(tf.nn.rnn_cell.RNNCell):
19
+
20
+ def __init__(
21
+ self,
22
+ lstm_size,
23
+ num_attn_mixture_components,
24
+ attention_values,
25
+ attention_values_lengths,
26
+ num_output_mixture_components,
27
+ bias,
28
+ reuse=None,
29
+ ):
30
+ self.reuse = reuse
31
+ self.lstm_size = lstm_size
32
+ self.num_attn_mixture_components = num_attn_mixture_components
33
+ self.attention_values = attention_values
34
+ self.attention_values_lengths = attention_values_lengths
35
+ self.window_size = shape(self.attention_values, 2)
36
+ self.char_len = tf.shape(attention_values)[1]
37
+ self.batch_size = tf.shape(attention_values)[0]
38
+ self.num_output_mixture_components = num_output_mixture_components
39
+ self.output_units = 6*self.num_output_mixture_components + 1
40
+ self.bias = bias
41
+
42
+ @property
43
+ def state_size(self):
44
+ return LSTMAttentionCellState(
45
+ self.lstm_size,
46
+ self.lstm_size,
47
+ self.lstm_size,
48
+ self.lstm_size,
49
+ self.lstm_size,
50
+ self.lstm_size,
51
+ self.num_attn_mixture_components,
52
+ self.num_attn_mixture_components,
53
+ self.num_attn_mixture_components,
54
+ self.window_size,
55
+ self.char_len,
56
+ )
57
+
58
+ @property
59
+ def output_size(self):
60
+ return self.lstm_size
61
+
62
+ def zero_state(self, batch_size, dtype):
63
+ return LSTMAttentionCellState(
64
+ tf.zeros([batch_size, self.lstm_size]),
65
+ tf.zeros([batch_size, self.lstm_size]),
66
+ tf.zeros([batch_size, self.lstm_size]),
67
+ tf.zeros([batch_size, self.lstm_size]),
68
+ tf.zeros([batch_size, self.lstm_size]),
69
+ tf.zeros([batch_size, self.lstm_size]),
70
+ tf.zeros([batch_size, self.num_attn_mixture_components]),
71
+ tf.zeros([batch_size, self.num_attn_mixture_components]),
72
+ tf.zeros([batch_size, self.num_attn_mixture_components]),
73
+ tf.zeros([batch_size, self.window_size]),
74
+ tf.zeros([batch_size, self.char_len]),
75
+ )
76
+
77
+ def __call__(self, inputs, state, scope=None):
78
+ with tf.variable_scope(scope or type(self).__name__, reuse=tf.AUTO_REUSE):
79
+
80
+ # lstm 1
81
+ s1_in = tf.concat([state.w, inputs], axis=1)
82
+ cell1 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
83
+ s1_out, s1_state = cell1(s1_in, state=(state.c1, state.h1))
84
+
85
+ # attention
86
+ attention_inputs = tf.concat([state.w, inputs, s1_out], axis=1)
87
+ attention_params = dense_layer(attention_inputs, 3*self.num_attn_mixture_components, scope='attention')
88
+ alpha, beta, kappa = tf.split(tf.nn.softplus(attention_params), 3, axis=1)
89
+ kappa = state.kappa + kappa / 25.0
90
+ beta = tf.clip_by_value(beta, .01, np.inf)
91
+
92
+ kappa_flat, alpha_flat, beta_flat = kappa, alpha, beta
93
+ kappa, alpha, beta = tf.expand_dims(kappa, 2), tf.expand_dims(alpha, 2), tf.expand_dims(beta, 2)
94
+
95
+ enum = tf.reshape(tf.range(self.char_len), (1, 1, self.char_len))
96
+ u = tf.cast(tf.tile(enum, (self.batch_size, self.num_attn_mixture_components, 1)), tf.float32)
97
+ phi_flat = tf.reduce_sum(alpha*tf.exp(-tf.square(kappa - u) / beta), axis=1)
98
+
99
+ phi = tf.expand_dims(phi_flat, 2)
100
+ sequence_mask = tf.cast(tf.sequence_mask(self.attention_values_lengths, maxlen=self.char_len), tf.float32)
101
+ sequence_mask = tf.expand_dims(sequence_mask, 2)
102
+ w = tf.reduce_sum(phi*self.attention_values*sequence_mask, axis=1)
103
+
104
+ # lstm 2
105
+ s2_in = tf.concat([inputs, s1_out, w], axis=1)
106
+ cell2 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
107
+ s2_out, s2_state = cell2(s2_in, state=(state.c2, state.h2))
108
+
109
+ # lstm 3
110
+ s3_in = tf.concat([inputs, s2_out, w], axis=1)
111
+ cell3 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
112
+ s3_out, s3_state = cell3(s3_in, state=(state.c3, state.h3))
113
+
114
+ new_state = LSTMAttentionCellState(
115
+ s1_state.h,
116
+ s1_state.c,
117
+ s2_state.h,
118
+ s2_state.c,
119
+ s3_state.h,
120
+ s3_state.c,
121
+ alpha_flat,
122
+ beta_flat,
123
+ kappa_flat,
124
+ w,
125
+ phi_flat,
126
+ )
127
+
128
+ return s3_out, new_state
129
+
130
+ def output_function(self, state):
131
+ params = dense_layer(state.h3, self.output_units, scope='gmm', reuse=tf.AUTO_REUSE)
132
+ pis, mus, sigmas, rhos, es = self._parse_parameters(params)
133
+ mu1, mu2 = tf.split(mus, 2, axis=1)
134
+ mus = tf.stack([mu1, mu2], axis=2)
135
+ sigma1, sigma2 = tf.split(sigmas, 2, axis=1)
136
+
137
+ covar_matrix = [tf.square(sigma1), rhos*sigma1*sigma2,
138
+ rhos*sigma1*sigma2, tf.square(sigma2)]
139
+ covar_matrix = tf.stack(covar_matrix, axis=2)
140
+ covar_matrix = tf.reshape(covar_matrix, (self.batch_size, self.num_output_mixture_components, 2, 2))
141
+
142
+ mvn = tfd.MultivariateNormalFullCovariance(loc=mus, covariance_matrix=covar_matrix)
143
+ b = tfd.Bernoulli(probs=es)
144
+ c = tfd.Categorical(probs=pis)
145
+
146
+ sampled_e = b.sample()
147
+ sampled_coords = mvn.sample()
148
+ sampled_idx = c.sample()
149
+
150
+ idx = tf.stack([tf.range(self.batch_size), sampled_idx], axis=1)
151
+ coords = tf.gather_nd(sampled_coords, idx)
152
+ return tf.concat([coords, tf.cast(sampled_e, tf.float32)], axis=1)
153
+
154
+ def termination_condition(self, state):
155
+ char_idx = tf.cast(tf.argmax(state.phi, axis=1), tf.int32)
156
+ final_char = char_idx >= self.attention_values_lengths - 1
157
+ past_final_char = char_idx >= self.attention_values_lengths
158
+ output = self.output_function(state)
159
+ es = tf.cast(output[:, 2], tf.int32)
160
+ is_eos = tf.equal(es, tf.ones_like(es))
161
+ return tf.logical_or(tf.logical_and(final_char, is_eos), past_final_char)
162
+
163
+ def _parse_parameters(self, gmm_params, eps=1e-8, sigma_eps=1e-4):
164
+ pis, sigmas, rhos, mus, es = tf.split(
165
+ gmm_params,
166
+ [
167
+ 1*self.num_output_mixture_components,
168
+ 2*self.num_output_mixture_components,
169
+ 1*self.num_output_mixture_components,
170
+ 2*self.num_output_mixture_components,
171
+ 1
172
+ ],
173
+ axis=-1
174
+ )
175
+ pis = pis*(1 + tf.expand_dims(self.bias, 1))
176
+ sigmas = sigmas - tf.expand_dims(self.bias, 1)
177
+
178
+ pis = tf.nn.softmax(pis, axis=-1)
179
+ pis = tf.where(pis < .01, tf.zeros_like(pis), pis)
180
+ sigmas = tf.clip_by_value(tf.exp(sigmas), sigma_eps, np.inf)
181
+ rhos = tf.clip_by_value(tf.tanh(rhos), eps - 1.0, 1.0 - eps)
182
+ es = tf.clip_by_value(tf.nn.sigmoid(es), eps, 1.0 - eps)
183
+ es = tf.where(es < .01, tf.zeros_like(es), es)
184
+
185
+ return pis, mus, sigmas, rhos, es
rnn_ops.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.compat.v1 as tf
2
+ tf.disable_v2_behavior()
3
+ from tensorflow.python.framework import constant_op
4
+ from tensorflow.python.framework import dtypes
5
+ from tensorflow.python.framework import ops
6
+ from tensorflow.python.ops import array_ops
7
+ from tensorflow.python.ops import control_flow_ops
8
+ from tensorflow.python.ops import cond as control_flow_ops_cond
9
+ from tensorflow.python.ops import math_ops
10
+ from tensorflow.python.ops import tensor_array_ops
11
+ from tensorflow.python.ops import variable_scope as vs
12
+ from tensorflow.python.ops.rnn_cell_impl import _concat, assert_like_rnncell
13
+ from tensorflow.python.ops.rnn import _maybe_tensor_shape_from_tensor
14
+ from tensorflow.python.util import nest
15
+ from tensorflow.python.framework import tensor_shape
16
+
17
+
18
+
19
+ def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None):
20
+ """
21
+ raw_rnn adapted from the original tensorflow implementation
22
+ (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
23
+ to emit arbitrarily nested states for each time step (concatenated along the time axis)
24
+ in addition to the outputs at each timestep and the final state
25
+
26
+ returns (
27
+ states for all timesteps,
28
+ outputs for all timesteps,
29
+ final cell state,
30
+ )
31
+ """
32
+ assert_like_rnncell("dummy_name", cell)
33
+ if not callable(loop_fn):
34
+ raise TypeError("loop_fn must be a callable")
35
+
36
+ parallel_iterations = parallel_iterations or 32
37
+
38
+ # Create a new scope in which the caching device is either
39
+ # determined by the parent scope, or is set to place the cached
40
+ # Variable using the same placement as for the rest of the RNN.
41
+ with vs.variable_scope(scope or "rnn") as varscope:
42
+ if not tf.executing_eagerly():
43
+ if varscope.caching_device is None:
44
+ varscope.set_caching_device(lambda op: op.device)
45
+
46
+ time = constant_op.constant(0, dtype=dtypes.int32)
47
+ (elements_finished, next_input, initial_state, emit_structure,
48
+ init_loop_state) = loop_fn(time, None, None, None)
49
+ flat_input = nest.flatten(next_input)
50
+
51
+ # Need a surrogate loop state for the while_loop if none is available.
52
+ loop_state = (init_loop_state if init_loop_state is not None
53
+ else constant_op.constant(0, dtype=dtypes.int32))
54
+
55
+ input_shape = [input_.get_shape() for input_ in flat_input]
56
+ static_batch_size = input_shape[0][0]
57
+
58
+ for input_shape_i in input_shape:
59
+ # Static verification that batch sizes all match
60
+ static_batch_size.merge_with(input_shape_i[0])
61
+
62
+ batch_size = static_batch_size.value
63
+ const_batch_size = batch_size
64
+ if batch_size is None:
65
+ batch_size = array_ops.shape(flat_input[0])[0]
66
+
67
+ nest.assert_same_structure(initial_state, cell.state_size)
68
+ state = initial_state
69
+ flat_state = nest.flatten(state)
70
+ flat_state = [ops.convert_to_tensor(s) for s in flat_state]
71
+ state = nest.pack_sequence_as(structure=state,
72
+ flat_sequence=flat_state)
73
+
74
+ if emit_structure is not None:
75
+ flat_emit_structure = nest.flatten(emit_structure)
76
+ flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
77
+ array_ops.shape(emit) for emit in flat_emit_structure]
78
+ flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
79
+ else:
80
+ emit_structure = cell.output_size
81
+ flat_emit_size = nest.flatten(emit_structure)
82
+ flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
83
+
84
+ flat_state_size = [s.shape if s.shape.is_fully_defined() else
85
+ array_ops.shape(s) for s in flat_state]
86
+ flat_state_dtypes = [s.dtype for s in flat_state]
87
+
88
+ flat_emit_ta = [
89
+ tensor_array_ops.TensorArray(
90
+ dtype=dtype_i,
91
+ dynamic_size=True,
92
+ element_shape=(tensor_shape.TensorShape([const_batch_size])
93
+ .concatenate(_maybe_tensor_shape_from_tensor(size_i))),
94
+ size=0,
95
+ name="rnn_output_%d" % i
96
+ )
97
+ for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
98
+ ]
99
+ emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta)
100
+ flat_zero_emit = [
101
+ array_ops.zeros(_concat(batch_size, size_i), dtype_i)
102
+ for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]
103
+
104
+ zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit)
105
+
106
+ flat_state_ta = [
107
+ tensor_array_ops.TensorArray(
108
+ dtype=dtype_i,
109
+ dynamic_size=True,
110
+ element_shape=(tensor_shape.TensorShape([const_batch_size])
111
+ .concatenate(_maybe_tensor_shape_from_tensor(size_i))),
112
+ size=0,
113
+ name="rnn_state_%d" % i
114
+ )
115
+ for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
116
+ ]
117
+ state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta)
118
+
119
+ def condition(unused_time, elements_finished, *_):
120
+ return math_ops.logical_not(math_ops.reduce_all(elements_finished))
121
+
122
+ def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state):
123
+ (next_output, cell_state) = cell(current_input, state)
124
+
125
+ nest.assert_same_structure(state, cell_state)
126
+ nest.assert_same_structure(cell.output_size, next_output)
127
+
128
+ next_time = time + 1
129
+ (next_finished, next_input, next_state, emit_output,
130
+ next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state)
131
+
132
+ nest.assert_same_structure(state, next_state)
133
+ nest.assert_same_structure(current_input, next_input)
134
+ nest.assert_same_structure(emit_ta, emit_output)
135
+
136
+ # If loop_fn returns None for next_loop_state, just reuse the previous one.
137
+ loop_state = loop_state if next_loop_state is None else next_loop_state
138
+
139
+ def _copy_some_through(current, candidate):
140
+ """Copy some tensors through via array_ops.where."""
141
+ def copy_fn(cur_i, cand_i):
142
+ # TensorArray and scalar get passed through.
143
+ if isinstance(cur_i, tensor_array_ops.TensorArray):
144
+ return cand_i
145
+ if cur_i.shape.ndims == 0:
146
+ return cand_i
147
+ # Otherwise propagate the old or the new value.
148
+ with ops.colocate_with(cand_i):
149
+ return array_ops.where(elements_finished, cur_i, cand_i)
150
+ return nest.map_structure(copy_fn, current, candidate)
151
+
152
+ emit_output = _copy_some_through(zero_emit, emit_output)
153
+ next_state = _copy_some_through(state, next_state)
154
+
155
+ emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
156
+ state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state)
157
+
158
+ elements_finished = math_ops.logical_or(elements_finished, next_finished)
159
+
160
+ return (next_time, elements_finished, next_input, state_ta,
161
+ emit_ta, next_state, loop_state)
162
+
163
+ returned = tf.while_loop(
164
+ condition, body, loop_vars=[
165
+ time, elements_finished, next_input, state_ta,
166
+ emit_ta, state, loop_state],
167
+ parallel_iterations=parallel_iterations,
168
+ swap_memory=swap_memory
169
+ )
170
+
171
+ (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]
172
+
173
+ flat_states = nest.flatten(state_ta)
174
+ flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states]
175
+ states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states)
176
+
177
+ flat_outputs = nest.flatten(emit_ta)
178
+ flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs]
179
+ outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs)
180
+
181
+ return (states, outputs, final_state)
182
+
183
+
184
+ def rnn_teacher_force(inputs, cell, sequence_length, initial_state, scope='dynamic-rnn-teacher-force'):
185
+ """
186
+ Implementation of an rnn with teacher forcing inputs provided.
187
+ Used in the same way as tf.dynamic_rnn.
188
+ """
189
+ inputs = array_ops.transpose(inputs, (1, 0, 2))
190
+ inputs_ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
191
+ inputs_ta = inputs_ta.unstack(inputs)
192
+
193
+ def loop_fn(time, cell_output, cell_state, loop_state):
194
+ emit_output = cell_output
195
+ next_cell_state = initial_state if cell_output is None else cell_state
196
+
197
+ elements_finished = time >= sequence_length
198
+ finished = math_ops.reduce_all(elements_finished)
199
+
200
+ next_input = control_flow_ops_cond.cond(
201
+ finished,
202
+ lambda: array_ops.zeros([array_ops.shape(inputs)[1], inputs.shape.as_list()[2]], dtype=dtypes.float32),
203
+ lambda: inputs_ta.read(time)
204
+ )
205
+
206
+ next_loop_state = None
207
+ return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
208
+
209
+ states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope)
210
+ return states, outputs, final_state
211
+
212
+
213
+ def rnn_free_run(cell, initial_state, sequence_length, initial_input=None, scope='dynamic-rnn-free-run'):
214
+ """
215
+ Implementation of an rnn which feeds its feeds its predictions back to itself at the next timestep.
216
+
217
+ cell must implement two methods:
218
+
219
+ cell.output_function(state) which takes in the state at timestep t and returns
220
+ the cell input at timestep t+1.
221
+
222
+ cell.termination_condition(state) which returns a boolean tensor of shape
223
+ [batch_size] denoting which sequences no longer need to be sampled.
224
+ """
225
+ with vs.variable_scope(scope, reuse=True):
226
+ if initial_input is None:
227
+ initial_input = cell.output_function(initial_state)
228
+
229
+ def loop_fn(time, cell_output, cell_state, loop_state):
230
+ next_cell_state = initial_state if cell_output is None else cell_state
231
+
232
+ elements_finished = math_ops.logical_or(
233
+ time >= sequence_length,
234
+ cell.termination_condition(next_cell_state)
235
+ )
236
+ finished = math_ops.reduce_all(elements_finished)
237
+
238
+ next_input = control_flow_ops_cond.cond(
239
+ finished,
240
+ lambda: array_ops.zeros_like(initial_input),
241
+ lambda: initial_input if cell_output is None else cell.output_function(next_cell_state)
242
+ )
243
+ emit_output = next_input[0] if cell_output is None else next_input
244
+
245
+ next_loop_state = None
246
+ return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
247
+
248
+ states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope)
249
+ return states, outputs, final_state
tf_base_model.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ from collections import deque
3
+ from datetime import datetime
4
+ import logging
5
+ import os
6
+ import pprint as pp
7
+ import time
8
+
9
+ import numpy as np
10
+ import tensorflow.compat.v1 as tf
11
+ tf.disable_v2_behavior()
12
+
13
+ from tf_utils import shape
14
+
15
+
16
+ class TFBaseModel(object):
17
+
18
+ """Interface containing some boilerplate code for training tensorflow models.
19
+
20
+ Subclassing models must implement self.calculate_loss(), which returns a tensor for the batch loss.
21
+ Code for the training loop, parameter updates, checkpointing, and inference are implemented here and
22
+ subclasses are mainly responsible for building the computational graph beginning with the placeholders
23
+ and ending with the loss tensor.
24
+
25
+ Args:
26
+ reader: Class with attributes train_batch_generator, val_batch_generator, and test_batch_generator
27
+ that yield dictionaries mapping tf.placeholder names (as strings) to batch data (numpy arrays).
28
+ batch_size: Minibatch size.
29
+ learning_rate: Learning rate.
30
+ optimizer: 'rms' for RMSProp, 'adam' for Adam, 'sgd' for SGD
31
+ grad_clip: Clip gradients elementwise to have norm at most equal to grad_clip.
32
+ regularization_constant: Regularization constant applied to all trainable parameters.
33
+ keep_prob: 1 - p, where p is the dropout probability
34
+ early_stopping_steps: Number of steps to continue training after validation loss has
35
+ stopped decreasing.
36
+ warm_start_init_step: If nonzero, model will resume training a restored model beginning
37
+ at warm_start_init_step.
38
+ num_restarts: After validation loss plateaus, the best checkpoint will be restored and the
39
+ learning rate will be halved. This process will repeat num_restarts times.
40
+ enable_parameter_averaging: If true, model saves exponential weighted averages of parameters
41
+ to separate checkpoint file.
42
+ min_steps_to_checkpoint: Model only saves after min_steps_to_checkpoint training steps
43
+ have passed.
44
+ log_interval: Train and validation accuracies are logged every log_interval training steps.
45
+ loss_averaging_window: Train/validation losses are averaged over the last loss_averaging_window
46
+ training steps.
47
+ num_validation_batches: Number of batches to be used in validation evaluation at each step.
48
+ log_dir: Directory where logs are written.
49
+ checkpoint_dir: Directory where checkpoints are saved.
50
+ prediction_dir: Directory where predictions/outputs are saved.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ reader=None,
56
+ batch_sizes=[128],
57
+ num_training_steps=20000,
58
+ learning_rates=[.01],
59
+ beta1_decays=[.99],
60
+ optimizer='adam',
61
+ grad_clip=5,
62
+ regularization_constant=0.0,
63
+ keep_prob=1.0,
64
+ patiences=[3000],
65
+ warm_start_init_step=0,
66
+ enable_parameter_averaging=False,
67
+ min_steps_to_checkpoint=100,
68
+ log_interval=20,
69
+ logging_level=logging.INFO,
70
+ loss_averaging_window=100,
71
+ validation_batch_size=64,
72
+ log_dir='logs',
73
+ checkpoint_dir='checkpoints',
74
+ prediction_dir='predictions',
75
+ ):
76
+
77
+ assert len(batch_sizes) == len(learning_rates) == len(patiences)
78
+ self.batch_sizes = batch_sizes
79
+ self.learning_rates = learning_rates
80
+ self.beta1_decays = beta1_decays
81
+ self.patiences = patiences
82
+ self.num_restarts = len(batch_sizes) - 1
83
+ self.restart_idx = 0
84
+ self.update_train_params()
85
+
86
+ self.reader = reader
87
+ self.num_training_steps = num_training_steps
88
+ self.optimizer = optimizer
89
+ self.grad_clip = grad_clip
90
+ self.regularization_constant = regularization_constant
91
+ self.warm_start_init_step = warm_start_init_step
92
+ self.keep_prob_scalar = keep_prob
93
+ self.enable_parameter_averaging = enable_parameter_averaging
94
+ self.min_steps_to_checkpoint = min_steps_to_checkpoint
95
+ self.log_interval = log_interval
96
+ self.loss_averaging_window = loss_averaging_window
97
+ self.validation_batch_size = validation_batch_size
98
+
99
+ self.log_dir = log_dir
100
+ self.logging_level = logging_level
101
+ self.prediction_dir = prediction_dir
102
+ self.checkpoint_dir = checkpoint_dir
103
+ if self.enable_parameter_averaging:
104
+ self.checkpoint_dir_averaged = checkpoint_dir + '_avg'
105
+
106
+ self.init_logging(self.log_dir)
107
+ logging.info('\nnew run with parameters:\n{}'.format(pp.pformat(self.__dict__)))
108
+
109
+ self.graph = self.build_graph()
110
+ self.session = tf.Session(graph=self.graph)
111
+ logging.info('built graph')
112
+
113
+ def update_train_params(self):
114
+ self.batch_size = self.batch_sizes[self.restart_idx]
115
+ self.learning_rate = self.learning_rates[self.restart_idx]
116
+ self.beta1_decay = self.beta1_decays[self.restart_idx]
117
+ self.early_stopping_steps = self.patiences[self.restart_idx]
118
+
119
+ def calculate_loss(self):
120
+ raise NotImplementedError('subclass must implement this')
121
+
122
+ def fit(self):
123
+ with self.session.as_default():
124
+
125
+ if self.warm_start_init_step:
126
+ self.restore(self.warm_start_init_step)
127
+ step = self.warm_start_init_step
128
+ else:
129
+ self.session.run(self.init)
130
+ step = 0
131
+
132
+ train_generator = self.reader.train_batch_generator(self.batch_size)
133
+ val_generator = self.reader.val_batch_generator(self.validation_batch_size)
134
+
135
+ train_loss_history = deque(maxlen=self.loss_averaging_window)
136
+ val_loss_history = deque(maxlen=self.loss_averaging_window)
137
+ train_time_history = deque(maxlen=self.loss_averaging_window)
138
+ val_time_history = deque(maxlen=self.loss_averaging_window)
139
+ if not hasattr(self, 'metrics'):
140
+ self.metrics = {}
141
+
142
+ metric_histories = {
143
+ metric_name: deque(maxlen=self.loss_averaging_window) for metric_name in self.metrics
144
+ }
145
+ best_validation_loss, best_validation_tstep = float('inf'), 0
146
+
147
+ while step < self.num_training_steps:
148
+
149
+ # validation evaluation
150
+ val_start = time.time()
151
+ val_batch_df = next(val_generator)
152
+ val_feed_dict = {
153
+ getattr(self, placeholder_name, None): data
154
+ for placeholder_name, data in val_batch_df.items() if hasattr(self, placeholder_name)
155
+ }
156
+
157
+ val_feed_dict.update({self.learning_rate_var: self.learning_rate, self.beta1_decay_var: self.beta1_decay})
158
+ if hasattr(self, 'keep_prob'):
159
+ val_feed_dict.update({self.keep_prob: 1.0})
160
+ if hasattr(self, 'is_training'):
161
+ val_feed_dict.update({self.is_training: False})
162
+
163
+ results = self.session.run(
164
+ fetches=[self.loss] + self.metrics.values(),
165
+ feed_dict=val_feed_dict
166
+ )
167
+ val_loss = results[0]
168
+ val_metrics = results[1:] if len(results) > 1 else []
169
+ val_metrics = dict(zip(self.metrics.keys(), val_metrics))
170
+ val_loss_history.append(val_loss)
171
+ val_time_history.append(time.time() - val_start)
172
+ for key in val_metrics:
173
+ metric_histories[key].append(val_metrics[key])
174
+
175
+ if hasattr(self, 'monitor_tensors'):
176
+ for name, tensor in self.monitor_tensors.items():
177
+ [np_val] = self.session.run([tensor], feed_dict=val_feed_dict)
178
+ print(name)
179
+ print('min', np_val.min())
180
+ print('max', np_val.max())
181
+ print('mean', np_val.mean())
182
+ print('std', np_val.std())
183
+ print('nans', np.isnan(np_val).sum())
184
+ print()
185
+ print()
186
+ print()
187
+
188
+ # train step
189
+ train_start = time.time()
190
+ train_batch_df = next(train_generator)
191
+ train_feed_dict = {
192
+ getattr(self, placeholder_name, None): data
193
+ for placeholder_name, data in train_batch_df.items() if hasattr(self, placeholder_name)
194
+ }
195
+
196
+ train_feed_dict.update({self.learning_rate_var: self.learning_rate, self.beta1_decay_var: self.beta1_decay})
197
+ if hasattr(self, 'keep_prob'):
198
+ train_feed_dict.update({self.keep_prob: self.keep_prob_scalar})
199
+ if hasattr(self, 'is_training'):
200
+ train_feed_dict.update({self.is_training: True})
201
+
202
+ train_loss, _ = self.session.run(
203
+ fetches=[self.loss, self.step],
204
+ feed_dict=train_feed_dict
205
+ )
206
+ train_loss_history.append(train_loss)
207
+ train_time_history.append(time.time() - train_start)
208
+
209
+ if step % self.log_interval == 0:
210
+ avg_train_loss = sum(train_loss_history) / len(train_loss_history)
211
+ avg_val_loss = sum(val_loss_history) / len(val_loss_history)
212
+ avg_train_time = sum(train_time_history) / len(train_time_history)
213
+ avg_val_time = sum(val_time_history) / len(val_time_history)
214
+ metric_log = (
215
+ "[[step {:>8}]] "
216
+ "[[train {:>4}s]] loss: {:<12} "
217
+ "[[val {:>4}s]] loss: {:<12} "
218
+ ).format(
219
+ step,
220
+ round(avg_train_time, 4),
221
+ round(avg_train_loss, 8),
222
+ round(avg_val_time, 4),
223
+ round(avg_val_loss, 8),
224
+ )
225
+ early_stopping_metric = avg_val_loss
226
+ for metric_name, metric_history in metric_histories.items():
227
+ metric_val = sum(metric_history) / len(metric_history)
228
+ metric_log += '{}: {:<4} '.format(metric_name, round(metric_val, 4))
229
+ if metric_name == self.early_stopping_metric:
230
+ early_stopping_metric = metric_val
231
+
232
+ logging.info(metric_log)
233
+
234
+ if early_stopping_metric < best_validation_loss:
235
+ best_validation_loss = early_stopping_metric
236
+ best_validation_tstep = step
237
+ if step > self.min_steps_to_checkpoint:
238
+ self.save(step)
239
+ if self.enable_parameter_averaging:
240
+ self.save(step, averaged=True)
241
+
242
+ if step - best_validation_tstep > self.early_stopping_steps:
243
+
244
+ if self.num_restarts is None or self.restart_idx >= self.num_restarts:
245
+ logging.info('best validation loss of {} at training step {}'.format(
246
+ best_validation_loss, best_validation_tstep))
247
+ logging.info('early stopping - ending training.')
248
+ return
249
+
250
+ if self.restart_idx < self.num_restarts:
251
+ self.restore(best_validation_tstep)
252
+ step = best_validation_tstep
253
+ self.restart_idx += 1
254
+ self.update_train_params()
255
+ train_generator = self.reader.train_batch_generator(self.batch_size)
256
+
257
+ step += 1
258
+
259
+ if step <= self.min_steps_to_checkpoint:
260
+ best_validation_tstep = step
261
+ self.save(step)
262
+ if self.enable_parameter_averaging:
263
+ self.save(step, averaged=True)
264
+
265
+ logging.info('num_training_steps reached - ending training')
266
+
267
+ def predict(self, chunk_size=256):
268
+ if not os.path.isdir(self.prediction_dir):
269
+ os.makedirs(self.prediction_dir)
270
+
271
+ if hasattr(self, 'prediction_tensors'):
272
+ prediction_dict = {tensor_name: [] for tensor_name in self.prediction_tensors}
273
+
274
+ test_generator = self.reader.test_batch_generator(chunk_size)
275
+ for i, test_batch_df in enumerate(test_generator):
276
+ if i % 10 == 0:
277
+ print(i*len(test_batch_df))
278
+
279
+ test_feed_dict = {
280
+ getattr(self, placeholder_name, None): data
281
+ for placeholder_name, data in test_batch_df.items() if hasattr(self, placeholder_name)
282
+ }
283
+ if hasattr(self, 'keep_prob'):
284
+ test_feed_dict.update({self.keep_prob: 1.0})
285
+ if hasattr(self, 'is_training'):
286
+ test_feed_dict.update({self.is_training: False})
287
+
288
+ tensor_names, tf_tensors = zip(*self.prediction_tensors.items())
289
+ np_tensors = self.session.run(
290
+ fetches=tf_tensors,
291
+ feed_dict=test_feed_dict
292
+ )
293
+ for tensor_name, tensor in zip(tensor_names, np_tensors):
294
+ prediction_dict[tensor_name].append(tensor)
295
+
296
+ for tensor_name, tensor in prediction_dict.items():
297
+ np_tensor = np.concatenate(tensor, 0)
298
+ save_file = os.path.join(self.prediction_dir, '{}.npy'.format(tensor_name))
299
+ logging.info('saving {} with shape {} to {}'.format(tensor_name, np_tensor.shape, save_file))
300
+ np.save(save_file, np_tensor)
301
+
302
+ if hasattr(self, 'parameter_tensors'):
303
+ for tensor_name, tensor in self.parameter_tensors.items():
304
+ np_tensor = tensor.eval(self.session)
305
+
306
+ save_file = os.path.join(self.prediction_dir, '{}.npy'.format(tensor_name))
307
+ logging.info('saving {} with shape {} to {}'.format(tensor_name, np_tensor.shape, save_file))
308
+ np.save(save_file, np_tensor)
309
+
310
+ def save(self, step, averaged=False):
311
+ saver = self.saver_averaged if averaged else self.saver
312
+ checkpoint_dir = self.checkpoint_dir_averaged if averaged else self.checkpoint_dir
313
+ if not os.path.isdir(checkpoint_dir):
314
+ logging.info('creating checkpoint directory {}'.format(checkpoint_dir))
315
+ os.mkdir(checkpoint_dir)
316
+
317
+ model_path = os.path.join(checkpoint_dir, 'model')
318
+ logging.info('saving model to {}'.format(model_path))
319
+ saver.save(self.session, model_path, global_step=step)
320
+
321
+ def restore(self, step=None, averaged=False):
322
+ saver = self.saver_averaged if averaged else self.saver
323
+ checkpoint_dir = self.checkpoint_dir_averaged if averaged else self.checkpoint_dir
324
+ if not step:
325
+ model_path = tf.train.latest_checkpoint(checkpoint_dir)
326
+ logging.info('restoring model parameters from {}'.format(model_path))
327
+ saver.restore(self.session, model_path)
328
+ else:
329
+ model_path = os.path.join(
330
+ checkpoint_dir, 'model{}-{}'.format('_avg' if averaged else '', step)
331
+ )
332
+ logging.info('restoring model from {}'.format(model_path))
333
+ saver.restore(self.session, model_path)
334
+
335
+ def init_logging(self, log_dir):
336
+ if not os.path.isdir(log_dir):
337
+ os.makedirs(log_dir)
338
+
339
+ date_str = datetime.now().strftime('%Y-%m-%d_%H-%M')
340
+ log_file = 'log_{}.txt'.format(date_str)
341
+
342
+ try: # Python 2
343
+ reload(logging) # bad
344
+ except NameError: # Python 3
345
+ import logging
346
+ logging.basicConfig(
347
+ filename=os.path.join(log_dir, log_file),
348
+ level=self.logging_level,
349
+ format='[[%(asctime)s]] %(message)s',
350
+ datefmt='%m/%d/%Y %I:%M:%S %p'
351
+ )
352
+ logging.getLogger().addHandler(logging.StreamHandler())
353
+
354
+ def update_parameters(self, loss):
355
+ if self.regularization_constant != 0:
356
+ l2_norm = tf.reduce_sum([tf.sqrt(tf.reduce_sum(tf.square(param))) for param in tf.trainable_variables()])
357
+ loss = loss + self.regularization_constant*l2_norm
358
+
359
+ optimizer = self.get_optimizer(self.learning_rate_var, self.beta1_decay_var)
360
+ grads = optimizer.compute_gradients(loss)
361
+ clipped = [(tf.clip_by_value(g, -self.grad_clip, self.grad_clip), v_) for g, v_ in grads]
362
+
363
+ update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
364
+ with tf.control_dependencies(update_ops):
365
+ step = optimizer.apply_gradients(clipped, global_step=self.global_step)
366
+
367
+ if self.enable_parameter_averaging:
368
+ maintain_averages_op = self.ema.apply(tf.trainable_variables())
369
+ with tf.control_dependencies([step]):
370
+ self.step = tf.group(maintain_averages_op)
371
+ else:
372
+ self.step = step
373
+
374
+ logging.info('all parameters:')
375
+ logging.info(pp.pformat([(var.name, shape(var)) for var in tf.global_variables()]))
376
+
377
+ logging.info('trainable parameters:')
378
+ logging.info(pp.pformat([(var.name, shape(var)) for var in tf.trainable_variables()]))
379
+
380
+ logging.info('trainable parameter count:')
381
+ logging.info(str(np.sum(np.prod(shape(var)) for var in tf.trainable_variables())))
382
+
383
+ def get_optimizer(self, learning_rate, beta1_decay):
384
+ if self.optimizer == 'adam':
385
+ return tf.train.AdamOptimizer(learning_rate, beta1=beta1_decay)
386
+ elif self.optimizer == 'gd':
387
+ return tf.train.GradientDescentOptimizer(learning_rate)
388
+ elif self.optimizer == 'rms':
389
+ return tf.train.RMSPropOptimizer(learning_rate, decay=beta1_decay, momentum=0.9)
390
+ else:
391
+ assert False, 'optimizer must be adam, gd, or rms'
392
+
393
+ def build_graph(self):
394
+ with tf.Graph().as_default() as graph:
395
+ self.ema = tf.train.ExponentialMovingAverage(decay=0.99)
396
+ self.global_step = tf.Variable(0, trainable=False)
397
+ self.learning_rate_var = tf.Variable(0.0, trainable=False)
398
+ self.beta1_decay_var = tf.Variable(0.0, trainable=False)
399
+
400
+ self.loss = self.calculate_loss()
401
+ self.update_parameters(self.loss)
402
+
403
+ self.saver = tf.train.Saver(max_to_keep=1)
404
+ if self.enable_parameter_averaging:
405
+ self.saver_averaged = tf.train.Saver(self.ema.variables_to_restore(), max_to_keep=1)
406
+
407
+ self.init = tf.global_variables_initializer()
408
+ return graph
tf_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.compat.v1 as tf
2
+ tf.disable_v2_behavior()
3
+
4
+
5
+ def dense_layer(inputs, output_units, bias=True, activation=None, batch_norm=None,
6
+ dropout=None, scope='dense-layer', reuse=False):
7
+ """
8
+ Applies a dense layer to a 2D tensor of shape [batch_size, input_units]
9
+ to produce a tensor of shape [batch_size, output_units].
10
+ Args:
11
+ inputs: Tensor of shape [batch size, input_units].
12
+ output_units: Number of output units.
13
+ activation: activation function.
14
+ dropout: dropout keep prob.
15
+ Returns:
16
+ Tensor of shape [batch size, output_units].
17
+ """
18
+ with tf.variable_scope(scope, reuse=reuse):
19
+ W = tf.get_variable(
20
+ name='weights',
21
+ initializer=tf.compat.v1.variance_scaling_initializer(),
22
+ shape=[shape(inputs, -1), output_units]
23
+ )
24
+ z = tf.matmul(inputs, W)
25
+ if bias:
26
+ b = tf.get_variable(
27
+ name='biases',
28
+ initializer=tf.constant_initializer(),
29
+ shape=[output_units]
30
+ )
31
+ z = z + b
32
+
33
+ if batch_norm is not None:
34
+ z = tf.layers.batch_normalization(z, training=batch_norm, reuse=reuse)
35
+
36
+ z = activation(z) if activation else z
37
+ z = tf.nn.dropout(z, dropout) if dropout is not None else z
38
+ return z
39
+
40
+
41
+ def time_distributed_dense_layer(
42
+ inputs, output_units, bias=True, activation=None, batch_norm=None,
43
+ dropout=None, scope='time-distributed-dense-layer', reuse=False):
44
+ """
45
+ Applies a shared dense layer to each timestep of a tensor of shape
46
+ [batch_size, max_seq_len, input_units] to produce a tensor of shape
47
+ [batch_size, max_seq_len, output_units].
48
+
49
+ Args:
50
+ inputs: Tensor of shape [batch size, max sequence length, ...].
51
+ output_units: Number of output units.
52
+ activation: activation function.
53
+ dropout: dropout keep prob.
54
+
55
+ Returns:
56
+ Tensor of shape [batch size, max sequence length, output_units].
57
+ """
58
+ with tf.variable_scope(scope, reuse=reuse):
59
+ W = tf.get_variable(
60
+ name='weights',
61
+ initializer=tf.compat.v1.variance_scaling_initializer(),
62
+ shape=[shape(inputs, -1), output_units]
63
+ )
64
+ z = tf.einsum('ijk,kl->ijl', inputs, W)
65
+ if bias:
66
+ b = tf.get_variable(
67
+ name='biases',
68
+ initializer=tf.constant_initializer(),
69
+ shape=[output_units]
70
+ )
71
+ z = z + b
72
+
73
+ if batch_norm is not None:
74
+ z = tf.layers.batch_normalization(z, training=batch_norm, reuse=reuse)
75
+
76
+ z = activation(z) if activation else z
77
+ z = tf.nn.dropout(z, dropout) if dropout is not None else z
78
+ return z
79
+
80
+
81
+ def shape(tensor, dim=None):
82
+ """Get tensor shape/dimension as list/int"""
83
+ if dim is None:
84
+ return tensor.shape.as_list()
85
+ else:
86
+ return tensor.shape.as_list()[dim]
87
+
88
+
89
+ def rank(tensor):
90
+ """Get tensor rank as python list"""
91
+ return len(tensor.shape.as_list())