Teapack1 commited on
Commit
183ee92
1 Parent(s): 8c837c9

Initial commit

Browse files
Files changed (7) hide show
  1. README.md +2 -1
  2. _app.py +0 -50
  3. app.py +121 -48
  4. requirements.txt +2 -2
  5. static/script.js +49 -0
  6. static/styles.css +73 -0
  7. templates/index.html +55 -0
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: ASR W ZeroShotClassification Assistant
3
  emoji: 🦀
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.7.1
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: Smart Assistant - Audio Intent Classification
3
  emoji: 🦀
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.7.1
8
+ python_version: 3.10.4
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
_app.py DELETED
@@ -1,50 +0,0 @@
1
- from transformers import pipeline
2
- from transformers.pipelines.audio_utils import ffmpeg_microphone_live
3
- import torch
4
- import gradio as gr
5
-
6
- asr_model = "openai/whisper-tiny.en"
7
- nlp_model = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
8
-
9
- pipe = pipeline("automatic-speech-recognition", model=asr_model, device=device)
10
- sampling_rate = pipe.feature_extractor.sampling_rate
11
-
12
- chunk_length_s = 10 # how often returns the text
13
- stream_chunk_s = 1 # how often the microphone is checked for new audio
14
- mic = ffmpeg_microphone_live(
15
- sampling_rate=sampling_rate,
16
- chunk_length_s=chunk_length_s,
17
- stream_chunk_s=stream_chunk_s,
18
- )
19
-
20
- def listen_print_loop(responses):
21
- for response in responses:
22
- if response["text"]:
23
- print(response["text"], end="\r")
24
- return response["text"]
25
- if not response["partial"]:
26
- print("")
27
-
28
-
29
- classifier = pipeline("zero-shot-classification", model=nlp_model)
30
- candidate_labels = ["dim the light", "turn on light fully", "turn off light fully", "raise the light", "nothing about light"]
31
-
32
-
33
- while True:
34
- context = listen_print_loop(pipe(mic))
35
- print(context)
36
- output = classifier(context, candidate_labels, multi_label=False)
37
- top_label = output['labels'][0]
38
- top_score = output['scores'][0]
39
- print(f"Top Prediction: {top_label} with a score of {top_score:.2f}")
40
-
41
-
42
- iface = gr.Interface(
43
- fn=transcribe,
44
- inputs=gr.inputs.Audio(source="microphone", type="filepath"),
45
- outputs="text",
46
- title="Real-Time ASR Transcription",
47
- description="Speak into the microphone and get the real-time transcription."
48
- )
49
-
50
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,59 +1,132 @@
1
- import gradio as gr
2
- from transformers import pipeline
 
 
 
 
3
  import numpy as np
4
- import time
 
 
5
 
6
- # Initialize the pipelines
7
- transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en")
8
- classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
9
 
10
- candidate_labels = ["dim the light", "turn on light fully", "turn off light fully", "raise the light", "not about lighting"]
11
- last_update_time = time.time() - 5 # Initialize with a value to ensure immediate first update
 
 
 
 
12
 
13
- # Buffer to hold the last updated values
14
- last_transcription = ""
15
- last_classification = ""
16
 
17
- def transcribe_and_classify(stream, new_chunk):
18
- global last_update_time, last_transcription, last_classification
19
- sr, y = new_chunk
20
- y = y.astype(np.float32)
21
- y /= np.max(np.abs(y))
 
 
 
 
 
 
22
 
23
- # Concatenate new audio chunk to the stream
24
- if stream is not None:
25
- stream = np.concatenate([stream, y])
26
- else:
27
- stream = y
 
 
28
 
29
- # Transcribe the last 10 seconds of audio
30
- transcription = transcriber({"sampling_rate": sr, "task": "transcribe", "language": "english", "raw": stream})["text"]
31
- last_transcription = transcription # Update the buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Classify the transcribed text
34
- if transcription.strip():
35
- output = classifier(transcription, candidate_labels, multi_label=False)
36
- top_label = output['labels'][0]
37
- top_score = output['scores'][0]
38
- last_classification = f"{top_label.upper()}, score: {top_score:.2f}"
 
