from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration from typing import List import os import supervision as sv import uuid from tqdm import tqdm import gradio as gr import torch from PIL import Image import spaces import flax.linen as nn import jax import string import functools import jax.numpy as jnp import numpy as np import re device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_id = "google/paligemma-3b-mix-448" model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device) processor = PaliGemmaProcessor.from_pretrained(model_id) BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator() MASK_ANNOTATOR = sv.MaskAnnotator() LABEL_ANNOTATOR = sv.LabelAnnotator() def calculate_end_frame_index(source_video_path): video_info = sv.VideoInfo.from_video_path(source_video_path) return min( video_info.total_frames, video_info.fps * 2 ) def annotate_image( input_image, detections, labels ) -> np.ndarray: output_image = MASK_ANNOTATOR.annotate(input_image, detections) output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels) return output_image @spaces.GPU def process_video( input_video, labels, progress=gr.Progress(track_tqdm=True) ): video_info = sv.VideoInfo.from_video_path(input_video) total = calculate_end_frame_index(input_video) frame_generator = sv.get_video_frames_generator( source_path=input_video, end=total ) result_file_name = f"{uuid.uuid4()}.mp4" result_file_path = os.path.join("./", result_file_name) with sv.VideoSink(result_file_path, video_info=video_info) as sink: for _ in tqdm(range(total), desc="Processing video.."): frame = next(frame_generator) # list of dict of {"box": box, "mask":mask, "score":score, "label":label} results, input_list = parse_detection(frame, labels) detections = sv.Detections.from_transformers(results[0]) final_labels = [] for id in results[0]["labels"]: final_labels.append(input_list[id]) frame = annotate_image( input_image=frame, detections=detections, labels=final_labels, ) sink.write_frame(frame) return result_file_path @spaces.GPU def infer( image: Image.Image, text: str, max_new_tokens: int ) -> str: inputs = processor(text=text, images=image, return_tensors="pt").to(device) with torch.inference_mode(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) result = processor.batch_decode(generated_ids, skip_special_tokens=True) return result[0][len(text):].lstrip("\n") def parse_detection(input_image, input_text): prompt = f"detect {input_text}" out = infer(input_image, prompt, max_new_tokens=100) objs = extract_objs(out.lstrip("\n"), input_image.shape[0], input_image.shape[1], unique_labels=True) labels = list(obj.get('name') for obj in objs if obj.get('name')) print("labels", labels) input_list = input_text.split(";") for ind, input in enumerate(input_list): input_list[ind] = remove_special_characters(input).lstrip("\n").rstrip("\n") label_indices = [] for label in labels: label = remove_special_characters(label) label_indices.append(input_list.index(label)) label_indices = torch.tensor(label_indices).to("cuda") boxes = torch.tensor([list(obj["xyxy"]) for obj in objs]) return [{"boxes": boxes, "scores":torch.tensor([0.99 for _ in range(len(boxes))]).to("cuda"), "labels":label_indices}], input_list _MODEL_PATH = 'vae-oid.npz' _SEGMENT_DETECT_RE = re.compile( r'(.*?)' + r'' * 4 + r'\s*' + '(?:%s)?' % (r'' * 16) + r'\s*([^;<>]+)? ?(?:; )?', ) def _quantized_values_from_codebook_indices(codebook_indices, embeddings): batch_size, num_tokens = codebook_indices.shape assert num_tokens == 16, codebook_indices.shape unused_num_embeddings, embedding_dim = embeddings.shape encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) return encodings def remove_special_characters(word): return re.sub(r'^[^a-zA-Z0-9]+|[^a-zA-Z0-9]+$', '', word) def extract_objs(text, width, height, unique_labels=False): """Returns objs for a string with "" and "" tokens.""" objs = [] seen = set() while text: m = _SEGMENT_DETECT_RE.match(text) if not m: break gs = list(m.groups()) before = gs.pop(0) name = gs.pop() y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) seg_indices = gs[4:20] mask=None content = m.group() if before: objs.append(dict(content=before)) content = content[len(before):] while unique_labels and name in seen: name = (name or '') + "'" seen.add(name) objs.append(dict( content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) text = text[len(before) + len(content):] if text: objs.append(dict(content=text)) return objs with gr.Blocks() as demo: gr.Markdown("## Zero-shot Object Tracking with PaliGemma") gr.Markdown("This is a demo for zero-shot object tracking using [PaliGemma](https://huggingface.co/google/paligemma-3b-mix-448) vision language model by Google.") gr.Markdown("Simply upload a video and enter the candidate labels, or try the example below. Text input should be ; separated. 👇") with gr.Tab(label="Video"): with gr.Row(): input_video = gr.Video( label='Input Video' ) output_video = gr.Video( label='Output Video' ) with gr.Row(): candidate_labels = gr.Textbox( label='Labels', placeholder='Labels separated by a comma', ) submit = gr.Button() gr.Examples( fn=process_video, examples=[["./cats.mp4", "bird ; cat"]], inputs=[ input_video, candidate_labels, ], outputs=output_video ) submit.click( fn=process_video, inputs=[input_video, candidate_labels], outputs=output_video ) demo.launch(debug=False, show_error=True)