hoyoGPT / utils.py
MuGeminorum
upl base
653dc95
raw
history blame
No virus
12.8 kB
import os
import re
import torch
import random
from config import *
from tqdm import tqdm
from unidecode import unidecode
from torch.utils.data import Dataset
from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
def template(region):
return f'''A:{region}
S:2
B:9
E:4
B:9
L:1/8
M:3/4
K:D
de |"D"'''
def create_dir(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
import time
import requests
try:
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
chunk_size = 1024
with open(filename, 'wb') as file, tqdm(
desc=f"Downloading weights to '{filename}'...",
total=total_size,
unit='B',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
except ConnectionError as e:
print(f"Error: {e}")
time.sleep(3)
download(filename, ZH_WEIGHT_URL)
class Patchilizer:
"""
A class for converting music bars to patches and vice versa.
"""
def __init__(self):
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})"
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
def split_bars(self, body):
"""
Split a body of music into individual bars.
"""
bars = re.split(self.regexPattern, ''.join(body))
bars = list(filter(None, bars))
# remove empty strings
if bars[0] in self.delimiters:
bars[1] = bars[0] + bars[1]
bars = bars[1:]
bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
return bars
def bar2patch(self, bar, patch_size=PATCH_SIZE):
"""
Convert a bar into a patch of specified length.
"""
patch = [self.bos_token_id] + \
[ord(c) for c in bar] + [self.eos_token_id]
patch = patch[:patch_size]
patch += [self.pad_token_id] * (patch_size - len(patch))
return patch
def patch2bar(self, patch):
"""
Convert a patch into a bar.
"""
return ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in patch if idx != self.eos_token_id)
def encode(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=False):
"""
Encode music into patches of specified length.
"""
lines = unidecode(abc_code).split('\n')
lines = list(filter(None, lines)) # remove empty lines
body = ""
patches = []
for line in lines:
if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%score')):
if body:
bars = self.split_bars(body)
patches.extend(
self.bar2patch(bar + '\n' if idx == len(bars) - 1 else bar, patch_size) for idx, bar in enumerate(bars)
)
body = ""
patches.append(self.bar2patch(line + '\n', patch_size))
else:
body += line + '\n'
if body:
patches.extend(
self.bar2patch(bar, patch_size) for bar in self.split_bars(body)
)
if add_special_patches:
bos_patch = [self.bos_token_id] * \
(patch_size-1) + [self.eos_token_id]
eos_patch = [self.bos_token_id] + \
[self.eos_token_id] * (patch_size-1)
patches = [bos_patch] + patches + [eos_patch]
return patches[:patch_length]
def decode(self, patches):
"""
Decode patches into music.
"""
return ''.join(self.patch2bar(patch) for patch in patches)
class PatchLevelDecoder(PreTrainedModel):
"""
An Patch-level Decoder model for generating patch features in an auto-regressive manner.
It inherits PreTrainedModel from transformers.
"""
def __init__(self, config):
super().__init__(config)
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
self.base = GPT2Model(config)
def forward(self, patches: torch.Tensor) -> torch.Tensor:
"""
The forward pass of the patch-level decoder model.
:param patches: the patches to be encoded
:return: the encoded patches
"""
patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
patches = self.patch_embedding(patches.to(self.device))
return self.base(inputs_embeds=patches)
class CharLevelDecoder(PreTrainedModel):
"""
A Char-level Decoder model for generating the characters within each bar patch sequentially.
It inherits PreTrainedModel from transformers.
"""
def __init__(self, config):
super().__init__(config)
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
self.base = GPT2LMHeadModel(config)
def forward(self, encoded_patches: torch.Tensor, target_patches: torch.Tensor, patch_sampling_batch_size: int):
"""
The forward pass of the char-level decoder model.
:param encoded_patches: the encoded patches
:param target_patches: the target patches
:return: the decoded patches
"""
# preparing the labels for model training
target_masks = target_patches == self.pad_token_id
labels = target_patches.clone().masked_fill_(target_masks, -100)
# masking the labels for model training
target_masks = torch.ones_like(labels)
target_masks = target_masks.masked_fill_(labels == -100, 0)
# select patches
if patch_sampling_batch_size != 0 and patch_sampling_batch_size < target_patches.shape[0]:
indices = list(range(len(target_patches)))
random.shuffle(indices)
selected_indices = sorted(indices[:patch_sampling_batch_size])
target_patches = target_patches[selected_indices, :]
target_masks = target_masks[selected_indices, :]
encoded_patches = encoded_patches[selected_indices, :]
labels = labels[selected_indices, :]
# get input embeddings
inputs_embeds = torch.nn.functional.embedding(
target_patches,
self.base.transformer.wte.weight
)
# concatenate the encoded patches with the input embeddings
inputs_embeds = torch.cat(
(encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]),
dim=1
)
return self.base(
inputs_embeds=inputs_embeds,
attention_mask=target_masks,
labels=labels
)
def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
"""
The generate function for generating a patch based on the encoded patch and already generated tokens.
:param encoded_patch: the encoded patch
:param tokens: already generated tokens in the patch
:return: the probability distribution of next token
"""
encoded_patch = encoded_patch.reshape(1, 1, -1)
tokens = tokens.reshape(1, -1)
# Get input embeddings
tokens = torch.nn.functional.embedding(
tokens,
self.base.transformer.wte.weight
)
# Concatenate the encoded patch with the input embeddings
tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
# Get output from model
outputs = self.base(inputs_embeds=tokens)
# Get probabilities of next token
probs = torch.nn.functional.softmax(
outputs.logits.squeeze(0)[-1],
dim=-1
)
return probs
class TunesFormer(PreTrainedModel):
"""
TunesFormer is a hierarchical music generation model based on bar patching.
It includes a patch-level decoder and a character-level decoder.
It inherits PreTrainedModel from transformers.
"""
def __init__(self, encoder_config, decoder_config, share_weights=False):
super().__init__(encoder_config)
self.pad_token_id = 0
self.bos_token_id = 1
self.eos_token_id = 2
if share_weights:
max_layers = max(
encoder_config.num_hidden_layers,
decoder_config.num_hidden_layers
)
max_context_size = max(
encoder_config.max_length,
decoder_config.max_length
)
max_position_embeddings = max(
encoder_config.max_position_embeddings,
decoder_config.max_position_embeddings
)
encoder_config.num_hidden_layers = max_layers
encoder_config.max_length = max_context_size
encoder_config.max_position_embeddings = max_position_embeddings
decoder_config.num_hidden_layers = max_layers
decoder_config.max_length = max_context_size
decoder_config.max_position_embeddings = max_position_embeddings
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
self.char_level_decoder = CharLevelDecoder(decoder_config)
if share_weights:
self.patch_level_decoder.base = self.char_level_decoder.base.transformer
def forward(self, patches: torch.Tensor, patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE):
"""
The forward pass of the TunesFormer model.
:param patches: the patches to be both encoded and decoded
:return: the decoded patches
"""
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
encoded_patches = self.patch_level_decoder(
patches)["last_hidden_state"]
return self.char_level_decoder(encoded_patches.squeeze(0)[:-1, :], patches.squeeze(0)[1:, :], patch_sampling_batch_size)
def generate(
self,
patches: torch.Tensor,
tokens: torch.Tensor,
top_p: float = 1,
top_k: int = 0,
temperature: float = 1,
seed: int = None
):
"""
The generate function for generating patches based on patches.
:param patches: the patches to be encoded
:return: the generated patches
"""
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
encoded_patches = self.patch_level_decoder(
patches)["last_hidden_state"]
if tokens == None:
tokens = torch.tensor([self.bos_token_id], device=self.device)
generated_patch = []
random.seed(seed)
while True:
if seed != None:
n_seed = random.randint(0, 1000000)
random.seed(n_seed)
else:
n_seed = None
prob = self.char_level_decoder.generate(
encoded_patches[0][-1],
tokens
).cpu().detach().numpy()
prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
token = temperature_sampling(
prob,
temperature=temperature,
seed=n_seed
)
generated_patch.append(token)
if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
break
else:
tokens = torch.cat(
(tokens, torch.tensor([token], device=self.device)),
dim=0
)
return generated_patch, n_seed
class PatchilizedData(Dataset):
def __init__(self, items, patchilizer):
self.texts = []
for item in tqdm(items):
text = item['control code'] + \
"\n".join(item['abc notation'].split('\n')[1:])
input_patch = patchilizer.encode(text, add_special_patches=True)
input_patch = torch.tensor(input_patch)
if torch.sum(input_patch) != 0:
self.texts.append(input_patch)
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
return self.texts[idx]