import os from huggingface_hub import login login(os.environ['hf_token']) from transformers import CLIPConfig, CLIPModel from torch import nn from huggingface_hub import hf_hub_download from safetensors.torch import load_file def load_distillclip(model_id, revision=None): ckpt_path = hf_hub_download(repo_id=model_id, filename="model.safetensors", revision=revision) config = CLIPConfig.from_pretrained(model_id) model = CLIPModel(config) model.vision_model.embeddings.patch_embedding = nn.Conv2d( in_channels=model.config.vision_config.num_channels, out_channels=model.vision_model.embeddings.embed_dim, kernel_size=model.vision_model.embeddings.patch_size, stride=model.vision_model.embeddings.patch_size, bias=True, ) model.vision_model.pre_layrnorm = nn.Identity() print(model.load_state_dict({k.removeprefix('student.'): v for k, v in load_file(ckpt_path).items()})) return model import torch from torch import nn from einops import reduce from tqdm.auto import tqdm class ZeroShotCLIP(nn.Module): def __init__(self, model=None, processor=None, classes=[], templates=[], load_in_8bit=False): super().__init__() self.model = model.eval() self.processor = processor self.classes = classes self.templates = templates self._init_weights() @torch.no_grad() def _init_weights(self): self.model.eval() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') weights = [] for classname in tqdm(self.classes): prompts = [template.format(classname) for template in self.templates] prompts = self.processor(text=prompts, truncation=True, padding=True, return_tensors='pt') embeddings = self.model.get_text_features(**{k: v.to(device) for k, v in prompts.items()}).cpu() embeddings /= embeddings.norm(dim=-1, keepdim=True) embeddings = reduce(embeddings, 'b d -> d', 'mean') embeddings /= embeddings.norm() weights.append(embeddings) weights = torch.stack(weights) self.register_buffer('weights', weights) @torch.no_grad() def forward(self, pixel_values): x = self.model.get_image_features(pixel_values=pixel_values) x /= x.norm(dim=-1, keepdim=True) return x.mm(self.weights.t()) * 100.00000762939453 def preprocess_and_forward(self, x): x = self.processor(images=x, return_tensors='pt') return self(x['pixel_values']) from transformers import CLIPProcessor model = load_distillclip('Ramos-Ramos/distillclip') processor = CLIPProcessor.from_pretrained('Ramos-Ramos/distillclip') def infer(image, classes, templates): classes = [label.strip() for label in classes.split(',')] print(classes) templates = [template.strip() for template in templates.split(';')] print(templates) clip = ZeroShotCLIP(model=model, processor=processor, classes=classes, templates=templates) preds = clip.preprocess_and_forward(image).softmax(dim=1).flatten() return {label: score.item() for label, score in zip(classes, preds)} import gradio as gr title = 'DistillCLIP' description = 'Zero-shot image classification demo with DistillCLIP' article = '''DistillCLIP is a distilled version of [CLIP-ViT/B-32](https://huggingface.co/openai/clip-vit-base-patch32). Please refer to the [DistillCLIP model card](https://huggingface.co/Ramos-Ramos/distillclip) for more details on DistillCLIP. Note: As multiplying logits by a temperature prior to the softmax can better distinguish final scores, we multiply DistillCLIP's text-image similarity scores by the teacher CLIP's temperature.''' demo = gr.Interface( fn=infer, inputs=[ gr.Image(label='Image', type='pil'), gr.Textbox(label='Classes', placeholder='cat, truck', info='Classes for classification. Separate classes with commas.'), gr.Textbox(label='Prompt/s', placeholder='a photo of a {}.; a blurry photo of a {}.', info='Prompt templates. Use "{}" as placeholder for class. Separate prompts with semi-colons.') ], outputs=gr.Label(label='Class scores'), title=title, description=description, article=article ) demo.launch()