- inference.py +43 -0
- requirements.txt +13 -0
- test_prompt.json +31 -0
- utils.py +158 -0
inference.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import WhisperFeatureExtractor
|
| 3 |
+
from models.tinyoctopus import TINYOCTOPUS
|
| 4 |
+
from utils import prepare_one_sample
|
| 5 |
+
|
| 6 |
+
# Load model
|
| 7 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 8 |
+
model = TINYOCTOPUS.from_config(cfg.config.model)
|
| 9 |
+
model.to(device)
|
| 10 |
+
model.eval()
|
| 11 |
+
|
| 12 |
+
# Load processor
|
| 13 |
+
wav_processor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3")
|
| 14 |
+
|
| 15 |
+
def transcribe(audio_path, task="dialect"):
|
| 16 |
+
"""
|
| 17 |
+
Perform inference on an audio file.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
audio_path (str): Path to the audio file.
|
| 21 |
+
task (str): Task to perform. Options: "dialect", "asr", "translation".
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
str: The generated text.
|
| 25 |
+
"""
|
| 26 |
+
task_prompts = {
|
| 27 |
+
"dialect": "What is the dialect of the speaker?",
|
| 28 |
+
"asr": "تعرف على الكلام وأعطني النص.",
|
| 29 |
+
"translation": "الرجاء ترجمة هذا المقطع الصوتي إلى اللغة الإنجليزية."
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
if task not in task_prompts:
|
| 33 |
+
raise ValueError("Invalid task. Choose from: 'dialect', 'asr', or 'translation'.")
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
prompt = task_prompts[task]
|
| 37 |
+
samples = prepare_one_sample(audio_path, wav_processor)
|
| 38 |
+
prompt = [f"<Speech><SpeechHere></Speech> {prompt.strip()}"]
|
| 39 |
+
generated_text = model.generate(samples, {"temperature": 0.7}, prompts=prompt)[0]
|
| 40 |
+
return generated_text.replace('<s>', '').replace('</s>', '').strip()
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
return f"Error: {e}"
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.0.1
|
| 2 |
+
torchaudio==2.0.2
|
| 3 |
+
peft==0.3.0
|
| 4 |
+
soundfile
|
| 5 |
+
librosa
|
| 6 |
+
transformers==4.28.0
|
| 7 |
+
sentencepiece==0.1.97
|
| 8 |
+
accelerate==0.20.3
|
| 9 |
+
bitsandbytes==0.35.0
|
| 10 |
+
gradio==3.23.0
|
| 11 |
+
safetensors
|
| 12 |
+
tensorboardX
|
| 13 |
+
jiwer
|
test_prompt.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"asr": "<Speech><SpeechHere></Speech> Recognize the speech and give me the transcription.",
|
| 3 |
+
"gender_recognition": "<Speech><SpeechHere></Speech> What is the gender of the speaker?",
|
| 4 |
+
"dialect_identification": "<Speech><SpeechHere></Speech> What is the dialect of the speaker?",
|
| 5 |
+
"asr_zh": "<Speech><SpeechHere></Speech> 请将语音中的内容写下来。",
|
| 6 |
+
"summarization": "<Speech><SpeechHere></Speech> Could you capture the main points of this audio in a short summary?",
|
| 7 |
+
"translation_ae": "<Speech><SpeechHere></Speech> Listen to the speech and translate it into English.",
|
| 8 |
+
"asr_de": "<Speech><SpeechHere></Speech> Hören Sie sich die Rede an und schreiben Sie ihren Inhalt auf.",
|
| 9 |
+
"translation_ec": "<Speech><SpeechHere></Speech> Listen to the speech and translate it into Chinese.",
|
| 10 |
+
"audiocaption": "<Speech><SpeechHere></Speech> Please describe the audio.",
|
| 11 |
+
"audiocaption_v2": "<Speech><SpeechHere></Speech> Please write down what your hear in the audio.",
|
| 12 |
+
"QA": "<Speech><SpeechHere></Speech> {}",
|
| 13 |
+
"gender_QA": "<Speech><SpeechHere></Speech> {}",
|
| 14 |
+
"phone_recognition": "<Speech><SpeechHere></Speech> Provide the phonetic transcription for the speech.",
|
| 15 |
+
"speech_query": "<Speech><SpeechHere></Speech> Please answer the question in detail.",
|
| 16 |
+
"emotion_recognition": "<Speech><SpeechHere></Speech> Describe the emotion of the speaker in one word.",
|
| 17 |
+
"lyrics_recognition": "<Speech><SpeechHere></Speech> Listen to the song and write down its content.",
|
| 18 |
+
"audio_speech_description": "<Speech><SpeechHere></Speech> Describe the speech and the background audio",
|
| 19 |
+
"speaker_verification": "<Speech><SpeechHere></Speech> Do you only hear the same person talking? Answer yes or no.",
|
| 20 |
+
"fluent_speech_audio": "<Speech><SpeechHere></Speech> Describe the background audio and the speech in a fluent sentence.",
|
| 21 |
+
"speech_separation": "<Speech><SpeechHere></Speech> Please write down what you hear each person says.",
|
| 22 |
+
"audio_story_telling": "<Speech><SpeechHere></Speech> Based on the audio, write a story in detail. Your story should be highly related to the audio.",
|
| 23 |
+
"speech_audio_query": "<Speech><SpeechHere></Speech> Please answer the speaker's question in detail based on the background sound.",
|
| 24 |
+
"slot_filling": "<Speech><SpeechHere></Speech> According to the speech, what is the {}?",
|
| 25 |
+
"music_description": "<Speech><SpeechHere></Speech> Listen to this music clip and describe the music.",
|
| 26 |
+
"translation_en2ja": "<Speech><SpeechHere></Speech> Listen to the speech and translate it into Japanese.",
|
| 27 |
+
"translation_en2de": "<Speech><SpeechHere></Speech> Listen to the speech and translate it into German.",
|
| 28 |
+
"speech_audio_coreasoning": "<Speech><SpeechHere></Speech> Use your strong reasoning skills to answer the speaker's question in detail based on the background sound.",
|
| 29 |
+
"keywords": "<Speech><SpeechHere></Speech> Give me only three keywords of the text.",
|
| 30 |
+
"speaker_diarization_asr": "<Speech><SpeechHere></Speech> Please recognize each speaker and transcribe their speech content."
|
| 31 |
+
}
|
utils.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 20 |
+
import soundfile as sf
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from dist_utils import is_main_process, get_world_size, get_rank
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def now():
|
| 27 |
+
from datetime import datetime
|
| 28 |
+
|
| 29 |
+
return datetime.now().strftime("%Y%m%d%H%M")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def setup_logger():
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO if is_main_process() else logging.WARN,
|
| 35 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 36 |
+
handlers=[logging.StreamHandler()],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_dataloader(dataset, config, is_train=True, use_distributed=True):
|
| 41 |
+
if use_distributed:
|
| 42 |
+
sampler = DistributedSampler(
|
| 43 |
+
dataset,
|
| 44 |
+
shuffle=is_train,
|
| 45 |
+
num_replicas=get_world_size(),
|
| 46 |
+
rank=get_rank()
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
sampler = None
|
| 50 |
+
|
| 51 |
+
loader = DataLoader(
|
| 52 |
+
dataset,
|
| 53 |
+
batch_size=config.batch_size_train if is_train else config.batch_size_eval,
|
| 54 |
+
num_workers=config.num_workers,
|
| 55 |
+
pin_memory=True,
|
| 56 |
+
sampler=sampler,
|
| 57 |
+
shuffle=sampler is None and is_train,
|
| 58 |
+
collate_fn=dataset.collater,
|
| 59 |
+
drop_last=is_train,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if is_train:
|
| 63 |
+
loader = IterLoader(loader, use_distributed=use_distributed)
|
| 64 |
+
|
| 65 |
+
return loader
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def apply_to_sample(f, sample):
|
| 69 |
+
if len(sample) == 0:
|
| 70 |
+
return {}
|
| 71 |
+
|
| 72 |
+
def _apply(x):
|
| 73 |
+
if torch.is_tensor(x):
|
| 74 |
+
return f(x)
|
| 75 |
+
elif isinstance(x, dict):
|
| 76 |
+
return {key: _apply(value) for key, value in x.items()}
|
| 77 |
+
elif isinstance(x, list):
|
| 78 |
+
return [_apply(x) for x in x]
|
| 79 |
+
else:
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
return _apply(sample)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def move_to_cuda(sample):
|
| 86 |
+
def _move_to_cuda(tensor):
|
| 87 |
+
return tensor.cuda()
|
| 88 |
+
|
| 89 |
+
return apply_to_sample(_move_to_cuda, sample)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def prepare_sample(samples, cuda_enabled=True):
|
| 93 |
+
if cuda_enabled:
|
| 94 |
+
samples = move_to_cuda(samples)
|
| 95 |
+
|
| 96 |
+
# TODO fp16 support
|
| 97 |
+
|
| 98 |
+
return samples
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class IterLoader:
|
| 102 |
+
"""
|
| 103 |
+
A wrapper to convert DataLoader as an infinite iterator.
|
| 104 |
+
|
| 105 |
+
Modified from:
|
| 106 |
+
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
|
| 110 |
+
self._dataloader = dataloader
|
| 111 |
+
self.iter_loader = iter(self._dataloader)
|
| 112 |
+
self._use_distributed = use_distributed
|
| 113 |
+
self._epoch = 0
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def epoch(self) -> int:
|
| 117 |
+
return self._epoch
|
| 118 |
+
|
| 119 |
+
def __next__(self):
|
| 120 |
+
try:
|
| 121 |
+
data = next(self.iter_loader)
|
| 122 |
+
except StopIteration:
|
| 123 |
+
self._epoch += 1
|
| 124 |
+
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
|
| 125 |
+
self._dataloader.sampler.set_epoch(self._epoch)
|
| 126 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
| 127 |
+
self.iter_loader = iter(self._dataloader)
|
| 128 |
+
data = next(self.iter_loader)
|
| 129 |
+
|
| 130 |
+
return data
|
| 131 |
+
|
| 132 |
+
def __iter__(self):
|
| 133 |
+
return self
|
| 134 |
+
|
| 135 |
+
def __len__(self):
|
| 136 |
+
return len(self._dataloader)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def prepare_one_sample(wav_path, wav_processor, cuda_enabled=True):
|
| 140 |
+
audio, sr = sf.read(wav_path)
|
| 141 |
+
if len(audio.shape) == 2: # stereo to mono
|
| 142 |
+
audio = audio[:, 0]
|
| 143 |
+
if len(audio) < sr: # pad audio to at least 1s
|
| 144 |
+
sil = np.zeros(sr - len(audio), dtype=float)
|
| 145 |
+
audio = np.concatenate((audio, sil), axis=0)
|
| 146 |
+
audio = audio[: sr * 30] # truncate audio to at most 30s
|
| 147 |
+
|
| 148 |
+
spectrogram = wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"]
|
| 149 |
+
|
| 150 |
+
samples = {
|
| 151 |
+
"spectrogram": spectrogram,
|
| 152 |
+
"raw_wav": torch.from_numpy(audio).unsqueeze(0),
|
| 153 |
+
"padding_mask": torch.zeros(len(audio), dtype=torch.bool).unsqueeze(0),
|
| 154 |
+
}
|
| 155 |
+
if cuda_enabled:
|
| 156 |
+
samples = move_to_cuda(samples)
|
| 157 |
+
|
| 158 |
+
return samples
|