#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Feb 19 13:56:57 2024 @author: selbl """ #This code runs the lyric generation model #Because it is a script, I do not spend time writing about what each part does #For more info please refer to the Github script import torch from transformers import GPT2Tokenizer, GPT2Config, GPT2Model, GPT2PreTrainedModel from torch.nn import functional as F from better_profanity import profanity import re import torch.hub profanity.load_censor_words() #It seems the streamlit space does not allow mps #device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu' class GPT2_Model(GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = GPT2Model.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>') # this is necessary since we add a new unique token for pad_token self.transformer.resize_token_embeddings(len(tokenizer)) self.lm_head = torch.nn.Linear(config.n_embd, len(tokenizer), bias=False) def forward(self, input_ids, attention_mask=None, token_type_ids=None): x = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0] x = self.lm_head(x) return x #Dedfine generation function def generate(idx, max_new_tokens, context_size, tokenizer, model, top_k=10, top_p=0.95): for _ in range(max_new_tokens): if idx[:,-1].item() != tokenizer.encode(tokenizer.eos_token)[0]: # crop idx to the last block_size tokens idx_cond = idx[:, -context_size:] # get the predictions logits = model(idx_cond) # focus only on the last time step logits = logits[:, -1, :] # apply softmax to get probabilities probs = F.softmax(logits, dim=-1) # sort probabilities in descending order sorted_probs, indices = torch.sort(probs, descending=True) # compute cumsum of probabilities probs_cumsum = torch.cumsum(sorted_probs, dim=1) # choose only top_p tokens sorted_probs, indices = sorted_probs[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] + 1], indices[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] +1] # choose only top_k tokens sorted_probs, indices = sorted_probs[:,:top_k], indices[:,:top_k] # sample from the distribution sorted_probs = F.softmax(sorted_probs, dim=-1) idx_next = indices[:, torch.multinomial(sorted_probs, num_samples=1)].squeeze(0) # append new token ids idx = torch.cat((idx, idx_next), dim=1) else: break return idx #Define capitalization functions def custom_capitalize(match): return match.group(1).capitalize() def capitalize_string(input_string): # Capitalize every first letter after "\n\n " result = re.sub(r'\n\n\s*([a-zA-Z])', lambda x: '\n\n' + x.group(1).upper(), input_string) # Capitalize every first letter in every word inside brackets ([ ]) result = re.sub(r'\[([^\]]*)\]', lambda x: '[' + ' '.join(word.capitalize() for word in x.group(1).split()) + ']', result) # Capitalize every instance of the letter i by itself result = re.sub(r'\bi\b', 'I', result) return result #Format the string def format_string(input_string): result = ''.join([char + ('\n\n' if char == ']' else '') for char in input_string.replace('\n', '')]) result = result.replace('[', '\n\n[') #Capitalize for good measure result = capitalize_string(result) return result def TextGeneration(prompt,prof=False,parts=True): #Load everything #Load model configuration = GPT2Config() gpt_model = GPT2_Model(configuration).to(device) gpt_model.load_state_dict(torch.load('GPT-Trained-Model.pt',map_location=device)) #state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=',map_location=device) #state_dict = torch.hub.load_state_dict_from_url(r'https://huggingface.co/spaces/selbl/LyricGeneration/resolve/main/GPT-Trained-Model',map_location=device) #gpt_model.load_state_dict(state_dict) #Load tokenizer tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>') #Call loader function gpt_model.eval() #Set model parameters prefix = '[verse]' #Input prompt prompt = prompt.lower() #Pre-process prompt if parts: prompt = prefix + ' ' + prompt #Generate generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) generated = generated.to(device) sample_outputs = generate(generated, max_new_tokens=200, context_size=450, tokenizer=tokenizer, model=gpt_model, top_k=10, top_p=0.95) #Store the lyrics and decode lyric = tokenizer.decode(sample_outputs[0], skip_special_tokens=True) #Format if it has parts if parts: lyric = format_string(lyric) #Remove profanity if not prof: lyric = profanity.censor(lyric) return lyric