owl-con-demo / app.py
Hritik
update app
8b85fc6
raw
history blame
3.24 kB
import os
import csv
import json
import torch
import argparse
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from torch.utils.data import DataLoader
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
from peft import LoraConfig, get_peft_model
from data_utils.xgpt3_dataset import MultiModalDataset
from utils import batchify
import gradio as gr
from entailment_inference import get_scores
from nle_inference import VideoCaptionDataset, get_nle
pretrained_ckpt = "mplugowl7bvideo/"
trained_ckpt = "owl-con/checkpoint-5178/pytorch_model.bin"
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)
# Instantiate model
model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16,
device_map={'': 'cpu'}
)
peft_config = LoraConfig(
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
inference_mode=True,
r=32,
lora_alpha=16,
lora_dropout=0.05
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
with open(trained_ckpt, 'rb') as f:
ckpt = torch.load(f, map_location = torch.device("cpu"))
model.load_state_dict(ckpt)
model = model.to("cuda:0").to(torch.bfloat16)
def inference(videopath, text):
PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: <|video|>
Human: Does this video entail the description: "{caption}"?
AI: """
valid_data = MultiModalDataset(videopath, PROMPT.format(caption = text), tokenizer, processor, max_length = 256, loss_objective = 'sequential')
dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify)
score = get_scores(model, tokenizer, dataloader)
if score < 0.5:
dataset = VideoCaptionDataset(videopath, text)
dataloader = DataLoader(dataset)
nle = get_nle(model, processor, tokenizer, dataloader)
else:
nle = "None (NLE is only triggered when entailment score < 0.5)"
return score, nle
demo = gr.Interface(inference,
title="Owl-Con Demo",
description="Owl-Con Demo (Code: https://github.com/Hritikbansal/videocon | Paper: https://arxiv.org/abs/2311.10111)",
inputs=[gr.Video(label='input_video'), gr.Textbox(label='input_caption')],
outputs=[gr.Number(label='Entailment Score'), gr.Textbox(label='Natural Language Explanation')],
examples=[["examples/820.mp4", "We see the group making cookies."], ["examples/820.mp4", "We see the group eating cookies."], ["examples/244.mp4", "She throws a bowling ball while talking on the phone."], ["examples/244.mp4", "She throws a baseball while talking on the phone."]])
if __name__ == "__main__":
demo.launch()