Spaces:
Running
Running
Marcos12886
commited on
Commit
•
166aa6c
1
Parent(s):
abdf62b
Decibelios. Llamar modelos mejor. Mejorar botones...
Browse files- app.py +69 -72
- interfaz.py +2 -2
- model.py +9 -9
app.py
CHANGED
@@ -7,71 +7,63 @@ from interfaz import estilo, my_theme
|
|
7 |
|
8 |
token = os.getenv("HF_TOKEN")
|
9 |
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
|
10 |
-
|
|
|
|
|
11 |
|
12 |
-
def
|
13 |
-
if (model_path, dataset_path, filter_white_noise) not in model_cache:
|
14 |
-
model, _, _, id2label = predict_params(dataset_path, model_path, filter_white_noise)
|
15 |
-
model_cache[(model_path, dataset_path, filter_white_noise)] = (model, id2label)
|
16 |
-
return model_cache[(model_path, dataset_path, filter_white_noise)]
|
17 |
-
|
18 |
-
def predict(audio_path, model_path, dataset_path, filter_white_noise):
|
19 |
-
model, id2label = load_model_and_dataset(model_path, dataset_path, filter_white_noise)
|
20 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
model.to(device)
|
22 |
model.eval()
|
23 |
-
|
24 |
-
|
|
|
25 |
with torch.no_grad():
|
26 |
outputs = model(**inputs)
|
27 |
logits = outputs.logits
|
28 |
-
|
29 |
-
label = id2label[predicted_class_ids]
|
30 |
-
if dataset_path == "data/mixed_data":
|
31 |
-
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
|
32 |
-
label = label_mapping.get(predicted_class_ids, label)
|
33 |
-
return label
|
34 |
|
35 |
-
def
|
36 |
-
model_mon, _ = load_model_and_dataset(
|
37 |
-
model_path="distilhubert-finetuned-cry-detector",
|
38 |
-
dataset_path="data/baby_cry_detection",
|
39 |
-
filter_white_noise=False
|
40 |
-
)
|
41 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
-
model_mon.to(device)
|
43 |
-
model_mon.eval()
|
44 |
-
audio_dataset = AudioDataset(dataset_path="data/baby_cry_detection", label2id={}, filter_white_noise=False)
|
45 |
-
processed_audio = audio_dataset.preprocess_audio(audio_path)
|
46 |
-
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
|
47 |
with torch.no_grad():
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
51 |
crying_probabilities = probabilities[:, 1]
|
52 |
-
avg_crying_probability = crying_probabilities.mean()
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
73 |
else:
|
74 |
-
return
|
75 |
|
76 |
def chatbot_config(message, history: list[tuple[str, str]]):
|
77 |
system_message = "You are a Chatbot specialized in baby health and care."
|
@@ -105,12 +97,12 @@ with gr.Blocks(theme=my_theme) as demo:
|
|
105 |
with gr.Row():
|
106 |
with gr.Column():
|
107 |
gr.Markdown("<h2>Predictor</h2>")
|
108 |
-
|
109 |
-
gr.Markdown("<p>Descubre por qué llora tu
|
110 |
with gr.Column():
|
111 |
gr.Markdown("<h2>Monitor</h2>")
|
112 |
-
|
113 |
-
gr.Markdown("<p>
|
114 |
with gr.Column(visible=False) as pag_predictor:
|
115 |
gr.Markdown("<h2>Predictor</h2>")
|
116 |
audio_input = gr.Audio(
|
@@ -119,14 +111,8 @@ with gr.Blocks(theme=my_theme) as demo:
|
|
119 |
label="Baby recorder",
|
120 |
type="filepath",
|
121 |
)
|
122 |
-
|
123 |
-
|
124 |
-
lambda audio: predict( # Mirar porque usar lambda
|
125 |
-
audio,
|
126 |
-
model_path="distilhubert-finetuned-mixed-data",
|
127 |
-
dataset_path="data/mixed_data",
|
128 |
-
filter_white_noise=True
|
129 |
-
),
|
130 |
inputs=audio_input,
|
131 |
outputs=gr.Textbox(label="Tu bebé llora por:")
|
132 |
)
|
@@ -134,18 +120,29 @@ with gr.Blocks(theme=my_theme) as demo:
|
|
134 |
with gr.Column(visible=False) as pag_monitor:
|
135 |
gr.Markdown("<h2>Monitor</h2>")
|
136 |
audio_stream = gr.Audio(
|
137 |
-
# min_length=1.0, # mirar por qué no va esto
|
138 |
format="wav",
|
139 |
label="Baby recorder",
|
140 |
type="filepath",
|
141 |
streaming=True
|
142 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
audio_stream.stream(
|
144 |
-
|
145 |
-
inputs=audio_stream,
|
146 |
-
outputs=gr.Textbox(label="Tu
|
147 |
)
|
148 |
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
|
149 |
-
|
150 |
-
|
151 |
demo.launch(share=True)
|
|
|
7 |
|
8 |
token = os.getenv("HF_TOKEN")
|
9 |
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=token)
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
model_class, id2label_class = predict_params(model_path="distilhubert-finetuned-mixed-data", dataset_path="data/mixed_data", filter_white_noise=True)
|
12 |
+
model_mon, id2label_mon = predict_params(model_path="distilhubert-finetuned-cry-detector", dataset_path="data/baby_cry_detection", filter_white_noise=False)
|
13 |
|
14 |
+
def call(audiopath, model, dataset_path, filter_white_noise):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
model.to(device)
|
16 |
model.eval()
|
17 |
+
audio_dataset = AudioDataset(dataset_path, {}, filter_white_noise,)
|
18 |
+
processed_audio = audio_dataset.preprocess_audio(audiopath)
|
19 |
+
inputs = {"input_values": processed_audio.to(device).unsqueeze(0)}
|
20 |
with torch.no_grad():
|
21 |
outputs = model(**inputs)
|
22 |
logits = outputs.logits
|
23 |
+
return logits
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
def predict(audio_path_pred):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
with torch.no_grad():
|
27 |
+
logits = call(audio_path_pred, model=model_class, dataset_path="data/mixed_data", filter_white_noise=True)
|
28 |
+
predicted_class_ids_class = torch.argmax(logits, dim=-1).item()
|
29 |
+
label_class = id2label_class[predicted_class_ids_class]
|
30 |
+
label_mapping = {0: 'Hambre', 1: 'Problemas para respirar', 2: 'Dolor', 3: 'Cansancio/Incomodidad'}
|
31 |
+
label_class = label_mapping.get(predicted_class_ids_class, label_class)
|
32 |
+
return label_class
|
33 |
+
|
34 |
+
def predict_stream(audio_path_stream):
|
35 |
+
with torch.no_grad():
|
36 |
+
logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False)
|
37 |
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
38 |
crying_probabilities = probabilities[:, 1]
|
39 |
+
avg_crying_probability = crying_probabilities.mean()*100
|
40 |
+
if avg_crying_probability < 15:
|
41 |
+
label_class = predict(audio_path_stream)
|
42 |
+
return "Está llorando por:", f"{label_class}. Probabilidad: {avg_crying_probability:.1f}%"
|
43 |
+
else:
|
44 |
+
return "No está llorando.", f"Probabilidad: {avg_crying_probability:.1f}%"
|
45 |
+
|
46 |
+
def decibelios(audio_path_stream):
|
47 |
+
with torch.no_grad():
|
48 |
+
logits = call(audio_path_stream, model=model_mon, dataset_path="data/baby_cry_detection", filter_white_noise=False)
|
49 |
+
rms = torch.sqrt(torch.mean(torch.square(logits)))
|
50 |
+
db_level = 20 * torch.log10(rms + 1e-6).item()
|
51 |
+
return db_level
|
52 |
+
|
53 |
+
def mostrar_decibelios(audio_path_stream, visual_threshold):
|
54 |
+
db_level = decibelios(audio_path_stream)
|
55 |
+
if db_level < visual_threshold:
|
56 |
+
return f"Prediciendo. Decibelios: {db_level:.2f}"
|
57 |
+
elif db_level > visual_threshold:
|
58 |
+
return "No detectamos ruido..."
|
59 |
+
|
60 |
+
def predict_stream_decib(audio_path_stream, visual_threshold):
|
61 |
+
db_level = decibelios(audio_path_stream)
|
62 |
+
if db_level < visual_threshold:
|
63 |
+
llorando, probabilidad = predict_stream(audio_path_stream)
|
64 |
+
return f"{llorando} {probabilidad}"
|
65 |
else:
|
66 |
+
return ""
|
67 |
|
68 |
def chatbot_config(message, history: list[tuple[str, str]]):
|
69 |
system_message = "You are a Chatbot specialized in baby health and care."
|
|
|
97 |
with gr.Row():
|
98 |
with gr.Column():
|
99 |
gr.Markdown("<h2>Predictor</h2>")
|
100 |
+
boton_predictor = gr.Button("Prueba el predictor")
|
101 |
+
gr.Markdown("<p>Descubre por qué llora tu bebé</p>")
|
102 |
with gr.Column():
|
103 |
gr.Markdown("<h2>Monitor</h2>")
|
104 |
+
boton_monitor = gr.Button("Prueba el monitor")
|
105 |
+
gr.Markdown("<p>Monitoriza si tu hijo está llorando y por qué, sin levantarte del sofá</p>")
|
106 |
with gr.Column(visible=False) as pag_predictor:
|
107 |
gr.Markdown("<h2>Predictor</h2>")
|
108 |
audio_input = gr.Audio(
|
|
|
111 |
label="Baby recorder",
|
112 |
type="filepath",
|
113 |
)
|
114 |
+
gr.Button("¿Por qué llora?").click(
|
115 |
+
predict,
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
inputs=audio_input,
|
117 |
outputs=gr.Textbox(label="Tu bebé llora por:")
|
118 |
)
|
|
|
120 |
with gr.Column(visible=False) as pag_monitor:
|
121 |
gr.Markdown("<h2>Monitor</h2>")
|
122 |
audio_stream = gr.Audio(
|
|
|
123 |
format="wav",
|
124 |
label="Baby recorder",
|
125 |
type="filepath",
|
126 |
streaming=True
|
127 |
)
|
128 |
+
threshold_db = gr.Slider(
|
129 |
+
minimum=0,
|
130 |
+
maximum=100,
|
131 |
+
step=1,
|
132 |
+
value=30,
|
133 |
+
label="Umbral de dB para activar la predicción"
|
134 |
+
)
|
135 |
+
audio_stream.stream(
|
136 |
+
mostrar_decibelios,
|
137 |
+
inputs=[audio_stream, threshold_db],
|
138 |
+
outputs=gr.Textbox(value="Esperando...", label="Estado")
|
139 |
+
)
|
140 |
audio_stream.stream(
|
141 |
+
predict_stream_decib,
|
142 |
+
inputs=[audio_stream, threshold_db],
|
143 |
+
outputs=gr.Textbox(value="", label="Tu bebé:")
|
144 |
)
|
145 |
gr.Button("Volver a la pantalla inicial").click(cambiar_pestaña, outputs=[pag_monitor, chatbot])
|
146 |
+
boton_predictor.click(cambiar_pestaña, outputs=[chatbot, pag_predictor])
|
147 |
+
boton_monitor.click(cambiar_pestaña, outputs=[chatbot, pag_monitor])
|
148 |
demo.launch(share=True)
|
interfaz.py
CHANGED
@@ -93,9 +93,9 @@ def inicio():
|
|
93 |
with gr.Column():
|
94 |
gr.Markdown("<h2>Predictor</h2>")
|
95 |
boton_pagina_1 = gr.Button("Prueba el predictor")
|
96 |
-
gr.Markdown("<p>Descubre por qué llora tu
|
97 |
with gr.Column():
|
98 |
gr.Markdown("<h2>Monitor</h2>")
|
99 |
boton_pagina_2 = gr.Button("Prueba el monitor")
|
100 |
-
gr.Markdown("<p>
|
101 |
return boton_pagina_1, boton_pagina_2
|
|
|
93 |
with gr.Column():
|
94 |
gr.Markdown("<h2>Predictor</h2>")
|
95 |
boton_pagina_1 = gr.Button("Prueba el predictor")
|
96 |
+
gr.Markdown("<p>Descubre por qué llora tu bebé</p>")
|
97 |
with gr.Column():
|
98 |
gr.Markdown("<h2>Monitor</h2>")
|
99 |
boton_pagina_2 = gr.Button("Prueba el monitor")
|
100 |
+
gr.Markdown("<p>Detecta si tu hijo está llorando y por qué antes de que puedas levantarte del sofá</p>")
|
101 |
return boton_pagina_1, boton_pagina_2
|
model.py
CHANGED
@@ -5,8 +5,8 @@ import torch
|
|
5 |
import torchaudio
|
6 |
from torch.utils.data import Dataset, DataLoader
|
7 |
from huggingface_hub import upload_folder
|
8 |
-
from transformers.integrations import TensorBoardCallback
|
9 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
|
|
10 |
from transformers import (
|
11 |
Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
|
12 |
Trainer, TrainingArguments,
|
@@ -121,7 +121,7 @@ def create_dataloader(dataset_path, filter_white_noise, test_size=0.2, shuffle=T
|
|
121 |
)
|
122 |
return train_dataloader, test_dataloader, label2id, id2label
|
123 |
|
124 |
-
def load_model(model_path,
|
125 |
config = HubertConfig.from_pretrained(
|
126 |
pretrained_model_name_or_path=model_path,
|
127 |
num_labels=num_labels,
|
@@ -140,13 +140,13 @@ def load_model(model_path, num_labels, label2id, id2label):
|
|
140 |
|
141 |
def train_params(dataset_path, filter_white_noise):
|
142 |
train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
|
143 |
-
model = load_model(
|
144 |
return model, train_dataloader, test_dataloader, id2label
|
145 |
|
146 |
def predict_params(dataset_path, model_path, filter_white_noise):
|
147 |
_, _, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
|
148 |
-
model = load_model(model_path,
|
149 |
-
return model,
|
150 |
|
151 |
def compute_metrics(eval_pred):
|
152 |
predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
|
@@ -187,10 +187,10 @@ def load_config(model_name):
|
|
187 |
return model_config
|
188 |
|
189 |
if __name__ == "__main__":
|
190 |
-
config = load_config(clasificador) # PARA CAMBIAR MODELOS
|
191 |
-
filter_white_noise = True
|
192 |
-
|
193 |
-
|
194 |
training_args = config["training_args"]
|
195 |
output_dir = config["output_dir"]
|
196 |
dataset_path = config["dataset_path"]
|
|
|
5 |
import torchaudio
|
6 |
from torch.utils.data import Dataset, DataLoader
|
7 |
from huggingface_hub import upload_folder
|
|
|
8 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
9 |
+
from transformers.integrations import TensorBoardCallback
|
10 |
from transformers import (
|
11 |
Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
|
12 |
Trainer, TrainingArguments,
|
|
|
121 |
)
|
122 |
return train_dataloader, test_dataloader, label2id, id2label
|
123 |
|
124 |
+
def load_model(model_path, label2id, id2label, num_labels):
|
125 |
config = HubertConfig.from_pretrained(
|
126 |
pretrained_model_name_or_path=model_path,
|
127 |
num_labels=num_labels,
|
|
|
140 |
|
141 |
def train_params(dataset_path, filter_white_noise):
|
142 |
train_dataloader, test_dataloader, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
|
143 |
+
model = load_model(MODEL, label2id, id2label, num_labels=len(id2label))
|
144 |
return model, train_dataloader, test_dataloader, id2label
|
145 |
|
146 |
def predict_params(dataset_path, model_path, filter_white_noise):
|
147 |
_, _, label2id, id2label = create_dataloader(dataset_path, filter_white_noise)
|
148 |
+
model = load_model(model_path, label2id, id2label, num_labels=len(id2label))
|
149 |
+
return model, id2label
|
150 |
|
151 |
def compute_metrics(eval_pred):
|
152 |
predictions = torch.argmax(torch.tensor(eval_pred.predictions), dim=-1)
|
|
|
187 |
return model_config
|
188 |
|
189 |
if __name__ == "__main__":
|
190 |
+
# config = load_config(clasificador) # PARA CAMBIAR MODELOS
|
191 |
+
# filter_white_noise = True
|
192 |
+
config = load_config(monitor) # PARA CAMBIAR MODELOS
|
193 |
+
filter_white_noise = False
|
194 |
training_args = config["training_args"]
|
195 |
output_dir = config["output_dir"]
|
196 |
dataset_path = config["dataset_path"]
|