import re import os import torch import requests from tqdm import tqdm from unidecode import unidecode from transformers import AutoModel, AutoConfig, BertModel, PreTrainedModel # Constants for patch length and number of features in a patch PATCH_LENGTH = 64 PATCH_FEATURES = 98 class MusicPatchilizer: """ Class for converting music data to patches and vice-versa. Attributes: delimiters (tuple): A tuple of strings containing the delimiters used for splitting bars. regexPattern (str): A regular expression pattern for splitting bars. pad_id (int): The id of the padding token. mask_id (int): The id of the mask token. eos_id (int): The id of the end-of-sequence token. Methods: split_bars(body): Splits a body of music into individual bars using the delimiters specified in `self.delimiters`. bar2patch(bar, patch_length): Encodes a single bar as a patch of specified length. patch2bar(patch): Converts a patch to a bar string. encode(music, music_length, patch_length=PATCH_LENGTH, add_eos_patch=False): Encodes the input music string as a list of patches. decode(patches): Decodes a sequence of patches into a music score. """ def __init__(self): # Delimiters used for splitting bars self.delimiters = "|:", "::", ":|", "[|", "||", "|]", "|" # Regular expression pattern for splitting bars self.regexPattern = '('+'|'.join(map(re.escape, self.delimiters))+')' # Padding, mask, and end-of-sequence token ids self.pad_id = 0 self.mask_id = 96 self.eos_id = 97 def split_bars(self, body): """ Splits a body of music into individual bars using the delimiters specified in `self.delimiters`. Args: body (str): A string containing the body of music to be split into bars. Returns: list: A list of strings containing the individual bars. """ body = "".join(body) bars = re.split(self.regexPattern, body) while("" in bars): bars.remove("") if bars[0] in self.delimiters: bars[1] = bars[0]+bars[1] bars = bars[1:] bars = [bars[i*2]+bars[i*2+1] for i in range(int(len(bars)/2))] return bars def bar2patch(self, bar, patch_length): """ Encodes a single bar as a patch of specified length. Args: bar (str): A string containing the bar to be encoded. patch_length (int): An integer indicating the length of the patch to be returned. Returns: list: A list of integer-encoded musical tokens. """ patch = [self.pad_id] * patch_length for i in range(min(patch_length, len(bar))): chr = bar[i] idx = ord(chr) if idx>=32 and idx<127: patch[i] = idx-31 if i+10 and idx<96: bar += chr(idx+31) else: break return bar def encode(self, music, music_length, patch_length=PATCH_LENGTH, add_eos_patch=False): """ Encodes the input music string as a list of patches. Args: music (str): A string containing the music to be encoded. music_length (int): An integer indicating the maximum number of patches to be returned. patch_length (int): An integer indicating the length of each patch. add_eos_patch (bool): A boolean indicating whether to add an extra patch consisting of all EOS tokens at the end of the encoded music. Returns: list: A list of integer-encoded patches. """ # Convert to ASCII and split into lines music = unidecode(music) lines = music.split('\n') try: lines.remove('') except: pass body = "" patches = [] # Iterate over lines, splitting bars and encoding each one as a patch for line in lines: # check if the line is a music score line or not if len(line)>1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%score')): # if the current line is a music score line, encode the previous body as patches if body!="": bars = self.split_bars(body) for bar in bars: # encode each bar in the body as a patch and append to the patches list patch = self.bar2patch(bar, patch_length) patches.append(patch) # reset the body variable body = "" # encode the current line as a patch and append to the patches list patch = self.bar2patch(line, patch_length) patches.append(patch) else: # if the line is not a music score line, append to the body variable body += line if body!="": bars = self.split_bars(body) for bar in bars: # encode each bar in the body as a patch and append to the patches list patch = self.bar2patch(bar, patch_length) patches.append(patch) # add an extra patch consisting of all EOS tokens, if required if add_eos_patch: eos_patch = [self.eos_id] * patch_length patches = patches + [eos_patch] return patches[:music_length] def decode(self, patches): """ Decodes a sequence of patches into a music score. Args: patches (list): A list of integer-encoded patches. Returns: str: A string containing the decoded music score. """ music = "" for patch in patches: music += self.patch2bar(patch)+'\n' return music class MusicEncoder(PreTrainedModel): """ MusicEncoder model for encoding music patches into a sequence of hidden states. Args: config (:obj:`BertConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. Attributes: patch_embedding (:obj:`torch.nn.Linear`): A linear layer to convert the one-hot encoded patches to the hidden size of the model. enc (:obj:`BertModel`): The BERT model used to encode the patches. """ def __init__(self, config): super(MusicEncoder, self).__init__(config) self.patch_embedding = torch.nn.Linear(PATCH_LENGTH*PATCH_FEATURES, config.hidden_size) torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) self.enc = BertModel(config=config) def forward(self, input_musics, music_masks): """ Args: input_musics (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length, patch_length)`): Tensor containing the integer-encoded music patches. music_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length)`): Tensor containing the attention masks for the music patches. Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, music_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. """ # One-hot encode the input music patches input_musics = torch.nn.functional.one_hot(input_musics, num_classes=PATCH_FEATURES) # Reshape the input music patches to feed into the linear layer input_musics = input_musics.reshape(len(input_musics), -1, PATCH_LENGTH*PATCH_FEATURES).type(torch.FloatTensor) # Apply the linear layer to convert the one-hot encoded patches to hidden features input_musics = self.patch_embedding(input_musics.to(self.device)) # Apply the BERT model to encode the music data output = self.enc(inputs_embeds=input_musics, attention_mask=music_masks.to(self.device)) return output class CLaMP(PreTrainedModel): """ CLaMP model for joint text and music encoding. Args: config (:obj:`BertConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. text_model_name (:obj:`str`, `optional`, defaults to :obj:`"distilroberta-base"`): The name of the pre-trained text model to be used for text encoding. Attributes: text_enc (:obj:`AutoModel`): The pre-trained text model used for text encoding. text_proj (:obj:`torch.nn.Linear`): A linear layer to project the text encoding to the hidden size of the model. music_enc (:obj:`MusicEncoder`): The music encoder model used for music encoding. music_proj (:obj:`torch.nn.Linear`): A linear layer to project the music encoding to the hidden size of the model. """ def __init__(self, config, text_model_name="distilroberta-base"): super(CLaMP, self).__init__(config) self.text_enc = AutoModel.from_pretrained(text_model_name) self.text_proj = torch.nn.Linear(config.hidden_size, config.hidden_size) torch.nn.init.normal_(self.text_proj.weight, std=0.02) self.music_enc = MusicEncoder(config=config) self.music_proj = torch.nn.Linear(config.hidden_size, config.hidden_size) torch.nn.init.normal_(self.music_proj.weight, std=0.02) def forward(self, input_texts, text_masks, input_musics, music_masks): """ Args: input_texts (:obj:`torch.LongTensor` of shape :obj:`(batch_size, text_length)`): Tensor containing the integer-encoded text. text_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, text_length)`): Tensor containing the attention masks for the text. input_musics (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length, patch_length)`): Tensor containing the integer-encoded music patches. music_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length)`): Tensor containing the attention masks for the music patches. Returns: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: music_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): The music features extracted from the music encoder. text_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): The text features extracted from the text encoder. """ # Encode input texts text_features = self.text_enc(input_texts.to(self.device), attention_mask=text_masks.to(self.device))['last_hidden_state'] text_features = self.avg_pooling(text_features, text_masks) text_features = self.text_proj(text_features) # Encode input musics music_features = self.music_enc(input_musics, music_masks)['last_hidden_state'] music_features = self.avg_pooling(music_features, music_masks) music_features = self.music_proj(music_features) return music_features, text_features def avg_pooling(self, input_features, input_masks): """ Applies average pooling to the input features. Args: input_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_length, hidden_size)`): Tensor containing the input features. input_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, seq_length)`): Tensor containing the attention masks for the input features. Returns: :obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`: The pooled features. """ input_masks = input_masks.unsqueeze(-1).to(self.device) input_features = input_features * input_masks avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) return avg_pool @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Instantiate a CLaMP model from a pre-trained model configuration. Args: pretrained_model_name_or_path (:obj:`str`): This can be either: "clamp-small-512" for the small CLaMP model with 512 max sequence length. "clamp-small-1024" for the small CLaMP model with 1024 max sequence length. Returns: :class:`~transformers.CLaMP`: The CLaMP model. """ model_dir = pretrained_model_name_or_path # If the pre-trained model is not found locally, download it from Hugging Face if not os.path.exists(model_dir): # Create the model directory and download the config and pytorch model files os.makedirs(model_dir) config_url = f"https://huggingface.co/{pretrained_model_name_or_path}/raw/main/config.json" model_url = f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/pytorch_model.bin" chunk_size = 1024 * 1024 # 1MB # download config file with requests.get(config_url, stream=True) as r: r.raise_for_status() total_size = int(r.headers.get('content-length', 0)) with open(model_dir+"/config.json", 'wb') as f: with tqdm(total=total_size, unit='B', unit_scale=True, desc='Downloading config') as pbar: for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) pbar.update(len(chunk)) # download pytorch model file with requests.get(model_url, stream=True) as r: r.raise_for_status() total_size = int(r.headers.get('content-length', 0)) with open(model_dir+"/pytorch_model.bin", 'wb') as f: with tqdm(total=total_size, unit='B', unit_scale=True, desc='Downloading model') as pbar: for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) pbar.update(len(chunk)) # Load the model weights and configuration config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) model = cls(config) model.load_state_dict(torch.load(pretrained_model_name_or_path+str('/pytorch_model.bin'))) return model