andito HF staff commited on
Commit
c72e80d
·
verified ·
1 Parent(s): 6412783

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +13 -0
  2. Dockerfile.arm64 +13 -0
  3. LICENSE +201 -0
  4. LLM/__pycache__/chat.cpython-311.pyc +0 -0
  5. LLM/__pycache__/language_model.cpython-311.pyc +0 -0
  6. LLM/__pycache__/mlx_language_model.cpython-311.pyc +0 -0
  7. LLM/chat.py +25 -0
  8. LLM/language_model.py +144 -0
  9. LLM/mlx_language_model.py +107 -0
  10. README.md +244 -0
  11. STT/__pycache__/lightning_whisper_mlx_handler.cpython-311.pyc +0 -0
  12. STT/__pycache__/paraformer_handler.cpython-311.pyc +0 -0
  13. STT/__pycache__/whisper_stt_handler.cpython-311.pyc +0 -0
  14. STT/lightning_whisper_mlx_handler.py +85 -0
  15. STT/paraformer_handler.py +61 -0
  16. STT/whisper_stt_handler.py +140 -0
  17. TTS/__pycache__/chatTTS_handler.cpython-311.pyc +0 -0
  18. TTS/__pycache__/melo_handler.cpython-311.pyc +0 -0
  19. TTS/__pycache__/parler_handler.cpython-311.pyc +0 -0
  20. TTS/chatTTS_handler.py +82 -0
  21. TTS/melo_handler.py +109 -0
  22. TTS/parler_handler.py +191 -0
  23. VAD/__pycache__/vad_handler.cpython-311.pyc +0 -0
  24. VAD/__pycache__/vad_handler.cpython-312.pyc +0 -0
  25. VAD/__pycache__/vad_iterator.cpython-311.pyc +0 -0
  26. VAD/__pycache__/vad_iterator.cpython-312.pyc +0 -0
  27. VAD/vad_handler.py +92 -0
  28. VAD/vad_iterator.py +100 -0
  29. arguments_classes/__pycache__/chat_tts_arguments.cpython-311.pyc +0 -0
  30. arguments_classes/__pycache__/language_model_arguments.cpython-311.pyc +0 -0
  31. arguments_classes/__pycache__/melo_tts_arguments.cpython-311.pyc +0 -0
  32. arguments_classes/__pycache__/mlx_language_model_arguments.cpython-311.pyc +0 -0
  33. arguments_classes/__pycache__/module_arguments.cpython-311.pyc +0 -0
  34. arguments_classes/__pycache__/paraformer_stt_arguments.cpython-311.pyc +0 -0
  35. arguments_classes/__pycache__/parler_tts_arguments.cpython-311.pyc +0 -0
  36. arguments_classes/__pycache__/socket_receiver_arguments.cpython-311.pyc +0 -0
  37. arguments_classes/__pycache__/socket_sender_arguments.cpython-311.pyc +0 -0
  38. arguments_classes/__pycache__/vad_arguments.cpython-311.pyc +0 -0
  39. arguments_classes/__pycache__/whisper_stt_arguments.cpython-311.pyc +0 -0
  40. arguments_classes/chat_tts_arguments.py +21 -0
  41. arguments_classes/language_model_arguments.py +71 -0
  42. arguments_classes/melo_tts_arguments.py +23 -0
  43. arguments_classes/mlx_language_model_arguments.py +65 -0
  44. arguments_classes/module_arguments.py +46 -0
  45. arguments_classes/paraformer_stt_arguments.py +17 -0
  46. arguments_classes/parler_tts_arguments.py +62 -0
  47. arguments_classes/socket_receiver_arguments.py +24 -0
  48. arguments_classes/socket_sender_arguments.py +18 -0
  49. arguments_classes/vad_arguments.py +47 -0
  50. arguments_classes/whisper_stt_arguments.py +64 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel
2
+
3
+ ENV PYTHONUNBUFFERED 1
4
+
5
+ WORKDIR /usr/src/app
6
+
7
+ # Install packages
8
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY requirements.txt ./
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ COPY . .
Dockerfile.arm64 ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3
2
+
3
+ ENV PYTHONUNBUFFERED 1
4
+
5
+ WORKDIR /usr/src/app
6
+
7
+ # Install packages
8
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY requirements.txt ./
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ COPY . .
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2024] [The HuggingFace Inc. team]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LLM/__pycache__/chat.cpython-311.pyc ADDED
Binary file (1.59 kB). View file
 
LLM/__pycache__/language_model.cpython-311.pyc ADDED
Binary file (6.31 kB). View file
 
LLM/__pycache__/mlx_language_model.cpython-311.pyc ADDED
Binary file (5 kB). View file
 
