jefsnacker commited on
Commit
243da15
1 Parent(s): e538149

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import yaml
8
+
9
+
10
+ config_path = huggingface_hub.hf_hub_download(
11
+ "jefsnacker/surname_mlp",
12
+ "torch_mlp_config.yaml")
13
+
14
+ weights_path = huggingface_hub.hf_hub_download(
15
+ "jefsnacker/surname_mlp",
16
+ "mlp_weights.pt")
17
+
18
+ with open(config_path, 'r') as file:
19
+ config = yaml.safe_load(file)
20
+
21
+ stoi = config['stoi']
22
+ itos = {s:i for i,s in stoi.items()}
23
+
24
+ class MLP(nn.Module):
25
+ def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
26
+ super(MLP, self).__init__()
27
+
28
+ self.window = window
29
+ self.hidden_nodes = hidden_nodes
30
+ self.embeddings = embeddings
31
+
32
+ self.C = nn.Parameter(torch.randn((num_char, embeddings)) * 0.1, requires_grad=True)
33
+
34
+ self.first = nn.Linear(embeddings*window, hidden_nodes)
35
+
36
+ self.layers = nn.Sequential()
37
+ for i in range(num_layers):
38
+ self.layers = self.layers.extend(nn.Sequential(
39
+ nn.Linear(hidden_nodes, hidden_nodes, bias=False),
40
+ nn.BatchNorm1d(hidden_nodes),
41
+ nn.Tanh()))
42
+
43
+ self.final = nn.Linear(hidden_nodes, num_char)
44
+
45
+ def forward(self, x):
46
+ x = self.C[x]
47
+ x = self.first(x.view(-1, self.window*self.embeddings))
48
+
49
+ x = self.layers(x)
50
+
51
+ x = self.final(x)
52
+ return x
53
+
54
+ def sample_char(self, x):
55
+ logits = self(x)
56
+ probs = F.softmax(logits, dim=1)
57
+ return torch.multinomial(probs, num_samples=1).item()
58
+
59
+ mlp = MLP(config['num_char'],
60
+ config['hidden_nodes'],
61
+ config['embeddings'],
62
+ config['window'],
63
+ config['num_layers'])
64
+
65
+ mlp.load_state_dict(torch.load(weights_path))
66
+ mlp.eval()
67
+
68
+ def generate_names(name_start, number_of_names):
69
+ names = ""
70
+ for _ in range((int)(number_of_names)):
71
+
72
+ # Initialize name with user input
73
+ name = ""
74
+ context = [0] * config['window']
75
+ for c in name_start.lower():
76
+ name += c
77
+ context = context[1:] + [stoi[c]]
78
+
79
+ # Run inference to finish off the name
80
+ while True:
81
+ ix = mlp.sample_char(context)
82
+
83
+ context = context[1:] + [ix]
84
+ name += itos[ix]
85
+
86
+ if ix == 0:
87
+ break
88
+
89
+ names += name + "\n"
90
+
91
+ return names
92
+
93
+ app = gr.Interface(
94
+ fn=generate_names,
95
+ inputs=[
96
+ gr.Textbox(placeholder="Start name with..."),
97
+ gr.Number(value=1)
98
+ ],
99
+ outputs="text",
100
+ )
101
+ app.launch()