shikunl commited on
Commit
3cce9aa
β€’
1 Parent(s): bd0d673
prismer/dataset/__init__.py CHANGED
@@ -6,18 +6,12 @@
6
 
7
  from torch.utils.data import DataLoader
8
 
9
- from dataset.pretrain_dataset import Pretrain
10
  from dataset.vqa_dataset import VQA
11
  from dataset.caption_dataset import Caption
12
- from dataset.classification_dataset import Classification
13
 
14
 
15
  def create_dataset(dataset, config):
16
- if dataset == 'pretrain':
17
- dataset = Pretrain(config)
18
- return dataset
19
-
20
- elif dataset == 'vqa':
21
  train_dataset = VQA(config, train=True)
22
  test_dataset = VQA(config, train=False)
23
  return train_dataset, test_dataset
@@ -26,11 +20,6 @@ def create_dataset(dataset, config):
26
  train_dataset = Caption(config, train=True)
27
  test_dataset = Caption(config, train=False)
28
  return train_dataset, test_dataset
29
-
30
- elif dataset == 'classification':
31
- train_dataset = Classification(config, train=True)
32
- test_dataset = Classification(config, train=False)
33
- return train_dataset, test_dataset
34
 
35
 
36
  def create_loader(dataset, batch_size, num_workers, train, collate_fn=None):
 
6
 
7
  from torch.utils.data import DataLoader
8
 
 
9
  from dataset.vqa_dataset import VQA
10
  from dataset.caption_dataset import Caption
 
11
 
12
 
13
  def create_dataset(dataset, config):
14
+ if dataset == 'vqa':
 
 
 
 
15
  train_dataset = VQA(config, train=True)
16
  test_dataset = VQA(config, train=False)
17
  return train_dataset, test_dataset
 
20
  train_dataset = Caption(config, train=True)
21
  test_dataset = Caption(config, train=False)
22
  return train_dataset, test_dataset
 
 
 
 
 
23
 
24
 
25
  def create_loader(dataset, batch_size, num_workers, train, collate_fn=None):
prismer/dataset/caption_dataset.py CHANGED
@@ -7,7 +7,7 @@
7
  import glob
8
 
9
  from torch.utils.data import Dataset
10
- from dataset.utils import *
11
  from PIL import ImageFile
12
  ImageFile.LOAD_TRUNCATED_IMAGES = True
13
 
@@ -50,7 +50,7 @@ class Caption(Dataset):
50
  elif self.dataset == 'demo':
51
  img_path_split = self.data_list[index]['image'].split('/')
52
  img_name = img_path_split[-2] + '/' + img_path_split[-1]
53
- image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
54
 
55
  experts = self.transform(image, labels)
56
  experts = post_label_process(experts, labels_info)
 
7
  import glob
8
 
9
  from torch.utils.data import Dataset
10
+ from prismer.dataset.utils import *
11
  from PIL import ImageFile
12
  ImageFile.LOAD_TRUNCATED_IMAGES = True
13
 
 
50
  elif self.dataset == 'demo':
51
  img_path_split = self.data_list[index]['image'].split('/')
52
  img_name = img_path_split[-2] + '/' + img_path_split[-1]
53
+ image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
54
 
55
  experts = self.transform(image, labels)
56
  experts = post_label_process(experts, labels_info)
