Spaces:
Running
Running
import os | |
import sys | |
import cv2 | |
import subprocess | |
from tqdm import tqdm # add this at the top | |
from PIL import Image | |
from dotenv import load_dotenv | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate | |
from transformers import pipeline | |
from scenedetect import SceneManager, open_video, ContentDetector | |
from sentence_transformers import SentenceTransformer, util | |
# ─── 1. AUTH & MODELS ──────────────────────────────────────────────────────────── | |
# Load environment variables | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
if not HF_TOKEN: | |
print("❌ Error: HF_TOKEN not found in .env file") | |
sys.exit(1) | |
# Initialize models with proper configurations | |
captioner = pipeline( | |
"image-to-text", | |
model="Salesforce/blip-image-captioning-base", | |
device="cpu" | |
) | |
vl_pipeline = pipeline( | |
"visual-question-answering", | |
model="Salesforce/blip-vqa-base", | |
device="cpu" | |
) | |
elaborator = pipeline( | |
"text-generation", | |
model="gpt2-medium", | |
device="cpu", | |
max_new_tokens=500, # Use max_new_tokens instead of max_length | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.7 | |
) | |
embedder = SentenceTransformer("BAAI/bge-small-en-v1.5") | |
# ─── 2. HELPERS ────────────────────────────────────────────────────────────────── | |
def run_ffmpeg(cmd): | |
full = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"] + cmd | |
p = subprocess.Popen(full, stderr=subprocess.PIPE) | |
_, err = p.communicate() | |
if p.returncode != 0: | |
print("❌ FFmpeg error:\n", err.decode()) | |
sys.exit(1) | |
# ─── 3. SCENE DETECTION & KEYFRAMES ────────────────────────────────────────────── | |
def detect_scenes(video_path, thresh=15.0): | |
v = open_video(video_path) | |
mgr = SceneManager() | |
mgr.add_detector(ContentDetector(threshold=thresh)) | |
mgr.detect_scenes(v) | |
return mgr.get_scene_list() | |
def get_removal_indices_groq(captions, query): | |
llm = ChatGroq( | |
model="llama-3.1-8b-instant", | |
temperature=0.2, | |
max_tokens=500 | |
) | |
prompt = ChatPromptTemplate.from_messages([ | |
( | |
"system", | |
"You are a helpful assistant for video analysis. The user will give you a list of scene captions, " | |
"each labeled with an index like [1], [2], ..., and a filtering instruction like 'remove food scenes'.\n\n" | |
"Return ONLY the list of indexes that should be removed — e.g., [2, 5, 9]\n" | |
"⚠️ Do not explain, describe, or add any commentary. Your response MUST be a valid Python list of integers." | |
), | |
( | |
"human", | |
"Filtering instruction: {query}\n\nCaptions:\n{captions}" | |
) | |
]) | |
chain = prompt | llm | |
captions_formatted = "\n".join(f"[{i+1}] {cap.strip()}" for i, cap in enumerate(captions)) | |
try: | |
response = chain.invoke({"query": query, "captions": captions_formatted}) | |
to_remove = eval(response.content.strip()) | |
if not isinstance(to_remove, list) or not all(isinstance(i, int) for i in to_remove): | |
raise ValueError("Invalid format") | |
except Exception as e: | |
print(f"❌ LLM returned invalid output: {response.content}") | |
to_remove = [] | |
return to_remove | |
def groq_llm(prompt): | |
llm = ChatGroq( | |
model="llama-3.1-8b-instant", | |
temperature=0.2, | |
max_tokens=500 | |
) | |
return llm.invoke(prompt).content.strip() | |
def extract_keyframes(video_path, scenes): | |
cap, frames = cv2.VideoCapture(video_path), [] | |
for s,e in scenes: | |
mid = (s.get_frames() + e.get_frames()) // 2 | |
cap.set(cv2.CAP_PROP_POS_FRAMES, mid) | |
ok, img = cap.read() | |
if ok: frames.append((mid, img)) | |
cap.release() | |
return frames | |
# ─── 4. DESCRIPTIONS & SUMMARY ─────────────────────────────────────────────────── | |
def generate_scene_caption(frame): | |
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
return captioner(img)[0]["generated_text"] | |
def generate_video_summary_groq(captions): | |
"""Generate a video summary using Groq LLM.""" | |
llm = ChatGroq( | |
model="llama-3.1-8b-instant", | |
temperature=0.2, | |
max_tokens=500 | |
) | |
prompt = ChatPromptTemplate.from_messages([ | |
( | |
"system", | |
"You are a helpful assistant for video analysis. The user will give you a list of scene captions from a video. " | |
"Your task is to write a concise, narrative summary of what happens in the video, focusing only on the events shown. " | |
"Make it engaging and easy to understand. Do not include any titles, links, or external references." | |
), | |
( | |
"human", | |
"Here are the scene captions from the video in order:\n{captions}\n\nPlease provide a narrative summary." | |
) | |
]) | |
chain = prompt | llm | |
captions_formatted = "\n".join(f"[{i+1}] {cap.strip()}" for i, cap in enumerate(captions)) | |
try: | |
response = chain.invoke({"captions": captions_formatted}) | |
summary = response.content.strip() | |
# Format the final output | |
return f"""🎬 Video Summary: | |
{summary} | |
📊 Total Scenes: {len(captions)} | |
🔍 Key Moments: | |
{chr(10).join(f"• {cap}" for cap in captions[:5])} | |
...""" | |
except Exception as e: | |
print(f"❌ Error generating summary with Groq: {e}") | |
return "❌ Error: Failed to generate video summary" | |
def generate_video_summary(captions): | |
""" | |
Generate a video summary using Groq LLM. | |
""" | |
return generate_video_summary_groq(captions) | |
import ast | |
def filter_scenes_with_llm(captions, query, llm): | |
""" | |
Uses an LLM to determine which scenes to remove based on captions and a user query. | |
Args: | |
captions (List[str]): List of scene/frame captions. | |
query (str): User intent, e.g. "Remove scenes with Trump". | |
llm (callable): Function to call your LLM, e.g. `llm(prompt)`. | |
Returns: | |
List[int]: List of 0-based frame indexes to remove. | |
""" | |
formatted = "\n".join([f"{i+1}. {cap}" for i, cap in enumerate(captions)]) | |
prompt = f""" | |
You're an intelligent video assistant. | |
The user wants to: **{query}** | |
Below are numbered captions for each scene in a video: | |
{formatted} | |
👉 Return a Python list of only the scene numbers that should be removed based on the user query. | |
👉 ONLY return the list like this: [3, 5, 11]. No explanation. | |
""" | |
# Run LLM | |
response = llm(prompt) | |
try: | |
result = ast.literal_eval(response.strip()) | |
result = [i-1 for i in result] # convert to 0-based index | |
return result | |
except: | |
print("⚠️ Failed to parse LLM output:", response) | |
return [] | |
# ─── 5. FILTERING ─────────────────────────────────────────────────────────────── | |
def group_indices(indices): | |
"""Group consecutive indices together as chunks.""" | |
if not indices: | |
return [] | |
indices = sorted(indices) | |
groups = [[indices[0]]] | |
for i in indices[1:]: | |
if i == groups[-1][-1] + 1: | |
groups[-1].append(i) | |
else: | |
groups.append([i]) | |
return groups | |
def vqa_matches(keyframes, question): | |
flags = [] | |
for _,frame in keyframes: | |
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
ans = vl_pipeline({"image": img, "question": question}) | |
flags.append("yes" in ans[0]["answer"].lower()) | |
return flags | |
def semantic_matches(captions, prompt, thresh=0.8): | |
embs = embedder.encode(captions, convert_to_tensor=True) | |
q = embedder.encode(prompt, convert_to_tensor=True) | |
sims = util.cos_sim(q, embs)[0] | |
return [i for i,s in enumerate(sims) if s>=thresh], sims.tolist() | |
# ─── 6. TRIMMING ──────────────────────────────────────────────────────────────── | |
def remove_scenes(video_path, scenes, to_remove, out="trimmed.mp4"): | |
times = [(float(s.get_seconds()), float(e.get_seconds())) for s,e in scenes] | |
# Group deletions | |
remove_groups = group_indices(to_remove) | |
# Threshold: max N consecutive scenes to allow trimming | |
MAX_REMOVE_GROUP_SIZE = 4 | |
# Adjust `to_remove`: only allow small groups or isolated removals | |
filtered_remove = [] | |
if len(scenes) > 3: | |
last_scene_idx = len(scenes) - 1 | |
for i in range(last_scene_idx - 2, last_scene_idx + 1): | |
if i in filtered_remove: | |
filtered_remove.remove(i) | |
for group in remove_groups: | |
if len(group) <= MAX_REMOVE_GROUP_SIZE: | |
filtered_remove.extend(group) | |
print(f"🧩 Filtered scenes to remove (after capping long chunks): {filtered_remove}") | |
# Final list of segments to keep | |
keep = [t for i,t in enumerate(times) if i not in filtered_remove] | |
# Create a temporary directory for segments | |
os.makedirs("temp_segments", exist_ok=True) | |
try: | |
parts = [] | |
for i,(ss,tt) in enumerate(keep): | |
fn = os.path.join("temp_segments", f"segment_{i}.mp4") | |
# Use proper encoding settings to maintain frame integrity | |
run_ffmpeg([ | |
"-i", video_path, | |
"-ss", str(ss), | |
"-to", str(tt), | |
"-c:v", "libx264", # Use H.264 codec | |
"-preset", "medium", # Balance between speed and quality | |
"-crf", "23", # Constant Rate Factor for quality | |
"-c:a", "aac", # Audio codec | |
"-b:a", "128k", # Audio bitrate | |
"-movflags", "+faststart", # Enable fast start for web playback | |
fn | |
]) | |
parts.append(fn) | |
# Create concat file | |
with open("parts.txt", "w") as f: | |
for p in parts: | |
f.write(f"file '{p}'\n") | |
# Concatenate segments with proper encoding | |
run_ffmpeg([ | |
"-f", "concat", | |
"-safe", "0", | |
"-i", "parts.txt", | |
"-c:v", "libx264", | |
"-preset", "medium", | |
"-crf", "23", | |
"-c:a", "aac", | |
"-b:a", "128k", | |
"-movflags", "+faststart", | |
out | |
]) | |
finally: | |
# Cleanup | |
for p in parts: | |
if os.path.exists(p): | |
os.remove(p) | |
if os.path.exists("parts.txt"): | |
os.remove("parts.txt") | |
if os.path.exists("temp_segments"): | |
os.rmdir("temp_segments") | |
# ─── 7. MAIN PIPELINE ────────────────────────────────────────────────────────── | |
def run(video, query): | |
print(f"\n🎥 Video: {video}\n🔎 Query: '{query}'\n") | |
scenes = detect_scenes(video) | |
print(f"🔢 {len(scenes)} scenes detected.") | |
keyframes = extract_keyframes(video, scenes) | |
print(f"🖼️ {len(keyframes)} keyframes extracted.\n") | |
captions = [generate_scene_caption(f) for _, f in tqdm(keyframes, desc="Generating captions")] | |
summary = generate_video_summary(captions) | |
print("\n--- Video Summary ---") | |
print(summary) | |
# 🧠 Let the LLM decide which scenes to remove based on captions | |
to_remove = filter_scenes_with_llm(captions, query, groq_llm) | |
print(f"\n🔴 Scenes to remove: {to_remove}") | |
if to_remove: | |
remove_scenes(video, scenes, to_remove) | |
print("✅ Trimmed video saved as `trimmed.mp4`.") | |
else: | |
print("⚠️ No matching scenes found; no trimming done.") | |
return to_remove # Optional: return for external use | |
# ─── 8. ENTRY POINT ───────────────────────────────────────────────────────────── | |
if __name__ == "__main__": | |
if len(sys.argv)<3: | |
print("Usage: python main.py <video.mp4> \"your query here\"") | |
sys.exit(1) | |
run(sys.argv[1], sys.argv[2]) | |