import os.path

# from src.music.data_collection.is_audio_solo_piano import calculate_piano_solo_prob
from src.music.utils import load_audio
from src.music.config import FPS
import pretty_midi as pm
import numpy as np
from src.music.config import MUSIC_REP_PATH, MUSIC_NN_PATH
from sklearn.neighbors import NearestNeighbors
from src.cocktails.config import FULL_COCKTAIL_REP_PATH, COCKTAIL_NN_PATH, COCKTAILS_CSV_DATA
# from src.cocktails.pipeline.get_affect2affective_cluster import get_affective_cluster_centers
from src.cocktails.utilities.other_scrubbing_utilities import print_recipe
from src.music.utils import get_all_subfiles_with_extension
import os
import pickle
import pandas as pd
import time

keyword = 'b256_r128_represented'
def load_reps(rep_path, sample_size=None):
    if sample_size:
        with open(rep_path + f'all_reps_unnormalized_sample{sample_size}.pickle', 'rb') as f:
            data = pickle.load(f)
    else:
        with open(rep_path + f'music_reps_unnormalized.pickle', 'rb') as f:
            data = pickle.load(f)
    reps = data['reps']
    # playlists = [r.split(f'_{keyword}')[0].split('/')[-1] for r in data['paths']]
    playlists = [r.split(f'{keyword}')[1].split('/')[1] for r in data['paths']]
    n_data, dim_data = reps.shape
    return reps, data['paths'], playlists, n_data, dim_data