prismer/dataset/classification_dataset.py DELETED
@@ -1,72 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import glob
8
- from torch.utils.data import Dataset
9
- from dataset.utils import *
10
-
11
-
12
- class Classification(Dataset):
13
- def __init__(self, config, train):
14
- self.data_path = config['data_path']
15
- self.label_path = config['label_path']
16
- self.experts = config['experts']
17
- self.dataset = config['dataset']
18
- self.shots = config['shots']
19
- self.prefix = config['prefix']
20
-
21
- self.train = train
22
- self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=True)
23
-
24
- if train:
25
- data_folders = glob.glob(f'{self.data_path}/imagenet_train/*/')
26
- self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.JPEG')[:self.shots]]
27
- self.answer_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_answer.json'))
28
- self.class_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_class.json'))
29
- else:
30
- data_folders = glob.glob(f'{self.data_path}/imagenet/*/')
31
- self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.JPEG')]
32
- self.answer_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_answer.json'))
33
- self.class_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_class.json'))
34
-
35
- def __len__(self):
36
- return len(self.data_list)
37
-
38
- def __getitem__(self, index):
39
- img_path = self.data_list[index]['image']
40
- if self.train:
41
- img_path_split = img_path.split('/')
42
- img_name = img_path_split[-2] + '/' + img_path_split[-1]
43
- class_name = img_path_split[-2]
44
- image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, img_name, 'imagenet_train', self.experts)
45
- else:
46
- img_path_split = img_path.split('/')
47
- img_name = img_path_split[-2] + '/' + img_path_split[-1]
48
- class_name = img_path_split[-2]
49
- image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, img_name, 'imagenet', self.experts)
50
-
51
- experts = self.transform(image, labels)
52
- experts = post_label_process(experts, labels_info)
53
-
54
- if self.train:
55
- caption = self.prefix + ' ' + self.answer_list[int(self.class_list[class_name])].lower()
56
- return experts, caption
57
- else:
58
- return experts, self.class_list[class_name]
59
-
60
-
61
-
62
-
63
-
64
- # import os
65
- # import glob
66
- #
67
- # data_path = '/Users/shikunliu/Documents/dataset/mscoco/mscoco'
68
- #
69
- # data_folders = glob.glob(f'{data_path}/*/')
70
- # data_list = [data for f in data_folders for data in glob.glob(f + '*.jpg')]
71
-
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/dataset/pretrain_dataset.py DELETED
@@ -1,73 +0,0 @@
1
- # Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
2
- #
3
- # This work is made available under the Nvidia Source Code License-NC.
4
- # To view a copy of this license, visit
5
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
-
7
- import glob
8
-
9
- from torch.utils.data import Dataset
10
- from dataset.utils import *
11
-
12
-
13
- class Pretrain(Dataset):
14
- def __init__(self, config):
15
- self.cc12m_data_path = config['cc12m_data_path']
16
- self.cc3m_data_path = config['cc3m_data_path']
17
- self.coco_data_path = config['coco_data_path']
18
- self.vg_data_path = config['vg_data_path']
19
- self.label_path = config['label_path']
20
- self.experts = config['experts']
21
-
22
- self.data_list = []
23
- if 'cc12m' in config['datasets']:
24
- data_folders = glob.glob(f'{self.cc12m_data_path}/cc12m/*/')
25
- self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
26
- if 'cc3m_sgu' in config['datasets']:
27
- data_folders = glob.glob(f'{self.cc3m_data_path}/cc3m_sgu/*/')
28
- self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
29
- if 'coco' in config['datasets']:
30
- self.data_list += json.load(open(os.path.join(self.coco_data_path, 'coco_karpathy_train.json'), 'r'))
31
- if 'vg' in config['datasets']:
32
- self.data_list += json.load(open(os.path.join(self.vg_data_path, 'vg_caption.json'), 'r'))
33
-
34
- self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.5], train=True)
35
-
36
- def __len__(self):
37
- return len(self.data_list)
38
-
39
- def __getitem__(self, index):
40
- img_path = self.data_list[index]['image']
41
-
42
- if 'cc12m' in img_path:
43
- img_path_split = img_path.split('/')
44
- img_name = img_path_split[-2] + '/' + img_path_split[-1]
45
- image, labels, labels_info = get_expert_labels(self.cc12m_data_path, self.label_path, img_name, 'cc12m', self.experts)
46
-
47
- caption_path = img_path.replace('.jpg', '.txt')
48
- with open(caption_path) as f:
49
- caption = f.readlines()[0]
50
-
51
- elif 'cc3m_sgu' in img_path:
52
- img_path_split = img_path.split('/')
53
- img_name = img_path_split[-2] + '/' + img_path_split[-1]
54
- image, labels, labels_info = get_expert_labels(self.cc3m_data_path, self.label_path, img_name, 'cc3m_sgu', self.experts)
55
-
56
- caption_path = img_path.replace('.jpg', '.txt')
57
- with open(caption_path) as f:
58
- caption = f.readlines()[0]
59
-
60
- elif 'train2014' in img_path or 'val2014' in img_path:
61
- image, labels, labels_info = get_expert_labels(self.coco_data_path, self.label_path, img_path, 'vqav2', self.experts)
62
- caption = self.data_list[index]['caption']
63
-
64
- elif 'visual-genome' in img_path:
65
- img_path_split = img_path.split('/')
66
- img_name = img_path_split[-2] + '/' + img_path_split[-1]
67
- image, labels, labels_info = get_expert_labels(self.vg_data_path, self.label_path, img_name, 'vg', self.experts)
68
- caption = self.data_list[index]['caption']
69
-
70
- experts = self.transform(image, labels)
71
- experts = post_label_process(experts, labels_info)
72
- caption = pre_caption(caption, max_words=30)
73
- return experts, caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer/dataset/utils.py CHANGED
@@ -12,12 +12,16 @@ import PIL.Image as Image
12
  import numpy as np
