jefsnacker commited on
Commit
70d2f66
1 Parent(s): 7e62304

adds reverse model and feature that allows you to pick last letters in the name

Browse files
Files changed (1) hide show
  1. app.py +42 -11
app.py CHANGED
@@ -34,6 +34,14 @@ gpt_micro_weights_path = huggingface_hub.hf_hub_download(
34
  "jefsnacker/surname_generator",
35
  "micro_gpt_weights.pt")
36
 
 
 
 
 
 
 
 
 
37
  with open(mlp_config_path, 'r') as file:
38
  mlp_config = yaml.safe_load(file)
39
 
@@ -43,6 +51,9 @@ with open(wavenet_config_path, 'r') as file:
43
  with open(gpt_micro_config_path, 'r') as file:
44
  gpt_micro_config = yaml.safe_load(file)
45
 
 
 
 
46
  ##################################################################################
47
  ## MLP
48
  ##################################################################################
@@ -314,31 +325,48 @@ gpt_micro = GPT(gpt_micro_config)
314
  gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
315
  gpt_micro.eval()
316
 
 
 
 
 
317
  ##################################################################################
318
  ## Gradio App
319
  ##################################################################################
320
 
321
- def generate_names(name_start, number_of_names, model):
322
  if model == "MLP":
323
- stoi = mlp_config['stoi']
324
- window = mlp_config['window']
325
  elif model == "WaveNet":
326
- stoi = wavenet_config['stoi']
327
- window = wavenet_config['window']
328
  elif model == "GPT Micro":
329
- stoi = gpt_micro_config['stoi']
330
- window = gpt_micro_config['window']
 
331
  else:
332
  raise Exception("Model not selected")
333
-
 
334
  itos = {s:i for i,s in stoi.items()}
335
 
336
  names = ""
337
  for _ in range((int)(number_of_names)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  # Initialize name with user input
340
- name = ""
341
- context = [0] * window
342
  for c in name_start.lower():
343
  name += c
344
  context = context[1:] + [stoi[c]]
@@ -352,6 +380,8 @@ def generate_names(name_start, number_of_names, model):
352
  ix = wavenet.sample_char(x)
353
  elif model == "GPT Micro":
354
  ix = gpt_micro.sample_char(x)
 
 
355
  else:
356
  raise Exception("Model not selected")
357
 
@@ -369,8 +399,9 @@ demo = gr.Interface(
369
  fn=generate_names,
370
  inputs=[
371
  gr.Textbox(placeholder="Start name with..."),
 
372
  gr.Number(value=5),
373
- gr.Dropdown(["MLP", "WaveNet", "GPT Micro"], value="GPT Micro"),
374
  ],
375
  outputs="text",
376
  )
 
34
  "jefsnacker/surname_generator",
35
  "micro_gpt_weights.pt")
36
 
37
+ gpt_rev_config_path = huggingface_hub.hf_hub_download(
38
+ "jefsnacker/surname_generator",
39
+ "rev_gpt_config.yaml")
40
+
41
+ gpt_rev_weights_path = huggingface_hub.hf_hub_download(
42
+ "jefsnacker/surname_generator",
43
+ "rev_gpt_weights.pt")
44
+
45
  with open(mlp_config_path, 'r') as file:
46
  mlp_config = yaml.safe_load(file)
47
 
 
51
  with open(gpt_micro_config_path, 'r') as file:
52
  gpt_micro_config = yaml.safe_load(file)
53
 
54
+ with open(gpt_rev_config_path, 'r') as file:
55
+ gpt_rev_config = yaml.safe_load(file)
56
+
57
  ##################################################################################
58
  ## MLP
59
  ##################################################################################
 
325
  gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
326
  gpt_micro.eval()
327
 
328
+ gpt_rev = GPT(gpt_rev_config)
329
+ gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path))
330
+ gpt_rev.eval()
331
+
332
  ##################################################################################
333
  ## Gradio App
334
  ##################################################################################
335
 
336
+ def generate_names(name_start, name_end, number_of_names, model):
337
  if model == "MLP":
338
+ config = mlp_config
 
339
  elif model == "WaveNet":
340
+ config = wavenet_config
 
341
  elif model == "GPT Micro":
342
+ config = gpt_micro_config
343
+ elif model == "GPT Rev":
344
+ config = gpt_rev_config
345
  else:
346
  raise Exception("Model not selected")
347
+
348
+ stoi = config['stoi']
349
  itos = {s:i for i,s in stoi.items()}
350
 
351
  names = ""
352
  for _ in range((int)(number_of_names)):
353
+ name = ""
354
+ context = [0] * config['window']
355
+
356
+ if "num_final_chars_in_dataset" in config:
357
+ # Put final chars in context
358
+ if len(name_end) > config["num_final_chars_in_dataset"]:
359
+ name_end = name_end[-config["num_final_chars_in_dataset"]:]
360
+ print("Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + name_end)
361
+
362
+ for c in name_end:
363
+ context = context[1:] + [stoi[c]]
364
+ context = context[1:] + [stoi['.']]
365
+
366
+ elif (name_end != ""):
367
+ print("Final chars not used. Need to use a model trained with this feature.")
368
 
369
  # Initialize name with user input
 
 
370
  for c in name_start.lower():
371
  name += c
372
  context = context[1:] + [stoi[c]]
 
380
  ix = wavenet.sample_char(x)
381
  elif model == "GPT Micro":
382
  ix = gpt_micro.sample_char(x)
383
+ elif model == "GPT Rev":
384
+ ix = gpt_rev.sample_char(x)
385
  else:
386
  raise Exception("Model not selected")
387
 
 
399
  fn=generate_names,
400
  inputs=[
401
  gr.Textbox(placeholder="Start name with..."),
402
+ gr.Textbox(placeholder="End name with... (only works for rev model)"),
403
  gr.Number(value=5),
404
+ gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev"], value="GPT Rev"),
405
  ],
406
  outputs="text",
407
  )