Spaces:
Running
on
Zero
Running
on
Zero
import shutil | |
import subprocess | |
import torch | |
import gradio as gr | |
from fastapi import FastAPI | |
import os | |
from PIL import Image | |
import tempfile | |
from decord import VideoReader, cpu | |
import uvicorn | |
from transformers import TextStreamer | |
import hashlib | |
import os | |
import sys | |
import time | |
import warnings | |
from pathlib import Path | |
from typing import Optional | |
from typing import Dict, List, Literal, Optional, Tuple | |
from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable | |
import lightning as L | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from generate import generate as generate_ | |
from lit_llama import Tokenizer, LLaMA, LLaMAConfig | |
from lit_llama.lora import lora | |
from lit_llama.utils import EmptyInitOnDevice | |
from lit_gpt.utils import lazy_load | |
from scripts.video_dataset.prepare_video_dataset_video_llava import generate_prompt_mlp | |
from options import option | |
import imageio | |
from tqdm import tqdm | |
from models.multimodal_encoder.builder import build_image_tower, build_video_tower | |
from models.multimodal_projector.builder import build_vision_projector | |
title_markdown = ("""<div class="embed_hidden" style="text-align: center;"> | |
<h1>MotionLLM: Understanding Human Behaviors from Human Motions and Videos</h1> | |
<h3> | |
<a href="https://lhchen.top" target="_blank" rel="noopener noreferrer">Ling-Hao Chen</a><sup>π 1, 3</sup>, | |
<a href="https://shunlinlu.github.io" target="_blank" rel="noopener noreferrer">Shunlin Lu</a><sup>π 2, 3</sup>, | |
<br> | |
<a href="https://ailingzeng.sit" target="_blank" rel="noopener noreferrer">Ailing Zeng</a><sup>3</sup>, | |
<a href="https://haozhang534.github.io/" target="_blank" rel="noopener noreferrer">Hao Zhang</a><sup>3, 4</sup>, | |
<a href="https://wabyking.github.io/old.html" target="_blank" rel="noopener noreferrer">Benyou Wang</a><sup>2</sup>, | |
<a href="http://zhangruimao.site" target="_blank" rel="noopener noreferrer">Ruimao Zhang</a><sup>2</sup>, | |
<a href="https://leizhang.org" target="_blank" rel="noopener noreferrer">Lei Zhang</a><sup>π€ 3</sup> | |
</h3> | |
<h3><sup>π</sup><i>Co-first author. Listing order is random.</i>   <sup>π€</sup><i>Corresponding author.</i></h3> | |
<h3> | |
<sup>1</sup>THU   | |
<sup>2</sup>CUHK (SZ)   | |
<sup>3</sup>IDEA Research   | |
<sup>4</sup>HKUST | |
</h3> | |
</div> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<img src="https://lhchen.top/MotionLLM/assets/img/highlight.png" alt="MotionLLM" style="width:60%; height: auto; align-items: center;"> | |
</div> | |
""") | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
tos_markdown = (""" | |
*We are now working to support the motion branch of the MotionLLM model. | |
### Terms of use | |
By using this service, users are required to agree to the following terms: | |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. | |
It is forbidden to use the service to generate content that is illegal, harmful, violent, racist, or sexual | |
The usage of this service is subject to the IDEA License. | |
""") | |
learn_more_markdown = (""" | |
### License | |
License for Non-commercial Scientific Research Purposes | |
IDEA grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under IDEAβs copyright interests to reproduce, distribute, and create derivative works of the text, videos, codes solely for your non-commercial research purposes. | |
Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. | |
Text and visualization results are owned by International Digital Economy Academy (IDEA). | |
You also need to obey the original license of the dependency models/data used in this service. | |
""") | |
class LlavaMetaModel: | |
def __init__(self, config, pretrained_checkpoint): | |
super(LlavaMetaModel, self).__init__() | |
# import pdb; pdb.set_trace() | |
if hasattr(config, "mm_image_tower") or hasattr(config, "image_tower"): | |
self.image_tower = build_image_tower(config, delay_load=True) | |
self.mm_projector = build_vision_projector(config) | |
if hasattr(config, "mm_video_tower") or hasattr(config, "video_tower"): | |
self.video_tower = build_video_tower(config, delay_load=True) | |
self.mm_projector = build_vision_projector(config) | |
self.load_video_tower_pretrained(pretrained_checkpoint) | |
def get_image_tower(self): | |
image_tower = getattr(self, 'image_tower', None) | |
if type(image_tower) is list: | |
image_tower = image_tower[0] | |
return image_tower | |
def get_video_tower(self): | |
video_tower = getattr(self, 'video_tower', None) | |
if type(video_tower) is list: | |
video_tower = video_tower[0] | |
return video_tower | |
def get_all_tower(self, keys): | |
tower = {key: getattr(self, f'get_{key}_tower') for key in keys} | |
return tower | |
def load_video_tower_pretrained(self, pretrained_checkpoint): | |
self.mm_projector.load_state_dict(pretrained_checkpoint, strict=True) | |
def initialize_image_modules(self, model_args, fsdp=None): | |
image_tower = model_args.image_tower | |
mm_vision_select_layer = model_args.mm_vision_select_layer | |
mm_vision_select_feature = model_args.mm_vision_select_feature | |
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter | |
self.config.mm_image_tower = image_tower | |
image_tower = build_image_tower(model_args) | |
if fsdp is not None and len(fsdp) > 0: | |
self.image_tower = [image_tower] | |
else: | |
self.image_tower = image_tower | |
self.config.use_mm_proj = True | |
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') | |
self.config.mm_hidden_size = image_tower.hidden_size | |
self.config.mm_vision_select_layer = mm_vision_select_layer | |
self.config.mm_vision_select_feature = mm_vision_select_feature | |
self.mm_projector = build_vision_projector(self.config) | |
if pretrain_mm_mlp_adapter is not None: | |
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') | |
def get_w(weights, keyword): | |
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} | |
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) | |
def initialize_video_modules(self, model_args, fsdp=None): | |
video_tower = model_args.video_tower | |
mm_vision_select_layer = model_args.mm_vision_select_layer | |
mm_vision_select_feature = model_args.mm_vision_select_feature | |
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter | |
self.config.mm_video_tower = video_tower | |
video_tower = build_video_tower(model_args) | |
if fsdp is not None and len(fsdp) > 0: | |
self.video_tower = [video_tower] | |
else: | |
self.video_tower = video_tower | |
self.config.use_mm_proj = True | |
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') | |
self.config.mm_hidden_size = video_tower.hidden_size | |
self.config.mm_vision_select_layer = mm_vision_select_layer | |
self.config.mm_vision_select_feature = mm_vision_select_feature | |
self.mm_projector = build_vision_projector(self.config) | |
if pretrain_mm_mlp_adapter is not None: | |
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') | |
def get_w(weights, keyword): | |
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} | |
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) | |
def encode_images(self, images): | |
image_features = self.get_image_tower()(images) | |
image_features = self.mm_projector(image_features) | |
return image_features | |
def encode_videos(self, videos): | |
# import pdb; pdb.set_trace() | |
# videos: torch.Size([1, 3, 8, 224, 224]) | |
video_features = self.get_video_tower()(videos) # torch.Size([1, 2048, 1024]) | |
video_features = self.mm_projector(video_features.float()) # torch.Size([1, 2048, 4096]) | |
return video_features | |
def get_multimodal_embeddings(self, X_modalities): | |
Xs, keys= X_modalities | |
X_features = getattr(self, f'encode_{keys[0]}s')(Xs) # expand to get batchsize | |
return X_features | |
class Projection(nn.Module): | |
def __init__(self, ): | |
super().__init__() | |
self.linear_proj = nn.Linear(512, 4096) | |
def forward(self, x): | |
return self.linear_proj(x) | |
class ProjectionNN(nn.Module): | |
def __init__(self, ): | |
super().__init__() | |
self.proj = nn.Sequential( | |
nn.Linear(512, 4096), | |
nn.GELU(), | |
nn.Linear(4096, 4096) | |
) | |
def forward(self, x): | |
return self.proj(x) | |
class Conversation(): | |
def __init__(self, output=None, input_prompt=None, prompt=None): | |
if output is None: | |
self.messages = [] | |
else: | |
self.messages = [] | |
self.append_message(prompt, input_prompt, output) | |
def append_message(self, output, input_prompt, prompt, show_images): | |
# print(output) | |
# print(input_prompt) | |
# print(prompt) | |
# print(show_images) | |
self.messages.append((output, input_prompt, prompt, show_images)) | |
def to_gradio_chatbot(self, show_images=None, output_text=None): | |
# return a list | |
if show_images is None: | |
show_images = self.messages[-1][3] | |
output_text = self.messages[-1][0] | |
return [ | |
[show_images, output_text] | |
] | |
def get_info(self): | |
return self.messages[-1][0], self.messages[-1][1] | |
class ConversationBuffer(): | |
def __init__(self, input_text): | |
self.buffer_ = [] | |
self.buffer.append(input_text) | |
def init_conv(): | |
conv = Conversation() | |
return conv | |
def get_processor(X, config, device, pretrained_checkpoint_tower, model_path = 'LanguageBind/MotionLLM-7B'): | |
mm_backbone_mlp_model = LlavaMetaModel(config, pretrained_checkpoint_tower) | |
processor = {} | |
if 'Image' in X: | |
image_tower = mm_backbone_mlp_model.get_image_tower() # LanguageBindImageTower() | |
if not image_tower.is_loaded: | |
image_tower.load_model() | |
image_tower.to(device=device, dtype=torch.float16) | |
image_processor = image_tower.image_processor | |
processor['image'] = image_processor | |
if 'Video' in X: | |
video_tower = mm_backbone_mlp_model.get_video_tower() | |
if not video_tower.is_loaded: | |
video_tower.load_model() | |
video_tower.to(device=device, dtype=torch.float16) | |
video_processor = video_tower.video_processor | |
processor['video'] = video_processor | |
return mm_backbone_mlp_model, processor | |
def motionllm( | |
args, | |
input_video_path: str, | |
text_en_in: str, | |
quantize: Optional[str] = None, | |
dtype: str = "float32", | |
max_new_tokens: int = 200, | |
top_k: int = 200, | |
temperature: float = 0.8, | |
accelerator: str = "auto",): | |
video_tensor = video_processor(input_video_path, return_tensors='pt')['pixel_values'] | |
if type(video_tensor) is list: | |
tensor = [video.to('cuda', dtype=torch.float16) for video in video_tensor] | |
else: | |
tensor = video_tensor.to('cuda', dtype=torch.float16) # (1,3,8,224,224) | |
X_modalities = [tensor,['video']] | |
video_feature = mm_backbone_mlp_model.get_multimodal_embeddings(X_modalities) | |
prompt = text_en_in | |
input_prompt = prompt | |
sample = {"instruction": prompt, "input": input_video_path} | |
prefix = generate_prompt_mlp(sample) | |
pre = torch.cat((tokenizer.encode(prefix.split('INPUT_VIDEO: ')[0] + "\n", bos=True, eos=False, device=model.device).view(1, -1), tokenizer.encode("INPUT_VIDEO: ", bos=False, eos=False, device=model.device).view(1, -1)), dim=1) | |
prompt = (pre, ". ASSISTANT: ") | |
encoded = (prompt[0], video_feature[0], tokenizer.encode(prompt[1], bos=False, eos=False, device=model.device).view(1, -1)) | |
t0 = time.perf_counter() | |
output_seq = generate_( | |
model, | |
idx=encoded, | |
max_seq_length=4096, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=top_k, | |
eos_id=tokenizer.eos_id, | |
tokenizer = tokenizer, | |
) | |
outputfull = tokenizer.decode(output_seq) | |
output = outputfull.split("ASSISTANT:")[-1].strip() | |
print("================================") | |
print(output) | |
print("================================") | |
return output, input_prompt, prompt | |
def save_image_to_local(image): | |
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') | |
image = Image.open(image) | |
image.save(filename) | |
# print(filename) | |
return filename | |
def save_video_to_local(video_path): | |
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') | |
shutil.copyfile(video_path, filename) | |
return filename | |
def generate(image1, video, textbox_in, first_run, state, images_tensor): | |
flag = 1 | |
image1 = image1 if image1 else "none" | |
video = video if video else "none" | |
if type(state) is not Conversation: | |
state = init_conv() | |
images_tensor = [[], []] | |
first_run = False if len(state.messages) > 0 else True | |
text_en_in = textbox_in.replace("picture", "image") | |
output, input_prompt, prompt = motionllm(args, video, text_en_in) | |
text_en_out = output | |
textbox_out = text_en_out | |
show_images = "" | |
if os.path.exists(image1): | |
filename = save_image_to_local(image1) | |
show_images += f'<img src="./file={filename}" style="display: inline-block;width: 250px;max-height: 400px;">' | |
if os.path.exists(video): | |
filename = save_video_to_local(video) | |
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>' | |
show_images = textbox_in + "\n" + show_images | |
state.append_message(output, input_prompt, prompt, show_images) | |
torch.cuda.empty_cache() | |
return (state, state.to_gradio_chatbot(show_images, output), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(image1) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True)) | |
def regenerate(state): | |
if len(state.messages) > 0: | |
tobot = state.to_gradio_chatbot() | |
tobot[-1][1] = None | |
textbox = state.messages[-1][1] | |
state.messages.pop(-1) | |
return state, tobot, False, textbox | |
return (state, [], True) | |
def clear_history(state): | |
state = init_conv() | |
try: | |
tgt = state.to_gradio_chatbot() | |
except: | |
tgt = [None, None] | |
return (gr.update(value=None, interactive=True), | |
gr.update(value=None, interactive=True),\ | |
gr.update(value=None, interactive=True),\ | |
True, state, tgt, [[], []]) | |
def get_md5(file_path): | |
hash_md5 = hashlib.md5() | |
with open(file_path, "rb") as f: | |
for chunk in iter(lambda: f.read(4096), b""): | |
hash_md5.update(chunk) | |
return hash_md5.hexdigest() | |
def logging_up(video, state): | |
try: | |
state.get_info() | |
except: | |
return False | |
action = "upvote" | |
# Get the current time | |
current_time = str(time.time()) | |
# Create an md5 object | |
hash_object = hashlib.md5(current_time.encode()) | |
# Get the hexadecimal representation of the hash | |
md5_hash = get_md5(video) + "-" + hash_object.hexdigest() | |
command = f"cp {video} ./feedback/{action}/mp4/{md5_hash}.mp4" | |
os.system(command) | |
with open (f"./feedback/{action}/txt/{md5_hash}.txt", "w") as f: | |
out, prp = state.get_info() | |
f.write(f"==========\nPrompt: {prp}\n==========\nOutput: {out}==========\n") | |
return True | |
def logging_down(video, state): | |
try: | |
state.get_info() | |
except: | |
return False | |
action = "downvote" | |
# Get the current time | |
current_time = str(time.time()) | |
# Create an md5 object | |
hash_object = hashlib.md5(current_time.encode()) | |
# Get the hexadecimal representation of the hash | |
md5_hash = get_md5(video) + "-" + hash_object.hexdigest() | |
command = f"cp {video} ./feedback/{action}/mp4/{md5_hash}.mp4" | |
os.system(command) | |
with open (f"./feedback/{action}/txt/{md5_hash}.txt", "w") as f: | |
out, prp = state.get_info() | |
f.write(f"==========\nPrompt: {prp}\n==========\nOutput: {out}==========\n") | |
return True | |
torch.set_float32_matmul_precision("high") | |
warnings.filterwarnings('ignore') | |
args = option.get_args_parser() | |
conv_mode = "llava_v1" | |
model_path = 'LanguageBind/Video-LLaVA-7B' | |
device = 'cuda' | |
load_8bit = False | |
load_4bit = True | |
dtype = torch.float16 | |
if not os.path.exists("temp"): | |
os.makedirs("temp") | |
lora_path = Path(args.lora_path) | |
pretrained_llm_path = Path(f"./checkpoints/vicuna-7b-v1.5/lit_model.pth") | |
tokenizer_llm_path = Path("./checkpoints/vicuna-7b-v1.5/tokenizer.model") | |
# assert lora_path.is_file() | |
assert pretrained_llm_path.is_file() | |
assert tokenizer_llm_path.is_file() | |
accelerator = "auto" | |
fabric = L.Fabric(accelerator=accelerator, devices=1) | |
dtype = "float32" | |
dt = getattr(torch, dtype, None) | |
if not isinstance(dt, torch.dtype): | |
raise ValueError(f"{dtype} is not a valid dtype.") | |
dtype = dt | |
quantize = None | |
t0 = time.time() | |
with EmptyInitOnDevice( | |
device=fabric.device, dtype=dtype, quantization_mode=quantize | |
), lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True): | |
checkpoint_dir = Path("checkpoints/vicuna-7b-v1.5") | |
lora_query = True | |
lora_key = False | |
lora_value = True | |
lora_projection = False | |
lora_mlp = False | |
lora_head = False | |
config = Config.from_name( | |
name=checkpoint_dir.name, | |
r=args.lora_r, | |
alpha=args.lora_alpha, | |
dropout=args.lora_dropout, | |
to_query=lora_query, | |
to_key=lora_key, | |
to_value=lora_value, | |
to_projection=lora_projection, | |
to_mlp=lora_mlp, | |
to_head=lora_head, | |
) | |
model = GPT(config).bfloat16() | |
mlp_path = args.mlp_path | |
pretrained_checkpoint_mlp = torch.load(mlp_path) | |
X = ['Video'] | |
mm_backbone_mlp_model, processor = get_processor(X, args, 'cuda', pretrained_checkpoint_mlp, model_path = 'LanguageBind/Video-LLaVA-7B') | |
video_processor = processor['video'] | |
linear_proj = mm_backbone_mlp_model.mm_projector | |
# 1. Load the pretrained weights | |
pretrained_llm_checkpoint = lazy_load(pretrained_llm_path) | |
# 2. Load the fine-tuned LoRA weights | |
lora_checkpoint = lazy_load(lora_path) | |
# 3. merge the two checkpoints | |
model_state_dict = {**pretrained_llm_checkpoint, **lora_checkpoint} | |
model.load_state_dict(model_state_dict, strict=True) | |
print('Load llm base model from', pretrained_llm_path) | |
print('Load lora model from', lora_path) | |
# load mlp again, to en sure, not neccessary actually | |
linear_proj.load_state_dict(pretrained_checkpoint_mlp) | |
linear_proj = linear_proj.cuda() | |
print('Load mlp model again from', mlp_path) | |
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) | |
model.eval() | |
model = fabric.setup_module(model) | |
linear_proj.eval() | |
tokenizer = Tokenizer(tokenizer_llm_path) | |
print('Load tokenizer from', tokenizer_llm_path) | |
print(torch.cuda.memory_allocated()) | |
print(torch.cuda.max_memory_allocated()) | |
app = FastAPI() | |
textbox = gr.Textbox( | |
show_label=False, placeholder="Enter text and press ENTER", container=False | |
) | |
with gr.Blocks(title='MotionLLM', theme=gr.themes.Default(), css=block_css) as demo: | |
gr.Markdown(title_markdown) | |
state = gr.State() | |
buffer_ = gr.State() | |
first_run = gr.State() | |
images_tensor = gr.State() | |
with gr.Row(): | |
with gr.Column(scale=3): | |
image1 = gr.State() | |
video = gr.Video(label="Input Video") | |
cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
gr.Examples( | |
examples=[ | |
[ | |
f"{cur_dir}/examples/Play_Electric_guitar_16_clip1.mp4", | |
"why is the girl so happy", | |
], | |
[ | |
f"{cur_dir}/examples/guoyoucai.mov", | |
"what is the feeling of him", | |
], | |
[ | |
f"{cur_dir}/examples/sprint_run_18_clip1.mp4", | |
"Why is the man running so fast?", | |
], | |
[ | |
f"{cur_dir}/examples/lift_weight.mp4", | |
"Assume you are a fitness coach, refer to the video of the professional athlete, please analyze specific action essentials in steps and give detailed instruction.", | |
], | |
[ | |
f"{cur_dir}/examples/Shaolin_Kung_Fu_Wushu_Selfdefense_Sword_Form_Session_22_clip3.mp4", | |
"wow, can you teach me the motion, step by step in detail", | |
], | |
[ | |
f"{cur_dir}/examples/mabaoguo.mp4", | |
"why is the video funny?", | |
], | |
[ | |
f"{cur_dir}/examples/COBRA_PUSH_UPS_clip2.mp4", | |
"describe the body movement of the woman", | |
], | |
[ | |
f"{cur_dir}/examples/sample_demo_1.mp4", | |
"Why is this video interesting?", | |
], | |
], | |
inputs=[video, textbox], | |
) | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot(label="MotionLLM", bubble_full_width=True).style(height=875) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox.render() | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button( | |
value="Send", variant="primary", interactive=True | |
) | |
with gr.Row(elem_id="buttons") as button_row: | |
upvote_btn = gr.Button(value="π Upvote", interactive=True) | |
downvote_btn = gr.Button(value="π Downvote", interactive=True) | |
flag_btn = gr.Button(value="β οΈ Flag", interactive=True) | |
# stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False) | |
regenerate_btn = gr.Button(value="π Regenerate", interactive=True) | |
clear_btn = gr.Button(value="ποΈ Clear history", interactive=True) | |
gr.Markdown(tos_markdown) | |
gr.Markdown(learn_more_markdown) | |
tmp = gr.State() | |
upvote_btn.click(logging_up, [video, state], [tmp]) | |
downvote_btn.click(logging_down, [video, state], [tmp]) | |
submit_btn.click(generate, [image1, video, textbox, first_run, state, images_tensor], | |
[state, chatbot, first_run, textbox, images_tensor, image1, video]) | |
regenerate_btn.click(regenerate, [state], [state, chatbot, first_run, textbox]).then( | |
generate, [image1, video, textbox, first_run, state, images_tensor], [state, chatbot, first_run, textbox, images_tensor, image1, video]) | |
clear_btn.click(clear_history, [state], | |
[image1, video, textbox, first_run, state, chatbot, images_tensor]) | |
app = gr.mount_gradio_app(app, demo, path="/") | |
uvicorn.run(app, host="0.0.0.0", port=6657) |