mobicham's picture
Update README.md
0086360
---
license: mit
train: false
inference: false
pipeline_tag: image-classification
---
## CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ
This is a version of the ViT-H-14 vision model based on timm's ```vit_huge_patch14_clip_224.laion2b``` quantized to 2-bit via Half-Quadratic Quantization (HQQ): https://mobiusml.github.io/hqq_blog/
This 2-bit model achieves a 0.716 zero-shot top-1 accuracy on Imagenet, outperforming a full-precision ViT-B-32 (0.664).
### Basic Usage
To run the model, install the HQQ library from https://github.com/mobiusml/hqq and use it as follows:
``` Python
from hqq.engine.timm import HQQtimm
model = HQQtimm.from_quantized("mobiuslabsgmbh/CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ")
```
### Zero-Shot Classification
For zero-shot classification you'd need the text model as well, here's a complete example:
``` Python
!pip install open_clip_torch
!pip install Pillow
import torch
import numpy as np
import open_clip
orig_model, _ , preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2B-s32B-b79K')
tokenizer = open_clip.get_tokenizer('ViT-H-14')
model_text = orig_model.encode_text
from hqq.engine.timm import HQQtimm
model_visual = HQQtimm.from_quantized("mobiuslabsgmbh/CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ")
###############################################################
#Add your own templates here, we provide simple ones below.
#https://github.com/openai/CLIP/blob/main/data/prompts.md for the complete list
TEMPLATES = (
lambda c: f'itap of a {c}.',
lambda c: f'a origami {c}.',
lambda c: f'a bad photo of the {c}.',
lambda c: f'a photo of the large {c}.',
lambda c: f'a photo of the small {c}.',
lambda c: f'a {c} in a video game.',
lambda c: f'art of the {c}.',
)
@torch.no_grad()
def forward_image(img):
x = preprocess(img).unsqueeze(0)
f = model_visual(x.half().cuda())
f /= torch.norm(f, p=2, dim=-1, keepdim=True)
return f
@torch.no_grad()
def forward_text(text_batch_list, normalize=True):
inputs = tokenizer(text_batch_list)
f = model_text(inputs)
if(normalize):
f /= torch.norm(f, p=2, dim=-1, keepdim=True)
del inputs
return f.half().to('cuda')
def forward_text_with_templates(text, templates=TEMPLATES, normalize=True):
f = forward_text([t(text) for t in templates], normalize=False).mean(axis=0)
if(normalize):
f /= torch.norm(f, p=2, dim=-1, keepdim=True)
return f
def classifier_zero_shot_with_pil(img, classes):
classifiers = torch.cat([forward_text_with_templates(c).reshape([1, -1]) for c in classes], axis=0)
img_features = forward_image(img)
scores = torch.matmul(img_features, classifiers.T)[0].detach().cpu().numpy()
out = classes[np.argmax(scores)]
return out
###############################################################
from PIL import Image
import requests
#img_path_or_url = 'https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-looking-at-camera-1593184780.jpg' #Cat
#img_path_or_url = 'https://www.shutterstock.com/image-photo/photo-cute-golden-retriever-running-600nw-2291249193.jpg' #Dog
img_path_or_url = "https://my-sweet-usa.de/cdn/shop/products/1727.jpg" #bag of chips
img = Image.open(requests.get(img_path_or_url, stream=True).raw)
classes = ['cat', 'dog', 'car', 'tiger', 'bag of chips']
out = classifier_zero_shot_with_pil(img, classes)
print("It's a picture of a " + out) #It's a picture of a bag of chips
```
*Limitations*: <br>
-Only supports single GPU runtime.<br>
-Doesn't support finetuning the linear layers.<br>