class Debugger():
    def __init__(self, verbose=True):

        if verbose: print('Setting up debugger.')
        if not os.path.exists(MUSIC_NN_PATH):
            reps_path = MUSIC_REP_PATH + 'music_reps_unnormalized.pickle'
            if not os.path.exists(reps_path):
                all_rep_path = get_all_subfiles_with_extension(MUSIC_REP_PATH, max_depth=3, extension='.txt', current_depth=0)
                all_data = []
                new_all_rep_path = []
                for i_r, r in enumerate(all_rep_path):
                    if 'mean_std' not in r:
                        all_data.append(np.loadtxt(r))
                        assert len(all_data[-1]) == 128
                        new_all_rep_path.append(r)
                data = np.array(all_data)
                to_save = dict(reps=data,
                               paths=new_all_rep_path)
                with open(reps_path, 'wb') as f:
                    pickle.dump(to_save, f)

            reps, self.rep_paths, playlists, n_data, self.dim_rep_music = load_reps(MUSIC_REP_PATH)
            self.nn_model_music = NearestNeighbors(n_neighbors=6, metric='cosine')
            self.nn_model_music.fit(reps)
            to_save = dict(nn_model=self.nn_model_music,
                           rep_paths=self.rep_paths,
                           dim_rep_music=self.dim_rep_music)
            with open(MUSIC_NN_PATH, 'wb') as f:
                pickle.dump(to_save, f)
        else:
            with open(MUSIC_NN_PATH, 'rb') as f:
                data = pickle.load(f)
            self.nn_model_music = data['nn_model']
            self.rep_paths = data['rep_paths']
            self.dim_rep_music = data['dim_rep_music']
        if verbose: print(f'  {len(self.rep_paths)} songs, representation dim: {self.dim_rep_music}')
        self.rep_paths = np.array(self.rep_paths)
        if not os.path.exists(COCKTAIL_NN_PATH):
            cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
            # cocktail_reps = (cocktail_reps - cocktail_reps.mean(axis=0)) / cocktail_reps.std(axis=0)
            self.nn_model_cocktail = NearestNeighbors(n_neighbors=6)
            self.nn_model_cocktail.fit(cocktail_reps)
            self.dim_rep_cocktail = cocktail_reps.shape[1]
            self.n_cocktails = cocktail_reps.shape[0]
            to_save = dict(nn_model=self.nn_model_cocktail,
                           dim_rep_cocktail=self.dim_rep_cocktail,
                           n_cocktails=self.n_cocktails)
            with open(COCKTAIL_NN_PATH, 'wb') as f:
                pickle.dump(to_save, f)
        else:
            with open(COCKTAIL_NN_PATH, 'rb') as f:
                data = pickle.load(f)
            self.nn_model_cocktail = data['nn_model']
            self.dim_rep_cocktail = data['dim_rep_cocktail']
            self.n_cocktails = data['n_cocktails']
        if verbose: print(f'  {self.n_cocktails} cocktails, representation dim: {self.dim_rep_cocktail}')

        self.cocktail_data = pd.read_csv(COCKTAILS_CSV_DATA)
        # self.affective_cluster_centers = get_affective_cluster_centers()
        self.keys_to_print = ['mse_reconstruction', 'nearest_cocktail_recipes', 'nearest_cocktail_urls',
                              'nn_music_dists', 'nn_music', 'dim_rep', 'nb_notes', 'audio_len', 'piano_solo_prob', 'recipe_score', 'cocktail_rep']
        # 'affect', 'affective_cluster_id', 'affective_cluster_center',


    def get_nearest_songs(self, music_rep):
        dists, indexes = self.nn_model_music.kneighbors(music_rep.reshape(1, -1))
        indexes = indexes.flatten()[:5]
        rep_paths = [r.split('/')[-1] for r in self.rep_paths[indexes[:5]]]
        return rep_paths, dists.flatten().tolist()

    def get_nearest_cocktails(self, cocktail_rep):
        dists, indexes = self.nn_model_cocktail.kneighbors(cocktail_rep.reshape(1, -1))
        indexes = indexes.flatten()
        nn_names = np.array(self.cocktail_data['names'])[indexes].tolist()
        nn_urls = np.array(self.cocktail_data['urls'])[indexes].tolist()
        nn_recipes = [print_recipe(ingredient_str=ing_str, to_print=False) for ing_str in np.array(self.cocktail_data['ingredients_str'])[indexes]]
        nn_ing_strs = np.array(self.cocktail_data['ingredients_str'])[indexes].tolist()
        return indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs

    def extract_info(self, all_paths, affective_cluster_id, affect, cocktail_rep, music_reconstruction, recipe_score, verbose=False, level=0):
        if verbose: print(' ' * level + 'Extracting debug info..')
        init_time = time.time()
        debug_dict = dict()
        debug_dict['all_paths'] = all_paths
        debug_dict['recipe_score'] = recipe_score

        if all_paths['audio_path'] != None:
            # is it piano?
            debug_dict['piano_solo_prob'] = None#float(calculate_piano_solo_prob(all_paths['audio_path'])[0])
            # how long is the audio
            (audio, _) = load_audio(all_paths['audio_path'], sr=FPS, mono=True)
            debug_dict['audio_len'] = int(len(audio) / FPS)
        else:
            debug_dict['piano_solo_prob'] = None
            debug_dict['audio_len'] = None

        # how many notes?
        midi = pm.PrettyMIDI(all_paths['processed_path'])
        debug_dict['nb_notes'] = len(midi.instruments[0].notes)

        # dimension of music rep
        representation = np.loadtxt(all_paths['representation_path'])
        debug_dict['dim_rep'] = representation.shape[0]

        # closest songs in dataset
        debug_dict['nn_music'], debug_dict['nn_music_dists'] = self.get_nearest_songs(representation)

        # get affective cluster info
        # debug_dict['affective_cluster_id'] = affective_cluster_id[0]
        # debug_dict['affective_cluster_center'] = self.affective_cluster_centers[affective_cluster_id].flatten().tolist()
        # debug_dict['affect'] = affect.flatten().tolist()
        indexes, nn_names, nn_urls, nn_recipes, nn_ing_strs = self.get_nearest_cocktails(cocktail_rep)
        debug_dict['cocktail_rep'] = cocktail_rep.copy().tolist()
        debug_dict['nearest_cocktail_indexes'] = indexes.tolist()
        debug_dict['nn_ing_strs'] = nn_ing_strs
        debug_dict['nearest_cocktail_names'] = nn_names
        debug_dict['nearest_cocktail_urls'] = nn_urls
        debug_dict['nearest_cocktail_recipes'] = nn_recipes

        debug_dict['music_reconstruction'] = music_reconstruction.tolist()
        debug_dict['mse_reconstruction'] = ((music_reconstruction - representation) ** 2).mean()
        self.debug_dict = debug_dict
        if verbose: print(' ' * (level + 2) + f'Debug info extracted in {int(time.time() - init_time)} seconds.')

        return self.debug_dict

    def print_debug(self, level=0):
        print(' ' * level + '__DEBUGGING INFO__')
        for k in self.keys_to_print:
            to_print = self.debug_dict[k]
            if k == 'nearest_cocktail_recipes':
                to_print = self.debug_dict[k].copy()
                for i in range(len(to_print)):
                    to_print[i] = to_print[i].replace('\n', '').replace('\t', '').replace('()', '')
            if k == "nn_music":
                to_print = self.debug_dict[k].copy()
                for i in range(len(to_print)):
                    to_print[i] = to_print[i].replace('encoded_new_structured_', '').replace('_represented.txt', '')
            to_print_str = f'{to_print}'
            if isinstance(to_print, float):
                to_print_str = f'{to_print:.2f}'
            elif isinstance(to_print, list):
                if isinstance(to_print[0], float):
                    to_print_str = '['
                    for element in to_print:
                        to_print_str += f'{element:.2f}, '
                    to_print_str = to_print_str[:-2] + ']'
            print(' ' * (level + 2) + f'{k} : ' + to_print_str)