LyricGeneration / StreamlitModel.py
selbl's picture
Update StreamlitModel.py
3e231ed verified
raw
history blame contribute delete
No virus
5.57 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 19 13:56:57 2024
@author: selbl
"""
#This code runs the lyric generation model
#Because it is a script, I do not spend time writing about what each part does
#For more info please refer to the Github script
import torch
from transformers import GPT2Tokenizer, GPT2Config, GPT2Model, GPT2PreTrainedModel
from torch.nn import functional as F
from better_profanity import profanity
import re
import torch.hub
profanity.load_censor_words()
#It seems the streamlit space does not allow mps
#device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class GPT2_Model(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
# this is necessary since we add a new unique token for pad_token
self.transformer.resize_token_embeddings(len(tokenizer))
self.lm_head = torch.nn.Linear(config.n_embd, len(tokenizer), bias=False)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
x = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
x = self.lm_head(x)
return x
#Dedfine generation function
def generate(idx, max_new_tokens, context_size, tokenizer, model, top_k=10, top_p=0.95):
for _ in range(max_new_tokens):
if idx[:,-1].item() != tokenizer.encode(tokenizer.eos_token)[0]:
# crop idx to the last block_size tokens
idx_cond = idx[:, -context_size:]
# get the predictions
logits = model(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :]
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1)
# sort probabilities in descending order
sorted_probs, indices = torch.sort(probs, descending=True)
# compute cumsum of probabilities
probs_cumsum = torch.cumsum(sorted_probs, dim=1)
# choose only top_p tokens
sorted_probs, indices = sorted_probs[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] + 1], indices[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] +1]
# choose only top_k tokens
sorted_probs, indices = sorted_probs[:,:top_k], indices[:,:top_k]
# sample from the distribution
sorted_probs = F.softmax(sorted_probs, dim=-1)
idx_next = indices[:, torch.multinomial(sorted_probs, num_samples=1)].squeeze(0)
# append new token ids
idx = torch.cat((idx, idx_next), dim=1)
else:
break
return idx
#Define capitalization functions
def custom_capitalize(match):
return match.group(1).capitalize()
def capitalize_string(input_string):
# Capitalize every first letter after "\n\n "
result = re.sub(r'\n\n\s*([a-zA-Z])', lambda x: '\n\n' + x.group(1).upper(), input_string)
# Capitalize every first letter in every word inside brackets ([ ])
result = re.sub(r'\[([^\]]*)\]', lambda x: '[' + ' '.join(word.capitalize() for word in x.group(1).split()) + ']', result)
# Capitalize every instance of the letter i by itself
result = re.sub(r'\bi\b', 'I', result)
return result
#Format the string
def format_string(input_string):
result = ''.join([char + ('\n\n' if char == ']' else '') for char in input_string.replace('\n', '')])
result = result.replace('[', '\n\n[')
#Capitalize for good measure
result = capitalize_string(result)
return result
def TextGeneration(prompt,prof=False,parts=True):
#Load everything
#Load model
configuration = GPT2Config()
gpt_model = GPT2_Model(configuration).to(device)
gpt_model.load_state_dict(torch.load('GPT-Trained-Model.pt',map_location=device))
#state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=',map_location=device)
#state_dict = torch.hub.load_state_dict_from_url(r'https://huggingface.co/spaces/selbl/LyricGeneration/resolve/main/GPT-Trained-Model',map_location=device)
#gpt_model.load_state_dict(state_dict)
#Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
#Call loader function
gpt_model.eval()
#Set model parameters
prefix = '[verse]'
#Input prompt
prompt = prompt.lower()
#Pre-process prompt
if parts:
prompt = prefix + ' ' + prompt
#Generate
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
generated = generated.to(device)
sample_outputs = generate(generated,
max_new_tokens=200,
context_size=450,
tokenizer=tokenizer,
model=gpt_model,
top_k=10,
top_p=0.95)
#Store the lyrics and decode
lyric = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
#Format if it has parts
if parts:
lyric = format_string(lyric)
#Remove profanity
if not prof:
lyric = profanity.censor(lyric)
return lyric