Spaces:
Runtime error
Runtime error
import subprocess | |
import os | |
import gradio as gr | |
from utils import * | |
from unidecode import unidecode | |
from transformers import AutoTokenizer | |
description = """ | |
<div> | |
<a style="display:inline-block" href='https://github.com/microsoft/muzic/tree/main/clamp'><img src='https://img.shields.io/github/stars/microsoft/muzic?style=social' /></a> | |
<a style='display:inline-block' href='https://ai-muzic.github.io/clamp/'><img src='https://img.shields.io/badge/website-CLaMP-ff69b4.svg' /></a> | |
<a style="display:inline-block" href="https://huggingface.co/datasets/sander-wood/wikimusictext"><img src="https://img.shields.io/badge/huggingface-dataset-ffcc66.svg"></a> | |
<a style="display:inline-block" href="https://arxiv.org/pdf/2304.11029.pdf"><img src="https://img.shields.io/badge/arXiv-2304.11029-b31b1b.svg"></a> | |
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/sander-wood/clamp_zero_shot_music_classification?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-md-dark.svg" alt="Duplicate Space"></a> | |
</div> | |
## ℹ️ How to use this demo? | |
1. Select a music file in MusicXML (.mxl) format. | |
2. Enter the candidate classes (e.g., composers) with context (e.g., "This piece of music is composed by {composer}.") in the text box, up to 10 classes. | |
3. Click "Submit" and wait for the result. | |
## ❕Notice | |
- The demo only supports MusicXML (.mxl) files. | |
- The demo supports up to 10 classes for zero-shot classification. | |
- The text box is case-sensitive. | |
- You can enter longer text for the text box, but the demo will only use the first 128 tokens. | |
- The demo is based on CLaMP-S/512, a CLaMP model with 6-layer Transformer text/music encoders and a sequence length of 512. | |
## 🎵👉🔠 Zero-shot Music Classification | |
Zero-shot classification refers to the classification of new items into any desired label without the need for training data. It involves using a prompt template to provide context for the text encoder. For example, a prompt such as "This piece of music is composed by {composer}." is utilized to form input texts based on the names of candidate composers. The text encoder then outputs text features based on these input texts. Meanwhile, the music encoder extracts the music feature from the unlabelled target symbolic music. By calculating the similarity between each candidate text feature and the target music feature, the label with the highest similarity is chosen as the predicted one. | |
""" | |
CLAMP_MODEL_NAME = 'sander-wood/clamp-small-512' | |
QUERY_MODAL = 'music' | |
KEY_MODAL = 'text' | |
TOP_N = 1 | |
TEXT_MODEL_NAME = 'distilroberta-base' | |
TEXT_LENGTH = 128 | |
device = torch.device("cpu") | |
# load CLaMP model | |
model = CLaMP.from_pretrained(CLAMP_MODEL_NAME) | |
music_length = model.config.max_length | |
model = model.to(device) | |
model.eval() | |
# initialize patchilizer, tokenizer, and softmax | |
patchilizer = MusicPatchilizer() | |
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) | |
softmax = torch.nn.Softmax(dim=1) | |
def compute_values(Q_e, K_e, t=1): | |
""" | |
Compute the values for the attention matrix | |
Args: | |
Q_e (torch.Tensor): Query embeddings | |
K_e (torch.Tensor): Key embeddings | |
t (float): Temperature for the softmax | |
Returns: | |
values (torch.Tensor): Values for the attention matrix | |
""" | |
# Normalize the feature representations | |
Q_e = torch.nn.functional.normalize(Q_e, dim=1) | |
K_e = torch.nn.functional.normalize(K_e, dim=1) | |
# Scaled pairwise cosine similarities [1, n] | |
logits = torch.mm(Q_e, K_e.T) * torch.exp(torch.tensor(t)) | |
values = softmax(logits) | |
return values.squeeze() | |
def encoding_data(data, modal): | |
""" | |
Encode the data into ids | |
Args: | |
data (list): List of strings | |
modal (str): "music" or "text" | |
Returns: | |
ids_list (list): List of ids | |
""" | |
ids_list = [] | |
if modal=="music": | |
for item in data: | |
patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True) | |
ids_list.append(torch.tensor(patches).reshape(-1)) | |
else: | |
for item in data: | |
text_encodings = tokenizer(item, | |
return_tensors='pt', | |
truncation=True, | |
max_length=TEXT_LENGTH) | |
ids_list.append(text_encodings['input_ids'].squeeze(0)) | |
return ids_list | |
def abc_filter(lines): | |
""" | |
Filter out the metadata from the abc file | |
Args: | |
lines (list): List of lines in the abc file | |
Returns: | |
music (str): Music string | |
""" | |
music = "" | |
for line in lines: | |
if line[:2] in ['A:', 'B:', 'C:', 'D:', 'F:', 'G', 'H:', 'N:', 'O:', 'R:', 'r:', 'S:', 'T:', 'W:', 'w:', 'X:', 'Z:'] \ | |
or line=='\n' \ | |
or (line.startswith('%') and not line.startswith('%%score')): | |
continue | |
else: | |
if "%" in line and not line.startswith('%%score'): | |
line = "%".join(line.split('%')[:-1]) | |
music += line[:-1] + '\n' | |
else: | |
music += line + '\n' | |
return music | |
def load_music(filename): | |
""" | |
Load the music from the xml file | |
Args: | |
file (Union[str, bytes, BinaryIO, TextIO]): Input file object containing the xml file | |
Returns: | |
music (str): Music string | |
""" | |
# Get absolute path of xml2abc.py | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
xml2abc_path = os.path.join(script_dir, 'xml2abc.py') | |
# Use absolute path in Popen() | |
p = subprocess.Popen(['python', xml2abc_path, '-m', '2', '-c', '6', '-x', filename], stdout=subprocess.PIPE) | |
result = p.communicate()[0] | |
output = result.decode('utf-8').replace('\r', '') | |
music = unidecode(output).split('\n') | |
music = abc_filter(music) | |
return music | |
def get_features(ids_list, modal): | |
""" | |
Get the features from the CLaMP model | |
Args: | |
ids_list (list): List of ids | |
modal (str): "music" or "text" | |
Returns: | |
features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size) | |
""" | |
features_list = [] | |
print("Extracting "+modal+" features...") | |
with torch.no_grad(): | |
for ids in tqdm(ids_list): | |
ids = ids.unsqueeze(0) | |
if modal=="text": | |
masks = torch.tensor([1]*len(ids[0])).unsqueeze(0) | |
features = model.text_enc(ids.to(device), attention_mask=masks.to(device))['last_hidden_state'] | |
features = model.avg_pooling(features, masks) | |
features = model.text_proj(features) | |
else: | |
masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0) | |
features = model.music_enc(ids, masks)['last_hidden_state'] | |
features = model.avg_pooling(features, masks) | |
features = model.music_proj(features) | |
features_list.append(features[0]) | |
return torch.stack(features_list).to(device) | |
def zero_shot_music_classification(file, class1, class2, class3, class4, class5, class6, class7, class8, class9, class10): | |
""" | |
Classify music based on the given classes | |
Args: | |
file (Union[str, bytes, BinaryIO, TextIO]): Input file object containing the xml file | |
classNum(str): Class Num | |
Returns: | |
output (str): Output string | |
""" | |
query = load_music(file.name) | |
print("\nQuery:\n" + query) | |
# encode query | |
query_ids = encoding_data([query], QUERY_MODAL) | |
query_feature = get_features(query_ids, QUERY_MODAL) | |
keys = [class1, class2, class3, class4, class5, class6, class7, class8, class9, class10] | |
keys = [key for key in keys if key != ''] | |
print("\nKeys:") | |
for key in keys: | |
print(key) | |
key_features = get_features(encoding_data(keys, KEY_MODAL), KEY_MODAL) | |
# compute values | |
values = compute_values(query_feature, key_features) | |
idxs = torch.argsort(values, descending=True) | |
results = {} | |
print("\nResults:") | |
for i in range(len(idxs)): | |
results[keys[idxs[i]]] = values[idxs[i]].item() | |
print(keys[idxs[i]] + ": " + str(values[idxs[i]].item())) | |
return results | |
input_file = gr.inputs.File(label="Upload MusicXML file") | |
input_class1 = gr.inputs.Textbox(label="Class 1", placeholder="Description of class 1") | |
input_class2 = gr.inputs.Textbox(label="Class 2", placeholder="Description of class 2") | |
input_class3 = gr.inputs.Textbox(label="Class 3", placeholder="Description of class 3") | |
input_class4 = gr.inputs.Textbox(label="Class 4", placeholder="Description of class 4") | |
input_class5 = gr.inputs.Textbox(label="Class 5", placeholder="Description of class 5") | |
input_class6 = gr.inputs.Textbox(label="Class 6", placeholder="Description of class 6") | |
input_class7 = gr.inputs.Textbox(label="Class 7", placeholder="Description of class 7") | |
input_class8 = gr.inputs.Textbox(label="Class 8", placeholder="Description of class 8") | |
input_class9 = gr.inputs.Textbox(label="Class 9", placeholder="Description of class 9") | |
input_class10 = gr.inputs.Textbox(label="Class 10", placeholder="Description of class 10") | |
# output labels with their probabilities | |
output_class = gr.outputs.Label(num_top_classes=10, label="Predicted Results") | |
gr.Interface(zero_shot_music_classification, | |
inputs=[input_file, | |
input_class1, | |
input_class2, | |
input_class3, | |
input_class4, | |
input_class5, | |
input_class6, | |
input_class7, | |
input_class8, | |
input_class9, | |
input_class10], | |
outputs=output_class, | |
title="🗜️ CLaMP: Zero-Shot Music Classification", | |
description=description).launch() |