Ahsen Khaliq commited on
Commit
c9ec901
·
1 Parent(s): ed0238e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from PIL import Image
3
+ import clip
4
+ import torch
5
+ import math
6
+ import numpy as np
7
+ import torch
8
+ import datetime
9
+
10
+
11
+ # Load the open CLIP model
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model, preprocess = clip.load("ViT-B/32", device=device)
14
+
15
+
16
+
17
+ def search_video(search_query, display_heatmap=True, display_results_count=1):
18
+
19
+ # Encode and normalize the search query using CLIP
20
+ with torch.no_grad():
21
+ text_features = model.encode_text(clip.tokenize(search_query).to(device))
22
+ text_features /= text_features.norm(dim=-1, keepdim=True)
23
+
24
+ # Compute the similarity between the search query and each frame using the Cosine similarity
25
+ similarities = (100.0 * video_features @ text_features.T)
26
+ values, best_photo_idx = similarities.topk(display_results_count, dim=0)
27
+
28
+
29
+ for frame_id in best_photo_idx:
30
+ frame = video_frames[frame_id]
31
+ # Find the timestamp in the video and display it
32
+ seconds = round(frame_id.cpu().numpy()[0] * N / fps)
33
+ return frame,f"Found at {str(datetime.timedelta(seconds=seconds))}"
34
+
35
+
36
+ def inference(video, text):
37
+ # The frame images will be stored in video_frames
38
+ video_frames = []
39
+ # Open the video file
40
+ capture = cv2.VideoCapture(video)
41
+ fps = capture.get(cv2.CAP_PROP_FPS)
42
+
43
+ current_frame = 0
44
+ while capture.isOpened():
45
+ # Read the current frame
46
+ ret, frame = capture.read()
47
+
48
+ # Convert it to a PIL image (required for CLIP) and store it
49
+ if ret == True:
50
+ video_frames.append(Image.fromarray(frame[:, :, ::-1]))
51
+ else:
52
+ break
53
+
54
+ # Skip N frames
55
+ current_frame += N
56
+ capture.set(cv2.CAP_PROP_POS_FRAMES, current_frame)
57
+
58
+ # Print some statistics
59
+ print(f"Frames extracted: {len(video_frames)}")
60
+
61
+
62
+ # You can try tuning the batch size for very large videos, but it should usually be OK
63
+ batch_size = 256
64
+ batches = math.ceil(len(video_frames) / batch_size)
65
+
66
+ # The encoded features will bs stored in video_features
67
+ video_features = torch.empty([0, 512], dtype=torch.float16).to(device)
68
+
69
+ # Process each batch
70
+ for i in range(batches):
71
+ print(f"Processing batch {i+1}/{batches}")
72
+
73
+ # Get the relevant frames
74
+ batch_frames = video_frames[i*batch_size : (i+1)*batch_size]
75
+
76
+ # Preprocess the images for the batch
77
+ batch_preprocessed = torch.stack([preprocess(frame) for frame in batch_frames]).to(device)
78
+
79
+ # Encode with CLIP and normalize
80
+ with torch.no_grad():
81
+ batch_features = model.encode_image(batch_preprocessed)
82
+ batch_features /= batch_features.norm(dim=-1, keepdim=True)
83
+
84
+ # Append the batch to the list containing all features
85
+ video_features = torch.cat((video_features, batch_features))
86
+
87
+ # Print some stats
88
+ print(f"Features: {video_features.shape}")
89
+
90
+ return search_video(text)
91
+
92
+
93
+