|
|
import torch |
|
|
from termcolor import colored |
|
|
from modeling_tara import TARA, read_frames_decord, read_images_decord |
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
def main(model_path: str = "."): |
|
|
print(colored("="*60, 'yellow')) |
|
|
print(colored("TARA Model Demo", 'yellow', attrs=['bold'])) |
|
|
print(colored("="*60, 'yellow')) |
|
|
|
|
|
|
|
|
print(colored("\n[1/6] Loading model...", 'cyan')) |
|
|
model = TARA.from_pretrained( |
|
|
model_path, |
|
|
device_map='auto', |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
n_params = sum(p.numel() for p in model.model.parameters()) |
|
|
print(colored(f"β Model loaded successfully!", 'green')) |
|
|
print(f"Number of parameters: {round(n_params/1e9, 3)}B") |
|
|
print("-" * 100) |
|
|
|
|
|
|
|
|
print(colored("\n[2/6] Testing video encoding and captioning ...", 'cyan')) |
|
|
video_path = "./assets/folding_paper.mp4" |
|
|
try: |
|
|
video_tensor = read_frames_decord(video_path, num_frames=16) |
|
|
video_tensor = video_tensor.unsqueeze(0) |
|
|
video_tensor = video_tensor.to(model.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
video_emb = model.encode_vision(video_tensor).cpu().squeeze(0).float() |
|
|
|
|
|
|
|
|
video_caption = model.describe(video_tensor)[0] |
|
|
|
|
|
print(colored("β Video encoded successfully!", 'green')) |
|
|
print(f"Video shape: {video_tensor.shape}") |
|
|
print(f"Video embedding shape: {video_emb.shape}") |
|
|
print(colored(f"Video caption: {video_caption}", 'magenta')) |
|
|
except FileNotFoundError: |
|
|
print(colored(f"β Video file not found: {video_path}", 'red')) |
|
|
print(colored(" Please add a video file or update the path in demo_usage.py", 'yellow')) |
|
|
video_emb = None |
|
|
print("-" * 100) |
|
|
|
|
|
|
|
|
print(colored("\n[3/6] Testing text encoding...", 'cyan')) |
|
|
text = ['someone is folding a paper', 'cutting a paper', 'someone is unfolding a paper'] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
text_emb = model.encode_text(text).cpu().float() |
|
|
|
|
|
print(colored("β Text encoded successfully!", 'green')) |
|
|
print(f"Text: {text}") |
|
|
print(f"Text embedding shape: {text_emb.shape}") |
|
|
|
|
|
|
|
|
if video_emb is not None: |
|
|
print(colored("\n[4/6] Computing video-text similarities...", 'cyan')) |
|
|
similarities = torch.cosine_similarity( |
|
|
video_emb.unsqueeze(0).unsqueeze(0), |
|
|
text_emb.unsqueeze(0), |
|
|
dim=-1 |
|
|
) |
|
|
print(colored("β Similarities computed!", 'green')) |
|
|
for i, txt in enumerate(text): |
|
|
print(f" '{txt}': {similarities[0, i].item():.4f}") |
|
|
print("-" * 100) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(colored("\n[5/6] Testing negation example...", 'cyan')) |
|
|
image_paths = [ |
|
|
'./assets/cat.png', |
|
|
'./assets/dog+cat.png', |
|
|
] |
|
|
image_tensors = read_images_decord(image_paths) |
|
|
with torch.no_grad(): |
|
|
image_embs = model.encode_vision(image_tensors.to(model.model.device)).cpu().float() |
|
|
image_embs = torch.nn.functional.normalize(image_embs, dim=-1) |
|
|
print(f"Image embedding shape: {image_embs.shape}") |
|
|
|
|
|
texts = ['an image of a cat but there is no dog in it'] |
|
|
with torch.no_grad(): |
|
|
text_embs = model.encode_text(texts).cpu().float() |
|
|
text_embs = torch.nn.functional.normalize(text_embs, dim=-1) |
|
|
print("Text query: ", texts) |
|
|
sim = text_embs @ image_embs.t() |
|
|
print(f"Text-Image similarity: {sim}") |
|
|
print("- " * 50) |
|
|
|
|
|
texts = ['an image of a cat and a dog together'] |
|
|
with torch.no_grad(): |
|
|
text_embs = model.encode_text(texts).cpu().float() |
|
|
text_embs = torch.nn.functional.normalize(text_embs, dim=-1) |
|
|
print("Text query: ", texts) |
|
|
sim = text_embs @ image_embs.t() |
|
|
print(f"Text-Image similarity: {sim}") |
|
|
print("-" * 100) |
|
|
|
|
|
|
|
|
|
|
|
print(colored("\n[6/6] Testing composed video retrieval...", 'cyan')) |
|
|
|
|
|
|
|
|
|
|
|
source_video_path = "./assets/5369546.mp4" |
|
|
target_video_path = "./assets/1006630957.mp4" |
|
|
edit_text ="make the tree lit up" |
|
|
source_video_tensor = read_frames_decord(source_video_path, num_frames=4) |
|
|
target_video_tensor = read_frames_decord(target_video_path, num_frames=16) |
|
|
with torch.no_grad(): |
|
|
source_video_emb = model.encode_vision(source_video_tensor.unsqueeze(0), edit_text).cpu().squeeze(0).float() |
|
|
source_video_emb = torch.nn.functional.normalize(source_video_emb, dim=-1) |
|
|
target_video_emb = model.encode_vision(target_video_tensor.unsqueeze(0)).cpu().squeeze(0).float() |
|
|
target_video_emb = torch.nn.functional.normalize(target_video_emb, dim=-1) |
|
|
sim_with_edit = source_video_emb @ target_video_emb.t() |
|
|
print(f"Source-Target similarity with edit: {sim_with_edit}") |
|
|
|
|
|
|
|
|
print(colored("\n" + "="*60, 'yellow')) |
|
|
print(colored("Demo completed successfully! π", 'green', attrs=['bold'])) |
|
|
print(colored("="*60, 'yellow')) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_path", type=str, default=".") |
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args.model_path) |