Spaces:
Running
Running
import os | |
import re | |
import torch | |
import random | |
from config import * | |
from tqdm import tqdm | |
from unidecode import unidecode | |
from torch.utils.data import Dataset | |
from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel | |
from samplings import top_p_sampling, top_k_sampling, temperature_sampling | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
def template(region): | |
return f'''A:{region} | |
S:2 | |
B:9 | |
E:4 | |
B:9 | |
L:1/8 | |
M:3/4 | |
K:D | |
de |"D"''' | |
def download(filename=WEIGHT_PATH, url=WEIGHT_URL): | |
import time | |
import requests | |
try: | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get("content-length", 0)) | |
chunk_size = 1024 | |
with open(filename, "wb") as file, tqdm( | |
desc=f"Downloading weights to '{filename}'...", | |
total=total_size, | |
unit="B", | |
unit_scale=True, | |
unit_divisor=1024, | |
) as bar: | |
for data in response.iter_content(chunk_size=chunk_size): | |
size = file.write(data) | |
bar.update(size) | |
except Exception as e: | |
print(f"Error: {e}") | |
time.sleep(3) | |
download(filename, ZH_WEIGHT_URL) | |
class Patchilizer: | |
""" | |
A class for converting music bars to patches and vice versa. | |
""" | |
def __init__(self): | |
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] | |
self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})" | |
self.pad_token_id = 0 | |
self.bos_token_id = 1 | |
self.eos_token_id = 2 | |
def split_bars(self, body): | |
""" | |
Split a body of music into individual bars. | |
""" | |
bars = re.split(self.regexPattern, "".join(body)) | |
bars = list(filter(None, bars)) | |
# remove empty strings | |
if bars[0] in self.delimiters: | |
bars[1] = bars[0] + bars[1] | |
bars = bars[1:] | |
bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)] | |
return bars | |
def bar2patch(self, bar, patch_size=PATCH_SIZE): | |
""" | |
Convert a bar into a patch of specified length. | |
""" | |
patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id] | |
patch = patch[:patch_size] | |
patch += [self.pad_token_id] * (patch_size - len(patch)) | |
return patch | |
def patch2bar(self, patch): | |
""" | |
Convert a patch into a bar. | |
""" | |
return "".join( | |
chr(idx) if idx > self.eos_token_id else "" | |
for idx in patch | |
if idx != self.eos_token_id | |
) | |
def encode( | |
self, | |
abc_code, | |
patch_length=PATCH_LENGTH, | |
patch_size=PATCH_SIZE, | |
add_special_patches=False, | |
): | |
""" | |
Encode music into patches of specified length. | |
""" | |
lines = unidecode(abc_code).split("\n") | |
lines = list(filter(None, lines)) # remove empty lines | |
body = "" | |
patches = [] | |
for line in lines: | |
if len(line) > 1 and ( | |
(line[0].isalpha() and line[1] == ":") or line.startswith("%%score") | |
): | |
if body: | |
bars = self.split_bars(body) | |
patches.extend( | |
self.bar2patch( | |
bar + "\n" if idx == len(bars) - 1 else bar, patch_size | |
) | |
for idx, bar in enumerate(bars) | |
) | |
body = "" | |
patches.append(self.bar2patch(line + "\n", patch_size)) | |
else: | |
body += line + "\n" | |
if body: | |
patches.extend( | |
self.bar2patch(bar, patch_size) for bar in self.split_bars(body) | |
) | |
if add_special_patches: | |
bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id] | |
eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1) | |
patches = [bos_patch] + patches + [eos_patch] | |
return patches[:patch_length] | |
def decode(self, patches): | |
""" | |
Decode patches into music. | |
""" | |
return "".join(self.patch2bar(patch) for patch in patches) | |
class PatchLevelDecoder(PreTrainedModel): | |
""" | |
An Patch-level Decoder model for generating patch features in an auto-regressive manner. | |
It inherits PreTrainedModel from transformers. | |
""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd) | |
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) | |
self.base = GPT2Model(config) | |
def forward(self, patches: torch.Tensor) -> torch.Tensor: | |
""" | |
The forward pass of the patch-level decoder model. | |
:param patches: the patches to be encoded | |
:return: the encoded patches | |
""" | |
patches = torch.nn.functional.one_hot(patches, num_classes=128).float() | |
patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128) | |
patches = self.patch_embedding(patches.to(self.device)) | |
return self.base(inputs_embeds=patches) | |
class CharLevelDecoder(PreTrainedModel): | |
""" | |
A Char-level Decoder model for generating the characters within each bar patch sequentially. | |
It inherits PreTrainedModel from transformers. | |
""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.pad_token_id = 0 | |
self.bos_token_id = 1 | |
self.eos_token_id = 2 | |
self.base = GPT2LMHeadModel(config) | |
def forward( | |
self, | |
encoded_patches: torch.Tensor, | |
target_patches: torch.Tensor, | |
patch_sampling_batch_size: int, | |
): | |
""" | |
The forward pass of the char-level decoder model. | |
:param encoded_patches: the encoded patches | |
:param target_patches: the target patches | |
:return: the decoded patches | |
""" | |
# preparing the labels for model training | |
target_masks = target_patches == self.pad_token_id | |
labels = target_patches.clone().masked_fill_(target_masks, -100) | |
# masking the labels for model training | |
target_masks = torch.ones_like(labels) | |
target_masks = target_masks.masked_fill_(labels == -100, 0) | |
# select patches | |
if ( | |
patch_sampling_batch_size != 0 | |
and patch_sampling_batch_size < target_patches.shape[0] | |
): | |
indices = list(range(len(target_patches))) | |
random.shuffle(indices) | |
selected_indices = sorted(indices[:patch_sampling_batch_size]) | |
target_patches = target_patches[selected_indices, :] | |
target_masks = target_masks[selected_indices, :] | |
encoded_patches = encoded_patches[selected_indices, :] | |
labels = labels[selected_indices, :] | |
# get input embeddings | |
inputs_embeds = torch.nn.functional.embedding( | |
target_patches, self.base.transformer.wte.weight | |
) | |
# concatenate the encoded patches with the input embeddings | |
inputs_embeds = torch.cat( | |
(encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1 | |
) | |
return self.base( | |
inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels | |
) | |
def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor): | |
""" | |
The generate function for generating a patch based on the encoded patch and already generated tokens. | |
:param encoded_patch: the encoded patch | |
:param tokens: already generated tokens in the patch | |
:return: the probability distribution of next token | |
""" | |
encoded_patch = encoded_patch.reshape(1, 1, -1) | |
tokens = tokens.reshape(1, -1) | |
# Get input embeddings | |
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight) | |
# Concatenate the encoded patch with the input embeddings | |
tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1) | |
# Get output from model | |
outputs = self.base(inputs_embeds=tokens) | |
# Get probabilities of next token | |
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1) | |
return probs | |
class TunesFormer(PreTrainedModel): | |
""" | |
TunesFormer is a hierarchical music generation model based on bar patching. | |
It includes a patch-level decoder and a character-level decoder. | |
It inherits PreTrainedModel from transformers. | |
""" | |
def __init__(self, encoder_config, decoder_config, share_weights=False): | |
super().__init__(encoder_config) | |
self.pad_token_id = 0 | |
self.bos_token_id = 1 | |
self.eos_token_id = 2 | |
if share_weights: | |
max_layers = max( | |
encoder_config.num_hidden_layers, decoder_config.num_hidden_layers | |
) | |
max_context_size = max(encoder_config.max_length, decoder_config.max_length) | |
max_position_embeddings = max( | |
encoder_config.max_position_embeddings, | |
decoder_config.max_position_embeddings, | |
) | |
encoder_config.num_hidden_layers = max_layers | |
encoder_config.max_length = max_context_size | |
encoder_config.max_position_embeddings = max_position_embeddings | |
decoder_config.num_hidden_layers = max_layers | |
decoder_config.max_length = max_context_size | |
decoder_config.max_position_embeddings = max_position_embeddings | |
self.patch_level_decoder = PatchLevelDecoder(encoder_config) | |
self.char_level_decoder = CharLevelDecoder(decoder_config) | |
if share_weights: | |
self.patch_level_decoder.base = self.char_level_decoder.base.transformer | |
def forward( | |
self, | |
patches: torch.Tensor, | |
patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE, | |
): | |
""" | |
The forward pass of the TunesFormer model. | |
:param patches: the patches to be both encoded and decoded | |
:return: the decoded patches | |
""" | |
patches = patches.reshape(len(patches), -1, PATCH_SIZE) | |
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] | |
return self.char_level_decoder( | |
encoded_patches.squeeze(0)[:-1, :], | |
patches.squeeze(0)[1:, :], | |
patch_sampling_batch_size, | |
) | |
def generate( | |
self, | |
patches: torch.Tensor, | |
tokens: torch.Tensor, | |
top_p: float = 1, | |
top_k: int = 0, | |
temperature: float = 1, | |
seed: int = None, | |
): | |
""" | |
The generate function for generating patches based on patches. | |
:param patches: the patches to be encoded | |
:return: the generated patches | |
""" | |
patches = patches.reshape(len(patches), -1, PATCH_SIZE) | |
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] | |
if tokens == None: | |
tokens = torch.tensor([self.bos_token_id], device=self.device) | |
generated_patch = [] | |
random.seed(seed) | |
while True: | |
if seed != None: | |
n_seed = random.randint(0, 1000000) | |
random.seed(n_seed) | |
else: | |
n_seed = None | |
prob = ( | |
self.char_level_decoder.generate(encoded_patches[0][-1], tokens) | |
.cpu() | |
.detach() | |
.numpy() | |
) | |
prob = top_p_sampling(prob, top_p=top_p, return_probs=True) | |
prob = top_k_sampling(prob, top_k=top_k, return_probs=True) | |
token = temperature_sampling(prob, temperature=temperature, seed=n_seed) | |
generated_patch.append(token) | |
if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1: | |
break | |
else: | |
tokens = torch.cat( | |
(tokens, torch.tensor([token], device=self.device)), dim=0 | |
) | |
return generated_patch, n_seed | |
class PatchilizedData(Dataset): | |
def __init__(self, items, patchilizer): | |
self.texts = [] | |
for item in tqdm(items): | |
text = item["control code"] + "\n".join( | |
item["abc notation"].split("\n")[1:] | |
) | |
input_patch = patchilizer.encode(text, add_special_patches=True) | |
input_patch = torch.tensor(input_patch) | |
if torch.sum(input_patch) != 0: | |
self.texts.append(input_patch) | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, idx): | |
return self.texts[idx] | |