leadr64 commited on
Commit
60e357d
1 Parent(s): a51a160

Ajouter le script Gradio et les dépendances

Browse files
Files changed (6) hide show
  1. Dockerfile +0 -8
  2. README.md +0 -12
  3. database.py +89 -63
  4. docker-compose.yml +0 -6
  5. s3_utils.py +66 -0
  6. setup.sh +0 -2
Dockerfile DELETED
@@ -1,8 +0,0 @@
1
- # Utilisez une image de base pour Qdrant
2
- FROM qdrant/qdrant
3
-
4
- # Exposez le port par défaut de Qdrant
5
- EXPOSE 6333
6
-
7
- # Commande pour démarrer Qdrant
8
- CMD ["qdrant"]
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,12 +0,0 @@
1
- ---
2
- title: Database
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.36.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
database.py CHANGED
@@ -1,18 +1,34 @@
 
1
  import hashlib
2
  import os
3
  from glob import glob
 
4
 
5
- import laion_clap
 
6
  from diskcache import Cache
7
  from qdrant_client import QdrantClient
8
  from qdrant_client.http import models
9
  from tqdm import tqdm
 
10
 
11
- # Utiliser les variables d'environnement pour la configuration
12
- QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost')
13
- QDRANT_PORT = int(os.getenv('QDRANT_PORT', 6333))
14
 
15
- # Functions utils
 
 
 
 
 
 
 
 
 
 
 
 
16
  def get_md5(fpath):
17
  with open(fpath, "rb") as f:
18
  file_hash = hashlib.md5()
@@ -20,73 +36,83 @@ def get_md5(fpath):
20
  file_hash.update(chunk)
21
  return file_hash.hexdigest()
22
 
23
- # PARAMETERS
24
- CACHE_FOLDER = '/home/nahia/data/audio/'
25
- KAGGLE_TRAIN_PATH = '/home/nahia/Documents/audio/actor/Actor_01/'
26
 
27
- # Charger le modèle CLAP
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  print("[INFO] Loading the model...")
29
- model_name = 'music_speech_epoch_15_esc_89.25.pt'
30
- model = laion_clap.CLAP_Module(enable_fusion=False)
31
- model.load_ckpt() # télécharger le checkpoint préentraîné par défaut
32
 
33
- # Initialiser le cache
34
  os.makedirs(CACHE_FOLDER, exist_ok=True)
35
  cache = Cache(CACHE_FOLDER)
36
 
37
- # Embarquer les fichiers audio
38
- audio_files = [p for p in glob(os.path.join(KAGGLE_TRAIN_PATH, '*.wav'))]
39
- audio_embeddings = []
40
- chunk_size = 100
41
- total_chunks = int(len(audio_files) / chunk_size)
42
-
43
- # Utiliser tqdm pour une barre de progression
44
- for i in tqdm(range(0, len(audio_files), chunk_size), total=total_chunks):
45
- chunk = audio_files[i:i + chunk_size] # Obtenir un chunk de fichiers audio
46
- chunk_embeddings = []
47
-
48
- for audio_file in chunk:
49
- # Calculer un hash unique pour le fichier audio
50
- file_key = get_md5(audio_file)
51
-
52
- if file_key in cache:
53
- # Si l'embedding pour ce fichier est en cache, le récupérer
54
- embedding = cache[file_key]
55
- else:
56
- # Sinon, calculer l'embedding et le mettre en cache
57
- embedding = model.get_audio_embedding_from_filelist(x=[audio_file], use_tensor=False)[
58
- 0] # Assumer que le modèle retourne une liste
59
- cache[file_key] = embedding
60
- chunk_embeddings.append(embedding)
61
- audio_embeddings.extend(chunk_embeddings)
62
-
63
- # Fermer le cache quand terminé
64
- cache.close()
65
-
66
- # Créer une collection qdrant
67
- client = QdrantClient(QDRANT_HOST, port=QDRANT_PORT)
68
  print("[INFO] Client created...")
69
 
70
  print("[INFO] Creating qdrant data collection...")
