Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
import torch | |
try: | |
import ruamel_yaml as yaml | |
except ModuleNotFoundError: | |
import ruamel.yaml as yaml | |
from model.prismer_caption import PrismerCaption | |
from dataset import create_dataset, create_loader | |
from tqdm import tqdm | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--mode', default='') | |
parser.add_argument('--port', default='') | |
parser.add_argument('--exp_name', default='', type=str) | |
args = parser.parse_args() | |
# load config | |
config = yaml.load(open('configs/caption.yaml', 'r'), Loader=yaml.Loader)['demo'] | |
# generate expert labels | |
if len(config['experts']) > 0: | |
script_name = f'python experts/generate_depth.py' | |
os.system(script_name) | |
print('***** Generated Depth *****') | |
script_name = f'python experts/generate_edge.py' | |
os.system(script_name) | |
print('***** Generated Edge *****') | |
script_name = f'python experts/generate_normal.py' | |
os.system(script_name) | |
print('***** Generated Surface Normals *****') | |
script_name = f'python experts/generate_objdet.py' | |
os.system(script_name) | |
print('***** Generated Object Detection Labels *****') | |
script_name = f'python experts/generate_ocrdet.py' | |
os.system(script_name) | |
print('***** Generated OCR Detection Labels *****') | |
script_name = f'python experts/generate_segmentation.py' | |
os.system(script_name) | |
print('***** Generated Segmentation Labels *****') | |
# load datasets | |
_, test_dataset = create_dataset('caption', config) | |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False) | |
# load pre-trained model | |
model = PrismerCaption(config) | |
state_dict = torch.load(f'logging/caption_{args.exp_name}/pytorch_model.bin', map_location='cuda:0') | |
model.load_state_dict(state_dict) | |
tokenizer = model.tokenizer | |
# inference | |
model.eval() | |
with torch.no_grad(): | |
for step, (experts, data_ids) in enumerate(tqdm(test_loader)): | |
captions = model(experts, train=False, prefix=config['prefix']) | |
captions = tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids | |
caption = captions.to(experts['rgb'].device)[0] | |
caption = tokenizer.decode(caption, skip_special_tokens=True) | |
caption = caption.capitalize() + '.' | |
# save caption | |
save_path = test_loader.dataset.data_list[data_ids[0]]['image'].replace('jpg', 'txt') | |
with open(save_path, 'w') as f: | |
f.write(caption) | |
print('All Done.') | |