File size: 14,709 Bytes
b65c5e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import os
import numpy as np
import torch
import random
from PIL import Image, ImageDraw, ImageFont
import pickle
from config.GlobalVariables import *

np.random.seed(0)

class DataLoader():
	def __init__(self, num_writer=2, num_samples=5, divider=10.0, datadir='./data/writers'):
		self.device			= 'cuda' if torch.cuda.is_available() else 'cpu'
		self.num_writer		= num_writer
		self.num_samples	= num_samples
		self.divider		= divider
		self.datadir		= datadir
		print ('self.datadir : ', self.datadir)
		self.total_writers	= len([name for name in os.listdir(datadir)])

	def next_batch(self, TYPE='TRAIN', uid=-1, tids=[]):
		all_sentence_level_stroke_in		= []
		all_sentence_level_stroke_out		= []
		all_sentence_level_stroke_length	= []
		all_sentence_level_term				= []
		all_sentence_level_char				= []
		all_sentence_level_char_length		= []
		all_word_level_stroke_in			= []
		all_word_level_stroke_out			= []
		all_word_level_stroke_length		= []
		all_word_level_term					= []
		all_word_level_char					= []
		all_word_level_char_length			= []
		all_segment_level_stroke_in			= []
		all_segment_level_stroke_out		= []
		all_segment_level_stroke_length		= []
		all_segment_level_term				= []
		all_segment_level_char				= []
		all_segment_level_char_length		= []

		while len(all_sentence_level_stroke_in) < self.num_writer:
			if uid < 0:
				if TYPE == 'TRAIN':
					if self.datadir == './data/NEW_writers' or self.datadir == './data/writers':
						uid = np.random.choice([i for i in range(150)])
					else:
						if self.device == 'cpu':
							uid = np.random.choice([i for i in range(20)])
						else:
							uid = np.random.choice([i for i in range(294)])
				else:
					uid = np.random.choice([i for i in range(150,170)])

			total_texts				= len([name for name in os.listdir(self.datadir+'/'+str(uid))])
			if len(tids) == 0:
				tids = random.sample([i for i in range(total_texts)], self.num_samples)

			user_sentence_level_stroke_in		= []
			user_sentence_level_stroke_out		= []
			user_sentence_level_stroke_length	= []
			user_sentence_level_term			= []
			user_sentence_level_char			= []
			user_sentence_level_char_length		= []
			user_word_level_stroke_in			= []
			user_word_level_stroke_out			= []
			user_word_level_stroke_length		= []
			user_word_level_term				= []
			user_word_level_char				= []
			user_word_level_char_length			= []
			user_segment_level_stroke_in		= []
			user_segment_level_stroke_out		= []
			user_segment_level_stroke_length	= []
			user_segment_level_term				= []
			user_segment_level_char				= []
			user_segment_level_char_length		= []

			# print ("uid: ", uid, "\ttids:", tids)
			for tid in tids:
				if self.datadir == './data/NEW_writers':
					[sentence_level_raw_stroke, sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_term, sentence_level_char, word_level_raw_stroke, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_raw_stroke, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char] = \
						np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes')

				elif self.datadir == './data/DW_writers':
					[sentence_level_raw_stroke, sentence_level_char, sentence_level_term, sentence_level_stroke_in, sentence_level_stroke_out,
					word_level_raw_stroke, word_level_char, word_level_term, word_level_stroke_in, word_level_stroke_out,
					segment_level_raw_stroke, segment_level_char, segment_level_term, segment_level_stroke_in, segment_level_stroke_out, _] = \
						np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes')

				elif self.datadir == './data/VALID_DW_writers':
					[sentence_level_raw_stroke, sentence_level_char, sentence_level_term, sentence_level_stroke_in, sentence_level_stroke_out,
					word_level_raw_stroke, word_level_char, word_level_term, word_level_stroke_in, word_level_stroke_out,
					segment_level_raw_stroke, segment_level_char, segment_level_term, segment_level_stroke_in, segment_level_stroke_out, _] = \
						np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes')

				else:
					[sentence_level_raw_stroke, sentence_level_stroke_in, sentence_level_stroke_out, sentence_level_term, sentence_level_char, word_level_raw_stroke, word_level_stroke_in, word_level_stroke_out, word_level_term, word_level_char, segment_level_raw_stroke, segment_level_stroke_in, segment_level_stroke_out, segment_level_term, segment_level_char, _] = \
						np.load(self.datadir+'/'+str(uid)+'/'+str(tid)+'.npy', allow_pickle=True, encoding='bytes')

				if self.datadir == './data/DW_writers':
					sentence_level_char	= sentence_level_char[1:]
					sentence_level_term	= sentence_level_term[1:]

				if self.datadir == './data/VALID_DW_writers':
					sentence_level_char	= sentence_level_char[1:]
					sentence_level_term	= sentence_level_term[1:]

				while True:
					if len(sentence_level_term) == 0:
						break
					if sentence_level_term[-1] != 1.0:
						sentence_level_raw_stroke = sentence_level_raw_stroke[:-1]
						sentence_level_char = sentence_level_char[:-1]
						sentence_level_term = sentence_level_term[:-1]
						sentence_level_stroke_in = sentence_level_stroke_in[:-1]
						sentence_level_stroke_out = sentence_level_stroke_out[:-1]
					else:
						break

				tmp = []
				for i, t in enumerate(sentence_level_term):
					if t == 1:
						tmp.append(sentence_level_char[i])

				a = np.ones_like(sentence_level_stroke_in)
				a[:,:2] /= self.divider

				if len(sentence_level_stroke_in) == len(sentence_level_term) and len(tmp) > 0 and len(sentence_level_stroke_in) > 0:
					user_sentence_level_stroke_in.append(np.asarray(sentence_level_stroke_in) * a)
					user_sentence_level_stroke_out.append(np.asarray(sentence_level_stroke_out) * a)
					user_sentence_level_stroke_length.append(len(sentence_level_stroke_in))
					user_sentence_level_char.append(np.asarray(tmp))
					user_sentence_level_term.append(np.asarray(sentence_level_term))
					user_sentence_level_char_length.append(len(tmp))

				for wid in range(len(word_level_stroke_in)):
					each_word_level_stroke_in		= word_level_stroke_in[wid]
					each_word_level_stroke_out		= word_level_stroke_out[wid]

					if self.datadir == './data/DW_writers':
						each_word_level_term			= word_level_term[wid][1:]
						each_word_level_char			= word_level_char[wid][1:]
					elif self.datadir == './data/VALID_DW_writers':
						each_word_level_term			= word_level_term[wid][1:]
						each_word_level_char			= word_level_char[wid][1:]
					else:
						each_word_level_term			= word_level_term[wid]
						each_word_level_char			= word_level_char[wid]


					# assert (len(each_word_level_stroke_in) == len(each_word_level_char) == len(each_word_level_term))
					while True:
						if len(each_word_level_term) == 0:
							break
						if each_word_level_term[-1] != 1.0:
							# each_word_level_raw_stroke = each_word_level_raw_stroke[:-1]
							each_word_level_char = each_word_level_char[:-1]
							each_word_level_term = each_word_level_term[:-1]
							each_word_level_stroke_in = each_word_level_stroke_in[:-1]
							each_word_level_stroke_out = each_word_level_stroke_out[:-1]
						else:
							break

					tmp = []
					for i, t in enumerate(each_word_level_term):
						if t == 1:
							tmp.append(each_word_level_char[i])

					b = np.ones_like(each_word_level_stroke_in)
					b[:,:2] /= self.divider

					if len(each_word_level_stroke_in) == len(each_word_level_term) and len(tmp) > 0 and len(each_word_level_stroke_in) > 0:
						user_word_level_stroke_in.append(np.asarray(each_word_level_stroke_in) * b)
						user_word_level_stroke_out.append(np.asarray(each_word_level_stroke_out) * b)
						user_word_level_stroke_length.append(len(each_word_level_stroke_in))
						user_word_level_char.append(np.asarray(tmp))
						user_word_level_term.append(np.asarray(each_word_level_term))
						user_word_level_char_length.append(len(tmp))

					segment_level_stroke_in_list		= []
					segment_level_stroke_out_list		= []
					segment_level_stroke_length_list	= []
					segment_level_char_list				= []
					segment_level_term_list				= []
					segment_level_char_length_list		= []

					for sid in range(len(segment_level_stroke_in[wid])):
						each_segment_level_stroke_in	= segment_level_stroke_in[wid][sid]
						each_segment_level_stroke_out	= segment_level_stroke_out[wid][sid]

						if self.datadir == './data/DW_writers':
							each_segment_level_term			= segment_level_term[wid][sid][1:]
							each_segment_level_char			= segment_level_char[wid][sid][1:]
						elif self.datadir == './data/VALID_DW_writers':
							each_segment_level_term			= segment_level_term[wid][sid][1:]
							each_segment_level_char			= segment_level_char[wid][sid][1:]
						else:
							each_segment_level_term			= segment_level_term[wid][sid]
							each_segment_level_char			= segment_level_char[wid][sid]

						while True:
							if len(each_segment_level_term) == 0:
								break
							if each_segment_level_term[-1] != 1.0:
								# each_segment_level_raw_stroke = each_segment_level_raw_stroke[:-1]
								each_segment_level_char = each_segment_level_char[:-1]
								each_segment_level_term = each_segment_level_term[:-1]
								each_segment_level_stroke_in = each_segment_level_stroke_in[:-1]
								each_segment_level_stroke_out = each_segment_level_stroke_out[:-1]
							else:
								break

						tmp = []
						for i, t in enumerate(each_segment_level_term):
							if t == 1:
								tmp.append(each_segment_level_char[i])

						c = np.ones_like(each_segment_level_stroke_in)
						c[:,:2] /= self.divider

						if len(each_segment_level_stroke_in) == len(each_segment_level_term) and len(tmp) > 0 and len(each_segment_level_stroke_in) > 0:
							segment_level_stroke_in_list.append(np.asarray(each_segment_level_stroke_in) * c)
							segment_level_stroke_out_list.append(np.asarray(each_segment_level_stroke_out) * c)
							segment_level_stroke_length_list.append(len(each_segment_level_stroke_in))
							segment_level_char_list.append(np.asarray(tmp))
							segment_level_term_list.append(np.asarray(each_segment_level_term))
							segment_level_char_length_list.append(len(tmp))

					if len(segment_level_stroke_length_list) > 0:
						SEGMENT_MAX_STROKE_LENGTH		= np.max(segment_level_stroke_length_list)
						SEGMENT_MAX_CHARACTER_LENGTH	= np.max(segment_level_char_length_list)

						new_segment_level_stroke_in_list 	= np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a)), (0, 0)), 'constant') for a in segment_level_stroke_in_list])
						new_segment_level_stroke_out_list 	= np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a)), (0, 0)), 'constant') for a in segment_level_stroke_out_list])
						new_segment_level_term_list 		= np.asarray([np.pad(a, ((0, SEGMENT_MAX_STROKE_LENGTH-len(a))), 'constant') for a in segment_level_term_list])
						new_segment_level_char_list 		= np.asarray([np.pad(a, ((0, SEGMENT_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in segment_level_char_list])

						user_segment_level_stroke_in.append(new_segment_level_stroke_in_list)
						user_segment_level_stroke_out.append(new_segment_level_stroke_out_list)
						user_segment_level_stroke_length.append(segment_level_stroke_length_list)
						user_segment_level_char.append(new_segment_level_char_list)
						user_segment_level_term.append(new_segment_level_term_list)
						user_segment_level_char_length.append(segment_level_char_length_list)

			WORD_MAX_STROKE_LENGTH			= np.max(user_word_level_stroke_length)
			WORD_MAX_CHARACTER_LENGTH		= np.max(user_word_level_char_length)

			SENTENCE_MAX_STROKE_LENGTH		= np.max(user_sentence_level_stroke_length)
			SENTENCE_MAX_CHARACTER_LENGTH	= np.max(user_sentence_level_char_length)

			new_sentence_level_stroke_in	= np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_sentence_level_stroke_in])
			new_sentence_level_stroke_out	= np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_sentence_level_stroke_out])
			new_sentence_level_term			= np.asarray([np.pad(a, ((0, SENTENCE_MAX_STROKE_LENGTH-len(a))), 'constant') for a in user_sentence_level_term])
			new_sentence_level_char			= np.asarray([np.pad(a, ((0, SENTENCE_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in user_sentence_level_char])
			new_word_level_stroke_in		= np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_word_level_stroke_in])
			new_word_level_stroke_out		= np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a)), (0,0)), 'constant') for a in user_word_level_stroke_out])
			new_word_level_term				= np.asarray([np.pad(a, ((0, WORD_MAX_STROKE_LENGTH-len(a))), 'constant') for a in user_word_level_term])
			new_word_level_char				= np.asarray([np.pad(a, ((0, WORD_MAX_CHARACTER_LENGTH-len(a))), 'constant') for a in user_word_level_char])

			all_sentence_level_stroke_in.append(new_sentence_level_stroke_in)
			all_sentence_level_stroke_out.append(new_sentence_level_stroke_out)
			all_sentence_level_stroke_length.append(user_sentence_level_stroke_length)
			all_sentence_level_term.append(new_sentence_level_term)
			all_sentence_level_char.append(new_sentence_level_char)
			all_sentence_level_char_length.append(user_sentence_level_char_length)
			all_word_level_stroke_in.append(new_word_level_stroke_in)
			all_word_level_stroke_out.append(new_word_level_stroke_out)
			all_word_level_stroke_length.append(user_word_level_stroke_length)
			all_word_level_term.append(new_word_level_term)
			all_word_level_char.append(new_word_level_char)
			all_word_level_char_length.append(user_word_level_char_length)
			all_segment_level_stroke_in.append(user_segment_level_stroke_in)
			all_segment_level_stroke_out.append(user_segment_level_stroke_out)
			all_segment_level_stroke_length.append(user_segment_level_stroke_length)
			all_segment_level_term.append(user_segment_level_term)
			all_segment_level_char.append(user_segment_level_char)
			all_segment_level_char_length.append(user_segment_level_char_length)

		return [all_sentence_level_stroke_in, all_sentence_level_stroke_out, all_sentence_level_stroke_length, all_sentence_level_term, all_sentence_level_char, all_sentence_level_char_length, all_word_level_stroke_in, all_word_level_stroke_out, all_word_level_stroke_length, all_word_level_term, all_word_level_char, all_word_level_char_length, all_segment_level_stroke_in, all_segment_level_stroke_out, all_segment_level_stroke_length, all_segment_level_term, all_segment_level_char, all_segment_level_char_length]