distillclip / app.py
patrickramos's picture
Update app.py
b991b4f
import os
from huggingface_hub import login
login(os.environ['hf_token'])
from transformers import CLIPConfig, CLIPModel
from torch import nn
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
def load_distillclip(model_id, revision=None):
ckpt_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=revision)
config = CLIPConfig.from_pretrained(model_id)
model = CLIPModel(config)
model.vision_model.embeddings.patch_embedding = nn.Conv2d(
in_channels=model.config.vision_config.num_channels,
out_channels=model.vision_model.embeddings.embed_dim,
kernel_size=model.vision_model.embeddings.patch_size,
stride=model.vision_model.embeddings.patch_size,
bias=True,
)
model.vision_model.pre_layrnorm = nn.Identity()
print(model.load_state_dict({k.removeprefix('student.'): v for k, v in load_file(ckpt_path).items()}))
return model
import torch
from torch import nn
from einops import reduce
from tqdm.auto import tqdm
class ZeroShotCLIP(nn.Module):
def __init__(self, model=None, processor=None, classes=[], templates=[], load_in_8bit=False):
super().__init__()
self.model = model.eval()
self.processor = processor
self.classes = classes
self.templates = templates
self._init_weights()
@torch.no_grad()
def _init_weights(self):
self.model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights = []
for classname in tqdm(self.classes):
prompts = [template.format(classname) for template in self.templates]
prompts = self.processor(text=prompts, truncation=True, padding=True, return_tensors='pt')
embeddings = self.model.get_text_features(**{k: v.to(device) for k, v in prompts.items()}).cpu()
embeddings /= embeddings.norm(dim=-1, keepdim=True)
embeddings = reduce(embeddings, 'b d -> d', 'mean')
embeddings /= embeddings.norm()
weights.append(embeddings)
weights = torch.stack(weights)
self.register_buffer('weights', weights)
@torch.no_grad()
def forward(self, pixel_values):
x = self.model.get_image_features(pixel_values=pixel_values)
x /= x.norm(dim=-1, keepdim=True)
return x.mm(self.weights.t()) * 100.00000762939453
def preprocess_and_forward(self, x):
x = self.processor(images=x, return_tensors='pt')
return self(x['pixel_values'])
from transformers import CLIPProcessor
model = load_distillclip('Ramos-Ramos/distillclip')
processor = CLIPProcessor.from_pretrained('Ramos-Ramos/distillclip')
def infer(image, classes, templates):
classes = [label.strip() for label in classes.split(',')]
print(classes)
templates = [template.strip() for template in templates.split(';')]
print(templates)
clip = ZeroShotCLIP(model=model, processor=processor, classes=classes, templates=templates)
preds = clip.preprocess_and_forward(image).softmax(dim=1).flatten()
return {label: score.item() for label, score in zip(classes, preds)}
import gradio as gr
title = 'DistillCLIP'
description = 'Zero-shot image classification demo with DistillCLIP'
article = '''DistillCLIP is a distilled version of [CLIP-ViT/B-32](https://huggingface.co/openai/clip-vit-base-patch32).
Please refer to the [DistillCLIP model card](https://huggingface.co/Ramos-Ramos/distillclip) for more details on DistillCLIP.
Note: As multiplying logits by a temperature prior to the softmax can better distinguish final scores, we multiply DistillCLIP's text-image similarity scores by the teacher CLIP's temperature.'''
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(label='Image', type='pil'),
gr.Textbox(label='Classes', placeholder='cat, truck', info='Classes for classification. Separate classes with commas.'),
gr.Textbox(label='Prompt/s', placeholder='a photo of a {}.; a blurry photo of a {}.', info='Prompt templates. Use "{}" as placeholder for class. Separate prompts with semi-colons.')
],
outputs=gr.Label(label='Class scores'),
title=title,
description=description,
article=article
)
demo.launch()