Spaces:
Runtime error
Runtime error
import os | |
import csv | |
import json | |
import torch | |
import argparse | |
import pandas as pd | |
from tqdm import tqdm | |
from peft import LoraConfig, get_peft_model | |
from torch.utils.data import Dataset, DataLoader | |
from transformers.models.llama.tokenization_llama import LlamaTokenizer | |
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration | |
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor | |
PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. | |
Human: <|video|> | |
Human: What is the misalignment between this video and the description: "{caption}"? | |
AI: ''' | |
generate_kwargs = { | |
'do_sample': True, | |
'top_k': 5, | |
'max_length': 512 | |
} | |
class VideoCaptionDataset(Dataset): | |
def __init__(self, videopath, text): | |
self.videopath = videopath | |
self.text = text | |
def __len__(self): | |
return 1 | |
def __getitem__(self, index): | |
item = {} | |
item['videopath'] = self.videopath | |
item['neg_caption'] = self.text | |
return item | |
def get_nle(model, processor, tokenizer, dataloader): | |
with torch.no_grad(): | |
for _, batch in tqdm(enumerate(dataloader)): | |
videopaths = batch['videopath'] | |
neg_caption = batch['neg_caption'][0] | |
prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)] | |
inputs = processor(text=prompts, videos=videopaths, num_frames=32, return_tensors='pt') | |
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
res = model.generate(**inputs, **generate_kwargs) | |
generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) | |
return generated_nle |