File size: 3,223 Bytes
b734d92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
diff --git a/dataset/caption_dataset.py b/dataset/caption_dataset.py
index 266fdda..0cc5d3f 100644
--- a/dataset/caption_dataset.py
+++ b/dataset/caption_dataset.py
@@ -50,7 +50,7 @@ class Caption(Dataset):
         elif self.dataset == 'demo':
             img_path_split = self.data_list[index]['image'].split('/')
             img_name = img_path_split[-2] + '/' + img_path_split[-1]
-            image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
+            image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
 
         experts = self.transform(image, labels)
         experts = post_label_process(experts, labels_info)
diff --git a/dataset/utils.py b/dataset/utils.py
index b368aac..418358c 100644
--- a/dataset/utils.py
+++ b/dataset/utils.py
@@ -5,6 +5,7 @@
 # https://github.com/NVlabs/prismer/blob/main/LICENSE
 
 import os
+import pathlib
 import re
 import json
 import torch
@@ -14,10 +15,12 @@ import torchvision.transforms as transforms
 import torchvision.transforms.functional as transforms_f
 from dataset.randaugment import RandAugment
 
-COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
-ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
-DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
-BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
+cur_dir = pathlib.Path(__file__).parent
+
+COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
+ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
+DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
+BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
 
 
 class Transform:
diff --git a/model/prismer.py b/model/prismer.py
index 080253a..02362a4 100644
--- a/model/prismer.py
+++ b/model/prismer.py
@@ -5,6 +5,7 @@
 # https://github.com/NVlabs/prismer/blob/main/LICENSE
 
 import json
+import pathlib
 import torch.nn as nn
 
 from model.modules.vit import load_encoder
@@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder
 from transformers import RobertaTokenizer, RobertaConfig
 
 
+cur_dir = pathlib.Path(__file__).parent
+
+
 class Prismer(nn.Module):
     def __init__(self, config):
         super().__init__()
@@ -26,7 +30,7 @@ class Prismer(nn.Module):
             elif exp in ['obj_detection', 'ocr_detection']:
                 self.experts[exp] = 64
 
-        prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
+        prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
         roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
 
         self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
@@ -35,7 +39,7 @@ class Prismer(nn.Module):
 
         self.prepare_to_train(config['freeze'])
         self.ignored_modules = self.get_ignored_modules(config['freeze'])
-    
+
     def prepare_to_train(self, mode='none'):
         for name, params in self.named_parameters():
             if mode == 'freeze_lang':