patrickramos commited on
Commit
b991b4f
β€’
1 Parent(s): b66e6ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -179
app.py CHANGED
@@ -1,20 +1,14 @@
1
- from transformers import CLIPModel, CLIPProcessor
2
-
3
- MODEL_ID = 'openai/clip-vit-base-patch32' #@param {'type': 'string'}
4
- LOAD_IN_8BIT = False #@param {'type': 'boolean'}
5
- BATCH_SIZE = 1024 #@param {'type': 'integer'}
6
- REVISION = '' #@param {'type': 'string'}
7
- REVISION = None if not REVISION else REVISION
8
-
9
- from transformers import CLIPConfig
10
- from huggingface_hub import hf_hub_download
11
- from safetensors.torch import load_file
12
-
13
  import os
14
  from huggingface_hub import login
15
 
16
  login(os.environ['hf_token'])
17
 
 
 
 
 
 
 
18
  def load_distillclip(model_id, revision=None):
19
  ckpt_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=revision)
20
  config = CLIPConfig.from_pretrained(model_id)
@@ -27,34 +21,21 @@ def load_distillclip(model_id, revision=None):
27
  bias=True,
28
  )
29
  model.vision_model.pre_layrnorm = nn.Identity()
30
- # model.vision_model.post_layernorm = nn.Identity()
31
  print(model.load_state_dict({k.removeprefix('student.'): v for k, v in load_file(ckpt_path).items()}))
32
- # model.load_state_dict(load_file(ckpt_path))
33
  return model
 
34
 
 
35
  from torch import nn
36
- from accelerate import init_empty_weights, infer_auto_device_map
37
- from transformers import CLIPModel, CLIPProcessor
38
  from einops import reduce
 
39
 
40
  class ZeroShotCLIP(nn.Module):
41
- def __init__(self, model_id=None, model=None, processor=None,classes=[], templates=[], load_in_8bit=False):
42
  super().__init__()
43
 
44
- self.load_in_8bit = load_in_8bit
45
- if model is not None and processor is not None:
46
- self.model = model.eval()
47
- self.processor = processor
48
- else:
49
- if load_in_8bit:
50
- with init_empty_weights():
51
- dummy = CLIPModel.from_pretrained(model_id)
52
- device_map = infer_auto_device_map(dummy)
53
- del dummy
54
- self.model = CLIPModel.from_pretrained(model_id, load_in_8bit=True, device_map=device_map)
55
- else:
56
- self.model = CLIPModel.from_pretrained(model_id).eval()
57
- self.processor = CLIPProcessor.from_pretrained(model_id)
58
  self.classes = classes
59
  self.templates = templates
60
  self._init_weights()
@@ -63,8 +44,6 @@ class ZeroShotCLIP(nn.Module):
63
  def _init_weights(self):
64
  self.model.eval()
65
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66
- if not self.load_in_8bit:
67
- self.model = self.model.to(device)
68
  weights = []
69
  for classname in tqdm(self.classes):
70
  prompts = [template.format(classname) for template in self.templates]
@@ -76,159 +55,54 @@ class ZeroShotCLIP(nn.Module):
76
  weights.append(embeddings)
77
  weights = torch.stack(weights)
78
  self.register_buffer('weights', weights)
79
- if not self.load_in_8bit:
80
- self.model = self.model.cpu()
81
 
82
  @torch.no_grad()
83
  def forward(self, pixel_values):
84
  x = self.model.get_image_features(pixel_values=pixel_values)
85
  x /= x.norm(dim=-1, keepdim=True)
86
- return x.mm(self.weights.t())
87
 
88
  def preprocess_and_forward(self, x):
