import os import csv import json import torch import argparse import pandas as pd from tqdm import tqdm from peft import LoraConfig, get_peft_model from torch.utils.data import Dataset, DataLoader from transformers.models.llama.tokenization_llama import LlamaTokenizer from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Human: <|video|> Human: What is the misalignment between this video and the description: "{caption}"? AI: ''' generate_kwargs = { 'do_sample': True, 'top_k': 5, 'max_length': 512 } class VideoCaptionDataset(Dataset): def __init__(self, videopath, text): self.videopath = videopath self.text = text def __len__(self): return 1 def __getitem__(self, index): item = {} item['videopath'] = self.videopath item['neg_caption'] = self.text return item def get_nle(model, processor, tokenizer, dataloader): with torch.no_grad(): for _, batch in tqdm(enumerate(dataloader)): videopaths = batch['videopath'] neg_caption = batch['neg_caption'][0] prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)] inputs = processor(text=prompts, videos=videopaths, num_frames=32, return_tensors='pt') inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} inputs = {k: v.to(model.device) for k, v in inputs.items()} res = model.generate(**inputs, **generate_kwargs) generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) return generated_nle