asahi417 commited on
Commit
429df62
1 Parent(s): 986454a
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sample_audio/* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea
2
+ *.egg-info
3
+ build
4
+ dist
5
+ *.ipynb_checkpoints
6
+ .DS_Store
7
+ .python-version
8
+ *.pyc
9
+ __pycache__
10
+ *.nfs000*
11
+ .eggs
12
+
README.md CHANGED
@@ -1,199 +1,142 @@
1
  ---
 
2
  library_name: transformers
3
- tags: []
 
 
 
 
 
 
 
 
 
4
  ---
5
 
6
- # Model Card for Model ID
7
-
8
- <!-- Provide a quick summary of what the model is/does. -->
9
-
10
-
11
-
12
- ## Model Details
13
-
14
- ### Model Description
15
-
16
- <!-- Provide a longer summary of what this model is. -->
17
-
18
- This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
-
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
-
28
- ### Model Sources [optional]
29
-
30
- <!-- Provide the basic links for the model. -->
31
-
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
35
-
36
- ## Uses
37
-
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
-
40
- ### Direct Use
41
-
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
-
44
- [More Information Needed]
45
-
46
- ### Downstream Use [optional]
47
-
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
-
50
- [More Information Needed]
51
-
52
- ### Out-of-Scope Use
53
-
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
-
56
- [More Information Needed]
57
-
58
- ## Bias, Risks, and Limitations
59
-
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
-
62
- [More Information Needed]
63
-
64
- ### Recommendations
65
-
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
-
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
-
70
- ## How to Get Started with the Model
71
-
72
- Use the code below to get started with the model.
73
-
74
- [More Information Needed]
75
-
76
- ## Training Details
77
-
78
- ### Training Data
79
-
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
-
82
- [More Information Needed]
83
-
84
- ### Training Procedure
85
-
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
-
88
- #### Preprocessing [optional]
89
-
90
- [More Information Needed]
91
-
92
-
93
- #### Training Hyperparameters
94
-
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
-
97
- #### Speeds, Sizes, Times [optional]
98
-
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
-
101
- [More Information Needed]
102
-
103
- ## Evaluation
104
-
105
- <!-- This section describes the evaluation protocols and provides the results. -->
106
-
107
- ### Testing Data, Factors & Metrics
108
-
109
- #### Testing Data
110
-
111
- <!-- This should link to a Dataset Card if possible. -->
112
-
113
- [More Information Needed]
114
-
115
- #### Factors
116
-
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
-
119
- [More Information Needed]
120
-
121
- #### Metrics
122
-
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
-
125
- [More Information Needed]
126
-
127
- ### Results
128
-
129
- [More Information Needed]
130
-
131
- #### Summary
132
-
133
-
134
-
135
- ## Model Examination [optional]
136
-
137
- <!-- Relevant interpretability work for the model goes here -->
138
-
139
- [More Information Needed]
140
-
141
- ## Environmental Impact
142
-
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
-
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
-
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
-
153
- ## Technical Specifications [optional]
154
-
155
- ### Model Architecture and Objective
156
-
157
- [More Information Needed]
158
-
159
- ### Compute Infrastructure
160
-
161
- [More Information Needed]
162
-
163
- #### Hardware
164
-
165
- [More Information Needed]
166
-
167
- #### Software
168
-
169
- [More Information Needed]
170
-
171
- ## Citation [optional]
172
-
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
-
175
- **BibTeX:**
176
-
177
- [More Information Needed]
178
-
179
- **APA:**
180
-
181
- [More Information Needed]
182
-
183
- ## Glossary [optional]
184
-
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
-
187
- [More Information Needed]
188
-
189
- ## More Information [optional]
190
-
191
- [More Information Needed]
192
-
193
- ## Model Card Authors [optional]
194
-
195
- [More Information Needed]
196
-
197
- ## Model Card Contact
198
-
199
- [More Information Needed]
 
1
  ---
2
+ language: ja
3
  library_name: transformers
4
+ license: apache-2.0
5
+ tags:
6
+ - audio
7
+ - automatic-speech-recognition
8
+ - hf-asr-leaderboard
9
+ widget:
10
+ - example_title: Sample 1
11
+ src: >-
12
+ https://huggingface.co/datasets/japanese-asr/ja_asr.common_voice_8_0/resolve/main/sample.flac
13
+ pipeline_tag: automatic-speech-recognition
14
  ---
15
 
16
+ # Kotoba-Whisper-v2.2
17
+ _Kotoba-Whisper-v2.2_ is a Japanese ASR model based on [kotoba-tech/kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0), with
18
+ additional postprocessing stacks integrated as [`pipeline`](https://huggingface.co/docs/transformers/en/main_classes/pipelines). The new features includes
19
+ (i) improved timestamp achieved by [stable-ts](https://github.com/jianfch/stable-ts) and (ii) adding punctuation with [punctuators](https://github.com/1-800-BAD-CODE/punctuators/tree/main).
20
+ These libraries are merged into Kotoba-Whisper-v2.1 via pipeline and will be applied seamlessly to the predicted transcription from [kotoba-tech/kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0).
21
+ The pipeline has been developed through the collaboration between [Asahi Ushio](https://asahiushio.com) and [Kotoba Technologies](https://twitter.com/kotoba_tech)
22
+
23
+
24
+ Following table presents the raw CER (unlike usual CER where the punctuations are removed before computing the metrics, see the evaluation script [here](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.1/blob/main/run_short_form_eval.py))
25
+ along with the.
26
+
27
+
28
+ | model | [CommonVoice 8 (Japanese test set)](https://huggingface.co/datasets/japanese-asr/ja_asr.common_voice_8_0) | [JSUT Basic 5000](https://huggingface.co/datasets/japanese-asr/ja_asr.jsut_basic5000) | [ReazonSpeech (held out test set)](https://huggingface.co/datasets/japanese-asr/ja_asr.reazonspeech_test) |
29
+ |:--------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------:|----------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------:|
30
+ | [kotoba-tech/kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0) | 17.6 | 15.4 | 17.4 |
31
+ | [kotoba-tech/kotoba-whisper-v2.1](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.1) | 17.7 | 15.4 | 17 |
32
+ | [kotoba-tech/kotoba-whisper-v2.1](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.1) (punctuator + stable-ts) | 17.7 | 15.4 | 17 |
33
+ | [kotoba-tech/kotoba-whisper-v2.1](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.1) (punctuator) | 17.7 | 15.4 | 17 |
34
+ | [kotoba-tech/kotoba-whisper-v2.1](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.1) (stable-ts) | 17.7 | 15.4 | 17 |
35
+ | [kotoba-tech/kotoba-whisper-v1.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v1.0) | 17.8 | 15.2 | 17.8 |
36
+ | [kotoba-tech/kotoba-whisper-v1.1](https://huggingface.co/kotoba-tech/kotoba-whisper-v1.1) | 17.9 | 15 | 17.8 |
37
+ | [kotoba-tech/kotoba-whisper-v1.1](https://huggingface.co/kotoba-tech/kotoba-whisper-v1.1) (punctuator + stable-ts) | 17.9 | 15 | 17.8 |
38
+ | [kotoba-tech/kotoba-whisper-v1.1](https://huggingface.co/kotoba-tech/kotoba-whisper-v1.1) (punctuator) | 17.9 | 15 | 17.8 |
39
+ | [kotoba-tech/kotoba-whisper-v1.1](https://huggingface.co/kotoba-tech/kotoba-whisper-v1.1) (stable-ts) | 17.9 | 15 | 17.8 |
40
+ | [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | 15.3 | 13.4 | 20.5 |
41
+ | [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) | 15.9 | 10.6 | 34.6 |
42
+ | [openai/whisper-large](https://huggingface.co/openai/whisper-large) | 16.6 | 11.3 | 40.7 |
43
+ | [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) | 17.9 | 13.1 | 39.3 |
44
+ | [openai/whisper-base](https://huggingface.co/openai/whisper-base) | 34.5 | 26.4 | 76 |
45
+ | [openai/whisper-small](https://huggingface.co/openai/whisper-small) | 21.5 | 18.9 | 48.1 |
46
+ | [openai/whisper-tiny](https://huggingface.co/openai/whisper-tiny) | 58.8 | 38.3 | 153.3 |
47
+
48
+
49
+ Regarding to the normalized CER, since those update from v2.1 will be removed by the normalization, kotoba-tech/kotoba-whisper-v2.1 marks the same CER values as [kotoba-tech/kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0).
50
+
51
+ ### Latency
52
+ Please refer to the section of the latency in the kotoba-whisper-v1.1 [here](https://huggingface.co/kotoba-tech/kotoba-whisper-v1.1#latency).
53
+
54
+ ## Transformers Usage
55
+ Kotoba-Whisper-v2.1 is supported in the Hugging Face 🤗 Transformers library from version 4.39 onwards. To run the model, first
56
+ install the latest version of Transformers.
57
+
58
+ ```bash
59
+ pip install --upgrade pip
60
+ pip install --upgrade transformers accelerate torchaudio
61
+ pip install stable-ts==2.16.0
62
+ pip install punctuators==0.0.5
63
+ ```
64
+
65
+ ### Transcription
66
+ The model can be used with the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline)
67
+ class to transcribe audio files as follows:
68
+
69
+ ```python
70
+ import torch
71
+ from transformers import pipeline
72
+ from datasets import load_dataset
73
+
74
+ # config
75
+ model_id = "kotoba-tech/kotoba-whisper-v2.1"
76
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
77
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
78
+ model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
79
+ generate_kwargs = {"language": "ja", "task": "transcribe"}
80
+
81
+ # load model
82
+ pipe = pipeline(
83
+ model=model_id,
84
+ torch_dtype=torch_dtype,
85
+ device=device,
86
+ model_kwargs=model_kwargs,
87
+ chunk_length_s=15,
88
+ batch_size=16,
89
+ trust_remote_code=True,
90
+ stable_ts=True,
91
+ punctuator=True
92
+ )
93
+
94
+ # load sample audio
95
+ dataset = load_dataset("japanese-asr/ja_asr.reazonspeech_test", split="test")
96
+ sample = dataset[0]["audio"]
97
+
98
+ # run inference
99
+ result = pipe(sample, return_timestamps=True, generate_kwargs=generate_kwargs)
100
+ print(result)
101
+ ```
102
+
103
+ - To transcribe a local audio file, simply pass the path to your audio file when you call the pipeline:
104
+ ```diff
105
+ - result = pipe(sample, return_timestamps=True, generate_kwargs=generate_kwargs)
106
+ + result = pipe("audio.mp3", return_timestamps=True, generate_kwargs=generate_kwargs)
107
+ ```
108
+
109
+ - To deactivate stable-ts:
110
+ ```diff
111
+ - stable_ts=True,
112
+ + stable_ts=False,
113
+ ```
114
+
115
+ - To deactivate punctuator:
116
+ ```diff
117
+ - punctuator=True,
118
+ + punctuator=False,
119
+ ```
120
+
121
+
122
+ ### Flash Attention 2
123
+ We recommend using [Flash-Attention 2](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#flashattention-2)
124
+ if your GPU allows for it. To do so, you first need to install [Flash Attention](https://github.com/Dao-AILab/flash-attention):
125
+
126
+ ```
127
+ pip install flash-attn --no-build-isolation
128
+ ```
129
+
130
+ Then pass `attn_implementation="flash_attention_2"` to `from_pretrained`:
131
+
132
+ ```diff
133
+ - model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
134
+ + model_kwargs = {"attn_implementation": "flash_attention_2"} if torch.cuda.is_available() else {}
135
+ ```
136
+
137
+
138
+ ## Acknowledgements
139
+ * [OpenAI](https://openai.com/) for the Whisper [model](https://huggingface.co/openai/whisper-large-v3).
140
+ * Hugging Face 🤗 [Transformers](https://github.com/huggingface/transformers) for the model integration.
141
+ * Hugging Face 🤗 for the [Distil-Whisper codebase](https://github.com/huggingface/distil-whisper).
142
+ * [Reazon Human Interaction Lab](https://research.reazon.jp/) for the [ReazonSpeech dataset](https://huggingface.co/datasets/reazon-research/reazonspeech).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline/kotoba_whisper.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional, Dict, List, Any
2
+ import requests
3
+
4
+ import torch
5
+ import numpy as np
6
+
7
+ from transformers.pipelines.audio_utils import ffmpeg_read
8
+ from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline, chunk_iter
9
+ from transformers.utils import is_torchaudio_available
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.tokenization_utils import PreTrainedTokenizer
12
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
13
+ from pyannote.audio import Pipeline
14
+ from pyannote.core.annotation import Annotation
15
+ from punctuators.models import PunctCapSegModelONNX
16
+
17
+
18
+ class Punctuator:
19
+
20
+ ja_punctuations = ["!", "?", "、", "。"]
21
+
22
+ def __init__(self, model: str = "pcs_47lang"):
23
+ self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
24
+
25
+ def punctuate(self, pipeline_chunk: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
26
+
27
+ def validate_punctuation(raw: str, punctuated: str):
28
+ if 'unk' in punctuated.lower() or any(p in raw for p in self.ja_punctuations):
29
+ return raw
30
+ if punctuated.count("。") > 1:
31
+ ind = punctuated.rfind("。")
32
+ punctuated = punctuated.replace("。", "")
33
+ punctuated = punctuated[:ind] + "。" + punctuated[ind:]
34
+ return punctuated
35
+
36
+ text_edit = self.punctuation_model.infer([c['text'] for c in pipeline_chunk])
37
+ return [
38
+ {
39
+ 'timestamp': c['timestamp'],
40
+ 'text': validate_punctuation(c['text'], "".join(e))
41
+ } for c, e in zip(pipeline_chunk, text_edit)
42
+ ]
43
+
44
+
45
+
46
+ class SpeakerDiarization:
47
+
48
+ def __init__(self, model_id: str, device: torch.device):
49
+ self.device = device
50
+ self.pipeline = Pipeline.from_pretrained(model_id)
51
+ self.pipeline = self.pipeline.to(self.device)
52
+
53
+ def __call__(self,
54
+ audio: Union[str, torch.Tensor, np.ndarray],
55
+ sampling_rate: Optional[int] = None) -> Annotation:
56
+ if type(audio) is torch.Tensor or type(audio) is np.ndarray:
57
+ if sampling_rate is None:
58
+ raise ValueError("sampling_rate must be provided")
59
+ if type(audio) is np.ndarray:
60
+ audio = torch.as_tensor(audio)
61
+ audio = torch.as_tensor(audio, dtype=torch.float32)
62
+ if len(audio.shape) == 1:
63
+ audio = audio.unsqueeze(0)
64
+ elif len(audio.shape) > 3:
65
+ raise ValueError("audio shape must be (channel, time)")
66
+ audio = {"waveform": audio.to(self.device), "sample_rate": sampling_rate}
67
+ output = self.pipeline(audio)
68
+ return output
69
+
70
+
71
+ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
72
+
73
+ def __init__(self,
74
+ model: "PreTrainedModel",
75
+ model_diarizarization: str="pyannote/speaker-diarization-3.1",
76
+ feature_extractor: Union["SequenceFeatureExtractor", str] = None,
77
+ tokenizer: Optional[PreTrainedTokenizer] = None,
78
+ device: Union[int, "torch.device"] = None,
79
+ device_diarizarization: Union[int, "torch.device"] = None,
80
+ torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
81
+ return_unique_speaker: bool = False,
82
+ punctuator: bool = False,
83
+ **kwargs):
84
+ self.type = "seq2seq_whisper"
85
+ if device is None:
86
+ device = "cpu"
87
+ if device_diarizarization is None:
88
+ device_diarizarization = device
89
+ if type(device_diarizarization) is str:
90
+ device_diarizarization = torch.device(device_diarizarization)
91
+ self.model_speaker_diarization = SpeakerDiarization(model_diarizarization, device_diarizarization)
92
+ self.return_unique_speaker = return_unique_speaker
93
+ if punctuator:
94
+ self.punctuator = Punctuator()
95
+ else:
96
+ self.punctuator = None
97
+ super().__init__(
98
+ model=model,
99
+ feature_extractor=feature_extractor,
100
+ tokenizer=tokenizer,
101
+ device=device,
102
+ torch_dtype=torch_dtype,
103
+ **kwargs
104
+ )
105
+
106
+ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
107
+ if isinstance(inputs, str):
108
+ if inputs.startswith("http://") or inputs.startswith("https://"):
109
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
110
+ # like http_huggingface_co.png
111
+ inputs = requests.get(inputs).content
112
+ else:
113
+ with open(inputs, "rb") as f:
114
+ inputs = f.read()
115
+
116
+ if isinstance(inputs, bytes):
117
+ inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
118
+
119
+ stride = None
120
+ extra = {}
121
+ if isinstance(inputs, dict):
122
+ stride = inputs.pop("stride", None)
123
+ # Accepting `"array"` which is the key defined in `datasets` for
124
+ # better integration
125
+ if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
126
+ raise ValueError(
127
+ "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
128
+ '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
129
+ "containing the sampling_rate associated with that array"
130
+ )
131
+
132
+ _inputs = inputs.pop("raw", None)
133
+ if _inputs is None:
134
+ # Remove path which will not be used from `datasets`.
135
+ inputs.pop("path", None)
136
+ _inputs = inputs.pop("array", None)
137
+ in_sampling_rate = inputs.pop("sampling_rate")
138
+ extra = inputs
139
+ inputs = _inputs
140
+ if in_sampling_rate != self.feature_extractor.sampling_rate:
141
+ if is_torchaudio_available():
142
+ from torchaudio import functional as F
143
+ else:
144
+ raise ImportError(
145
+ "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
146
+ "The torchaudio package can be installed through: `pip install torchaudio`."
147
+ )
148
+
149
+ inputs = F.resample(
150
+ torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
151
+ ).numpy()
152
+ ratio = self.feature_extractor.sampling_rate / in_sampling_rate
153
+ else:
154
+ ratio = 1
155
+ if stride is not None:
156
+ if stride[0] + stride[1] > inputs.shape[0]:
157
+ raise ValueError("Stride is too large for input")
158
+
159
+ # Stride needs to get the chunk length here, it's going to get
160
+ # swallowed by the `feature_extractor` later, and then batching
161
+ # can add extra data in the inputs, so we need to keep track
162
+ # of the original length in the stride so we can cut properly.
163
+ stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
164
+ if not isinstance(inputs, np.ndarray):
165
+ raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
166
+ if len(inputs.shape) != 1:
167
+ raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
168
+
169
+ if chunk_length_s:
170
+ if stride_length_s is None:
171
+ stride_length_s = chunk_length_s / 6
172
+
173
+ if isinstance(stride_length_s, (int, float)):
174
+ stride_length_s = [stride_length_s, stride_length_s]
175
+
176
+ # XXX: Carefuly, this variable will not exist in `seq2seq` setting.
177
+ # Currently chunking is not possible at this level for `seq2seq` so
178
+ # it's ok.
179
+ align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
180
+ chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
181
+ stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
182
+ stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
183
+
184
+ if chunk_len < stride_left + stride_right:
185
+ raise ValueError("Chunk length must be superior to stride length")
186
+
187
+ for item in chunk_iter(
188
+ inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
189
+ ):
190
+ item["audio_array"] = inputs
191
+ yield item
192
+ else:
193
+ if inputs.shape[0] > self.feature_extractor.n_samples:
194
+ processed = self.feature_extractor(
195
+ inputs,
196
+ sampling_rate=self.feature_extractor.sampling_rate,
197
+ truncation=False,
198
+ padding="longest",
199
+ return_tensors="pt",
200
+ )
201
+ else:
202
+ processed = self.feature_extractor(
203
+ inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
204
+ )
205
+
206
+ if self.torch_dtype is not None:
207
+ processed = processed.to(dtype=self.torch_dtype)
208
+ if stride is not None:
209
+ processed["stride"] = stride
210
+ yield {"is_last": True, "audio_array": inputs, **processed, **extra}
211
+
212
+ def _forward(self, model_inputs, **generate_kwargs):
213
+ attention_mask = model_inputs.pop("attention_mask", None)
214
+ stride = model_inputs.pop("stride", None)
215
+ is_last = model_inputs.pop("is_last")
216
+ audio_array = model_inputs.pop("audio_array")
217
+ encoder = self.model.get_encoder()
218
+ # Consume values so we can let extra information flow freely through
219
+ # the pipeline (important for `partial` in microphone)
220
+ if "input_features" in model_inputs:
221
+ inputs = model_inputs.pop("input_features")
222
+ elif "input_values" in model_inputs:
223
+ inputs = model_inputs.pop("input_values")
224
+ else:
225
+ raise ValueError(
226
+ "Seq2Seq speech recognition model requires either a "
227
+ f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
228
+ )
229
+
230
+ # custom processing for Whisper timestamps and word-level timestamps
231
+ generate_kwargs["return_timestamps"] = True
232
+ if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
233
+ generate_kwargs["input_features"] = inputs
234
+ else:
235
+ generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
236
+
237
+ tokens = self.model.generate(attention_mask=attention_mask, **generate_kwargs)
238
+ # whisper longform generation stores timestamps in "segments"
239
+ out = {"tokens": tokens}
240
+ if self.type == "seq2seq_whisper":
241
+ if stride is not None:
242
+ out["stride"] = stride
243
+
244
+ # Leftover
245
+ extra = model_inputs
246
+ return {"is_last": is_last, "audio_array": audio_array, **out, **extra}
247
+
248
+ def postprocess(self,
249
+ model_outputs,
250
+ decoder_kwargs: Optional[Dict] = None,
251
+ return_language=None,
252
+ *args,
253
+ **kwargs):
254
+ assert len(model_outputs) > 0
255
+ audio_array = list(model_outputs)[0]["audio_array"]
256
+ sd = self.model_speaker_diarization(audio_array, sampling_rate=self.feature_extractor.sampling_rate)
257
+ timelines = sd.get_timeline()
258
+ outputs = super().postprocess(
259
+ model_outputs=model_outputs,
260
+ decoder_kwargs=decoder_kwargs,
261
+ return_timestamps=True,
262
+ return_language=return_language
263
+ )
264
+ pointer_ts = 0
265
+ pointer_chunk = 0
266
+ new_chunks = []
267
+ while True:
268
+ if pointer_ts == len(timelines):
269
+ ts = timelines[-1]
270
+ for chunk in outputs["chunks"][pointer_chunk:]:
271
+ chunk["speaker"] = sd.get_labels(ts)
272
+ new_chunks.append(chunk)
273
+ break
274
+ if pointer_chunk == len(outputs["chunks"]):
275
+ break
276
+ ts = timelines[pointer_ts]
277
+
278
+ chunk = outputs["chunks"][pointer_chunk]
279
+ if "speaker" not in chunk:
280
+ chunk["speaker"] = []
281
+
282
+ start, end = chunk["timestamp"]
283
+ if ts.end <= start:
284
+ pointer_ts += 1
285
+ elif end <= ts.start:
286
+ if len(chunk["speaker"]) == 0:
287
+ chunk["speaker"] += list(sd.get_labels(ts))
288
+ new_chunks.append(chunk)
289
+ pointer_chunk += 1
290
+ else:
291
+ chunk["speaker"] += list(sd.get_labels(ts))
292
+ if ts.end >= end:
293
+ new_chunks.append(chunk)
294
+ pointer_chunk += 1
295
+ else:
296
+ pointer_ts += 1
297
+ for i in new_chunks:
298
+ if "speaker" in i:
299
+ if self.return_unique_speaker:
300
+ i["speaker"] = [i["speaker"][0]]
301
+ else:
302
+ i["speaker"] = list(set(i["speaker"]))
303
+ else:
304
+ i["speaker"] = []
305
+ outputs["chunks"] = new_chunks
306
+ if self.punctuator:
307
+ outputs["chunks"] = self.punctuator.punctuate(outputs["chunks"])
308
+ outputs["text"] = "".join([c["text"] for c in outputs["chunks"]])
309
+ outputs["speakers"] = sd.labels()
310
+ outputs.pop("audio_array")
311
+ for s in outputs["speakers"]:
312
+ outputs[f"text/{s}"] = "".join([c["text"] for c in outputs["chunks"] if s in c["speaker"]])
313
+ outputs[f"chunks/{s}"] = [c for c in outputs["chunks"] if s in c["speaker"]]
314
+ return outputs
315
+
pipeline/push_pipeline.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ from kotoba_whisper import KotobaWhisperPipeline
3
+ from transformers.pipelines import PIPELINE_REGISTRY, pipeline
4
+ from transformers import WhisperForConditionalGeneration, TFWhisperForConditionalGeneration
5
+
6
+
7
+ model_alias = "kotoba-tech/kotoba-whisper-v2.2"
8
+ PIPELINE_REGISTRY.register_pipeline(
9
+ "kotoba-whisper",
10
+ pipeline_class=KotobaWhisperPipeline,
11
+ pt_model=WhisperForConditionalGeneration,
12
+ tf_model=TFWhisperForConditionalGeneration
13
+ )
14
+ test_audio = "/Users/asahiu/Desktop/speaker_diariazation_sample_1.wav"
15
+ pipe = pipeline(task="kotoba-whisper", model="kotoba-tech/kotoba-whisper-v2.0", chunk_length_s=15, batch_size=16, return_unique_speaker=True)
16
+ output = pipe(test_audio)
17
+ pprint(output)
18
+ pipe = pipeline(task="kotoba-whisper", model="kotoba-tech/kotoba-whisper-v2.0", chunk_length_s=15, batch_size=16)
19
+ output = pipe(test_audio)
20
+ pprint(output)
21
+ pipe.push_to_hub(model_alias)
22
+
23
+
pipeline/test_pipeline.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ from transformers.pipelines import pipeline
3
+
4
+ test_audio = "/Users/asahiu/Desktop/speaker_diariazation_sample_1.wav"
5
+ pipe = pipeline(model="kotoba-tech/kotoba-whisper-v2.2", chunk_length_s=15, batch_size=16, trust_remote_code=True)
6
+ output = pipe(test_audio)
7
+ pprint(output)
pipeline/test_speaker_diarization.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Setup:
2
+ # pip install pyannote.audio>=3.1
3
+ # Requirement: Sumit access request for the following models.
4
+ # https://huggingface.co/pyannote/speaker-diarization-3.1
5
+ # https://huggingface.co/pyannote/segmentation-3.0
6
+ import soundfile as sf
7
+ import numpy as np
8
+ from typing import Union, Optional, Dict, List
9
+
10
+ import torch
11
+ from pyannote.audio import Pipeline
12
+
13
+
14
+ class SpeakerDiarization:
15
+
16
+ def __init__(self, model_id: str):
17
+ self.pipeline = Pipeline.from_pretrained(model_id)
18
+
19
+ def __call__(self,
20
+ audio: Union[str, torch.Tensor, np.ndarray],
21
+ sampling_rate: Optional[int] = None) -> Dict[str, List[List[float]]]:
22
+ if type(audio) is torch.Tensor or type(audio) is np.ndarray:
23
+ if sampling_rate is None:
24
+ raise ValueError("sampling_rate must be provided")
25
+ if type(audio) is np.ndarray:
26
+ audio = torch.as_tensor(audio)
27
+ audio = torch.as_tensor(audio, dtype=torch.float32)
28
+ if len(audio.shape) == 1:
29
+ audio = audio.unsqueeze(0)
30
+ elif len(audio.shape) > 3:
31
+ raise ValueError("audio shape must be (channel, time)")
32
+ audio = {"waveform": audio, "sample_rate": sampling_rate}
33
+ output = self.pipeline(audio)
34
+ # dictionary: {speaker_id: [[start, end],...]}
35
+ return {s: [[i.start, i.end] for i in output.label_timeline(s)] for s in output.labels()}
36
+
37
+
38
+ pipeline = SpeakerDiarization("pyannote/speaker-diarization-3.1")
39
+ root_dir = "/Users/asahiu/Desktop"
40
+ sample_audio_files = ["speaker_diariazation_sample_1.wav", "speaker_diariazation_sample_2.wav"]
41
+
42
+ print(sample_audio_file)
43
+ a, sr = sf.read(f"{root_dir}/{sample_audio_file}")
44
+ output = pipeline(a, sampling_rate=sr)
45
+ print(output)
46
+ output = pipeline(f"{root_dir}/{sample_audio_file}")
47
+ print(output)
48
+ print()
sample_audio/sample_diarization_japanese.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7252359b53264c767da33a48e39ff57a8f31641c4a80a1702c6940f8914697b
3
+ size 780064