89
- x = self.processor(images=x)
90
- return self(x)
91
-
92
- def to(self, *args, **kwargs):
93
- if not self.load_in_8bit:
94
- return super().to(*args, **kwargs)
95
- else:
96
- self.weights = self.weights.to(*args, **kwargs)
97
- return self
98
-
99
- model = load_distillclip('Ramos-Ramos/distillclip-different-moon-37')
100
- processor = CLIPProcessor.from_pretrained(MODEL_ID)
101
-
102
- pipe = pipeline("zero-shot-image-classification", model=model, feature_extractor=processor.image_processor, tokenizer=processor.tokenizer)
103
-
104
- cifar_templates = [
105
- 'a photo of a {}.',
106
- 'a blurry photo of a {}.',
107
- 'a black and white photo of a {}.',
108
- 'a low contrast photo of a {}.',
109
- 'a high contrast photo of a {}.',
110
- 'a bad photo of a {}.',
111
- 'a good photo of a {}.',
112
- 'a photo of a small {}.',
113
- 'a photo of a big {}.',
114
- 'a photo of the {}.',
115
- 'a blurry photo of the {}.',
116
- 'a black and white photo of the {}.',
117
- 'a low contrast photo of the {}.',
118
- 'a high contrast photo of the {}.',
119
- 'a bad photo of the {}.',
120
- 'a good photo of the {}.',
121
- 'a photo of the small {}.',
122
- 'a photo of the big {}.',
123
- ]
124
-
125
- imagenet_templates = [
126
- 'a bad photo of a {}.',
127
- 'a photo of many {}.',
128
- 'a sculpture of a {}.',
129
- 'a photo of the hard to see {}.',
130
- 'a low resolution photo of the {}.',
131
- 'a rendering of a {}.',
132
- 'graffiti of a {}.',
133
- 'a bad photo of the {}.',
134
- 'a cropped photo of the {}.',
135
- 'a tattoo of a {}.',
136
- 'the embroidered {}.',
137
- 'a photo of a hard to see {}.',
138
- 'a bright photo of a {}.',
139
- 'a photo of a clean {}.',
140
- 'a photo of a dirty {}.',
141
- 'a dark photo of the {}.',
142
- 'a drawing of a {}.',
143
- 'a photo of my {}.',
144
- 'the plastic {}.',
145
- 'a photo of the cool {}.',
146
- 'a close-up photo of a {}.',
147
- 'a black and white photo of the {}.',
148
- 'a painting of the {}.',
149
- 'a painting of a {}.',
150
- 'a pixelated photo of the {}.',
151
- 'a sculpture of the {}.',
152
- 'a bright photo of the {}.',
153
- 'a cropped photo of a {}.',
154
- 'a plastic {}.',
155
- 'a photo of the dirty {}.',
156
- 'a jpeg corrupted photo of a {}.',
157
- 'a blurry photo of the {}.',
158
- 'a photo of the {}.',
159
- 'a good photo of the {}.',
160
- 'a rendering of the {}.',
161
- 'a {} in a video game.',
162
- 'a photo of one {}.',
163
- 'a doodle of a {}.',
164
- 'a close-up photo of the {}.',
165
- 'a photo of a {}.',
166
- 'the origami {}.',
167
- 'the {} in a video game.',
168
- 'a sketch of a {}.',
169
- 'a doodle of the {}.',
170
- 'a origami {}.',
171
- 'a low resolution photo of a {}.',
172
- 'the toy {}.',
173
- 'a rendition of the {}.',
174
- 'a photo of the clean {}.',
175
- 'a photo of a large {}.',
176
- 'a rendition of a {}.',
177
- 'a photo of a nice {}.',
178
- 'a photo of a weird {}.',
179
- 'a blurry photo of a {}.',
180
- 'a cartoon {}.',
181
- 'art of a {}.',
182
- 'a sketch of the {}.',
183
- 'a embroidered {}.',
184
- 'a pixelated photo of a {}.',
185
- 'itap of the {}.',
186
- 'a jpeg corrupted photo of the {}.',
187
- 'a good photo of a {}.',
188
- 'a plushie {}.',
189
- 'a photo of the nice {}.',
190
- 'a photo of the small {}.',
191
- 'a photo of the weird {}.',
192
- 'the cartoon {}.',
193
- 'art of the {}.',
194
- 'a drawing of the {}.',
195
- 'a photo of the large {}.',
196
- 'a black and white photo of a {}.',
197
- 'the plushie {}.',
198
- 'a dark photo of a {}.',
199
- 'itap of a {}.',
200
- 'graffiti of the {}.',
201
- 'a toy {}.',
202
- 'itap of my {}.',
203
- 'a photo of a cool {}.',
204
- 'a photo of a small {}.',
205
- 'a tattoo of the {}.',
206
- ]
207
-
208
- dashcam_templates = [
209
- 'a dashcam recording of {}.',
210
- 'a picture of {}.',
211
- 'a recording of {}.'
212
- ]
213
-
214
- stl10_templates = [
215
- 'a photo of a {}.',
216
- 'a photo of the {}.',
217
- ]
218
-
219
- oxfordpets_templates = [
220
- 'a photo of a {}, a type of pet.',
221
- ]
222
-
223
- def predict(image, texts):
224
- texts = texts.split(', ')
225
- out = pipe(image, candidate_labels=texts)
226
- return {d['label']: d['score'] for d in out}
227
 
