prismer / patch
shikunl's picture
Reset again!
b734d92
raw history blame
No virus
3.22 kB
diff --git a/dataset/caption_dataset.py b/dataset/caption_dataset.py
index 266fdda..0cc5d3f 100644
--- a/dataset/caption_dataset.py
+++ b/dataset/caption_dataset.py
@@ -50,7 +50,7 @@ class Caption(Dataset):
elif self.dataset == 'demo':
img_path_split = self.data_list[index]['image'].split('/')
img_name = img_path_split[-2] + '/' + img_path_split[-1]
- image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
+ image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
experts = self.transform(image, labels)
experts = post_label_process(experts, labels_info)
diff --git a/dataset/utils.py b/dataset/utils.py
index b368aac..418358c 100644
--- a/dataset/utils.py
+++ b/dataset/utils.py
@@ -5,6 +5,7 @@
# https://github.com/NVlabs/prismer/blob/main/LICENSE
import os
+import pathlib
import re
import json
import torch
@@ -14,10 +15,12 @@ import torchvision.transforms as transforms
import torchvision.transforms.functional as transforms_f
from dataset.randaugment import RandAugment
-COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
-ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
-DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
-BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
+cur_dir = pathlib.Path(__file__).parent
+
+COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
+ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
+DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
+BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
class Transform:
diff --git a/model/prismer.py b/model/prismer.py
index 080253a..02362a4 100644
--- a/model/prismer.py
+++ b/model/prismer.py
@@ -5,6 +5,7 @@
# https://github.com/NVlabs/prismer/blob/main/LICENSE
import json
+import pathlib
import torch.nn as nn
from model.modules.vit import load_encoder
@@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder
from transformers import RobertaTokenizer, RobertaConfig
+cur_dir = pathlib.Path(__file__).parent
+
+
class Prismer(nn.Module):
def __init__(self, config):
super().__init__()
@@ -26,7 +30,7 @@ class Prismer(nn.Module):
elif exp in ['obj_detection', 'ocr_detection']:
self.experts[exp] = 64
- prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
+ prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
@@ -35,7 +39,7 @@ class Prismer(nn.Module):
self.prepare_to_train(config['freeze'])
self.ignored_modules = self.get_ignored_modules(config['freeze'])
-
+
def prepare_to_train(self, mode='none'):
for name, params in self.named_parameters():
if mode == 'freeze_lang':