|
import os |
|
import re |
|
import json |
|
import torch |
|
import ffmpeg |
|
import yt_dlp |
|
import torchaudio |
|
import gradio as gr |
|
import subprocess |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript, VideoUnavailable |
|
from youtube_transcript_api.formatters import TextFormatter |
|
from transformers import ( |
|
pipeline, |
|
WhisperProcessor, |
|
WhisperForConditionalGeneration, |
|
) |
|
|
|
|
|
import time |
|
import threading |
|
from fastapi import FastAPI, UploadFile, File |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from starlette.responses import JSONResponse |
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript, VideoUnavailable |
|
from youtube_transcript_api.formatters import TextFormatter |
|
from transformers import ( |
|
pipeline, |
|
WhisperProcessor, |
|
WhisperForConditionalGeneration, |
|
) |
|
|
|
|
|
|
|
def is_youtube_url(url): |
|
return "youtube.com" in url or "youtu.be" in url |
|
|
|
def is_web_url(url): |
|
return url.startswith("http://") or url.startswith("https://") |
|
|
|
def get_video_id(url): |
|
match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11})', url) |
|
return match.group(1) if match else None |
|
|
|
def try_download_transcript(video_id): |
|
try: |
|
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=["en"]) |
|
formatted = TextFormatter().format_transcript(transcript) |
|
return formatted |
|
except (TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript, VideoUnavailable): |
|
return None |
|
except Exception as e: |
|
print(f"Transcript error: {e}") |
|
return None |
|
|
|
def download_youtube_subtitles(url, lang='en', output_path="subtitles.vtt", cookies_path=None): |
|
ydl_opts = { |
|
"writesubtitles": True, |
|
"writeautomaticsub": True, |
|
"skip_download": True, |
|
"subtitleslangs": [lang], |
|
"subtitlesformat": "vtt", |
|
"outtmpl": output_path, |
|
} |
|
if cookies_path: |
|
ydl_opts["cookiefile"] = cookies_path |
|
|
|
try: |
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl: |
|
ydl.download([url]) |
|
return output_path |
|
except Exception as e: |
|
print(f"Subtitle download error: {e}") |
|
return None |
|
|
|
def parse_subtitle_file(sub_path): |
|
try: |
|
with open(sub_path, "r", encoding="utf-8") as f: |
|
lines = f.readlines() |
|
text_lines = [line.strip() for line in lines if line.strip() and not re.match(r'^\d+$|^\d{2}:\d{2}', line)] |
|
return " ".join(text_lines) |
|
except Exception as e: |
|
print(f"Subtitle parse error: {e}") |
|
return None |
|
|
|
def get_subtitle_streams(video_path): |
|
cmd = [ |
|
"ffprobe", "-v", "error", "-print_format", "json", |
|
"-show_streams", "-select_streams", "s", video_path |
|
] |
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
try: |
|
return json.loads(result.stdout).get("streams", []) |
|
except Exception as e: |
|
print(f"FFProbe parsing error: {e}") |
|
return [] |
|
|
|
def extract_subtitles_from_video(video_path, output_path="subtitles.srt"): |
|
subtitle_streams = get_subtitle_streams(video_path) |
|
if not subtitle_streams: |
|
return None |
|
try: |
|
cmd = [ |
|
"ffmpeg", "-y", "-i", video_path, |
|
"-map", "0:s:0", output_path |
|
] |
|
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
|
return output_path |
|
except Exception as e: |
|
print(f"Subtitle extraction error: {e}") |
|
return None |
|
|
|
def download_audio_youtube(url, output_path="audio.wav", cookies_path=None): |
|
fallback_video_path = "fallback_video.mp4" |
|
video_id = get_video_id(url) |
|
|
|
ydl_opts = { |
|
"format": "best", |
|
"outtmpl": fallback_video_path, |
|
"user_agent": "com.google.android.youtube/17.31.35 (Linux; U; Android 11)", |
|
"compat_opts": ["allow_unplayable_formats"] |
|
} |
|
|
|
if cookies_path: |
|
ydl_opts["cookiefile"] = cookies_path |
|
|
|
try: |
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl: |
|
ydl.download([url]) |
|
except Exception as e: |
|
raise RuntimeError( |
|
f"\u26a0\ufe0f Could not download this YouTube video. Try this alternative:" |
|
f" https://youtubetotranscript.com/transcript?v={video_id}¤t_language_code=en\n" |
|
f"Details: {e}" |
|
) |
|
|
|
return extract_audio_from_video(fallback_video_path, audio_path=output_path) |
|
|
|
def download_video_direct(url, output_path="video.mp4"): |
|
ydl_opts = { |
|
"format": "best", |
|
"outtmpl": output_path |
|
} |
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl: |
|
ydl.download([url]) |
|
return output_path |
|
|
|
def extract_audio_from_video(video_path, audio_path="audio.wav"): |
|
ffmpeg.input(video_path).output(audio_path, ac=1, ar=16000).run(overwrite_output=True) |
|
return audio_path |
|
|
|
def split_audio(input_path, chunk_length_sec=30, target_sr=16000): |
|
waveform, sr = torchaudio.load(input_path) |
|
if sr != target_sr: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) |
|
waveform = resampler(waveform) |
|
if waveform.shape[0] > 1: |
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
chunk_samples = target_sr * chunk_length_sec |
|
chunks = [waveform[:, i:i+chunk_samples] for i in range(0, waveform.shape[1], chunk_samples)] |
|
return chunks, target_sr |
|
|
|
class AudioChunksDataset(Dataset): |
|
def __init__(self, chunks): |
|
self.chunks = chunks |
|
|
|
def __len__(self): |
|
return len(self.chunks) |
|
|
|
def __getitem__(self, idx): |
|
return self.chunks[idx].squeeze(0) |
|
|
|
def collate_audio_batch(batch): |
|
max_len = max([b.shape[0] for b in batch]) |
|
padded_batch = [torch.nn.functional.pad(b, (0, max_len - b.shape[0])) for b in batch] |
|
return torch.stack(padded_batch) |
|
|
|
def transcribe_chunks_dataset(chunks, sr, model_name="openai/whisper-small", batch_size=4): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
processor = WhisperProcessor.from_pretrained(model_name) |
|
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device) |
|
model.eval() |
|
|
|
dataset = AudioChunksDataset(chunks) |
|
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_audio_batch) |
|
|
|
full_transcript = [] |
|
for batch_waveforms in dataloader: |
|
wave_list = [waveform.numpy() for waveform in batch_waveforms] |
|
input_features = processor(wave_list, sampling_rate=sr, return_tensors="pt", padding="max_length").input_features.to(device) |
|
with torch.no_grad(): |
|
predicted_ids = model.generate(input_features, language="en") |
|
transcriptions = processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
full_transcript.extend(transcriptions) |
|
|
|
return " ".join(full_transcript) |
|
|
|
def summarize_with_bart(text, max_tokens=1024): |
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1) |
|
sentences = text.split(". ") |
|
chunks = [] |
|
current_chunk = "" |
|
|
|
for sentence in sentences: |
|
if len(current_chunk + sentence) <= max_tokens: |
|
current_chunk += sentence + ". " |
|
else: |
|
chunks.append(current_chunk.strip()) |
|
current_chunk = sentence + ". " |
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
|
|
summary = "" |
|
for chunk in chunks: |
|
out = summarizer(chunk, max_length=150, min_length=30, do_sample=False) |
|
summary += out[0]['summary_text'] + " " |
|
|
|
return summary.strip() |
|
|
|
def generate_questions_with_pipeline(text, num_questions=5): |
|
question_generator = pipeline("text2text-generation", model="valhalla/t5-base-qg-hl", device=0 if torch.cuda.is_available() else -1) |
|
sentences = text.split(". ") |
|
questions = [] |
|
|
|
for sentence in sentences[:num_questions * 2]: |
|
if not sentence.strip(): |
|
continue |
|
input_text = f"generate question: {sentence.strip()}" |
|
out = question_generator(input_text, max_length=50, do_sample=True, temperature=0.9) |
|
question = out[0]["generated_text"].strip() |
|
if question: |
|
questions.append(question) |
|
|
|
return questions[:num_questions] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_input_gradio(url_input, file_input, text_input): |
|
try: |
|
transcript = "" |
|
|
|
if text_input: |
|
transcript = text_input.strip() |
|
|
|
elif file_input is not None: |
|
subtitle_path = extract_subtitles_from_video(file_input.name) |
|
if subtitle_path: |
|
transcript = parse_subtitle_file(subtitle_path) |
|
if not transcript: |
|
audio_path = extract_audio_from_video(file_input.name) |
|
chunks, sr = split_audio(audio_path, chunk_length_sec=15) |
|
transcript = transcribe_chunks_dataset(chunks, sr) |
|
|
|
elif url_input: |
|
if os.path.exists(url_input): |
|
subtitle_path = extract_subtitles_from_video(url_input) |
|
if subtitle_path: |
|
transcript = parse_subtitle_file(subtitle_path) |
|
if not transcript: |
|
audio_path = extract_audio_from_video(url_input) |
|
chunks, sr = split_audio(audio_path, chunk_length_sec=15) |
|
transcript = transcribe_chunks_dataset(chunks, sr) |
|
elif is_youtube_url(url_input): |
|
video_id = get_video_id(url_input) |
|
transcript = try_download_transcript(video_id) |
|
if not transcript: |
|
subtitle_path = download_youtube_subtitles(url_input) |
|
if subtitle_path: |
|
transcript = parse_subtitle_file(subtitle_path) |
|
if not transcript: |
|
try: |
|
audio_path = download_audio_youtube(url_input) |
|
chunks, sr = split_audio(audio_path, chunk_length_sec=15) |
|
transcript = transcribe_chunks_dataset(chunks, sr) |
|
except Exception as e: |
|
return ( |
|
f"\u26a0\ufe0f Could not download this YouTube video due to restrictions. " |
|
"Please upload the video manually.\n" |
|
f"Details: {e}", "" |
|
) |
|
else: |
|
video_file = download_video_direct(url_input) |
|
subtitle_path = extract_subtitles_from_video(video_file) |
|
if subtitle_path: |
|
transcript = parse_subtitle_file(subtitle_path) |
|
if not transcript: |
|
audio_path = extract_audio_from_video(video_file) |
|
chunks, sr = split_audio(audio_path, chunk_length_sec=15) |
|
transcript = transcribe_chunks_dataset(chunks, sr) |
|
else: |
|
return "Please provide a URL, upload a video file, or paste text.", "" |
|
|
|
summary = summarize_with_bart(transcript) |
|
questions = generate_questions_with_pipeline(summary) |
|
return summary, "\n".join([f"{i+1}. {q}" for i, q in enumerate(questions)]) |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}", "" |
|
|
|
|
|
app = FastAPI() |
|
UPLOAD_DIR = "uploads" |
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
@app.post("/upload") |
|
async def upload_file(file: UploadFile = File(...)): |
|
file_path = os.path.join(UPLOAD_DIR, file.filename) |
|
with open(file_path, "wb") as f: |
|
f.write(await file.read()) |
|
|
|
threading.Timer(120, lambda: delete_file_if_exists(file_path)).start() |
|
return JSONResponse({ "status": "ok", "path": file_path }) |
|
|
|
def delete_file_if_exists(path): |
|
try: |
|
if os.path.exists(path): |
|
os.remove(path) |
|
print(f"β
Deleted file: {path}") |
|
except Exception as e: |
|
print(f"β Failed to delete {path}: {e}") |
|
|
|
|
|
iface = gr.Interface( |
|
fn=process_input_gradio, |
|
inputs=[ |
|
gr.Textbox(label="YouTube or Direct Video URL", placeholder="https://... or uploads/video.mp4"), |
|
gr.File(label="Or Upload a Video File", file_types=[".mp4", ".mkv", ".webm"]), |
|
gr.Textbox(label="Or Paste Transcript/Text Directly", lines=10, placeholder="Paste transcript or text here...") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Summary", lines=10), |
|
gr.Textbox(label="Generated Questions", lines=10), |
|
], |
|
title="Lecture Summary & Question Generator", |
|
description="Provide a YouTube/Direct video URL, upload a video file, or paste text. If the video is restricted, upload the video file directly." |
|
) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
|
|
gr.mount_gradio_app(app, iface, path="/") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|