Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import torch | |
import torchaudio | |
import numpy as np | |
from transformers import ( | |
Wav2Vec2ForSequenceClassification, | |
AutoFeatureExtractor, | |
Wav2Vec2ForCTC, | |
AutoProcessor, | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM | |
) | |
import spaces | |
import logging | |
from difflib import SequenceMatcher | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
class AudioProcessor: | |
def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): | |
self.chunk_size = chunk_size | |
self.overlap = overlap | |
self.sample_rate = sample_rate | |
self.previous_text = "" | |
self.previous_lang = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_models(self): | |
"""Load all required models""" | |
logger.info("Loading MMS models...") | |
# Language identification model | |
lid_processor = AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256") | |
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256") | |
# Transcription model | |
mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") | |
mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all") | |
# Translation model | |
translation_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
return { | |
'lid': (lid_model, lid_processor), | |
'mms': (mms_model, mms_processor), | |
'translation': (translation_model, translation_tokenizer) | |
} | |
def identify_language(self, audio_chunk, models): | |
"""Identify language of audio chunk""" | |
lid_model, lid_processor = models['lid'] | |
inputs = lid_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") | |
lid_model.to(self.device) | |
with torch.no_grad(): | |
outputs = lid_model(inputs.input_values.to(self.device)).logits | |
lang_id = torch.argmax(outputs, dim=-1)[0].item() | |
detected_lang = lid_model.config.id2label[lang_id] | |
return detected_lang | |
def transcribe_chunk(self, audio_chunk, language, models): | |
"""Transcribe audio chunk""" | |
mms_model, mms_processor = models['mms'] | |
mms_processor.tokenizer.set_target_lang(language) | |
mms_model.load_adapter(language) | |
mms_model.to(self.device) | |
inputs = mms_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = mms_model(inputs.input_values.to(self.device)).logits | |
ids = torch.argmax(outputs, dim=-1)[0] | |
transcription = mms_processor.decode(ids) | |
return transcription | |
def translate_text(self, text, models): | |
"""Translate text to English""" | |
translation_model, translation_tokenizer = models['translation'] | |
inputs = translation_tokenizer(text, return_tensors="pt") | |
inputs = inputs.to(self.device) | |
translation_model.to(self.device) | |
with torch.no_grad(): | |
outputs = translation_model.generate( | |
**inputs, | |
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), | |
max_length=100 | |
) | |
translation = translation_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
return translation | |
def preprocess_audio(self, audio): | |
""" | |
Create overlapping chunks with improved timing logic | |
""" | |
chunk_samples = int(self.chunk_size * self.sample_rate) | |
overlap_samples = int(self.overlap * self.sample_rate) | |
chunks_with_times = [] | |
start_idx = 0 | |
while start_idx < len(audio): | |
end_idx = min(start_idx + chunk_samples, len(audio)) | |
# Add padding for first chunk | |
if start_idx == 0: | |
chunk = audio[start_idx:end_idx] | |
padding = torch.zeros(int(1 * self.sample_rate)) | |
chunk = torch.cat([padding, chunk]) | |
else: | |
# Include overlap from previous chunk | |
actual_start = max(0, start_idx - overlap_samples) | |
chunk = audio[actual_start:end_idx] | |
# Pad if necessary | |
if len(chunk) < chunk_samples: | |
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) | |
# Adjust time ranges to account for overlaps | |
chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap) | |
chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate) | |
chunks_with_times.append({ | |
'chunk': chunk, | |
'start_time': start_idx / self.sample_rate, | |
'end_time': end_idx / self.sample_rate, | |
'transcribe_start': chunk_start_time, | |
'transcribe_end': chunk_end_time | |
}) | |
# Move to next chunk with smaller step size for better continuity | |
start_idx += (chunk_samples - overlap_samples) | |
return chunks_with_times | |
def process_audio(self, audio_path, translate=False): | |
"""Main processing function""" | |
try: | |
# Load audio | |
waveform, sample_rate = torchaudio.load(audio_path) | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0) | |
else: | |
waveform = waveform.squeeze(0) | |
# Resample if necessary | |
if sample_rate != self.sample_rate: | |
resampler = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, | |
new_freq=self.sample_rate | |
) | |
waveform = resampler(waveform) | |
# if sample_rate != self.sample_rate: | |
# waveform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(waveform) | |
# Load models | |
models = self.load_models() | |
# Process in chunks | |
chunk_samples = int(self.chunk_size * self.sample_rate) | |
overlap_samples = int(self.overlap * self.sample_rate) | |
segments = [] | |
language_segments = [] | |
for i in range(0, len(waveform), chunk_samples - overlap_samples): | |
chunk = waveform[i:i + chunk_samples] | |
if len(chunk) < chunk_samples: | |
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) | |
# Process chunk | |
start_time = i / self.sample_rate | |
end_time = (i + len(chunk)) / self.sample_rate | |
# Identify language | |
language = self.identify_language(chunk, models) | |
# Record language segment | |
language_segments.append({ | |
"language": language, | |
"start": start_time, | |
"end": end_time | |
}) | |
# Transcribe | |
transcription = self.transcribe_chunk(chunk, language, models) | |
segment = { | |
"start": start_time, | |
"end": end_time, | |
"language": language, | |
"text": transcription, | |
"speaker": "Speaker" # Simple speaker assignment | |
} | |
if translate: | |
translation = self.translate_text(transcription, models) | |
segment["translated"] = translation | |
segments.append(segment) | |
# Clean up GPU memory | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Merge nearby segments | |
merged_segments = self.merge_segments(segments) | |
return language_segments, merged_segments | |
except Exception as e: | |
logger.error(f"Error processing audio: {str(e)}") | |
raise | |
def merge_segments(self, segments, time_threshold=0.5, similarity_threshold=0.7): | |
"""Merge similar nearby segments""" | |
if not segments: | |
return segments | |
merged = [] | |
current = segments[0] | |
for next_segment in segments[1:]: | |
if (next_segment['start'] - current['end'] <= time_threshold and | |
current['language'] == next_segment['language']): | |
# Check text similarity | |
matcher = SequenceMatcher(None, current['text'], next_segment['text']) | |
similarity = matcher.ratio() | |
if similarity > similarity_threshold: | |
# Merge segments | |
current['end'] = next_segment['end'] | |
current['text'] = current['text'] + ' ' + next_segment['text'] | |
if 'translated' in current and 'translated' in next_segment: | |
current['translated'] = current['translated'] + ' ' + next_segment['translated'] | |
else: | |
merged.append(current) | |
current = next_segment | |
else: | |
merged.append(current) | |
current = next_segment | |
merged.append(current) | |
return merged |