crackalamoo commited on
Commit
ceea12e
1 Parent(s): 1427339

Upload new model.py

Browse files
Files changed (1) hide show
  1. model.py +24 -24
model.py CHANGED
@@ -314,35 +314,37 @@ if __name__ == '__main__':
314
  't': 'inputs/transformer_train.npz',
315
  'b': 'inputs/bard_train.npz'
316
  }[MODEL_TYPE]
317
- print("Loading data from", fname)
318
- loaded = np.load(fname)
319
- train_x = loaded['x']
320
- train_y = loaded['y']
321
- if MODEL_TYPE == 'b':
322
- train_x = [tf.convert_to_tensor(train_x), tf.convert_to_tensor(loaded['rm'])] # rhyme and syllables
323
- if MODEL_TYPE == 'n':
324
- train_x = tf.convert_to_tensor(train_x, tf.int32)
325
- del loaded
 
326
 
327
- if TRAINING and VERBOSE:
328
- if MODEL_TYPE != 'b':
329
- print("X:", train_x[10:14])
330
- else:
331
- print("X:", train_x[0][10:14])
332
- print("RM:", train_x[1][10:14][1])
333
- print("Y:", train_y[10:14])
334
- if MODEL_TYPE != 'b':
335
- print("X shape:", train_x.shape)
336
- print("Y shape:", train_y.shape)
337
 
338
  print("Initializing model")
339
  models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel}
340
  model = models[MODEL_TYPE]()
341
  if MODEL_TYPE != 'b':
342
- res = model(train_x[:1])
 
343
  else:
344
- x0 = train_x[0][:1]
345
- x1 = train_x[1][:1]
346
  res = model([x0, x1])
347
  if VERBOSE:
348
  print(model)
@@ -391,8 +393,6 @@ if __name__ == '__main__':
391
  print(pretty_tokens(genTokens(model, 500)))
392
 
393
  else:
394
- del train_x
395
- del train_y
396
  print("Loading weights")
397
  model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5')
398
 
 
314
  't': 'inputs/transformer_train.npz',
315
  'b': 'inputs/bard_train.npz'
316
  }[MODEL_TYPE]
317
+ if TRAINING:
318
+ print("Loading data from", fname)
319
+ loaded = np.load(fname)
320
+ train_x = loaded['x']
321
+ train_y = loaded['y']
322
+ if MODEL_TYPE == 'b':
323
+ train_x = [tf.convert_to_tensor(train_x), tf.convert_to_tensor(loaded['rm'])] # rhyme and syllables
324
+ if MODEL_TYPE == 'n':
325
+ train_x = tf.convert_to_tensor(train_x, tf.int32)
326
+ del loaded
327
 
328
+ if VERBOSE:
329
+ if MODEL_TYPE != 'b':
330
+ print("X:", train_x[10:14])
331
+ else:
332
+ print("X:", train_x[0][10:14])
333
+ print("RM:", train_x[1][10:14][1])
334
+ print("Y:", train_y[10:14])
335
+ if MODEL_TYPE != 'b':
336
+ print("X shape:", train_x.shape)
337
+ print("Y shape:", train_y.shape)
338
 
339
  print("Initializing model")
340
  models = {'n': LinearModel, 't': TransformerModel, 'b': BardModel}
341
  model = models[MODEL_TYPE]()
342
  if MODEL_TYPE != 'b':
343
+ x0 = np.zeros((1,NGRAM_N-1 if MODEL_TYPE=='n' else TRANSFORMER_N))
344
+ res = model(x0)
345
  else:
346
+ x0 = np.zeros((1,TRANSFORMER_N))
347
+ x1 = np.zeros((1,TRANSFORMER_N,RHYME_STACK_SIZE*2+METER_STACK_SIZE))
348
  res = model([x0, x1])
349
  if VERBOSE:
350
  print(model)
 
393
  print(pretty_tokens(genTokens(model, 500)))
394
 
395
  else:
 
 
396
  print("Loading weights")
397
  model.load_weights('saved_models/'+MODEL_TYPE+'_model.h5')
398