Spaces:
Running
Running
import gradio as gr | |
import argparse | |
import numpy as np | |
import torch | |
from torch import nn | |
from data.process_image import load_and_transform_image, get_image_transform | |
from main import SET_GLOBAL_VALUE | |
from model.build_model import create_vat_model | |
from data.process_audio import load_and_transform_audio, get_audio_transform | |
from data.process_video import load_and_transform_video, get_video_transform | |
from data.process_depth import load_and_transform_depth, get_depth_transform | |
from data.process_thermal import load_and_transform_thermal, get_thermal_transform | |
from data.process_text import load_and_transform_text | |
from open_clip import get_tokenizer | |
from open_clip.factory import HF_HUB_PREFIX | |
class LanguageBind(nn.Module): | |
def __init__(self, args, no_temp=False): | |
super(LanguageBind, self).__init__() | |
self.no_temp = no_temp | |
MODEL_DICT = {"ViT-L-14": "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", | |
"ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"} | |
args.pretrained = False | |
args.model = MODEL_DICT["ViT-L-14"] | |
args.cache_dir = 'D:/Omni-modal-valdt-audio' | |
args.video_decode_backend = 'decord' | |
# args.device = 'cpu' | |
args.device = 'cuda:0' | |
device = torch.device(args.device) | |
args.precision = None | |
args.init_temp = 0 | |
args.force_patch_dropout = 0.0 | |
args.add_time_attn = False | |
args.convert_to_lora = True | |
args.lora_r = 2 | |
args.lora_alpha = 16 | |
args.lora_dropout = 0.0 # 0.1? | |
args.num_frames = 8 | |
args.clip_type = 'vl' | |
args.num_mel_bins = 1008 | |
args.target_length = 112 | |
args.audio_sample_rate = 16000 | |
args.audio_mean = 4.5689974 | |
args.audio_std = -4.2677393 | |
args.max_depth = 10 | |
args.image_size = 224 | |
args.rank = 0 | |
SET_GLOBAL_VALUE('PATCH_DROPOUT', args.force_patch_dropout) | |
SET_GLOBAL_VALUE('NUM_FRAMES', args.num_frames) | |
args.clip_type = ['il', 'vl', 'al', 'dl', 'tl'] | |
temp_clip_type = args.clip_type | |
self.modality_encoder = {} | |
self.modality_proj = {} | |
self.modality_scale = {} | |
for c in temp_clip_type: | |
args.clip_type = c | |
if c == 'il': | |
args.convert_to_lora = False | |
model = create_vat_model(args) | |
args.convert_to_lora = True | |
elif c == 'vl': | |
args.lora_r = 64 | |
args.add_time_attn = True | |
model = create_vat_model(args) | |
args.add_time_attn = False | |
args.lora_r = 2 | |
elif c == 'al': | |
args.lora_r = 8 | |
model = create_vat_model(args) | |
args.lora_r = 2 | |
else: | |
model = create_vat_model(args) | |
''' | |
state_dict = torch.load(f'model_zoo/{c}.pt', map_location='cpu') | |
if state_dict.get('state_dict', None) is not None: | |
state_dict = state_dict['state_dict'] | |
if next(iter(state_dict.items()))[0].startswith('module'): | |
state_dict = {k[7:]: v for k, v in state_dict.items()} | |
msg = model.load_state_dict(state_dict, strict=False) | |
print(f'load {c}, {msg}') | |
''' | |
if c == 'vl': | |
self.modality_encoder['video'] = model.vision_model | |
self.modality_proj['video'] = model.visual_projection | |
self.modality_scale['video'] = model.logit_scale | |
elif c == 'al': | |
self.modality_encoder['audio'] = model.vision_model | |
self.modality_proj['audio'] = model.visual_projection | |
self.modality_scale['audio'] = model.logit_scale | |
elif c == 'dl': | |
self.modality_encoder['depth'] = model.vision_model | |
self.modality_proj['depth'] = model.visual_projection | |
self.modality_scale['depth'] = model.logit_scale | |
elif c == 'tl': | |
self.modality_encoder['thermal'] = model.vision_model | |
self.modality_proj['thermal'] = model.visual_projection | |
self.modality_scale['thermal'] = model.logit_scale | |
elif c == 'il': | |
self.modality_encoder['image'] = model.vision_model | |
self.modality_proj['image'] = model.visual_projection | |
self.modality_scale['image'] = model.logit_scale | |
else: | |
raise NameError(f'No clip_type of {c}') | |
self.modality_encoder['language'] = model.text_model | |
self.modality_proj['language'] = model.text_projection | |
self.modality_encoder = nn.ModuleDict(self.modality_encoder) | |
self.modality_proj = nn.ModuleDict(self.modality_proj) | |
def forward(self, inputs): | |
outputs = {} | |
for key, value in inputs.items(): | |
value = self.modality_encoder[key](**value)[1] | |
value = self.modality_proj[key](value) | |
value = value / value.norm(p=2, dim=-1, keepdim=True) | |
if not self.no_temp: | |
if key != 'language': | |
value = value * self.modality_scale[key].exp() | |
outputs[key] = value | |
return outputs | |
def stack_dict(x, device): | |
if len(x) == 0: | |
return None | |
out_dict = {} | |
keys = list(x[0].keys()) | |
for key in keys: | |
out_dict[key] = torch.stack([i[key] for i in x]).to(device) | |
return out_dict |