import subprocess import os import gradio as gr import json from utils import * from unidecode import unidecode from transformers import AutoTokenizer description = """
CLAMP_MODEL_NAME = 'clamp-small-512'
QUERY_MODAL = 'text'
KEY_MODAL = 'music'
TOP_N = 1
TEXT_MODEL_NAME = 'distilroberta-base'
TEXT_LENGTH = 128

device = torch.device("cpu")

# load CLaMP model
model = CLaMP.from_pretrained(CLAMP_MODEL_NAME)
music_length = model.config.max_length
model = model.eval()

# initialize patchilizer, tokenizer, and softmax
patchilizer = MusicPatchilizer()
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
softmax = torch.nn.Softmax(dim=1)

def compute_values(Q_e, K_e, t=1):
    """
    Compute the values for the attention matrix
    Args:
        Q_e (torch.Tensor): Query embeddings
        K_e (torch.Tensor): Key embeddings
        t (float): Temperature for the softmax
    Returns:
        values (torch.Tensor): Values for the attention matrix
    """
    # Normalize the feature representations
    Q_e = torch.nn.functional.normalize(Q_e, dim=1)
    K_e = torch.nn.functional.normalize(K_e, dim=1)
    # Scaled pairwise cosine similarities [1, n]
    logits =, K_e.T) * torch.exp(torch.tensor(t))
    values = softmax(logits)
    return values.squeeze()

def encoding_data(data, modal):
    """
    Encode the data into ids
    Args:
        data (list): List of strings
        modal (str): "music" or "text"
    Returns:
        ids_list (list): List of ids
    """
    ids_list = []
    if modal=="music":
        for item in data:
            patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True)
            ids_list.append(torch.tensor(patches).reshape(-1))
    else:
        for item in data:
            text_encodings = tokenizer(item, return_tensors='pt', truncation=True, max_length=TEXT_LENGTH)
            ids_list.append(text_encodings['input_ids'].squeeze(0))
    return ids_list

def get_features(ids_list, modal):
    """
    Get the features from the CLaMP model
    Args:
        ids_list (list): List of ids
        modal (str): "music" or "text"
    Returns:
        features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size)
    """
    features_list = []
    print("Extracting "+modal+" features...")
    with torch.no_grad():
        for ids in tqdm(ids_list):
            ids = ids.unsqueeze(0)
            if modal=="text":
                masks = torch.tensor([1]*len(ids[0])).unsqueeze(0)
                features = model.text_enc(,['last_hidden_state']
                features = model.avg_pooling(features, masks)
                features = model.text_proj(features)
            else:
                masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0)
                features = model.music_enc(ids, masks)['last_hidden_state']
                features = model.avg_pooling(features, masks)
                features = model.music_proj(features)
            features_list.append(features[0])
    return torch.stack(features_list).to(device)

def semantic_music_search(query):
    """
    Semantic music search
    Args:
        query (str): Query string
    Returns:
        output (str): Search result
    """
    with open(KEY_MODAL+"_key_cache_"+str(music_length)+".pth", 'rb') as f:
        key_cache = torch.load(f)
    # encode query
    query_ids = encoding_data([query], QUERY_MODAL)
    query_feature = get_features(query_ids, QUERY_MODAL)
    key_filenames = key_cache["filenames"]
    key_features = key_cache["features"]
    # compute values
    values = compute_values(query_feature, key_features)
    idx = torch.argsort(values)[-1]
    filename = key_filenames[idx].split('/')[-1][:-4]
    with open("wikimusictext.json", 'r') as f:
        wikimusictext = json.load(f)
    for item in wikimusictext:
        if item['title']==filename:
            # output = "Title:\n" + item['title']+'\n\n'
            # output += "Artist:\n" + item['artist']+ '\n\n'
            # output += "Genre:\n" + item['genre']+ '\n\n'
            # output += "Description:\n" + item['text']+ '\n\n'
            # output += "ABC notation:\n" + item['music']+ '\n\n'
            return item["title"], item["artist"], item["genre"], item["text"], item["music"]

output_title = gr.outputs.Textbox(label="Title")
output_artist = gr.outputs.Textbox(label="Artist")
output_genre = gr.outputs.Textbox(label="Genre")
output_description = gr.outputs.Textbox(label="Description")
output_abc = gr.outputs.Textbox(label="ABC notation")

gr.Interface(
    fn=semantic_music_search,
    inputs=gr.Textbox(lines=2, placeholder="Describe the music you want to search..."),
    outputs=[output_title, output_artist, output_genre, output_description, output_abc],
    title="🗜️ CLaMP: Semantic Music Search",
    description=description,
    article=article).launch()