File size: 3,825 Bytes
3a496ae
 
 
1cb9066
3a496ae
 
 
 
 
 
 
 
 
 
 
 
de447f9
3a496ae
 
 
0ba1d16
b3402e9
42a01c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de447f9
 
 
6ab097e
3a496ae
 
 
 
 
 
 
 
cfe5653
3a496ae
 
 
 
 
 
 
 
 
 
 
42a01c3
3a496ae
cfe5653
42a01c3
 
3a496ae
cfe5653
b3402e9
0ba1d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b85fc6
 
0ba1d16
8b85fc6
a606624
890b2c5
0ba1d16
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import csv
import json
import torch
import argparse
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from torch.utils.data import DataLoader
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
from peft import LoraConfig, get_peft_model
from data_utils.xgpt3_dataset import MultiModalDataset
from utils import batchify
from huggingface_hub import hf_hub_download

import gradio as gr
from entailment_inference import get_scores
from nle_inference import VideoCaptionDataset, get_nle

import re

def modify_keys(state_dict):
    new_state_dict = defaultdict()

    pattern = re.compile(r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj).weight')

    for key, value in state_dict.items():
        if pattern.match(key):
            key = key.split('.')
            key.insert(-1, 'base_layer')
            key = '.'.join(key)
        new_state_dict[key] = value

    return new_state_dict

pretrained_ckpt = "MAGAer13/mplug-owl-llama-7b-video/"
trained_ckpt = hf_hub_download(repo_id="videocon/owl-con", filename="pytorch_model.bin", repo_type="model")


tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)

# Instantiate model
model = MplugOwlForConditionalGeneration.from_pretrained(
    pretrained_ckpt,
    torch_dtype=torch.bfloat16,
    device_map={'': 'cpu'}
)

peft_config = LoraConfig(
    target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)', 
    inference_mode=True, 
    r=32, 
    lora_alpha=16, 
    lora_dropout=0.05
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

with open(trained_ckpt, 'rb') as f:
    ckpt = torch.load(f, map_location = torch.device("cpu"))
    ckpt = modify_keys(ckpt)
    
model.load_state_dict(ckpt)
model = model.to("cuda:0").to(torch.bfloat16)

def inference(videopath, text):

    PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
    Human: <|video|>
    Human: Does this video entail the description: "{caption}"?
    AI: """

    valid_data = MultiModalDataset(videopath, PROMPT.format(caption = text), tokenizer, processor, max_length = 256, loss_objective = 'sequential')
    dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify)
    score = get_scores(model, tokenizer, dataloader)

    if score < 0.5:
        dataset = VideoCaptionDataset(videopath, text)
        dataloader = DataLoader(dataset)
        nle = get_nle(model, processor, tokenizer, dataloader)
    else:
        nle = "None (NLE is only triggered when entailment score < 0.5)"

    return score, nle

demo = gr.Interface(inference, 
                    title="Owl-Con Demo",
                    description="Owl-Con Demo (Code: https://github.com/Hritikbansal/videocon | Paper: https://arxiv.org/abs/2311.10111)",
                    inputs=[gr.Video(label='input_video'), gr.Textbox(label='input_caption')], 
                    outputs=[gr.Number(label='Entailment Score'), gr.Textbox(label='Natural Language Explanation')],
                    examples=[["examples/820.mp4", "We see the group making cookies."], ["examples/820.mp4", "We see the group eating cookies."], ["examples/244.mp4", "She throws a bowling ball while talking on the phone."], ["examples/244.mp4", "She throws a baseball while talking on the phone."]])

if __name__ == "__main__":
    demo.launch()