|
import cv2 |
|
import gradio as gr |
|
from PIL import Image |
|
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor |
|
|
|
model_id = "BridgeTower/bridgetower-large-itm-mlm" |
|
processor = BridgeTowerProcessor.from_pretrained(model_id) |
|
model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id) |
|
|
|
|
|
def process_frame(image, texts): |
|
scores = {} |
|
texts = texts.split(",") |
|
for t in texts: |
|
encoding = processor(image, t, return_tensors="pt") |
|
outputs = model(**encoding) |
|
scores[t] = "{:.2f}".format(outputs.logits[0, 1].item()) |
|
|
|
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)) |
|
return scores |
|
|
|
|
|
|
|
def process(video, text, sample_rate, min_score): |
|
video = cv2.VideoCapture(video) |
|
fps = round(video.get(cv2.CAP_PROP_FPS)) |
|
frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
length = frames // fps |
|
print(f"{fps} fps, {frames} frames, {length} seconds") |
|
|
|
frame_count = 0 |
|
clips = [] |
|
clip_images = [] |
|
clip_started = False |
|
while True: |
|
ret, frame = video.read() |
|
if not ret: |
|
break |
|
|
|
if frame_count % (fps * sample_rate) == 0: |
|
frame = Image.fromarray(frame) |
|
score = process_frame(frame, text) |
|
|
|
|
|
if float(score[text]) > min_score: |
|
if clip_started: |
|
end_time = frame_count / fps |
|
else: |
|
clip_started = True |
|
start_time = frame_count / fps |
|
end_time = start_time |
|
start_score = score[text] |
|
clip_images.append(frame) |
|
elif clip_started: |
|
clip_started = False |
|
end_time = frame_count / fps |
|
clips.append((start_score, start_time, end_time)) |
|
frame_count += 1 |
|
return clip_images, clips |
|
|
|
|
|
|
|
video = gr.Video(label="Video") |
|
text = gr.Text(label="Text query") |
|
sample_rate = gr.Number(value=5, label="Sample rate (1 frame every 'n' seconds)") |
|
min_score = gr.Number(value=3, label="Minimum score") |
|
|
|
|
|
gallery = gr.Gallery(label="Images") |
|
clips = gr.Text(label="Clips (score, start time, end time)") |
|
|
|
description = "This Space lets you run semantic search on a video." |
|
|
|
iface = gr.Interface( |
|
description=description, |
|
fn=process, |
|
inputs=[video, text, sample_rate, min_score], |
|
outputs=[gallery, clips], |
|
examples=[ |
|
[ |
|
"video.mp4", |
|
"wild bears", |
|
5, |
|
3, |
|
] |
|
], |
|
allow_flagging="never", |
|
) |
|
|
|
iface.launch() |
|
|