AfroLogicInsect's picture
Update app.py
2b6b811 verified
import gradio as gr
import torch
import librosa
import numpy as np
import json
import os
import tempfile
import time
from datetime import datetime
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import warnings
warnings.filterwarnings("ignore")
# =============================================================================
# MODEL LOADING AND CONFIGURATION
# =============================================================================
# Configure your model path - UPDATE THIS with your actual model name
MODEL_NAME = "AfroLogicInsect/whisper-finetuned-float32" # Replace with your HF model
# Global variables for model and processor
model = None
processor = None
model_dtype = None
def load_model():
"""Load the Whisper model and processor"""
global model, processor, model_dtype
try:
print(f"πŸ”„ Loading model: {MODEL_NAME}")
# Load processor
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
# Load model with appropriate dtype
model = WhisperForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32, # Use float32 for stability
device_map="auto" if torch.cuda.is_available() else None
)
model_dtype = torch.float32
# Move to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print(f"βœ… Model loaded on GPU: {torch.cuda.get_device_name()}")
else:
print("βœ… Model loaded on CPU")
return True
except Exception as e:
print(f"❌ Error loading model: {e}")
# Fallback to base Whisper model
try:
print("πŸ”„ Falling back to base Whisper model...")
fallback_model = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(fallback_model)
model = WhisperForConditionalGeneration.from_pretrained(
fallback_model,
torch_dtype=torch.float32
)
model_dtype = torch.float32
if torch.cuda.is_available():
model = model.cuda()
print(f"βœ… Fallback model loaded: {fallback_model}")
return True
except Exception as e2:
print(f"❌ Fallback model loading failed: {e2}")
return False
# Load model on startup
print("πŸš€ Initializing Whisper Transcription Service...")
model_loaded = load_model()
# =============================================================================
# CORE TRANSCRIPTION FUNCTIONS
# =============================================================================
def transcribe_audio_chunk(audio_chunk, sr=16000):
"""Transcribe a single audio chunk"""
try:
# Process with processor
inputs = processor(
audio_chunk,
sampling_rate=sr,
return_tensors="pt"
)
input_features = inputs.input_features
# Handle dtype matching
if model_dtype == torch.float16:
input_features = input_features.half()
else:
input_features = input_features.float()
# Move to same device as model
input_features = input_features.to(model.device)
# Generate transcription
with torch.no_grad():
try:
predicted_ids = model.generate(
input_features,
language="en",
task="transcribe",
max_length=448,
num_beams=1,
do_sample=False,
use_cache=True,
no_repeat_ngram_size=2
)
transcription = processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription
except RuntimeError as gen_error:
if "Input type" in str(gen_error) and "bias type" in str(gen_error):
# Handle dtype mismatch
model.float()
input_features = input_features.float()
predicted_ids = model.generate(
input_features,
language="en",
task="transcribe",
max_length=448,
num_beams=1,
do_sample=False,
no_repeat_ngram_size=2
)
transcription = processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription
else:
raise gen_error
except Exception as e:
print(f"❌ Chunk transcription failed: {e}")
return None
def process_audio_with_timestamps(audio_array, sr=16000, chunk_length=15):
"""Process audio with timestamps using robust chunking"""
try:
total_duration = len(audio_array) / sr
# Check duration limit (3 minutes = 180 seconds)
if total_duration > 180:
return {
"error": f"⚠️ Audio too long ({total_duration:.1f}s). Maximum allowed: 3 minutes (180s)",
"success": False
}
chunk_samples = chunk_length * sr
overlap_samples = int(2 * sr) # 2-second overlap
all_segments = []
start = 0
chunk_index = 0
progress_updates = []
while start < len(audio_array):
# Define chunk boundaries
end = min(start + chunk_samples, len(audio_array))
# Add overlap for better transcription
chunk_start_with_overlap = max(0, start - overlap_samples // 2)
chunk_end_with_overlap = min(len(audio_array), end + overlap_samples // 2)
chunk_audio = audio_array[chunk_start_with_overlap:chunk_end_with_overlap]
# Calculate time boundaries
start_time = start / sr
end_time = end / sr
# Update progress
progress = (chunk_index + 1) / max(1, int(np.ceil(len(audio_array) / chunk_samples))) * 100
progress_updates.append(f"Processing chunk {chunk_index + 1}: {start_time:.1f}s - {end_time:.1f}s ({progress:.0f}%)")
# Transcribe chunk
transcription = transcribe_audio_chunk(chunk_audio, sr)
if transcription and transcription.strip():
clean_text = transcription.strip()
segment = {
"start": round(start_time, 2),
"end": round(end_time, 2),
"text": clean_text,
"duration": round(end_time - start_time, 2)
}
all_segments.append(segment)
# Move to next chunk
start = end
chunk_index += 1
# Remove overlaps between segments
cleaned_segments = remove_segment_overlaps(all_segments)
if cleaned_segments:
full_text = " ".join([seg["text"] for seg in cleaned_segments])
result = {
"success": True,
"text": full_text,
"segments": cleaned_segments,
"metadata": {
"total_duration": round(total_duration, 2),
"num_segments": len(cleaned_segments),
"chunk_length": chunk_length,
"processing_time": time.time()
}
}
return result
else:
return {
"error": "❌ No transcription could be generated",
"success": False
}
except Exception as e:
return {
"error": f"❌ Processing failed: {str(e)}",
"success": False
}
def remove_segment_overlaps(segments):
"""Remove overlapping text between segments"""
if len(segments) <= 1:
return segments
cleaned_segments = [segments[0]]
for i in range(1, len(segments)):
current_segment = segments[i].copy()
previous_text = cleaned_segments[-1]["text"]
current_text = current_segment["text"]
# Simple overlap detection
prev_words = previous_text.lower().split()
curr_words = current_text.lower().split()
overlap_length = 0
max_check = min(8, len(prev_words), len(curr_words))
for j in range(1, max_check + 1):
if prev_words[-j:] == curr_words[:j]:
overlap_length = j
if overlap_length > 0:
remaining_words = current_text.split()[overlap_length:]
if remaining_words:
current_segment["text"] = " ".join(remaining_words)
cleaned_segments.append(current_segment)
else:
cleaned_segments.append(current_segment)
return cleaned_segments
# =============================================================================
# GRADIO INTERFACE FUNCTIONS
# =============================================================================
def transcribe_file(audio_file):
"""Handle file upload transcription"""
if not model_loaded:
return "❌ Model not loaded. Please refresh the page.", None, None
if audio_file is None:
return "⚠️ Please upload an audio file.", None, None
try:
# Load audio file
audio_array, sr = librosa.load(audio_file, sr=16000)
# Check duration
duration = len(audio_array) / sr
if duration > 180: # 3 minutes
return f"⚠️ Audio too long ({duration:.1f}s). Maximum allowed: 3 minutes.", None, None
# Process with timestamps
result = process_audio_with_timestamps(audio_array, sr)
if result["success"]:
# Format output
formatted_text = format_transcription_output(result)
# Create downloadable files
json_file = create_json_download(result, audio_file)
srt_file = create_srt_download(result, audio_file)
return formatted_text, json_file, srt_file
else:
return result["error"], None, None
except Exception as e:
return f"❌ Error processing file: {str(e)}", None, None
def transcribe_microphone(audio_data):
"""Handle microphone recording transcription"""
if not model_loaded:
return "❌ Model not loaded. Please refresh the page.", None, None
if audio_data is None:
return "⚠️ No audio recorded. Please record something first.", None, None
try:
# Extract sample rate and audio array from Gradio audio data
sr, audio_array = audio_data
# Convert to float32 and normalize
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32)
if audio_array.max() > 1.0:
audio_array = audio_array / 32768.0 # Convert from int16 to float32
# Resample to 16kHz if needed
if sr != 16000:
audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000)
sr = 16000
# Check duration
duration = len(audio_array) / sr
if duration > 180: # 3 minutes
return f"⚠️ Recording too long ({duration:.1f}s). Maximum allowed: 3 minutes.", None, None
if duration < 0.5: # Less than 0.5 seconds
return "⚠️ Recording too short. Please record for at least 0.5 seconds.", None, None
# Process with timestamps
result = process_audio_with_timestamps(audio_array, sr)
if result["success"]:
# Format output
formatted_text = format_transcription_output(result)
# Create downloadable files
json_file = create_json_download(result, "microphone_recording")
srt_file = create_srt_download(result, "microphone_recording")
return formatted_text, json_file, srt_file
else:
return result["error"], None, None
except Exception as e:
return f"❌ Error processing recording: {str(e)}", None, None
def format_transcription_output(result):
"""Format transcription result for display"""
output = []
# Header
output.append("🎯 TRANSCRIPTION RESULTS")
output.append("=" * 50)
# Metadata
metadata = result["metadata"]
output.append(f"πŸ“Š Duration: {metadata['total_duration']}s")
output.append(f"πŸ“ Segments: {metadata['num_segments']}")
output.append("")
# Full text
output.append("πŸ“„ FULL TRANSCRIPT:")
output.append("-" * 30)
output.append(result["text"])
output.append("")
# Timestamped segments
output.append("πŸ• TIMESTAMPED SEGMENTS:")
output.append("-" * 30)
for i, segment in enumerate(result["segments"], 1):
start_min = int(segment["start"] // 60)
start_sec = int(segment["start"] % 60)
end_min = int(segment["end"] // 60)
end_sec = int(segment["end"] % 60)
time_str = f"{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}"
output.append(f"{i:2d}. [{time_str}] {segment['text']}")
return "\n".join(output)
def create_json_download(result, source_name):
"""Create JSON file for download"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"transcription_{timestamp}.json"
# Add metadata
result["metadata"]["source"] = os.path.basename(str(source_name))
result["metadata"]["generated_at"] = datetime.now().isoformat()
result["metadata"]["model"] = MODEL_NAME
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as f:
json.dump(result, f, indent=2, ensure_ascii=False)
return f.name
except Exception as e:
print(f"Error creating JSON download: {e}")
return None
def create_srt_download(result, source_name):
"""Create SRT subtitle file for download"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"subtitles_{timestamp}.srt"
srt_content = []
for i, segment in enumerate(result["segments"], 1):
start_time = format_time_srt(segment["start"])
end_time = format_time_srt(segment["end"])
srt_content.extend([
str(i),
f"{start_time} --> {end_time}",
segment["text"],
""
])
with tempfile.NamedTemporaryFile(mode='w', suffix='.srt', delete=False, encoding='utf-8') as f:
f.write("\n".join(srt_content))
return f.name
except Exception as e:
print(f"Error creating SRT download: {e}")
return None
def format_time_srt(seconds):
"""Format seconds to SRT time format"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
millis = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
# =============================================================================
# GRADIO INTERFACE
# =============================================================================
def create_gradio_interface():
"""Create the Gradio interface"""
# Custom CSS for better styling
css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.title {
text-align: center;
color: #2d3748;
margin-bottom: 2rem;
}
.subtitle {
text-align: center;
color: #4a5568;
margin-bottom: 1rem;
}
.output-text {
font-family: 'Courier New', monospace;
background-color: #f7fafc;
padding: 1rem;
border-radius: 8px;
border: 1px solid #e2e8f0;
}
.warning {
background-color: #fff3cd;
border: 1px solid #ffeaa7;
color: #856404;
padding: 10px;
border-radius: 4px;
margin: 10px 0;
}
"""
with gr.Blocks(css=css, title="πŸŽ™οΈ Whisper Speech Transcription") as interface:
# Header
gr.HTML("""
<div class="title">
<h1>πŸŽ™οΈ Whisper Speech Transcription</h1>
<p class="subtitle">Upload an audio file or record your voice to get an AI-powered transcription with timestamps</p>
</div>
""")
# Warning about limits
gr.HTML("""
<div class="warning">
<strong>⚠️ Important:</strong> Maximum audio length is 3 minutes (180 seconds).
Longer files will be rejected to ensure fair usage for all users.
</div>
""")
# Model status
status_color = "green" if model_loaded else "red"
status_text = "βœ… Model loaded and ready" if model_loaded else "❌ Model loading failed"
gr.HTML(f'<p style="color: {status_color}; text-align: center;"><strong>{status_text}</strong></p>')
with gr.Tabs():
# Tab 1: File Upload
with gr.TabItem("πŸ“ Upload Audio File"):
with gr.Row():
with gr.Column():
audio_file_input = gr.Audio(
label="Upload Audio File",
type="filepath",
sources=["upload"]
)
file_transcribe_btn = gr.Button(
"πŸš€ Transcribe File",
variant="primary",
size="lg"
)
with gr.Row():
file_output = gr.Textbox(
label="Transcription Results",
lines=15,
placeholder="Your transcription will appear here...",
elem_classes=["output-text"]
)
with gr.Row():
with gr.Column():
json_download = gr.File(
label="πŸ“„ Download JSON",
visible=False
)
with gr.Column():
srt_download = gr.File(
label="πŸ“„ Download SRT Subtitles",
visible=False
)
# Tab 2: Voice Recording
with gr.TabItem("🎀 Record Voice"):
with gr.Row():
with gr.Column():
audio_mic_input = gr.Audio(
label="Record Your Voice",
sources=["microphone"],
type="numpy"
)
mic_transcribe_btn = gr.Button(
"πŸš€ Transcribe Recording",
variant="primary",
size="lg"
)
with gr.Row():
mic_output = gr.Textbox(
label="Transcription Results",
lines=15,
placeholder="Your transcription will appear here...",
elem_classes=["output-text"]
)
with gr.Row():
with gr.Column():
json_download_mic = gr.File(
label="πŸ“„ Download JSON",
visible=False
)
with gr.Column():
srt_download_mic = gr.File(
label="πŸ“„ Download SRT Subtitles",
visible=False
)
# Footer
gr.HTML("""
<div style="text-align: center; margin-top: 2rem; padding: 1rem; background-color: #f8f9fa; border-radius: 8px;">
<h3>πŸ“‹ Output Formats</h3>
<p><strong>JSON:</strong> Complete transcription data with timestamps and metadata</p>
<p><strong>SRT:</strong> Standard subtitle format for video players</p>
<p><strong>Display:</strong> Formatted text with timestamped segments</p>
<br>
<p style="color: #6c757d; font-size: 0.9em;">
Powered by Whisper AI | Maximum 3 minutes per audio | English language optimized
</p>
</div>
""")
# Event handlers
def update_file_outputs(result_text, json_file, srt_file):
json_visible = json_file is not None
srt_visible = srt_file is not None
return (
result_text,
gr.update(value=json_file, visible=json_visible),
gr.update(value=srt_file, visible=srt_visible)
)
file_transcribe_btn.click(
fn=transcribe_file,
inputs=[audio_file_input],
outputs=[file_output, json_download, srt_download]
).then(
fn=update_file_outputs,
inputs=[file_output, json_download, srt_download],
outputs=[file_output, json_download, srt_download]
)
mic_transcribe_btn.click(
fn=transcribe_microphone,
inputs=[audio_mic_input],
outputs=[mic_output, json_download_mic, srt_download_mic]
).then(
fn=update_file_outputs,
inputs=[mic_output, json_download_mic, srt_download_mic],
outputs=[mic_output, json_download_mic, srt_download_mic]
)
return interface
# =============================================================================
# LAUNCH APPLICATION
# =============================================================================
if __name__ == "__main__":
# Create and launch the interface
interface = create_gradio_interface()
# Launch configuration
interface.launch(
share=True, # Creates a public URL
server_name="0.0.0.0", # Allows external access
server_port=7860, # Standard Gradio port
show_error=True,
# enable_queue=True, # Handle multiple users
max_threads=10 # Limit concurrent processing
)