import subprocess import os import gradio as gr import json from utils import * from unidecode import unidecode from transformers import AutoTokenizer CLAMP_MODEL_NAME = 'clamp-small-512' QUERY_MODAL = 'text' KEY_MODAL = 'music' TOP_N = 1 TEXT_MODEL_NAME = 'distilroberta-base' TEXT_LENGTH = 128 device = torch.device("cpu") # load CLaMP model model = CLaMP.from_pretrained(CLAMP_MODEL_NAME) music_length = model.config.max_length model = model.to(device) model.eval() # initialize patchilizer, tokenizer, and softmax patchilizer = MusicPatchilizer() tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) softmax = torch.nn.Softmax(dim=1) def compute_values(Q_e, K_e, t=1): """ Compute the values for the attention matrix Args: Q_e (torch.Tensor): Query embeddings K_e (torch.Tensor): Key embeddings t (float): Temperature for the softmax Returns: values (torch.Tensor): Values for the attention matrix """ # Normalize the feature representations Q_e = torch.nn.functional.normalize(Q_e, dim=1) K_e = torch.nn.functional.normalize(K_e, dim=1) # Scaled pairwise cosine similarities [1, n] logits = torch.mm(Q_e, K_e.T) * torch.exp(torch.tensor(t)) values = softmax(logits) return values.squeeze() def encoding_data(data, modal): """ Encode the data into ids Args: data (list): List of strings modal (str): "music" or "text" Returns: ids_list (list): List of ids """ ids_list = [] if modal=="music": for item in data: patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True) ids_list.append(torch.tensor(patches).reshape(-1)) else: for item in data: text_encodings = tokenizer(item, return_tensors='pt', truncation=True, max_length=TEXT_LENGTH) ids_list.append(text_encodings['input_ids'].squeeze(0)) return ids_list def get_features(ids_list, modal): """ Get the features from the CLaMP model Args: ids_list (list): List of ids modal (str): "music" or "text" Returns: features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size) """ features_list = [] print("Extracting "+modal+" features...") with torch.no_grad(): for ids in tqdm(ids_list): ids = ids.unsqueeze(0) if modal=="text": masks = torch.tensor([1]*len(ids[0])).unsqueeze(0) features = model.text_enc(ids.to(device), attention_mask=masks.to(device))['last_hidden_state'] features = model.avg_pooling(features, masks) features = model.text_proj(features) else: masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0) features = model.music_enc(ids, masks)['last_hidden_state'] features = model.avg_pooling(features, masks) features = model.music_proj(features) features_list.append(features[0]) return torch.stack(features_list).to(device) def semantic_music_search(query): """ Semantic music search Args: query (str): Query string Returns: output (str): Search result """ with open(KEY_MODAL+"_key_cache_"+str(music_length)+".pth", 'rb') as f: key_cache = torch.load(f) # encode query query_ids = encoding_data([query], QUERY_MODAL) query_feature = get_features(query_ids, QUERY_MODAL) key_filenames = key_cache["filenames"] key_features = key_cache["features"] # compute values values = compute_values(query_feature, key_features) idx = torch.argsort(values)[-1] filename = key_filenames[idx].split('/')[-1][:-4] with open("wikimusictext.json", 'r') as f: wikimusictext = json.load(f) for item in wikimusictext: if item['title']==filename: output = "Title:\n" + item['title']+'\n\n' output += "Artist:\n" + item['artist']+ '\n\n' output += "Genre:\n" + item['genre']+ '\n\n' output += "Description:\n" + item['text']+ '\n\n' output += "ABC notation:\n" + item['music']+ '\n\n' return output gr.Interface( fn=semantic_music_search, inputs=gr.Textbox(lines=2, placeholder="Describe the music you want to search..."), outputs="text", ).launch()