shikunl commited on
Commit
61a0078
1 Parent(s): 818a4f8
Files changed (6) hide show
  1. .gitignore +1 -0
  2. .pre-commit-config.yaml +0 -1
  3. app.py +1 -7
  4. patch +0 -82
  5. prismer_model.py +12 -36
  6. style.css +0 -3
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  cache/
 
2
 
3
  # Byte-compiled / optimized / DLL files
4
  __pycache__/
 
1
  cache/
2
+ .idea
3
 
4
  # Byte-compiled / optimized / DLL files
5
  __pycache__/
.pre-commit-config.yaml CHANGED
@@ -1,4 +1,3 @@
1
- exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
  rev: v4.2.0
 
 
1
  repos:
2
  - repo: https://github.com/pre-commit/pre-commit-hooks
3
  rev: v4.2.0
app.py CHANGED
@@ -5,14 +5,8 @@ from __future__ import annotations
5
  import os
6
  import shutil
7
  import subprocess
8
-
9
  import gradio as gr
10
 
11
- if os.getenv('SYSTEM') == 'spaces':
12
- with open('patch') as f:
13
- subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
14
- shutil.copytree('prismer/helpers/images', 'prismer/images', dirs_exist_ok=True)
15
-
16
  from app_caption import create_demo as create_demo_caption
17
  from prismer_model import build_deformable_conv, download_models
18
 
@@ -32,7 +26,7 @@ if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
32
  description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
33
 
34
 
35
- with gr.Blocks(css='style.css') as demo:
36
  gr.Markdown(description)
37
  with gr.Tabs():
38
  with gr.TabItem('Zero-shot Image Captioning'):
 
5
  import os
6
  import shutil
7
  import subprocess
 
8
  import gradio as gr
9
 
 
 
 
 
 
10
  from app_caption import create_demo as create_demo_caption
11
  from prismer_model import build_deformable_conv, download_models
12
 
 
26
  description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
27
 
28
 
29
+ with gr.Blocks() as demo:
30
  gr.Markdown(description)
31
  with gr.Tabs():
32
  with gr.TabItem('Zero-shot Image Captioning'):
patch DELETED
@@ -1,82 +0,0 @@
1
- diff --git a/dataset/caption_dataset.py b/dataset/caption_dataset.py
2
- index 266fdda..0cc5d3f 100644
3
- --- a/dataset/caption_dataset.py
4
- +++ b/dataset/caption_dataset.py
5
- @@ -50,7 +50,7 @@ class Caption(Dataset):
6
- elif self.dataset == 'demo':
7
- img_path_split = self.data_list[index]['image'].split('/')
8
- img_name = img_path_split[-2] + '/' + img_path_split[-1]
9
- - image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
10
- + image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
11
-
12
- experts = self.transform(image, labels)
13
- experts = post_label_process(experts, labels_info)
14
- diff --git a/dataset/utils.py b/dataset/utils.py
15
- index b368aac..418358c 100644
16
- --- a/dataset/utils.py
17
- +++ b/dataset/utils.py
18
- @@ -5,6 +5,7 @@
19
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
20
-
21
- import os
22
- +import pathlib
23
- import re
24
- import json
25
- import torch
26
- @@ -14,10 +15,12 @@ import torchvision.transforms as transforms
27
- import torchvision.transforms.functional as transforms_f
28
- from dataset.randaugment import RandAugment
29
-
30
- -COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
31
- -ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
32
- -DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
33
- -BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')
34
- +cur_dir = pathlib.Path(__file__).parent
35
- +
36
- +COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
37
- +ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
38
- +DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
39
- +BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
40
-
41
-
42
- class Transform:
43
- diff --git a/model/prismer.py b/model/prismer.py
44
- index 080253a..02362a4 100644
45
- --- a/model/prismer.py
46
- +++ b/model/prismer.py
47
- @@ -5,6 +5,7 @@
48
- # https://github.com/NVlabs/prismer/blob/main/LICENSE
49
-
50
- import json
51
- +import pathlib
52
- import torch.nn as nn
53
-
54
- from model.modules.vit import load_encoder
55
- @@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder
56
- from transformers import RobertaTokenizer, RobertaConfig
57
-
58
-
59
- +cur_dir = pathlib.Path(__file__).parent
60
- +
61
- +
62
- class Prismer(nn.Module):
63
- def __init__(self, config):
64
- super().__init__()
65
- @@ -26,7 +30,7 @@ class Prismer(nn.Module):
66
- elif exp in ['obj_detection', 'ocr_detection']:
67
- self.experts[exp] = 64
68
-
69
- - prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
70
- + prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
71
- roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
72
-
73
- self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
74
- @@ -35,7 +39,7 @@ class Prismer(nn.Module):
75
-
76
- self.prepare_to_train(config['freeze'])
77
- self.ignored_modules = self.get_ignored_modules(config['freeze'])
78
- -
79
- +
80
- def prepare_to_train(self, mode='none'):
81
- for name, params in self.named_parameters():
82
- if mode == 'freeze_lang':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prismer_model.py CHANGED
@@ -79,28 +79,14 @@ class Model:
79
  if exp_name == self.exp_name:
