Spaces:
Sleeping
Sleeping
Model Underlying the Program
Browse files- StreamlitModel.py +146 -0
StreamlitModel.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Mon Feb 19 13:56:57 2024
|
5 |
+
|
6 |
+
@author: selbl
|
7 |
+
"""
|
8 |
+
|
9 |
+
#This code runs the lyric generation model
|
10 |
+
#Because it is a script, I do not spend time writing about what each part does
|
11 |
+
#For more info please refer to the Github script
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
from torch import nn
|
16 |
+
from transformers import GPT2Tokenizer, GPT2Config, GPT2Model, GPT2PreTrainedModel
|
17 |
+
from torch.optim import AdamW
|
18 |
+
from datasets import load_dataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
from torch.nn import functional as F
|
21 |
+
import pandas as pd
|
22 |
+
from better_profanity import profanity
|
23 |
+
import re
|
24 |
+
import torch.hub
|
25 |
+
|
26 |
+
profanity.load_censor_words()
|
27 |
+
|
28 |
+
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu'
|
29 |
+
|
30 |
+
class GPT2_Model(GPT2PreTrainedModel):
|
31 |
+
|
32 |
+
def __init__(self, config):
|
33 |
+
|
34 |
+
super().__init__(config)
|
35 |
+
|
36 |
+
self.transformer = GPT2Model.from_pretrained('gpt2')
|
37 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
|
38 |
+
|
39 |
+
# this is necessary since we add a new unique token for pad_token
|
40 |
+
self.transformer.resize_token_embeddings(len(tokenizer))
|
41 |
+
|
42 |
+
self.lm_head = nn.Linear(config.n_embd, len(tokenizer), bias=False)
|
43 |
+
|
44 |
+
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
|
45 |
+
|
46 |
+
x = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
|
47 |
+
x = self.lm_head(x)
|
48 |
+
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
#Dedfine generation function
|
53 |
+
def generate(idx, max_new_tokens, context_size, tokenizer, model, top_k=10, top_p=0.95):
|
54 |
+
|
55 |
+
for _ in range(max_new_tokens):
|
56 |
+
if idx[:,-1].item() != tokenizer.encode(tokenizer.eos_token)[0]:
|
57 |
+
# crop idx to the last block_size tokens
|
58 |
+
idx_cond = idx[:, -context_size:]
|
59 |
+
# get the predictions
|
60 |
+
logits = model(idx_cond)
|
61 |
+
# focus only on the last time step
|
62 |
+
logits = logits[:, -1, :]
|
63 |
+
# apply softmax to get probabilities
|
64 |
+
probs = F.softmax(logits, dim=-1)
|
65 |
+
# sort probabilities in descending order
|
66 |
+
sorted_probs, indices = torch.sort(probs, descending=True)
|
67 |
+
# compute cumsum of probabilities
|
68 |
+
probs_cumsum = torch.cumsum(sorted_probs, dim=1)
|
69 |
+
# choose only top_p tokens
|
70 |
+
sorted_probs, indices = sorted_probs[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] + 1], indices[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] +1]
|
71 |
+
# choose only top_k tokens
|
72 |
+
sorted_probs, indices = sorted_probs[:,:top_k], indices[:,:top_k]
|
73 |
+
# sample from the distribution
|
74 |
+
sorted_probs = F.softmax(sorted_probs, dim=-1)
|
75 |
+
idx_next = indices[:, torch.multinomial(sorted_probs, num_samples=1)].squeeze(0)
|
76 |
+
# append new token ids
|
77 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
78 |
+
else:
|
79 |
+
break
|
80 |
+
|
81 |
+
return idx
|
82 |
+
|
83 |
+
#Define capitalization functions
|
84 |
+
def custom_capitalize(match):
|
85 |
+
return match.group(1).capitalize()
|
86 |
+
|
87 |
+
def capitalize_string(input_string):
|
88 |
+
# Capitalize every first letter after "\n\n "
|
89 |
+
result = re.sub(r'\n\n\s*([a-zA-Z])', lambda x: '\n\n' + x.group(1).upper(), input_string)
|
90 |
+
# Capitalize every first letter in every word inside brackets ([ ])
|
91 |
+
result = re.sub(r'\[([^\]]*)\]', lambda x: '[' + ' '.join(word.capitalize() for word in x.group(1).split()) + ']', result)
|
92 |
+
# Capitalize every instance of the letter i by itself
|
93 |
+
result = re.sub(r'\bi\b', 'I', result)
|
94 |
+
return result
|
95 |
+
|
96 |
+
#Format the string
|
97 |
+
def format_string(input_string):
|
98 |
+
result = ''.join([char + ('\n\n' if char == ']' else '') for char in input_string.replace('\n', '')])
|
99 |
+
result = result.replace('[', '\n\n[')
|
100 |
+
#Capitalize for good measure
|
101 |
+
result = capitalize_string(result)
|
102 |
+
return result
|
103 |
+
|
104 |
+
|
105 |
+
def TextGeneration(prompt,prof=False,parts=True):
|
106 |
+
#Load everything
|
107 |
+
#Load model
|
108 |
+
configuration = GPT2Config()
|
109 |
+
gpt_model = GPT2_Model(configuration).to(device)
|
110 |
+
#gpt_model.load_state_dict(torch.load('GPT-Trained-Model-Prod.pt'))
|
111 |
+
state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=')
|
112 |
+
gpt_model.load_state_dict(state_dict)
|
113 |
+
#Load tokenizer
|
114 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
|
115 |
+
#Call loader function
|
116 |
+
gpt_model.eval()
|
117 |
+
#Set model parameters
|
118 |
+
prefix = '[verse]'
|
119 |
+
|
120 |
+
#Input prompt
|
121 |
+
prompt = prompt.lower()
|
122 |
+
|
123 |
+
#Pre-process prompt
|
124 |
+
if parts:
|
125 |
+
prompt = prefix + ' ' + prompt
|
126 |
+
#Generate
|
127 |
+
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
128 |
+
generated = generated.to(device)
|
129 |
+
|
130 |
+
sample_outputs = generate(generated,
|
131 |
+
max_new_tokens=200,
|
132 |
+
context_size=400,
|
133 |
+
tokenizer=tokenizer,
|
134 |
+
model=gpt_model,
|
135 |
+
top_k=10,
|
136 |
+
top_p=0.95)
|
137 |
+
|
138 |
+
#Store the lyrics and decode
|
139 |
+
lyric = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
140 |
+
#Format if it has parts
|
141 |
+
if parts:
|
142 |
+
lyric = format_string(lyric)
|
143 |
+
#Remove profanity
|
144 |
+
if not prof:
|
145 |
+
lyric = profanity.censor(lyric)
|
146 |
+
return lyric
|