Spaces:
Runtime error
Runtime error
File size: 3,825 Bytes
3a496ae 1cb9066 3a496ae de447f9 3a496ae 0ba1d16 b3402e9 42a01c3 de447f9 6ab097e 3a496ae cfe5653 3a496ae 42a01c3 3a496ae cfe5653 42a01c3 3a496ae cfe5653 b3402e9 0ba1d16 8b85fc6 0ba1d16 8b85fc6 a606624 890b2c5 0ba1d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import csv
import json
import torch
import argparse
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from transformers.models.llama.tokenization_llama import LlamaTokenizer
from torch.utils.data import DataLoader
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
from peft import LoraConfig, get_peft_model
from data_utils.xgpt3_dataset import MultiModalDataset
from utils import batchify
from huggingface_hub import hf_hub_download
import gradio as gr
from entailment_inference import get_scores
from nle_inference import VideoCaptionDataset, get_nle
import re
def modify_keys(state_dict):
new_state_dict = defaultdict()
pattern = re.compile(r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj).weight')
for key, value in state_dict.items():
if pattern.match(key):
key = key.split('.')
key.insert(-1, 'base_layer')
key = '.'.join(key)
new_state_dict[key] = value
return new_state_dict
pretrained_ckpt = "MAGAer13/mplug-owl-llama-7b-video/"
trained_ckpt = hf_hub_download(repo_id="videocon/owl-con", filename="pytorch_model.bin", repo_type="model")
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)
# Instantiate model
model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16,
device_map={'': 'cpu'}
)
peft_config = LoraConfig(
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
inference_mode=True,
r=32,
lora_alpha=16,
lora_dropout=0.05
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
with open(trained_ckpt, 'rb') as f:
ckpt = torch.load(f, map_location = torch.device("cpu"))
ckpt = modify_keys(ckpt)
model.load_state_dict(ckpt)
model = model.to("cuda:0").to(torch.bfloat16)
def inference(videopath, text):
PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: <|video|>
Human: Does this video entail the description: "{caption}"?
AI: """
valid_data = MultiModalDataset(videopath, PROMPT.format(caption = text), tokenizer, processor, max_length = 256, loss_objective = 'sequential')
dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify)
score = get_scores(model, tokenizer, dataloader)
if score < 0.5:
dataset = VideoCaptionDataset(videopath, text)
dataloader = DataLoader(dataset)
nle = get_nle(model, processor, tokenizer, dataloader)
else:
nle = "None (NLE is only triggered when entailment score < 0.5)"
return score, nle
demo = gr.Interface(inference,
title="Owl-Con Demo",
description="Owl-Con Demo (Code: https://github.com/Hritikbansal/videocon | Paper: https://arxiv.org/abs/2311.10111)",
inputs=[gr.Video(label='input_video'), gr.Textbox(label='input_caption')],
outputs=[gr.Number(label='Entailment Score'), gr.Textbox(label='Natural Language Explanation')],
examples=[["examples/820.mp4", "We see the group making cookies."], ["examples/820.mp4", "We see the group eating cookies."], ["examples/244.mp4", "She throws a bowling ball while talking on the phone."], ["examples/244.mp4", "She throws a baseball while talking on the phone."]])
if __name__ == "__main__":
demo.launch() |