File size: 3,528 Bytes
b2e8f0b
 
5a07b76
 
851e5a9
b2e8f0b
5a07b76
 
851e5a9
5a07b76
38f65d2
5a07b76
851e5a9
f8e16c5
5a07b76
d647df1
 
5a07b76
 
851e5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
d647df1
0086360
851e5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9db6f86
851e5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a07b76
 
851e5a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
---
license: mit
train: false
inference: false
pipeline_tag: image-classification
---

## CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ
This is a version of the ViT-H-14 vision model based on timm's ```vit_huge_patch14_clip_224.laion2b``` quantized to 2-bit via Half-Quadratic Quantization (HQQ): https://mobiusml.github.io/hqq_blog/

This 2-bit model achieves a 0.716 zero-shot top-1 accuracy on Imagenet, outperforming a full-precision ViT-B-32 (0.664).

### Basic Usage
To run the model, install the HQQ library from https://github.com/mobiusml/hqq and use it as follows:
``` Python
from hqq.engine.timm import HQQtimm
model = HQQtimm.from_quantized("mobiuslabsgmbh/CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ")
```

### Zero-Shot Classification
For zero-shot classification you'd need the text model as well, here's a complete example:
``` Python
!pip install open_clip_torch
!pip install Pillow

import torch
import numpy as np

import open_clip
orig_model, _ , preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2B-s32B-b79K')
tokenizer  = open_clip.get_tokenizer('ViT-H-14')
model_text = orig_model.encode_text

from hqq.engine.timm import HQQtimm
model_visual = HQQtimm.from_quantized("mobiuslabsgmbh/CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ")

###############################################################
#Add your own templates here, we provide simple ones below.
#https://github.com/openai/CLIP/blob/main/data/prompts.md for the complete list 
TEMPLATES = (
    lambda c: f'itap of a {c}.',
    lambda c: f'a origami {c}.',
    lambda c: f'a bad photo of the {c}.',
    lambda c: f'a photo of the large {c}.',
    lambda c: f'a photo of the small {c}.',
    lambda c: f'a {c} in a video game.',
    lambda c: f'art of the {c}.',
)

@torch.no_grad()
def forward_image(img):
	x = preprocess(img).unsqueeze(0)
	f = model_visual(x.half().cuda())
	f /= torch.norm(f, p=2, dim=-1, keepdim=True)
	return f

@torch.no_grad()
def forward_text(text_batch_list, normalize=True):
	inputs  = tokenizer(text_batch_list)
	f       = model_text(inputs)
	if(normalize):
		f  /= torch.norm(f, p=2, dim=-1, keepdim=True)
	del inputs
	return f.half().to('cuda')

def forward_text_with_templates(text, templates=TEMPLATES, normalize=True):
	f = forward_text([t(text) for t in templates], normalize=False).mean(axis=0)
	if(normalize):
		f  /= torch.norm(f, p=2, dim=-1, keepdim=True)
	return f

def classifier_zero_shot_with_pil(img, classes):
	classifiers  = torch.cat([forward_text_with_templates(c).reshape([1, -1]) for c in classes], axis=0)
	img_features = forward_image(img)
	scores       = torch.matmul(img_features, classifiers.T)[0].detach().cpu().numpy()
	out          = classes[np.argmax(scores)]
	return out
###############################################################
from PIL import Image
import requests
#img_path_or_url = 'https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-looking-at-camera-1593184780.jpg' #Cat
#img_path_or_url = 'https://www.shutterstock.com/image-photo/photo-cute-golden-retriever-running-600nw-2291249193.jpg' #Dog
img_path_or_url = "https://my-sweet-usa.de/cdn/shop/products/1727.jpg" #bag of chips

img     = Image.open(requests.get(img_path_or_url, stream=True).raw)
classes = ['cat', 'dog', 'car', 'tiger', 'bag of chips']
out     = classifier_zero_shot_with_pil(img, classes)
print("It's a picture of a " + out) #It's a picture of a bag of chips
```

*Limitations*: <br>
-Only supports single GPU runtime.<br>
-Doesn't support finetuning the linear layers.<br>