jefsnacker commited on
Commit
34a8736
1 Parent(s): 70d2f66

better error handling

Browse files
Files changed (1) hide show
  1. app.py +41 -24
app.py CHANGED
@@ -334,56 +334,73 @@ gpt_rev.eval()
334
  ##################################################################################
335
 
336
  def generate_names(name_start, name_end, number_of_names, model):
 
 
 
 
337
  if model == "MLP":
338
  config = mlp_config
 
339
  elif model == "WaveNet":
340
  config = wavenet_config
 
341
  elif model == "GPT Micro":
342
  config = gpt_micro_config
 
343
  elif model == "GPT Rev":
344
  config = gpt_rev_config
 
 
 
 
 
345
  else:
346
- raise Exception("Model not selected")
347
 
348
  stoi = config['stoi']
349
  itos = {s:i for i,s in stoi.items()}
350
 
351
- names = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  for _ in range((int)(number_of_names)):
353
  name = ""
354
  context = [0] * config['window']
355
 
356
  if "num_final_chars_in_dataset" in config:
357
- # Put final chars in context
358
- if len(name_end) > config["num_final_chars_in_dataset"]:
359
- name_end = name_end[-config["num_final_chars_in_dataset"]:]
360
- print("Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + name_end)
361
-
362
  for c in name_end:
363
  context = context[1:] + [stoi[c]]
364
  context = context[1:] + [stoi['.']]
365
-
366
- elif (name_end != ""):
367
- print("Final chars not used. Need to use a model trained with this feature.")
368
 
369
  # Initialize name with user input
370
- for c in name_start.lower():
371
  name += c
372
  context = context[1:] + [stoi[c]]
373
 
374
  # Run inference to finish off the name
375
  while True:
376
  x = torch.tensor(context).view(1, -1)
377
- if model == "MLP":
378
- ix = mlp.sample_char(x)
379
- elif model == "WaveNet":
380
- ix = wavenet.sample_char(x)
381
- elif model == "GPT Micro":
382
- ix = gpt_micro.sample_char(x)
383
- elif model == "GPT Rev":
384
- ix = gpt_rev.sample_char(x)
385
- else:
386
- raise Exception("Model not selected")
387
 
388
  context = context[1:] + [ix]
389
  name += itos[ix]
@@ -391,9 +408,9 @@ def generate_names(name_start, name_end, number_of_names, model):
391
  if ix == 0:
392
  break
393
 
394
- names += name + "\n"
395
 
396
- return names
397
 
398
  demo = gr.Interface(
399
  fn=generate_names,
@@ -401,7 +418,7 @@ demo = gr.Interface(
401
  gr.Textbox(placeholder="Start name with..."),
402
  gr.Textbox(placeholder="End name with... (only works for rev model)"),
403
  gr.Number(value=5),
404
- gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev"], value="GPT Rev"),
405
  ],
406
  outputs="text",
407
  )
 
334
  ##################################################################################
335
 
336
  def generate_names(name_start, name_end, number_of_names, model):
337
+ if number_of_names < 0:
338
+ return "Error: Please enter a positive number of names to generate!"
339
+
340
+ # Select model
341
  if model == "MLP":
342
  config = mlp_config
343
+ sample_fcn = mlp.sample_char
344
  elif model == "WaveNet":
345
  config = wavenet_config
346
+ sample_fcn = wavenet.sample_char
347
  elif model == "GPT Micro":
348
  config = gpt_micro_config
349
+ sample_fcn = gpt_micro.sample_char
350
  elif model == "GPT Rev":
351
  config = gpt_rev_config
352
+ sample_fcn = gpt_rev.sample_char
353
+ elif model == "GPT First Rev":
354
+ # TODO: Change model!
355
+ config = gpt_rev_config
356
+ sample_fcn = gpt_rev.sample_char
357
  else:
358
+ return "Error: Model not selected"
359
 
360
  stoi = config['stoi']
361
  itos = {s:i for i,s in stoi.items()}
362
 
363
+ output = ""
364
+
365
+ # Sanitize user inputs, and append errors to output
366
+ name_end = name_end.lower()
367
+ name_start = name_start.lower()
368
+
369
+ for c in name_end:
370
+ if c not in stoi:
371
+ return "Please change name end. \"" + c + "\" not included in the training set."
372
+
373
+ for c in name_start:
374
+ if c not in stoi:
375
+ return "Please change name start. \"" + c + "\" not included in the training set."
376
+
377
+ if "num_final_chars_in_dataset" in config and len(name_end) > config["num_final_chars_in_dataset"]:
378
+ name_end = name_end[-config["num_final_chars_in_dataset"]:]
379
+ output += "Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + str(name_end) + "\n"
380
+
381
+ elif "num_final_chars_in_dataset" not in config and name_end != "":
382
+ output += "Final chars not used. Need to use a \"Rev\" model trained with this feature.\n"
383
+
384
+
385
+ ## Print requested names
386
  for _ in range((int)(number_of_names)):
387
  name = ""
388
  context = [0] * config['window']
389
 
390
  if "num_final_chars_in_dataset" in config:
 
 
 
 
 
391
  for c in name_end:
392
  context = context[1:] + [stoi[c]]
393
  context = context[1:] + [stoi['.']]
 
 
 
394
 
395
  # Initialize name with user input
396
+ for c in name_start:
397
  name += c
398
  context = context[1:] + [stoi[c]]
399
 
400
  # Run inference to finish off the name
401
  while True:
402
  x = torch.tensor(context).view(1, -1)
403
+ ix = sample_fcn(x)
 
 
 
 
 
 
 
 
 
404
 
405
  context = context[1:] + [ix]
406
  name += itos[ix]
 
408
  if ix == 0:
409
  break
410
 
411
+ output += name + "\n"
412
 
413
+ return output
414
 
415
  demo = gr.Interface(
416
  fn=generate_names,
 
418
  gr.Textbox(placeholder="Start name with..."),
419
  gr.Textbox(placeholder="End name with... (only works for rev model)"),
420
  gr.Number(value=5),
421
+ gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev", "GPT First Rev"], value="GPT Rev"),
422
  ],
423
  outputs="text",
424
  )