diff --git a/dataset/caption_dataset.py b/dataset/caption_dataset.py index 266fdda..0cc5d3f 100644 --- a/dataset/caption_dataset.py +++ b/dataset/caption_dataset.py @@ -50,7 +50,7 @@ class Caption(Dataset): elif self.dataset == 'demo': img_path_split = self.data_list[index]['image'].split('/') img_name = img_path_split[-2] + '/' + img_path_split[-1] - image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts) + image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts) experts = self.transform(image, labels) experts = post_label_process(experts, labels_info) diff --git a/dataset/utils.py b/dataset/utils.py index b368aac..418358c 100644 --- a/dataset/utils.py +++ b/dataset/utils.py @@ -5,6 +5,7 @@ # https://github.com/NVlabs/prismer/blob/main/LICENSE import os +import pathlib import re import json import torch @@ -14,10 +15,12 @@ import torchvision.transforms as transforms import torchvision.transforms.functional as transforms_f from dataset.randaugment import RandAugment -COCO_FEATURES = torch.load('dataset/coco_features.pt')['features'] -ADE_FEATURES = torch.load('dataset/ade_features.pt')['features'] -DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features'] -BACKGROUND_FEATURES = torch.load('dataset/background_features.pt') +cur_dir = pathlib.Path(__file__).parent + +COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features'] +ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features'] +DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features'] +BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt') class Transform: diff --git a/model/prismer.py b/model/prismer.py index 080253a..02362a4 100644 --- a/model/prismer.py +++ b/model/prismer.py @@ -5,6 +5,7 @@ # https://github.com/NVlabs/prismer/blob/main/LICENSE import json +import pathlib import torch.nn as nn from model.modules.vit import load_encoder @@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder from transformers import RobertaTokenizer, RobertaConfig +cur_dir = pathlib.Path(__file__).parent + + class Prismer(nn.Module): def __init__(self, config): super().__init__() @@ -26,7 +30,7 @@ class Prismer(nn.Module): elif exp in ['obj_detection', 'ocr_detection']: self.experts[exp] = 64 - prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']] + prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']] roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model']) self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name']) @@ -35,7 +39,7 @@ class Prismer(nn.Module): self.prepare_to_train(config['freeze']) self.ignored_modules = self.get_ignored_modules(config['freeze']) - + def prepare_to_train(self, mode='none'): for name, params in self.named_parameters(): if mode == 'freeze_lang':