Edit model card

CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ

This is a version of the ViT-H-14 vision model based on timm's vit_huge_patch14_clip_224.laion2b quantized to 2-bit via Half-Quadratic Quantization (HQQ): https://mobiusml.github.io/hqq_blog/

This 2-bit model achieves a 0.716 zero-shot top-1 accuracy on Imagenet, outperforming a full-precision ViT-B-32 (0.664).

Basic Usage

To run the model, install the HQQ library from https://github.com/mobiusml/hqq and use it as follows:

from hqq.engine.timm import HQQtimm
model = HQQtimm.from_quantized("mobiuslabsgmbh/CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ")

Zero-Shot Classification

For zero-shot classification you'd need the text model as well, here's a complete example:

!pip install open_clip_torch
!pip install Pillow

import torch
import numpy as np

import open_clip
orig_model, _ , preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2B-s32B-b79K')
tokenizer  = open_clip.get_tokenizer('ViT-H-14')
model_text = orig_model.encode_text

from hqq.engine.timm import HQQtimm
model_visual = HQQtimm.from_quantized("mobiuslabsgmbh/CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ")

###############################################################
#Add your own templates here, we provide simple ones below.
#https://github.com/openai/CLIP/blob/main/data/prompts.md for the complete list 
TEMPLATES = (
    lambda c: f'itap of a {c}.',
    lambda c: f'a origami {c}.',
    lambda c: f'a bad photo of the {c}.',
    lambda c: f'a photo of the large {c}.',
    lambda c: f'a photo of the small {c}.',
    lambda c: f'a {c} in a video game.',
    lambda c: f'art of the {c}.',
)

@torch.no_grad()
def forward_image(img):
    x = preprocess(img).unsqueeze(0)
    f = model_visual(x.half().cuda())
    f /= torch.norm(f, p=2, dim=-1, keepdim=True)
    return f

@torch.no_grad()
def forward_text(text_batch_list, normalize=True):
    inputs  = tokenizer(text_batch_list)
    f       = model_text(inputs)
    if(normalize):
        f  /= torch.norm(f, p=2, dim=-1, keepdim=True)
    del inputs
    return f.half().to('cuda')

def forward_text_with_templates(text, templates=TEMPLATES, normalize=True):
    f = forward_text([t(text) for t in templates], normalize=False).mean(axis=0)
    if(normalize):
        f  /= torch.norm(f, p=2, dim=-1, keepdim=True)
    return f

def classifier_zero_shot_with_pil(img, classes):
    classifiers  = torch.cat([forward_text_with_templates(c).reshape([1, -1]) for c in classes], axis=0)
    img_features = forward_image(img)
    scores       = torch.matmul(img_features, classifiers.T)[0].detach().cpu().numpy()
    out          = classes[np.argmax(scores)]
    return out
###############################################################
from PIL import Image
import requests
#img_path_or_url = 'https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-looking-at-camera-1593184780.jpg' #Cat
#img_path_or_url = 'https://www.shutterstock.com/image-photo/photo-cute-golden-retriever-running-600nw-2291249193.jpg' #Dog
img_path_or_url = "https://my-sweet-usa.de/cdn/shop/products/1727.jpg" #bag of chips

img     = Image.open(requests.get(img_path_or_url, stream=True).raw)
classes = ['cat', 'dog', 'car', 'tiger', 'bag of chips']
out     = classifier_zero_shot_with_pil(img, classes)
print("It's a picture of a " + out) #It's a picture of a bag of chips

Limitations:
-Only supports single GPU runtime.
-Doesn't support finetuning the linear layers.

Downloads last month
5
Inference API
Drag image file here or click to browse from your device
Inference API (serverless) has been turned off for this model.

Collection including mobiuslabsgmbh/CLIP-ViT-H-14-laion2B-2bit_g16_s128-HQQ