owl-con-demo / nle_inference.py
Hritik
add app and nle code
0ba1d16
import os
import csv
import json
import torch
import argparse
import pandas as pd
from tqdm import tqdm
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset, DataLoader
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: <|video|>
Human: What is the misalignment between this video and the description: "{caption}"?
AI: '''
generate_kwargs = {
'do_sample': True,
'top_k': 5,
'max_length': 512
}
class VideoCaptionDataset(Dataset):
def __init__(self, videopath, text):
self.videopath = videopath
self.text = text
def __len__(self):
return 1
def __getitem__(self, index):
item = {}
item['videopath'] = self.videopath
item['neg_caption'] = self.text
return item
def get_nle(model, processor, tokenizer, dataloader):
with torch.no_grad():
for _, batch in tqdm(enumerate(dataloader)):
videopaths = batch['videopath']
neg_caption = batch['neg_caption'][0]
prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)]
inputs = processor(text=prompts, videos=videopaths, num_frames=32, return_tensors='pt')
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
inputs = {k: v.to(model.device) for k, v in inputs.items()}
res = model.generate(**inputs, **generate_kwargs)
generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
return generated_nle