crackalamoo
commited on
Commit
•
ceea12e
1
Parent(s):
1427339
Upload new model.py
Browse files
model.py
CHANGED
@@ -314,35 +314,37 @@ if __name__ == '__main__':
|
|
314 |
't': 'inputs/transformer_train.npz',
|
315 |
'b': 'inputs/bard_train.npz'
|
316 |
}[MODEL_TYPE]
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
326 |
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
|
338 |
print("Initializing model")
|
339 |
models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel}
|
340 |
model = models[MODEL_TYPE]()
|
341 |
if MODEL_TYPE != 'b':
|
342 |
-
|
|
|
343 |
else:
|
344 |
-
x0 =
|
345 |
-
x1 =
|
346 |
res = model([x0, x1])
|
347 |
if VERBOSE:
|
348 |
print(model)
|
@@ -391,8 +393,6 @@ if __name__ == '__main__':
|
|
391 |
print(pretty_tokens(genTokens(model, 500)))
|
392 |
|
393 |
else:
|
394 |
-
del train_x
|
395 |
-
del train_y
|
396 |
print("Loading weights")
|
397 |
model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5')
|
398 |
|
|
|
314 |
't': 'inputs/transformer_train.npz',
|
315 |
'b': 'inputs/bard_train.npz'
|
316 |
}[MODEL_TYPE]
|
317 |
+
if TRAINING:
|
318 |
+
print("Loading data from", fname)
|
319 |
+
loaded = np.load(fname)
|
320 |
+
train_x = loaded['x']
|
321 |
+
train_y = loaded['y']
|
322 |
+
if MODEL_TYPE == 'b':
|
323 |
+
train_x = [tf.convert_to_tensor(train_x), tf.convert_to_tensor(loaded['rm'])] # rhyme and syllables
|
324 |
+
if MODEL_TYPE == 'n':
|
325 |
+
train_x = tf.convert_to_tensor(train_x, tf.int32)
|
326 |
+
del loaded
|
327 |
|
328 |
+
if VERBOSE:
|
329 |
+
if MODEL_TYPE != 'b':
|
330 |
+
print("X:", train_x[10:14])
|
331 |
+
else:
|
332 |
+
print("X:", train_x[0][10:14])
|
333 |
+
print("RM:", train_x[1][10:14][1])
|
334 |
+
print("Y:", train_y[10:14])
|
335 |
+
if MODEL_TYPE != 'b':
|
336 |
+
print("X shape:", train_x.shape)
|
337 |
+
print("Y shape:", train_y.shape)
|
338 |
|
339 |
print("Initializing model")
|
340 |
models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel}
|
341 |
model = models[MODEL_TYPE]()
|
342 |
if MODEL_TYPE != 'b':
|
343 |
+
x0 = np.zeros((1,NGRAM_N-1 if MODEL_TYPE=='n' else TRANSFORMER_N))
|
344 |
+
res = model(x0)
|
345 |
else:
|
346 |
+
x0 = np.zeros((1,TRANSFORMER_N))
|
347 |
+
x1 = np.zeros((1,TRANSFORMER_N,RHYME_STACK_SIZE*2+METER_STACK_SIZE))
|
348 |
res = model([x0, x1])
|
349 |
if VERBOSE:
|
350 |
print(model)
|
|
|
393 |
print(pretty_tokens(genTokens(model, 500)))
|
394 |
|
395 |
else:
|
|
|
|
|
396 |
print("Loading weights")
|
397 |
model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5')
|
398 |
|