Spaces:
Runtime error
Runtime error
jefsnacker
commited on
Commit
•
7e62304
1
Parent(s):
3f362c0
rename gpt nano -> gpt micro
Browse files
app.py
CHANGED
@@ -26,13 +26,13 @@ wavenet_weights_path = huggingface_hub.hf_hub_download(
|
|
26 |
"jefsnacker/surname_generator",
|
27 |
"wavenet_weights.pt")
|
28 |
|
29 |
-
|
30 |
"jefsnacker/surname_generator",
|
31 |
-
"
|
32 |
|
33 |
-
|
34 |
"jefsnacker/surname_generator",
|
35 |
-
"
|
36 |
|
37 |
with open(mlp_config_path, 'r') as file:
|
38 |
mlp_config = yaml.safe_load(file)
|
@@ -40,8 +40,8 @@ with open(mlp_config_path, 'r') as file:
|
|
40 |
with open(wavenet_config_path, 'r') as file:
|
41 |
wavenet_config = yaml.safe_load(file)
|
42 |
|
43 |
-
with open(
|
44 |
-
|
45 |
|
46 |
##################################################################################
|
47 |
## MLP
|
@@ -310,9 +310,9 @@ class GPT(nn.Module):
|
|
310 |
probs = F.softmax(logits[:,-1,:], dim=1)
|
311 |
return torch.multinomial(probs, num_samples=1).item()
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
|
317 |
##################################################################################
|
318 |
## Gradio App
|
@@ -325,9 +325,9 @@ def generate_names(name_start, number_of_names, model):
|
|
325 |
elif model == "WaveNet":
|
326 |
stoi = wavenet_config['stoi']
|
327 |
window = wavenet_config['window']
|
328 |
-
elif model == "GPT
|
329 |
-
stoi =
|
330 |
-
window =
|
331 |
else:
|
332 |
raise Exception("Model not selected")
|
333 |
|
@@ -350,8 +350,8 @@ def generate_names(name_start, number_of_names, model):
|
|
350 |
ix = mlp.sample_char(x)
|
351 |
elif model == "WaveNet":
|
352 |
ix = wavenet.sample_char(x)
|
353 |
-
elif model == "GPT
|
354 |
-
ix =
|
355 |
else:
|
356 |
raise Exception("Model not selected")
|
357 |
|
@@ -370,7 +370,7 @@ demo = gr.Interface(
|
|
370 |
inputs=[
|
371 |
gr.Textbox(placeholder="Start name with..."),
|
372 |
gr.Number(value=5),
|
373 |
-
gr.Dropdown(["MLP", "WaveNet", "GPT
|
374 |
],
|
375 |
outputs="text",
|
376 |
)
|
|
|
26 |
"jefsnacker/surname_generator",
|
27 |
"wavenet_weights.pt")
|
28 |
|
29 |
+
gpt_micro_config_path = huggingface_hub.hf_hub_download(
|
30 |
"jefsnacker/surname_generator",
|
31 |
+
"micro_gpt_config.yaml")
|
32 |
|
33 |
+
gpt_micro_weights_path = huggingface_hub.hf_hub_download(
|
34 |
"jefsnacker/surname_generator",
|
35 |
+
"micro_gpt_weights.pt")
|
36 |
|
37 |
with open(mlp_config_path, 'r') as file:
|
38 |
mlp_config = yaml.safe_load(file)
|
|
|
40 |
with open(wavenet_config_path, 'r') as file:
|
41 |
wavenet_config = yaml.safe_load(file)
|
42 |
|
43 |
+
with open(gpt_micro_config_path, 'r') as file:
|
44 |
+
gpt_micro_config = yaml.safe_load(file)
|
45 |
|
46 |
##################################################################################
|
47 |
## MLP
|
|
|
310 |
probs = F.softmax(logits[:,-1,:], dim=1)
|
311 |
return torch.multinomial(probs, num_samples=1).item()
|
312 |
|
313 |
+
gpt_micro = GPT(gpt_micro_config)
|
314 |
+
gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
|
315 |
+
gpt_micro.eval()
|
316 |
|
317 |
##################################################################################
|
318 |
## Gradio App
|
|
|
325 |
elif model == "WaveNet":
|
326 |
stoi = wavenet_config['stoi']
|
327 |
window = wavenet_config['window']
|
328 |
+
elif model == "GPT Micro":
|
329 |
+
stoi = gpt_micro_config['stoi']
|
330 |
+
window = gpt_micro_config['window']
|
331 |
else:
|
332 |
raise Exception("Model not selected")
|
333 |
|
|
|
350 |
ix = mlp.sample_char(x)
|
351 |
elif model == "WaveNet":
|
352 |
ix = wavenet.sample_char(x)
|
353 |
+
elif model == "GPT Micro":
|
354 |
+
ix = gpt_micro.sample_char(x)
|
355 |
else:
|
356 |
raise Exception("Model not selected")
|
357 |
|
|
|
370 |
inputs=[
|
371 |
gr.Textbox(placeholder="Start name with..."),
|
372 |
gr.Number(value=5),
|
373 |
+
gr.Dropdown(["MLP", "WaveNet", "GPT Micro"], value="GPT Micro"),
|
374 |
],
|
375 |
outputs="text",
|
376 |
)
|