Spaces:
Running
Running
Upload 14 files
Browse files- data_frame.py +104 -0
- demo.py +62 -0
- drawing.py +216 -0
- hand.py +150 -0
- handwriting_api.py +52 -0
- lyrics.py +190 -0
- prepare_data.py +134 -0
- readme.md +68 -0
- requirements.txt +15 -0
- rnn.py +236 -0
- rnn_cell.py +185 -0
- rnn_ops.py +249 -0
- tf_base_model.py +408 -0
- tf_utils.py +91 -0
data_frame.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from sklearn.model_selection import train_test_split
|
6 |
+
|
7 |
+
|
8 |
+
class DataFrame(object):
|
9 |
+
|
10 |
+
"""Minimal pd.DataFrame analog for handling n-dimensional numpy matrices with additional
|
11 |
+
support for shuffling, batching, and train/test splitting.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
columns: List of names corresponding to the matrices in data.
|
15 |
+
data: List of n-dimensional data matrices ordered in correspondence with columns.
|
16 |
+
All matrices must have the same leading dimension. Data can also be fed a list of
|
17 |
+
instances of np.memmap, in which case RAM usage can be limited to the size of a
|
18 |
+
single batch.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, columns, data):
|
22 |
+
assert len(columns) == len(data), 'columns length does not match data length'
|
23 |
+
|
24 |
+
lengths = [mat.shape[0] for mat in data]
|
25 |
+
assert len(set(lengths)) == 1, 'all matrices in data must have same first dimension'
|
26 |
+
|
27 |
+
self.length = lengths[0]
|
28 |
+
self.columns = columns
|
29 |
+
self.data = data
|
30 |
+
self.dict = dict(zip(self.columns, self.data))
|
31 |
+
self.idx = np.arange(self.length)
|
32 |
+
|
33 |
+
def shapes(self):
|
34 |
+
return pd.Series(dict(zip(self.columns, [mat.shape for mat in self.data])))
|
35 |
+
|
36 |
+
def dtypes(self):
|
37 |
+
return pd.Series(dict(zip(self.columns, [mat.dtype for mat in self.data])))
|
38 |
+
|
39 |
+
def shuffle(self):
|
40 |
+
np.random.shuffle(self.idx)
|
41 |
+
|
42 |
+
def train_test_split(self, train_size, random_state=np.random.randint(1000), stratify=None):
|
43 |
+
train_idx, test_idx = train_test_split(
|
44 |
+
self.idx,
|
45 |
+
train_size=train_size,
|
46 |
+
random_state=random_state,
|
47 |
+
stratify=stratify
|
48 |
+
)
|
49 |
+
train_df = DataFrame(copy.copy(self.columns), [mat[train_idx] for mat in self.data])
|
50 |
+
test_df = DataFrame(copy.copy(self.columns), [mat[test_idx] for mat in self.data])
|
51 |
+
return train_df, test_df
|
52 |
+
|
53 |
+
def batch_generator(self, batch_size, shuffle=True, num_epochs=10000, allow_smaller_final_batch=False):
|
54 |
+
epoch_num = 0
|
55 |
+
while epoch_num < num_epochs:
|
56 |
+
if shuffle:
|
57 |
+
self.shuffle()
|
58 |
+
|
59 |
+
for i in range(0, self.length + 1, batch_size):
|
60 |
+
batch_idx = self.idx[i: i + batch_size]
|
61 |
+
if not allow_smaller_final_batch and len(batch_idx) != batch_size:
|
62 |
+
break
|
63 |
+
yield DataFrame(
|
64 |
+
columns=copy.copy(self.columns),
|
65 |
+
data=[mat[batch_idx].copy() for mat in self.data]
|
66 |
+
)
|
67 |
+
|
68 |
+
epoch_num += 1
|
69 |
+
|
70 |
+
def iterrows(self):
|
71 |
+
for i in self.idx:
|
72 |
+
yield self[i]
|
73 |
+
|
74 |
+
def mask(self, mask):
|
75 |
+
return DataFrame(copy.copy(self.columns), [mat[mask] for mat in self.data])
|
76 |
+
|
77 |
+
def concat(self, other_df):
|
78 |
+
mats = []
|
79 |
+
for column in self.columns:
|
80 |
+
mats.append(np.concatenate([self[column], other_df[column]], axis=0))
|
81 |
+
return DataFrame(copy.copy(self.columns), mats)
|
82 |
+
|
83 |
+
def items(self):
|
84 |
+
return self.dict.items()
|
85 |
+
|
86 |
+
def __iter__(self):
|
87 |
+
return self.dict.items().__iter__()
|
88 |
+
|
89 |
+
def __len__(self):
|
90 |
+
return self.length
|
91 |
+
|
92 |
+
def __getitem__(self, key):
|
93 |
+
if isinstance(key, str):
|
94 |
+
return self.dict[key]
|
95 |
+
|
96 |
+
elif isinstance(key, int):
|
97 |
+
return pd.Series(dict(zip(self.columns, [mat[self.idx[key]] for mat in self.data])))
|
98 |
+
|
99 |
+
def __setitem__(self, key, value):
|
100 |
+
assert value.shape[0] == len(self), 'matrix first dimension does not match'
|
101 |
+
if key not in self.columns:
|
102 |
+
self.columns.append(key)
|
103 |
+
self.data.append(value)
|
104 |
+
self.dict[key] = value
|
demo.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from hand import Hand
|
3 |
+
|
4 |
+
import lyrics
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == '__main__':
|
8 |
+
hand = Hand()
|
9 |
+
|
10 |
+
# usage demo
|
11 |
+
lines = [
|
12 |
+
"what's up how are you marcos ?"
|
13 |
+
]
|
14 |
+
biases = [.75 for i in lines]
|
15 |
+
styles = [9 for i in lines]
|
16 |
+
stroke_colors = ['red', 'green', 'black', 'blue']
|
17 |
+
stroke_widths = [1, 2, 1, 2]
|
18 |
+
|
19 |
+
hand.write(
|
20 |
+
filename='img/usage_demo.svg',
|
21 |
+
lines=lines,
|
22 |
+
biases=biases,
|
23 |
+
styles=styles,
|
24 |
+
stroke_colors=stroke_colors,
|
25 |
+
stroke_widths=stroke_widths
|
26 |
+
)
|
27 |
+
|
28 |
+
# demo number 1 - fixed bias, fixed style
|
29 |
+
lines = lyrics.all_star.split("\n")
|
30 |
+
biases = [.75 for i in lines]
|
31 |
+
styles = [12 for i in lines]
|
32 |
+
|
33 |
+
hand.write(
|
34 |
+
filename='img/all_star.svg',
|
35 |
+
lines=lines,
|
36 |
+
biases=biases,
|
37 |
+
styles=styles,
|
38 |
+
)
|
39 |
+
|
40 |
+
# demo number 2 - fixed bias, varying style
|
41 |
+
lines = lyrics.downtown.split("\n")
|
42 |
+
biases = [.75 for i in lines]
|
43 |
+
styles = np.cumsum(np.array([len(i) for i in lines]) == 0).astype(int)
|
44 |
+
|
45 |
+
hand.write(
|
46 |
+
filename='img/downtown.svg',
|
47 |
+
lines=lines,
|
48 |
+
biases=biases,
|
49 |
+
styles=styles,
|
50 |
+
)
|
51 |
+
|
52 |
+
# demo number 3 - varying bias, fixed style
|
53 |
+
lines = lyrics.give_up.split("\n")
|
54 |
+
biases = .2*np.flip(np.cumsum([len(i) == 0 for i in lines]), 0)
|
55 |
+
styles = [7 for i in lines]
|
56 |
+
|
57 |
+
hand.write(
|
58 |
+
filename='img/give_up.svg',
|
59 |
+
lines=lines,
|
60 |
+
biases=biases,
|
61 |
+
styles=styles,
|
62 |
+
)
|
drawing.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
from scipy.signal import savgol_filter
|
7 |
+
from scipy.interpolate import interp1d
|
8 |
+
|
9 |
+
|
10 |
+
alphabet = [
|
11 |
+
'\x00', ' ', '!', '"', '#', "'", '(', ')', ',', '-', '.',
|
12 |
+
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';',
|
13 |
+
'?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
|
14 |
+
'L', 'M', 'N', 'O', 'P', 'R', 'S', 'T', 'U', 'V', 'W', 'Y',
|
15 |
+
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
|
16 |
+
'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
|
17 |
+
'y', 'z'
|
18 |
+
]
|
19 |
+
alphabet_ord = list(map(ord, alphabet))
|
20 |
+
alpha_to_num = defaultdict(int, list(map(reversed, enumerate(alphabet))))
|
21 |
+
num_to_alpha = dict(enumerate(alphabet_ord))
|
22 |
+
|
23 |
+
MAX_STROKE_LEN = 1200
|
24 |
+
MAX_CHAR_LEN = 75
|
25 |
+
|
26 |
+
|
27 |
+
def align(coords):
|
28 |
+
"""
|
29 |
+
corrects for global slant/offset in handwriting strokes
|
30 |
+
"""
|
31 |
+
coords = np.copy(coords)
|
32 |
+
X, Y = coords[:, 0].reshape(-1, 1), coords[:, 1].reshape(-1, 1)
|
33 |
+
X = np.concatenate([np.ones([X.shape[0], 1]), X], axis=1)
|
34 |
+
offset, slope = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(Y).squeeze()
|
35 |
+
theta = np.arctan(slope)
|
36 |
+
rotation_matrix = np.array(
|
37 |
+
[[np.cos(theta), -np.sin(theta)],
|
38 |
+
[np.sin(theta), np.cos(theta)]]
|
39 |
+
)
|
40 |
+
coords[:, :2] = np.dot(coords[:, :2], rotation_matrix) - offset
|
41 |
+
return coords
|
42 |
+
|
43 |
+
|
44 |
+
def skew(coords, degrees):
|
45 |
+
"""
|
46 |
+
skews strokes by given degrees
|
47 |
+
"""
|
48 |
+
coords = np.copy(coords)
|
49 |
+
theta = degrees * np.pi/180
|
50 |
+
A = np.array([[np.cos(-theta), 0], [np.sin(-theta), 1]])
|
51 |
+
coords[:, :2] = np.dot(coords[:, :2], A)
|
52 |
+
return coords
|
53 |
+
|
54 |
+
|
55 |
+
def stretch(coords, x_factor, y_factor):
|
56 |
+
"""
|
57 |
+
stretches strokes along x and y axis
|
58 |
+
"""
|
59 |
+
coords = np.copy(coords)
|
60 |
+
coords[:, :2] *= np.array([x_factor, y_factor])
|
61 |
+
return coords
|
62 |
+
|
63 |
+
|
64 |
+
def add_noise(coords, scale):
|
65 |
+
"""
|
66 |
+
adds gaussian noise to strokes
|
67 |
+
"""
|
68 |
+
coords = np.copy(coords)
|
69 |
+
coords[1:, :2] += np.random.normal(loc=0.0, scale=scale, size=coords[1:, :2].shape)
|
70 |
+
return coords
|
71 |
+
|
72 |
+
|
73 |
+
def encode_ascii(ascii_string):
|
74 |
+
"""
|
75 |
+
encodes ascii string to array of ints
|
76 |
+
"""
|
77 |
+
return np.array(list(map(lambda x: alpha_to_num[x], ascii_string)) + [0])
|
78 |
+
|
79 |
+
|
80 |
+
def denoise(coords):
|
81 |
+
"""
|
82 |
+
smoothing filter to mitigate some artifacts of the data collection
|
83 |
+
"""
|
84 |
+
coords = np.split(coords, np.where(coords[:, 2] == 1)[0] + 1, axis=0)
|
85 |
+
new_coords = []
|
86 |
+
for stroke in coords:
|
87 |
+
if len(stroke) != 0:
|
88 |
+
x_new = savgol_filter(stroke[:, 0], 7, 3, mode='nearest')
|
89 |
+
y_new = savgol_filter(stroke[:, 1], 7, 3, mode='nearest')
|
90 |
+
xy_coords = np.hstack([x_new.reshape(-1, 1), y_new.reshape(-1, 1)])
|
91 |
+
stroke = np.concatenate([xy_coords, stroke[:, 2].reshape(-1, 1)], axis=1)
|
92 |
+
new_coords.append(stroke)
|
93 |
+
|
94 |
+
coords = np.vstack(new_coords)
|
95 |
+
return coords
|
96 |
+
|
97 |
+
|
98 |
+
def interpolate(coords, factor=2):
|
99 |
+
"""
|
100 |
+
interpolates strokes using cubic spline
|
101 |
+
"""
|
102 |
+
coords = np.split(coords, np.where(coords[:, 2] == 1)[0] + 1, axis=0)
|
103 |
+
new_coords = []
|
104 |
+
for stroke in coords:
|
105 |
+
|
106 |
+
if len(stroke) == 0:
|
107 |
+
continue
|
108 |
+
|
109 |
+
xy_coords = stroke[:, :2]
|
110 |
+
|
111 |
+
if len(stroke) > 3:
|
112 |
+
f_x = interp1d(np.arange(len(stroke)), stroke[:, 0], kind='cubic')
|
113 |
+
f_y = interp1d(np.arange(len(stroke)), stroke[:, 1], kind='cubic')
|
114 |
+
|
115 |
+
xx = np.linspace(0, len(stroke) - 1, factor*(len(stroke)))
|
116 |
+
yy = np.linspace(0, len(stroke) - 1, factor*(len(stroke)))
|
117 |
+
|
118 |
+
x_new = f_x(xx)
|
119 |
+
y_new = f_y(yy)
|
120 |
+
|
121 |
+
xy_coords = np.hstack([x_new.reshape(-1, 1), y_new.reshape(-1, 1)])
|
122 |
+
|
123 |
+
stroke_eos = np.zeros([len(xy_coords), 1])
|
124 |
+
stroke_eos[-1] = 1.0
|
125 |
+
stroke = np.concatenate([xy_coords, stroke_eos], axis=1)
|
126 |
+
new_coords.append(stroke)
|
127 |
+
|
128 |
+
coords = np.vstack(new_coords)
|
129 |
+
return coords
|
130 |
+
|
131 |
+
|
132 |
+
def normalize(offsets):
|
133 |
+
"""
|
134 |
+
normalizes strokes to median unit norm
|
135 |
+
"""
|
136 |
+
offsets = np.copy(offsets)
|
137 |
+
offsets[:, :2] /= np.median(np.linalg.norm(offsets[:, :2], axis=1))
|
138 |
+
return offsets
|
139 |
+
|
140 |
+
|
141 |
+
def coords_to_offsets(coords):
|
142 |
+
"""
|
143 |
+
convert from coordinates to offsets
|
144 |
+
"""
|
145 |
+
offsets = np.concatenate([coords[1:, :2] - coords[:-1, :2], coords[1:, 2:3]], axis=1)
|
146 |
+
offsets = np.concatenate([np.array([[0, 0, 1]]), offsets], axis=0)
|
147 |
+
return offsets
|
148 |
+
|
149 |
+
|
150 |
+
def offsets_to_coords(offsets):
|
151 |
+
"""
|
152 |
+
convert from offsets to coordinates
|
153 |
+
"""
|
154 |
+
return np.concatenate([np.cumsum(offsets[:, :2], axis=0), offsets[:, 2:3]], axis=1)
|
155 |
+
|
156 |
+
|
157 |
+
def draw(
|
158 |
+
offsets,
|
159 |
+
ascii_seq=None,
|
160 |
+
align_strokes=True,
|
161 |
+
denoise_strokes=True,
|
162 |
+
interpolation_factor=None,
|
163 |
+
save_file=None
|
164 |
+
):
|
165 |
+
strokes = offsets_to_coords(offsets)
|
166 |
+
|
167 |
+
if denoise_strokes:
|
168 |
+
strokes = denoise(strokes)
|
169 |
+
|
170 |
+
if interpolation_factor is not None:
|
171 |
+
strokes = interpolate(strokes, factor=interpolation_factor)
|
172 |
+
|
173 |
+
if align_strokes:
|
174 |
+
strokes[:, :2] = align(strokes[:, :2])
|
175 |
+
|
176 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
177 |
+
|
178 |
+
stroke = []
|
179 |
+
for x, y, eos in strokes:
|
180 |
+
stroke.append((x, y))
|
181 |
+
if eos == 1:
|
182 |
+
coords = zip(*stroke)
|
183 |
+
ax.plot(coords[0], coords[1], 'k')
|
184 |
+
stroke = []
|
185 |
+
if stroke:
|
186 |
+
coords = zip(*stroke)
|
187 |
+
ax.plot(coords[0], coords[1], 'k')
|
188 |
+
stroke = []
|
189 |
+
|
190 |
+
ax.set_xlim(-50, 600)
|
191 |
+
ax.set_ylim(-40, 40)
|
192 |
+
|
193 |
+
ax.set_aspect('equal')
|
194 |
+
plt.tick_params(
|
195 |
+
axis='both',
|
196 |
+
left='off',
|
197 |
+
top='off',
|
198 |
+
right='off',
|
199 |
+
bottom='off',
|
200 |
+
labelleft='off',
|
201 |
+
labeltop='off',
|
202 |
+
labelright='off',
|
203 |
+
labelbottom='off'
|
204 |
+
)
|
205 |
+
|
206 |
+
if ascii_seq is not None:
|
207 |
+
if not isinstance(ascii_seq, str):
|
208 |
+
ascii_seq = ''.join(list(map(chr, ascii_seq)))
|
209 |
+
plt.title(ascii_seq)
|
210 |
+
|
211 |
+
if save_file is not None:
|
212 |
+
plt.savefig(save_file)
|
213 |
+
print('saved to {}'.format(save_file))
|
214 |
+
else:
|
215 |
+
plt.show()
|
216 |
+
plt.close('all')
|
hand.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import drawing
|
2 |
+
from rnn import rnn
|
3 |
+
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import svgwrite
|
7 |
+
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
class Hand(object):
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
17 |
+
self.nn = rnn(
|
18 |
+
log_dir='logs',
|
19 |
+
checkpoint_dir='checkpoints',
|
20 |
+
prediction_dir='predictions',
|
21 |
+
learning_rates=[.0001, .00005, .00002],
|
22 |
+
batch_sizes=[32, 64, 64],
|
23 |
+
patiences=[1500, 1000, 500],
|
24 |
+
beta1_decays=[.9, .9, .9],
|
25 |
+
validation_batch_size=32,
|
26 |
+
optimizer='rms',
|
27 |
+
num_training_steps=100000,
|
28 |
+
warm_start_init_step=17900,
|
29 |
+
regularization_constant=0.0,
|
30 |
+
keep_prob=1.0,
|
31 |
+
enable_parameter_averaging=False,
|
32 |
+
min_steps_to_checkpoint=2000,
|
33 |
+
log_interval=20,
|
34 |
+
logging_level=logging.CRITICAL,
|
35 |
+
grad_clip=10,
|
36 |
+
lstm_size=400,
|
37 |
+
output_mixture_components=20,
|
38 |
+
attention_mixture_components=10
|
39 |
+
)
|
40 |
+
self.nn.restore()
|
41 |
+
|
42 |
+
def write(self, filename, lines, biases=None, styles=None, stroke_colors=None, stroke_widths=None):
|
43 |
+
valid_char_set = set(drawing.alphabet)
|
44 |
+
for line_num, line in enumerate(lines):
|
45 |
+
if len(line) > 75:
|
46 |
+
raise ValueError(
|
47 |
+
(
|
48 |
+
"Each line must be at most 75 characters. "
|
49 |
+
"Line {} contains {}"
|
50 |
+
).format(line_num, len(line))
|
51 |
+
)
|
52 |
+
|
53 |
+
for char in line:
|
54 |
+
if char not in valid_char_set:
|
55 |
+
raise ValueError(
|
56 |
+
(
|
57 |
+
"Invalid character {} detected in line {}. "
|
58 |
+
"Valid character set is {}"
|
59 |
+
).format(char, line_num, valid_char_set)
|
60 |
+
)
|
61 |
+
|
62 |
+
strokes = self._sample(lines, biases=biases, styles=styles)
|
63 |
+
self._draw(strokes, lines, filename, stroke_colors=stroke_colors, stroke_widths=stroke_widths)
|
64 |
+
|
65 |
+
def _sample(self, lines, biases=None, styles=None):
|
66 |
+
num_samples = len(lines)
|
67 |
+
max_tsteps = 40*max([len(i) for i in lines])
|
68 |
+
biases = biases if biases is not None else [0.5]*num_samples
|
69 |
+
|
70 |
+
x_prime = np.zeros([num_samples, 1200, 3])
|
71 |
+
x_prime_len = np.zeros([num_samples])
|
72 |
+
chars = np.zeros([num_samples, 120])
|
73 |
+
chars_len = np.zeros([num_samples])
|
74 |
+
|
75 |
+
if styles is not None:
|
76 |
+
for i, (cs, style) in enumerate(zip(lines, styles)):
|
77 |
+
x_p = np.load('styles/style-{}-strokes.npy'.format(style))
|
78 |
+
c_p = np.load('styles/style-{}-chars.npy'.format(style)).tostring().decode('utf-8')
|
79 |
+
|
80 |
+
c_p = str(c_p) + " " + cs
|
81 |
+
c_p = drawing.encode_ascii(c_p)
|
82 |
+
c_p = np.array(c_p)
|
83 |
+
|
84 |
+
x_prime[i, :len(x_p), :] = x_p
|
85 |
+
x_prime_len[i] = len(x_p)
|
86 |
+
chars[i, :len(c_p)] = c_p
|
87 |
+
chars_len[i] = len(c_p)
|
88 |
+
|
89 |
+
else:
|
90 |
+
for i in range(num_samples):
|
91 |
+
encoded = drawing.encode_ascii(lines[i])
|
92 |
+
chars[i, :len(encoded)] = encoded
|
93 |
+
chars_len[i] = len(encoded)
|
94 |
+
|
95 |
+
[samples] = self.nn.session.run(
|
96 |
+
[self.nn.sampled_sequence],
|
97 |
+
feed_dict={
|
98 |
+
self.nn.prime: styles is not None,
|
99 |
+
self.nn.x_prime: x_prime,
|
100 |
+
self.nn.x_prime_len: x_prime_len,
|
101 |
+
self.nn.num_samples: num_samples,
|
102 |
+
self.nn.sample_tsteps: max_tsteps,
|
103 |
+
self.nn.c: chars,
|
104 |
+
self.nn.c_len: chars_len,
|
105 |
+
self.nn.bias: biases
|
106 |
+
}
|
107 |
+
)
|
108 |
+
samples = [sample[~np.all(sample == 0.0, axis=1)] for sample in samples]
|
109 |
+
return samples
|
110 |
+
|
111 |
+
def _draw(self, strokes, lines, filename, stroke_colors=None, stroke_widths=None):
|
112 |
+
stroke_colors = stroke_colors or ['black']*len(lines)
|
113 |
+
stroke_widths = stroke_widths or [2]*len(lines)
|
114 |
+
|
115 |
+
line_height = 60
|
116 |
+
view_width = 1000
|
117 |
+
view_height = line_height*(len(strokes) + 1)
|
118 |
+
|
119 |
+
dwg = svgwrite.Drawing(filename=filename)
|
120 |
+
dwg.viewbox(width=view_width, height=view_height)
|
121 |
+
dwg.add(dwg.rect(insert=(0, 0), size=(view_width, view_height), fill='white'))
|
122 |
+
|
123 |
+
initial_coord = np.array([0, -(3*line_height / 4)])
|
124 |
+
for offsets, line, color, width in zip(strokes, lines, stroke_colors, stroke_widths):
|
125 |
+
|
126 |
+
if not line:
|
127 |
+
initial_coord[1] -= line_height
|
128 |
+
continue
|
129 |
+
|
130 |
+
offsets[:, :2] *= 1.5
|
131 |
+
strokes = drawing.offsets_to_coords(offsets)
|
132 |
+
strokes = drawing.denoise(strokes)
|
133 |
+
strokes[:, :2] = drawing.align(strokes[:, :2])
|
134 |
+
|
135 |
+
strokes[:, 1] *= -1
|
136 |
+
strokes[:, :2] -= strokes[:, :2].min() + initial_coord
|
137 |
+
strokes[:, 0] += (view_width - strokes[:, 0].max()) / 2
|
138 |
+
|
139 |
+
prev_eos = 1.0
|
140 |
+
p = "M{},{} ".format(0, 0)
|
141 |
+
for x, y, eos in zip(*strokes.T):
|
142 |
+
p += '{}{},{} '.format('M' if prev_eos == 1.0 else 'L', x, y)
|
143 |
+
prev_eos = eos
|
144 |
+
path = svgwrite.path.Path(p)
|
145 |
+
path = path.stroke(color=color, width=width, linecap='round').fill("none")
|
146 |
+
dwg.add(path)
|
147 |
+
|
148 |
+
initial_coord[1] -= line_height
|
149 |
+
|
150 |
+
dwg.save()
|
handwriting_api.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI as FastAPIApp, HTTPException
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from pydantic import BaseModel
|
4 |
+
import uvicorn
|
5 |
+
import os
|
6 |
+
from hand import Hand
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
app = FastAPIApp()
|
10 |
+
hand = Hand()
|
11 |
+
|
12 |
+
class InputData(BaseModel):
|
13 |
+
text: str
|
14 |
+
bias: Optional[float] = 0.75
|
15 |
+
style: Optional[int] = 9
|
16 |
+
stroke_colors: Optional[List[str]] = ['black']
|
17 |
+
stroke_widths: Optional[List[float]] = [2]
|
18 |
+
|
19 |
+
def validate_input(data: InputData):
|
20 |
+
if len(data.text) > 75:
|
21 |
+
raise ValueError("Text must be 75 characters or less")
|
22 |
+
if not (0.5 <= data.bias <= 1.0):
|
23 |
+
raise ValueError("Bias must be between 0.5 and 1.0")
|
24 |
+
if not (0 <= data.style <= 12):
|
25 |
+
raise ValueError("Style must be between 0 and 12")
|
26 |
+
if len(data.stroke_colors) != len(data.text.split('\n')):
|
27 |
+
raise ValueError("Number of stroke colors must match number of lines")
|
28 |
+
if len(data.stroke_widths) != len(data.text.split('\n')):
|
29 |
+
raise ValueError("Number of stroke widths must match number of lines")
|
30 |
+
|
31 |
+
@app.post("/synthesize")
|
32 |
+
def synthesize(data: InputData):
|
33 |
+
try:
|
34 |
+
validate_input(data)
|
35 |
+
lines = data.text.split('\n')
|
36 |
+
biases = [data.bias] * len(lines)
|
37 |
+
styles = [data.style] * len(lines)
|
38 |
+
|
39 |
+
hand.write(
|
40 |
+
filename='img/output.svg',
|
41 |
+
lines=lines,
|
42 |
+
biases=biases,
|
43 |
+
styles=styles,
|
44 |
+
stroke_colors=data.stroke_colors,
|
45 |
+
stroke_widths=data.stroke_widths
|
46 |
+
)
|
47 |
+
return {"result": "Handwriting synthesized successfully", "output_file": "img/output.svg"}
|
48 |
+
except ValueError as e:
|
49 |
+
raise HTTPException(status_code=400, detail=str(e))
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
lyrics.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""lyrics taken from https://www.azlyrics.com/"""
|
2 |
+
|
3 |
+
all_star = """Somebody once told me the world is gonna roll me
|
4 |
+
I ain't the sharpest tool in the shed
|
5 |
+
She was looking kind of dumb with her finger and her thumb
|
6 |
+
In the shape of an "L" on her forehead
|
7 |
+
|
8 |
+
Well, the years start coming and they don't stop coming
|
9 |
+
Fed to the rules and I hit the ground running
|
10 |
+
Didn't make sense not to live for fun
|
11 |
+
Your brain gets smart but your head gets dumb
|
12 |
+
|
13 |
+
So much to do, so much to see
|
14 |
+
So what's wrong with taking the back streets?
|
15 |
+
You'll never know if you don't go
|
16 |
+
You'll never shine if you don't glow
|
17 |
+
|
18 |
+
Hey, now, you're an All Star, get your game on, go play
|
19 |
+
Hey, now, you're a Rock Star, get the show on, get paid
|
20 |
+
And all that glitters is gold
|
21 |
+
Only shooting stars break the mold
|
22 |
+
|
23 |
+
It's a cool place and they say it gets colder
|
24 |
+
You're bundled up now wait 'til you get older
|
25 |
+
But the meteor men beg to differ
|
26 |
+
Judging by the hole in the satellite picture
|
27 |
+
|
28 |
+
The ice we skate is getting pretty thin
|
29 |
+
The water's getting warm so you might as well swim
|
30 |
+
My world's on fire. How about yours?
|
31 |
+
That's the way I like it and I'll never get bored.
|
32 |
+
|
33 |
+
Somebody once asked could I spare some change for gas
|
34 |
+
I need to get myself away from this place
|
35 |
+
I said yep, what a concept
|
36 |
+
I could use a little fuel myself
|
37 |
+
And we could all use a little change
|
38 |
+
|
39 |
+
Well, the years start coming and they don't stop coming
|
40 |
+
Fed to the rules and I hit the ground running
|
41 |
+
Didn't make sense not to live for fun
|
42 |
+
Your brain gets smart but your head gets dumb
|
43 |
+
|
44 |
+
So much to do, so much to see
|
45 |
+
So what's wrong with taking the back streets?
|
46 |
+
You'll never know if you don't go
|
47 |
+
You'll never shine if you don't glow.
|
48 |
+
|
49 |
+
And all that glitters is gold
|
50 |
+
Only shooting stars break the mold"""
|
51 |
+
|
52 |
+
downtown = """Making my way downtown
|
53 |
+
Walking fast
|
54 |
+
Faces pass
|
55 |
+
And I'm home-bound
|
56 |
+
|
57 |
+
Staring blankly ahead
|
58 |
+
Just making my way
|
59 |
+
Making a way
|
60 |
+
Through the crowd
|
61 |
+
|
62 |
+
And I need you
|
63 |
+
And I miss you
|
64 |
+
And now I wonder
|
65 |
+
|
66 |
+
If I could fall into the sky
|
67 |
+
Do you think time would pass me by?
|
68 |
+
'Cause you know I'd walk a thousand miles
|
69 |
+
If I could just see you tonight
|
70 |
+
|
71 |
+
It's always times like these
|
72 |
+
When I think of you
|
73 |
+
And I wonder if you ever think of me
|
74 |
+
'Cause everything's so wrong
|
75 |
+
And I don't belong
|
76 |
+
Living in your precious memory
|
77 |
+
|
78 |
+
'Cause I need you
|
79 |
+
And I miss you
|
80 |
+
And now I wonder
|
81 |
+
|
82 |
+
If I could fall into the sky
|
83 |
+
Do you think time would pass me by?
|
84 |
+
'Cause you know I'd walk a thousand miles
|
85 |
+
If I could just see you tonight
|
86 |
+
|
87 |
+
And I, I don't wanna let you know
|
88 |
+
I, I drown in your memory
|
89 |
+
I, I don't wanna let this go
|
90 |
+
I, I don't
|
91 |
+
|
92 |
+
Making my way downtown
|
93 |
+
Walking fast
|
94 |
+
Faces pass
|
95 |
+
And I'm home-bound
|
96 |
+
|
97 |
+
Staring blankly ahead
|
98 |
+
Just making my way
|
99 |
+
Making a way
|
100 |
+
Through the crowd
|
101 |
+
|
102 |
+
And I still need you
|
103 |
+
And I still miss you
|
104 |
+
And now I wonder
|
105 |
+
|
106 |
+
If I could fall into the sky
|
107 |
+
Do you think time would pass us by?
|
108 |
+
'Cause you know I'd walk a thousand miles
|
109 |
+
If I could just see you
|
110 |
+
|
111 |
+
If I could fall into the sky
|
112 |
+
Do you think time would pass me by?
|
113 |
+
'Cause you know I'd walk a thousand miles
|
114 |
+
If I could just see you
|
115 |
+
If I could just hold you tonight"""
|
116 |
+
|
117 |
+
give_up = """We're no strangers to love
|
118 |
+
You know the rules and so do I
|
119 |
+
A full commitment's what I'm thinking of
|
120 |
+
You wouldn't get this from any other guy
|
121 |
+
|
122 |
+
I just wanna tell you how I'm feeling
|
123 |
+
Gotta make you understand
|
124 |
+
|
125 |
+
Never gonna give you up
|
126 |
+
Never gonna let you down
|
127 |
+
Never gonna run around and desert you
|
128 |
+
Never gonna make you cry
|
129 |
+
Never gonna say goodbye
|
130 |
+
Never gonna tell a lie and hurt you
|
131 |
+
|
132 |
+
We've known each other for so long
|
133 |
+
Your heart's been aching, but
|
134 |
+
You're too shy to say it
|
135 |
+
Inside, we both know what's been going on
|
136 |
+
We know the game and we're gonna play it
|
137 |
+
|
138 |
+
And if you ask me how I'm feeling
|
139 |
+
Don't tell me you're too blind to see
|
140 |
+
|
141 |
+
Never gonna give you up
|
142 |
+
Never gonna let you down
|
143 |
+
Never gonna run around and desert you
|
144 |
+
Never gonna make you cry
|
145 |
+
Never gonna say goodbye
|
146 |
+
Never gonna tell a lie and hurt you
|
147 |
+
|
148 |
+
Never gonna give you up
|
149 |
+
Never gonna let you down
|
150 |
+
Never gonna run around and desert you
|
151 |
+
Never gonna make you cry
|
152 |
+
Never gonna say goodbye
|
153 |
+
Never gonna tell a lie and hurt you
|
154 |
+
|
155 |
+
(Ooh, give you up)
|
156 |
+
(Ooh, give you up)
|
157 |
+
Never gonna give, never gonna give
|
158 |
+
(Give you up)
|
159 |
+
Never gonna give, never gonna give
|
160 |
+
(Give you up)
|
161 |
+
|
162 |
+
We've known each other for so long
|
163 |
+
Your heart's been aching, but
|
164 |
+
You're too shy to say it
|
165 |
+
Inside, we both know what's been going on
|
166 |
+
We know the game and we're gonna play it
|
167 |
+
|
168 |
+
I just wanna tell you how I'm feeling
|
169 |
+
Gotta make you understand
|
170 |
+
|
171 |
+
Never gonna give you up
|
172 |
+
Never gonna let you down
|
173 |
+
Never gonna run around and desert you
|
174 |
+
Never gonna make you cry
|
175 |
+
Never gonna say goodbye
|
176 |
+
Never gonna tell a lie and hurt you
|
177 |
+
|
178 |
+
Never gonna give you up
|
179 |
+
Never gonna let you down
|
180 |
+
Never gonna run around and desert you
|
181 |
+
Never gonna make you cry
|
182 |
+
Never gonna say goodbye
|
183 |
+
Never gonna tell a lie and hurt you
|
184 |
+
|
185 |
+
Never gonna give you up
|
186 |
+
Never gonna let you down
|
187 |
+
Never gonna run around and desert you
|
188 |
+
Never gonna make you cry
|
189 |
+
Never gonna say goodbye
|
190 |
+
Never gonna tell a lie and hurt you"""
|
prepare_data.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
from xml.etree import ElementTree
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import drawing
|
8 |
+
|
9 |
+
|
10 |
+
def get_stroke_sequence(filename):
|
11 |
+
tree = ElementTree.parse(filename).getroot()
|
12 |
+
strokes = [i for i in tree if i.tag == 'StrokeSet'][0]
|
13 |
+
|
14 |
+
coords = []
|
15 |
+
for stroke in strokes:
|
16 |
+
for i, point in enumerate(stroke):
|
17 |
+
coords.append([
|
18 |
+
int(point.attrib['x']),
|
19 |
+
-1*int(point.attrib['y']),
|
20 |
+
int(i == len(stroke) - 1)
|
21 |
+
])
|
22 |
+
coords = np.array(coords)
|
23 |
+
|
24 |
+
coords = drawing.align(coords)
|
25 |
+
coords = drawing.denoise(coords)
|
26 |
+
offsets = drawing.coords_to_offsets(coords)
|
27 |
+
offsets = offsets[:drawing.MAX_STROKE_LEN]
|
28 |
+
offsets = drawing.normalize(offsets)
|
29 |
+
return offsets
|
30 |
+
|
31 |
+
|
32 |
+
def get_ascii_sequences(filename):
|
33 |
+
sequences = open(filename, 'r').read()
|
34 |
+
sequences = sequences.replace(r'%%%%%%%%%%%', '\n')
|
35 |
+
sequences = [i.strip() for i in sequences.split('\n')]
|
36 |
+
lines = sequences[sequences.index('CSR:') + 2:]
|
37 |
+
lines = [line.strip() for line in lines if line.strip()]
|
38 |
+
lines = [drawing.encode_ascii(line)[:drawing.MAX_CHAR_LEN] for line in lines]
|
39 |
+
return lines
|
40 |
+
|
41 |
+
|
42 |
+
def collect_data():
|
43 |
+
fnames = []
|
44 |
+
for dirpath, dirnames, filenames in os.walk('data/raw/ascii/'):
|
45 |
+
if dirnames:
|
46 |
+
continue
|
47 |
+
for filename in filenames:
|
48 |
+
if filename.startswith('.'):
|
49 |
+
continue
|
50 |
+
fnames.append(os.path.join(dirpath, filename))
|
51 |
+
|
52 |
+
# low quality samples (selected by collecting samples to
|
53 |
+
# which the trained model assigned very low likelihood)
|
54 |
+
blacklist = set(np.load('data/blacklist.npy'))
|
55 |
+
|
56 |
+
stroke_fnames, transcriptions, writer_ids = [], [], []
|
57 |
+
for i, fname in enumerate(fnames):
|
58 |
+
print(i, fname)
|
59 |
+
if fname == 'data/raw/ascii/z01/z01-000/z01-000z.txt':
|
60 |
+
continue
|
61 |
+
|
62 |
+
head, tail = os.path.split(fname)
|
63 |
+
last_letter = os.path.splitext(fname)[0][-1]
|
64 |
+
last_letter = last_letter if last_letter.isalpha() else ''
|
65 |
+
|
66 |
+
line_stroke_dir = head.replace('ascii', 'lineStrokes')
|
67 |
+
line_stroke_fname_prefix = os.path.split(head)[-1] + last_letter + '-'
|
68 |
+
|
69 |
+
if not os.path.isdir(line_stroke_dir):
|
70 |
+
continue
|
71 |
+
line_stroke_fnames = sorted([f for f in os.listdir(line_stroke_dir)
|
72 |
+
if f.startswith(line_stroke_fname_prefix)])
|
73 |
+
if not line_stroke_fnames:
|
74 |
+
continue
|
75 |
+
|
76 |
+
original_dir = head.replace('ascii', 'original')
|
77 |
+
original_xml = os.path.join(original_dir, 'strokes' + last_letter + '.xml')
|
78 |
+
tree = ElementTree.parse(original_xml)
|
79 |
+
root = tree.getroot()
|
80 |
+
|
81 |
+
general = root.find('General')
|
82 |
+
if general is not None:
|
83 |
+
writer_id = int(general[0].attrib.get('writerID', '0'))
|
84 |
+
else:
|
85 |
+
writer_id = int('0')
|
86 |
+
|
87 |
+
ascii_sequences = get_ascii_sequences(fname)
|
88 |
+
assert len(ascii_sequences) == len(line_stroke_fnames)
|
89 |
+
|
90 |
+
for ascii_seq, line_stroke_fname in zip(ascii_sequences, line_stroke_fnames):
|
91 |
+
if line_stroke_fname in blacklist:
|
92 |
+
continue
|
93 |
+
|
94 |
+
stroke_fnames.append(os.path.join(line_stroke_dir, line_stroke_fname))
|
95 |
+
transcriptions.append(ascii_seq)
|
96 |
+
writer_ids.append(writer_id)
|
97 |
+
|
98 |
+
return stroke_fnames, transcriptions, writer_ids
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == '__main__':
|
102 |
+
print('traversing data directory...')
|
103 |
+
stroke_fnames, transcriptions, writer_ids = collect_data()
|
104 |
+
|
105 |
+
print('dumping to numpy arrays...')
|
106 |
+
x = np.zeros([len(stroke_fnames), drawing.MAX_STROKE_LEN, 3], dtype=np.float32)
|
107 |
+
x_len = np.zeros([len(stroke_fnames)], dtype=np.int16)
|
108 |
+
c = np.zeros([len(stroke_fnames), drawing.MAX_CHAR_LEN], dtype=np.int8)
|
109 |
+
c_len = np.zeros([len(stroke_fnames)], dtype=np.int8)
|
110 |
+
w_id = np.zeros([len(stroke_fnames)], dtype=np.int16)
|
111 |
+
valid_mask = np.zeros([len(stroke_fnames)], dtype=np.bool)
|
112 |
+
|
113 |
+
for i, (stroke_fname, c_i, w_id_i) in enumerate(zip(stroke_fnames, transcriptions, writer_ids)):
|
114 |
+
if i % 200 == 0:
|
115 |
+
print(i, '\t', '/', len(stroke_fnames))
|
116 |
+
x_i = get_stroke_sequence(stroke_fname)
|
117 |
+
valid_mask[i] = ~np.any(np.linalg.norm(x_i[:, :2], axis=1) > 60)
|
118 |
+
|
119 |
+
x[i, :len(x_i), :] = x_i
|
120 |
+
x_len[i] = len(x_i)
|
121 |
+
|
122 |
+
c[i, :len(c_i)] = c_i
|
123 |
+
c_len[i] = len(c_i)
|
124 |
+
|
125 |
+
w_id[i] = w_id_i
|
126 |
+
|
127 |
+
if not os.path.isdir('data/processed'):
|
128 |
+
os.makedirs('data/processed')
|
129 |
+
|
130 |
+
np.save('data/processed/x.npy', x[valid_mask])
|
131 |
+
np.save('data/processed/x_len.npy', x_len[valid_mask])
|
132 |
+
np.save('data/processed/c.npy', c[valid_mask])
|
133 |
+
np.save('data/processed/c_len.npy', c_len[valid_mask])
|
134 |
+
np.save('data/processed/w_id.npy', w_id[valid_mask])
|
readme.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+