71
- client.create_collection(
72
- collection_name="demo_db7",
73
- vectors_config=models.VectorParams(
74
- size=audio_embeddings[0].shape[0],
75
- distance=models.Distance.COSINE
76
- ),
77
- )
78
-
79
- # Créer des enregistrements Qdrant à partir des embeddings
80
- records = []
81
- for idx, (audio_path, embedding) in enumerate(zip(audio_files, audio_embeddings)):
82
- record = models.PointStruct(
83
- id=idx,
84
- vector=embedding,
85
- payload={"audio_path": audio_path, "style": audio_path.split('/')[-2]}
86
  )
87
- records.append(record)
88
 
89
- # Téléverser les enregistrements dans la collection Qdrant
90
- print("[INFO] Uploading data records to data collection...")
91
- client.upload_points(collection_name="demo_db7", points=records)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  print("[INFO] Successfully uploaded data records to data collection!")
 
 
 
 
 
1
+ import gc
2
  import hashlib
3
  import os
4
  from glob import glob
5
+ from pathlib import Path
6
 
7
+ import librosa
8
+ import torch
9
  from diskcache import Cache
10
  from qdrant_client import QdrantClient
11
  from qdrant_client.http import models
12
  from tqdm import tqdm
13
+ from transformers import ClapModel, ClapProcessor
14
 
15
+ from s3_utils import s3_auth, upload_file_to_bucket
16
+ from dotenv import load_dotenv
17
+ load_dotenv()
18
 
