import gc import hashlib import os from glob import glob from pathlib import Path import librosa import torch from diskcache import Cache from qdrant_client import QdrantClient from qdrant_client.http import models from tqdm import tqdm from transformers import ClapModel, ClapProcessor from s3_utils import s3_auth, upload_file_to_bucket from dotenv import load_dotenv load_dotenv() # PARAMETERS ####################################################################################### CACHE_FOLDER = '/home/arthur/data/music/demo_audio_search/audio_embeddings_cache_individual/' KAGGLE_DB_PATH = '/home/arthur/data/kaggle/park-spring-2023-music-genre-recognition/train/train' AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID'] AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY'] S3_BUCKET = "synthia-research" S3_FOLDER = "huggingface_spaces_demo" AWS_REGION = "eu-west-3" s3 = s3_auth(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION) # Functions utils ################################################################################## def get_md5(fpath): with open(fpath, "rb") as f: file_hash = hashlib.md5() while chunk := f.read(8192): file_hash.update(chunk) return file_hash.hexdigest() def get_audio_embedding(model, audio_file, cache): # Compute a unique hash for the audio file file_key = f"{model.config._name_or_path}" + get_md5(audio_file) if file_key in cache: # If the embedding for this file is cached, retrieve it embedding = cache[file_key] else: # Otherwise, compute the embedding and cache it y, sr = librosa.load(audio_file, sr=48000) inputs = processor(audios=y, sampling_rate=sr, return_tensors="pt") embedding = model.get_audio_features(**inputs)[0] gc.collect() torch.cuda.empty_cache() cache[file_key] = embedding return embedding # ################## Loading the CLAP model ################### # loading the model print("[INFO] Loading the model...") model_name = "laion/larger_clap_general" model = ClapModel.from_pretrained(model_name) processor = ClapProcessor.from_pretrained(model_name) # Initialize the cache os.makedirs(CACHE_FOLDER, exist_ok=True) cache = Cache(CACHE_FOLDER) # Creating a qdrant collection ##################################################################### client = QdrantClient(os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_KEY']) print("[INFO] Client created...") print("[INFO] Creating qdrant data collection...") if not client.collection_exists("demo_spaces_db"): client.create_collection( collection_name="demo_spaces_db", vectors_config=models.VectorParams( size=model.config.projection_dim, distance=models.Distance.COSINE ), ) # Embed the audio files ! audio_files = [p for p in glob(os.path.join(KAGGLE_DB_PATH, '*/*.wav'))] chunk_size, idx = 1, 0 total_chunks = int(len(audio_files) / chunk_size) # Use tqdm for a progress bar print("Uploading on DB + S3") for i in tqdm(range(0, len(audio_files), chunk_size), desc="[INFO] Uploading data records to data collection..."): chunk = audio_files[i:i + chunk_size] # Get a chunk of audio files records = [] for audio_file in chunk: embedding = get_audio_embedding(model, audio_file, cache) file_obj = open(audio_file, 'rb') s3key = f'{S3_FOLDER}/{Path(audio_file).name}' upload_file_to_bucket(s3, file_obj, S3_BUCKET, s3key) records.append( models.PointStruct( id=idx, vector=embedding, payload={ "audio_path": audio_file, "audio_s3url": f"https://{S3_BUCKET}.s3.amazonaws.com/{s3key}", "style": audio_file.split('/')[-1]} ) ) f"Uploaded s3 file : {idx}" idx += 1 client.upload_points( collection_name="demo_spaces_db", points=records ) print("[INFO] Successfully uploaded data records to data collection!") # It's a good practice to close the cache when done cache.close()