Spaces:
Runtime error
Runtime error
jefsnacker
commited on
Commit
•
70d2f66
1
Parent(s):
7e62304
adds reverse model and feature that allows you to pick last letters in the name
Browse files
app.py
CHANGED
@@ -34,6 +34,14 @@ gpt_micro_weights_path = huggingface_hub.hf_hub_download(
|
|
34 |
"jefsnacker/surname_generator",
|
35 |
"micro_gpt_weights.pt")
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
with open(mlp_config_path, 'r') as file:
|
38 |
mlp_config = yaml.safe_load(file)
|
39 |
|
@@ -43,6 +51,9 @@ with open(wavenet_config_path, 'r') as file:
|
|
43 |
with open(gpt_micro_config_path, 'r') as file:
|
44 |
gpt_micro_config = yaml.safe_load(file)
|
45 |
|
|
|
|
|
|
|
46 |
##################################################################################
|
47 |
## MLP
|
48 |
##################################################################################
|
@@ -314,31 +325,48 @@ gpt_micro = GPT(gpt_micro_config)
|
|
314 |
gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
|
315 |
gpt_micro.eval()
|
316 |
|
|
|
|
|
|
|
|
|
317 |
##################################################################################
|
318 |
## Gradio App
|
319 |
##################################################################################
|
320 |
|
321 |
-
def generate_names(name_start, number_of_names, model):
|
322 |
if model == "MLP":
|
323 |
-
|
324 |
-
window = mlp_config['window']
|
325 |
elif model == "WaveNet":
|
326 |
-
|
327 |
-
window = wavenet_config['window']
|
328 |
elif model == "GPT Micro":
|
329 |
-
|
330 |
-
|
|
|
331 |
else:
|
332 |
raise Exception("Model not selected")
|
333 |
-
|
|
|
334 |
itos = {s:i for i,s in stoi.items()}
|
335 |
|
336 |
names = ""
|
337 |
for _ in range((int)(number_of_names)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
# Initialize name with user input
|
340 |
-
name = ""
|
341 |
-
context = [0] * window
|
342 |
for c in name_start.lower():
|
343 |
name += c
|
344 |
context = context[1:] + [stoi[c]]
|
@@ -352,6 +380,8 @@ def generate_names(name_start, number_of_names, model):
|
|
352 |
ix = wavenet.sample_char(x)
|
353 |
elif model == "GPT Micro":
|
354 |
ix = gpt_micro.sample_char(x)
|
|
|
|
|
355 |
else:
|
356 |
raise Exception("Model not selected")
|
357 |
|
@@ -369,8 +399,9 @@ demo = gr.Interface(
|
|
369 |
fn=generate_names,
|
370 |
inputs=[
|
371 |
gr.Textbox(placeholder="Start name with..."),
|
|
|
372 |
gr.Number(value=5),
|
373 |
-
gr.Dropdown(["MLP", "WaveNet", "GPT Micro"], value="GPT
|
374 |
],
|
375 |
outputs="text",
|
376 |
)
|
|
|
34 |
"jefsnacker/surname_generator",
|
35 |
"micro_gpt_weights.pt")
|
36 |
|
37 |
+
gpt_rev_config_path = huggingface_hub.hf_hub_download(
|
38 |
+
"jefsnacker/surname_generator",
|
39 |
+
"rev_gpt_config.yaml")
|
40 |
+
|
41 |
+
gpt_rev_weights_path = huggingface_hub.hf_hub_download(
|
42 |
+
"jefsnacker/surname_generator",
|
43 |
+
"rev_gpt_weights.pt")
|
44 |
+
|
45 |
with open(mlp_config_path, 'r') as file:
|
46 |
mlp_config = yaml.safe_load(file)
|
47 |
|
|
|
51 |
with open(gpt_micro_config_path, 'r') as file:
|
52 |
gpt_micro_config = yaml.safe_load(file)
|
53 |
|
54 |
+
with open(gpt_rev_config_path, 'r') as file:
|
55 |
+
gpt_rev_config = yaml.safe_load(file)
|
56 |
+
|
57 |
##################################################################################
|
58 |
## MLP
|
59 |
##################################################################################
|
|
|
325 |
gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
|
326 |
gpt_micro.eval()
|
327 |
|
328 |
+
gpt_rev = GPT(gpt_rev_config)
|
329 |
+
gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path))
|
330 |
+
gpt_rev.eval()
|
331 |
+
|
332 |
##################################################################################
|
333 |
## Gradio App
|
334 |
##################################################################################
|
335 |
|
336 |
+
def generate_names(name_start, name_end, number_of_names, model):
|
337 |
if model == "MLP":
|
338 |
+
config = mlp_config
|
|
|
339 |
elif model == "WaveNet":
|
340 |
+
config = wavenet_config
|
|
|
341 |
elif model == "GPT Micro":
|
342 |
+
config = gpt_micro_config
|
343 |
+
elif model == "GPT Rev":
|
344 |
+
config = gpt_rev_config
|
345 |
else:
|
346 |
raise Exception("Model not selected")
|
347 |
+
|
348 |
+
stoi = config['stoi']
|
349 |
itos = {s:i for i,s in stoi.items()}
|
350 |
|
351 |
names = ""
|
352 |
for _ in range((int)(number_of_names)):
|
353 |
+
name = ""
|
354 |
+
context = [0] * config['window']
|
355 |
+
|
356 |
+
if "num_final_chars_in_dataset" in config:
|
357 |
+
# Put final chars in context
|
358 |
+
if len(name_end) > config["num_final_chars_in_dataset"]:
|
359 |
+
name_end = name_end[-config["num_final_chars_in_dataset"]:]
|
360 |
+
print("Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + name_end)
|
361 |
+
|
362 |
+
for c in name_end:
|
363 |
+
context = context[1:] + [stoi[c]]
|
364 |
+
context = context[1:] + [stoi['.']]
|
365 |
+
|
366 |
+
elif (name_end != ""):
|
367 |
+
print("Final chars not used. Need to use a model trained with this feature.")
|
368 |
|
369 |
# Initialize name with user input
|
|
|
|
|
370 |
for c in name_start.lower():
|
371 |
name += c
|
372 |
context = context[1:] + [stoi[c]]
|
|
|
380 |
ix = wavenet.sample_char(x)
|
381 |
elif model == "GPT Micro":
|
382 |
ix = gpt_micro.sample_char(x)
|
383 |
+
elif model == "GPT Rev":
|
384 |
+
ix = gpt_rev.sample_char(x)
|
385 |
else:
|
386 |
raise Exception("Model not selected")
|
387 |
|
|
|
399 |
fn=generate_names,
|
400 |
inputs=[
|
401 |
gr.Textbox(placeholder="Start name with..."),
|
402 |
+
gr.Textbox(placeholder="End name with... (only works for rev model)"),
|
403 |
gr.Number(value=5),
|
404 |
+
gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev"], value="GPT Rev"),
|
405 |
],
|
406 |
outputs="text",
|
407 |
)
|