39
 
40
- # Return the last updated transcription and classification
41
- return stream, last_transcription, last_classification
42
-
43
- # Define the Gradio interface
44
- demo = gr.Interface(
45
- fn=transcribe_and_classify,
46
- inputs=[
47
- "state",
48
- gr.Audio(sources=["microphone"])
49
- ],
50
- outputs=[
51
- "state",
52
- "text",
53
- "text"
54
- ],
55
 
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Launch the demo
59
- demo.launch(debug=True, share=True)
 
1
+ from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.templating import Jinja2Templates
5
+ import os
6
+
7
  import numpy as np
8
+ from transformers import pipeline
9
+ import torch
10
+ from transformers.pipelines.audio_utils import ffmpeg_microphone_live
11
 
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
13
 
14
+ classifier = pipeline(
15
+ "audio-classification", model="MIT/ast-finetuned-speech-commands-v2", device=device
16
+ )
17
+ intent_class_pipe = pipeline(
18
+ "audio-classification", model="anton-l/xtreme_s_xlsr_minds14", device=device
19
+ )
20
 
 
 
 
21
 
22
+ async def launch_fn(
23
+ wake_word="marvin",
24
+ prob_threshold=0.5,
25
+ chunk_length_s=2.0,
26
+ stream_chunk_s=0.25,
27
+ debug=False,
28
+ ):
29
+ if wake_word not in classifier.model.config.label2id.keys():
30
+ raise ValueError(
31
+ f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}."
32
+ )
33
 
34
+ sampling_rate = classifier.feature_extractor.sampling_rate
35
+
36
+ mic = ffmpeg_microphone_live(
37
+ sampling_rate=sampling_rate,
38
+ chunk_length_s=chunk_length_s,
39
+ stream_chunk_s=stream_chunk_s,
40
+ )
41
 
42
+ print("Listening for wake word...")
43
+ for prediction in classifier(mic):
44
+ prediction = prediction[0]
45
+ if debug:
46
+ print(prediction)
47
+ if prediction["label"] == wake_word:
48
+ if prediction["score"] > prob_threshold:
49
+ return True
50
+
51
+
52
+ async def listen(websocket, chunk_length_s=2.0, stream_chunk_s=2.0):
53
+ sampling_rate = intent_class_pipe.feature_extractor.sampling_rate
54
+
55
+ mic = ffmpeg_microphone_live(
56
+ sampling_rate=sampling_rate,
57
+ chunk_length_s=chunk_length_s,
58
+ stream_chunk_s=stream_chunk_s,
59
+ )
60
+ audio_buffer = []
61
 
62
+ print("Listening")
63
+ for i in range(4):
64
+ audio_chunk = next(mic)
65
+ audio_buffer.append(audio_chunk["raw"])
66
+
67
+ prediction = intent_class_pipe(audio_chunk["raw"])
68
+ await websocket.send_text(f"chunk: {prediction[0]['label']} | {i+1} / 4")
69
 
