leadr64 commited on
Commit
a51a160
1 Parent(s): 1689e46

Ajouter le script Gradio et les dépendances

Browse files
Files changed (5) hide show
  1. .env +2 -0
  2. Dockerfile +8 -0
  3. app.py +11 -6
  4. database.py +18 -16
  5. docker-compose.yml +6 -0
.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ QDRANT_HOST=localhost
2
+ QDRANT_PORT=6333
Dockerfile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Utilisez une image de base pour Qdrant
2
+ FROM qdrant/qdrant
3
+
4
+ # Exposez le port par défaut de Qdrant
5
+ EXPOSE 6333
6
+
7
+ # Commande pour démarrer Qdrant
8
+ CMD ["qdrant"]
app.py CHANGED
@@ -1,18 +1,23 @@
1
  import gradio as gr
2
  import laion_clap
3
  from qdrant_client import QdrantClient
 
4
 
5
- # Loading the Qdrant DB in local ###################################################################
6
- client = QdrantClient("localhost", port=6333)
 
 
 
 
7
  print("[INFO] Client created...")
8
 
9
- # loading the model
10
  print("[INFO] Loading the model...")
11
  model_name = "laion/larger_clap_music"
12
  model = laion_clap.CLAP_Module(enable_fusion=False)
13
- model.load_ckpt() # download the default pretrained checkpoint.
14
 
15
- # Gradio Interface #################################################################################
16
  max_results = 10
17
 
18
  def sound_search(query):
@@ -34,7 +39,7 @@ with gr.Blocks() as demo:
34
  """# Sound search database """
35
  )
36
  inp = gr.Textbox(placeholder="What sound are you looking for ?")
37
- out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Necessary to have different objs
38
  inp.change(sound_search, inp, out)
39
 
40
  demo.launch()
 
1
  import gradio as gr
2
  import laion_clap
3
  from qdrant_client import QdrantClient
4
+ import os
5
 
6
+ # Utilisez les variables d'environnement pour la configuration
7
+ QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost')
8
+ QDRANT_PORT = int(os.getenv('QDRANT_PORT', 6333))
9
+
10
+ # Connexion à Qdrant
11
+ client = QdrantClient(QDRANT_HOST, port=QDRANT_PORT)
12
  print("[INFO] Client created...")
13
 
14
+ # Charger le modèle
15
  print("[INFO] Loading the model...")
16
  model_name = "laion/larger_clap_music"
17
  model = laion_clap.CLAP_Module(enable_fusion=False)
18
+ model.load_ckpt() # télécharger le checkpoint préentraîné par défaut
19
 
20
+ # Interface Gradio
21
  max_results = 10
22
 
23
  def sound_search(query):
 
39
  """# Sound search database """
40
  )
41
  inp = gr.Textbox(placeholder="What sound are you looking for ?")
42
+ out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Nécessaire pour avoir différents objets
43
  inp.change(sound_search, inp, out)
44
 
45
  demo.launch()
database.py CHANGED
@@ -8,8 +8,11 @@ from qdrant_client import QdrantClient
8
  from qdrant_client.http import models
9
  from tqdm import tqdm
10
 
 
 
 
11
 
12
- # Functions utils ##################################################################################
13
  def get_md5(fpath):
14
  with open(fpath, "rb") as f:
15
  file_hash = hashlib.md5()
@@ -17,52 +20,51 @@ def get_md5(fpath):
17
  file_hash.update(chunk)
18
  return file_hash.hexdigest()
19
 
20
-
21
- # PARAMETERS #######################################################################################
22
  CACHE_FOLDER = '/home/nahia/data/audio/'
23
  KAGGLE_TRAIN_PATH = '/home/nahia/Documents/audio/actor/Actor_01/'
24
 
25
- # ################## Loading the CLAP model ###################
26
  print("[INFO] Loading the model...")
27
  model_name = 'music_speech_epoch_15_esc_89.25.pt'
28
  model = laion_clap.CLAP_Module(enable_fusion=False)
29
- model.load_ckpt() # download the default pretrained checkpoint.
30
 
31
- # Initialize the cache
32
  os.makedirs(CACHE_FOLDER, exist_ok=True)
33
  cache = Cache(CACHE_FOLDER)
34
 
35
- # Embed the audio files !
36
  audio_files = [p for p in glob(os.path.join(KAGGLE_TRAIN_PATH, '*.wav'))]
