Spaces:
Runtime error
Runtime error
| import os | |
| import csv | |
| import json | |
| import torch | |
| import argparse | |
| import pandas as pd | |
| from tqdm import tqdm | |
| from peft import LoraConfig, get_peft_model | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers.models.llama.tokenization_llama import LlamaTokenizer | |
| from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration | |
| from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor | |
| PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. | |
| Human: <|video|> | |
| Human: What is the misalignment between this video and the description: "{caption}"? | |
| AI: ''' | |
| generate_kwargs = { | |
| 'do_sample': True, | |
| 'top_k': 5, | |
| 'max_length': 512 | |
| } | |
| class VideoCaptionDataset(Dataset): | |
| def __init__(self, videopath, text): | |
| self.videopath = videopath | |
| self.text = text | |
| def __len__(self): | |
| return 1 | |
| def __getitem__(self, index): | |
| item = {} | |
| item['videopath'] = self.videopath | |
| item['neg_caption'] = self.text | |
| return item | |
| def get_nle(model, processor, tokenizer, dataloader): | |
| with torch.no_grad(): | |
| for _, batch in tqdm(enumerate(dataloader)): | |
| videopaths = batch['videopath'] | |
| neg_caption = batch['neg_caption'][0] | |
| prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)] | |
| inputs = processor(text=prompts, videos=videopaths, num_frames=32, return_tensors='pt') | |
| inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| res = model.generate(**inputs, **generate_kwargs) | |
| generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) | |
| return generated_nle |