selbl commited on
Commit
0f1198e
1 Parent(s): 85d7874

Model Underlying the Program

Browse files
Files changed (1) hide show
  1. StreamlitModel.py +146 -0
StreamlitModel.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Mon Feb 19 13:56:57 2024
5
+
6
+ @author: selbl
7
+ """
8
+
9
+ #This code runs the lyric generation model
10
+ #Because it is a script, I do not spend time writing about what each part does
11
+ #For more info please refer to the Github script
12
+
13
+ import torch
14
+ import numpy as np
15
+ from torch import nn
16
+ from transformers import GPT2Tokenizer, GPT2Config, GPT2Model, GPT2PreTrainedModel
17
+ from torch.optim import AdamW
18
+ from datasets import load_dataset
19
+ from tqdm import tqdm
20
+ from torch.nn import functional as F
21
+ import pandas as pd
22
+ from better_profanity import profanity
23
+ import re
24
+ import torch.hub
25
+
26
+ profanity.load_censor_words()
27
+
28
+ device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_built() else 'cpu'
29
+
30
+ class GPT2_Model(GPT2PreTrainedModel):
31
+
32
+ def __init__(self, config):
33
+
34
+ super().__init__(config)
35
+
36
+ self.transformer = GPT2Model.from_pretrained('gpt2')
37
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
38
+
39
+ # this is necessary since we add a new unique token for pad_token
40
+ self.transformer.resize_token_embeddings(len(tokenizer))
41
+
42
+ self.lm_head = nn.Linear(config.n_embd, len(tokenizer), bias=False)
43
+
44
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None):
45
+
46
+ x = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
47
+ x = self.lm_head(x)
48
+
49
+ return x
50
+
51
+
52
+ #Dedfine generation function
53
+ def generate(idx, max_new_tokens, context_size, tokenizer, model, top_k=10, top_p=0.95):
54
+
55
+ for _ in range(max_new_tokens):
56
+ if idx[:,-1].item() != tokenizer.encode(tokenizer.eos_token)[0]:
57
+ # crop idx to the last block_size tokens
58
+ idx_cond = idx[:, -context_size:]
59
+ # get the predictions
60
+ logits = model(idx_cond)
61
+ # focus only on the last time step
62
+ logits = logits[:, -1, :]
63
+ # apply softmax to get probabilities
64
+ probs = F.softmax(logits, dim=-1)
65
+ # sort probabilities in descending order
66
+ sorted_probs, indices = torch.sort(probs, descending=True)
67
+ # compute cumsum of probabilities
68
+ probs_cumsum = torch.cumsum(sorted_probs, dim=1)
69
+ # choose only top_p tokens
70
+ sorted_probs, indices = sorted_probs[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] + 1], indices[:, :probs_cumsum[[probs_cumsum < top_p]].size()[0] +1]
71
+ # choose only top_k tokens
72
+ sorted_probs, indices = sorted_probs[:,:top_k], indices[:,:top_k]
73
+ # sample from the distribution
74
+ sorted_probs = F.softmax(sorted_probs, dim=-1)
75
+ idx_next = indices[:, torch.multinomial(sorted_probs, num_samples=1)].squeeze(0)
76
+ # append new token ids
77
+ idx = torch.cat((idx, idx_next), dim=1)
78
+ else:
79
+ break
80
+
81
+ return idx
82
+
83
+ #Define capitalization functions
84
+ def custom_capitalize(match):
85
+ return match.group(1).capitalize()
86
+
87
+ def capitalize_string(input_string):
88
+ # Capitalize every first letter after "\n\n "
89
+ result = re.sub(r'\n\n\s*([a-zA-Z])', lambda x: '\n\n' + x.group(1).upper(), input_string)
90
+ # Capitalize every first letter in every word inside brackets ([ ])
91
+ result = re.sub(r'\[([^\]]*)\]', lambda x: '[' + ' '.join(word.capitalize() for word in x.group(1).split()) + ']', result)
92
+ # Capitalize every instance of the letter i by itself
93
+ result = re.sub(r'\bi\b', 'I', result)
94
+ return result
95
+
96
+ #Format the string
97
+ def format_string(input_string):
98
+ result = ''.join([char + ('\n\n' if char == ']' else '') for char in input_string.replace('\n', '')])
99
+ result = result.replace('[', '\n\n[')
100
+ #Capitalize for good measure
101
+ result = capitalize_string(result)
102
+ return result
103
+
104
+
105
+ def TextGeneration(prompt,prof=False,parts=True):
106
+ #Load everything
107
+ #Load model
108
+ configuration = GPT2Config()
109
+ gpt_model = GPT2_Model(configuration).to(device)
110
+ #gpt_model.load_state_dict(torch.load('GPT-Trained-Model-Prod.pt'))
111
+ state_dict = torch.hub.load_state_dict_from_url(r'https://github.com/Selbl/LyricGeneration/raw/main/GPT-Trained-Model-Prod.pt?download=')
112
+ gpt_model.load_state_dict(state_dict)
113
+ #Load tokenizer
114
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|pad|>')
115
+ #Call loader function
116
+ gpt_model.eval()
117
+ #Set model parameters
118
+ prefix = '[verse]'
119
+
120
+ #Input prompt
121
+ prompt = prompt.lower()
122
+
123
+ #Pre-process prompt
124
+ if parts:
125
+ prompt = prefix + ' ' + prompt
126
+ #Generate
127
+ generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
128
+ generated = generated.to(device)
129
+
130
+ sample_outputs = generate(generated,
131
+ max_new_tokens=200,
132
+ context_size=400,
133
+ tokenizer=tokenizer,
134
+ model=gpt_model,
135
+ top_k=10,
136
+ top_p=0.95)
137
+
138
+ #Store the lyrics and decode
139
+ lyric = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
140
+ #Format if it has parts
141
+ if parts:
142
+ lyric = format_string(lyric)
143
+ #Remove profanity
144
+ if not prof:
145
+ lyric = profanity.censor(lyric)
146
+ return lyric