File size: 5,390 Bytes
0f1198e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a37fb4
d508cb1
 
0f1198e
 
 
 
 
 
 
 
 
 
 
 
 
97a0dbb
0f1198e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae40d11
835c0d8
0f1198e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 19 13:56:57 2024

@author: selbl
"""

#This code runs the lyric generation model
#Because it is a script, I do not spend time writing about what each part does
#For more info please refer to the Github script

import torch
from transformers import GPT2Tokenizer, GPT2Config, GPT2Model, GPT2PreTrainedModel
from torch.nn import functional as F
from better_profanity import profanity
import re
import torch.hub

profanity.load_censor_words()    

#It seems the streamlit space does not allow mps
#device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
class GPT2_Model(GPT2PreTrainedModel):

    def __init__(self, config):

        super().__init__(config)

        self.transformer = GPT2Model.from_pretrained('gpt2')
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')

        # this is necessary since we add a new unique token for pad_token
        self.transformer.resize_token_embeddings(len(tokenizer))

        self.lm_head = torch.nn.Linear(config.n_embd, len(tokenizer), bias=False)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):

        x = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
        x = self.lm_head(x)

        return x
        

#Dedfine generation function
def generate(idx, max_new_tokens, context_size, tokenizer, model, top_k=10, top_p=0.95):

        for _ in range(max_new_tokens):
            if idx[:,-1].item() != tokenizer.encode(tokenizer.eos_token)[0]:
                # crop idx to the last block_size tokens
                idx_cond = idx[:, -context_size:]
                # get the predictions
                logits = model(idx_cond)
                # focus only on the last time step
                logits = logits[:, -1, :]
                # apply softmax to get probabilities
                probs = F.softmax(logits, dim=-1)
                # sort probabilities in descending order
                sorted_probs, indices = torch.sort(probs, descending=True)
                # compute cumsum of probabilities
                probs_cumsum = torch.cumsum(sorted_probs, dim=1)
                # choose only top_p tokens
                sorted_probs, indices = sorted_probs[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] + 1], indices[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] +1]
                # choose only top_k tokens
                sorted_probs, indices = sorted_probs[:,:top_k], indices[:,:top_k]
                # sample from the distribution
                sorted_probs = F.softmax(sorted_probs, dim=-1)
                idx_next = indices[:, torch.multinomial(sorted_probs, num_samples=1)].squeeze(0)
                # append new token ids
                idx = torch.cat((idx, idx_next), dim=1)
            else:
                break

        return idx

#Define capitalization functions
def custom_capitalize(match):
    return match.group(1).capitalize()

def capitalize_string(input_string):
    # Capitalize every first letter after "\n\n "
    result = re.sub(r'\n\n\s*([a-zA-Z])', lambda x: '\n\n' + x.group(1).upper(), input_string)
    # Capitalize every first letter in every word inside brackets ([ ])
    result = re.sub(r'\[([^\]]*)\]', lambda x: '[' + ' '.join(word.capitalize() for word in x.group(1).split()) + ']', result)
    # Capitalize every instance of the letter i by itself
    result = re.sub(r'\bi\b', 'I', result)
    return result

#Format the string
def format_string(input_string):
    result = ''.join([char + ('\n\n' if char == ']' else '') for char in input_string.replace('\n', '')])
    result = result.replace('[', '\n\n[')
    #Capitalize for good measure
    result = capitalize_string(result)
    return result


def TextGeneration(prompt,prof=False,parts=True):
    #Load everything
    #Load model
    configuration = GPT2Config()
    gpt_model = GPT2_Model(configuration).to(device)
    #gpt_model.load_state_dict(torch.load('GPT-Trained-Model-Prod.pt'))
    state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=',map_location=device)
    gpt_model.load_state_dict(state_dict)
    #Load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>') 
    #Call loader function
    gpt_model.eval()
    #Set model parameters
    prefix = '[verse]'

    #Input prompt
    prompt = prompt.lower()

    #Pre-process prompt
    if parts:
        prompt = prefix + ' ' + prompt
    #Generate
    generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
    generated = generated.to(device)
    
    sample_outputs = generate(generated,
                             max_new_tokens=200,
                             context_size=400,
                             tokenizer=tokenizer,
                             model=gpt_model,
                             top_k=10,
                             top_p=0.95)
    
    #Store the lyrics and decode
    lyric = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
    #Format if it has parts
    if parts:
        lyric = format_string(lyric)
    #Remove profanity
    if not prof:
        lyric = profanity.censor(lyric)
    return lyric