|
2 |
+
# Handwriting Synthesis
|
3 |
+
Implementation of the handwriting synthesis experiments in the paper <a href="https://arxiv.org/abs/1308.0850">Generating Sequences with Recurrent Neural Networks</a> by Alex Graves. The implementation closely follows the original paper, with a few slight deviations, and the generated samples are of similar quality to those presented in the paper.
|
4 |
+
|
5 |
+
Web demo from original github (which I forked) is available <a href="https://seanvasquez.com/handwriting-generation/">here</a>.
|
6 |
+
|
7 |
+
## Fork Notes
|
8 |
+
Fork has the following changes from the original
|
9 |
+
1. Updated to work with TensorFlow 2.15.0 (current as of December 2023), but uses the v1 compat mechanism. Honestly I hacked and burned through many errors in an hour and a half (I forked at 7:05 PM and am writing this at 8:35 PM) and just verified by running the demo.py and looking through the output image files. It didn't create a banner.svg file (no idea if it was supposed to), and it gets a lot of deprecation warnings so use at your own risk. TensorFlow will likely abandon some of this V1 compat stuff in the relatively near future, but this should work as long as you can still get 2.15.0 for whatever python version you have.
|
10 |
+
2. I'm going to split the Hand class into its own file (per the original author's suggestion).
|
11 |
+
3. I left everything else below this text alone (except striking the "split Hand class" request).
|
12 |
+
4. If you think my fork is super sloppy (you're right) and want to do it right- the major job is to convert the deprecated `tf.nn.rnn_cell.LSTMCell` and replace it with `tf.keras.layers.LSTMCell`. It is not a drop in replacement. Have at it.
|
13 |
+
|
14 |
+
## Usage
|
15 |
+
```python
|
16 |
+
lines = [
|
17 |
+
"Now this is a story all about how",
|
18 |
+
"My life got flipped turned upside down",
|
19 |
+
"And I'd like to take a minute, just sit right there",
|
20 |
+
"I'll tell you how I became the prince of a town called Bel-Air",
|
21 |
+
]
|
22 |
+
biases = [.75 for i in lines]
|
23 |
+
styles = [9 for i in lines]
|
24 |
+
stroke_colors = ['red', 'green', 'black', 'blue']
|
25 |
+
stroke_widths = [1, 2, 1, 2]
|
26 |
+
|
27 |
+
hand = Hand()
|
28 |
+
hand.write(
|
29 |
+
filename='img/usage_demo.svg',
|
30 |
+
lines=lines,
|
31 |
+
biases=biases,
|
32 |
+
styles=styles,
|
33 |
+
stroke_colors=stroke_colors,
|
34 |
+
stroke_widths=stroke_widths
|
35 |
+
)
|
36 |
+
```
|
37 |
+

|
38 |
+
|
39 |
+
~~Currently, the `Hand` class must be imported from `demo.py`. If someone would like to package this project to make it more usable, please [contribute](#contribute).~~
|
40 |
+
|
41 |
+
A pretrained model is included, but if you'd like to train your own, read <a href='https://github.com/sjvasquez/handwriting-synthesis/tree/master/data/raw'>these instructions</a>.
|
42 |
+
|
43 |
+
## Demonstrations
|
44 |
+
Below are a few hundred samples from the model, including some samples demonstrating the effect of priming and biasing the model. Loosely speaking, biasing controls the neatness of the samples and priming controls the style of the samples. The code for these demonstrations can be found in `demo.py`.
|
45 |
+
|
46 |
+
### Demo #1:
|
47 |
+
The following samples were generated with a fixed style and fixed bias.
|
48 |
+
|
49 |
+
**Smash Mouth – All Star (<a href="https://www.azlyrics.com/lyrics/smashmouth/allstar.html">lyrics</a>)**
|
50 |
+

|
51 |
+
|
52 |
+
### Demo #2
|
53 |
+
The following samples were generated with varying style and fixed bias. Each verse is generated in a different style.
|
54 |
+
|
55 |
+
**Vanessa Carlton – A Thousand Miles (<a href="https://www.azlyrics.com/lyrics/vanessacarlton/athousandmiles.html">lyrics</a>)**
|
56 |
+

|
57 |
+
|
58 |
+
### Demo #3
|
59 |
+
The following samples were generated with a fixed style and varying bias. Each verse has a lower bias than the previous, with the last verse being unbiased.
|
60 |
+
|
61 |
+
**Leonard Cohen – Hallelujah (<a href="https://www.youtube.com/watch?v=dQw4w9WgXcQ">lyrics</a>)**
|
62 |
+

|
63 |
+
|
64 |
+
## Contribute
|
65 |
+
This project was intended to serve as a reference implementation for a research paper, but since the results are of decent quality, it may be worthwile to make the project more broadly usable. I plan to continue focusing on the machine learning side of things. That said, I'd welcome contributors who can:
|
66 |
+
|
67 |
+
- Package this, and otherwise make it look more like a usable software project and less like research code.
|
68 |
+
- Add support for more sophisticated drawing, animations, or anything else in this direction. Currently, the project only creates some simple svg files.
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib==3.9.4
|
2 |
+
pandas==2.2.3
|
3 |
+
scikit-learn==1.6.1
|
4 |
+
scipy==1.13.0
|
5 |
+
svgwrite==1.1.12
|
6 |
+
tensorflow-probability==0.23.0
|
7 |
+
tensorflow==2.15.0
|
8 |
+
fastapi==0.109.0
|
9 |
+
uvicorn==0.27.0
|
10 |
+
pydantic==2.5.0
|
11 |
+
cairosvg==2.7.1
|
12 |
+
gradio==4.29.0
|
13 |
+
python-multipart==0.0.9
|
14 |
+
httpx==0.27.0
|
15 |
+
requests==2.31.0
|
rnn.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import tensorflow.compat.v1 as tf
|
6 |
+
tf.disable_v2_behavior()
|
7 |
+
|
8 |
+
import drawing
|
9 |
+
from data_frame import DataFrame
|
10 |
+
from rnn_cell import LSTMAttentionCell
|
11 |
+
from rnn_ops import rnn_free_run
|
12 |
+
from tf_base_model import TFBaseModel
|
13 |
+
from tf_utils import time_distributed_dense_layer
|
14 |
+
|
15 |
+
|
16 |
+
class DataReader(object):
|
17 |
+
|
18 |
+
def __init__(self, data_dir):
|
19 |
+
data_cols = ['x', 'x_len', 'c', 'c_len']
|
20 |
+
data = [np.load(os.path.join(data_dir, '{}.npy'.format(i))) for i in data_cols]
|
21 |
+
|
22 |
+
self.test_df = DataFrame(columns=data_cols, data=data)
|
23 |
+
self.train_df, self.val_df = self.test_df.train_test_split(train_size=0.95, random_state=2018)
|
24 |
+
|
25 |
+
print('train size', len(self.train_df))
|
26 |
+
print('val size', len(self.val_df))
|
27 |
+
print('test size', len(self.test_df))
|
28 |
+
|
29 |
+
def train_batch_generator(self, batch_size):
|
30 |
+
return self.batch_generator(
|
31 |
+
batch_size=batch_size,
|
32 |
+
df=self.train_df,
|
33 |
+
shuffle=True,
|
34 |
+
num_epochs=10000,
|
35 |
+
mode='train'
|
36 |
+
)
|
37 |
+
|
38 |
+
def val_batch_generator(self, batch_size):
|
39 |
+
return self.batch_generator(
|
40 |
+
batch_size=batch_size,
|
41 |
+
df=self.val_df,
|
42 |
+
shuffle=True,
|
43 |
+
num_epochs=10000,
|
44 |
+
mode='val'
|
45 |
+
)
|
46 |
+
|
47 |
+
def test_batch_generator(self, batch_size):
|
48 |
+
return self.batch_generator(
|
49 |
+
batch_size=batch_size,
|
50 |
+
df=self.test_df,
|
51 |
+
shuffle=False,
|
52 |
+
num_epochs=1,
|
53 |
+
mode='test'
|
54 |
+
)
|
55 |
+
|
56 |
+
def batch_generator(self, batch_size, df, shuffle=True, num_epochs=10000, mode='train'):
|
57 |
+
gen = df.batch_generator(
|
58 |
+
batch_size=batch_size,
|
59 |
+
shuffle=shuffle,
|
60 |
+
num_epochs=num_epochs,
|
61 |
+
allow_smaller_final_batch=(mode == 'test')
|
62 |
+
)
|
63 |
+
for batch in gen:
|
64 |
+
batch['x_len'] = batch['x_len'] - 1
|
65 |
+
max_x_len = np.max(batch['x_len'])
|
66 |
+
max_c_len = np.max(batch['c_len'])
|
67 |
+
batch['y'] = batch['x'][:, 1:max_x_len + 1, :]
|
68 |
+
batch['x'] = batch['x'][:, :max_x_len, :]
|
69 |
+
batch['c'] = batch['c'][:, :max_c_len]
|
70 |
+
yield batch
|
71 |
+
|
72 |
+
|
73 |
+
class rnn(TFBaseModel):
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
lstm_size,
|
78 |
+
output_mixture_components,
|
79 |
+
attention_mixture_components,
|
80 |
+
**kwargs
|
81 |
+
):
|
82 |
+
self.lstm_size = lstm_size
|
83 |
+
self.output_mixture_components = output_mixture_components
|
84 |
+
self.output_units = self.output_mixture_components*6 + 1
|
85 |
+
self.attention_mixture_components = attention_mixture_components
|
86 |
+
super(rnn, self).__init__(**kwargs)
|
87 |
+
|
88 |
+
def parse_parameters(self, z, eps=1e-8, sigma_eps=1e-4):
|
89 |
+
pis, sigmas, rhos, mus, es = tf.split(
|
90 |
+
z,
|
91 |
+
[
|
92 |
+
1*self.output_mixture_components,
|
93 |
+
2*self.output_mixture_components,
|
94 |
+
1*self.output_mixture_components,
|
95 |
+
2*self.output_mixture_components,
|
96 |
+
1
|
97 |
+
],
|
98 |
+
axis=-1
|
99 |
+
)
|
100 |
+
pis = tf.nn.softmax(pis, axis=-1)
|
101 |
+
sigmas = tf.clip_by_value(tf.exp(sigmas), sigma_eps, np.inf)
|
102 |
+
rhos = tf.clip_by_value(tf.tanh(rhos), eps - 1.0, 1.0 - eps)
|
103 |
+
es = tf.clip_by_value(tf.nn.sigmoid(es), eps, 1.0 - eps)
|
104 |
+
return pis, mus, sigmas, rhos, es
|
105 |
+
|
106 |
+
def NLL(self, y, lengths, pis, mus, sigmas, rho, es, eps=1e-8):
|
107 |
+
sigma_1, sigma_2 = tf.split(sigmas, 2, axis=2)
|
108 |
+
y_1, y_2, y_3 = tf.split(y, 3, axis=2)
|
109 |
+
mu_1, mu_2 = tf.split(mus, 2, axis=2)
|
110 |
+
|
111 |
+
norm = 1.0 / (2*np.pi*sigma_1*sigma_2 * tf.sqrt(1 - tf.square(rho)))
|
112 |
+
Z = tf.square((y_1 - mu_1) / (sigma_1)) + \
|
113 |
+
tf.square((y_2 - mu_2) / (sigma_2)) - \
|
114 |
+
2*rho*(y_1 - mu_1)*(y_2 - mu_2) / (sigma_1*sigma_2)
|
115 |
+
|
116 |
+
exp = -1.0*Z / (2*(1 - tf.square(rho)))
|
117 |
+
gaussian_likelihoods = tf.exp(exp) * norm
|
118 |
+
gmm_likelihood = tf.reduce_sum(pis * gaussian_likelihoods, 2)
|
119 |
+
gmm_likelihood = tf.clip_by_value(gmm_likelihood, eps, np.inf)
|
120 |
+
|
121 |
+
bernoulli_likelihood = tf.squeeze(tf.where(tf.equal(tf.ones_like(y_3), y_3), es, 1 - es))
|
122 |
+
|
123 |
+
nll = -(tf.log(gmm_likelihood) + tf.log(bernoulli_likelihood))
|
124 |
+
sequence_mask = tf.logical_and(
|
125 |
+
tf.sequence_mask(lengths, maxlen=tf.shape(y)[1]),
|
126 |
+
tf.logical_not(tf.is_nan(nll)),
|
127 |
+
)
|
128 |
+
nll = tf.where(sequence_mask, nll, tf.zeros_like(nll))
|
129 |
+
num_valid = tf.reduce_sum(tf.cast(sequence_mask, tf.float32), axis=1)
|
130 |
+
|
131 |
+
sequence_loss = tf.reduce_sum(nll, axis=1) / tf.maximum(num_valid, 1.0)
|
132 |
+
element_loss = tf.reduce_sum(nll) / tf.maximum(tf.reduce_sum(num_valid), 1.0)
|
133 |
+
return sequence_loss, element_loss
|
134 |
+
|
135 |
+
def sample(self, cell):
|
136 |
+
initial_state = cell.zero_state(self.num_samples, dtype=tf.float32)
|
137 |
+
initial_input = tf.concat([
|
138 |
+
tf.zeros([self.num_samples, 2]),
|
139 |
+
tf.ones([self.num_samples, 1]),
|
140 |
+
], axis=1)
|
141 |
+
return rnn_free_run(
|
142 |
+
cell=cell,
|
143 |
+
sequence_length=self.sample_tsteps,
|
144 |
+
initial_state=initial_state,
|
145 |
+
initial_input=initial_input,
|
146 |
+
scope='rnn'
|
147 |
+
)[1]
|
148 |
+
|
149 |
+
def primed_sample(self, cell):
|
150 |
+
initial_state = cell.zero_state(self.num_samples, dtype=tf.float32)
|
151 |
+
primed_state = tf.nn.dynamic_rnn(
|
152 |
+
inputs=self.x_prime,
|
153 |
+
cell=cell,
|
154 |
+
sequence_length=self.x_prime_len,
|
155 |
+
dtype=tf.float32,
|
156 |
+
initial_state=initial_state,
|
157 |
+
scope='rnn'
|
158 |
+
)[1]
|
159 |
+
return rnn_free_run(
|
160 |
+
cell=cell,
|
161 |
+
sequence_length=self.sample_tsteps,
|
162 |
+
initial_state=primed_state,
|
163 |
+
scope='rnn'
|
164 |
+
)[1]
|
165 |
+
|
166 |
+
def calculate_loss(self):
|
167 |
+
self.x = tf.placeholder(tf.float32, [None, None, 3])
|
168 |
+
self.y = tf.placeholder(tf.float32, [None, None, 3])
|
169 |
+
self.x_len = tf.placeholder(tf.int32, [None])
|
170 |
+
self.c = tf.placeholder(tf.int32, [None, None])
|
171 |
+
self.c_len = tf.placeholder(tf.int32, [None])
|
172 |
+
|
173 |
+
self.sample_tsteps = tf.placeholder(tf.int32, [])
|
174 |
+
self.num_samples = tf.placeholder(tf.int32, [])
|
175 |
+
self.prime = tf.placeholder(tf.bool, [])
|
176 |
+
self.x_prime = tf.placeholder(tf.float32, [None, None, 3])
|
177 |
+
self.x_prime_len = tf.placeholder(tf.int32, [None])
|
178 |
+
self.bias = tf.placeholder_with_default(
|
179 |
+
tf.zeros([self.num_samples], dtype=tf.float32), [None])
|
180 |
+
|
181 |
+
cell = LSTMAttentionCell(
|
182 |
+
lstm_size=self.lstm_size,
|
183 |
+
num_attn_mixture_components=self.attention_mixture_components,
|
184 |
+
attention_values=tf.one_hot(self.c, len(drawing.alphabet)),
|
185 |
+
attention_values_lengths=self.c_len,
|
186 |
+
num_output_mixture_components=self.output_mixture_components,
|
187 |
+
bias=self.bias
|
188 |
+
)
|
189 |
+
self.initial_state = cell.zero_state(tf.shape(self.x)[0], dtype=tf.float32)
|
190 |
+
outputs, self.final_state = tf.nn.dynamic_rnn(
|
191 |
+
inputs=self.x,
|
192 |
+
cell=cell,
|
193 |
+
sequence_length=self.x_len,
|
194 |
+
dtype=tf.float32,
|
195 |
+
initial_state=self.initial_state,
|
196 |
+
scope='rnn'
|
197 |
+
)
|
198 |
+
params = time_distributed_dense_layer(outputs, self.output_units, scope='rnn/gmm')
|
199 |
+
pis, mus, sigmas, rhos, es = self.parse_parameters(params)
|
200 |
+
sequence_loss, self.loss = self.NLL(self.y, self.x_len, pis, mus, sigmas, rhos, es)
|
201 |
+
|
202 |
+
self.sampled_sequence = tf.cond(
|
203 |
+
self.prime,
|
204 |
+
lambda: self.primed_sample(cell),
|
205 |
+
lambda: self.sample(cell)
|
206 |
+
)
|
207 |
+
return self.loss
|
208 |
+
|
209 |
+
|
210 |
+
if __name__ == '__main__':
|
211 |
+
dr = DataReader(data_dir='data/processed/')
|
212 |
+
|
213 |
+
nn = rnn(
|
214 |
+
reader=dr,
|
215 |
+
log_dir='logs',
|
216 |
+
checkpoint_dir='checkpoints',
|
217 |
+
prediction_dir='predictions',
|
218 |
+
learning_rates=[.0001, .00005, .00002],
|
219 |
+
batch_sizes=[32, 64, 64],
|
220 |
+
patiences=[1500, 1000, 500],
|
221 |
+
beta1_decays=[.9, .9, .9],
|
222 |
+
validation_batch_size=32,
|
223 |
+
optimizer='rms',
|
224 |
+
num_training_steps=100000,
|
225 |
+
warm_start_init_step=0,
|
226 |
+
regularization_constant=0.0,
|
227 |
+
keep_prob=1.0,
|
228 |
+
enable_parameter_averaging=False,
|
229 |
+
min_steps_to_checkpoint=2000,
|
230 |
+
log_interval=20,
|
231 |
+
grad_clip=10,
|
232 |
+
lstm_size=400,
|
233 |
+
output_mixture_components=20,
|
234 |
+
attention_mixture_components=10
|
235 |
+
)
|
236 |
+
nn.fit()
|
rnn_cell.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
import tensorflow.compat.v1 as tf
|
4 |
+
tf.disable_v2_behavior()
|
5 |
+
import tensorflow_probability as tfp
|
6 |
+
tfd = tfp.distributions
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from tf_utils import dense_layer, shape
|
10 |
+
|
11 |
+
|
12 |
+
LSTMAttentionCellState = namedtuple(
|
13 |
+
'LSTMAttentionCellState',
|
14 |
+
['h1', 'c1', 'h2', 'c2', 'h3', 'c3', 'alpha', 'beta', 'kappa', 'w', 'phi']
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class LSTMAttentionCell(tf.nn.rnn_cell.RNNCell):
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
lstm_size,
|
23 |
+
num_attn_mixture_components,
|
24 |
+
attention_values,
|
25 |
+
attention_values_lengths,
|
26 |
+
num_output_mixture_components,
|
27 |
+
bias,
|
28 |
+
reuse=None,
|
29 |
+
):
|
30 |
+
self.reuse = reuse
|
31 |
+
self.lstm_size = lstm_size
|
32 |
+
self.num_attn_mixture_components = num_attn_mixture_components
|
33 |
+
self.attention_values = attention_values
|
34 |
+
self.attention_values_lengths = attention_values_lengths
|
35 |
+
self.window_size = shape(self.attention_values, 2)
|
36 |
+
self.char_len = tf.shape(attention_values)[1]
|
37 |
+
self.batch_size = tf.shape(attention_values)[0]
|
38 |
+
self.num_output_mixture_components = num_output_mixture_components
|
39 |
+
self.output_units = 6*self.num_output_mixture_components + 1
|
40 |
+
self.bias = bias
|
41 |
+
|
42 |
+
@property
|
43 |
+
def state_size(self):
|
44 |
+
return LSTMAttentionCellState(
|
45 |
+
self.lstm_size,
|
46 |
+
self.lstm_size,
|
47 |
+
self.lstm_size,
|
48 |
+
self.lstm_size,
|
49 |
+
self.lstm_size,
|
50 |
+
self.lstm_size,
|
51 |
+
self.num_attn_mixture_components,
|
52 |
+
self.num_attn_mixture_components,
|
53 |
+
self.num_attn_mixture_components,
|
54 |
+
self.window_size,
|
55 |
+
self.char_len,
|
56 |
+
)
|
57 |
+
|
58 |
+
@property
|
59 |
+
def output_size(self):
|
60 |
+
return self.lstm_size
|
61 |
+
|
62 |
+
def zero_state(self, batch_size, dtype):
|
63 |
+
return LSTMAttentionCellState(
|
64 |
+
tf.zeros([batch_size, self.lstm_size]),
|
65 |
+
tf.zeros([batch_size, self.lstm_size]),
|
66 |
+
tf.zeros([batch_size, self.lstm_size]),
|
67 |
+
tf.zeros([batch_size, self.lstm_size]),
|
68 |
+
tf.zeros([batch_size, self.lstm_size]),
|
69 |
+
tf.zeros([batch_size, self.lstm_size]),
|
70 |
+
tf.zeros([batch_size, self.num_attn_mixture_components]),
|
71 |
+
tf.zeros([batch_size, self.num_attn_mixture_components]),
|
72 |
+
tf.zeros([batch_size, self.num_attn_mixture_components]),
|
73 |
+
tf.zeros([batch_size, self.window_size]),
|
74 |
+
tf.zeros([batch_size, self.char_len]),
|
75 |
+
)
|
76 |
+
|
77 |
+
def __call__(self, inputs, state, scope=None):
|
78 |
+
with tf.variable_scope(scope or type(self).__name__, reuse=tf.AUTO_REUSE):
|
79 |
+
|
80 |
+
# lstm 1
|
81 |
+
s1_in = tf.concat([state.w, inputs], axis=1)
|
82 |
+
cell1 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
|
83 |
+
s1_out, s1_state = cell1(s1_in, state=(state.c1, state.h1))
|
84 |
+
|
85 |
+
# attention
|
86 |
+
attention_inputs = tf.concat([state.w, inputs, s1_out], axis=1)
|
87 |
+
attention_params = dense_layer(attention_inputs, 3*self.num_attn_mixture_components, scope='attention')
|
88 |
+
alpha, beta, kappa = tf.split(tf.nn.softplus(attention_params), 3, axis=1)
|
89 |
+
kappa = state.kappa + kappa / 25.0
|
90 |
+
beta = tf.clip_by_value(beta, .01, np.inf)
|
91 |
+
|
92 |
+
kappa_flat, alpha_flat, beta_flat = kappa, alpha, beta
|
93 |
+
kappa, alpha, beta = tf.expand_dims(kappa, 2), tf.expand_dims(alpha, 2), tf.expand_dims(beta, 2)
|
94 |
+
|
95 |
+
enum = tf.reshape(tf.range(self.char_len), (1, 1, self.char_len))
|
96 |
+
u = tf.cast(tf.tile(enum, (self.batch_size, self.num_attn_mixture_components, 1)), tf.float32)
|
97 |
+
phi_flat = tf.reduce_sum(alpha*tf.exp(-tf.square(kappa - u) / beta), axis=1)
|
98 |
+
|
99 |
+
phi = tf.expand_dims(phi_flat, 2)
|
100 |
+
sequence_mask = tf.cast(tf.sequence_mask(self.attention_values_lengths, maxlen=self.char_len), tf.float32)
|
101 |
+
sequence_mask = tf.expand_dims(sequence_mask, 2)
|
102 |
+
w = tf.reduce_sum(phi*self.attention_values*sequence_mask, axis=1)
|
103 |
+
|
104 |
+
# lstm 2
|
105 |
+
s2_in = tf.concat([inputs, s1_out, w], axis=1)
|
106 |
+
cell2 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
|
107 |
+
s2_out, s2_state = cell2(s2_in, state=(state.c2, state.h2))
|
108 |
+
|
109 |
+
# lstm 3
|
110 |
+
s3_in = tf.concat([inputs, s2_out, w], axis=1)
|
111 |
+
cell3 = tf.compat.v1.nn.rnn_cell.LSTMCell(self.lstm_size)
|
112 |
+
s3_out, s3_state = cell3(s3_in, state=(state.c3, state.h3))
|
113 |
+
|
114 |
+
new_state = LSTMAttentionCellState(
|
115 |
+
s1_state.h,
|
116 |
+
s1_state.c,
|
117 |
+
s2_state.h,
|
118 |
+
s2_state.c,
|
119 |
+
s3_state.h,
|
120 |
+
s3_state.c,
|
121 |
+
alpha_flat,
|
122 |
+
beta_flat,
|
123 |
+
kappa_flat,
|
124 |
+
w,
|
125 |
+
phi_flat,
|
126 |
+
)
|
127 |
+
|
128 |
+
return s3_out, new_state
|
129 |
+
|
130 |
+
def output_function(self, state):
|
131 |
+
params = dense_layer(state.h3, self.output_units, scope='gmm', reuse=tf.AUTO_REUSE)
|
132 |
+
pis, mus, sigmas, rhos, es = self._parse_parameters(params)
|
133 |
+
mu1, mu2 = tf.split(mus, 2, axis=1)
|
134 |
+
mus = tf.stack([mu1, mu2], axis=2)
|
135 |
+
sigma1, sigma2 = tf.split(sigmas, 2, axis=1)
|
136 |
+
|
137 |
+
covar_matrix = [tf.square(sigma1), rhos*sigma1*sigma2,
|
138 |
+
rhos*sigma1*sigma2, tf.square(sigma2)]
|
139 |
+
covar_matrix = tf.stack(covar_matrix, axis=2)
|
140 |
+
covar_matrix = tf.reshape(covar_matrix, (self.batch_size, self.num_output_mixture_components, 2, 2))
|
141 |
+
|
142 |
+
mvn = tfd.MultivariateNormalFullCovariance(loc=mus, covariance_matrix=covar_matrix)
|
143 |
+
b = tfd.Bernoulli(probs=es)
|
144 |
+
c = tfd.Categorical(probs=pis)
|
145 |
+
|
146 |
+
sampled_e = b.sample()
|
147 |
+
sampled_coords = mvn.sample()
|
148 |
+
sampled_idx = c.sample()
|
149 |
+
|
150 |
+
idx = tf.stack([tf.range(self.batch_size), sampled_idx], axis=1)
|
151 |
+
coords = tf.gather_nd(sampled_coords, idx)
|
152 |
+
return tf.concat([coords, tf.cast(sampled_e, tf.float32)], axis=1)
|
153 |
+
|
154 |
+
def termination_condition(self, state):
|
155 |
+
char_idx = tf.cast(tf.argmax(state.phi, axis=1), tf.int32)
|
156 |
+
final_char = char_idx >= self.attention_values_lengths - 1
|
157 |
+
past_final_char = char_idx >= self.attention_values_lengths
|
158 |
+
output = self.output_function(state)
|
159 |
+
es = tf.cast(output[:, 2], tf.int32)
|
160 |
+
is_eos = tf.equal(es, tf.ones_like(es))
|
161 |
+
return tf.logical_or(tf.logical_and(final_char, is_eos), past_final_char)
|
162 |
+
|
163 |
+
def _parse_parameters(self, gmm_params, eps=1e-8, sigma_eps=1e-4):
|
164 |
+
pis, sigmas, rhos, mus, es = tf.split(
|
165 |
+
gmm_params,
|
166 |
+
[
|
167 |
+
1*self.num_output_mixture_components,
|
168 |
+
2*self.num_output_mixture_components,
|
169 |
+
1*self.num_output_mixture_components,
|
170 |
+
2*self.num_output_mixture_components,
|
171 |
+
1
|
172 |
+
],
|
173 |
+
axis=-1
|
174 |
+
)
|
175 |
+
pis = pis*(1 + tf.expand_dims(self.bias, 1))
|
176 |
+
sigmas = sigmas - tf.expand_dims(self.bias, 1)
|
177 |
+
|
178 |
+
pis = tf.nn.softmax(pis, axis=-1)
|
179 |
+
pis = tf.where(pis < .01, tf.zeros_like(pis), pis)
|
180 |
+
sigmas = tf.clip_by_value(tf.exp(sigmas), sigma_eps, np.inf)
|
181 |
+
rhos = tf.clip_by_value(tf.tanh(rhos), eps - 1.0, 1.0 - eps)
|
182 |
+
es = tf.clip_by_value(tf.nn.sigmoid(es), eps, 1.0 - eps)
|
183 |
+
es = tf.where(es < .01, tf.zeros_like(es), es)
|
184 |
+
|
185 |
+
return pis, mus, sigmas, rhos, es
|
rnn_ops.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow.compat.v1 as tf
|
2 |
+
tf.disable_v2_behavior()
|
3 |
+
from tensorflow.python.framework import constant_op
|
4 |
+
from tensorflow.python.framework import dtypes
|
5 |
+
from tensorflow.python.framework import ops
|
6 |
+
from tensorflow.python.ops import array_ops
|
7 |
+
from tensorflow.python.ops import control_flow_ops
|
8 |
+
from tensorflow.python.ops import cond as control_flow_ops_cond
|
9 |
+
from tensorflow.python.ops import math_ops
|
10 |
+
from tensorflow.python.ops import tensor_array_ops
|
11 |
+
from tensorflow.python.ops import variable_scope as vs
|
12 |
+
from tensorflow.python.ops.rnn_cell_impl import _concat, assert_like_rnncell
|
13 |
+
from tensorflow.python.ops.rnn import _maybe_tensor_shape_from_tensor
|
14 |
+
from tensorflow.python.util import nest
|
15 |
+
from tensorflow.python.framework import tensor_shape
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None):
|
20 |
+
"""
|
21 |
+
raw_rnn adapted from the original tensorflow implementation
|
22 |
+
(https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
|
23 |
+
to emit arbitrarily nested states for each time step (concatenated along the time axis)
|
24 |
+
in addition to the outputs at each timestep and the final state
|
25 |
+
|
26 |
+
returns (
|
27 |
+
states for all timesteps,
|
28 |
+
outputs for all timesteps,
|
29 |
+
final cell state,
|
30 |
+
)
|
31 |
+
"""
|
32 |
+
assert_like_rnncell("dummy_name", cell)
|
33 |
+
if not callable(loop_fn):
|
34 |
+
raise TypeError("loop_fn must be a callable")
|
35 |
+
|
36 |
+
parallel_iterations = parallel_iterations or 32
|
37 |
+
|
38 |
+
# Create a new scope in which the caching device is either
|
39 |
+
# determined by the parent scope, or is set to place the cached
|
40 |
+
# Variable using the same placement as for the rest of the RNN.
|
41 |
+
with vs.variable_scope(scope or "rnn") as varscope:
|
42 |
+
if not tf.executing_eagerly():
|
43 |
+
if varscope.caching_device is None:
|
44 |
+
varscope.set_caching_device(lambda op: op.device)
|
45 |
+
|
46 |
+
time = constant_op.constant(0, dtype=dtypes.int32)
|
47 |
+
(elements_finished, next_input, initial_state, emit_structure,
|
48 |
+
init_loop_state) = loop_fn(time, None, None, None)
|
49 |
+
flat_input = nest.flatten(next_input)
|
50 |
+
|
51 |
+
# Need a surrogate loop state for the while_loop if none is available.
|
52 |
+
loop_state = (init_loop_state if init_loop_state is not None
|
53 |
+
else constant_op.constant(0, dtype=dtypes.int32))
|
54 |
+
|
55 |
+
input_shape = [input_.get_shape() for input_ in flat_input]
|
56 |
+
static_batch_size = input_shape[0][0]
|
57 |
+
|
58 |
+
for input_shape_i in input_shape:
|
59 |
+
# Static verification that batch sizes all match
|
60 |
+
static_batch_size.merge_with(input_shape_i[0])
|
61 |
+
|
62 |
+
batch_size = static_batch_size.value
|
63 |
+
const_batch_size = batch_size
|
64 |
+
if batch_size is None:
|
65 |
+
batch_size = array_ops.shape(flat_input[0])[0]
|
66 |
+
|
67 |
+
nest.assert_same_structure(initial_state, cell.state_size)
|
68 |
+
state = initial_state
|
69 |
+
flat_state = nest.flatten(state)
|
70 |
+
flat_state = [ops.convert_to_tensor(s) for s in flat_state]
|
71 |
+
state = nest.pack_sequence_as(structure=state,
|
72 |
+
flat_sequence=flat_state)
|
73 |
+
|
74 |
+
if emit_structure is not None:
|
75 |
+
flat_emit_structure = nest.flatten(emit_structure)
|
76 |
+
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
|
77 |
+
array_ops.shape(emit) for emit in flat_emit_structure]
|
78 |
+
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
|
79 |
+
else:
|
80 |
+
emit_structure = cell.output_size
|
81 |
+
flat_emit_size = nest.flatten(emit_structure)
|
82 |
+
flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
|
83 |
+
|
84 |
+
flat_state_size = [s.shape if s.shape.is_fully_defined() else
|
85 |
+
array_ops.shape(s) for s in flat_state]
|
86 |
+
flat_state_dtypes = [s.dtype for s in flat_state]
|
87 |
+
|
88 |
+
flat_emit_ta = [
|
89 |
+
tensor_array_ops.TensorArray(
|
90 |
+
dtype=dtype_i,
|
91 |
+
dynamic_size=True,
|
92 |
+
element_shape=(tensor_shape.TensorShape([const_batch_size])
|
93 |
+
.concatenate(_maybe_tensor_shape_from_tensor(size_i))),
|
94 |
+
size=0,
|
95 |
+
name="rnn_output_%d" % i
|
96 |
+
)
|
97 |
+
for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
|
98 |
+
]
|
99 |
+
emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta)
|
100 |
+
flat_zero_emit = [
|
101 |
+
array_ops.zeros(_concat(batch_size, size_i), dtype_i)
|
102 |
+
for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]
|
103 |
+
|
104 |
+
zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit)
|
105 |
+
|
106 |
+
flat_state_ta = [
|
107 |
+
tensor_array_ops.TensorArray(
|
108 |
+
dtype=dtype_i,
|
109 |
+
dynamic_size=True,
|
110 |
+
element_shape=(tensor_shape.TensorShape([const_batch_size])
|
111 |
+
.concatenate(_maybe_tensor_shape_from_tensor(size_i))),
|
112 |
+
size=0,
|
113 |
+
name="rnn_state_%d" % i
|
114 |
+
)
|
115 |
+
for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
|
116 |
+
]
|
117 |
+
state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta)
|
118 |
+
|
119 |
+
def condition(unused_time, elements_finished, *_):
|
120 |
+
return math_ops.logical_not(math_ops.reduce_all(elements_finished))
|
121 |
+
|
122 |
+
def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state):
|
123 |
+
(next_output, cell_state) = cell(current_input, state)
|
124 |
+
|
125 |
+
nest.assert_same_structure(state, cell_state)
|
126 |
+
nest.assert_same_structure(cell.output_size, next_output)
|
127 |
+
|
128 |
+
next_time = time + 1
|
129 |
+
(next_finished, next_input, next_state, emit_output,
|
130 |
+
next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state)
|
131 |
+
|
132 |
+
nest.assert_same_structure(state, next_state)
|
133 |
+
nest.assert_same_structure(current_input, next_input)
|
134 |
+
nest.assert_same_structure(emit_ta, emit_output)
|
135 |
+
|
136 |
+
# If loop_fn returns None for next_loop_state, just reuse the previous one.
|
137 |
+
loop_state = loop_state if next_loop_state is None else next_loop_state
|
138 |
+
|
139 |
+
def _copy_some_through(current, candidate):
|
140 |
+
"""Copy some tensors through via array_ops.where."""
|
141 |
+
def copy_fn(cur_i, cand_i):
|
142 |
+
# TensorArray and scalar get passed through.
|
143 |
+
if isinstance(cur_i, tensor_array_ops.TensorArray):
|
144 |
+
return cand_i
|
145 |
+
if cur_i.shape.ndims == 0:
|
146 |
+
return cand_i
|
147 |
+
# Otherwise propagate the old or the new value.
|
148 |
+
with ops.colocate_with(cand_i):
|
149 |
+
return array_ops.where(elements_finished, cur_i, cand_i)
|
150 |
+
return nest.map_structure(copy_fn, current, candidate)
|
151 |
+
|
152 |
+
emit_output = _copy_some_through(zero_emit, emit_output)
|
153 |
+
next_state = _copy_some_through(state, next_state)
|
154 |
+
|
155 |
+
emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
|
156 |
+
state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state)
|
157 |
+
|
158 |
+
elements_finished = math_ops.logical_or(elements_finished, next_finished)
|
159 |
+
|
160 |
+
return (next_time, elements_finished, next_input, state_ta,
|
161 |
+
emit_ta, next_state, loop_state)
|
162 |
+
|
163 |
+
returned = tf.while_loop(
|
164 |
+
condition, body, loop_vars=[
|
165 |
+
time, elements_finished, next_input, state_ta,
|
166 |
+
emit_ta, state, loop_state],
|
167 |
+
parallel_iterations=parallel_iterations,
|
168 |
+
swap_memory=swap_memory
|
169 |
+
)
|
170 |
+
|
171 |
+
(state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]
|
172 |
+
|
173 |
+
flat_states = nest.flatten(state_ta)
|
174 |
+
flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states]
|
175 |
+
states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states)
|
176 |
+
|
177 |
+
flat_outputs = nest.flatten(emit_ta)
|
178 |
+
flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs]
|
179 |
+
outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs)
|
180 |
+
|
181 |
+
return (states, outputs, final_state)
|
182 |
+
|
183 |
+
|
184 |
+
def rnn_teacher_force(inputs, cell, sequence_length, initial_state, scope='dynamic-rnn-teacher-force'):
|
185 |
+
"""
|
186 |
+
Implementation of an rnn with teacher forcing inputs provided.
|
187 |
+
Used in the same way as tf.dynamic_rnn.
|
188 |
+
"""
|
189 |
+
inputs = array_ops.transpose(inputs, (1, 0, 2))
|
190 |
+
inputs_ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
|
191 |
+
inputs_ta = inputs_ta.unstack(inputs)
|
192 |
+
|
193 |
+
def loop_fn(time, cell_output, cell_state, loop_state):
|
194 |
+
emit_output = cell_output
|
195 |
+
next_cell_state = initial_state if cell_output is None else cell_state
|
196 |
+
|
197 |
+
elements_finished = time >= sequence_length
|
198 |
+
finished = math_ops.reduce_all(elements_finished)
|
199 |
+
|
200 |
+
next_input = control_flow_ops_cond.cond(
|
201 |
+
finished,
|
202 |
+
lambda: array_ops.zeros([array_ops.shape(inputs)[1], inputs.shape.as_list()[2]], dtype=dtypes.float32),
|
203 |
+
lambda: inputs_ta.read(time)
|
204 |
+
)
|
205 |
+
|
206 |
+
next_loop_state = None
|
207 |
+
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
|
208 |
+
|
209 |
+
states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope)
|
210 |
+
return states, outputs, final_state
|
211 |
+
|
212 |
+
|
213 |
+
def rnn_free_run(cell, initial_state, sequence_length, initial_input=None, scope='dynamic-rnn-free-run'):
|
214 |
+
"""
|
215 |
+
Implementation of an rnn which feeds its feeds its predictions back to itself at the next timestep.
|
216 |
+
|
217 |
+
cell must implement two methods:
|
218 |
+
|
219 |
+
cell.output_function(state) which takes in the state at timestep t and returns
|
220 |
+
the cell input at timestep t+1.
|
221 |
+
|
222 |
+
cell.termination_condition(state) which returns a boolean tensor of shape
|
223 |
+
[batch_size] denoting which sequences no longer need to be sampled.
|
224 |
+
"""
|
225 |
+
with vs.variable_scope(scope, reuse=True):
|
226 |
+
if initial_input is None:
|
227 |
+
initial_input = cell.output_function(initial_state)
|
228 |
+
|
229 |
+
def loop_fn(time, cell_output, cell_state, loop_state):
|
230 |
+
next_cell_state = initial_state if cell_output is None else cell_state
|
231 |
+
|
232 |
+
elements_finished = math_ops.logical_or(
|
233 |
+
time >= sequence_length,
|
234 |
+
cell.termination_condition(next_cell_state)
|
235 |
+
)
|
236 |
+
finished = math_ops.reduce_all(elements_finished)
|
237 |
+
|
238 |
+
next_input = control_flow_ops_cond.cond(
|
239 |
+
finished,
|
240 |
+
lambda: array_ops.zeros_like(initial_input),
|
241 |
+
lambda: initial_input if cell_output is None else cell.output_function(next_cell_state)
|
242 |
+
)
|
243 |
+
emit_output = next_input[0] if cell_output is None else next_input
|
244 |
+
|
245 |
+
next_loop_state = None
|
246 |
+
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
|
247 |
+
|
248 |
+
states, outputs, final_state = raw_rnn(cell, loop_fn, scope=scope)
|
249 |
+
return states, outputs, final_state
|
tf_base_model.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
from collections import deque
|
3 |
+
from datetime import datetime
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import pprint as pp
|
7 |
+
import time
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import tensorflow.compat.v1 as tf
|
11 |
+
tf.disable_v2_behavior()
|
12 |
+
|
13 |
+
from tf_utils import shape
|
14 |
+
|
15 |
+
|
16 |
+
class TFBaseModel(object):
|
17 |
+
|
18 |
+
"""Interface containing some boilerplate code for training tensorflow models.
|
19 |
+
|
20 |
+
Subclassing models must implement self.calculate_loss(), which returns a tensor for the batch loss.
|
21 |
+
Code for the training loop, parameter updates, checkpointing, and inference are implemented here and
|
22 |
+
subclasses are mainly responsible for building the computational graph beginning with the placeholders
|
23 |
+
and ending with the loss tensor.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
reader: Class with attributes train_batch_generator, val_batch_generator, and test_batch_generator
|
27 |
+
that yield dictionaries mapping tf.placeholder names (as strings) to batch data (numpy arrays).
|
28 |
+
batch_size: Minibatch size.
|
29 |
+
learning_rate: Learning rate.
|
30 |
+
optimizer: 'rms' for RMSProp, 'adam' for Adam, 'sgd' for SGD
|
31 |
+
grad_clip: Clip gradients elementwise to have norm at most equal to grad_clip.
|
32 |
+
regularization_constant: Regularization constant applied to all trainable parameters.
|
33 |
+
keep_prob: 1 - p, where p is the dropout probability
|
34 |
+
early_stopping_steps: Number of steps to continue training after validation loss has
|
35 |
+
stopped decreasing.
|
36 |
+
warm_start_init_step: If nonzero, model will resume training a restored model beginning
|
37 |
+
at warm_start_init_step.
|
38 |
+
num_restarts: After validation loss plateaus, the best checkpoint will be restored and the
|
39 |
+
learning rate will be halved. This process will repeat num_restarts times.
|
40 |
+
enable_parameter_averaging: If true, model saves exponential weighted averages of parameters
|
41 |
+
to separate checkpoint file.
|
42 |
+
min_steps_to_checkpoint: Model only saves after min_steps_to_checkpoint training steps
|
43 |
+
have passed.
|
44 |
+
log_interval: Train and validation accuracies are logged every log_interval training steps.
|
45 |
+
loss_averaging_window: Train/validation losses are averaged over the last loss_averaging_window
|
46 |
+
training steps.
|
47 |
+
num_validation_batches: Number of batches to be used in validation evaluation at each step.
|
48 |
+
log_dir: Directory where logs are written.
|
49 |
+
checkpoint_dir: Directory where checkpoints are saved.
|
50 |
+
prediction_dir: Directory where predictions/outputs are saved.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
reader=None,
|
56 |
+
batch_sizes=[128],
|
57 |
+
num_training_steps=20000,
|
58 |
+
learning_rates=[.01],
|
59 |
+
beta1_decays=[.99],
|
60 |
+
optimizer='adam',
|
61 |
+
grad_clip=5,
|
62 |
+
regularization_constant=0.0,
|
63 |
+
keep_prob=1.0,
|
64 |
+
patiences=[3000],
|
65 |
+
warm_start_init_step=0,
|
66 |
+
enable_parameter_averaging=False,
|
67 |
+
min_steps_to_checkpoint=100,
|
68 |
+
log_interval=20,
|
69 |
+
logging_level=logging.INFO,
|
70 |
+
loss_averaging_window=100,
|
71 |
+
validation_batch_size=64,
|
72 |
+
log_dir='logs',
|
73 |
+
checkpoint_dir='checkpoints',
|
74 |
+
prediction_dir='predictions',
|
75 |
+
):
|
76 |
+
|
77 |
+
assert len(batch_sizes) == len(learning_rates) == len(patiences)
|
78 |
+
self.batch_sizes = batch_sizes
|
79 |
+
self.learning_rates = learning_rates
|
80 |
+
self.beta1_decays = beta1_decays
|
81 |
+
self.patiences = patiences
|
82 |
+
self.num_restarts = len(batch_sizes) - 1
|
83 |
+
self.restart_idx = 0
|
84 |
+
self.update_train_params()
|
85 |
+
|
86 |
+
self.reader = reader
|
87 |
+
self.num_training_steps = num_training_steps
|
88 |
+
self.optimizer = optimizer
|
89 |
+
self.grad_clip = grad_clip
|
90 |
+
self.regularization_constant = regularization_constant
|
91 |
+
self.warm_start_init_step = warm_start_init_step
|
92 |
+
self.keep_prob_scalar = keep_prob
|
93 |
+
self.enable_parameter_averaging = enable_parameter_averaging
|
94 |
+
self.min_steps_to_checkpoint = min_steps_to_checkpoint
|
95 |
+
self.log_interval = log_interval
|
96 |
+
self.loss_averaging_window = loss_averaging_window
|
97 |
+
self.validation_batch_size = validation_batch_size
|
98 |
+
|
99 |
+
self.log_dir = log_dir
|
100 |
+
self.logging_level = logging_level
|
101 |
+
self.prediction_dir = prediction_dir
|
102 |
+
self.checkpoint_dir = checkpoint_dir
|
103 |
+
if self.enable_parameter_averaging:
|
104 |
+
self.checkpoint_dir_averaged = checkpoint_dir + '_avg'
|
105 |
+
|
106 |
+
self.init_logging(self.log_dir)
|
107 |
+
logging.info('\nnew run with parameters:\n{}'.format(pp.pformat(self.__dict__)))
|
108 |
+
|
109 |
+
self.graph = self.build_graph()
|
110 |
+
self.session = tf.Session(graph=self.graph)
|
111 |
+
logging.info('built graph')
|
112 |
+
|
113 |
+
def update_train_params(self):
|
114 |
+
self.batch_size = self.batch_sizes[self.restart_idx]
|
115 |
+
self.learning_rate = self.learning_rates[self.restart_idx]
|
116 |
+
self.beta1_decay = self.beta1_decays[self.restart_idx]
|
117 |
+
self.early_stopping_steps = self.patiences[self.restart_idx]
|
118 |
+
|
119 |
+
def calculate_loss(self):
|
120 |
+
raise NotImplementedError('subclass must implement this')
|
121 |
+
|
122 |
+
def fit(self):
|
123 |
+
with self.session.as_default():
|
124 |
+
|
125 |
+
if self.warm_start_init_step:
|
126 |
+
self.restore(self.warm_start_init_step)
|
127 |
+
step = self.warm_start_init_step
|
128 |
+
else:
|
129 |
+
self.session.run(self.init)
|
130 |
+
step = 0
|
131 |
+
|
132 |
+
train_generator = self.reader.train_batch_generator(self.batch_size)
|
133 |
+
val_generator = self.reader.val_batch_generator(self.validation_batch_size)
|
134 |
+
|
135 |
+
train_loss_history = deque(maxlen=self.loss_averaging_window)
|
136 |
+
val_loss_history = deque(maxlen=self.loss_averaging_window)
|
137 |
+
train_time_history = deque(maxlen=self.loss_averaging_window)
|
138 |
+
val_time_history = deque(maxlen=self.loss_averaging_window)
|
139 |
+
if not hasattr(self, 'metrics'):
|
140 |
+
self.metrics = {}
|
141 |
+
|
142 |
+
metric_histories = {
|
143 |
+
metric_name: deque(maxlen=self.loss_averaging_window) for metric_name in self.metrics
|
144 |
+
}
|
145 |
+
best_validation_loss, best_validation_tstep = float('inf'), 0
|
146 |
+
|
147 |
+
while step < self.num_training_steps:
|
148 |
+
|
149 |
+
# validation evaluation
|
150 |
+
val_start = time.time()
|
151 |
+
val_batch_df = next(val_generator)
|
152 |
+
val_feed_dict = {
|
153 |
+
getattr(self, placeholder_name, None): data
|
154 |
+
for placeholder_name, data in val_batch_df.items() if hasattr(self, placeholder_name)
|
155 |
+
}
|
156 |
+
|
157 |
+
val_feed_dict.update({self.learning_rate_var: self.learning_rate, self.beta1_decay_var: self.beta1_decay})
|
158 |
+
if hasattr(self, 'keep_prob'):
|
159 |
+
val_feed_dict.update({self.keep_prob: 1.0})
|
160 |
+
if hasattr(self, 'is_training'):
|
161 |
+
val_feed_dict.update({self.is_training: False})
|
162 |
+
|
163 |
+
results = self.session.run(
|
164 |
+
fetches=[self.loss] + self.metrics.values(),
|
165 |
+
feed_dict=val_feed_dict
|
166 |
+
)
|
167 |
+
val_loss = results[0]
|
168 |
+
val_metrics = results[1:] if len(results) > 1 else []
|
169 |
+
val_metrics = dict(zip(self.metrics.keys(), val_metrics))
|
170 |
+
val_loss_history.append(val_loss)
|
171 |
+
val_time_history.append(time.time() - val_start)
|
172 |
+
for key in val_metrics:
|
173 |
+
metric_histories[key].append(val_metrics[key])
|
174 |
+
|
175 |
+
if hasattr(self, 'monitor_tensors'):
|
176 |
+
for name, tensor in self.monitor_tensors.items():
|
177 |
+
[np_val] = self.session.run([tensor], feed_dict=val_feed_dict)
|
178 |
+
print(name)
|
179 |
+
print('min', np_val.min())
|
180 |
+
print('max', np_val.max())
|
181 |
+
print('mean', np_val.mean())
|
182 |
+
print('std', np_val.std())
|
183 |
+
print('nans', np.isnan(np_val).sum())
|
184 |
+
print()
|
185 |
+
print()
|
186 |
+
print()
|
187 |
+
|
188 |
+
# train step
|
189 |
+
train_start = time.time()
|
190 |
+
train_batch_df = next(train_generator)
|
191 |
+
train_feed_dict = {
|
192 |
+
getattr(self, placeholder_name, None): data
|
193 |
+
for placeholder_name, data in train_batch_df.items() if hasattr(self, placeholder_name)
|
194 |
+
}
|
195 |
+
|
196 |
+
train_feed_dict.update({self.learning_rate_var: self.learning_rate, self.beta1_decay_var: self.beta1_decay})
|
197 |
+
if hasattr(self, 'keep_prob'):
|
198 |
+
train_feed_dict.update({self.keep_prob: self.keep_prob_scalar})
|
199 |
+
if hasattr(self, 'is_training'):
|
200 |
+
train_feed_dict.update({self.is_training: True})
|
201 |
+
|
202 |
+
train_loss, _ = self.session.run(
|
203 |
+
fetches=[self.loss, self.step],
|
204 |
+
feed_dict=train_feed_dict
|
205 |
+
)
|
206 |
+
train_loss_history.append(train_loss)
|
207 |
+
train_time_history.append(time.time() - train_start)
|
208 |
+
|
209 |
+
if step % self.log_interval == 0:
|
210 |
+
avg_train_loss = sum(train_loss_history) / len(train_loss_history)
|
211 |
+
avg_val_loss = sum(val_loss_history) / len(val_loss_history)
|
212 |
+
avg_train_time = sum(train_time_history) / len(train_time_history)
|
213 |
+
avg_val_time = sum(val_time_history) / len(val_time_history)
|
214 |
+
metric_log = (
|
215 |
+
"[[step {:>8}]] "
|
216 |
+
"[[train {:>4}s]] loss: {:<12} "
|
217 |
+
"[[val {:>4}s]] loss: {:<12} "
|
218 |
+
).format(
|
219 |
+
step,
|
220 |
+
round(avg_train_time, 4),
|
221 |
+
round(avg_train_loss, 8),
|
222 |
+
round(avg_val_time, 4),
|
223 |
+
round(avg_val_loss, 8),
|
224 |
+
)
|
225 |
+
early_stopping_metric = avg_val_loss
|
226 |
+
for metric_name, metric_history in metric_histories.items():
|
227 |
+
metric_val = sum(metric_history) / len(metric_history)
|
228 |
+
metric_log += '{}: {:<4} '.format(metric_name, round(metric_val, 4))
|
229 |
+
if metric_name == self.early_stopping_metric:
|
230 |
+
early_stopping_metric = metric_val
|
231 |
+
|
232 |
+
logging.info(metric_log)
|
233 |
+
|
234 |
+
if early_stopping_metric < best_validation_loss:
|
235 |
+
best_validation_loss = early_stopping_metric
|
236 |
+
best_validation_tstep = step
|
237 |
+
if step > self.min_steps_to_checkpoint:
|
238 |
+
self.save(step)
|
239 |
+
if self.enable_parameter_averaging:
|
240 |
+
self.save(step, averaged=True)
|
241 |
+
|
242 |
+
if step - best_validation_tstep > self.early_stopping_steps:
|
243 |
+
|
244 |
+
if self.num_restarts is None or self.restart_idx >= self.num_restarts:
|
245 |
+
logging.info('best validation loss of {} at training step {}'.format(
|
246 |
+
best_validation_loss, best_validation_tstep))
|
247 |
+
logging.info('early stopping - ending training.')
|
248 |
+
return
|
249 |
+
|
250 |
+
if self.restart_idx < self.num_restarts:
|
251 |
+
self.restore(best_validation_tstep)
|
252 |
+
step = best_validation_tstep
|
253 |
+
self.restart_idx += 1
|
254 |
+
self.update_train_params()
|
255 |
+
train_generator = self.reader.train_batch_generator(self.batch_size)
|
256 |
+
|
257 |
+
step += 1
|
258 |
+
|
259 |
+
if step <= self.min_steps_to_checkpoint:
|
260 |
+
best_validation_tstep = step
|
261 |
+
self.save(step)
|
262 |
+
if self.enable_parameter_averaging:
|
263 |
+
self.save(step, averaged=True)
|
264 |
+
|
265 |
+
logging.info('num_training_steps reached - ending training')
|
266 |
+
|
267 |
+
def predict(self, chunk_size=256):
|
268 |
+
if not os.path.isdir(self.prediction_dir):
|
269 |
+
os.makedirs(self.prediction_dir)
|
270 |
+
|
271 |
+
if hasattr(self, 'prediction_tensors'):
|
272 |
+
prediction_dict = {tensor_name: [] for tensor_name in self.prediction_tensors}
|
273 |
+
|
274 |
+
test_generator = self.reader.test_batch_generator(chunk_size)
|
275 |
+
for i, test_batch_df in enumerate(test_generator):
|
276 |
+
if i % 10 == 0:
|
277 |
+
print(i*len(test_batch_df))
|
278 |
+
|
279 |
+
test_feed_dict = {
|
280 |
+
getattr(self, placeholder_name, None): data
|
281 |
+
for placeholder_name, data in test_batch_df.items() if hasattr(self, placeholder_name)
|
282 |
+
}
|
283 |
+
if hasattr(self, 'keep_prob'):
|
284 |
+
test_feed_dict.update({self.keep_prob: 1.0})
|
285 |
+
if hasattr(self, 'is_training'):
|
286 |
+
test_feed_dict.update({self.is_training: False})
|
287 |
+
|
288 |
+
tensor_names, tf_tensors = zip(*self.prediction_tensors.items())
|
289 |
+
np_tensors = self.session.run(
|
290 |
+
fetches=tf_tensors,
|
291 |
+
feed_dict=test_feed_dict
|
292 |
+
)
|
293 |
+
for tensor_name, tensor in zip(tensor_names, np_tensors):
|
294 |
+
prediction_dict[tensor_name].append(tensor)
|
295 |
+
|
296 |
+
for tensor_name, tensor in prediction_dict.items():
|
297 |
+
np_tensor = np.concatenate(tensor, 0)
|
298 |
+
save_file = os.path.join(self.prediction_dir, '{}.npy'.format(tensor_name))
|
299 |
+
logging.info('saving {} with shape {} to {}'.format(tensor_name, np_tensor.shape, save_file))
|
300 |
+
np.save(save_file, np_tensor)
|
301 |
+
|
302 |
+
if hasattr(self, 'parameter_tensors'):
|
303 |
+
for tensor_name, tensor in self.parameter_tensors.items():
|
304 |
+
np_tensor = tensor.eval(self.session)
|
305 |
+
|
306 |
+
save_file = os.path.join(self.prediction_dir, '{}.npy'.format(tensor_name))
|
307 |
+
logging.info('saving {} with shape {} to {}'.format(tensor_name, np_tensor.shape, save_file))
|
308 |
+
np.save(save_file, np_tensor)
|
309 |
+
|
310 |
+
def save(self, step, averaged=False):
|
311 |
+
saver = self.saver_averaged if averaged else self.saver
|
312 |
+
checkpoint_dir = self.checkpoint_dir_averaged if averaged else self.checkpoint_dir
|
313 |
+
if not os.path.isdir(checkpoint_dir):
|
314 |
+
logging.info('creating checkpoint directory {}'.format(checkpoint_dir))
|
315 |
+
os.mkdir(checkpoint_dir)
|
316 |
+
|
317 |
+
model_path = os.path.join(checkpoint_dir, 'model')
|
318 |
+
logging.info('saving model to {}'.format(model_path))
|
319 |
+
saver.save(self.session, model_path, global_step=step)
|
320 |
+
|
321 |
+
def restore(self, step=None, averaged=False):
|
322 |
+
saver = self.saver_averaged if averaged else self.saver
|
323 |
+
checkpoint_dir = self.checkpoint_dir_averaged if averaged else self.checkpoint_dir
|
324 |
+
if not step:
|
325 |
+
model_path = tf.train.latest_checkpoint(checkpoint_dir)
|
326 |
+
logging.info('restoring model parameters from {}'.format(model_path))
|
327 |
+
saver.restore(self.session, model_path)
|
328 |
+
else:
|
329 |
+
model_path = os.path.join(
|
330 |
+
checkpoint_dir, 'model{}-{}'.format('_avg' if averaged else '', step)
|
331 |
+
)
|
332 |
+
logging.info('restoring model from {}'.format(model_path))
|
333 |
+
saver.restore(self.session, model_path)
|
334 |
+
|
335 |
+
def init_logging(self, log_dir):
|
336 |
+
if not os.path.isdir(log_dir):
|
337 |
+
os.makedirs(log_dir)
|
338 |
+
|
339 |
+
date_str = datetime.now().strftime('%Y-%m-%d_%H-%M')
|
340 |
+
log_file = 'log_{}.txt'.format(date_str)
|
341 |
+
|
342 |
+
try: # Python 2
|
343 |
+
reload(logging) # bad
|
344 |
+
except NameError: # Python 3
|
345 |
+
import logging
|
346 |
+
logging.basicConfig(
|
347 |
+
filename=os.path.join(log_dir, log_file),
|
348 |
+
level=self.logging_level,
|
349 |
+
format='[[%(asctime)s]] %(message)s',
|
350 |
+
datefmt='%m/%d/%Y %I:%M:%S %p'
|
351 |
+
)
|
352 |
+
logging.getLogger().addHandler(logging.StreamHandler())
|
353 |
+
|
354 |
+
def update_parameters(self, loss):
|
355 |
+
if self.regularization_constant != 0:
|
356 |
+
l2_norm = tf.reduce_sum([tf.sqrt(tf.reduce_sum(tf.square(param))) for param in tf.trainable_variables()])
|
357 |
+
loss = loss + self.regularization_constant*l2_norm
|
358 |
+
|
359 |
+
optimizer = self.get_optimizer(self.learning_rate_var, self.beta1_decay_var)
|
360 |
+
grads = optimizer.compute_gradients(loss)
|
361 |
+
clipped = [(tf.clip_by_value(g, -self.grad_clip, self.grad_clip), v_) for g, v_ in grads]
|
362 |
+
|
363 |
+
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
364 |
+
with tf.control_dependencies(update_ops):
|
365 |
+
step = optimizer.apply_gradients(clipped, global_step=self.global_step)
|
366 |
+
|
367 |
+
if self.enable_parameter_averaging:
|
368 |
+
maintain_averages_op = self.ema.apply(tf.trainable_variables())
|
369 |
+
with tf.control_dependencies([step]):
|
370 |
+
self.step = tf.group(maintain_averages_op)
|
371 |
+
else:
|
372 |
+
self.step = step
|
373 |
+
|
374 |
+
logging.info('all parameters:')
|
375 |
+
logging.info(pp.pformat([(var.name, shape(var)) for var in tf.global_variables()]))
|
376 |
+
|
377 |
+
logging.info('trainable parameters:')
|
378 |
+
logging.info(pp.pformat([(var.name, shape(var)) for var in tf.trainable_variables()]))
|
379 |
+
|
380 |
+
logging.info('trainable parameter count:')
|
381 |
+
logging.info(str(np.sum(np.prod(shape(var)) for var in tf.trainable_variables())))
|
382 |
+
|
383 |
+
def get_optimizer(self, learning_rate, beta1_decay):
|
384 |
+
if self.optimizer == 'adam':
|
385 |
+
return tf.train.AdamOptimizer(learning_rate, beta1=beta1_decay)
|
386 |
+
elif self.optimizer == 'gd':
|
387 |
+
return tf.train.GradientDescentOptimizer(learning_rate)
|
388 |
+
elif self.optimizer == 'rms':
|
389 |
+
return tf.train.RMSPropOptimizer(learning_rate, decay=beta1_decay, momentum=0.9)
|
390 |
+
else:
|
391 |
+
assert False, 'optimizer must be adam, gd, or rms'
|
392 |
+
|
393 |
+
def build_graph(self):
|
394 |
+
with tf.Graph().as_default() as graph:
|
395 |
+
self.ema = tf.train.ExponentialMovingAverage(decay=0.99)
|
396 |
+
self.global_step = tf.Variable(0, trainable=False)
|
397 |
+
self.learning_rate_var = tf.Variable(0.0, trainable=False)
|
398 |
+
self.beta1_decay_var = tf.Variable(0.0, trainable=False)
|
399 |
+
|
400 |
+
self.loss = self.calculate_loss()
|
401 |
+
self.update_parameters(self.loss)
|
402 |
+
|
403 |
+
self.saver = tf.train.Saver(max_to_keep=1)
|
404 |
+
if self.enable_parameter_averaging:
|
405 |
+
self.saver_averaged = tf.train.Saver(self.ema.variables_to_restore(), max_to_keep=1)
|
406 |
+
|
407 |
+
self.init = tf.global_variables_initializer()
|
408 |
+
return graph
|
tf_utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow.compat.v1 as tf
|
2 |
+
tf.disable_v2_behavior()
|
3 |
+
|
4 |
+
|
5 |
+
def dense_layer(inputs, output_units, bias=True, activation=None, batch_norm=None,
|
6 |
+
dropout=None, scope='dense-layer', reuse=False):
|
7 |
+
"""
|
8 |
+
Applies a dense layer to a 2D tensor of shape [batch_size, input_units]
|
9 |
+
to produce a tensor of shape [batch_size, output_units].
|
10 |
+
Args:
|
11 |
+
inputs: Tensor of shape [batch size, input_units].
|
12 |
+
output_units: Number of output units.
|
13 |
+
activation: activation function.
|
14 |
+
dropout: dropout keep prob.
|
15 |
+
Returns:
|
16 |
+
Tensor of shape [batch size, output_units].
|
17 |
+
"""
|
18 |
+
with tf.variable_scope(scope, reuse=reuse):
|
19 |
+
W = tf.get_variable(
|
20 |
+
name='weights',
|
21 |
+
initializer=tf.compat.v1.variance_scaling_initializer(),
|
22 |
+
shape=[shape(inputs, -1), output_units]
|
23 |
+
)
|
24 |
+
z = tf.matmul(inputs, W)
|
25 |
+
if bias:
|
26 |
+
b = tf.get_variable(
|
27 |
+
name='biases',
|
28 |
+
initializer=tf.constant_initializer(),
|
29 |
+
shape=[output_units]
|
30 |
+
)
|
31 |
+
z = z + b
|
32 |
+
|
33 |
+
if batch_norm is not None:
|
34 |
+
z = tf.layers.batch_normalization(z, training=batch_norm, reuse=reuse)
|
35 |
+
|
36 |
+
z = activation(z) if activation else z
|
37 |
+
z = tf.nn.dropout(z, dropout) if dropout is not None else z
|
38 |
+
return z
|
39 |
+
|
40 |
+
|
41 |
+
def time_distributed_dense_layer(
|
42 |
+
inputs, output_units, bias=True, activation=None, batch_norm=None,
|
43 |
+
dropout=None, scope='time-distributed-dense-layer', reuse=False):
|
44 |
+
"""
|
45 |
+
Applies a shared dense layer to each timestep of a tensor of shape
|
46 |
+
[batch_size, max_seq_len, input_units] to produce a tensor of shape
|
47 |
+
[batch_size, max_seq_len, output_units].
|
48 |
+
|
49 |
+
Args:
|
50 |
+
inputs: Tensor of shape [batch size, max sequence length, ...].
|
51 |
+
output_units: Number of output units.
|
52 |
+
activation: activation function.
|
53 |
+
dropout: dropout keep prob.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Tensor of shape [batch size, max sequence length, output_units].
|
57 |
+
"""
|
58 |
+
with tf.variable_scope(scope, reuse=reuse):
|
59 |
+
W = tf.get_variable(
|
60 |
+
name='weights',
|
61 |
+
initializer=tf.compat.v1.variance_scaling_initializer(),
|
62 |
+
shape=[shape(inputs, -1), output_units]
|
63 |
+
)
|
64 |
+
z = tf.einsum('ijk,kl->ijl', inputs, W)
|
65 |
+
if bias:
|
66 |
+
b = tf.get_variable(
|
67 |
+
name='biases',
|
68 |
+
initializer=tf.constant_initializer(),
|
69 |
+
shape=[output_units]
|
70 |
+
)
|
71 |
+
z = z + b
|
72 |
+
|
73 |
+
if batch_norm is not None:
|
74 |
+
z = tf.layers.batch_normalization(z, training=batch_norm, reuse=reuse)
|
75 |
+
|
76 |
+
z = activation(z) if activation else z
|
77 |
+
z = tf.nn.dropout(z, dropout) if dropout is not None else z
|
78 |
+
return z
|
79 |
+
|
80 |
+
|
81 |
+
def shape(tensor, dim=None):
|
82 |
+
"""Get tensor shape/dimension as list/int"""
|
83 |
+
if dim is None:
|
84 |
+
return tensor.shape.as_list()
|
85 |
+
else:
|
86 |
+
return tensor.shape.as_list()[dim]
|
87 |
+
|
88 |
+
|
89 |
+
def rank(tensor):
|
90 |
+
"""Get tensor rank as python list"""
|
91 |
+
return len(tensor.shape.as_list())
|