37
  audio_embeddings = []
38
  chunk_size = 100
39
  total_chunks = int(len(audio_files) / chunk_size)
40
 
41
- # Use tqdm for a progress bar
42
  for i in tqdm(range(0, len(audio_files), chunk_size), total=total_chunks):
43
- chunk = audio_files[i:i + chunk_size] # Get a chunk of audio files
44
  chunk_embeddings = []
45
 
46
  for audio_file in chunk:
47
- # Compute a unique hash for the audio file
48
  file_key = get_md5(audio_file)
49
 
50
  if file_key in cache:
51
- # If the embedding for this file is cached, retrieve it
52
  embedding = cache[file_key]
53
  else:
54
- # Otherwise, compute the embedding and cache it
55
  embedding = model.get_audio_embedding_from_filelist(x=[audio_file], use_tensor=False)[
56
- 0] # Assuming the model returns a list
57
  cache[file_key] = embedding
58
  chunk_embeddings.append(embedding)
59
  audio_embeddings.extend(chunk_embeddings)
60
 
61
- # It's a good practice to close the cache when done
62
  cache.close()
63
 
64
- # Creating a qdrant collection #####################################################################
65
- client = QdrantClient("localhost", port=6333)
66
  print("[INFO] Client created...")
67
 
68
  print("[INFO] Creating qdrant data collection...")
 
8
  from qdrant_client.http import models
9
  from tqdm import tqdm
10
 
11
+ # Utiliser les variables d'environnement pour la configuration
12
+ QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost')
13
+ QDRANT_PORT = int(os.getenv('QDRANT_PORT', 6333))
14
 
15
+ # Functions utils
16
  def get_md5(fpath):
17
  with open(fpath, "rb") as f:
18
  file_hash = hashlib.md5()
 
20
  file_hash.update(chunk)
21
  return file_hash.hexdigest()
22
 
23
+ # PARAMETERS
 
24
  CACHE_FOLDER = '/home/nahia/data/audio/'
25
  KAGGLE_TRAIN_PATH = '/home/nahia/Documents/audio/actor/Actor_01/'
26
 
27
+ # Charger le modèle CLAP
28
  print("[INFO] Loading the model...")
29
  model_name = 'music_speech_epoch_15_esc_89.25.pt'
30
  model = laion_clap.CLAP_Module(enable_fusion=False)
31
+ model.load_ckpt() # télécharger le checkpoint préentraîné par défaut
32
 
33
+ # Initialiser le cache
34
  os.makedirs(CACHE_FOLDER, exist_ok=True)
35
  cache = Cache(CACHE_FOLDER)
36
 
37
+ # Embarquer les fichiers audio
38
  audio_files = [p for p in glob(os.path.join(KAGGLE_TRAIN_PATH, '*.wav'))]
39
  audio_embeddings = []
40
  chunk_size = 100
41
  total_chunks = int(len(audio_files) / chunk_size)
42
 
43
+ # Utiliser tqdm pour une barre de progression
44
  for i in tqdm(range(0, len(audio_files), chunk_size), total=total_chunks):
45
+ chunk = audio_files[i:i + chunk_size] # Obtenir un chunk de fichiers audio
46
  chunk_embeddings = []
47
 
48
  for audio_file in chunk:
49
+ # Calculer un hash unique pour le fichier audio
50
  file_key = get_md5(audio_file)
51
 
52
  if file_key in cache:
53
+ # Si l'embedding pour ce fichier est en cache, le récupérer
54
  embedding = cache[file_key]
55
  else:
56
+ # Sinon, calculer l'embedding et le mettre en cache
57
  embedding = model.get_audio_embedding_from_filelist(x=[audio_file], use_tensor=False)[
58
+ 0] # Assumer que le modèle retourne une liste
59
  cache[file_key] = embedding
60
  chunk_embeddings.append(embedding)
61
  audio_embeddings.extend(chunk_embeddings)
62
 
63
+ # Fermer le cache quand terminé
64
  cache.close()
65
 
66
+ # Créer une collection qdrant
67
+ client = QdrantClient(QDRANT_HOST, port=QDRANT_PORT)
68
  print("[INFO] Client created...")
69
 
70
  print("[INFO] Creating qdrant data collection...")
docker-compose.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+ services:
3
+ qdrant:
4
+ build: .
5
+ ports:
6
+ - "6333:6333"