import gradio as gr
import json
from utils import *
from unidecode import unidecode
from transformers import AutoTokenizer
description = """
## ℹ️ How to use this demo?
1. Enter a query in the text box.
2. Click "Submit" and wait for the result.
3. It will return the most matching music score from the WikiMusictext dataset (1010 scores in total).
## ❕Notice
- The text box is case-sensitive.
- You can enter longer text for the text box, but the demo will only use the first 128 tokens.
- The returned results include the title, artist, genre, description, and the score in ABC notation.
- The genre and description may not be accurate, as they are collected from the web.
- The demo is based on CLaMP-S/512, a CLaMP model with 6-layer Transformer text/music encoders and a sequence length of 512.
## 🔠👉🎵 Semantic Music Search
Semantic search is a technique for retrieving music by open-domain queries, which differs from traditional keyword-based searches that depend on exact matches or meta-information. This involves two steps: 1) extracting music features from all scores in the library, and 2) transforming the query into a text feature. By calculating the similarities between the text feature and the music features, it can efficiently locate the score that best matches the user's query in the library.
examples = [
"Jazz standard in Minor key with a swing feel.",
"Jazz standard in Major key with a fast tempo.",
"Jazz standard in Blues form with a soulfoul melody.",
"a painting of a starry night with the moon in the sky",
"a green field with a blue sky and clouds",
"a beach with a castle on top of it"
CLAMP_MODEL_NAME = 'sander-wood/clamp-small-512'
QUERY_MODAL = 'text'
KEY_MODAL = 'music'
TOP_N = 1
TEXT_MODEL_NAME = 'distilroberta-base'
device = torch.device("cpu")
# load CLaMP model
model = CLaMP.from_pretrained(CLAMP_MODEL_NAME)
music_length = model.config.max_length
model =
# initialize patchilizer, tokenizer, and softmax
patchilizer = MusicPatchilizer()
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
softmax = torch.nn.Softmax(dim=1)
def compute_values(Q_e, K_e, t=1):
Compute the values for the attention matrix
Q_e (torch.Tensor): Query embeddings
K_e (torch.Tensor): Key embeddings
t (float): Temperature for the softmax
values (torch.Tensor): Values for the attention matrix
# Normalize the feature representations
Q_e = torch.nn.functional.normalize(Q_e, dim=1)
K_e = torch.nn.functional.normalize(K_e, dim=1)
# Scaled pairwise cosine similarities [1, n]
logits =, K_e.T) * torch.exp(torch.tensor(t))
values = softmax(logits)
return values.squeeze()
def encoding_data(data, modal):
Encode the data into ids
data (list): List of strings
modal (str): "music" or "text"
ids_list (list): List of ids
ids_list = []
if modal=="music":
for item in data:
patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True)
for item in data:
text_encodings = tokenizer(item,
return ids_list
def get_features(ids_list, modal):
Get the features from the CLaMP model
ids_list (list): List of ids
modal (str): "music" or "text"
features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size)
features_list = []
print("Extracting "+modal+" features...")
with torch.no_grad():
for ids in tqdm(ids_list):
ids = ids.unsqueeze(0)
if modal=="text":
masks = torch.tensor([1]*len(ids[0])).unsqueeze(0)
features = model.text_enc(,['last_hidden_state']
features = model.avg_pooling(features, masks)
features = model.text_proj(features)
masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0)
features = model.music_enc(ids, masks)['last_hidden_state']
features = model.avg_pooling(features, masks)
features = model.music_proj(features)
return torch.stack(features_list).to(device)
def semantic_music_search(query):
Semantic music search
query (str): Query string
output (str): Search result
with open(KEY_MODAL+"_key_cache_"+str(music_length)+".pth", 'rb') as f:
key_cache = torch.load(f)
print("\nQuery: "+query+"\n")
# encode query
query_ids = encoding_data([unidecode(query)], QUERY_MODAL)
query_feature = get_features(query_ids, QUERY_MODAL)
key_filenames = key_cache["filenames"]
key_features = key_cache["features"]
# compute values
values = compute_values(query_feature, key_features)
idx = torch.argsort(values)[-1]
filename = key_filenames[idx].split('/')[-1][:-4]
with open("wikimusictext.json", 'r') as f:
wikimusictext = json.load(f)
for item in wikimusictext:
if item['title']==filename:
# output = "Title:\n" + item['title']+'\n\n'
# output += "Artist:\n" + item['artist']+ '\n\n'
# output += "Genre:\n" + item['genre']+ '\n\n'
# output += "Description:\n" + item['text']+ '\n\n'
# output += "ABC notation:\n" + item['music']+ '\n\n'
print("Title: " + item['title'])
print("Artist: " + item['artist'])
print("Genre: " + item['genre'])
print("Description: " + item['text'])
print("ABC notation:\n" + item['music'])
return item["title"], item["artist"], item["genre"], item["text"], item["music"]
output_title = gr.outputs.Textbox(label="Title")
output_artist = gr.outputs.Textbox(label="Artist")
output_genre = gr.outputs.Textbox(label="Genre")
output_description = gr.outputs.Textbox(label="Description")
output_abc = gr.outputs.Textbox(label="ABC notation")
inputs=gr.Textbox(lines=2, placeholder="Describe the music you want to search...", label="Query"),
outputs=[output_title, output_artist, output_genre, output_description, output_abc],
title="🗜️ CLaMP: Semantic Music Search",