surnamerator / app.py
jefsnacker's picture
updated weights for first rev model
8b07bee
import math
import yaml
import gradio as gr
import huggingface_hub
import torch
import torch.nn as nn
import torch.nn.functional as F
mlp_config_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"torch_mlp_config.yaml")
mlp_weights_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"mlp_weights.pt")
wavenet_config_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"wavenet_config.yaml")
wavenet_weights_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"wavenet_weights.pt")
gpt_micro_config_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"micro_gpt_config.yaml")
gpt_micro_weights_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"micro_gpt_weights.pt")
gpt_rev_config_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"rev_gpt_config.yaml")
gpt_rev_weights_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"rev_gpt_weights.pt")
gpt_first_rev_config_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"first_name_gpt_config.yaml")
gpt_first_rev_weights_path = huggingface_hub.hf_hub_download(
"jefsnacker/surname_generator",
"first_name_gpt_weights.pt")
with open(mlp_config_path, 'r') as file:
mlp_config = yaml.safe_load(file)
with open(wavenet_config_path, 'r') as file:
wavenet_config = yaml.safe_load(file)
with open(gpt_micro_config_path, 'r') as file:
gpt_micro_config = yaml.safe_load(file)
with open(gpt_rev_config_path, 'r') as file:
gpt_rev_config = yaml.safe_load(file)
with open(gpt_first_rev_config_path, 'r') as file:
gpt_first_rev_config = yaml.safe_load(file)
##################################################################################
## MLP
##################################################################################
class MLP(nn.Module):
def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
super(MLP, self).__init__()
self.window = window
self.hidden_nodes = hidden_nodes
self.embeddings = embeddings
self.C = nn.Parameter(torch.randn((num_char, embeddings)) * 0.1, requires_grad=True)
self.first = nn.Linear(embeddings*window, hidden_nodes)
self.layers = nn.Sequential()
for i in range(num_layers):
self.layers = self.layers.extend(nn.Sequential(
nn.Linear(hidden_nodes, hidden_nodes, bias=False),
nn.BatchNorm1d(hidden_nodes),
nn.Tanh()))
self.final = nn.Linear(hidden_nodes, num_char)
def forward(self, x):
x = self.C[x]
x = self.first(x.view(-1, self.window*self.embeddings))
x = self.layers(x)
x = self.final(x)
return x
def sample_char(self, x):
logits = self(x)
probs = F.softmax(logits, dim=1)
return torch.multinomial(probs, num_samples=1).item()
mlp = MLP(mlp_config['num_char'],
mlp_config['hidden_nodes'],
mlp_config['embeddings'],
mlp_config['window'],
mlp_config['num_layers'])
mlp.load_state_dict(torch.load(mlp_weights_path))
mlp.eval()
##################################################################################
## WaveNet
##################################################################################
class WaveNet(nn.Module):
def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers):
super(WaveNet, self).__init__()
self.window = window
self.hidden_nodes = hidden_nodes
self.embeddings = embeddings
self.layers = nn.Sequential(
nn.Embedding(num_char, embeddings)
)
for i in range(num_layers):
if i == 0:
nodes = window
else:
nodes = hidden_nodes
self.layers = self.layers.extend(nn.Sequential(
nn.Conv1d(nodes, hidden_nodes, kernel_size=2, stride=1, bias=False),
nn.BatchNorm1d(hidden_nodes),
nn.Tanh()))
self.layers = self.layers.extend(nn.Sequential(
nn.Flatten(),
nn.Linear(hidden_nodes*(embeddings-num_layers), num_char)
))
def forward(self, x):
return self.layers(x)
def sample_char(self, x):
logits = self(x)
probs = F.softmax(logits, dim=1)
return torch.multinomial(probs, num_samples=1).item()
wavenet = WaveNet(wavenet_config['num_char'],
wavenet_config['hidden_nodes'],
wavenet_config['embeddings'],
wavenet_config['window'],
wavenet_config['num_layers'])
wavenet.load_state_dict(torch.load(wavenet_weights_path))
wavenet.eval()
##################################################################################
## Transformer
##################################################################################
class NewGELU(nn.Module):
"""
Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class GptAttention(nn.Module):
"""
For this attention module k = v = q are all the same.
It's for encoder/decoder only transfomers.
"""
def __init__(self, config):
super(GptAttention, self).__init__()
self.config = config
assert self.config["d_model"] % self.config["heads"] == 0
self.heads = self.config["heads"]
self.w_attn = nn.Linear(self.config["d_model"], 3*self.config["d_model"])
self.head = nn.Linear(self.config["d_model"], self.config["d_model"])
self.attn_dropout = nn.Dropout(config["attn_pdrop"])
self.resid_dropout = nn.Dropout(config["resid_pdrop"])
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"bias",
torch.tril(
torch.ones(
self.config["window"],
self.config["window"])
).view(1, 1, self.config["window"], self.config["window"])
)
def forward(self, x):
B, window, embs = x.shape
q, v, k = self.w_attn(x).split(self.config["d_model"], dim=2)
# (B, heads, window, embs)
q = q.view(
B,
window,
self.config["heads"],
embs // self.config["heads"]
).transpose(1, 2)
k = k.view(
B,
window,
self.config["heads"],
embs // self.config["heads"]
).transpose(1, 2)
v = v.view(
B,
window,
self.config["heads"],
embs // self.config["heads"]
).transpose(1, 2)
# Self-attend: (B, heads, window, embs) x (B, heads, embs, window) -> (B, heads, window, window)
scores = q @ k.transpose(-2, -1) / math.sqrt(k.size(-1))
mask = scores.masked_fill(self.bias[:,:,:window,:window] == 0, float('-inf'))
probs = F.softmax(mask, dim=-1)
attn = self.attn_dropout(probs)
attn = probs @ v
attn = attn.transpose(1, 2).contiguous().view(B, window, embs)
return self.resid_dropout(self.head(attn))
class FeedForward(nn.Module):
def __init__(self, config):
super(FeedForward, self).__init__()
self.l1 = nn.Linear(config["d_model"], 4*config["d_model"])
self.l2 = nn.Linear(4*config["d_model"], config["d_model"])
self.dropout = nn.Dropout(config["resid_pdrop"])
def forward(self, x):
x = NewGELU()(self.l1(x))
return self.dropout(self.l2(x))
class Block(nn.Module):
def __init__(self, config):
super(Block, self).__init__()
self.attn = GptAttention(config)
self.norm1 = nn.LayerNorm(config["d_model"])
self.ff = FeedForward(config)
self.norm2 = nn.LayerNorm(config["d_model"])
def forward(self, x):
x = self.norm1(x + self.attn(x))
x = self.norm2(x + self.ff(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super(GPT, self).__init__()
self.config = config
self.vocab_emb = nn.Embedding(self.config["vocab"], self.config["d_model"])
self.pos_emb = nn.Embedding(self.config["window"], self.config["d_model"])
self.emb_dropout = nn.Dropout(config["embd_pdrop"])
self.blocks = nn.ModuleList([Block(self.config) for _ in range(self.config["blocks"])])
self.head_layer_norm = nn.LayerNorm(config["d_model"])
self.head = nn.Linear(self.config["d_model"], self.config["vocab"])
def forward(self, x):
vocab_emb = self.vocab_emb(x)
pos_emb = self.pos_emb(torch.arange(0, x.shape[1], dtype=torch.long, device=x.device))
x = self.emb_dropout(vocab_emb + pos_emb)
for b in self.blocks:
x = b(x)
x = self.head_layer_norm(x)
x = self.head(x)
return x
def configure_opt(self):
p_decay = set()
p_no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# random note: because named_modules and named_parameters are recursive
# we will see the same tensors p many many times. but doing it this way
# allows us to know which parent module any tensor p belongs to...
if pn.endswith('bias'):
# all biases will not be decayed
p_no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
p_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
p_no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = p_decay & p_no_decay
union_params = p_decay | p_no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(p_decay))], "weight_decay": self.config["weight_decay"]},
{"params": [param_dict[pn] for pn in sorted(list(p_no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(
optim_groups,
lr=self.config["lr"],
betas=(self.config["b1"], self.config["b2"])
)
return optimizer
def sample_char(self, x):
logits = self(x)
probs = F.softmax(logits[:,-1,:], dim=1)
return torch.multinomial(probs, num_samples=1).item()
gpt_micro = GPT(gpt_micro_config)
gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path))
gpt_micro.eval()
gpt_rev = GPT(gpt_rev_config)
gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path))
gpt_rev.eval()
gpt_first_rev = GPT(gpt_first_rev_config)
gpt_first_rev.load_state_dict(torch.load(gpt_first_rev_weights_path))
gpt_first_rev.eval()
##################################################################################
## Gradio App
##################################################################################
def generate_names(name_start, name_end, number_of_names, model):
if number_of_names < 0:
return "Error: Please enter a positive number of names to generate!"
# Select model
if model == "MLP":
config = mlp_config
sample_fcn = mlp.sample_char
elif model == "WaveNet":
config = wavenet_config
sample_fcn = wavenet.sample_char
elif model == "GPT Micro":
config = gpt_micro_config
sample_fcn = gpt_micro.sample_char
elif model == "GPT Rev":
config = gpt_rev_config
sample_fcn = gpt_rev.sample_char
elif model == "GPT First Rev":
config = gpt_first_rev_config
sample_fcn = gpt_first_rev.sample_char
else:
return "Error: Model not selected"
stoi = config['stoi']
itos = {s:i for i,s in stoi.items()}
output = ""
# Sanitize user inputs, and append errors to output
name_end = name_end.lower()
name_start = name_start.lower()
for c in name_end:
if c not in stoi:
return "Please change name end. \"" + c + "\" not included in the training set."
for c in name_start:
if c not in stoi:
return "Please change name start. \"" + c + "\" not included in the training set."
if "num_final_chars_in_dataset" in config and len(name_end) > config["num_final_chars_in_dataset"]:
name_end = name_end[-config["num_final_chars_in_dataset"]:]
output += "Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + str(name_end) + "\n"
elif "num_final_chars_in_dataset" not in config and name_end != "":
output += "Final chars not used. Need to use a \"Rev\" model trained with this feature.\n"
## Print requested names
for _ in range((int)(number_of_names)):
name = ""
context = [0] * config['window']
if "num_final_chars_in_dataset" in config:
for c in name_end:
context = context[1:] + [stoi[c]]
context = context[1:] + [stoi['.']]
# Initialize name with user input
for c in name_start:
name += c
context = context[1:] + [stoi[c]]
# Run inference to finish off the name
while True:
x = torch.tensor(context).view(1, -1)
ix = sample_fcn(x)
context = context[1:] + [ix]
name += itos[ix]
if ix == 0:
break
output += name + "\n"
return output
demo = gr.Interface(
fn=generate_names,
inputs=[
gr.Textbox(placeholder="Start name with..."),
gr.Textbox(placeholder="End name with... (only works for rev model)"),
gr.Number(value=5),
gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev", "GPT First Rev"], value="GPT Rev"),
],
outputs="text",
)
demo.launch()