import os import pathlib import shlex import shutil import subprocess import sys import cv2 import torch repo_dir = pathlib.Path(__file__).parent submodule_dir = repo_dir / 'prismer' sys.path.insert(0, submodule_dir.as_posix()) from dataset import create_dataset, create_loader from model.prismer_caption import PrismerCaption def download_models() -> None: if not pathlib.Path('prismer/experts/expert_weights/').exists(): subprocess.run(shlex.split( 'python download_checkpoints.py --download_experts=True'), cwd='prismer') model_names = [ 'vqa_prismer_base', 'vqa_prismer_large', 'pretrain_prismer_base', 'pretrain_prismer_large', ] for model_name in model_names: if pathlib.Path(f'prismer/logging/{model_name}').exists(): continue subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer') def build_deformable_conv() -> None: subprocess.run( shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops') def run_experts(image_path: str) -> tuple[str | None, ...]: helper_dir = submodule_dir / 'helpers' shutil.rmtree(helper_dir, ignore_errors=True) image_dir = helper_dir / 'images' image_dir.mkdir(parents=True, exist_ok=True) out_path = image_dir / 'image.jpg' cv2.imwrite(out_path.as_posix(), cv2.imread(image_path)) expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation'] for expert_name in expert_names: env = os.environ.copy() if 'PYTHONPATH' in env: env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}' else: env['PYTHONPATH'] = submodule_dir.as_posix() subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'), cwd='prismer', env=env, check=True) # keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection'] keys = ['depth', 'edge', 'normal'] results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys] return results[0].as_posix(), results[1].as_posix(), results[2].as_posix()