import math import yaml import gradio as gr import huggingface_hub import torch import torch.nn as nn import torch.nn.functional as F mlp_config_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "torch_mlp_config.yaml") mlp_weights_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "mlp_weights.pt") wavenet_config_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "wavenet_config.yaml") wavenet_weights_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "wavenet_weights.pt") gpt_micro_config_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "micro_gpt_config.yaml") gpt_micro_weights_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "micro_gpt_weights.pt") gpt_rev_config_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "rev_gpt_config.yaml") gpt_rev_weights_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "rev_gpt_weights.pt") gpt_first_rev_config_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "first_name_gpt_config.yaml") gpt_first_rev_weights_path = huggingface_hub.hf_hub_download( "jefsnacker/surname_generator", "first_name_gpt_weights.pt") with open(mlp_config_path, 'r') as file: mlp_config = yaml.safe_load(file) with open(wavenet_config_path, 'r') as file: wavenet_config = yaml.safe_load(file) with open(gpt_micro_config_path, 'r') as file: gpt_micro_config = yaml.safe_load(file) with open(gpt_rev_config_path, 'r') as file: gpt_rev_config = yaml.safe_load(file) with open(gpt_first_rev_config_path, 'r') as file: gpt_first_rev_config = yaml.safe_load(file) ################################################################################## ## MLP ################################################################################## class MLP(nn.Module): def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers): super(MLP, self).__init__() self.window = window self.hidden_nodes = hidden_nodes self.embeddings = embeddings self.C = nn.Parameter(torch.randn((num_char, embeddings)) * 0.1, requires_grad=True) self.first = nn.Linear(embeddings*window, hidden_nodes) self.layers = nn.Sequential() for i in range(num_layers): self.layers = self.layers.extend(nn.Sequential( nn.Linear(hidden_nodes, hidden_nodes, bias=False), nn.BatchNorm1d(hidden_nodes), nn.Tanh())) self.final = nn.Linear(hidden_nodes, num_char) def forward(self, x): x = self.C[x] x = self.first(x.view(-1, self.window*self.embeddings)) x = self.layers(x) x = self.final(x) return x def sample_char(self, x): logits = self(x) probs = F.softmax(logits, dim=1) return torch.multinomial(probs, num_samples=1).item() mlp = MLP(mlp_config['num_char'], mlp_config['hidden_nodes'], mlp_config['embeddings'], mlp_config['window'], mlp_config['num_layers']) mlp.load_state_dict(torch.load(mlp_weights_path)) mlp.eval() ################################################################################## ## WaveNet ################################################################################## class WaveNet(nn.Module): def __init__(self, num_char, hidden_nodes, embeddings, window, num_layers): super(WaveNet, self).__init__() self.window = window self.hidden_nodes = hidden_nodes self.embeddings = embeddings self.layers = nn.Sequential( nn.Embedding(num_char, embeddings) ) for i in range(num_layers): if i == 0: nodes = window else: nodes = hidden_nodes self.layers = self.layers.extend(nn.Sequential( nn.Conv1d(nodes, hidden_nodes, kernel_size=2, stride=1, bias=False), nn.BatchNorm1d(hidden_nodes), nn.Tanh())) self.layers = self.layers.extend(nn.Sequential( nn.Flatten(), nn.Linear(hidden_nodes*(embeddings-num_layers), num_char) )) def forward(self, x): return self.layers(x) def sample_char(self, x): logits = self(x) probs = F.softmax(logits, dim=1) return torch.multinomial(probs, num_samples=1).item() wavenet = WaveNet(wavenet_config['num_char'], wavenet_config['hidden_nodes'], wavenet_config['embeddings'], wavenet_config['window'], wavenet_config['num_layers']) wavenet.load_state_dict(torch.load(wavenet_weights_path)) wavenet.eval() ################################################################################## ## Transformer ################################################################################## class NewGELU(nn.Module): """ Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 """ def forward(self, x): return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) class GptAttention(nn.Module): """ For this attention module k = v = q are all the same. It's for encoder/decoder only transfomers. """ def __init__(self, config): super(GptAttention, self).__init__() self.config = config assert self.config["d_model"] % self.config["heads"] == 0 self.heads = self.config["heads"] self.w_attn = nn.Linear(self.config["d_model"], 3*self.config["d_model"]) self.head = nn.Linear(self.config["d_model"], self.config["d_model"]) self.attn_dropout = nn.Dropout(config["attn_pdrop"]) self.resid_dropout = nn.Dropout(config["resid_pdrop"]) # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer( "bias", torch.tril( torch.ones( self.config["window"], self.config["window"]) ).view(1, 1, self.config["window"], self.config["window"]) ) def forward(self, x): B, window, embs = x.shape q, v, k = self.w_attn(x).split(self.config["d_model"], dim=2) # (B, heads, window, embs) q = q.view( B, window, self.config["heads"], embs // self.config["heads"] ).transpose(1, 2) k = k.view( B, window, self.config["heads"], embs // self.config["heads"] ).transpose(1, 2) v = v.view( B, window, self.config["heads"], embs // self.config["heads"] ).transpose(1, 2) # Self-attend: (B, heads, window, embs) x (B, heads, embs, window) -> (B, heads, window, window) scores = q @ k.transpose(-2, -1) / math.sqrt(k.size(-1)) mask = scores.masked_fill(self.bias[:,:,:window,:window] == 0, float('-inf')) probs = F.softmax(mask, dim=-1) attn = self.attn_dropout(probs) attn = probs @ v attn = attn.transpose(1, 2).contiguous().view(B, window, embs) return self.resid_dropout(self.head(attn)) class FeedForward(nn.Module): def __init__(self, config): super(FeedForward, self).__init__() self.l1 = nn.Linear(config["d_model"], 4*config["d_model"]) self.l2 = nn.Linear(4*config["d_model"], config["d_model"]) self.dropout = nn.Dropout(config["resid_pdrop"]) def forward(self, x): x = NewGELU()(self.l1(x)) return self.dropout(self.l2(x)) class Block(nn.Module): def __init__(self, config): super(Block, self).__init__() self.attn = GptAttention(config) self.norm1 = nn.LayerNorm(config["d_model"]) self.ff = FeedForward(config) self.norm2 = nn.LayerNorm(config["d_model"]) def forward(self, x): x = self.norm1(x + self.attn(x)) x = self.norm2(x + self.ff(x)) return x class GPT(nn.Module): def __init__(self, config): super(GPT, self).__init__() self.config = config self.vocab_emb = nn.Embedding(self.config["vocab"], self.config["d_model"]) self.pos_emb = nn.Embedding(self.config["window"], self.config["d_model"]) self.emb_dropout = nn.Dropout(config["embd_pdrop"]) self.blocks = nn.ModuleList([Block(self.config) for _ in range(self.config["blocks"])]) self.head_layer_norm = nn.LayerNorm(config["d_model"]) self.head = nn.Linear(self.config["d_model"], self.config["vocab"]) def forward(self, x): vocab_emb = self.vocab_emb(x) pos_emb = self.pos_emb(torch.arange(0, x.shape[1], dtype=torch.long, device=x.device)) x = self.emb_dropout(vocab_emb + pos_emb) for b in self.blocks: x = b(x) x = self.head_layer_norm(x) x = self.head(x) return x def configure_opt(self): p_decay = set() p_no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name # random note: because named_modules and named_parameters are recursive # we will see the same tensors p many many times. but doing it this way # allows us to know which parent module any tensor p belongs to... if pn.endswith('bias'): # all biases will not be decayed p_no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed p_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed p_no_decay.add(fpn) # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = p_decay & p_no_decay union_params = p_decay | p_no_decay assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ % (str(param_dict.keys() - union_params), ) # create the pytorch optimizer object optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(p_decay))], "weight_decay": self.config["weight_decay"]}, {"params": [param_dict[pn] for pn in sorted(list(p_no_decay))], "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW( optim_groups, lr=self.config["lr"], betas=(self.config["b1"], self.config["b2"]) ) return optimizer def sample_char(self, x): logits = self(x) probs = F.softmax(logits[:,-1,:], dim=1) return torch.multinomial(probs, num_samples=1).item() gpt_micro = GPT(gpt_micro_config) gpt_micro.load_state_dict(torch.load(gpt_micro_weights_path)) gpt_micro.eval() gpt_rev = GPT(gpt_rev_config) gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path)) gpt_rev.eval() gpt_first_rev = GPT(gpt_first_rev_config) gpt_first_rev.load_state_dict(torch.load(gpt_first_rev_weights_path)) gpt_first_rev.eval() ################################################################################## ## Gradio App ################################################################################## def generate_names(name_start, name_end, number_of_names, model): if number_of_names < 0: return "Error: Please enter a positive number of names to generate!" # Select model if model == "MLP": config = mlp_config sample_fcn = mlp.sample_char elif model == "WaveNet": config = wavenet_config sample_fcn = wavenet.sample_char elif model == "GPT Micro": config = gpt_micro_config sample_fcn = gpt_micro.sample_char elif model == "GPT Rev": config = gpt_rev_config sample_fcn = gpt_rev.sample_char elif model == "GPT First Rev": config = gpt_first_rev_config sample_fcn = gpt_first_rev.sample_char else: return "Error: Model not selected" stoi = config['stoi'] itos = {s:i for i,s in stoi.items()} output = "" # Sanitize user inputs, and append errors to output name_end = name_end.lower() name_start = name_start.lower() for c in name_end: if c not in stoi: return "Please change name end. \"" + c + "\" not included in the training set." for c in name_start: if c not in stoi: return "Please change name start. \"" + c + "\" not included in the training set." if "num_final_chars_in_dataset" in config and len(name_end) > config["num_final_chars_in_dataset"]: name_end = name_end[-config["num_final_chars_in_dataset"]:] output += "Only accepts up to " + str(config["num_final_chars_in_dataset"]) + " final chars. Using: " + str(name_end) + "\n" elif "num_final_chars_in_dataset" not in config and name_end != "": output += "Final chars not used. Need to use a \"Rev\" model trained with this feature.\n" ## Print requested names for _ in range((int)(number_of_names)): name = "" context = [0] * config['window'] if "num_final_chars_in_dataset" in config: for c in name_end: context = context[1:] + [stoi[c]] context = context[1:] + [stoi['.']] # Initialize name with user input for c in name_start: name += c context = context[1:] + [stoi[c]] # Run inference to finish off the name while True: x = torch.tensor(context).view(1, -1) ix = sample_fcn(x) context = context[1:] + [ix] name += itos[ix] if ix == 0: break output += name + "\n" return output demo = gr.Interface( fn=generate_names, inputs=[ gr.Textbox(placeholder="Start name with..."), gr.Textbox(placeholder="End name with... (only works for rev model)"), gr.Number(value=5), gr.Dropdown(["MLP", "WaveNet", "GPT Micro", "GPT Rev", "GPT First Rev"], value="GPT Rev"), ], outputs="text", ) demo.launch()