70
+ if await is_silence(audio_chunk["raw"], threshold=0.7):
71
+ print("Silence detected, processing audio.")
72
+ break
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ combined_audio = np.concatenate(audio_buffer)
75
+ prediction = intent_class_pipe(combined_audio)
76
+ top_3_predictions = prediction[:3]
77
+ formatted_predictions = "\n".join([f"{pred['label']}: {pred['score'] * 100:.2f}%" for pred in top_3_predictions])
78
+ await websocket.send_text(f"classes: \n{formatted_predictions}")
79
+ return
80
+
81
+
82
+ async def is_silence(audio_chunk, threshold):
83
+ silence = intent_class_pipe(audio_chunk)
84
+ if silence[0]["label"] == "silence" and silence[0]["score"] > threshold:
85
+ return True
86
+ else:
87
+ return False
88
+
89
+
90
+ # Initialize FastAPI app
91
+ app = FastAPI()
92
+
93
+ # Set up static file directory
94
+ app.mount("/static", StaticFiles(directory="static"), name="static")
95
+
96
+ # Jinja2 Template for HTML rendering
97
+ templates = Jinja2Templates(directory="templates")
98
+
99
+
100
+ @app.get("/", response_class=HTMLResponse)
101
+ async def get_home(request: Request):
102
+ return templates.TemplateResponse("index.html", {"request": request})
103
+
104
+
105
+ @app.websocket("/ws")
106
+ async def websocket_endpoint(websocket: WebSocket):
107
+ await websocket.accept()
108
+ try:
109
+ process_active = False # Flag to track the state of the process
110
+
111
+ while True:
112
+ message = await websocket.receive_text()
113
+
114
+ if message == "start" and not process_active:
115
+ process_active = True
116
+ await websocket.send_text("Listening for wake word...")
117
+ wake_word_detected = await launch_fn(debug=True)
118
+ if wake_word_detected:
119
+ await websocket.send_text("Wake word detected. Listening for your query...")
120
+ await listen(websocket)
121
+ process_active = False # Reset the process flag
122
+
123
+ elif message == "stop":
124
+ if process_active:
125
+ # Implement logic to stop the ongoing process
126
+ # This might involve setting a flag that your launch_fn and listen functions check
127
+ process_active = False
128
+ await websocket.send_text("Process stopped. Ready to restart.")
129
+ break # Or keep the loop running if you want to allow restarting without reconnecting
130
 
131
+ except WebSocketDisconnect:
132
+ print("Client disconnected.")
requirements.txt CHANGED
@@ -2,5 +2,5 @@ torch
2
  transformers
3
  torchaudio
4
  numpy
5
- sentencepiece
6
- gradio
 
2
  transformers
3
  torchaudio
4
  numpy
