leadr64 commited on
Commit
36b8886
1 Parent(s): 8bc75c1

Ajouter le script Gradio et les dépendances*

Browse files
Files changed (3) hide show
  1. .env +3 -1
  2. app.py +51 -43
  3. database.py +3 -3
.env CHANGED
@@ -1,2 +1,4 @@
1
  QDRANT_URL=https://ebe79742-e3ac-4d09-a2c6-63946024cc7a.us-east4-0.gcp.cloud.qdrant.io
2
- QDRANT_KEY=_NnGLuSMH4Qwv-ancoFh88YvzuR7WbyidAorVOVQ_eMCbPhxTb2TSw
 
 
 
1
  QDRANT_URL=https://ebe79742-e3ac-4d09-a2c6-63946024cc7a.us-east4-0.gcp.cloud.qdrant.io
2
+ QDRANT_KEY=_NnGLuSMH4Qwv-ancoFh88YvzuR7WbyidAorVOVQ_eMCbPhxTb2TSw
3
+ AWS_ACCESS_KEY_ID=AKIAWOUASMWP5DM6RZG2
4
+ AWS_SECRET_ACCESS_KEY=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
app.py CHANGED
@@ -1,58 +1,66 @@
1
  import os
2
-
3
  import gradio as gr
4
  from qdrant_client import QdrantClient
5
  from transformers import ClapModel, ClapProcessor
6
  from dotenv import load_dotenv
7
- import os
8
 
9
  # Charger les variables d'environnement à partir du fichier .env
10
  load_dotenv()
11
 
12
- # Récupérer le mot de passe depuis les variables d'environnement
13
  QDRANT_URL = os.getenv('QDRANT_URL')
14
  QDRANT_KEY = os.getenv('QDRANT_KEY')
15
 
16
  # Vérifier les valeurs récupérées
17
  print(f"QDRANT_URL: {QDRANT_URL}")
18
  print(f"QDRANT_KEY: {QDRANT_KEY}")
19
- # Loading the Qdrant DB in local ###################################################################
20
- client = QdrantClient(QDRANT_URL, api_key=QDRANT_KEY)
21
- print("[INFO] Client created...")
22
-
23
- # loading the model
24
- print("[INFO] Loading the model...")
25
- model_name = "laion/larger_clap_general"
26
- model = ClapModel.from_pretrained(model_name)
27
- processor = ClapProcessor.from_pretrained(model_name)
28
-
29
- # Gradio Interface #################################################################################
30
- max_results = 10
31
-
32
-
33
- def sound_search(query):
34
- text_inputs = processor(text=query, return_tensors="pt")
35
- text_embed = model.get_text_features(**text_inputs)[0]
36
-
37
- hits = client.search(
38
- collection_name="demo_spaces_db",
39
- query_vector=text_embed,
40
- limit=max_results,
41
- )
42
- return [
43
- gr.Audio(
44
- hit.payload['audio_path'],
45
- label=f"style: {hit.payload['style']} -- score: {hit.score}")
46
- for hit in hits
47
- ]
48
-
49
-
50
- with gr.Blocks() as demo:
51
- gr.Markdown(
52
- """# Sound search database """
53
- )
54
- inp = gr.Textbox(placeholder="What sound are you looking for ?")
55
- out = [gr.Audio(label=f"{x}") for x in range(max_results)] # Necessary to have different objs
56
- inp.change(sound_search, inp, out)
57
-
58
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  import gradio as gr
3
  from qdrant_client import QdrantClient
4
  from transformers import ClapModel, ClapProcessor
5
  from dotenv import load_dotenv
6
+ import requests
7
 
8
  # Charger les variables d'environnement à partir du fichier .env
9
  load_dotenv()
10
 
11
+ # Récupérer les variables d'environnement
12
  QDRANT_URL = os.getenv('QDRANT_URL')
13
  QDRANT_KEY = os.getenv('QDRANT_KEY')
14
 
15
  # Vérifier les valeurs récupérées
