jefsnacker commited on
Commit
7e62304
1 Parent(s): 3f362c0

rename gpt nano -> gpt micro

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -26,13 +26,13 @@ wavenet_weights_path = huggingface_hub.hf_hub_download(
26
  "jefsnacker/surname_generator",
27
  "wavenet_weights.pt")
28
 
29
- gpt_nano_config_path = huggingface_hub.hf_hub_download(
30
  "jefsnacker/surname_generator",
31
- "gpt_config.yaml")
32
 
33
- gpt_nano_weights_path = huggingface_hub.hf_hub_download(
34
  "jefsnacker/surname_generator",
35
- "gpt_weights.pt")
36
 
37
  with open(mlp_config_path, 'r') as file:
38
  mlp_config = yaml.safe_load(file)
@@ -40,8 +40,8 @@ with open(mlp_config_path, 'r') as file:
40
  with open(wavenet_config_path, 'r') as file:
41
  wavenet_config = yaml.safe_load(file)
42
 
43
- with open(gpt_nano_config_path, 'r') as file:
44
- gpt_nano_config = yaml.safe_load(file)
45
 
46
  ##################################################################################
47
  ## MLP
@@ -310,9 +310,9 @@ class GPT(nn.Module):
310
  probs = F.softmax(logits[:,-1,:], dim=1)
311
  return torch.multinomial(probs, num_samples=1).item()
312
 
313
- gpt_nano = GPT(gpt_nano_config)
314
- gpt_nano.load_state_dict(torch.load(gpt_nano_weights_path))
315
- gpt_nano.eval()
316
 
317
  ##################################################################################
318
  ## Gradio App
@@ -325,9 +325,9 @@ def generate_names(name_start, number_of_names, model):
325
  elif model == "WaveNet":
326
  stoi = wavenet_config['stoi']
327
  window = wavenet_config['window']
328
- elif model == "GPT Nano":
329
- stoi = gpt_nano_config['stoi']
330
- window = gpt_nano_config['window']
331
  else:
332
  raise Exception("Model not selected")
333
 
@@ -350,8 +350,8 @@ def generate_names(name_start, number_of_names, model):
350
  ix = mlp.sample_char(x)
351
  elif model == "WaveNet":
352
  ix = wavenet.sample_char(x)
353
- elif model == "GPT Nano":
354
- ix = gpt_nano.sample_char(x)
355
  else:
356
  raise Exception("Model not selected")
357
 
@@ -370,7 +370,7 @@ demo = gr.Interface(
370
  inputs=[
371
  gr.Textbox(placeholder="Start name with..."),
372
  gr.Number(value=5),
373
- gr.Dropdown(["MLP", "WaveNet", "GPT Nano"], value="GPT Nano"),
374
  ],
375
  outputs="text",
376
  )
 
26
  "jefsnacker/surname_generator",
27
  "wavenet_weights.pt")
28
 
29
+ gpt_micro_config_path = huggingface_hub.hf_hub_download(
30
  "jefsnacker/surname_generator",
31
+ "micro_gpt_config.yaml")
32
 
33
+ gpt_micro_weights_path = huggingface_hub.hf_hub_download(
34
  "jefsnacker/surname_generator",
35
+ "micro_gpt_weights.pt")
36
 
37
  with open(mlp_config_path, 'r') as file:
38
  mlp_config = yaml.safe_load(file)
 
40
  with open(wavenet_config_path, 'r') as file:
41
  wavenet_config = yaml.safe_load(file)
42
 
43
+ with open(gpt_micro_config_path, 'r') as file:
44
+ gpt_micro_config = yaml.safe_load(file)
45
 
46
  ##################################################################################
47
  ## MLP
 
310
  probs = F.softmax(logits[:,-1,:], dim=1)
311
  return torch.multinomial(probs, num_samples=1).item()
312
 
313
+ gpt_micro = GPT(gpt_micro_config)
314
+ gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
315
+ gpt_micro.eval()
316
 
317
  ##################################################################################
318
  ## Gradio App
 
325
  elif model == "WaveNet":
326
  stoi = wavenet_config['stoi']
327
  window = wavenet_config['window']
328
+ elif model == "GPT Micro":
329
+ stoi = gpt_micro_config['stoi']
330
+ window = gpt_micro_config['window']
331
  else:
332
  raise Exception("Model not selected")
333
 
 
350
  ix = mlp.sample_char(x)
351
  elif model == "WaveNet":
352
  ix = wavenet.sample_char(x)
353
+ elif model == "GPT Micro":
354
+ ix = gpt_micro.sample_char(x)
355
  else:
356
  raise Exception("Model not selected")
357
 
 
370
  inputs=[
371
  gr.Textbox(placeholder="Start name with..."),
372
  gr.Number(value=5),
373
+ gr.Dropdown(["MLP", "WaveNet", "GPT Micro"], value="GPT Micro"),
374
  ],
375
  outputs="text",
376
  )