merve HF staff commited on
Commit
f0e0019
1 Parent(s): cacd489

Upload 5 files

Browse files
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
2
+ from typing import List
3
+ import os
4
+ import supervision as sv
5
+ import uuid
6
+ from tqdm import tqdm
7
+ import gradio as gr
8
+ import torch
9
+ from PIL import Image
10
+ import spaces
11
+ import flax.linen as nn
12
+ import jax
13
+ import string
14
+ import functools
15
+ import jax.numpy as jnp
16
+ import numpy as np
17
+ import re
18
+
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model_id = "google/paligemma-3b-mix-448"
22
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
23
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
24
+
25
+ BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
26
+ MASK_ANNOTATOR = sv.MaskAnnotator()
27
+ LABEL_ANNOTATOR = sv.LabelAnnotator()
28
+
29
+
30
+ def calculate_end_frame_index(source_video_path):
31
+ video_info = sv.VideoInfo.from_video_path(source_video_path)
32
+ return min(
33
+ video_info.total_frames,
34
+ video_info.fps * 2
35
+ )
36
+
37
+
38
+ def annotate_image(
39
+ input_image,
40
+ detections,
41
+ labels
42
+ ) -> np.ndarray:
43
+ output_image = MASK_ANNOTATOR.annotate(input_image, detections)
44
+ output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
45
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
46
+ return output_image
47
+
48
+ @spaces.GPU
49
+ def process_video(
50
+ input_video,
51
+ labels,
52
+ progress=gr.Progress(track_tqdm=True)
53
+ ):
54
+ video_info = sv.VideoInfo.from_video_path(input_video)
55
+ total = calculate_end_frame_index(input_video)
56
+ frame_generator = sv.get_video_frames_generator(
57
+ source_path=input_video,
58
+ end=total
59
+ )
60
+
61
+ result_file_name = f"{uuid.uuid4()}.mp4"
62
+ result_file_path = os.path.join("./", result_file_name)
63
+ with sv.VideoSink(result_file_path, video_info=video_info) as sink:
64
+ for _ in tqdm(range(total), desc="Processing video.."):
65
+ frame = next(frame_generator)
66
+ # list of dict of {"box": box, "mask":mask, "score":score, "label":label}
67
+ results, input_list = parse_detection(frame, labels)
68
+ detections = sv.Detections.from_transformers(results[0])
69
+ final_labels = []
70
+
71
+ for id in results[0]["labels"]:
72
+ final_labels.append(input_list[id])
73
+ frame = annotate_image(
74
+ input_image=frame,
75
+ detections=detections,
76
+ labels=final_labels,
77
+ )
78
+ sink.write_frame(frame)
79
+ return result_file_path
80
+
81
+ @spaces.GPU
82
+ def infer(
83
+ image: Image.Image,
84
+ text: str,
85
+ max_new_tokens: int
86
+ ) -> str:
87
+ inputs = processor(text=text, images=image, return_tensors="pt").to(device)
88
+ with torch.inference_mode():
89
+ generated_ids = model.generate(
90
+ **inputs,
91
+ max_new_tokens=max_new_tokens,
92
+ do_sample=False
93
+ )
94
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)
95
+ return result[0][len(text):].lstrip("\n")
96
+
97
+ def parse_detection(input_image, input_text):
98
+ prompt = f"detect {input_text}"
99
+ out = infer(input_image, prompt, max_new_tokens=100)
100
+ objs = extract_objs(out.lstrip("\n"), input_image.shape[0], input_image.shape[1], unique_labels=True)
101
+
102
+ labels = list(obj.get('name') for obj in objs if obj.get('name'))
103
+ print("labels", labels)
104
+ input_list = input_text.split(";")
105
+ for ind, input in enumerate(input_list):
106
+ input_list[ind] = remove_special_characters(input).lstrip("\n").rstrip("\n")
107
+ label_indices = []
108
+ for label in labels:
109
+ label = remove_special_characters(label)
110
+ label_indices.append(input_list.index(label))
111
+ label_indices = torch.tensor(label_indices).to("cuda")
112
+ boxes = torch.tensor([list(obj["xyxy"]) for obj in objs])
113
+ return [{"boxes": boxes, "scores":torch.tensor([0.99 for _ in range(len(boxes))]).to("cuda"), "labels":label_indices}], input_list
114
+
115
+ _MODEL_PATH = 'vae-oid.npz'
116
+
117
+ _SEGMENT_DETECT_RE = re.compile(
118
+ r'(.*?)' +
119
+ r'<loc(\d{4})>' * 4 + r'\s*' +
120
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
121
+ r'\s*([^;<>]+)? ?(?:; )?',
122
+ )
123
+
124
+
125
+
126
+
127
+
128
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
129
+ batch_size, num_tokens = codebook_indices.shape
130
+ assert num_tokens == 16, codebook_indices.shape
131
+ unused_num_embeddings, embedding_dim = embeddings.shape
132
+
133
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
134
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
135
+ return encodings
136
+
137
+ def remove_special_characters(word):
138
+ return re.sub(r'^[^a-zA-Z0-9]+|[^a-zA-Z0-9]+$', '', word)
139
+
140
+
141
+ def extract_objs(text, width, height, unique_labels=False):
142
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
143
+ objs = []
144
+ seen = set()
145
+ while text:
146
+ m = _SEGMENT_DETECT_RE.match(text)
147
+ if not m:
148
+ break
149
+ gs = list(m.groups())
150
+ before = gs.pop(0)
151
+ name = gs.pop()
152
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
153
+
154
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
155
+ seg_indices = gs[4:20]
156
+ mask=None
157
+ content = m.group()
158
+ if before:
159
+ objs.append(dict(content=before))
160
+ content = content[len(before):]
161
+ while unique_labels and name in seen:
162
+ name = (name or '') + "'"
163
+ seen.add(name)
164
+ objs.append(dict(
165
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
166
+ text = text[len(before) + len(content):]
167
+
168
+ if text:
169
+ objs.append(dict(content=text))
170
+
171
+ return objs
172
+
173
+
174
+
175
+
176
+ with gr.Blocks() as demo:
177
+ gr.Markdown("## Zero-shot Object Tracking with OWLv2 🦉")
178
+ gr.Markdown("This is a demo for zero-shot object tracking using [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) model by Google.")
179
+ gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. 👇")
180
+ with gr.Tab(label="Video"):
181
+ with gr.Row():
182
+ input_video = gr.Video(
183
+ label='Input Video'
184
+ )
185
+ output_video = gr.Video(
186
+ label='Output Video'
187
+ )
188
+ with gr.Row():
189
+ candidate_labels = gr.Textbox(
190
+ label='Labels',
191
+ placeholder='Labels separated by a comma',
192
+ )
193
+ submit = gr.Button()
194
+ gr.Examples(
195
+ fn=process_video,
196
+ examples=[["./cats.mp4", "bird ; cat"]],
197
+ inputs=[
198
+ input_video,
199
+ candidate_labels,
200
+
201
+ ],
202
+ outputs=output_video
203
+ )
204
+
205
+ submit.click(
206
+ fn=process_video,
207
+ inputs=[input_video, candidate_labels],
208
+ outputs=output_video
209
+ )
210
+
211
+ demo.launch(debug=False, show_error=True)
c0175c7c-4f1c-4a23-8ad0-3f67fa2f9d3b.mp4 ADDED
Binary file (258 Bytes). View file
 
cats.mp4 ADDED
Binary file (115 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers.git
3
+ supervision
4
+ spaces
5
+ jax
6
+ pillow
7
+ flax
vae-oid.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5586010257b8536dddefab65e7755077f21d5672d5674dacf911f73ae95a4447
3
+ size 8479556