16
  print(f"QDRANT_URL: {QDRANT_URL}")
17
  print(f"QDRANT_KEY: {QDRANT_KEY}")
18
+
19
+ try:
20
+ # Tester la connexion à l'URL de Qdrant
21
+ response = requests.get(QDRANT_URL)
22
+ print(f"Test de la connexion à Qdrant: {response.status_code}")
23
+
24
+ # Vérifier que les variables sont correctement récupérées
25
+ if not QDRANT_URL or not QDRANT_KEY:
26
+ raise ValueError("Les variables d'environnement QDRANT_URL ou QDRANT_KEY ne sont pas définies")
27
+
28
+ # Connexion au client Qdrant
29
+ client = QdrantClient(QDRANT_URL, api_key=QDRANT_KEY)
30
+ print("[INFO] Client created...")
31
+
32
+ # Chargement du modèle
33
+ print("[INFO] Loading the model...")
34
+ model_name = "laion/larger_clap_general"
35
+ model = ClapModel.from_pretrained(model_name)
36
+ processor = ClapProcessor.from_pretrained(model_name)
37
+
38
+ # Interface Gradio
39
+ max_results = 10
40
+
41
+ def sound_search(query):
42
+ text_inputs = processor(text=query, return_tensors="pt")
43
+ text_embed = model.get_text_features(**text_inputs)[0]
44
+
45
+ hits = client.search(
46
+ collection_name="demo_spaces_db",
47
+ query_vector=text_embed,
48
+ limit=max_results,
49
+ )
50
+ return [
51
+ gr.Audio(
52
+ hit.payload['audio_path'],
53
+ label=f"style: {hit.payload['style']} -- score: {hit.score}")
54
+ for hit in hits
55
+ ]
56
+
57
+ with gr.Blocks() as demo:
58
+ gr.Markdown("# Sound search database")
59
+ inp = gr.Textbox(placeholder="What sound are you looking for ?")
60
+ out = [gr.Audio(label=f"{x}") for x in range(max_results)]
61
+ inp.change(sound_search, inp, out)
62
+
63
+ demo.launch()
64
+
65
+ except Exception as e:
66
+ print(f"[ERROR] Failed to create Qdrant client: {e}")
database.py CHANGED
@@ -17,8 +17,8 @@ from dotenv import load_dotenv
17
  load_dotenv()
18
 
19
  # PARAMETERS #######################################################################################
20
- CACHE_FOLDER = '/home/arthur/data/music/demo_audio_search/audio_embeddings_cache_individual/'
21
- KAGGLE_DB_PATH = '/home/arthur/data/kaggle/park-spring-2023-music-genre-recognition/train/train'
22
  AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
23
  AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
24
  S3_BUCKET = "synthia-research"
@@ -77,7 +77,7 @@ os.makedirs(CACHE_FOLDER, exist_ok=True)
77
  cache = Cache(CACHE_FOLDER)
78
 
79
  # Creating a qdrant collection #####################################################################
80
- client = QdrantClient(QDRANT_URL,QDRANT_KEY)
81
  print("[INFO] Client created...")
82
 
83
  print("[INFO] Creating qdrant data collection...")
 
17
  load_dotenv()
18
 
19
  # PARAMETERS #######################################################################################
20
+ CACHE_FOLDER = '/home/nahia/audio'
21
+ KAGGLE_DB_PATH = '/home/nahia/Documents/audio/actor/Actor_01'
22
  AWS_ACCESS_KEY_ID = os.environ['AWS_ACCESS_KEY_ID']
23
  AWS_SECRET_ACCESS_KEY = os.environ['AWS_SECRET_ACCESS_KEY']
24
  S3_BUCKET = "synthia-research"
 
77
  cache = Cache(CACHE_FOLDER)
78
 
79
  # Creating a qdrant collection #####################################################################
80
+ client = QdrantClient(QDRANT_URL,api_key=QDRANT_KEY)
81
  print("[INFO] Client created...")
82
 
83
  print("[INFO] Creating qdrant data collection...")