import gradio as gr import json from utils import * from unidecode import unidecode from transformers import AutoTokenizer description = """
Duplicate Space
## โ„น๏ธ How to use this demo? 1. Enter a query in the text box. 2. Click "Submit" and wait for the result. 3. It will return the most matching music score from the WikiMusictext dataset (1010 scores in total). ## โ•Notice - The text box is case-sensitive. - You can enter longer text for the text box, but the demo will only use the first 128 tokens. - The returned results include the title, artist, genre, description, and the score in ABC notation. - The genre and description may not be accurate, as they are collected from the web. - The demo is based on CLaMP-S/512, a CLaMP model with 6-layer Transformer text/music encoders and a sequence length of 512. ## ๐Ÿ” ๐Ÿ‘‰๐ŸŽต Semantic Music Search Semantic search is a technique for retrieving music by open-domain queries, which differs from traditional keyword-based searches that depend on exact matches or meta-information. This involves two steps: 1) extracting music features from all scores in the library, and 2) transforming the query into a text feature. By calculating the similarities between the text feature and the music features, it can efficiently locate the score that best matches the user's query in the library. """ examples = [ "Jazz standard in Minor key with a swing feel.", "Jazz standard in Major key with a fast tempo.", "Jazz standard in Blues form with a soulfoul melody.", "a painting of a starry night with the moon in the sky", "a green field with a blue sky and clouds", "a beach with a castle on top of it" ] CLAMP_MODEL_NAME = 'sander-wood/clamp-small-512' QUERY_MODAL = 'text' KEY_MODAL = 'music' TOP_N = 1 TEXT_MODEL_NAME = 'distilroberta-base' TEXT_LENGTH = 128 device = torch.device("cpu") # load CLaMP model model = CLaMP.from_pretrained(CLAMP_MODEL_NAME) music_length = model.config.max_length model = model.to(device) model.eval() # initialize patchilizer, tokenizer, and softmax patchilizer = MusicPatchilizer() tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) softmax = torch.nn.Softmax(dim=1) def compute_values(Q_e, K_e, t=1): """ Compute the values for the attention matrix Args: Q_e (torch.Tensor): Query embeddings K_e (torch.Tensor): Key embeddings t (float): Temperature for the softmax Returns: values (torch.Tensor): Values for the attention matrix """ # Normalize the feature representations Q_e = torch.nn.functional.normalize(Q_e, dim=1) K_e = torch.nn.functional.normalize(K_e, dim=1) # Scaled pairwise cosine similarities [1, n] logits = torch.mm(Q_e, K_e.T) * torch.exp(torch.tensor(t)) values = softmax(logits) return values.squeeze() def encoding_data(data, modal): """ Encode the data into ids Args: data (list): List of strings modal (str): "music" or "text" Returns: ids_list (list): List of ids """ ids_list = [] if modal=="music": for item in data: patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True) ids_list.append(torch.tensor(patches).reshape(-1)) else: for item in data: text_encodings = tokenizer(item, return_tensors='pt', truncation=True, max_length=TEXT_LENGTH) ids_list.append(text_encodings['input_ids'].squeeze(0)) return ids_list def get_features(ids_list, modal): """ Get the features from the CLaMP model Args: ids_list (list): List of ids modal (str): "music" or "text" Returns: features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size) """ features_list = [] print("Extracting "+modal+" features...") with torch.no_grad(): for ids in tqdm(ids_list): ids = ids.unsqueeze(0) if modal=="text": masks = torch.tensor([1]*len(ids[0])).unsqueeze(0) features = model.text_enc(ids.to(device), attention_mask=masks.to(device))['last_hidden_state'] features = model.avg_pooling(features, masks) features = model.text_proj(features) else: masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0) features = model.music_enc(ids, masks)['last_hidden_state'] features = model.avg_pooling(features, masks) features = model.music_proj(features) features_list.append(features[0]) return torch.stack(features_list).to(device) def semantic_music_search(query): """ Semantic music search Args: query (str): Query string Returns: output (str): Search result """ with open(KEY_MODAL+"_key_cache_"+str(music_length)+".pth", 'rb') as f: key_cache = torch.load(f) print("\nQuery: "+query+"\n") # encode query query_ids = encoding_data([unidecode(query)], QUERY_MODAL) query_feature = get_features(query_ids, QUERY_MODAL) key_filenames = key_cache["filenames"] key_features = key_cache["features"] # compute values values = compute_values(query_feature, key_features) idx = torch.argsort(values)[-1] filename = key_filenames[idx].split('/')[-1][:-4] with open("wikimusictext.json", 'r') as f: wikimusictext = json.load(f) for item in wikimusictext: if item['title']==filename: # output = "Title:\n" + item['title']+'\n\n' # output += "Artist:\n" + item['artist']+ '\n\n' # output += "Genre:\n" + item['genre']+ '\n\n' # output += "Description:\n" + item['text']+ '\n\n' # output += "ABC notation:\n" + item['music']+ '\n\n' print("Title: " + item['title']) print("Artist: " + item['artist']) print("Genre: " + item['genre']) print("Description: " + item['text']) print("ABC notation:\n" + item['music']) return item["title"], item["artist"], item["genre"], item["text"], item["music"] output_title = gr.outputs.Textbox(label="Title") output_artist = gr.outputs.Textbox(label="Artist") output_genre = gr.outputs.Textbox(label="Genre") output_description = gr.outputs.Textbox(label="Description") output_abc = gr.outputs.Textbox(label="ABC notation") gr.Interface( fn=semantic_music_search, inputs=gr.Textbox(lines=2, placeholder="Describe the music you want to search...", label="Query"), outputs=[output_title, output_artist, output_genre, output_description, output_abc], title="๐Ÿ—œ๏ธ CLaMP: Semantic Music Search", description=description, examples=examples).launch()