bluenevus's picture
Update app.py
19689fb verified
raw
history blame
7.14 kB
import io
import re
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import requests
from bs4 import BeautifulSoup
import tempfile
import os
import soundfile as sf
from spellchecker import SpellChecker
from pydub import AudioSegment
import librosa
import numpy as np
from pyannote.audio import Pipeline
from pywebio import start_server, config
from pywebio.input import input
from pywebio.output import put_text, put_markdown, put_file
# Initialize the speaker diarization pipeline
try:
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization")
print("Speaker diarization pipeline initialized successfully")
except Exception as e:
print(f"Error initializing speaker diarization pipeline: {str(e)}")
pipeline = None
# Check if CUDA is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Whisper model and processor
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
spell = SpellChecker()
def download_audio_from_url(url):
try:
if "share" in url:
print("Processing shareable link...")
response = requests.get(url)
soup = BeautifulSoup(response.content, 'html.parser')
video_tag = soup.find('video')
if video_tag and 'src' in video_tag.attrs:
video_url = video_tag['src']
print(f"Extracted video URL: {video_url}")
else:
raise ValueError("Direct video URL not found in the shareable link.")
else:
video_url = url
print(f"Downloading video from URL: {video_url}")
response = requests.get(video_url)
audio_bytes = response.content
print(f"Successfully downloaded {len(audio_bytes)} bytes of data")
return audio_bytes
except Exception as e:
print(f"Error in download_audio_from_url: {str(e)}")
raise
def correct_spelling(text):
words = text.split()
corrected_words = [spell.correction(word) or word for word in words]
return ' '.join(corrected_words)
def format_transcript_with_speakers(transcript, diarization):
formatted_transcript = []
current_speaker = None
for segment, _, speaker in diarization.itertracks(yield_label=True):
start = segment.start
end = segment.end
if speaker != current_speaker:
if current_speaker is not None:
formatted_transcript.append("\n") # Add a blank line between speakers
formatted_transcript.append(f"Speaker {speaker}:\n")
current_speaker = speaker
segment_text = transcript[start:end].strip()
if segment_text:
formatted_transcript.append(f"{segment_text}\n")
return "".join(formatted_transcript)
def transcribe_audio(audio_file):
try:
print("Loading audio file...")
audio_input, sr = librosa.load(audio_file, sr=16000)
audio_input = audio_input.astype(np.float32)
print(f"Audio duration: {len(audio_input) / sr:.2f} seconds")
# Apply speaker diarization
if pipeline:
print("Applying speaker diarization...")
diarization = pipeline(audio_file)
print("Speaker diarization complete.")
else:
diarization = None
chunk_length = 30 * sr
overlap = 5 * sr
transcriptions = []
print("Starting transcription...")
for i in range(0, len(audio_input), chunk_length - overlap):
chunk = audio_input[i:i+chunk_length]
input_features = processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to(device)
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
transcriptions.extend(transcription)
print(f"Processed {i / sr:.2f} to {(i + chunk_length) / sr:.2f} seconds")
full_transcription = " ".join(transcriptions)
print(f"Transcription complete. Full transcription length: {len(full_transcription)} characters")
if diarization:
print("Applying formatting with speaker diarization...")
formatted_transcription = format_transcript_with_speakers(full_transcription, diarization)
else:
print("Applying formatting without speaker diarization...")
formatted_transcription = format_transcript_with_breaks(full_transcription)
return formatted_transcription
except Exception as e:
print(f"Error in transcribe_audio: {str(e)}")
raise
def format_transcript_with_breaks(transcript):
sentences = re.split('(?<=[.!?]) +', transcript)
paragraphs = []
current_paragraph = []
for sentence in sentences:
current_paragraph.append(sentence)
if len(current_paragraph) >= 3: # Adjust this number to control paragraph size
paragraphs.append(' '.join(current_paragraph))
current_paragraph = []
if current_paragraph:
paragraphs.append(' '.join(current_paragraph))
return '\n\n'.join(paragraphs)
def transcribe_video(url):
try:
print(f"Attempting to download audio from URL: {url}")
audio_bytes = download_audio_from_url(url)
print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data")
# Convert audio bytes to AudioSegment
audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
print(f"Audio duration: {len(audio) / 1000} seconds")
# Save as WAV file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
audio.export(temp_audio.name, format="wav")
temp_audio_path = temp_audio.name
print("Starting audio transcription...")
transcript = transcribe_audio(temp_audio_path)
print(f"Transcription completed. Transcript length: {len(transcript)} characters")
# Clean up the temporary file
os.unlink(temp_audio_path)
# Apply spelling correction
transcript = correct_spelling(transcript)
return transcript
except Exception as e:
error_message = f"An error occurred: {str(e)}"
print(error_message)
return error_message
def video_transcription():
put_markdown("# Video Transcription")
video_url = input(label="Video URL")
if video_url:
put_text("Transcribing video...")
transcript = transcribe_video(video_url)
if transcript:
put_text(transcript)
put_file('transcript.txt', content=transcript.encode('utf-8'), label="Download Transcript")
else:
put_text("Failed to transcribe video.")
if __name__ == '__main__':
config(title="Video Transcription", description="Transcribe audio from a video URL using Whisper and PyAnnote")
start_server(video_transcription, port=7860, debug=True)