from __future__ import annotations import os import pathlib import shlex import shutil import subprocess import sys import cv2 import torch from label_prettify import label_prettify repo_dir = pathlib.Path(__file__).parent submodule_dir = repo_dir / 'prismer' sys.path.insert(0, submodule_dir.as_posix()) from dataset import create_dataset, create_loader from dataset.utils import pre_question from model.prismer_caption import PrismerCaption from model.prismer_vqa import PrismerVQA def download_models() -> None: if not pathlib.Path('prismer/experts/expert_weights/').exists(): subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer') model_names = [ 'vqa_prismer_base', 'vqa_prismer_large', 'pretrain_prismer_base', 'pretrain_prismer_large', ] for model_name in model_names: if pathlib.Path(f'prismer/logging/{model_name}').exists(): continue subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer') def build_deformable_conv() -> None: subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops') def run_experts(image_path: str) -> tuple[str | None, ...]: helper_dir = submodule_dir / 'helpers' shutil.rmtree(helper_dir, ignore_errors=True) image_dir = helper_dir / 'images' image_dir.mkdir(parents=True, exist_ok=True) out_path = image_dir / 'image.jpg' cv2.imwrite(out_path.as_posix(), cv2.imread(image_path)) expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation'] for expert_name in expert_names: env = os.environ.copy() if 'PYTHONPATH' in env: env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}' else: env['PYTHONPATH'] = submodule_dir.as_posix() subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'), cwd='prismer', env=env, check=True) keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection'] results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys] return tuple(path.as_posix() for path in results) class Model: def __init__(self): self.config = None self.model = None self.tokenizer = None self.exp_name = '' self.mode = '' def set_model(self, exp_name: str) -> None: if exp_name == self.exp_name: return # remap model name if self.exp_name == 'Prismer-Base': self.exp_name = 'prismer_base' elif self.exp_name == 'Prismer-Large': self.exp_name = 'prismer_large' # load checkpoints if self.mode == 'caption': config = { 'dataset': 'demo', 'data_path': 'prismer/helpers', 'label_path': 'prismer/helpers/labels', 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'], 'image_resolution': 480, 'prismer_model': self.exp_name, 'freeze': 'freeze_vision', 'prefix': '', } model = PrismerCaption(config) state_dict = torch.load(f'prismer/logging/pretrain_{self.exp_name}/pytorch_model.bin', map_location='cuda:0') elif self.mode == 'vqa': config = { 'dataset': 'demo', 'data_path': 'prismer/helpers', 'label_path': 'prismer/helpers/labels', 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'], 'image_resolution': 480, 'prismer_model': self.exp_name, 'freeze': 'freeze_vision', } model = PrismerVQA(config) state_dict = torch.load(f'prismer/logging/vqa_{self.exp_name}/pytorch_model.bin', map_location='cuda:0') model.load_state_dict(state_dict) model.eval() self.config = config self.model = model self.tokenizer = model.tokenizer self.exp_name = exp_name @torch.inference_mode() def run_caption_model(self, exp_name: str) -> str: self.set_model(exp_name) _, test_dataset = create_dataset('caption', self.config) test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False) experts, _ = next(iter(test_loader)) captions = self.model(experts, train=False, prefix=self.config['prefix']) captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids caption = captions.to(experts['rgb'].device)[0] caption = self.tokenizer.decode(caption, skip_special_tokens=True) caption = caption.capitalize() + '.' return caption def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]: out_paths = run_experts(image_path) caption = self.run_caption_model(model_name) label_prettify(image_path, out_paths) return caption, *out_paths @torch.inference_mode() def run_vqa_model(self, exp_name: str, question: str) -> str: self.set_model(exp_name) _, test_dataset = create_dataset('caption', self.config) test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False) experts, _ = next(iter(test_loader)) question = pre_question(question) answer = self.model(experts, question, train=False, inference='generate') answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids answer = answer.to(experts['rgb'].device)[0] answer = self.tokenizer.decode(answer, skip_special_tokens=True) answer = answer.capitalize() + '.' return answer def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]: out_paths = run_experts(image_path) answer = self.run_vqa_model(model_name, question) label_prettify(image_path, out_paths) return answer, *out_paths