YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
this is labelled as "poor-man's clip" in a very old google colab i found, transformers==4.21.0 required
some old incomplete code i found to maybe run this:
from torch import nn, optim
import torch
import os
import pandas as pd
import torch.nn.functional as F
class CFG:
debug = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'facebook/dino-vitb16'
image_embedding = 768
text_encoder_model = "microsoft/deberta-v3-base"
text_embedding = 768
text_tokenizer = "microsoft/deberta-v3-base"
max_length = 200
pretrained = True # for both image encoder and text encoder
trainable = True # for both image encoder and text encoder
temperature = 1.0
# image size
size = 224
# for projection head; used for both image and text encoders
projection_dim = 768
projection_width = 768
projection_heads = 16
class ImageEncoder(nn.Module):
"""
Encode images to a fixed size vector
"""
def __init__(
self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
):
super().__init__()
#self.model = timm.create_model(
# model_name, pretrained, num_classes=0, global_pool="avg"
#)
self.model = ViTModel.from_pretrained(model_name)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
return self.model(x).last_hidden_state[:,0,:]
class TextEncoder(nn.Module):
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
super().__init__()
if pretrained:
self.model = DebertaV2Model.from_pretrained(model_name)
else:
self.model = DebertaV2Model(config=DistilBertConfig())
for p in self.model.parameters():
p.requires_grad = trainable
# we are using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]
from collections import OrderedDict
from torch import nn
import torch
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim=384,
dropout=0.1
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = QuickGELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
class CLIPModel(nn.Module):
def __init__(self):
super().__init__()
self.feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
self.image_encoder = ViTModel.from_pretrained('facebook/dino-vitb16').to('cuda')
self.text_encoder = DebertaV2Model.from_pretrained("microsoft/deberta-v3-base").to('cuda')
self.image_projection = ProjectionHead(768).to('cuda')
self.text_projection = ProjectionHead(768).to('cuda')
self.temperature = 1.0
def encode_image(self, x):
x = self.image_encoder(x).last_hidden_state[:,0,:]
x = self.image_projection(x)
return x
def encode_text(self, x):
x = self.text_encoder(x['input_ids'].cuda()).last_hidden_state[:,0,:]
x = self.text_projection(x)
return x
def tokenize(self, x):
return tokenizer(x, return_tensors='pt', padding=True, truncation=True)
def preprocess(self, x):
x = self.feature_extractor(x, return_tensors='pt')
return x.convert_to_tensors()['pixel_values'].to('cuda')
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
loading, probably?
from huggingface_hub import hf_hub_url, cached_download
model_name = "model.pt"
model_url = hf_hub_url("crumb/pmclip-test-run-checkpoints-10", filename=model_name)
file_path = cached_download(model_url)
tokenizer = DebertaV2Tokenizer.from_pretrained(CFG.text_tokenizer)
model = CLIPModel().to(CFG.device)
model.load_state_dict(torch.load(file_path))
model = model.eval().cuda()
print("num params", sum(p.numel() for p in model.parameters()) // 10000 / 100, 'M')
# num params 271.1 M