File size: 4,136 Bytes
60e357d
f2865dc
 
 
60e357d
f2865dc
60e357d
 
f2865dc
 
 
 
60e357d
f2865dc
60e357d
 
 
f2865dc
60e357d
 
 
 
 
 
 
 
 
 
 
 
 
f2865dc
 
 
 
 
 
 
 
60e357d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2865dc
60e357d
 
 
f2865dc
60e357d
f2865dc
 
 
60e357d
 
f2865dc
 
 
60e357d
 
 
 
 
 
 
f2865dc
 
60e357d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2865dc
60e357d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gc
import hashlib
import os
from glob import glob
from pathlib import Path

import librosa
import torch
from diskcache import Cache
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
from transformers import ClapModel, ClapProcessor

from s3_utils import s3_auth, upload_file_to_bucket
from dotenv import load_dotenv
load_dotenv()

# PARAMETERS #######################################################################################
CACHE_FOLDER = '/home/arthur/data/music/demo_audio_search/audio_embeddings_cache_individual/'
KAGGLE_DB_PATH = '/home/arthur/data/kaggle/park-spring-2023-music-genre-recognition/train/train'
AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
S3_BUCKET = "synthia-research"
S3_FOLDER = "huggingface_spaces_demo"
AWS_REGION = "eu-west-3"

s3 = s3_auth(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)


# Functions utils ##################################################################################
def get_md5(fpath):
    with open(fpath, "rb") as f:
        file_hash = hashlib.md5()
        while chunk := f.read(8192):
            file_hash.update(chunk)
    return file_hash.hexdigest()


def get_audio_embedding(model, audio_file, cache):
    # Compute a unique hash for the audio file
    file_key = f"{model.config._name_or_path}" + get_md5(audio_file)
    if file_key in cache:
        # If the embedding for this file is cached, retrieve it
        embedding = cache[file_key]
    else:
        # Otherwise, compute the embedding and cache it
        y, sr = librosa.load(audio_file, sr=48000)
        inputs = processor(audios=y, sampling_rate=sr, return_tensors="pt")
        embedding = model.get_audio_features(**inputs)[0]
        gc.collect()
        torch.cuda.empty_cache()
        cache[file_key] = embedding
    return embedding



# ################## Loading the CLAP model ###################
# loading the model
print("[INFO] Loading the model...")
model_name = "laion/larger_clap_general"
model = ClapModel.from_pretrained(model_name)
processor = ClapProcessor.from_pretrained(model_name)

# Initialize the cache
os.makedirs(CACHE_FOLDER, exist_ok=True)
cache = Cache(CACHE_FOLDER)

# Creating a qdrant collection #####################################################################
client = QdrantClient(os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_KEY'])
print("[INFO] Client created...")

print("[INFO] Creating qdrant data collection...")
if not client.collection_exists("demo_spaces_db"):
    client.create_collection(
        collection_name="demo_spaces_db",
        vectors_config=models.VectorParams(
            size=model.config.projection_dim,
            distance=models.Distance.COSINE
        ),
    )

# Embed the audio files !
audio_files = [p for p in glob(os.path.join(KAGGLE_DB_PATH, '*/*.wav'))]
chunk_size, idx = 1, 0
total_chunks = int(len(audio_files) / chunk_size)

# Use tqdm for a progress bar
print("Uploading on DB + S3")
for i in tqdm(range(0, len(audio_files), chunk_size),
              desc="[INFO] Uploading data records to data collection..."):
    chunk = audio_files[i:i + chunk_size]  # Get a chunk of audio files
    records = []
    for audio_file in chunk:
        embedding = get_audio_embedding(model, audio_file, cache)
        file_obj = open(audio_file, 'rb')
        s3key = f'{S3_FOLDER}/{Path(audio_file).name}'
        upload_file_to_bucket(s3, file_obj, S3_BUCKET, s3key)
        records.append(
            models.PointStruct(
                id=idx, vector=embedding,
                payload={
                    "audio_path": audio_file,
                    "audio_s3url": f"https://{S3_BUCKET}.s3.amazonaws.com/{s3key}",
                    "style": audio_file.split('/')[-1]}
            )
        )
        f"Uploaded s3 file : {idx}"
        idx += 1
    client.upload_points(
        collection_name="demo_spaces_db",
        points=records
    )
print("[INFO] Successfully uploaded data records to data collection!")


# It's a good practice to close the cache when done
cache.close()