Spaces:
Runtime error
Runtime error
Hritik
commited on
Commit
•
0ba1d16
1
Parent(s):
cfe5653
add app and nle code
Browse files- app.py +27 -20
- nle_inference.py +9 -83
app.py
CHANGED
@@ -17,11 +17,7 @@ from utils import batchify
|
|
17 |
|
18 |
import gradio as gr
|
19 |
from entailment_inference import get_scores
|
20 |
-
|
21 |
-
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
22 |
-
# True
|
23 |
-
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
24 |
-
# Tesla T4
|
25 |
|
26 |
pretrained_ckpt = "mplugowl7bvideo/"
|
27 |
trained_ckpt = "owl-con/checkpoint-5178/pytorch_model.bin"
|
@@ -30,19 +26,13 @@ tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
|
|
30 |
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
|
31 |
processor = MplugOwlProcessor(image_processor, tokenizer)
|
32 |
|
33 |
-
|
34 |
# Instantiate model
|
35 |
model = MplugOwlForConditionalGeneration.from_pretrained(
|
36 |
pretrained_ckpt,
|
37 |
torch_dtype=torch.bfloat16,
|
38 |
device_map={'': 'cpu'}
|
39 |
-
# device_map={'':0}
|
40 |
)
|
41 |
|
42 |
-
# for name, param in model.named_parameters():
|
43 |
-
# print(param.device)
|
44 |
-
# break
|
45 |
-
|
46 |
peft_config = LoraConfig(
|
47 |
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
|
48 |
inference_mode=True,
|
@@ -56,14 +46,31 @@ with open(trained_ckpt, 'rb') as f:
|
|
56 |
ckpt = torch.load(f, map_location = torch.device("cpu"))
|
57 |
model.load_state_dict(ckpt)
|
58 |
model = model.to("cuda:0").to(torch.bfloat16)
|
59 |
-
print('Model Loaded')
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
score = get_scores(model, tokenizer, dataloader)
|
69 |
-
print(score)
|
|
|
17 |
|
18 |
import gradio as gr
|
19 |
from entailment_inference import get_scores
|
20 |
+
from nle_inference import VideoCaptionDataset, get_nle
|
|
|
|
|
|
|
|
|
21 |
|
22 |
pretrained_ckpt = "mplugowl7bvideo/"
|
23 |
trained_ckpt = "owl-con/checkpoint-5178/pytorch_model.bin"
|
|
|
26 |
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
|
27 |
processor = MplugOwlProcessor(image_processor, tokenizer)
|
28 |
|
|
|
29 |
# Instantiate model
|
30 |
model = MplugOwlForConditionalGeneration.from_pretrained(
|
31 |
pretrained_ckpt,
|
32 |
torch_dtype=torch.bfloat16,
|
33 |
device_map={'': 'cpu'}
|
|
|
34 |
)
|
35 |
|
|
|
|
|
|
|
|
|
36 |
peft_config = LoraConfig(
|
37 |
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
|
38 |
inference_mode=True,
|
|
|
46 |
ckpt = torch.load(f, map_location = torch.device("cpu"))
|
47 |
model.load_state_dict(ckpt)
|
48 |
model = model.to("cuda:0").to(torch.bfloat16)
|
|
|
49 |
|
50 |
+
def inference(videopath, text):
|
51 |
+
|
52 |
+
PROMPT = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
|
53 |
+
Human: <|video|>
|
54 |
+
Human: Does this video entail the description: "{caption}"?
|
55 |
+
AI: """
|
56 |
+
|
57 |
+
valid_data = MultiModalDataset(videopath, PROMPT.format(caption = text), tokenizer, processor, max_length = 256, loss_objective = 'sequential')
|
58 |
+
dataloader = DataLoader(valid_data, pin_memory=True, collate_fn=batchify)
|
59 |
+
score = get_scores(model, tokenizer, dataloader)
|
60 |
+
|
61 |
+
if score < 0.5:
|
62 |
+
dataset = VideoCaptionDataset(videopath, text)
|
63 |
+
dataloader = DataLoader(dataset)
|
64 |
+
nle = get_nle(model, processor, tokenizer, dataloader)
|
65 |
+
else:
|
66 |
+
nle = "None (NLE is only triggered when entailment score < 0.5)"
|
67 |
+
|
68 |
+
return score, nle
|
69 |
+
|
70 |
+
demo = gr.Interface(inference,
|
71 |
+
title="Owl-Con Demo (ode: https://github.com/Hritikbansal/videocon | Paper: https://arxiv.org/abs/2311.10111)",
|
72 |
+
inputs=[gr.Video(label='input_video'), gr.Textbox(label='input_caption')],
|
73 |
+
outputs=[gr.Number(label='Entailemnt Score'), gr.Textbox(label='Natural Language Explanation')])
|
74 |
|
75 |
+
if __name__ == "__main__":
|
76 |
+
demo.launch()
|
|
|
|
nle_inference.py
CHANGED
@@ -11,19 +11,6 @@ from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
|
11 |
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
|
12 |
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
|
13 |
|
14 |
-
parser = argparse.ArgumentParser()
|
15 |
-
|
16 |
-
parser.add_argument('--input_file', type = str, required = True, help = 'input csv file')
|
17 |
-
parser.add_argument('--output_file', type = str, help = 'output csv file')
|
18 |
-
parser.add_argument('--pretrained_ckpt', type = str, required = True, help = 'pretrained ckpt')
|
19 |
-
parser.add_argument('--trained_ckpt', type = str, help = 'trained ckpt')
|
20 |
-
parser.add_argument('--lora_r', type = int, default = 32)
|
21 |
-
parser.add_argument('--use_lora', action = 'store_true', help = 'lora model')
|
22 |
-
parser.add_argument('--all_params', action = 'store_true', help = 'all params')
|
23 |
-
parser.add_argument('--batch_size', type = int, default = 1)
|
24 |
-
parser.add_argument('--num_frames', type = int, default = 32)
|
25 |
-
|
26 |
-
args = parser.parse_args()
|
27 |
|
28 |
PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
|
29 |
Human: <|video|>
|
@@ -38,89 +25,28 @@ generate_kwargs = {
|
|
38 |
|
39 |
class VideoCaptionDataset(Dataset):
|
40 |
|
41 |
-
def __init__(self,
|
42 |
-
self.
|
|
|
43 |
|
44 |
def __len__(self):
|
45 |
-
return
|
46 |
|
47 |
def __getitem__(self, index):
|
48 |
item = {}
|
49 |
-
item['videopath'] = self.
|
50 |
-
item['neg_caption'] = self.
|
51 |
return item
|
52 |
|
53 |
-
def get_nle(
|
54 |
-
|
55 |
with torch.no_grad():
|
56 |
for _, batch in tqdm(enumerate(dataloader)):
|
57 |
videopaths = batch['videopath']
|
58 |
neg_caption = batch['neg_caption'][0]
|
59 |
prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)]
|
60 |
-
inputs = processor(text=prompts, videos=videopaths, num_frames=
|
61 |
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
|
62 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
63 |
res = model.generate(**inputs, **generate_kwargs)
|
64 |
generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
|
65 |
-
|
66 |
-
with open(args.output_file, 'a') as f:
|
67 |
-
writer = csv.writer(f)
|
68 |
-
writer.writerow([videopaths[0], neg_caption, generated_nle])
|
69 |
-
|
70 |
-
def main():
|
71 |
-
|
72 |
-
# Create dataloader
|
73 |
-
dataset = VideoCaptionDataset(args.input_file)
|
74 |
-
dataloader = DataLoader(dataset, batch_size = args.batch_size)
|
75 |
-
|
76 |
-
pretrained_ckpt = args.pretrained_ckpt
|
77 |
-
|
78 |
-
# Processors
|
79 |
-
tokenizer = LlamaTokenizer.from_pretrained(pretrained_ckpt)
|
80 |
-
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
|
81 |
-
processor = MplugOwlProcessor(image_processor, tokenizer)
|
82 |
-
|
83 |
-
# Instantiate model
|
84 |
-
model = MplugOwlForConditionalGeneration.from_pretrained(
|
85 |
-
pretrained_ckpt,
|
86 |
-
torch_dtype=torch.bfloat16,
|
87 |
-
device_map={'':0}
|
88 |
-
)
|
89 |
-
|
90 |
-
if args.use_lora:
|
91 |
-
for name, param in model.named_parameters():
|
92 |
-
param.requires_grad = False
|
93 |
-
if args.all_params:
|
94 |
-
peft_config = LoraConfig(
|
95 |
-
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj|gate_proj|down_proj|up_proj)',
|
96 |
-
inference_mode=True,
|
97 |
-
r=args.lora_r,
|
98 |
-
lora_alpha=16,
|
99 |
-
lora_dropout=0.05
|
100 |
-
)
|
101 |
-
else:
|
102 |
-
peft_config = LoraConfig(
|
103 |
-
target_modules=r'.*language_model.*\.(q_proj|v_proj|k_proj|o_proj)',
|
104 |
-
inference_mode=True,
|
105 |
-
r=args.lora_r,
|
106 |
-
lora_alpha=16,
|
107 |
-
lora_dropout=0.05
|
108 |
-
)
|
109 |
-
|
110 |
-
model = get_peft_model(model, peft_config)
|
111 |
-
model.print_trainable_parameters()
|
112 |
-
with open(args.trained_ckpt, 'rb') as f:
|
113 |
-
ckpt = torch.load(f, map_location = torch.device(f"cuda:0"))
|
114 |
-
model.load_state_dict(ckpt)
|
115 |
-
model = model.to(torch.bfloat16)
|
116 |
-
print('Model Loaded')
|
117 |
-
|
118 |
-
model.eval()
|
119 |
-
|
120 |
-
# get nle
|
121 |
-
get_nle(args, model, processor, tokenizer, dataloader)
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
if __name__ == "__main__":
|
126 |
-
main()
|
|
|
11 |
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration
|
12 |
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
|
16 |
Human: <|video|>
|
|
|
25 |
|
26 |
class VideoCaptionDataset(Dataset):
|
27 |
|
28 |
+
def __init__(self, videopath, text):
|
29 |
+
self.videopath = videopath
|
30 |
+
self.text = text
|
31 |
|
32 |
def __len__(self):
|
33 |
+
return 1
|
34 |
|
35 |
def __getitem__(self, index):
|
36 |
item = {}
|
37 |
+
item['videopath'] = self.videopath
|
38 |
+
item['neg_caption'] = self.text
|
39 |
return item
|
40 |
|
41 |
+
def get_nle(model, processor, tokenizer, dataloader):
|
|
|
42 |
with torch.no_grad():
|
43 |
for _, batch in tqdm(enumerate(dataloader)):
|
44 |
videopaths = batch['videopath']
|
45 |
neg_caption = batch['neg_caption'][0]
|
46 |
prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)]
|
47 |
+
inputs = processor(text=prompts, videos=videopaths, num_frames=32, return_tensors='pt')
|
48 |
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
|
49 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
50 |
res = model.generate(**inputs, **generate_kwargs)
|
51 |
generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
|
52 |
+
return generated_nle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|