Spaces:
Runtime error
Runtime error
File size: 4,589 Bytes
55f37e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import subprocess
import os
import gradio as gr
import json
from utils import *
from unidecode import unidecode
from transformers import AutoTokenizer
CLAMP_MODEL_NAME = 'clamp-small-512'
QUERY_MODAL = 'text'
KEY_MODAL = 'music'
TOP_N = 1
TEXT_MODEL_NAME = 'distilroberta-base'
TEXT_LENGTH = 128
device = torch.device("cpu")
# load CLaMP model
model = CLaMP.from_pretrained(CLAMP_MODEL_NAME)
music_length = model.config.max_length
model = model.to(device)
model.eval()
# initialize patchilizer, tokenizer, and softmax
patchilizer = MusicPatchilizer()
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
softmax = torch.nn.Softmax(dim=1)
def compute_values(Q_e, K_e, t=1):
"""
Compute the values for the attention matrix
Args:
Q_e (torch.Tensor): Query embeddings
K_e (torch.Tensor): Key embeddings
t (float): Temperature for the softmax
Returns:
values (torch.Tensor): Values for the attention matrix
"""
# Normalize the feature representations
Q_e = torch.nn.functional.normalize(Q_e, dim=1)
K_e = torch.nn.functional.normalize(K_e, dim=1)
# Scaled pairwise cosine similarities [1, n]
logits = torch.mm(Q_e, K_e.T) * torch.exp(torch.tensor(t))
values = softmax(logits)
return values.squeeze()
def encoding_data(data, modal):
"""
Encode the data into ids
Args:
data (list): List of strings
modal (str): "music" or "text"
Returns:
ids_list (list): List of ids
"""
ids_list = []
if modal=="music":
for item in data:
patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True)
ids_list.append(torch.tensor(patches).reshape(-1))
else:
for item in data:
text_encodings = tokenizer(item,
return_tensors='pt',
truncation=True,
max_length=TEXT_LENGTH)
ids_list.append(text_encodings['input_ids'].squeeze(0))
return ids_list
def get_features(ids_list, modal):
"""
Get the features from the CLaMP model
Args:
ids_list (list): List of ids
modal (str): "music" or "text"
Returns:
features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size)
"""
features_list = []
print("Extracting "+modal+" features...")
with torch.no_grad():
for ids in tqdm(ids_list):
ids = ids.unsqueeze(0)
if modal=="text":
masks = torch.tensor([1]*len(ids[0])).unsqueeze(0)
features = model.text_enc(ids.to(device), attention_mask=masks.to(device))['last_hidden_state']
features = model.avg_pooling(features, masks)
features = model.text_proj(features)
else:
masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0)
features = model.music_enc(ids, masks)['last_hidden_state']
features = model.avg_pooling(features, masks)
features = model.music_proj(features)
features_list.append(features[0])
return torch.stack(features_list).to(device)
def semantic_music_search(query):
"""
Semantic music search
Args:
query (str): Query string
Returns:
output (str): Search result
"""
with open(KEY_MODAL+"_key_cache_"+str(music_length)+".pth", 'rb') as f:
key_cache = torch.load(f)
# encode query
query_ids = encoding_data([query], QUERY_MODAL)
query_feature = get_features(query_ids, QUERY_MODAL)
key_filenames = key_cache["filenames"]
key_features = key_cache["features"]
# compute values
values = compute_values(query_feature, key_features)
idx = torch.argsort(values)[-1]
filename = key_filenames[idx].split('/')[-1][:-4]
with open("wikimusictext.json", 'r') as f:
wikimusictext = json.load(f)
for item in wikimusictext:
if item['title']==filename:
output = "Title:\n" + item['title']+'\n\n'
output += "Artist:\n" + item['artist']+ '\n\n'
output += "Genre:\n" + item['genre']+ '\n\n'
output += "Description:\n" + item['text']+ '\n\n'
output += "ABC notation:\n" + item['music']+ '\n\n'
return output
gr.Interface(
fn=semantic_music_search,
inputs=gr.Textbox(lines=2, placeholder="Describe the music you want to search..."),
outputs="text",
).launch()
|