jvcanavarro commited on
Commit
8b843d9
1 Parent(s): 31916e9

Add app v0.1

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import librosa
4
+ import time
5
+ import pandas as pd
6
+ from datetime import datetime
7
+ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
8
+
9
+ DESCRIPTION = "Store a record of previous calls in order to verify if the client already called or not. Pretrained on `https://huggingface.co/datasets/superb` using [S3PRL recipe](https://github.com/s3prl/s3prl/tree/master/s3prl/downstream/voxceleb1)."
10
+
11
+ # COLUMNS = ["call_id", "date", "client_id", "duration", "new"]
12
+ model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-large-superb-sid")
13
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-large-superb-sid")
14
+
15
+ def file_to_array(path):
16
+ speech, _ = librosa.load(path, sr=16000, mono=True)
17
+ duration = librosa.get_duration(y=speech)
18
+ return speech, duration
19
+
20
+
21
+ def handler(audio_path):
22
+ calls = pd.read_csv("call_records.csv")
23
+ speech, duration = file_to_array(audio_path)
24
+
25
+ # compute attention masks and normalize the waveform if needed
26
+ inputs = feature_extractor(speech, sampling_rate=16000, padding=True, return_tensors="pt")
27
+
28
+ logits = model(**inputs).logits
29
+ predicted_ids = torch.argmax(logits, dim=-1)
30
+ labels = [model.config.id2label[_id] for _id in predicted_ids.tolist()]
31
+
32
+ client_id = labels[0]
33
+ call_id = str(int(time.time()))
34
+
35
+ date = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
36
+
37
+ n_of_calls = len(calls.loc[calls.client_id == client_id])
38
+ new = n_of_calls == 0
39
+
40
+ # add new call record
41
+ record = [call_id, date, client_id, duration, new]
42
+ calls.loc[len(calls)] = record
43
+
44
+ calls.to_csv("call_records.csv", index=False)
45
+
46
+ if new:
47
+ return f"New client call: Client ID {client_id}"
48
+
49
+ return f"Client {client_id} calling again: {n_of_calls} previous calls"
50
+
51
+
52
+ first = gr.Interface(
53
+ fn=handler,
54
+ inputs=gr.Audio(label="Speech Audio", type="filepath"),
55
+
56
+ outputs=gr.Text(label="Output", value="..."),
57
+ description=DESCRIPTION
58
+
59
+ )
60
+
61
+ second = gr.Interface(
62
+ fn=handler,
63
+ inputs=gr.Audio(label="Microphone Input", source="microphone", type="filepath"),
64
+ outputs=gr.Text(label="Output", value="..."),
65
+ description=DESCRIPTION
66
+ )
67
+
68
+ app = gr.TabbedInterface(
69
+ [first, second],
70
+ title="Speaker Call Verification 🎤",
71
+ tab_names=["Audio Upload", "Microphone"],
72
+ )
73
+ app.launch()