File size: 2,710 Bytes
21828aa
 
 
80deb6a
21828aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cee3f3
21828aa
 
5cee3f3
21828aa
 
 
 
 
 
cc9a555
21828aa
 
 
 
eec0235
21828aa
 
 
 
 
 
 
5cee3f3
80deb6a
21828aa
 
 
5995a28
21828aa
 
 
 
 
 
 
 
 
 
 
 
5cee3f3
21828aa
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import cv2
import gradio as gr
from PIL import Image
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor

model_id = "BridgeTower/bridgetower-large-itm-mlm"
processor = BridgeTowerProcessor.from_pretrained(model_id)
model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id)

# Process a frame
def process_frame(image, texts):
    scores = {}
    texts = texts.split(",")
    for t in texts:
        encoding = processor(image, t, return_tensors="pt")
        outputs = model(**encoding)
        scores[t] = "{:.2f}".format(outputs.logits[0, 1].item())
        # sort scores in descending order
        scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
    return scores


# Process a video
def process(video, text, sample_rate, min_score):
    video = cv2.VideoCapture(video)
    fps = round(video.get(cv2.CAP_PROP_FPS))
    frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    length = frames // fps
    print(f"{fps} fps, {frames} frames, {length} seconds")

    frame_count = 0
    clips = []
    clip_images = []
    clip_started = False
    while True:
        ret, frame = video.read()
        if not ret:
            break

        if frame_count % (fps * sample_rate) == 0:
            frame = Image.fromarray(frame)
            score = process_frame(frame, text)
            # print(f"{frame_count} {scores}")

            if float(score[text]) > min_score:
                if clip_started:
                    end_time = frame_count / fps
                else:
                    clip_started = True
                    start_time = frame_count / fps
                    end_time = start_time
                    start_score = score[text]
                    clip_images.append(frame)
            elif clip_started:
                clip_started = False
                end_time = frame_count / fps
                clips.append((start_score, start_time, end_time))
        frame_count += 1
    return clip_images, clips


# Inputs
video = gr.Video(label="Video")
text = gr.Text(label="Text query")
sample_rate = gr.Number(value=5, label="Sample rate (1 frame every 'n' seconds)")
min_score = gr.Number(value=3, label="Minimum score")

# Output
gallery = gr.Gallery(label="Images")
clips = gr.Text(label="Clips (score, start time, end time)")

description = "This Space lets you run semantic search on a video."

iface = gr.Interface(
    description=description,
    fn=process,
    inputs=[video, text, sample_rate, min_score],
    outputs=[gallery, clips],
    examples=[
        [
            "video.mp4",
            "wild bears",
            5,
            3,
        ]
    ],
    allow_flagging="never",
)

iface.launch()