228
  demo = gr.Interface(
229
- fn=predict,
230
- inputs=[gr.Image(type='pil'), gr.Textbox(label='comma separated labels'), gr.Dropwdown(['CIFAR', 'ImageNet','STL-10', 'Oxford Pets', 'Dashcam'], label='text templates')],
231
- outputs='label',
 
 
 
 
 
 
 
232
  )
233
-
234
- demo.launch(debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from huggingface_hub import login
3
 
4
  login(os.environ['hf_token'])
5
 
6
+
7
+ from transformers import CLIPConfig, CLIPModel
8
+ from torch import nn
9
+ from huggingface_hub import hf_hub_download
10
+ from safetensors.torch import load_file
11
+
12
  def load_distillclip(model_id, revision=None):
13
  ckpt_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=revision)
14
  config = CLIPConfig.from_pretrained(model_id)
 
21
  bias=True,
22
  )
23
  model.vision_model.pre_layrnorm = nn.Identity()
 
24
  print(model.load_state_dict({k.removeprefix('student.'): v for k, v in load_file(ckpt_path).items()}))
 
25
  return model
26
+
27
 
28
+ import torch
29
  from torch import nn
 
 
30
  from einops import reduce
31
+ from tqdm.auto import tqdm
32
 
33
  class ZeroShotCLIP(nn.Module):
34
+ def __init__(self, model=None, processor=None, classes=[], templates=[], load_in_8bit=False):
35
  super().__init__()
36
 
37
+ self.model = model.eval()
38
+ self.processor = processor
 
 
 
 
 
 
 
 
 
 
 
 
39
  self.classes = classes
40
  self.templates = templates
41
  self._init_weights()
 
44
  def _init_weights(self):
45
  self.model.eval()
46
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
47
  weights = []
48
  for classname in tqdm(self.classes):
49
  prompts = [template.format(classname) for template in self.templates]
 
55
  weights.append(embeddings)
56
  weights = torch.stack(weights)
57
  self.register_buffer('weights', weights)
 
 
58
 
59
  @torch.no_grad()
60
  def forward(self, pixel_values):
61
  x = self.model.get_image_features(pixel_values=pixel_values)
62
  x /= x.norm(dim=-1, keepdim=True)
63
+ return x.mm(self.weights.t()) * 100.00000762939453
64
 
65
  def preprocess_and_forward(self, x):
66
+ x = self.processor(images=x, return_tensors='pt')
67
+ return self(x['pixel_values'])
68
+
69
+
70
+ from transformers import CLIPProcessor
71
+
72
+ model = load_distillclip('Ramos-Ramos/distillclip')
73
+ processor = CLIPProcessor.from_pretrained('Ramos-Ramos/distillclip')
74
+
75
+
76
+ def infer(image, classes, templates):
77
+ classes = [label.strip() for label in classes.split(',')]
78
+ print(classes)
79
+ templates = [template.strip() for template in templates.split(';')]
80
+ print(templates)
81
+ clip = ZeroShotCLIP(model=model, processor=processor, classes=classes, templates=templates)
82
+ preds = clip.preprocess_and_forward(image).softmax(dim=1).flatten()
83
+ return {label: score.item() for label, score in zip(classes, preds)}
84
+
85
+
86
+ import gradio as gr
87
+
88
+ title = 'DistillCLIP'
89
+ description = 'Zero-shot image classification demo with DistillCLIP'
90
+ article = '''DistillCLIP is a distilled version of [CLIP-ViT/B-32](https://huggingface.co/openai/clip-vit-base-patch32).
91
+
92
+ Please refer to the [DistillCLIP model card](https://huggingface.co/Ramos-Ramos/distillclip) for more details on DistillCLIP.
93
+
94
+ Note: As multiplying logits by a temperature prior to the softmax can better distinguish final scores, we multiply DistillCLIP's text-image similarity scores by the teacher CLIP's temperature.'''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  demo = gr.Interface(
97
+ fn=infer,
98
+ inputs=[
99
+ gr.Image(label='Image', type='pil'),
100
+ gr.Textbox(label='Classes', placeholder='cat, truck', info='Classes for classification. Separate classes with commas.'),
101
+ gr.Textbox(label='Prompt/s', placeholder='a photo of a {}.; a blurry photo of a {}.', info='Prompt templates. Use "{}" as placeholder for class. Separate prompts with semi-colons.')
102
+ ],
103
+ outputs=gr.Label(label='Class scores'),
104
+ title=title,
105
+ description=description,
106
+ article=article
107
  )
108
+ demo.launch()