13
  import torchvision.transforms as transforms
14
  import torchvision.transforms.functional as transforms_f
15
- from dataset.randaugment import RandAugment
 
16
 
17
- COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
18
- ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
19
- DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
20
- BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
 
 
 
21
 
22
 
23
  class Transform:
 
12
  import numpy as np
13
  import torchvision.transforms as transforms
14
  import torchvision.transforms.functional as transforms_f
15
+ import pathlib
16
+ from prismer.dataset.randaugment import RandAugment
17
 
18
+
19
+ cur_dir = pathlib.Path(__file__).parent
20
+
21
+ COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
22
+ ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
23
+ DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
24
+ BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
25
 
26
 
27
  class Transform:
prismer/model/prismer.py CHANGED
@@ -5,12 +5,15 @@
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import json
 
8
  import torch.nn as nn
9
 
10
  from model.modules.vit import load_encoder
11
  from model.modules.roberta import load_decoder
12
  from transformers import RobertaTokenizer, RobertaConfig
13
 
 
 
14
 
15
  class Prismer(nn.Module):
16
  def __init__(self, config):
@@ -26,7 +29,7 @@ class Prismer(nn.Module):
26
  elif exp in ['obj_detection', 'ocr_detection']:
27
  self.experts[exp] = 64
28
 
29
- prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
30
  roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
31
 
32
  self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
 
5
  # https://github.com/NVlabs/prismer/blob/main/LICENSE
6
 
7
  import json
8
+ import pathlib
9
  import torch.nn as nn
10
 
11
  from model.modules.vit import load_encoder
12
  from model.modules.roberta import load_decoder
13
  from transformers import RobertaTokenizer, RobertaConfig
14
 
15
+ cur_dir = pathlib.Path(__file__).parent
16
+
17
 
18
  class Prismer(nn.Module):
19
  def __init__(self, config):
 
29
  elif exp in ['obj_detection', 'ocr_detection']:
30
  self.experts[exp] = 64
31
 
32
+ prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
33
  roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
34
 
35
  self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
prismer_model.py CHANGED
@@ -7,12 +7,12 @@ import sys
7
  import cv2
8
  import torch
9
 
 
 
 
 
10
  repo_dir = pathlib.Path(__file__).parent
11
  submodule_dir = repo_dir / 'prismer'
12
- sys.path.insert(0, submodule_dir.as_posix())
13
-
14
- from dataset import create_dataset, create_loader
15
- from model.prismer_caption import PrismerCaption
16
 
17
 
18
  def download_models() -> None:
@@ -50,11 +50,11 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
50
  for expert_name in expert_names:
51
  env = os.environ.copy()
52
  if 'PYTHONPATH' in env:
53
- env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
54
  else:
55
- env['PYTHONPATH'] = submodule_dir.as_posix()
56
 
57
- subprocess.run(shlex.split(f'python experts/generate_{expert_name}.py'), cwd='prismer', env=env, check=True)
58
 
59
  # keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
60
  keys = ['depth', 'edge', 'normal']
 
7
  import cv2
8
  import torch
9
 
10
+ from prismer.dataset import create_dataset, create_loader
11
+ from prismer.model.prismer_caption import PrismerCaption
12
+
13
+
14
  repo_dir = pathlib.Path(__file__).parent
15
  submodule_dir = repo_dir / 'prismer'
 
 
 
 
16
 
17
 
18
  def download_models() -> None:
 
50
  for expert_name in expert_names:
51
  env = os.environ.copy()
52
  if 'PYTHONPATH' in env:
53
+ env['PYTHONPATH'] = f'{repo_dir.as_posix()}:{env["PYTHONPATH"]}'
54
  else:
55
+ env['PYTHONPATH'] = repo_dir.as_posix()
56
 
57
+ subprocess.run(shlex.split(f'python prismer/experts/generate_{expert_name}.py'), env=env, check=True)
58
 
59
  # keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
60
  keys = ['depth', 'edge', 'normal']