Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import concurrent.futures | |
| import os | |
| import pathlib | |
| import shlex | |
| import shutil | |
| import subprocess | |
| import sys | |
| import hashlib | |
| from typing import Tuple | |
| try: | |
| import ruamel_yaml as yaml | |
| except ModuleNotFoundError: | |
| import ruamel.yaml as yaml | |
| import cv2 | |
| import torch | |
| from label_prettify import label_prettify | |
| repo_dir = pathlib.Path(__file__).parent | |
| submodule_dir = repo_dir / 'prismer' | |
| sys.path.insert(0, submodule_dir.as_posix()) | |
| from dataset import create_dataset, create_loader | |
| from dataset.utils import pre_question | |
| from model.prismer_caption import PrismerCaption | |
| from model.prismer_vqa import PrismerVQA | |
| from model.modules.utils import interpolate_pos_embed | |
| def download_models() -> None: | |
| if not pathlib.Path('prismer/experts/expert_weights/').exists(): | |
| subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer') | |
| model_names = [ | |
| 'vqa_prismer_base', | |
| 'vqa_prismer_large', | |
| 'pretrain_prismer_base', | |
| 'pretrain_prismer_large', | |
| ] | |
| for model_name in model_names: | |
| if pathlib.Path(f'prismer/logging/{model_name}').exists(): | |
| continue | |
| subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer') | |
| def build_deformable_conv() -> None: | |
| subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops') | |
| def run_expert(expert_name: str): | |
| env = os.environ.copy() | |
| if 'PYTHONPATH' in env: | |
| env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}' | |
| else: | |
| env['PYTHONPATH'] = submodule_dir.as_posix() | |
| subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'), | |
| cwd='prismer', | |
| env=env, | |
| check=True) | |
| def compute_md5(image_path: str) -> str: | |
| with open(image_path, 'rb') as f: | |
| s = f.read() | |
| return hashlib.md5(s).hexdigest() | |
| def run_experts(image_path: str) -> Tuple[str, Tuple[str, ...]]: | |
| im_name = compute_md5(image_path) | |
| image_dir = submodule_dir / 'helpers' / 'images' | |
| out_path = image_dir / f'{im_name}.jpg' | |
| image_dir.mkdir(parents=True, exist_ok=True) | |
| keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection'] | |
| results = [pathlib.Path('prismer/helpers/labels') / key / f'helpers/images/{im_name}.png' for key in keys] | |
| results_pretty = [pathlib.Path('prismer/helpers/labels') / key / f'helpers/images/{im_name}_p.png' for key in keys] | |
| out_paths = tuple(path.as_posix() for path in results) | |
| pretty_paths = tuple(path.as_posix() for path in results_pretty) | |
| config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader) | |
| config['im_name'] = im_name | |
| with open('prismer/configs/experts.yaml', 'w') as yaml_file: | |
| yaml.dump(config, yaml_file, default_flow_style=False) | |
| if not os.path.exists(out_paths[0]): | |
| cv2.imwrite(out_path.as_posix(), cv2.imread(image_path)) | |
| # paralleled inference | |
| expert_names = ['edge', 'normal', 'objdet', 'ocrdet', 'segmentation'] | |
| run_expert('depth') | |
| with concurrent.futures.ProcessPoolExecutor() as executor: | |
| executor.map(run_expert, expert_names) | |
| executor.shutdown(wait=True) | |
| # no parallelization just to be safe | |
| # expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation'] | |
| # for exp in expert_names: | |
| # run_expert(exp) | |
| label_prettify(image_path, out_paths) | |
| return im_name, pretty_paths | |
| class Model: | |
| def __init__(self): | |
| self.config = None | |
| self.model = None | |
| self.tokenizer = None | |
| self.model_name = '' | |
| self.exp_name = '' | |
| self.mode = '' | |
| def set_model(self, exp_name: str, mode: str) -> None: | |
| if exp_name == self.exp_name and mode == self.mode: | |
| return | |
| # load checkpoints | |
| model_name = exp_name.lower().replace('-', '_') | |
| if mode == 'caption': | |
| config = { | |
| 'dataset': 'demo', | |
| 'data_path': 'prismer/helpers', | |
| 'label_path': 'prismer/helpers/labels', | |
| 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'], | |
| 'image_resolution': 480, | |
| 'prismer_model': model_name, | |
| 'freeze': 'freeze_vision', | |
| 'prefix': '', | |
| } | |
| model = PrismerCaption(config) | |
| state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0') | |
| state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'], | |
| len(model.expert_encoder.positional_embedding)) | |
| elif mode == 'vqa': | |
| config = { | |
| 'dataset': 'demo', | |
| 'data_path': 'prismer/helpers', | |
| 'label_path': 'prismer/helpers/labels', | |
| 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'], | |
| 'image_resolution': 480, | |
| 'prismer_model': model_name, | |
| 'freeze': 'freeze_vision', | |
| 'prefix': '', | |
| } | |
| model = PrismerVQA(config) | |
| state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0') | |
| state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'], | |
| len(model.expert_encoder.positional_embedding)) | |
| model.load_state_dict(state_dict) | |
| model = model.half() | |
| model.eval() | |
| self.config = config | |
| self.model = model.to('cuda:0') | |
| self.tokenizer = model.tokenizer | |
| self.exp_name = exp_name | |
| self.mode = mode | |
| def run_caption_model(self, exp_name: str, im_name: str) -> str: | |
| self.set_model(exp_name, 'caption') | |
| self.config['im_name'] = im_name | |
| _, test_dataset = create_dataset('caption', self.config) | |
| test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False) | |
| experts, _ = next(iter(test_loader)) | |
| for exp in experts: | |
| if exp == 'obj_detection': | |
| experts[exp]['label'] = experts['obj_detection']['label'].to('cuda:0') | |
| experts[exp]['instance'] = experts['obj_detection']['instance'].to('cuda:0') | |
| else: | |
| experts[exp] = experts[exp].to('cuda:0') | |
| captions = self.model(experts, train=False, prefix=self.config['prefix']) | |
| captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids | |
| caption = captions.to(experts['rgb'].device)[0] | |
| caption = self.tokenizer.decode(caption, skip_special_tokens=True) | |
| caption = caption.capitalize() + '.' | |
| return caption | |
| def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]: | |
| im_name, pretty_paths = run_experts(image_path) | |
| caption = self.run_caption_model(model_name, im_name) | |
| return caption, *pretty_paths | |
| def run_vqa_model(self, exp_name: str, im_name: str, question: str) -> str: | |
| self.set_model(exp_name, 'vqa') | |
| self.config['im_name'] = im_name | |
| _, test_dataset = create_dataset('caption', self.config) | |
| test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False) | |
| experts, _ = next(iter(test_loader)) | |
| for exp in experts: | |
| if exp == 'obj_detection': | |
| experts[exp]['label'] = experts['obj_detection']['label'].to('cuda:0') | |
| experts[exp]['instance'] = experts['obj_detection']['instance'].to('cuda:0') | |
| else: | |
| experts[exp] = experts[exp].to('cuda:0') | |
| question = pre_question(question) | |
| answer = self.model(experts, [question], train=False, inference='generate') | |
| answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids | |
| answer = answer.to(experts['rgb'].device)[0] | |
| answer = self.tokenizer.decode(answer, skip_special_tokens=True) | |
| answer = answer.capitalize() + '.' | |
| return answer | |
| def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]: | |
| im_name, pretty_paths = run_experts(image_path) | |
| answer = self.run_vqa_model(model_name, im_name, question) | |
| return answer, *pretty_paths | |