import cv2 import gradio as gr from PIL import Image from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor model_id = "BridgeTower/bridgetower-large-itm-mlm" processor = BridgeTowerProcessor.from_pretrained(model_id) model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id) # Process a frame def process_frame(image, texts): scores = {} texts = texts.split(",") for t in texts: encoding = processor(image, t, return_tensors="pt") outputs = model(**encoding) scores[t] = "{:.2f}".format(outputs.logits[0, 1].item()) # sort scores in descending order scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)) return scores # Process a video def process(video, text, sample_rate, min_score): video = cv2.VideoCapture(video) fps = round(video.get(cv2.CAP_PROP_FPS)) frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) length = frames // fps print(f"{fps} fps, {frames} frames, {length} seconds") frame_count = 0 clips = [] clip_images = [] clip_started = False while True: ret, frame = video.read() if not ret: break if frame_count % (fps * sample_rate) == 0: frame = Image.fromarray(frame) score = process_frame(frame, text) # print(f"{frame_count} {scores}") if float(score[text]) > min_score: if clip_started: end_time = frame_count / fps else: clip_started = True start_time = frame_count / fps end_time = start_time start_score = score[text] clip_images.append(frame) elif clip_started: clip_started = False end_time = frame_count / fps clips.append((start_score, start_time, end_time)) frame_count += 1 return clip_images, clips # Inputs video = gr.Video(label="Video") text = gr.Text(label="Text query") sample_rate = gr.Number(value=5, label="Sample rate (1 frame every 'n' seconds)") min_score = gr.Number(value=3, label="Minimum score") # Output gallery = gr.Gallery(label="Images") clips = gr.Text(label="Clips (score, start time, end time)") description = "This Space lets you run semantic search on a video." iface = gr.Interface( description=description, fn=process, inputs=[video, text, sample_rate, min_score], outputs=[gallery, clips], examples=[ [ "video.mp4", "wild bears", 5, 3, ] ], allow_flagging="never", ) iface.launch()