| import os |
| import torch |
| import numpy as np |
| from typing import BinaryIO, List |
|
|
| from imagebind import imagebind_model |
| from imagebind.models.imagebind_model import ModalityType |
| from imagebind.models.multimodal_preprocessors import SimpleTokenizer, TextPreprocessor |
|
|
|
|
| V2_URL = "https://huggingface.co/jondurbin/videobind-v0.2/resolve/main/videobind.pth" |
| V2_PATH = "./.checkpoints/videobind-v0.2.pth" |
| BPE_PATH = "./models/bpe_simple_vocab_16e6.txt.gz" |
| TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH) |
| LENGTH_TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH, context_length=1024) |
| TOKEN_CHUNK_SIZE = 74 |
|
|
| def get_imagebind_v2(path: str=V2_PATH): |
| if not os.path.isfile(path): |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| torch.hub.download_url_to_file(V2_URL, path, progress=True) |
| imagebind_model = torch.load(path) |
| return imagebind_model |
|
|
|
|
| def load_and_transform_text(text, device): |
| if text is None: |
| return None |
| tokens = [TOKENIZER(t).unsqueeze(0).to(device) for t in text] |
| tokens = torch.cat(tokens, dim=0) |
| return tokens |
|
|
| def split_text_by_token_limit(text, tokenizer, max_tokens=TOKEN_CHUNK_SIZE): |
| def fits_in_token_limit(text_segment): |
| tokens = tokenizer(text_segment) |
| tokens = tokens[tokens != 0][1:-1].tolist() |
| return len(tokens) <= max_tokens |
|
|
| def recursive_split(text, delimiters): |
| if fits_in_token_limit(text): |
| return [text] |
| if not delimiters: |
| return split_by_tokens(text) |
| delimiter = delimiters[0] |
| parts = text.split(delimiter) |
| result = [] |
| current_segment = "" |
| for part in parts: |
| candidate_segment = current_segment + (delimiter if current_segment else '') + part |
| if fits_in_token_limit(candidate_segment): |
| current_segment = candidate_segment |
| else: |
| if current_segment: |
| result.append(current_segment) |
| current_segment = part |
| if current_segment: |
| result.append(current_segment) |
| final_result = [] |
| for segment in result: |
| if fits_in_token_limit(segment): |
| final_result.append(segment) |
| else: |
| final_result.extend(recursive_split(segment, delimiters[1:])) |
| return final_result |
|
|
| def split_by_tokens(text): |
| tokens = tokenizer(text) |
| tokens = tokens[tokens != 0][1:-1].tolist() |
| chunks = np.array_split(tokens, int(len(tokens) / max_tokens) or 1) |
| return [ |
| tokenizer.decode(segment_tokens) |
| for segment_tokens in chunks |
| ] |
|
|
| return recursive_split(text, ['\n', '.', '!', '?', ',', ' ']) |
|
|
| def load_and_transform_text_chunks(text, device): |
| if not text: |
| return [] |
| all_tokens = LENGTH_TOKENIZER(text) |
| all_tokens = all_tokens[all_tokens != 0][1:-1].tolist() |
|
|
| return [ |
| load_and_transform_text([segment], device) |
| for segment in split_text_by_token_limit(text, LENGTH_TOKENIZER) |
| ] |
|
|
|
|
| class ImageBind: |
| def __init__(self, device="cuda:0", v2=False): |
| self.device = device |
| self.v2 = v2 |
| if v2: |
| if not os.path.exists(V2_PATH): |
| os.makedirs(os.path.dirname(V2_PATH), exist_ok=True) |
| torch.hub.download_url_to_file( |
| V2_URL, |
| V2_PATH, |
| progress=True, |
| ) |
| self.imagebind = torch.load(V2_PATH) |
| else: |
| self.imagebind = imagebind_model.imagebind_huge(pretrained=True) |
| self.imagebind.eval() |
| self.imagebind.to(self.device) |
|
|
| def generate_text_embeddings(self, text: str): |
| if not self.v2: |
| return self.imagebind({ |
| ModalityType.TEXT: load_and_transform_text([text], self.device) |
| })[ModalityType.TEXT] |
| chunks = load_and_transform_text_chunks(text, self.device) |
| embeddings = [ |
| self.imagebind({ModalityType.TEXT: chunk})[ModalityType.TEXT] |
| for chunk in chunks |
| ] |
| return torch.mean(torch.stack(embeddings), dim=0) |
|
|
| """ Deactivating full embeddings as they are not used in the current implementation |
| def get_inputs(self, video_file: BinaryIO) -> dict: |
| audio_file = video_utils.copy_audio(video_file.name) |
| try: |
| duration = video_utils.get_video_duration(video_file.name) |
| video_data = data.load_and_transform_video_data( |
| [video_file.name], |
| self.device, |
| ) |
| audio_data = data.load_and_transform_audio_data( |
| [audio_file.name], |
| self.device, |
| ) |
| inputs = { |
| ModalityType.VISION: video_data, |
| ModalityType.AUDIO: audio_data, |
| } |
| return inputs |
| finally: |
| audio_file.close() |
| |
| @torch.no_grad() |
| def embed(self, descriptions: List[str], video_files: List[BinaryIO]) -> Embeddings: |
| return_value = None |
| for idx in range(len(descriptions)): |
| inputs = self.get_inputs(video_files[idx]) |
| embeddings = self.imagebind(inputs) |
| text_embeddings = self.generate_text_embeddings(descriptions[idx]) |
| if not return_value: |
| return_value = Embeddings( |
| video=embeddings[ModalityType.VISION], |
| audio=embeddings[ModalityType.AUDIO], |
| description=text_embeddings, |
| ) |
| else: |
| return_value.video = torch.cat((return_value.video, embeddings[ModalityType.VISION])) |
| return_value.audio = torch.cat((return_value.audio, embeddings[ModalityType.AUDIO])) |
| return_value.description = torch.cat((return_value.description, text_embeddings)) |
| return return_value |
| |
| @torch.no_grad() |
| def embed_only_video(self, video_files: List[BinaryIO]) -> Embeddings: |
| video_filepaths = [video_file.name for video_file in video_files] |
| durations = [video_utils.get_video_duration(f.name) for f in video_files] |
| embeddings = self.imagebind({ |
| ModalityType.VISION: [ |
| data.load_and_transform_video_data( |
| [video_filepaths[idx]], |
| self.device, |
| )[0] |
| for idx in range(len(video_filepaths)) |
| ] |
| }) |
| return Embeddings( |
| video=embeddings[ModalityType.VISION], |
| ) |
| |
| @torch.no_grad() |
| def embed_video_and_text(self, video_files: List[BinaryIO], descriptions: List[str]) -> Embeddings: |
| video_filepaths = [video_file.name for video_file in video_files] |
| durations = [video_utils.get_video_duration(f.name) for f in video_files] |
| embeddings = self.imagebind({ |
| ModalityType.VISION: [ |
| data.load_and_transform_video_data( |
| [video_filepaths[idx]], |
| self.device, |
| )[0] |
| for idx in range(len(video_filepaths)) |
| ], |
| }) |
| description_embeddings = torch.stack([ |
| self.generate_text_embeddings(description) |
| for description in descriptions |
| ]) |
| return Embeddings( |
| video=embeddings[ModalityType.VISION], |
| description=description_embeddings, |
| ) |
| |
| @torch.no_grad() |
| def embed_text(self, texts: List[str]) -> torch.Tensor: |
| return_value = None |
| for text in texts: |
| emb = self.generate_text_embeddings(text) |
| if not return_value: |
| return_value = emb |
| else: |
| return_value = torch.cat((return_value, emb)) |
| return return_value |
| """ |
| |
| @torch.no_grad() |
| def embed_text(self, texts: List[str]) -> torch.Tensor: |
| embeddings = [] |
| for text in texts: |
| emb = self.generate_text_embeddings(text) |
| embeddings.append(emb) |
| |
| if not embeddings: |
| return None |
| |
| |
| return torch.stack(embeddings, dim=0) |
|
|