SaraAlthubaiti commited on
Commit
07d6770
·
verified ·
1 Parent(s): da9202d
Files changed (4) hide show
  1. inference.py +43 -0
  2. requirements.txt +13 -0
  3. test_prompt.json +31 -0
  4. utils.py +158 -0
inference.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import WhisperFeatureExtractor
3
+ from models.tinyoctopus import TINYOCTOPUS
4
+ from utils import prepare_one_sample
5
+
6
+ # Load model
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model = TINYOCTOPUS.from_config(cfg.config.model)
9
+ model.to(device)
10
+ model.eval()
11
+
12
+ # Load processor
13
+ wav_processor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3")
14
+
15
+ def transcribe(audio_path, task="dialect"):
16
+ """
17
+ Perform inference on an audio file.
18
+
19
+ Args:
20
+ audio_path (str): Path to the audio file.
21
+ task (str): Task to perform. Options: "dialect", "asr", "translation".
22
+
23
+ Returns:
24
+ str: The generated text.
25
+ """
26
+ task_prompts = {
27
+ "dialect": "What is the dialect of the speaker?",
28
+ "asr": "تعرف على الكلام وأعطني النص.",
29
+ "translation": "الرجاء ترجمة هذا المقطع الصوتي إلى اللغة الإنجليزية."
30
+ }
31
+
32
+ if task not in task_prompts:
33
+ raise ValueError("Invalid task. Choose from: 'dialect', 'asr', or 'translation'.")
34
+
35
+ try:
36
+ prompt = task_prompts[task]
37
+ samples = prepare_one_sample(audio_path, wav_processor)
38
+ prompt = [f"<Speech><SpeechHere></Speech> {prompt.strip()}"]
39
+ generated_text = model.generate(samples, {"temperature": 0.7}, prompts=prompt)[0]
40
+ return generated_text.replace('<s>', '').replace('</s>', '').strip()
41
+
42
+ except Exception as e:
43
+ return f"Error: {e}"
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchaudio==2.0.2
3
+ peft==0.3.0
4
+ soundfile
5
+ librosa
6
+ transformers==4.28.0
7
+ sentencepiece==0.1.97
8
+ accelerate==0.20.3
9
+ bitsandbytes==0.35.0
10
+ gradio==3.23.0
11
+ safetensors
12
+ tensorboardX
13
+ jiwer
test_prompt.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "asr": "<Speech><SpeechHere></Speech> Recognize the speech and give me the transcription.",
3
+ "gender_recognition": "<Speech><SpeechHere></Speech> What is the gender of the speaker?",
4
+ "dialect_identification": "<Speech><SpeechHere></Speech> What is the dialect of the speaker?",
5
+ "asr_zh": "<Speech><SpeechHere></Speech> 请将语音中的内容写下来。",
6
+ "summarization": "<Speech><SpeechHere></Speech> Could you capture the main points of this audio in a short summary?",
7
+ "translation_ae": "<Speech><SpeechHere></Speech> Listen to the speech and translate it into English.",
8
+ "asr_de": "<Speech><SpeechHere></Speech> Hören Sie sich die Rede an und schreiben Sie ihren Inhalt auf.",
9
+ "translation_ec": "<Speech><SpeechHere></Speech> Listen to the speech and translate it into Chinese.",
10
+ "audiocaption": "<Speech><SpeechHere></Speech> Please describe the audio.",
11
+ "audiocaption_v2": "<Speech><SpeechHere></Speech> Please write down what your hear in the audio.",
12
+ "QA": "<Speech><SpeechHere></Speech> {}",
13
+ "gender_QA": "<Speech><SpeechHere></Speech> {}",
14
+ "phone_recognition": "<Speech><SpeechHere></Speech> Provide the phonetic transcription for the speech.",
15
+ "speech_query": "<Speech><SpeechHere></Speech> Please answer the question in detail.",
16
+ "emotion_recognition": "<Speech><SpeechHere></Speech> Describe the emotion of the speaker in one word.",
17
+ "lyrics_recognition": "<Speech><SpeechHere></Speech> Listen to the song and write down its content.",
18
+ "audio_speech_description": "<Speech><SpeechHere></Speech> Describe the speech and the background audio",
19
+ "speaker_verification": "<Speech><SpeechHere></Speech> Do you only hear the same person talking? Answer yes or no.",
20
+ "fluent_speech_audio": "<Speech><SpeechHere></Speech> Describe the background audio and the speech in a fluent sentence.",
21
+ "speech_separation": "<Speech><SpeechHere></Speech> Please write down what you hear each person says.",
22
+ "audio_story_telling": "<Speech><SpeechHere></Speech> Based on the audio, write a story in detail. Your story should be highly related to the audio.",
23
+ "speech_audio_query": "<Speech><SpeechHere></Speech> Please answer the speaker's question in detail based on the background sound.",
24
+ "slot_filling": "<Speech><SpeechHere></Speech> According to the speech, what is the {}?",
25
+ "music_description": "<Speech><SpeechHere></Speech> Listen to this music clip and describe the music.",
26
+ "translation_en2ja": "<Speech><SpeechHere></Speech> Listen to the speech and translate it into Japanese.",
27
+ "translation_en2de": "<Speech><SpeechHere></Speech> Listen to the speech and translate it into German.",
28
+ "speech_audio_coreasoning": "<Speech><SpeechHere></Speech> Use your strong reasoning skills to answer the speaker's question in detail based on the background sound.",
29
+ "keywords": "<Speech><SpeechHere></Speech> Give me only three keywords of the text.",
30
+ "speaker_diarization_asr": "<Speech><SpeechHere></Speech> Please recognize each speaker and transcribe their speech content."
31
+ }
utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import time
17
+
18
+ import torch
19
+ from torch.utils.data import DataLoader, DistributedSampler
20
+ import soundfile as sf
21
+ import numpy as np
22
+
23
+ from dist_utils import is_main_process, get_world_size, get_rank
24
+
25
+
26
+ def now():
27
+ from datetime import datetime
28
+
29
+ return datetime.now().strftime("%Y%m%d%H%M")
30
+
31
+
32
+ def setup_logger():
33
+ logging.basicConfig(
34
+ level=logging.INFO if is_main_process() else logging.WARN,
35
+ format="%(asctime)s [%(levelname)s] %(message)s",
36
+ handlers=[logging.StreamHandler()],
37
+ )
38
+
39
+
40
+ def get_dataloader(dataset, config, is_train=True, use_distributed=True):
41
+ if use_distributed:
42
+ sampler = DistributedSampler(
43
+ dataset,
44
+ shuffle=is_train,
45
+ num_replicas=get_world_size(),
46
+ rank=get_rank()
47
+ )
48
+ else:
49
+ sampler = None
50
+
51
+ loader = DataLoader(
52
+ dataset,
53
+ batch_size=config.batch_size_train if is_train else config.batch_size_eval,
54
+ num_workers=config.num_workers,
55
+ pin_memory=True,
56
+ sampler=sampler,
57
+ shuffle=sampler is None and is_train,
58
+ collate_fn=dataset.collater,
59
+ drop_last=is_train,
60
+ )
61
+
62
+ if is_train:
63
+ loader = IterLoader(loader, use_distributed=use_distributed)
64
+
65
+ return loader
66
+
67
+
68
+ def apply_to_sample(f, sample):
69
+ if len(sample) == 0:
70
+ return {}
71
+
72
+ def _apply(x):
73
+ if torch.is_tensor(x):
74
+ return f(x)
75
+ elif isinstance(x, dict):
76
+ return {key: _apply(value) for key, value in x.items()}
77
+ elif isinstance(x, list):
78
+ return [_apply(x) for x in x]
79
+ else:
80
+ return x
81
+
82
+ return _apply(sample)
83
+
84
+
85
+ def move_to_cuda(sample):
86
+ def _move_to_cuda(tensor):
87
+ return tensor.cuda()
88
+
89
+ return apply_to_sample(_move_to_cuda, sample)
90
+
91
+
92
+ def prepare_sample(samples, cuda_enabled=True):
93
+ if cuda_enabled:
94
+ samples = move_to_cuda(samples)
95
+
96
+ # TODO fp16 support
97
+
98
+ return samples
99
+
100
+
101
+ class IterLoader:
102
+ """
103
+ A wrapper to convert DataLoader as an infinite iterator.
104
+
105
+ Modified from:
106
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
107
+ """
108
+
109
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
110
+ self._dataloader = dataloader
111
+ self.iter_loader = iter(self._dataloader)
112
+ self._use_distributed = use_distributed
113
+ self._epoch = 0
114
+
115
+ @property
116
+ def epoch(self) -> int:
117
+ return self._epoch
118
+
119
+ def __next__(self):
120
+ try:
121
+ data = next(self.iter_loader)
122
+ except StopIteration:
123
+ self._epoch += 1
124
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
125
+ self._dataloader.sampler.set_epoch(self._epoch)
126
+ time.sleep(2) # Prevent possible deadlock during epoch transition
127
+ self.iter_loader = iter(self._dataloader)
128
+ data = next(self.iter_loader)
129
+
130
+ return data
131
+
132
+ def __iter__(self):
133
+ return self
134
+
135
+ def __len__(self):
136
+ return len(self._dataloader)
137
+
138
+
139
+ def prepare_one_sample(wav_path, wav_processor, cuda_enabled=True):
140
+ audio, sr = sf.read(wav_path)
141
+ if len(audio.shape) == 2: # stereo to mono
142
+ audio = audio[:, 0]
143
+ if len(audio) < sr: # pad audio to at least 1s
144
+ sil = np.zeros(sr - len(audio), dtype=float)
145
+ audio = np.concatenate((audio, sil), axis=0)
146
+ audio = audio[: sr * 30] # truncate audio to at most 30s
147
+
148
+ spectrogram = wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"]
149
+
150
+ samples = {
151
+ "spectrogram": spectrogram,
152
+ "raw_wav": torch.from_numpy(audio).unsqueeze(0),
153
+ "padding_mask": torch.zeros(len(audio), dtype=torch.bool).unsqueeze(0),
154
+ }
155
+ if cuda_enabled:
156
+ samples = move_to_cuda(samples)
157
+
158
+ return samples