19
+ # PARAMETERS #######################################################################################
20
+ CACHE_FOLDER = '/home/arthur/data/music/demo_audio_search/audio_embeddings_cache_individual/'
21
+ KAGGLE_DB_PATH = '/home/arthur/data/kaggle/park-spring-2023-music-genre-recognition/train/train'
22
+ AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
23
+ AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
24
+ S3_BUCKET = "synthia-research"
25
+ S3_FOLDER = "huggingface_spaces_demo"
26
+ AWS_REGION = "eu-west-3"
27
+
28
+ s3 = s3_auth(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
29
+
30
+
31
+ # Functions utils ##################################################################################
32
  def get_md5(fpath):
33
  with open(fpath, "rb") as f:
34
  file_hash = hashlib.md5()
 
36
  file_hash.update(chunk)
37
  return file_hash.hexdigest()
38
 
 
 
 
39
 
40
+ def get_audio_embedding(model, audio_file, cache):
41
+ # Compute a unique hash for the audio file
42
+ file_key = f"{model.config._name_or_path}" + get_md5(audio_file)
43
+ if file_key in cache:
44
+ # If the embedding for this file is cached, retrieve it
45
+ embedding = cache[file_key]
46
+ else:
47
+ # Otherwise, compute the embedding and cache it
48
+ y, sr = librosa.load(audio_file, sr=48000)
49
+ inputs = processor(audios=y, sampling_rate=sr, return_tensors="pt")
50
+ embedding = model.get_audio_features(**inputs)[0]
51
+ gc.collect()
52
+ torch.cuda.empty_cache()
53
+ cache[file_key] = embedding
54
+ return embedding
55
+
56
+
57
+
58
+ # ################## Loading the CLAP model ###################
59
+ # loading the model
60
  print("[INFO] Loading the model...")
61
+ model_name = "laion/larger_clap_general"
62
+ model = ClapModel.from_pretrained(model_name)
63
+ processor = ClapProcessor.from_pretrained(model_name)
64
 
65
+ # Initialize the cache
66
  os.makedirs(CACHE_FOLDER, exist_ok=True)
67
  cache = Cache(CACHE_FOLDER)
68
 
69
+ # Creating a qdrant collection #####################################################################
70
+ client = QdrantClient(os.environ['QDRANT_URL'], api_key=os.environ['QDRANT_KEY'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  print("[INFO] Client created...")
72
 
73
  print("[INFO] Creating qdrant data collection...")
74
+ if not client.collection_exists("demo_spaces_db"):
75
+ client.create_collection(
76
+ collection_name="demo_spaces_db",
77
+ vectors_config=models.VectorParams(
78
+ size=model.config.projection_dim,
79
+ distance=models.Distance.COSINE
80
+ ),
 
 
 
 
 
 
 
 
81
  )
 
82
 
83
+ # Embed the audio files !
84
+ audio_files = [p for p in glob(os.path.join(KAGGLE_DB_PATH, '*/*.wav'))]
85
+ chunk_size, idx = 1, 0
86
+ total_chunks = int(len(audio_files) / chunk_size)
87
+
88
+ # Use tqdm for a progress bar
89
+ print("Uploading on DB + S3")
90
+ for i in tqdm(range(0, len(audio_files), chunk_size),
91
+ desc="[INFO] Uploading data records to data collection..."):
92
+ chunk = audio_files[i:i + chunk_size] # Get a chunk of audio files
93
+ records = []
94
+ for audio_file in chunk:
95
+ embedding = get_audio_embedding(model, audio_file, cache)
96
+ file_obj = open(audio_file, 'rb')
97
+ s3key = f'{S3_FOLDER}/{Path(audio_file).name}'
98
+ upload_file_to_bucket(s3, file_obj, S3_BUCKET, s3key)
99
+ records.append(
100
+ models.PointStruct(
101
+ id=idx, vector=embedding,
102
+ payload={
103
+ "audio_path": audio_file,
104
+ "audio_s3url": f"https://{S3_BUCKET}.s3.amazonaws.com/{s3key}",
105
+ "style": audio_file.split('/')[-1]}
106
+ )
107
+ )
108
+ f"Uploaded s3 file : {idx}"
109
+ idx += 1
110
+ client.upload_points(
111
+ collection_name="demo_spaces_db",
112
+ points=records
113
+ )
114
  print("[INFO] Successfully uploaded data records to data collection!")
115
+
116
+
117
+ # It's a good practice to close the cache when done
118
+ cache.close()
docker-compose.yml DELETED
@@ -1,6 +0,0 @@
1
- version: '3.8'
2
- services:
3
- qdrant:
4
- build: .
5
- ports:
6
- - "6333:6333"
 
 
 
 
 
 
 
s3_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from enum import Enum
3
+
4
+ import boto3
5
+ from botocore.client import BaseClient
6
+
7
+
8
+ # S3 HANDLING ######################################################################################
9
+ def get_md5(fpath):
10
+ with open(fpath, "rb") as f:
11
+ file_hash = hashlib.md5()
12
+ while chunk := f.read(8192):
13
+ file_hash.update(chunk)
14
+ return file_hash.hexdigest()
15
+
16
+
17
+ def upload_file_to_bucket(s3_client, file_obj, bucket, s3key):
18
+ """Upload a file to an S3 bucket
19
+ :param file_obj: File to upload
20
+ :param bucket: Bucket to upload to
21
+ :param s3key: s3key
22
+ :param object_name: S3 object name. If not specified then file_name is used
23
+ :return: True if file was uploaded, else False
24
+ """
25
+ # Upload the file
26
+ return s3_client.upload_fileobj(
27
+ file_obj, bucket, s3key,
28
+ ExtraArgs={"ACL": "public-read", "ContentType": "Content-Type: audio/mpeg"}
29
+ )
30
+
31
+
32
+ def s3_auth(aws_access_key_id, aws_secret_access_key, region_name) -> BaseClient:
33
+ s3 = boto3.client(
34
+ service_name='s3',
35
+ aws_access_key_id=aws_access_key_id,
36
+ aws_secret_access_key=aws_secret_access_key,
37
+ region_name=region_name
38
+ )
39
+ return s3
40
+
41
+
42
+ def get_list_of_buckets(s3: BaseClient):
43
+ response = s3.list_buckets()
44
+ buckets = {}
45
+
46
+ for buckets in response['Buckets']:
47
+ buckets[response['Name']] = response['Name']
48
+
49
+ BucketName = Enum('BucketName', buckets)
50
+ return BucketName
51
+
52
+
53
+ if __name__ == '__main__':
54
+ import os
55
+
56
+ AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
57
+ AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
58
+ S3_BUCKET = "synthia-research"
59
+ S3_FOLDER = "huggingface_spaces_demo"
60
+ AWS_REGION = "eu-west-3"
61
+
62
+ s3 = s3_auth(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
63
+ print(s3.list_buckets())
64
+
65
+ s3key = f'{S3_FOLDER}/015.WAV'
66
+ print(upload_file_to_bucket(s3, file_obj, S3_BUCKET, s3key))
setup.sh DELETED
@@ -1,2 +0,0 @@
1
- python database.py
2
- python app.py