prismer / prismer_model.py
shikunl's picture
Final test
1aa8228
raw
history blame
8.85 kB
from __future__ import annotations
import concurrent.futures
import os
import pathlib
import shlex
import shutil
import subprocess
import sys
import hashlib
from typing import Tuple
try:
import ruamel_yaml as yaml
except ModuleNotFoundError:
import ruamel.yaml as yaml
import cv2
import torch
from label_prettify import label_prettify
repo_dir = pathlib.Path(__file__).parent
submodule_dir = repo_dir / 'prismer'
sys.path.insert(0, submodule_dir.as_posix())
from dataset import create_dataset, create_loader
from dataset.utils import pre_question
from model.prismer_caption import PrismerCaption
from model.prismer_vqa import PrismerVQA
from model.modules.utils import interpolate_pos_embed
def download_models() -> None:
if not pathlib.Path('prismer/experts/expert_weights/').exists():
subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
model_names = [
'vqa_prismer_base',
'vqa_prismer_large',
'pretrain_prismer_base',
'pretrain_prismer_large',
]
for model_name in model_names:
if pathlib.Path(f'prismer/logging/{model_name}').exists():
continue
subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer')
def build_deformable_conv() -> None:
subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
def run_expert(expert_name: str):
env = os.environ.copy()
if 'PYTHONPATH' in env:
env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
else:
env['PYTHONPATH'] = submodule_dir.as_posix()
subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'),
cwd='prismer',
env=env,
check=True)
def compute_md5(image_path: str) -> str:
with open(image_path, 'rb') as f:
s = f.read()
return hashlib.md5(s).hexdigest()
def run_experts(image_path: str) -> Tuple[str, Tuple[str, ...]]:
im_name = compute_md5(image_path)
image_dir = submodule_dir / 'helpers' / 'images'
out_path = image_dir / f'{im_name}.jpg'
image_dir.mkdir(parents=True, exist_ok=True)
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
results = [pathlib.Path('prismer/helpers/labels') / key / f'helpers/images/{im_name}.png' for key in keys]
results_pretty = [pathlib.Path('prismer/helpers/labels') / key / f'helpers/images/{im_name}_p.png' for key in keys]
out_paths = tuple(path.as_posix() for path in results)
pretty_paths = tuple(path.as_posix() for path in results_pretty)
config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
config['im_name'] = im_name
with open('prismer/configs/experts.yaml', 'w') as yaml_file:
yaml.dump(config, yaml_file, default_flow_style=False)
if not os.path.exists(out_paths[0]):
cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
# paralleled inference
expert_names = ['edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
run_expert('depth')
with concurrent.futures.ProcessPoolExecutor() as executor:
executor.map(run_expert, expert_names)
executor.shutdown(wait=True)
# no parallelization just to be safe
# expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
# for exp in expert_names:
# run_expert(exp)
label_prettify(image_path, out_paths)
return im_name, pretty_paths
class Model:
def __init__(self):
self.config = None
self.model = None
self.tokenizer = None
self.model_name = ''
self.exp_name = ''
self.mode = ''
def set_model(self, exp_name: str, mode: str) -> None:
if exp_name == self.exp_name and mode == self.mode:
return
# load checkpoints
model_name = exp_name.lower().replace('-', '_')
if mode == 'caption':
config = {
'dataset': 'demo',
'data_path': 'prismer/helpers',
'label_path': 'prismer/helpers/labels',
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
'image_resolution': 480,
'prismer_model': model_name,
'freeze': 'freeze_vision',
'prefix': '',
}
model = PrismerCaption(config)
state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0')
state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'],
len(model.expert_encoder.positional_embedding))
elif mode == 'vqa':
config = {
'dataset': 'demo',
'data_path': 'prismer/helpers',
'label_path': 'prismer/helpers/labels',
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
'image_resolution': 480,
'prismer_model': model_name,
'freeze': 'freeze_vision',
'prefix': '',
}
model = PrismerVQA(config)
state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0')
state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'],
len(model.expert_encoder.positional_embedding))
model.load_state_dict(state_dict)
model = model.half()
model.eval()
self.config = config
self.model = model.to('cuda:0')
self.tokenizer = model.tokenizer
self.exp_name = exp_name
self.mode = mode
@torch.inference_mode()
def run_caption_model(self, exp_name: str, im_name: str) -> str:
self.set_model(exp_name, 'caption')
self.config['im_name'] = im_name
_, test_dataset = create_dataset('caption', self.config)
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
experts, _ = next(iter(test_loader))
for exp in experts:
if exp == 'obj_detection':
experts[exp]['label'] = experts['obj_detection']['label'].to('cuda:0')
experts[exp]['instance'] = experts['obj_detection']['instance'].to('cuda:0')
else:
experts[exp] = experts[exp].to('cuda:0')
captions = self.model(experts, train=False, prefix=self.config['prefix'])
captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
caption = captions.to(experts['rgb'].device)[0]
caption = self.tokenizer.decode(caption, skip_special_tokens=True)
caption = caption.capitalize() + '.'
return caption
def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
im_name, pretty_paths = run_experts(image_path)
caption = self.run_caption_model(model_name, im_name)
return caption, *pretty_paths
@torch.inference_mode()
def run_vqa_model(self, exp_name: str, im_name: str, question: str) -> str:
self.set_model(exp_name, 'vqa')
self.config['im_name'] = im_name
_, test_dataset = create_dataset('caption', self.config)
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
experts, _ = next(iter(test_loader))
for exp in experts:
if exp == 'obj_detection':
experts[exp]['label'] = experts['obj_detection']['label'].to('cuda:0')
experts[exp]['instance'] = experts['obj_detection']['instance'].to('cuda:0')
else:
experts[exp] = experts[exp].to('cuda:0')
question = pre_question(question)
answer = self.model(experts, [question], train=False, inference='generate')
answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
answer = answer.to(experts['rgb'].device)[0]
answer = self.tokenizer.decode(answer, skip_special_tokens=True)
answer = answer.capitalize() + '.'
return answer
def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
im_name, pretty_paths = run_experts(image_path)
answer = self.run_vqa_model(model_name, im_name, question)
return answer, *pretty_paths