Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Mon Feb 19 13:56:57 2024 | |
@author: selbl | |
""" | |
#This code runs the lyric generation model | |
#Because it is a script, I do not spend time writing about what each part does | |
#For more info please refer to the Github script | |
import torch | |
from transformers import GPT2Tokenizer, GPT2Config, GPT2Model, GPT2PreTrainedModel | |
from torch.nn import functional as F | |
from better_profanity import profanity | |
import re | |
import torch.hub | |
profanity.load_censor_words() | |
#It seems the streamlit space does not allow mps | |
#device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu' | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
class GPT2_Model(GPT2PreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.transformer = GPT2Model.from_pretrained('gpt2') | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>') | |
# this is necessary since we add a new unique token for pad_token | |
self.transformer.resize_token_embeddings(len(tokenizer)) | |
self.lm_head = torch.nn.Linear(config.n_embd, len(tokenizer), bias=False) | |
def forward(self, input_ids, attention_mask=None, token_type_ids=None): | |
x = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0] | |
x = self.lm_head(x) | |
return x | |
#Dedfine generation function | |
def generate(idx, max_new_tokens, context_size, tokenizer, model, top_k=10, top_p=0.95): | |
for _ in range(max_new_tokens): | |
if idx[:,-1].item() != tokenizer.encode(tokenizer.eos_token)[0]: | |
# crop idx to the last block_size tokens | |
idx_cond = idx[:, -context_size:] | |
# get the predictions | |
logits = model(idx_cond) | |
# focus only on the last time step | |
logits = logits[:, -1, :] | |
# apply softmax to get probabilities | |
probs = F.softmax(logits, dim=-1) | |
# sort probabilities in descending order | |
sorted_probs, indices = torch.sort(probs, descending=True) | |
# compute cumsum of probabilities | |
probs_cumsum = torch.cumsum(sorted_probs, dim=1) | |
# choose only top_p tokens | |
sorted_probs, indices = sorted_probs[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] + 1], indices[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] +1] | |
# choose only top_k tokens | |
sorted_probs, indices = sorted_probs[:,:top_k], indices[:,:top_k] | |
# sample from the distribution | |
sorted_probs = F.softmax(sorted_probs, dim=-1) | |
idx_next = indices[:, torch.multinomial(sorted_probs, num_samples=1)].squeeze(0) | |
# append new token ids | |
idx = torch.cat((idx, idx_next), dim=1) | |
else: | |
break | |
return idx | |
#Define capitalization functions | |
def custom_capitalize(match): | |
return match.group(1).capitalize() | |
def capitalize_string(input_string): | |
# Capitalize every first letter after "\n\n " | |
result = re.sub(r'\n\n\s*([a-zA-Z])', lambda x: '\n\n' + x.group(1).upper(), input_string) | |
# Capitalize every first letter in every word inside brackets ([ ]) | |
result = re.sub(r'\[([^\]]*)\]', lambda x: '[' + ' '.join(word.capitalize() for word in x.group(1).split()) + ']', result) | |
# Capitalize every instance of the letter i by itself | |
result = re.sub(r'\bi\b', 'I', result) | |
return result | |
#Format the string | |
def format_string(input_string): | |
result = ''.join([char + ('\n\n' if char == ']' else '') for char in input_string.replace('\n', '')]) | |
result = result.replace('[', '\n\n[') | |
#Capitalize for good measure | |
result = capitalize_string(result) | |
return result | |
def TextGeneration(prompt,prof=False,parts=True): | |
#Load everything | |
#Load model | |
configuration = GPT2Config() | |
gpt_model = GPT2_Model(configuration).to(device) | |
#gpt_model.load_state_dict(torch.load('GPT-Trained-Model-Prod.pt')) | |
state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=',map_location=device) | |
gpt_model.load_state_dict(state_dict) | |
#Load tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>') | |
#Call loader function | |
gpt_model.eval() | |
#Set model parameters | |
prefix = '[verse]' | |
#Input prompt | |
prompt = prompt.lower() | |
#Pre-process prompt | |
if parts: | |
prompt = prefix + ' ' + prompt | |
#Generate | |
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) | |
generated = generated.to(device) | |
sample_outputs = generate(generated, | |
max_new_tokens=200, | |
context_size=400, | |
tokenizer=tokenizer, | |
model=gpt_model, | |
top_k=10, | |
top_p=0.95) | |
#Store the lyrics and decode | |
lyric = tokenizer.decode(sample_outputs[0], skip_special_tokens=True) | |
#Format if it has parts | |
if parts: | |
lyric = format_string(lyric) | |
#Remove profanity | |
if not prof: | |
lyric = profanity.censor(lyric) | |
return lyric | |