Spaces:
Running
on
A10G
Running
on
A10G
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer | |
import os | |
# Models that use mean pooling | |
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"} | |
#Mean Pooling - Take attention mask into account for correct averaging | |
def mean_pooling(model_output, attention_mask): | |
token_embeddings = model_output[0] #First element of model_output contains all token embeddings | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
class LanguageModel(nn.Module): | |
def __init__(self, model='distilbert-base-uncased'): | |
super(LanguageModel, self).__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained(model) | |
self.model = AutoModel.from_pretrained(model) | |
self.model_name = model | |
# Remove the CLIP vision tower | |
if "clip" in self.model_name: | |
self.model.vision_model = None | |
# Freeze the pre-trained parameters (very important) | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
# Make sure to set evaluation mode (also important) | |
self.model.eval() | |
def forward(self, text_batch): | |
inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): # Ensure no gradients are computed for this forward pass | |
if "clip" in self.model_name: | |
sentence_embedding = self.model.get_text_features(**inputs) | |
return sentence_embedding | |
outputs = self.model(**inputs) | |
if any(model in self.model_name for model in POOL_MODELS): | |
sentence_embeddings = mean_pooling(outputs, inputs['attention_mask']) | |
# Normalize embeddings | |
sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1) | |
else: | |
sentence_embedding = outputs.last_hidden_state[:, 0, :] | |
return sentence_embedding | |
class LMHead(nn.Module): | |
def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4): | |
super(LMHead, self).__init__() | |
self.fc1 = nn.Linear(embedding_dim, hidden_dim) | |
#self.gelu = nn.GELU() | |
self.fc2 = nn.Linear(hidden_dim, num_classes) | |
def forward(self, x): | |
embd = self.fc1(x) | |
embd = F.normalize(embd, p=2, dim=1) | |
deg_pred = self.fc2(embd) | |
return embd, deg_pred |