LLM/chat.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Chat:
2
+ """
3
+ Handles the chat using to avoid OOM issues.
4
+ """
5
+
6
+ def __init__(self, size):
7
+ self.size = size
8
+ self.init_chat_message = None
9
+ # maxlen is necessary pair, since a each new step we add an prompt and assitant answer
10
+ self.buffer = []
11
+
12
+ def append(self, item):
13
+ self.buffer.append(item)
14
+ if len(self.buffer) == 2 * (self.size + 1):
15
+ self.buffer.pop(0)
16
+ self.buffer.pop(0)
17
+
18
+ def init_chat(self, init_chat_message):
19
+ self.init_chat_message = init_chat_message
20
+
21
+ def to_list(self):
22
+ if self.init_chat_message:
23
+ return [self.init_chat_message] + self.buffer
24
+ else:
25
+ return self.buffer
LLM/language_model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ pipeline,
6
+ TextIteratorStreamer,
7
+ )
8
+ import torch
9
+
10
+ from LLM.chat import Chat
11
+ from baseHandler import BaseHandler
12
+ from rich.console import Console
13
+ import logging
14
+ from nltk import sent_tokenize
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ console = Console()
19
+
20
+
21
+ WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
22
+ "en": "english",
23
+ "fr": "french",
24
+ "es": "spanish",
25
+ "zh": "chinese",
26
+ "ja": "japanese",
27
+ "ko": "korean",
28
+ }
29
+
30
+ class LanguageModelHandler(BaseHandler):
31
+ """
32
+ Handles the language model part.
33
+ """
34
+
35
+ def setup(
36
+ self,
37
+ model_name="microsoft/Phi-3-mini-4k-instruct",
38
+ device="cuda",
39
+ torch_dtype="float16",
40
+ gen_kwargs={},
41
+ user_role="user",
42
+ chat_size=1,
43
+ init_chat_role=None,
44
+ init_chat_prompt="You are a helpful AI assistant.",
45
+ ):
46
+ self.device = device
47
+ self.torch_dtype = getattr(torch, torch_dtype)
48
+
49
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+ self.model = AutoModelForCausalLM.from_pretrained(
51
+ model_name, torch_dtype=torch_dtype, trust_remote_code=True
52
+ ).to(device)
53
+ self.pipe = pipeline(
54
+ "text-generation", model=self.model, tokenizer=self.tokenizer, device=device
55
+ )
56
+ self.streamer = TextIteratorStreamer(
57
+ self.tokenizer,
58
+ skip_prompt=True,
59
+ skip_special_tokens=True,
60
+ )
61
+ self.gen_kwargs = {
62
+ "streamer": self.streamer,
63
+ "return_full_text": False,
64
+ **gen_kwargs,
65
+ }
66
+
67
+ self.chat = Chat(chat_size)
68
+ if init_chat_role:
69
+ if not init_chat_prompt:
70
+ raise ValueError(
71
+ "An initial promt needs to be specified when setting init_chat_role."
72
+ )
73
+ self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
74
+ self.user_role = user_role
75
+
76
+ self.warmup()
77
+
78
+ def warmup(self):
79
+ logger.info(f"Warming up {self.__class__.__name__}")
80
+
81
+ dummy_input_text = "Repeat the word 'home'."
82
+ dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
83
+ warmup_gen_kwargs = {
84
+ "min_new_tokens": self.gen_kwargs["min_new_tokens"],
85
+ "max_new_tokens": self.gen_kwargs["max_new_tokens"],
86
+ **self.gen_kwargs,
87
+ }
88
+
89
+ n_steps = 2
90
+
91
+ if self.device == "cuda":
92
+ start_event = torch.cuda.Event(enable_timing=True)
93
+ end_event = torch.cuda.Event(enable_timing=True)
94
+ torch.cuda.synchronize()
95
+ start_event.record()
96
+
97
+ for _ in range(n_steps):
98
+ thread = Thread(
99
+ target=self.pipe, args=(dummy_chat,), kwargs=warmup_gen_kwargs
100
+ )
101
+ thread.start()
102
+ for _ in self.streamer:
103
+ pass
104
+
105
+ if self.device == "cuda":
106
+ end_event.record()
107
+ torch.cuda.synchronize()
108
+
109
+ logger.info(
110
+ f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
111
+ )
112
+
113
+ def process(self, prompt):
114
+ logger.debug("infering language model...")
115
+ language_code = None
116
+ if isinstance(prompt, tuple):
117
+ prompt, language_code = prompt
118
+ prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
119
+
120
+ self.chat.append({"role": self.user_role, "content": prompt})
121
+ thread = Thread(
122
+ target=self.pipe, args=(self.chat.to_list(),), kwargs=self.gen_kwargs
123
+ )
124
+ thread.start()
125
+ if self.device == "mps":
126
+ generated_text = ""
127
+ for new_text in self.streamer:
128
+ generated_text += new_text
129
+ printable_text = generated_text
130
+ torch.mps.empty_cache()
131
+ else:
132
+ generated_text, printable_text = "", ""
133
+ for new_text in self.streamer:
134
+ generated_text += new_text
135
+ printable_text += new_text
136
+ sentences = sent_tokenize(printable_text)
137
+ if len(sentences) > 1:
138
+ yield (sentences[0], language_code)
139
+ printable_text = new_text
140
+
141
+ self.chat.append({"role": "assistant", "content": generated_text})
142
+
143
+ # don't forget last sentence
144
+ yield (printable_text, language_code)
LLM/mlx_language_model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from LLM.chat import Chat
3
+ from baseHandler import BaseHandler
4
+ from mlx_lm import load, stream_generate, generate
5
+ from rich.console import Console
6
+ import torch
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ console = Console()
11
+
12
+ WHISPER_LANGUAGE_TO_LLM_LANGUAGE = {
13
+ "en": "english",
14
+ "fr": "french",
15
+ "es": "spanish",
16
+ "zh": "chinese",
17
+ "ja": "japanese",
18
+ "ko": "korean",
19
+ }
20
+
21
+ class MLXLanguageModelHandler(BaseHandler):
22
+ """
23
+ Handles the language model part.
24
+ """
25
+
26
+ def setup(
27
+ self,
28
+ model_name="microsoft/Phi-3-mini-4k-instruct",
29
+ device="mps",
30
+ torch_dtype="float16",
31
+ gen_kwargs={},
32
+ user_role="user",
33
+ chat_size=1,
34
+ init_chat_role=None,
35
+ init_chat_prompt="You are a helpful AI assistant.",
36
+ ):
37
+ self.model_name = model_name
38
+ self.model, self.tokenizer = load(self.model_name)
39
+ self.gen_kwargs = gen_kwargs
40
+
41
+ self.chat = Chat(chat_size)
42
+ if init_chat_role:
43
+ if not init_chat_prompt:
44
+ raise ValueError(
45
+ "An initial promt needs to be specified when setting init_chat_role."
46
+ )
47
+ self.chat.init_chat({"role": init_chat_role, "content": init_chat_prompt})
48
+ self.user_role = user_role
49
+
50
+ self.warmup()
51
+
52
+ def warmup(self):
53
+ logger.info(f"Warming up {self.__class__.__name__}")
54
+
55
+ dummy_input_text = "Repeat the word 'home'."
56
+ dummy_chat = [{"role": self.user_role, "content": dummy_input_text}]
57
+
58
+ n_steps = 2
59
+
60
+ for _ in range(n_steps):
61
+ prompt = self.tokenizer.apply_chat_template(dummy_chat, tokenize=False)
62
+ generate(
63
+ self.model,
64
+ self.tokenizer,
65
+ prompt=prompt,
66
+ max_tokens=self.gen_kwargs["max_new_tokens"],
67
+ verbose=False,
68
+ )
69
+
70
+ def process(self, prompt):
71
+ logger.debug("infering language model...")
72
+ language_code = None
73
+
74
+ if isinstance(prompt, tuple):
75
+ prompt, language_code = prompt
76
+ prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
77
+
78
+ self.chat.append({"role": self.user_role, "content": prompt})
79
+
80
+ # Remove system messages if using a Gemma model
81
+ if "gemma" in self.model_name.lower():
82
+ chat_messages = [
83
+ msg for msg in self.chat.to_list() if msg["role"] != "system"
84
+ ]
85
+ else:
86
+ chat_messages = self.chat.to_list()
87
+
88
+ prompt = self.tokenizer.apply_chat_template(
89
+ chat_messages, tokenize=False, add_generation_prompt=True
90
+ )
91
+ output = ""
92
+ curr_output = ""
93
+ for t in stream_generate(
94
+ self.model,
95
+ self.tokenizer,
96
+ prompt,
97
+ max_tokens=self.gen_kwargs["max_new_tokens"],
98
+ ):
99
+ output += t
100
+ curr_output += t
101
+ if curr_output.endswith((".", "?", "!", "<|end|>")):
102
+ yield (curr_output.replace("<|end|>", ""), language_code)
103
+ curr_output = ""
104
+ generated_text = output.replace("<|end|>", "")
105
+ torch.mps.empty_cache()
106
+
107
+ self.chat.append({"role": "assistant", "content": generated_text})
README.md ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <div>&nbsp;</div>
3
+ <img src="logo.png" width="600"/>
4
+ </div>
5
+
6
+ # Speech To Speech: an effort for an open-sourced and modular GPT4-o
7
+
8
+
9
+ ## 📖 Quick Index
10
+ * [Approach](#approach)
11
+ - [Structure](#structure)
12
+ - [Modularity](#modularity)
13
+ * [Setup](#setup)
14
+ * [Usage](#usage)
15
+ - [Docker Server approach](#docker-server)
16
+ - [Server/Client approach](#serverclient-approach)
17
+ - [Local approach](#local-approach-running-on-mac)
18
+ * [Command-line usage](#command-line-usage)
19
+ - [Model parameters](#model-parameters)
20
+ - [Generation parameters](#generation-parameters)
21
+ - [Notable parameters](#notable-parameters)
22
+
23
+ ## Approach
24
+
25
+ ### Structure
26
+ This repository implements a speech-to-speech cascaded pipeline with consecutive parts:
27
+ 1. **Voice Activity Detection (VAD)**: [silero VAD v5](https://github.com/snakers4/silero-vad)
28
+ 2. **Speech to Text (STT)**: Whisper checkpoints (including [distilled versions](https://huggingface.co/distil-whisper))
29
+ 3. **Language Model (LM)**: Any instruct model available on the [Hugging Face Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending)! 🤗
30
+ 4. **Text to Speech (TTS)**: [Parler-TTS](https://github.com/huggingface/parler-tts)🤗
31
+
32
+ ### Modularity
33
+ The pipeline aims to provide a fully open and modular approach, leveraging models available on the Transformers library via the Hugging Face hub. The level of modularity intended for each part is as follows:
34
+ - **VAD**: Uses the implementation from [Silero's repo](https://github.com/snakers4/silero-vad).
35
+ - **STT**: Uses Whisper models exclusively; however, any Whisper checkpoint can be used, enabling options like [Distil-Whisper](https://huggingface.co/distil-whisper/distil-large-v3) and [French Distil-Whisper](https://huggingface.co/eustlb/distil-large-v3-fr).
36
+ - **LM**: This part is fully modular and can be changed by simply modifying the Hugging Face hub model ID. Users need to select an instruct model since the usage here involves interacting with it.
37
+ - **TTS**: The mini architecture of Parler-TTS is standard, but different checkpoints, including fine-tuned multilingual checkpoints, can be used.
38
+
39
+ The code is designed to facilitate easy modification. Each component is implemented as a class and can be re-implemented to match specific needs.
40
+
41
+ ## Setup
42
+
43
+ Clone the repository:
44
+ ```bash
45
+ git clone https://github.com/huggingface/speech-to-speech.git
46
+ cd speech-to-speech
47
+ ```
48
+
49
+ Install the required dependencies using [uv](https://github.com/astral-sh/uv):
50
+ ```bash
51
+ uv pip install -r requirements.txt
52
+ ```
53
+
54
+ For Mac users, use the `requirements_mac.txt` file instead:
55
+ ```bash
56
+ uv pip install -r requirements_mac.txt
57
+ ```
58
+
59
+ If you want to use Melo TTS, you also need to run:
60
+ ```bash
61
+ python -m unidic download
62
+ ```
63
+
64
+
65
+ ## Usage
66
+
67
+ The pipeline can be run in two ways:
68
+ - **Server/Client approach**: Models run on a server, and audio input/output are streamed from a client.
69
+ - **Local approach**: Runs locally.
70
+
71
+ ### Docker Server
72
+
73
+ #### Install the NVIDIA Container Toolkit
74
+
75
+ https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
76
+
77
+ #### Start the docker container
78
+ ```docker compose up```
79
+
80
+ ### Server/Client Approach
81
+
82
+ 1. Run the pipeline on the server:
83
+ ```bash
84
+ python s2s_pipeline.py --recv_host 0.0.0.0 --send_host 0.0.0.0
85
+ ```
86
+
87
+ 2. Run the client locally to handle microphone input and receive generated audio:
88
+ ```bash
89
+ python listen_and_play.py --host <IP address of your server>
90
+ ```
91
+
92
+ ### Local Approach (Mac)
93
+
94
+ 1. For optimal settings on Mac:
95
+ ```bash
96
+ python s2s_pipeline.py --local_mac_optimal_settings
97
+ ```
98
+
99
+ This setting:
100
+ - Adds `--device mps` to use MPS for all models.
101
+ - Sets LightningWhisperMLX for STT
102
+ - Sets MLX LM for language model
103
+ - Sets MeloTTS for TTS
104
+
105
+ ### Recommended usage with Cuda
106
+
107
+ Leverage Torch Compile for Whisper and Parler-TTS:
108
+
109
+ ```bash
110
+ python s2s_pipeline.py \
111
+ --recv_host 0.0.0.0 \
112
+ --send_host 0.0.0.0 \
113
+ --lm_model_name microsoft/Phi-3-mini-4k-instruct \
114
+ --init_chat_role system \
115
+ --stt_compile_mode reduce-overhead \
116
+ --tts_compile_mode default
117
+ ```
118
+
119
+ For the moment, modes capturing CUDA Graphs are not compatible with streaming Parler-TTS (`reduce-overhead`, `max-autotune`).
120
+
121
+
122
+ ### Multi-language Support
123
+
124
+ The pipeline supports multiple languages, allowing for automatic language detection or specific language settings. Here are examples for both local (Mac) and server setups:
125
+
126
+ #### With the server version:
127
+
128
+
129
+ For automatic language detection:
130
+
131
+ ```bash
132
+ python s2s_pipeline.py \
133
+ --stt_model_name large-v3 \
134
+ --language zh \
135
+ --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
136
+ ```
137
+
138
+ Or for one language in particular, chinese in this example
139
+
140
+ ```bash
141
+ python s2s_pipeline.py \
142
+ --stt_model_name large-v3 \
143
+ --language zh \
144
+ --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct \
145
+ ```
146
+
147
+ #### Local Mac Setup
148
+
149
+ For automatic language detection:
150
+
151
+ ```bash
152
+ python s2s_pipeline.py \
153
+ --local_mac_optimal_settings \
154
+ --device mps \
155
+ --stt_model_name large-v3 \
156
+ --language zh \
157
+ --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
158
+ ```
159
+
160
+ Or for one language in particular, chinese in this example
161
+
162
+ ```bash
163
+ python s2s_pipeline.py \
164
+ --local_mac_optimal_settings \
165
+ --device mps \
166
+ --stt_model_name large-v3 \
167
+ --language zh \
168
+ --mlx_lm_model_name mlx-community/Meta-Llama-3.1-8B-Instruct-4bit \
169
+ ```
170
+
171
+
172
+ ## Command-line Usage
173
+
174
+ ### Model Parameters
175
+
176
+ `model_name`, `torch_dtype`, and `device` are exposed for each part leveraging the Transformers' implementations: Speech to Text, Language Model, and Text to Speech. Specify the targeted pipeline part with the corresponding prefix:
177
+ - `stt` (Speech to Text)
178
+ - `lm` (Language Model)
179
+ - `tts` (Text to Speech)
180
+
181
+ For example:
182
+ ```bash
183
+ --lm_model_name google/gemma-2b-it
184
+ ```
185
+
186
+ ### Generation Parameters
187
+
188
+ Other generation parameters of the model's generate method can be set using the part's prefix + `_gen_`, e.g., `--stt_gen_max_new_tokens 128`. These parameters can be added to the pipeline part's arguments class if not already exposed (see `LanguageModelHandlerArguments` for example).
189
+
190
+ ### Notable Parameters
191
+
192
+ #### VAD Parameters
193
+ - `--thresh`: Threshold value to trigger voice activity detection.
194
+ - `--min_speech_ms`: Minimum duration of detected voice activity to be considered speech.
195
+ - `--min_silence_ms`: Minimum length of silence intervals for segmenting speech, balancing sentence cutting and latency reduction.
196
+
197
+ #### Language Model
198
+ - `--init_chat_role`: Defaults to `None`. Sets the initial role in the chat template, if applicable. Refer to the model's card to set this value (e.g. for [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) you have to set `--init_chat_role system`)
199
+ - `--init_chat_prompt`: Defaults to `"You are a helpful AI assistant."` Required when setting `--init_chat_role`.
200
+
201
+ #### Speech to Text
202
+ - `--description`: Sets the description for Parler-TTS generated voice. Defaults to: `"A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."`
203
+
204
+ - `--play_steps_s`: Specifies the duration of the first chunk sent during streaming output from Parler-TTS, impacting readiness and decoding steps.
205
+
206
+ ## Citations
207
+
208
+ ### Silero VAD
209
+ ```bibtex
210
+ @misc{Silero VAD,
211
+ author = {Silero Team},
212
+ title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
213
+ year = {2021},
214
+ publisher = {GitHub},
215
+ journal = {GitHub repository},
216
+ howpublished = {\url{https://github.com/snakers4/silero-vad}},
217
+ commit = {insert_some_commit_here},
218
+ email = {hello@silero.ai}
219
+ }
220
+ ```
221
+
222
+ ### Distil-Whisper
223
+ ```bibtex
224
+ @misc{gandhi2023distilwhisper,
225
+ title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
226
+ author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
227
+ year={2023},
228
+ eprint={2311.00430},
229
+ archivePrefix={arXiv},
230
+ primaryClass={cs.CL}
231
+ }
232
+ ```
233
+
234
+ ### Parler-TTS
235
+ ```bibtex
236
+ @misc{lacombe-etal-2024-parler-tts,
237
+ author = {Yoach Lacombe and Vaibhav Srivastav and Sanchit Gandhi},
238
+ title = {Parler-TTS},
239
+ year = {2024},
240
+ publisher = {GitHub},
241
+ journal = {GitHub repository},
242
+ howpublished = {\url{https://github.com/huggingface/parler-tts}}
243
+ }
244
+ ```
STT/__pycache__/lightning_whisper_mlx_handler.cpython-311.pyc ADDED
Binary file (4.17 kB). View file
 
STT/__pycache__/paraformer_handler.cpython-311.pyc ADDED
Binary file (3.59 kB). View file
 
STT/__pycache__/whisper_stt_handler.cpython-311.pyc ADDED
Binary file (6.46 kB). View file
 
STT/lightning_whisper_mlx_handler.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from time import perf_counter
3
+ from baseHandler import BaseHandler
4
+ from lightning_whisper_mlx import LightningWhisperMLX
5
+ import numpy as np
6
+ from rich.console import Console
7
+ from copy import copy
8
+ import torch
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ console = Console()
13
+
14
+ SUPPORTED_LANGUAGES = [
15
+ "en",
16
+ "fr",
17
+ "es",
18
+ "zh",
19
+ "ja",
20
+ "ko",
21
+ ]
22
+
23
+
24
+ class LightningWhisperSTTHandler(BaseHandler):
25
+ """
26
+ Handles the Speech To Text generation using a Whisper model.
27
+ """
28
+
29
+ def setup(
30
+ self,
31
+ model_name="distil-large-v3",
32
+ device="mps",
33
+ torch_dtype="float16",
34
+ compile_mode=None,
35
+ language=None,
36
+ gen_kwargs={},
37
+ ):
38
+ if len(model_name.split("/")) > 1:
39
+ model_name = model_name.split("/")[-1]
40
+ self.device = device
41
+ self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None)
42
+ self.start_language = language
43
+ self.last_language = language
44
+
45
+ self.warmup()
46
+
47
+ def warmup(self):
48
+ logger.info(f"Warming up {self.__class__.__name__}")
49
+
50
+ # 2 warmup steps for no compile or compile mode with CUDA graphs capture
51
+ n_steps = 1
52
+ dummy_input = np.array([0] * 512)
53
+
54
+ for _ in range(n_steps):
55
+ _ = self.model.transcribe(dummy_input)["text"].strip()
56
+
57
+ def process(self, spoken_prompt):
58
+ logger.debug("infering whisper...")
59
+
60
+ global pipeline_start
61
+ pipeline_start = perf_counter()
62
+
63
+ if self.start_language != 'auto':
64
+ transcription_dict = self.model.transcribe(spoken_prompt, language=self.start_language)
65
+ else:
66
+ transcription_dict = self.model.transcribe(spoken_prompt)
67
+ language_code = transcription_dict["language"]
68
+ if language_code not in SUPPORTED_LANGUAGES:
69
+ logger.warning(f"Whisper detected unsupported language: {language_code}")
70
+ if self.last_language in SUPPORTED_LANGUAGES: # reprocess with the last language
71
+ transcription_dict = self.model.transcribe(spoken_prompt, language=self.last_language)
72
+ else:
73
+ transcription_dict = {"text": "", "language": "en"}
74
+ else:
75
+ self.last_language = language_code
76
+
77
+ pred_text = transcription_dict["text"].strip()
78
+ language_code = transcription_dict["language"]
79
+ torch.mps.empty_cache()
80
+
81
+ logger.debug("finished whisper inference")
82
+ console.print(f"[yellow]USER: {pred_text}")
83
+ logger.debug(f"Language Code Whisper: {language_code}")
84
+
85
+ yield (pred_text, language_code)
STT/paraformer_handler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from time import perf_counter
3
+
4
+ from baseHandler import BaseHandler
5
+ from funasr import AutoModel
6
+ import numpy as np
7
+ from rich.console import Console
8
+ import torch
9
+
10
+ logging.basicConfig(
11
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
12
+ )
13
+ logger = logging.getLogger(__name__)
14
+
15
+ console = Console()
16
+
17
+
18
+ class ParaformerSTTHandler(BaseHandler):
19
+ """
20
+ Handles the Speech To Text generation using a Paraformer model.
21
+ The default for this model is set to Chinese.
22
+ This model was contributed by @wuhongsheng.
23
+ """
24
+
25
+ def setup(
26
+ self,
27
+ model_name="paraformer-zh",
28
+ device="cuda",
29
+ gen_kwargs={},
30
+ ):
31
+ print(model_name)
32
+ if len(model_name.split("/")) > 1:
33
+ model_name = model_name.split("/")[-1]
34
+ self.device = device
35
+ self.model = AutoModel(model=model_name, device=device)
36
+ self.warmup()
37
+
38
+ def warmup(self):
39
+ logger.info(f"Warming up {self.__class__.__name__}")
40
+
41
+ # 2 warmup steps for no compile or compile mode with CUDA graphs capture
42
+ n_steps = 1
43
+ dummy_input = np.array([0] * 512, dtype=np.float32)
44
+ for _ in range(n_steps):
45
+ _ = self.model.generate(dummy_input)[0]["text"].strip().replace(" ", "")
46
+
47
+ def process(self, spoken_prompt):
48
+ logger.debug("infering paraformer...")
49
+
50
+ global pipeline_start
51
+ pipeline_start = perf_counter()
52
+
53
+ pred_text = (
54
+ self.model.generate(spoken_prompt)[0]["text"].strip().replace(" ", "")
55
+ )
56
+ torch.mps.empty_cache()
57
+
58
+ logger.debug("finished paraformer inference")
59
+ console.print(f"[yellow]USER: {pred_text}")
60
+
61
+ yield pred_text
STT/whisper_stt_handler.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter
2
+ from transformers import (
3
+ AutoProcessor,
4
+ AutoModelForSpeechSeq2Seq
5
+ )
6
+ import torch
7
+ from copy import copy
8
+ from baseHandler import BaseHandler
9
+ from rich.console import Console
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+ console = Console()
14
+
15
+ SUPPORTED_LANGUAGES = [
16
+ "en",
17
+ "fr",
18
+ "es",
19
+ "zh",
20
+ "ja",
21
+ "ko",
22
+ ]
23
+
24
+
25
+ class WhisperSTTHandler(BaseHandler):
26
+ """
27
+ Handles the Speech To Text generation using a Whisper model.
28
+ """
29
+
30
+ def setup(
31
+ self,
32
+ model_name="distil-whisper/distil-large-v3",
33
+ device="cuda",
34
+ torch_dtype="float16",
35
+ compile_mode=None,
36
+ language=None,
37
+ gen_kwargs={},
38
+ ):
39
+ self.device = device
40
+ self.torch_dtype = getattr(torch, torch_dtype)
41
+ self.compile_mode = compile_mode
42
+ self.gen_kwargs = gen_kwargs
43
+ if language == 'auto':
44
+ language = None
45
+ self.last_language = language
46
+ if self.last_language is not None:
47
+ self.gen_kwargs["language"] = self.last_language
48
+
49
+ self.processor = AutoProcessor.from_pretrained(model_name)
50
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
51
+ model_name,
52
+ torch_dtype=self.torch_dtype,
53
+ ).to(device)
54
+
55
+ # compile
56
+ if self.compile_mode:
57
+ self.model.generation_config.cache_implementation = "static"
58
+ self.model.forward = torch.compile(
59
+ self.model.forward, mode=self.compile_mode, fullgraph=True
60
+ )
61
+ self.warmup()
62
+
63
+ def prepare_model_inputs(self, spoken_prompt):
64
+ input_features = self.processor(
65
+ spoken_prompt, sampling_rate=16000, return_tensors="pt"
66
+ ).input_features
67
+ input_features = input_features.to(self.device, dtype=self.torch_dtype)
68
+
69
+ return input_features
70
+
71
+ def warmup(self):
72
+ logger.info(f"Warming up {self.__class__.__name__}")
73
+
74
+ # 2 warmup steps for no compile or compile mode with CUDA graphs capture
75
+ n_steps = 1 if self.compile_mode == "default" else 2
76
+ dummy_input = torch.randn(
77
+ (1, self.model.config.num_mel_bins, 3000),
78
+ dtype=self.torch_dtype,
79
+ device=self.device,
80
+ )
81
+ if self.compile_mode not in (None, "default"):
82
+ # generating more tokens than previously will trigger CUDA graphs capture
83
+ # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation
84
+ # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense
85
+ warmup_gen_kwargs = {
86
+ "min_new_tokens": self.gen_kwargs[
87
+ "max_new_tokens"
88
+ ], # Yes, assign max_new_tokens to min_new_tokens
89
+ "max_new_tokens": self.gen_kwargs["max_new_tokens"],
90
+ **self.gen_kwargs,
91
+ }
92
+ else:
93
+ warmup_gen_kwargs = self.gen_kwargs
94
+
95
+ if self.device == "cuda":
96
+ start_event = torch.cuda.Event(enable_timing=True)
97
+ end_event = torch.cuda.Event(enable_timing=True)
98
+ torch.cuda.synchronize()
99
+ start_event.record()
100
+
101
+ for _ in range(n_steps):
102
+ _ = self.model.generate(dummy_input, **warmup_gen_kwargs)
103
+
104
+ if self.device == "cuda":
105
+ end_event.record()
106
+ torch.cuda.synchronize()
107
+
108
+ logger.info(
109
+ f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
110
+ )
111
+
112
+ def process(self, spoken_prompt):
113
+ logger.debug("infering whisper...")
114
+
115
+ global pipeline_start
116
+ pipeline_start = perf_counter()
117
+
118
+ input_features = self.prepare_model_inputs(spoken_prompt)
119
+ pred_ids = self.model.generate(input_features, **self.gen_kwargs)
120
+ language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
121
+
122
+ if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language
123
+ logger.warning("Whisper detected unsupported language:", language_code)
124
+ gen_kwargs = copy(self.gen_kwargs)
125
+ gen_kwargs['language'] = self.last_language
126
+ language_code = self.last_language
127
+ pred_ids = self.model.generate(input_features, **gen_kwargs)
128
+ else:
129
+ self.last_language = language_code
130
+
131
+ pred_text = self.processor.batch_decode(
132
+ pred_ids, skip_special_tokens=True, decode_with_timestamps=False
133
+ )[0]
134
+ language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>"
135
+
136
+ logger.debug("finished whisper inference")
137
+ console.print(f"[yellow]USER: {pred_text}")
138
+ logger.debug(f"Language Code Whisper: {language_code}")
139
+
140
+ yield (pred_text, language_code)
TTS/__pycache__/chatTTS_handler.cpython-311.pyc ADDED
Binary file (4.78 kB). View file
 
TTS/__pycache__/melo_handler.cpython-311.pyc ADDED
Binary file (4.98 kB). View file
 
TTS/__pycache__/parler_handler.cpython-311.pyc ADDED
Binary file (9.7 kB). View file
 
TTS/chatTTS_handler.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ChatTTS
2
+ import logging
3
+ from baseHandler import BaseHandler
4
+ import librosa
5
+ import numpy as np
6
+ from rich.console import Console
7
+ import torch
8
+
9
+ logging.basicConfig(
10
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+ console = Console()
15
+
16
+
17
+ class ChatTTSHandler(BaseHandler):
18
+ def setup(
19
+ self,
20
+ should_listen,
21
+ device="cuda",
22
+ gen_kwargs={}, # Unused
23
+ stream=True,
24
+ chunk_size=512,
25
+ ):
26
+ self.should_listen = should_listen
27
+ self.device = device
28
+ self.model = ChatTTS.Chat()
29
+ self.model.load(compile=False) # Doesn't work for me with True
30
+ self.chunk_size = chunk_size
31
+ self.stream = stream
32
+ rnd_spk_emb = self.model.sample_random_speaker()
33
+ self.params_infer_code = ChatTTS.Chat.InferCodeParams(
34
+ spk_emb=rnd_spk_emb,
35
+ )
36
+ self.warmup()
37
+
38
+ def warmup(self):
39
+ logger.info(f"Warming up {self.__class__.__name__}")
40
+ _ = self.model.infer("text")
41
+
42
+ def process(self, llm_sentence):
43
+ console.print(f"[green]ASSISTANT: {llm_sentence}")
44
+ if self.device == "mps":
45
+ import time
46
+
47
+ start = time.time()
48
+ torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
49
+ torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
50
+ _ = (
51
+ time.time() - start
52
+ ) # Removing this line makes it fail more often. I'm looking into it.
53
+
54
+ wavs_gen = self.model.infer(
55
+ llm_sentence, params_infer_code=self.params_infer_code, stream=self.stream
56
+ )
57
+
58
+ if self.stream:
59
+ wavs = [np.array([])]
60
+ for gen in wavs_gen:
61
+ if gen[0] is None or len(gen[0]) == 0:
62
+ self.should_listen.set()
63
+ return
64
+ audio_chunk = librosa.resample(gen[0], orig_sr=24000, target_sr=16000)
65
+ audio_chunk = (audio_chunk * 32768).astype(np.int16)[0]
66
+ while len(audio_chunk) > self.chunk_size:
67
+ yield audio_chunk[: self.chunk_size] # 返回前 chunk_size 字节的数据
68
+ audio_chunk = audio_chunk[self.chunk_size :] # 移除已返回的数据
69
+ yield np.pad(audio_chunk, (0, self.chunk_size - len(audio_chunk)))
70
+ else:
71
+ wavs = wavs_gen
72
+ if len(wavs[0]) == 0:
73
+ self.should_listen.set()
74
+ return
75
+ audio_chunk = librosa.resample(wavs[0], orig_sr=24000, target_sr=16000)
76
+ audio_chunk = (audio_chunk * 32768).astype(np.int16)
77
+ for i in range(0, len(audio_chunk), self.chunk_size):
78
+ yield np.pad(
79
+ audio_chunk[i : i + self.chunk_size],
80
+ (0, self.chunk_size - len(audio_chunk[i : i + self.chunk_size])),
81
+ )
82
+ self.should_listen.set()
TTS/melo_handler.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from melo.api import TTS
2
+ import logging
3
+ from baseHandler import BaseHandler
4
+ import librosa
5
+ import numpy as np
6
+ from rich.console import Console
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ console = Console()
12
+
13
+ WHISPER_LANGUAGE_TO_MELO_LANGUAGE = {
14
+ "en": "EN",
15
+ "fr": "FR",
16
+ "es": "ES",
17
+ "zh": "ZH",
18
+ "ja": "JP",
19
+ "ko": "KR",
20
+ }
21
+
22
+ WHISPER_LANGUAGE_TO_MELO_SPEAKER = {
23
+ "en": "EN-BR",
24
+ "fr": "FR",
25
+ "es": "ES",
26
+ "zh": "ZH",
27
+ "ja": "JP",
28
+ "ko": "KR",
29
+ }
30
+
31
+
32
+ class MeloTTSHandler(BaseHandler):
33
+ def setup(
34
+ self,
35
+ should_listen,
36
+ device="mps",
37
+ language="en",
38
+ speaker_to_id="en",
39
+ gen_kwargs={}, # Unused
40
+ blocksize=512,
41
+ ):
42
+ self.should_listen = should_listen
43
+ self.device = device
44
+ self.language = language
45
+ self.model = TTS(
46
+ language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[self.language], device=device
47
+ )
48
+ self.speaker_id = self.model.hps.data.spk2id[
49
+ WHISPER_LANGUAGE_TO_MELO_SPEAKER[speaker_to_id]
50
+ ]
51
+ self.blocksize = blocksize
52
+ self.warmup()
53
+
54
+ def warmup(self):
55
+ logger.info(f"Warming up {self.__class__.__name__}")
56
+ _ = self.model.tts_to_file("text", self.speaker_id, quiet=True)
57
+
58
+ def process(self, llm_sentence):
59
+ language_code = None
60
+
61
+ if isinstance(llm_sentence, tuple):
62
+ llm_sentence, language_code = llm_sentence
63
+
64
+ console.print(f"[green]ASSISTANT: {llm_sentence}")
65
+
66
+ if language_code is not None and self.language != language_code:
67
+ try:
68
+ self.model = TTS(
69
+ language=WHISPER_LANGUAGE_TO_MELO_LANGUAGE[language_code],
70
+ device=self.device,
71
+ )
72
+ self.speaker_id = self.model.hps.data.spk2id[
73
+ WHISPER_LANGUAGE_TO_MELO_SPEAKER[language_code]
74
+ ]
75
+ self.language = language_code
76
+ except KeyError:
77
+ console.print(
78
+ f"[red]Language {language_code} not supported by Melo. Using {self.language} instead."
79
+ )
80
+
81
+ if self.device == "mps":
82
+ import time
83
+
84
+ start = time.time()
85
+ torch.mps.synchronize() # Waits for all kernels in all streams on the MPS device to complete.
86
+ torch.mps.empty_cache() # Frees all memory allocated by the MPS device.
87
+ _ = (
88
+ time.time() - start
89
+ ) # Removing this line makes it fail more often. I'm looking into it.
90
+
91
+ try:
92
+ audio_chunk = self.model.tts_to_file(
93
+ llm_sentence, self.speaker_id, quiet=True
94
+ )
95
+ except (AssertionError, RuntimeError) as e:
96
+ logger.error(f"Error in MeloTTSHandler: {e}")
97
+ audio_chunk = np.array([])
98
+ if len(audio_chunk) == 0:
99
+ self.should_listen.set()
100
+ return
101
+ audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
102
+ audio_chunk = (audio_chunk * 32768).astype(np.int16)
103
+ for i in range(0, len(audio_chunk), self.blocksize):
104
+ yield np.pad(
105
+ audio_chunk[i : i + self.blocksize],
106
+ (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
107
+ )
108
+
109
+ self.should_listen.set()
TTS/parler_handler.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from time import perf_counter
3
+ from baseHandler import BaseHandler
4
+ import numpy as np
5
+ import torch
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ )
9
+ from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
10
+ import librosa
11
+ import logging
12
+ from rich.console import Console
13
+ from utils.utils import next_power_of_2
14
+ from transformers.utils.import_utils import (
15
+ is_flash_attn_2_available,
16
+ )
17
+
18
+ torch._inductor.config.fx_graph_cache = True
19
+ # mind about this parameter ! should be >= 2 * number of padded prompt sizes for TTS
20
+ torch._dynamo.config.cache_size_limit = 15
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ console = Console()
25
+
26
+
27
+ if not is_flash_attn_2_available() and torch.cuda.is_available():
28
+ logger.warn(
29
+ """Parler TTS works best with flash attention 2, but is not installed
30
+ Given that CUDA is available in this system, you can install flash attention 2 with `uv pip install flash-attn --no-build-isolation`"""
31
+ )
32
+
33
+
34
+ class ParlerTTSHandler(BaseHandler):
35
+ def setup(
36
+ self,
37
+ should_listen,
38
+ model_name="ylacombe/parler-tts-mini-jenny-30H",
39
+ device="cuda",
40
+ torch_dtype="float16",
41
+ compile_mode=None,
42
+ gen_kwargs={},
43
+ max_prompt_pad_length=8,
44
+ description=(
45
+ "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
46
+ "She speaks very fast."
47
+ ),
48
+ play_steps_s=1,
49
+ blocksize=512,
50
+ ):
51
+ self.should_listen = should_listen
52
+ self.device = device
53
+ self.torch_dtype = getattr(torch, torch_dtype)
54
+ self.gen_kwargs = gen_kwargs
55
+ self.compile_mode = compile_mode
56
+ self.max_prompt_pad_length = max_prompt_pad_length
57
+ self.description = description
58
+
59
+ self.description_tokenizer = AutoTokenizer.from_pretrained(model_name)
60
+ self.prompt_tokenizer = AutoTokenizer.from_pretrained(model_name)
61
+ self.model = ParlerTTSForConditionalGeneration.from_pretrained(
62
+ model_name, torch_dtype=self.torch_dtype
63
+ ).to(device)
64
+
65
+ framerate = self.model.audio_encoder.config.frame_rate
66
+ self.play_steps = int(framerate * play_steps_s)
67
+ self.blocksize = blocksize
68
+
69
+ if self.compile_mode not in (None, "default"):
70
+ logger.warning(
71
+ "Torch compilation modes that captures CUDA graphs are not yet compatible with the TTS part. Reverting to 'default'"
72
+ )
73
+ self.compile_mode = "default"
74
+
75
+ if self.compile_mode:
76
+ self.model.generation_config.cache_implementation = "static"
77
+ self.model.forward = torch.compile(
78
+ self.model.forward, mode=self.compile_mode, fullgraph=True
79
+ )
80
+
81
+ self.warmup()
82
+
83
+ def prepare_model_inputs(
84
+ self,
85
+ prompt,
86
+ max_length_prompt=50,
87
+ pad=False,
88
+ ):
89
+ pad_args_prompt = (
90
+ {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
91
+ )
92
+
93
+ tokenized_description = self.description_tokenizer(
94
+ self.description, return_tensors="pt"
95
+ )
96
+ input_ids = tokenized_description.input_ids.to(self.device)
97
+ attention_mask = tokenized_description.attention_mask.to(self.device)
98
+
99
+ tokenized_prompt = self.prompt_tokenizer(
100
+ prompt, return_tensors="pt", **pad_args_prompt
101
+ )
102
+ prompt_input_ids = tokenized_prompt.input_ids.to(self.device)
103
+ prompt_attention_mask = tokenized_prompt.attention_mask.to(self.device)
104
+
105
+ gen_kwargs = {
106
+ "input_ids": input_ids,
107
+ "attention_mask": attention_mask,
108
+ "prompt_input_ids": prompt_input_ids,
109
+ "prompt_attention_mask": prompt_attention_mask,
110
+ **self.gen_kwargs,
111
+ }
112
+
113
+ return gen_kwargs
114
+
115
+ def warmup(self):
116
+ logger.info(f"Warming up {self.__class__.__name__}")
117
+
118
+ if self.device == "cuda":
119
+ start_event = torch.cuda.Event(enable_timing=True)
120
+ end_event = torch.cuda.Event(enable_timing=True)
121
+
122
+ # 2 warmup steps for no compile or compile mode with CUDA graphs capture
123
+ n_steps = 1 if self.compile_mode == "default" else 2
124
+
125
+ if self.device == "cuda":
126
+ torch.cuda.synchronize()
127
+ start_event.record()
128
+ if self.compile_mode:
129
+ pad_lengths = [2**i for i in range(2, self.max_prompt_pad_length)]
130
+ for pad_length in pad_lengths[::-1]:
131
+ model_kwargs = self.prepare_model_inputs(
132
+ "dummy prompt", max_length_prompt=pad_length, pad=True
133
+ )
134
+ for _ in range(n_steps):
135
+ _ = self.model.generate(**model_kwargs)
136
+ logger.info(f"Warmed up length {pad_length} tokens!")
137
+ else:
138
+ model_kwargs = self.prepare_model_inputs("dummy prompt")
139
+ for _ in range(n_steps):
140
+ _ = self.model.generate(**model_kwargs)
141
+
142
+ if self.device == "cuda":
143
+ end_event.record()
144
+ torch.cuda.synchronize()
145
+ logger.info(
146
+ f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s"
147
+ )
148
+
149
+ def process(self, llm_sentence):
150
+ if isinstance(llm_sentence, tuple):
151
+ llm_sentence, _ = llm_sentence
152
+
153
+ console.print(f"[green]ASSISTANT: {llm_sentence}")
154
+ nb_tokens = len(self.prompt_tokenizer(llm_sentence).input_ids)
155
+
156
+ pad_args = {}
157
+ if self.compile_mode:
158
+ # pad to closest upper power of two
159
+ pad_length = next_power_of_2(nb_tokens)
160
+ logger.debug(f"padding to {pad_length}")
161
+ pad_args["pad"] = True
162
+ pad_args["max_length_prompt"] = pad_length
163
+
164
+ tts_gen_kwargs = self.prepare_model_inputs(
165
+ llm_sentence,
166
+ **pad_args,
167
+ )
168
+
169
+ streamer = ParlerTTSStreamer(
170
+ self.model, device=self.device, play_steps=self.play_steps
171
+ )
172
+ tts_gen_kwargs = {"streamer": streamer, **tts_gen_kwargs}
173
+ torch.manual_seed(0)
174
+ thread = Thread(target=self.model.generate, kwargs=tts_gen_kwargs)
175
+ thread.start()
176
+
177
+ for i, audio_chunk in enumerate(streamer):
178
+ global pipeline_start
179
+ if i == 0 and "pipeline_start" in globals():
180
+ logger.info(
181
+ f"Time to first audio: {perf_counter() - pipeline_start:.3f}"
182
+ )
183
+ audio_chunk = librosa.resample(audio_chunk, orig_sr=44100, target_sr=16000)
184
+ audio_chunk = (audio_chunk * 32768).astype(np.int16)
185
+ for i in range(0, len(audio_chunk), self.blocksize):
186
+ yield np.pad(
187
+ audio_chunk[i : i + self.blocksize],
188
+ (0, self.blocksize - len(audio_chunk[i : i + self.blocksize])),
189
+ )
190
+
191
+ self.should_listen.set()
VAD/__pycache__/vad_handler.cpython-311.pyc ADDED
Binary file (4.81 kB). View file
 
VAD/__pycache__/vad_handler.cpython-312.pyc ADDED
Binary file (4.46 kB). View file
 
VAD/__pycache__/vad_iterator.cpython-311.pyc ADDED
Binary file (4.4 kB). View file
 
VAD/__pycache__/vad_iterator.cpython-312.pyc ADDED
Binary file (4.24 kB). View file
 
VAD/vad_handler.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ from VAD.vad_iterator import VADIterator
3
+ from baseHandler import BaseHandler
4
+ import numpy as np
5
+ import torch
6
+ from rich.console import Console
7
+
8
+ from utils.utils import int2float
9
+ from df.enhance import enhance, init_df
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ console = Console()
15
+
16
+
17
+ class VADHandler(BaseHandler):
18
+ """
19
+ Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed
20
+ to the following part.
21
+ """
22
+
23
+ def setup(
24
+ self,
25
+ should_listen,
26
+ thresh=0.3,
27
+ sample_rate=16000,
28
+ min_silence_ms=1000,
29
+ min_speech_ms=500,
30
+ max_speech_ms=float("inf"),
31
+ speech_pad_ms=30,
32
+ audio_enhancement=False,
33
+ ):
34
+ self.should_listen = should_listen
35
+ self.sample_rate = sample_rate
36
+ self.min_silence_ms = min_silence_ms
37
+ self.min_speech_ms = min_speech_ms
38
+ self.max_speech_ms = max_speech_ms
39
+ self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad")
40
+ self.iterator = VADIterator(
41
+ self.model,
42
+ threshold=thresh,
43
+ sampling_rate=sample_rate,
44
+ min_silence_duration_ms=min_silence_ms,
45
+ speech_pad_ms=speech_pad_ms,
46
+ )
47
+ self.audio_enhancement = audio_enhancement
48
+ if audio_enhancement:
49
+ self.enhanced_model, self.df_state, _ = init_df()
50
+
51
+ def process(self, audio_chunk):
52
+ audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16)
53
+ audio_float32 = int2float(audio_int16)
54
+ vad_output = self.iterator(torch.from_numpy(audio_float32))
55
+ if vad_output is not None and len(vad_output) != 0:
56
+ logger.debug("VAD: end of speech detected")
57
+ array = torch.cat(vad_output).cpu().numpy()
58
+ duration_ms = len(array) / self.sample_rate * 1000
59
+ if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms:
60
+ logger.debug(
61
+ f"audio input of duration: {len(array) / self.sample_rate}s, skipping"
62
+ )
63
+ else:
64
+ self.should_listen.clear()
65
+ logger.debug("Stop listening")
66
+ if self.audio_enhancement:
67
+ if self.sample_rate != self.df_state.sr():
68
+ audio_float32 = torchaudio.functional.resample(
69
+ torch.from_numpy(array),
70
+ orig_freq=self.sample_rate,
71
+ new_freq=self.df_state.sr(),
72
+ )
73
+ enhanced = enhance(
74
+ self.enhanced_model,
75
+ self.df_state,
76
+ audio_float32.unsqueeze(0),
77
+ )
78
+ enhanced = torchaudio.functional.resample(
79
+ enhanced,
80
+ orig_freq=self.df_state.sr(),
81
+ new_freq=self.sample_rate,
82
+ )
83
+ else:
84
+ enhanced = enhance(
85
+ self.enhanced_model, self.df_state, audio_float32
86
+ )
87
+ array = enhanced.numpy().squeeze()
88
+ yield array
89
+
90
+ @property
91
+ def min_time_to_debug(self):
92
+ return 0.00001
VAD/vad_iterator.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class VADIterator:
5
+ def __init__(
6
+ self,
7
+ model,
8
+ threshold: float = 0.5,
9
+ sampling_rate: int = 16000,
10
+ min_silence_duration_ms: int = 100,
11
+ speech_pad_ms: int = 30,
12
+ ):
13
+ """
14
+ Mainly taken from https://github.com/snakers4/silero-vad
15
+ Class for stream imitation
16
+
17
+ Parameters
18
+ ----------
19
+ model: preloaded .jit/.onnx silero VAD model
20
+
21
+ threshold: float (default - 0.5)
22
+ Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
23
+ It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
24
+
25
+ sampling_rate: int (default - 16000)
26
+ Currently silero VAD models support 8000 and 16000 sample rates
27
+
28
+ min_silence_duration_ms: int (default - 100 milliseconds)
29
+ In the end of each speech chunk wait for min_silence_duration_ms before separating it
30
+
31
+ speech_pad_ms: int (default - 30 milliseconds)
32
+ Final speech chunks are padded by speech_pad_ms each side
33
+ """
34
+
35
+ self.model = model
36
+ self.threshold = threshold
37
+ self.sampling_rate = sampling_rate
38
+ self.is_speaking = False
39
+ self.buffer = []
40
+
41
+ if sampling_rate not in [8000, 16000]:
42
+ raise ValueError(
43
+ "VADIterator does not support sampling rates other than [8000, 16000]"
44
+ )
45
+
46
+ self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
47
+ self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
48
+ self.reset_states()
49
+
50
+ def reset_states(self):
51
+ self.model.reset_states()
52
+ self.triggered = False
53
+ self.temp_end = 0
54
+ self.current_sample = 0
55
+
56
+ @torch.no_grad()
57
+ def __call__(self, x):
58
+ """
59
+ x: torch.Tensor
60
+ audio chunk (see examples in repo)
61
+
62
+ return_seconds: bool (default - False)
63
+ whether return timestamps in seconds (default - samples)
64
+ """
65
+
66
+ if not torch.is_tensor(x):
67
+ try:
68
+ x = torch.Tensor(x)
69
+ except Exception:
70
+ raise TypeError("Audio cannot be casted to tensor. Cast it manually")
71
+
72
+ window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
73
+ self.current_sample += window_size_samples
74
+
75
+ speech_prob = self.model(x, self.sampling_rate).item()
76
+
77
+ if (speech_prob >= self.threshold) and self.temp_end:
78
+ self.temp_end = 0
79
+
80
+ if (speech_prob >= self.threshold) and not self.triggered:
81
+ self.triggered = True
82
+ return None
83
+
84
+ if (speech_prob < self.threshold - 0.15) and self.triggered:
85
+ if not self.temp_end:
86
+ self.temp_end = self.current_sample
87
+ if self.current_sample - self.temp_end < self.min_silence_samples:
88
+ return None
89
+ else:
90
+ # end of speak
91
+ self.temp_end = 0
92
+ self.triggered = False
93
+ spoken_utterance = self.buffer
94
+ self.buffer = []
95
+ return spoken_utterance
96
+
97
+ if self.triggered:
98
+ self.buffer.append(x)
99
+
100
+ return None
arguments_classes/__pycache__/chat_tts_arguments.cpython-311.pyc ADDED
Binary file (1.22 kB). View file
 
arguments_classes/__pycache__/language_model_arguments.cpython-311.pyc ADDED
Binary file (3.17 kB). View file
 
arguments_classes/__pycache__/melo_tts_arguments.cpython-311.pyc ADDED
Binary file (1.17 kB). View file
 
arguments_classes/__pycache__/mlx_language_model_arguments.cpython-311.pyc ADDED
Binary file (3.02 kB). View file
 
arguments_classes/__pycache__/module_arguments.cpython-311.pyc ADDED
Binary file (2.11 kB). View file
 
arguments_classes/__pycache__/paraformer_stt_arguments.cpython-311.pyc ADDED
Binary file (1.1 kB). View file
 
arguments_classes/__pycache__/parler_tts_arguments.cpython-311.pyc ADDED
Binary file (2.92 kB). View file
 
arguments_classes/__pycache__/socket_receiver_arguments.cpython-311.pyc ADDED
Binary file (1.27 kB). View file
 
arguments_classes/__pycache__/socket_sender_arguments.cpython-311.pyc ADDED
Binary file (1.06 kB). View file
 
arguments_classes/__pycache__/vad_arguments.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
arguments_classes/__pycache__/whisper_stt_arguments.cpython-311.pyc ADDED
Binary file (2.9 kB). View file
 
arguments_classes/chat_tts_arguments.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class ChatTTSHandlerArguments:
6
+ chat_tts_stream: bool = field(
7
+ default=True,
8
+ metadata={"help": "The tts mode is stream Default is 'stream'."},
9
+ )
10
+ chat_tts_device: str = field(
11
+ default="cuda",
12
+ metadata={
13
+ "help": "The device to be used for speech synthesis. Default is 'cuda'."
14
+ },
15
+ )
16
+ chat_tts_chunk_size: int = field(
17
+ default=512,
18
+ metadata={
19
+ "help": "Sets the size of the audio data chunk processed per cycle, balancing playback latency and CPU load.. Default is 512。."
20
+ },
21
+ )
arguments_classes/language_model_arguments.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class LanguageModelHandlerArguments:
6
+ lm_model_name: str = field(
7
+ default="HuggingFaceTB/SmolLM-360M-Instruct",
8
+ metadata={
9
+ "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
10
+ },
11
+ )
12
+ lm_device: str = field(
13
+ default="cuda",
14
+ metadata={
15
+ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
16
+ },
17
+ )
18
+ lm_torch_dtype: str = field(
19
+ default="float16",
20
+ metadata={
21
+ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
22
+ },
23
+ )
24
+ user_role: str = field(
25
+ default="user",
26
+ metadata={
27
+ "help": "Role assigned to the user in the chat context. Default is 'user'."
28
+ },
29
+ )
30
+ init_chat_role: str = field(
31
+ default="system",
32
+ metadata={
33
+ "help": "Initial role for setting up the chat context. Default is 'system'."
34
+ },
35
+ )
36
+ init_chat_prompt: str = field(
37
+ default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
38
+ metadata={
39
+ "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
40
+ },
41
+ )
42
+ lm_gen_max_new_tokens: int = field(
43
+ default=128,
44
+ metadata={
45
+ "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
46
+ },
47
+ )
48
+ lm_gen_min_new_tokens: int = field(
49
+ default=0,
50
+ metadata={
51
+ "help": "Minimum number of new tokens to generate in a single completion. Default is 0."
52
+ },
53
+ )
54
+ lm_gen_temperature: float = field(
55
+ default=0.0,
56
+ metadata={
57
+ "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
58
+ },
59
+ )
60
+ lm_gen_do_sample: bool = field(
61
+ default=False,
62
+ metadata={
63
+ "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
64
+ },
65
+ )
66
+ chat_size: int = field(
67
+ default=2,
68
+ metadata={
69
+ "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
70
+ },
71
+ )
arguments_classes/melo_tts_arguments.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MeloTTSHandlerArguments:
6
+ melo_language: str = field(
7
+ default="en",
8
+ metadata={
9
+ "help": "The language of the text to be synthesized. Default is 'EN_NEWEST'."
10
+ },
11
+ )
12
+ melo_device: str = field(
13
+ default="auto",
14
+ metadata={
15
+ "help": "The device to be used for speech synthesis. Default is 'auto'."
16
+ },
17
+ )
18
+ melo_speaker_to_id: str = field(
19
+ default="en",
20
+ metadata={
21
+ "help": "Mapping of speaker names to speaker IDs. Default is ['EN-Newest']."
22
+ },
23
+ )
arguments_classes/mlx_language_model_arguments.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MLXLanguageModelHandlerArguments:
6
+ mlx_lm_model_name: str = field(
7
+ default="mlx-community/SmolLM-360M-Instruct",
8
+ metadata={
9
+ "help": "The pretrained language model to use. Default is 'microsoft/Phi-3-mini-4k-instruct'."
10
+ },
11
+ )
12
+ mlx_lm_device: str = field(
13
+ default="mps",
14
+ metadata={
15
+ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
16
+ },
17
+ )
18
+ mlx_lm_torch_dtype: str = field(
19
+ default="float16",
20
+ metadata={
21
+ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
22
+ },
23
+ )
24
+ mlx_lm_user_role: str = field(
25
+ default="user",
26
+ metadata={
27
+ "help": "Role assigned to the user in the chat context. Default is 'user'."
28
+ },
29
+ )
30
+ mlx_lm_init_chat_role: str = field(
31
+ default="system",
32
+ metadata={
33
+ "help": "Initial role for setting up the chat context. Default is 'system'."
34
+ },
35
+ )
36
+ mlx_lm_init_chat_prompt: str = field(
37
+ default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
38
+ metadata={
39
+ "help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
40
+ },
41
+ )
42
+ mlx_lm_gen_max_new_tokens: int = field(
43
+ default=128,
44
+ metadata={
45
+ "help": "Maximum number of new tokens to generate in a single completion. Default is 128."
46
+ },
47
+ )
48
+ mlx_lm_gen_temperature: float = field(
49
+ default=0.0,
50
+ metadata={
51
+ "help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
52
+ },
53
+ )
54
+ mlx_lm_gen_do_sample: bool = field(
55
+ default=False,
56
+ metadata={
57
+ "help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
58
+ },
59
+ )
60
+ mlx_lm_chat_size: int = field(
61
+ default=2,
62
+ metadata={
63
+ "help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
64
+ },
65
+ )
arguments_classes/module_arguments.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class ModuleArguments:
7
+ device: Optional[str] = field(
8
+ default=None,
9
+ metadata={"help": "If specified, overrides the device for all handlers."},
10
+ )
11
+ mode: Optional[str] = field(
12
+ default="socket",
13
+ metadata={
14
+ "help": "The mode to run the pipeline in. Either 'local' or 'socket'. Default is 'socket'."
15
+ },
16
+ )
17
+ local_mac_optimal_settings: bool = field(
18
+ default=False,
19
+ metadata={
20
+ "help": "If specified, sets the optimal settings for Mac OS. Hence whisper-mlx, MLX LM and MeloTTS will be used."
21
+ },
22
+ )
23
+ stt: Optional[str] = field(
24
+ default="whisper",
25
+ metadata={
26
+ "help": "The STT to use. Either 'whisper', 'whisper-mlx', and 'paraformer'. Default is 'whisper'."
27
+ },
28
+ )
29
+ llm: Optional[str] = field(
30
+ default="transformers",
31
+ metadata={
32
+ "help": "The LLM to use. Either 'transformers' or 'mlx-lm'. Default is 'transformers'"
33
+ },
34
+ )
35
+ tts: Optional[str] = field(
36
+ default="parler",
37
+ metadata={
38
+ "help": "The TTS to use. Either 'parler', 'melo', or 'chatTTS'. Default is 'parler'"
39
+ },
40
+ )
41
+ log_level: str = field(
42
+ default="info",
43
+ metadata={
44
+ "help": "Provide logging level. Example --log_level debug, default=warning."
45
+ },
46
+ )
arguments_classes/paraformer_stt_arguments.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class ParaformerSTTHandlerArguments:
6
+ paraformer_stt_model_name: str = field(
7
+ default="paraformer-zh",
8
+ metadata={
9
+ "help": "The pretrained model to use. Default is 'paraformer-zh'. Can be choose from https://github.com/modelscope/FunASR"
10
+ },
11
+ )
12
+ paraformer_stt_device: str = field(
13
+ default="cuda",
14
+ metadata={
15
+ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
16
+ },
17
+ )
arguments_classes/parler_tts_arguments.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class ParlerTTSHandlerArguments:
6
+ tts_model_name: str = field(
7
+ default="ylacombe/parler-tts-mini-jenny-30H",
8
+ metadata={
9
+ "help": "The pretrained TTS model to use. Default is 'ylacombe/parler-tts-mini-jenny-30H'."
10
+ },
11
+ )
12
+ tts_device: str = field(
13
+ default="cuda",
14
+ metadata={
15
+ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
16
+ },
17
+ )
18
+ tts_torch_dtype: str = field(
19
+ default="float16",
20
+ metadata={
21
+ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
22
+ },
23
+ )
24
+ tts_compile_mode: str = field(
25
+ default=None,
26
+ metadata={
27
+ "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
28
+ },
29
+ )
30
+ tts_gen_min_new_tokens: int = field(
31
+ default=64,
32
+ metadata={
33
+ "help": "Maximum number of new tokens to generate in a single completion. Default is 10, which corresponds to ~0.1 secs"
34
+ },
35
+ )
36
+ tts_gen_max_new_tokens: int = field(
37
+ default=512,
38
+ metadata={
39
+ "help": "Maximum number of new tokens to generate in a single completion. Default is 256, which corresponds to ~6 secs"
40
+ },
41
+ )
42
+ description: str = field(
43
+ default=(
44
+ "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. "
45
+ "She speaks very fast."
46
+ ),
47
+ metadata={
48
+ "help": "Description of the speaker's voice and speaking style to guide the TTS model."
49
+ },
50
+ )
51
+ play_steps_s: float = field(
52
+ default=1.0,
53
+ metadata={
54
+ "help": "The time interval in seconds for playing back the generated speech in steps. Default is 0.5 seconds."
55
+ },
56
+ )
57
+ max_prompt_pad_length: int = field(
58
+ default=8,
59
+ metadata={
60
+ "help": "When using compilation, the prompt as to be padded to closest power of 2. This parameters sets the maximun power of 2 possible."
61
+ },
62
+ )
arguments_classes/socket_receiver_arguments.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class SocketReceiverArguments:
6
+ recv_host: str = field(
7
+ default="localhost",
8
+ metadata={
9
+ "help": "The host IP ddress for the socket connection. Default is '0.0.0.0' which binds to all "
10
+ "available interfaces on the host machine."
11
+ },
12
+ )
13
+ recv_port: int = field(
14
+ default=12345,
15
+ metadata={
16
+ "help": "The port number on which the socket server listens. Default is 12346."
17
+ },
18
+ )
19
+ chunk_size: int = field(
20
+ default=1024,
21
+ metadata={
22
+ "help": "The size of each data chunk to be sent or received over the socket. Default is 1024 bytes."
23
+ },
24
+ )
arguments_classes/socket_sender_arguments.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class SocketSenderArguments:
6
+ send_host: str = field(
7
+ default="localhost",
8
+ metadata={
9
+ "help": "The host IP address for the socket connection. Default is '0.0.0.0' which binds to all "
10
+ "available interfaces on the host machine."
11
+ },
12
+ )
13
+ send_port: int = field(
14
+ default=12346,
15
+ metadata={
16
+ "help": "The port number on which the socket server listens. Default is 12346."
17
+ },
18
+ )
arguments_classes/vad_arguments.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class VADHandlerArguments:
6
+ thresh: float = field(
7
+ default=0.3,
8
+ metadata={
9
+ "help": "The threshold value for voice activity detection (VAD). Values typically range from 0 to 1, with higher values requiring higher confidence in speech detection."
10
+ },
11
+ )
12
+ sample_rate: int = field(
13
+ default=16000,
14
+ metadata={
15
+ "help": "The sample rate of the audio in Hertz. Default is 16000 Hz, which is a common setting for voice audio."
16
+ },
17
+ )
18
+ min_silence_ms: int = field(
19
+ default=250,
20
+ metadata={
21
+ "help": "Minimum length of silence intervals to be used for segmenting speech. Measured in milliseconds. Default is 250 ms."
22
+ },
23
+ )
24
+ min_speech_ms: int = field(
25
+ default=500,
26
+ metadata={
27
+ "help": "Minimum length of speech segments to be considered valid speech. Measured in milliseconds. Default is 500 ms."
28
+ },
29
+ )
30
+ max_speech_ms: float = field(
31
+ default=float("inf"),
32
+ metadata={
33
+ "help": "Maximum length of continuous speech before forcing a split. Default is infinite, allowing for uninterrupted speech segments."
34
+ },
35
+ )
36
+ speech_pad_ms: int = field(
37
+ default=500,
38
+ metadata={
39
+ "help": "Amount of padding added to the beginning and end of detected speech segments. Measured in milliseconds. Default is 250 ms."
40
+ },
41
+ )
42
+ audio_enhancement: bool = field(
43
+ default=False,
44
+ metadata={
45
+ "help": "improves sound quality by applying techniques like noise reduction, equalization, and echo cancellation. Default is False."
46
+ },
47
+ )
arguments_classes/whisper_stt_arguments.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class WhisperSTTHandlerArguments:
7
+ stt_model_name: str = field(
8
+ default="distil-whisper/distil-large-v3",
9
+ metadata={
10
+ "help": "The pretrained Whisper model to use. Default is 'distil-whisper/distil-large-v3'."
11
+ },
12
+ )
13
+ stt_device: str = field(
14
+ default="cuda",
15
+ metadata={
16
+ "help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
17
+ },
18
+ )
19
+ stt_torch_dtype: str = field(
20
+ default="float16",
21
+ metadata={
22
+ "help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
23
+ },
24
+ )
25
+ stt_compile_mode: str = field(
26
+ default=None,
27
+ metadata={
28
+ "help": "Compile mode for torch compile. Either 'default', 'reduce-overhead' and 'max-autotune'. Default is None (no compilation)"
29
+ },
30
+ )
31
+ stt_gen_max_new_tokens: int = field(
32
+ default=128,
33
+ metadata={
34
+ "help": "The maximum number of new tokens to generate. Default is 128."
35
+ },
36
+ )
37
+ stt_gen_num_beams: int = field(
38
+ default=1,
39
+ metadata={
40
+ "help": "The number of beams for beam search. Default is 1, implying greedy decoding."
41
+ },
42
+ )
43
+ stt_gen_return_timestamps: bool = field(
44
+ default=False,
45
+ metadata={
46
+ "help": "Whether to return timestamps with transcriptions. Default is False."
47
+ },
48
+ )
49
+ stt_gen_task: str = field(
50
+ default="transcribe",
51
+ metadata={
52
+ "help": "The task to perform, typically 'transcribe' for transcription. Default is 'transcribe'."
53
+ },
54
+ )
55
+ language: Optional[str] = field(
56
+ default='en',
57
+ metadata={
58
+ "help": """The language for the conversation.
59
+ Choose between 'en' (english), 'fr' (french), 'es' (spanish),
60
+ 'zh' (chinese), 'ko' (korean), 'ja' (japanese), or 'None'.
61
+ If using 'auto', the language is automatically detected and can
62
+ change during the conversation. Default is 'en'."""
63
+ },
64
+ )