import numpy as np
import librosa
import pickle
import os
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import zipfile
import json
from transformers import ClapModel, ClapProcessor
import torch

dataset_zip = "dataset/all_sounds.zip"
extracted_folder = "dataset/all_sounds"
metadata_path = "dataset/licenses.txt"
audio_embeddings_path = "dataset/audio_embeddings.pkl"

# Unzip if not already extracted
if not os.path.exists(extracted_folder):
    with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
        zip_ref.extractall(extracted_folder)

# Load Hugging Face's CLAP model
processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
model = ClapModel.from_pretrained("laion/clap-htsat-fused")

# Load dataset metadata
with open(metadata_path, "r") as file:
    data = json.load(file)

# Convert the JSON data into a Pandas DataFrame
metadata = pd.DataFrame.from_dict(data, orient="index")
metadata.index = metadata.index.astype(str) + '.wav'

instrument_categories = {
    "Kick": ["kick", "bd", "bass", "808", "kd"],
    "Snare": ["snare", "sd", "sn"],
    "Hi-Hat": ["hihat", "hh", "hi_hat", "hi-hat"],
    "Tom": ["tom"],
    "Cymbal": ["crash", "ride", "splash", "cymbal"],
    "Clap": ["clap"],
    "Percussion": ["shaker", "perc", "tamb", "cowbell", "bongo", "conga", "egg"]
}

# Function to categorize filenames based on keywords
def categorize_instrument(filename):
    lower_filename = filename.lower()
    for category, keywords in instrument_categories.items():
        if any(keyword in lower_filename for keyword in keywords):
            return category
    return "Other"  # Default category if no match is found

# Apply function to create a new 'Instrument' column
metadata["Instrument"] = metadata["name"].apply(categorize_instrument)
metadata["Instrument"].value_counts()

# Load precomputed audio embeddings (to avoid recomputing on every request)
with open(audio_embeddings_path, "rb") as f:
    audio_embeddings = pickle.load(f)

def get_clap_embeddings_from_text(text):
    """Convert user text input to a CLAP embedding using Hugging Face's CLAP."""
    inputs = processor(text=text, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = model.get_text_features(**inputs)
    return text_embeddings.squeeze(0).numpy()

def get_clap_embeddings_from_audio(audio_path):
    audio, sr = librosa.load(audio_path)
    inputs = processor(audios=[audio], return_tensors="pt", sampling_rate=48000)
    with torch.no_grad():
        return model.get_audio_features(**inputs).squeeze(0).numpy()

def find_top_sounds(text_embed, instrument, top_N=4):
    """Finds the closest N sounds for an instrument."""
    valid_sounds = metadata[metadata["Instrument"] == instrument].index.tolist()
    relevant_embeddings = {k: v for k, v in audio_embeddings.items() if k in valid_sounds}

    # Compute cosine similarity
    all_embeds = np.array([v for v in relevant_embeddings.values()])
    similarities = cosine_similarity([text_embed], all_embeds)[0]

    # Get top N matches
    top_indices = np.argsort(similarities)[-top_N:][::-1]
    top_files = [os.path.join(extracted_folder, valid_sounds[i]) for i in top_indices]

    return top_files

def generate_drum_kit(prompt, kit_size=4):
    """Generate a drum kit dictionary from user input."""
    text_embed = get_clap_embeddings_from_text(prompt)
    drum_kit = {}

    for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]:
        drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size)

    return drum_kit