agnixcode's picture
Update app.py
e34ac27 verified
# app.py
import os
import re
import gradio as gr
import numpy as np
import faiss
from youtube_transcript_api import (
YouTubeTranscriptApi,
TranscriptsDisabled,
NoTranscriptFound,
VideoUnavailable,
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
# ---------------------------------------------------------------------------
# Global state
# ---------------------------------------------------------------------------
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
faiss_index = None
chunk_store = []
full_transcript = ""
HF_TOKEN = os.environ.get("HF_TOKEN", "")
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
inference_client = InferenceClient(model=LLM_MODEL, token=HF_TOKEN or None)
# ---------------------------------------------------------------------------
# Helper – extract video id
# ---------------------------------------------------------------------------
def _extract_video_id(url: str) -> str:
patterns = [
r"(?:v=)([A-Za-z0-9_-]{11})",
r"(?:youtu\.be/)([A-Za-z0-9_-]{11})",
r"(?:embed/)([A-Za-z0-9_-]{11})",
r"(?:shorts/)([A-Za-z0-9_-]{11})",
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
raise ValueError(f"Could not extract a valid video ID from: {url}")
# ---------------------------------------------------------------------------
# 1. Fetch transcript
# Confirmed from source: ALL methods are CLASS methods.
# get_transcript() returns list of dicts: [{"text": str, "start": float, "duration": float}]
# Access text with snippet["text"] not snippet.text
# ---------------------------------------------------------------------------
def get_transcript(url: str) -> str:
video_id = _extract_video_id(url)
# Primary: try English directly
try:
snippets = YouTubeTranscriptApi.get_transcript(
video_id, languages=["en", "en-US", "en-GB"]
)
return " ".join(s["text"] for s in snippets)
except (NoTranscriptFound, TranscriptsDisabled):
pass
except VideoUnavailable:
raise ValueError("This video is unavailable or private.")
except Exception:
pass
# Fallback: list all, pick first available, fetch it
try:
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
transcript = None
# prefer any english variant
for t in transcript_list:
if t.language_code.startswith("en"):
transcript = t
break
# if no english, take the first one
if transcript is None:
for t in transcript_list:
transcript = t
break
if transcript is None:
raise ValueError("No transcripts are available for this video.")
# fetch() returns list of dicts [{"text":..., "start":..., "duration":...}]
snippets = transcript.fetch()
return " ".join(s["text"] for s in snippets)
except ValueError:
raise
except TranscriptsDisabled:
raise ValueError("Transcripts are disabled for this video.")
except Exception as exc:
raise ValueError(f"Could not retrieve transcript: {exc}")
# ---------------------------------------------------------------------------
# 2. Process video
# ---------------------------------------------------------------------------
def process_video(url: str):
global faiss_index, chunk_store, full_transcript
faiss_index = None
chunk_store = []
full_transcript = ""
if not url.strip():
return "⚠️ Please enter a YouTube URL.", ""
try:
transcript = get_transcript(url)
except ValueError as exc:
return f"❌ {exc}", ""
except Exception as exc:
return f"❌ Unexpected error: {exc}", ""
if not transcript.strip():
return "❌ Transcript is empty for this video.", ""
full_transcript = transcript
splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
length_function=len,
)
chunks = splitter.split_text(transcript)
if not chunks:
return "❌ Could not split transcript into chunks.", transcript
chunk_store = chunks
embeddings = embedding_model.encode(chunks, show_progress_bar=False)
embeddings = np.array(embeddings, dtype="float32")
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
faiss_index = index
status = (
f"✅ Video processed successfully!\n"
f" • Chunks created : {len(chunks)}\n"
f" • Embedding dim : {dim}\n"
f" • FAISS vectors : {index.ntotal}\n\n"
f"Switch to the 💬 Chat with Video tab to ask questions."
)
return status, transcript
# ---------------------------------------------------------------------------
# 3. Retrieve top-k chunks
# ---------------------------------------------------------------------------
def retrieve_context(query: str, top_k: int = 3) -> str:
if faiss_index is None or not chunk_store:
return ""
query_vec = embedding_model.encode([query], show_progress_bar=False)
query_vec = np.array(query_vec, dtype="float32")
k = min(top_k, len(chunk_store))
_, indices = faiss_index.search(query_vec, k)
retrieved = [chunk_store[i] for i in indices[0] if 0 <= i < len(chunk_store)]
return "\n\n".join(retrieved)
# ---------------------------------------------------------------------------
# 4. Generate answer
# ---------------------------------------------------------------------------
def generate_answer(query: str) -> str:
if faiss_index is None:
return (
"⚠️ No video processed yet. "
"Go to 📥 Process Video tab first."
)
context = retrieve_context(query, top_k=3)
if not context:
return "⚠️ Could not retrieve relevant context for your question."
system_prompt = (
"You are a helpful assistant that answers questions strictly "
"based on the provided video transcript context. "
"If the answer is not in the context, say: "
"'I could not find this information in the video transcript.' "
"Do NOT hallucinate or make up information."
)
user_prompt = (
f"Context from the video transcript:\n"
f"---\n{context}\n---\n\n"
f"Question: {query}\n\n"
f"Answer:"
)
try:
response = inference_client.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
max_tokens=512,
temperature=0.2,
top_p=0.9,
)
return response.choices[0].message.content.strip()
except Exception as exc:
return (
f"❌ Inference failed: {exc}\n"
"Check that HF_TOKEN is set correctly as a Space secret."
)
# ---------------------------------------------------------------------------
# 5. Chat helper
# Gradio 6.x Chatbot uses list of [user, bot] pairs (list of lists)
# ---------------------------------------------------------------------------
def chat(user_message: str, history: list):
if not user_message.strip():
history = history + [["", "⚠️ Please enter a question."]]
return history, ""
answer = generate_answer(user_message)
history = history + [[user_message, answer]]
return history, ""
# ---------------------------------------------------------------------------
# 6. Gradio UI — fully compatible with Gradio 6.13
# ---------------------------------------------------------------------------
with gr.Blocks(title="YouTube RAG Chatbot") as app:
gr.Markdown(
"""
# 🎬 YouTube RAG Chatbot
**Fetch any YouTube transcript and chat with it using RAG + Mistral-7B.**
> 🔑 Add your `HF_TOKEN` in Space **Settings → Secrets** for the LLM to work.
"""
)
with gr.Tabs():
# ── Tab 1: Process ─────────────────────────────────────────────────
with gr.TabItem("📥 Process Video"):
gr.Markdown("Enter a YouTube URL and click **Process** to index the transcript.")
with gr.Row():
url_input = gr.Textbox(
label="YouTube URL",
placeholder="https://www.youtube.com/watch?v=...",
scale=5,
)
process_btn = gr.Button("⚙️ Process", variant="primary", scale=1)
status_output = gr.Textbox(
label="Status",
lines=6,
interactive=False,
)
transcript_output = gr.Textbox(
label="Transcript",
lines=15,
interactive=False,
)
process_btn.click(
fn=process_video,
inputs=[url_input],
outputs=[status_output, transcript_output],
)
# ── Tab 2: Chat ────────────────────────────────────────────────────
with gr.TabItem("💬 Chat with Video"):
gr.Markdown("Ask questions about the video. Answers are grounded in the transcript.")
# Gradio 6.13: Chatbot takes list of [user, bot] pairs
chatbot = gr.Chatbot(label="Conversation", height=450)
with gr.Row():
query_input = gr.Textbox(
label="Your question",
placeholder="What is the main topic of this video?",
scale=5,
)
send_btn = gr.Button("Send 🚀", variant="primary", scale=1)
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
# gr.State stores the history list between interactions
chat_history = gr.State([])
send_btn.click(
fn=chat,
inputs=[query_input, chat_history],
outputs=[chatbot, query_input],
).then(
fn=lambda h: h,
inputs=[chatbot],
outputs=[chat_history],
)
query_input.submit(
fn=chat,
inputs=[query_input, chat_history],
outputs=[chatbot, query_input],
).then(
fn=lambda h: h,
inputs=[chatbot],
outputs=[chat_history],
)
clear_btn.click(
fn=lambda: ([], []),
outputs=[chatbot, chat_history],
)
# ---------------------------------------------------------------------------
# Launch — theme passed here in Gradio 6.x
# ---------------------------------------------------------------------------
if __name__ == "__main__":
app.launch()