Spaces:
Sleeping
Sleeping
File size: 5,566 Bytes
0f1198e 4a37fb4 d508cb1 0f1198e 97a0dbb 0f1198e 0b384a2 f956d70 4dc714c cb05e99 0f1198e 3e231ed 0f1198e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 19 13:56:57 2024
@author: selbl
"""
#This code runs the lyric generation model
#Because it is a script, I do not spend time writing about what each part does
#For more info please refer to the Github script
import torch
from transformers import GPT2Tokenizer, GPT2Config, GPT2Model, GPT2PreTrainedModel
from torch.nn import functional as F
from better_profanity import profanity
import re
import torch.hub
profanity.load_censor_words()
#It seems the streamlit space does not allow mps
#device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class GPT2_Model(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
# this is necessary since we add a new unique token for pad_token
self.transformer.resize_token_embeddings(len(tokenizer))
self.lm_head = torch.nn.Linear(config.n_embd, len(tokenizer), bias=False)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
x = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
x = self.lm_head(x)
return x
#Dedfine generation function
def generate(idx, max_new_tokens, context_size, tokenizer, model, top_k=10, top_p=0.95):
for _ in range(max_new_tokens):
if idx[:,-1].item() != tokenizer.encode(tokenizer.eos_token)[0]:
# crop idx to the last block_size tokens
idx_cond = idx[:, -context_size:]
# get the predictions
logits = model(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :]
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1)
# sort probabilities in descending order
sorted_probs, indices = torch.sort(probs, descending=True)
# compute cumsum of probabilities
probs_cumsum = torch.cumsum(sorted_probs, dim=1)
# choose only top_p tokens
sorted_probs, indices = sorted_probs[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] + 1], indices[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] +1]
# choose only top_k tokens
sorted_probs, indices = sorted_probs[:,:top_k], indices[:,:top_k]
# sample from the distribution
sorted_probs = F.softmax(sorted_probs, dim=-1)
idx_next = indices[:, torch.multinomial(sorted_probs, num_samples=1)].squeeze(0)
# append new token ids
idx = torch.cat((idx, idx_next), dim=1)
else:
break
return idx
#Define capitalization functions
def custom_capitalize(match):
return match.group(1).capitalize()
def capitalize_string(input_string):
# Capitalize every first letter after "\n\n "
result = re.sub(r'\n\n\s*([a-zA-Z])', lambda x: '\n\n' + x.group(1).upper(), input_string)
# Capitalize every first letter in every word inside brackets ([ ])
result = re.sub(r'\[([^\]]*)\]', lambda x: '[' + ' '.join(word.capitalize() for word in x.group(1).split()) + ']', result)
# Capitalize every instance of the letter i by itself
result = re.sub(r'\bi\b', 'I', result)
return result
#Format the string
def format_string(input_string):
result = ''.join([char + ('\n\n' if char == ']' else '') for char in input_string.replace('\n', '')])
result = result.replace('[', '\n\n[')
#Capitalize for good measure
result = capitalize_string(result)
return result
def TextGeneration(prompt,prof=False,parts=True):
#Load everything
#Load model
configuration = GPT2Config()
gpt_model = GPT2_Model(configuration).to(device)
gpt_model.load_state_dict(torch.load('GPT-Trained-Model.pt',map_location=device))
#state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=',map_location=device)
#state_dict = torch.hub.load_state_dict_from_url(r'https://huggingface.co/spaces/selbl/LyricGeneration/resolve/main/GPT-Trained-Model',map_location=device)
#gpt_model.load_state_dict(state_dict)
#Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
#Call loader function
gpt_model.eval()
#Set model parameters
prefix = '[verse]'
#Input prompt
prompt = prompt.lower()
#Pre-process prompt
if parts:
prompt = prefix + ' ' + prompt
#Generate
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
generated = generated.to(device)
sample_outputs = generate(generated,
max_new_tokens=200,
context_size=450,
tokenizer=tokenizer,
model=gpt_model,
top_k=10,
top_p=0.95)
#Store the lyrics and decode
lyric = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
#Format if it has parts
if parts:
lyric = format_string(lyric)
#Remove profanity
if not prof:
lyric = profanity.censor(lyric)
return lyric
|