database / database.py
leadr64's picture
Ajouter le script Gradio et les dépendances
60e357d
raw
history blame contribute delete
No virus
4.14 kB
import gc
import hashlib
import os
from glob import glob
from pathlib import Path
import librosa
import torch
from diskcache import Cache
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
from transformers import ClapModel, ClapProcessor
from s3_utils import s3_auth, upload_file_to_bucket
from dotenv import load_dotenv
load_dotenv()
# PARAMETERS #######################################################################################
CACHE_FOLDER = '/home/arthur/data/music/demo_audio_search/audio_embeddings_cache_individual/'
KAGGLE_DB_PATH = '/home/arthur/data/kaggle/park-spring-2023-music-genre-recognition/train/train'
AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
S3_BUCKET = "synthia-research"
S3_FOLDER = "huggingface_spaces_demo"
AWS_REGION = "eu-west-3"
s3 = s3_auth(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
# Functions utils ##################################################################################
def get_md5(fpath):
with open(fpath, "rb") as f:
file_hash = hashlib.md5()
while chunk := f.read(8192):
file_hash.update(chunk)
return file_hash.hexdigest()
def get_audio_embedding(model, audio_file, cache):
# Compute a unique hash for the audio file
file_key = f"{model.config._name_or_path}" + get_md5(audio_file)
if file_key in cache:
# If the embedding for this file is cached, retrieve it
embedding = cache[file_key]
else:
# Otherwise, compute the embedding and cache it
y, sr = librosa.load(audio_file, sr=48000)
inputs = processor(audios=y, sampling_rate=sr, return_tensors="pt")
embedding = model.get_audio_features(**inputs)[0]
gc.collect()
torch.cuda.empty_cache()
cache[file_key] = embedding
return embedding
# ################## Loading the CLAP model ###################
# loading the model
print("[INFO] Loading the model...")
model_name = "laion/larger_clap_general"
model = ClapModel.from_pretrained(model_name)
processor = ClapProcessor.from_pretrained(model_name)
# Initialize the cache
os.makedirs(CACHE_FOLDER, exist_ok=True)
cache = Cache(CACHE_FOLDER)
# Creating a qdrant collection #####################################################################
client = QdrantClient(os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_KEY'])
print("[INFO] Client created...")
print("[INFO] Creating qdrant data collection...")
if not client.collection_exists("demo_spaces_db"):
client.create_collection(
collection_name="demo_spaces_db",
vectors_config=models.VectorParams(
size=model.config.projection_dim,
distance=models.Distance.COSINE
),
)
# Embed the audio files !
audio_files = [p for p in glob(os.path.join(KAGGLE_DB_PATH, '*/*.wav'))]
chunk_size, idx = 1, 0
total_chunks = int(len(audio_files) / chunk_size)
# Use tqdm for a progress bar
print("Uploading on DB + S3")
for i in tqdm(range(0, len(audio_files), chunk_size),
desc="[INFO] Uploading data records to data collection..."):
chunk = audio_files[i:i + chunk_size] # Get a chunk of audio files
records = []
for audio_file in chunk:
embedding = get_audio_embedding(model, audio_file, cache)
file_obj = open(audio_file, 'rb')
s3key = f'{S3_FOLDER}/{Path(audio_file).name}'
upload_file_to_bucket(s3, file_obj, S3_BUCKET, s3key)
records.append(
models.PointStruct(
id=idx, vector=embedding,
payload={
"audio_path": audio_file,
"audio_s3url": f"https://{S3_BUCKET}.s3.amazonaws.com/{s3key}",
"style": audio_file.split('/')[-1]}
)
)
f"Uploaded s3 file : {idx}"
idx += 1
client.upload_points(
collection_name="demo_spaces_db",
points=records
)
print("[INFO] Successfully uploaded data records to data collection!")
# It's a good practice to close the cache when done
cache.close()