5
+ fastapi
6
+ uvicorn[standard]
static/script.js ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ let ws;
2
+ let isRecording = false;
3
+
4
+ function toggleRecording() {
5
+ if (!isRecording) {
6
+ ws = new WebSocket("ws://localhost:8000/ws");
7
+ ws.onopen = () => ws.send("start");
8
+ ws.onmessage = (event) => {
9
+ const serverMessage = event.data;
10
+
11
+ if (serverMessage.startsWith("chunk:")) {
12
+ const chunkText = serverMessage.substring(6); // Remove "chunk:" prefix
13
+ document.getElementById('audio-chunks').innerText = chunkText;
14
+ } else if (serverMessage === "Restarting system...") {
15
+ isRecording = false;
16
+ updateButton();
17
+ } else {
18
+ document.getElementById('results').innerText = serverMessage;
19
+ }
20
+ };
21
+ isRecording = true;
22
+ } else {
23
+ ws.send("stop");
24
+ isRecording = false;
25
+ }
26
+ updateButton();
27
+ }
28
+
29
+ function updateButton() {
30
+ const startButton = document.getElementById('startBtn');
31
+ if (isRecording) {
32
+ startButton.innerText = "Stop";
33
+ startButton.className = "stop-button";
34
+ } else {
35
+ startButton.innerText = "Start";
36
+ startButton.className = "start-button";
37
+ }
38
+ }
39
+
40
+ document.getElementById('startBtn').addEventListener('click', toggleRecording);
41
+
42
+ document.getElementById('toggleClassListBtn').addEventListener('click', function() {
43
+ var classList = document.getElementById('class-list');
44
+ if (classList.style.display === "none") {
45
+ classList.style.display = "block";
46
+ } else {
47
+ classList.style.display = "none";
48
+ }
49
+ });
static/styles.css ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body, html {
2
+ margin: 0;
3
+ padding: 0;
4
+ font-family: Arial, sans-serif;
5
+ background-color: #eaeff2;
6
+ }
7
+
8
+ .container {
9
+ text-align: center;
10
+ margin-top: 50px;
11
+ }
12
+
13
+ header {
14
+ background-color: #007bff;
15
+ color: white;
16
+ padding: 20px 0;
17
+ }
18
+
19
+ main {
20
+ background-color: #ffffff;
21
+ padding: 20px;
22
+ margin: 20px auto;
23
+ border-radius: 10px;
24
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
25
+ width: 80%;
26
+ max-width: 800px;
27
+ }
28
+
29
+ section {
30
+ margin-bottom: 20px;
31
+ }
32
+
33
+ h1, h2 {
34
+ margin-bottom: 10px;
35
+ }
36
+
37
+ button {
38
+ background-color: #28a745;
39
+ color: white;
40
+ border: none;
41
+ padding: 10px 15px;
42
+ font-size: 16px;
43
+ border-radius: 5px;
44
+ cursor: pointer;
45
+ }
46
+
47
+ button:hover {
48
+ background-color: #218838;
49
+ }
50
+
51
+ #results {
52
+ padding: 20px;
53
+ background-color: #f4f4f4;
54
+ border: 1px solid #cccccc;
55
+ border-radius: 5px;
56
+ }
57
+
58
+ .start-button {
59
+ background-color: #28a745; /* Green */
60
+ /* other styling */
61
+ }
62
+
63
+ .stop-button {
64
+ background-color: #dc3545; /* Red */
65
+ /* other styling */
66
+ }
67
+
68
+ #audio-chunks {
69
+ padding: 20px;
70
+ background-color: #f4f4f4;
71
+ border: 1px solid #cccccc;
72
+ border-radius: 5px;
73
+ }
templates/index.html ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>ML Audio Demo</title>
5
+ <link rel="stylesheet" type="text/css" href="/static/styles.css">
6
+ </head>
7
+ <body>
8
+ <div class="container">
9
+ <header>
10
+
11
+ </header>
12
+ <main>
13
+ <h1>Audio Intent Clasification Demo</h1>
14
+ <section id="recording-section">
15
+ <p id="recording-instructions">The system is activated by pressing Start and calling a name <i>Marvin</i>.</p>
16
+ <p id="recording-instructions">After that the system listens to an audio query and classifies its intention.</p>
17
+ <p id="recording-instructions">Model is trained on <i>PolyAI/minds14</i>. Dataset covers 14 intents extracted from a commercial system in the e-banking domain.</p>
18
+ <h3>Start the System:</h3>
19
+ <button id="startBtn">Start</button>
20
+ <div id="restart-button-container"></div>
21
+ </section>
22
+ <section id="audio-chunks-section">
23
+ <h3>Partial Predictions:</h3>
24
+ <div id="audio-chunks">Partial results will appear here...</div>
25
+ </section>
26
+ <section id="results-section">
27
+ <h3>Final Result:</h3>
28
+ <div id="results"><b>Intent classification will appear here...</b></div>
29
+ </section>
30
+
31
+ <section id="class-list-section">
32
+ <button id="toggleClassListBtn">see all classes</button>
33
+ <div id="class-list" style="display: none;">
34
+ <p>1.abroad</p>
35
+ <p>2.address</p>
36
+ <p>3.app_error</p>
37
+ <p>4.atm_limit</p>
38
+ <p>5.balance</p>
39
+ <p>6.business_loan</p>
40
+ <p>7.card_issues</p>
41
+ <p>8.cash_deposit</p>
42
+ <p>9.direct_debit</p>
43
+ <p>10.freeze</p>
44
+ <p>11.high_value_payment</p>
45
+ <p>12.joint_account</p>
46
+ <p>13.latest_transactions</p>
47
+ <p>14.pay_bill</p>
48
+ </div>
49
+ </section>
50
+ </main>
51
+ </div>
52
+ <script src="/static/script.js?v=8"></script>
53
+
54
+ </body>
55
+ </html>