80
  return
81
  config = {
82
- 'dataset':
83
- 'demo',
84
- 'data_path':
85
- 'prismer/helpers',
86
- 'label_path':
87
- 'prismer/helpers/labels',
88
- 'experts': [
89
- 'depth',
90
- 'normal',
91
- 'seg_coco',
92
- 'edge',
93
- 'obj_detection',
94
- 'ocr_detection',
95
- ],
96
- 'image_resolution':
97
- 480,
98
- 'prismer_model':
99
- 'prismer_base',
100
- 'freeze':
101
- 'freeze_vision',
102
- 'prefix':
103
- 'A picture of',
104
  }
105
  model = PrismerCaption(config)
106
  state_dict = torch.load(
@@ -118,27 +104,17 @@ class Model:
118
  @torch.inference_mode()
119
  def run_caption_model(self, exp_name: str) -> str:
120
  self.set_model(exp_name)
121
-
122
  _, test_dataset = create_dataset('caption', self.config)
123
- test_loader = create_loader(test_dataset,
124
- batch_size=1,
125
- num_workers=4,
126
- train=False)
127
  experts, _ = next(iter(test_loader))
128
- captions = self.model(experts,
129
- train=False,
130
- prefix=self.config['prefix'])
131
- captions = self.tokenizer(captions,
132
- max_length=30,
133
- padding='max_length',
134
- return_tensors='pt').input_ids
135
  caption = captions.to(experts['rgb'].device)[0]
136
  caption = self.tokenizer.decode(caption, skip_special_tokens=True)
137
  caption = caption.capitalize() + '.'
138
  return caption
139
 
140
- def run_caption(self, image_path: str,
141
- model_name: str) -> tuple[str | None, ...]:
142
  out_paths = run_experts(image_path)
143
  caption = self.run_caption_model(model_name)
144
  return caption, *out_paths
 
79
  if exp_name == self.exp_name:
80
  return
81
  config = {
82
+ 'dataset': 'demo',
83
+ 'data_path': 'prismer/helpers',
84
+ 'label_path': 'prismer/helpers/labels',
85
+ 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
86
+ 'image_resolution': 480,
87
+ 'prismer_model': 'prismer_base',
88
+ 'freeze': 'freeze_vision',
89
+ 'prefix': 'A picture of',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  }
91
  model = PrismerCaption(config)
92
  state_dict = torch.load(
 
104
  @torch.inference_mode()
105
  def run_caption_model(self, exp_name: str) -> str:
106
  self.set_model(exp_name)
 
107
  _, test_dataset = create_dataset('caption', self.config)
108
+ test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
 
 
 
109
  experts, _ = next(iter(test_loader))
110
+ captions = self.model(experts, train=False, prefix=self.config['prefix'])
111
+ captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
 
 
 
 
 
112
  caption = captions.to(experts['rgb'].device)[0]
113
  caption = self.tokenizer.decode(caption, skip_special_tokens=True)
114
  caption = caption.capitalize() + '.'
115
  return caption
116
 
117
+ def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
 
118
  out_paths = run_experts(image_path)
119
  caption = self.run_caption_model(model_name)
120
  return caption, *out_paths
style.css DELETED
@@ -1,3 +0,0 @@
1
- h1 {
2
- text-align: center;
3
- }