leadr64 commited on
Commit
f2865dc
1 Parent(s): 1cfb19d

Ajouter le script Gradio et les dépendances

Browse files
Files changed (2) hide show
  1. app.py +42 -0
  2. database.py +90 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import laion_clap
3
+ from qdrant_client import QdrantClient
4
+
5
+ # Loading the Qdrant DB in local ###################################################################
6
+ client = QdrantClient("localhost", port=6333)
7
+ print("[INFO] Client created...")
8
+
9
+ # loading the model
10
+ print("[INFO] Loading the model...")
11
+ model_name = "laion/larger_clap_music"
12
+ model = laion_clap.CLAP_Module(enable_fusion=False)
13
+ model.load_ckpt() # download the default pretrained checkpoint.
14
+
15
+ # Gradio Interface #################################################################################
16
+ max_results = 10
17
+
18
+
19
+ def sound_search(query):
20
+ text_embed = model.get_text_embedding([query, ''])[0] # trick because can't accept singleton
21
+ hits = client.search(
22
+ collection_name="demo_db7",
23
+ query_vector=text_embed,
24
+ limit=max_results,
25
+ )
26
+ return [
27
+ gr.Audio(
28
+ hit.payload['audio_path'],
29
+ label=f"style: {hit.payload['style']} -- score: {hit.score}")
30
+ for hit in hits
31
+ ]
32
+
33
+
34
+ with gr.Blocks() as demo:
35
+ gr.Markdown(
36
+ """# Sound search database """
37
+ )
38
+ inp = gr.Textbox(placeholder="What sound are you looking for ?")
39
+ out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Necessary to have different objs
40
+ inp.change(sound_search, inp, out)
41
+
42
+ demo.launch()
database.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ from glob import glob
4
+
5
+ import laion_clap
6
+ from diskcache import Cache
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.http import models
9
+ from tqdm import tqdm
10
+
11
+
12
+ # Functions utils ##################################################################################
13
+ def get_md5(fpath):
14
+ with open(fpath, "rb") as f:
15
+ file_hash = hashlib.md5()
16
+ while chunk := f.read(8192):
17
+ file_hash.update(chunk)
18
+ return file_hash.hexdigest()
19
+
20
+
21
+ # PARAMETERS #######################################################################################
22
+ CACHE_FOLDER = '/home/nahia/data/audio/'
23
+ KAGGLE_TRAIN_PATH = '/home/nahia/Documents/audio/actor/Actor_01/'
24
+
25
+ # ################## Loading the CLAP model ###################
26
+ print("[INFO] Loading the model...")
27
+ model_name = 'music_speech_epoch_15_esc_89.25.pt'
28
+ model = laion_clap.CLAP_Module(enable_fusion=False)
29
+ model.load_ckpt() # download the default pretrained checkpoint.
30
+
31
+ # Initialize the cache
32
+ os.makedirs(CACHE_FOLDER, exist_ok=True)
33
+ cache = Cache(CACHE_FOLDER)
34
+
35
+ # Embed the audio files !
36
+ audio_files = [p for p in glob(os.path.join(KAGGLE_TRAIN_PATH, '*.wav'))]
37
+ audio_embeddings = []
38
+ chunk_size = 100
39
+ total_chunks = int(len(audio_files) / chunk_size)
40
+
41
+ # Use tqdm for a progress bar
42
+ for i in tqdm(range(0, len(audio_files), chunk_size), total=total_chunks):
43
+ chunk = audio_files[i:i + chunk_size] # Get a chunk of audio files
44
+ chunk_embeddings = []
45
+
46
+ for audio_file in chunk:
47
+ # Compute a unique hash for the audio file
48
+ file_key = get_md5(audio_file)
49
+
50
+ if file_key in cache:
51
+ # If the embedding for this file is cached, retrieve it
52
+ embedding = cache[file_key]
53
+ else:
54
+ # Otherwise, compute the embedding and cache it
55
+ embedding = model.get_audio_embedding_from_filelist(x=[audio_file], use_tensor=False)[
56
+ 0] # Assuming the model returns a list
57
+ cache[file_key] = embedding
58
+ chunk_embeddings.append(embedding)
59
+ audio_embeddings.extend(chunk_embeddings)
60
+
61
+ # It's a good practice to close the cache when done
62
+ cache.close()
63
+
64
+ # Creating a qdrant collection #####################################################################
65
+ client = QdrantClient("localhost", port=6333)
66
+ print("[INFO] Client created...")
67
+
68
+ print("[INFO] Creating qdrant data collection...")
69
+ client.create_collection(
70
+ collection_name="demo_db7",
71
+ vectors_config=models.VectorParams(
72
+ size=audio_embeddings[0].shape[0],
73
+ distance=models.Distance.COSINE
74
+ ),
75
+ )
76
+
77
+ # Créer des enregistrements Qdrant à partir des embeddings
78
+ records = []
79
+ for idx, (audio_path, embedding) in enumerate(zip(audio_files, audio_embeddings)):
80
+ record = models.PointStruct(
81
+ id=idx,
82
+ vector=embedding,
83
+ payload={"audio_path": audio_path, "style": audio_path.split('/')[-2]}
84
+ )
85
+ records.append(record)
86
+
87
+ # Téléverser les enregistrements dans la collection Qdrant
88
+ print("[INFO] Uploading data records to data collection...")
89
+ client.upload_points(collection_name="demo_db7", points=records)
90
+ print("[INFO] Successfully uploaded data records to data collection!")