Spaces:
Running
on
A10G
Running
on
A10G
File size: 2,624 Bytes
f8ea2c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from torch import nn
import torch.nn.functional as F
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
import os
# Models that use mean pooling
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class LanguageModel(nn.Module):
def __init__(self, model='distilbert-base-uncased'):
super(LanguageModel, self).__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModel.from_pretrained(model)
self.model_name = model
# Remove the CLIP vision tower
if "clip" in self.model_name:
self.model.vision_model = None
# Freeze the pre-trained parameters (very important)
for param in self.model.parameters():
param.requires_grad = False
# Make sure to set evaluation mode (also important)
self.model.eval()
def forward(self, text_batch):
inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad(): # Ensure no gradients are computed for this forward pass
if "clip" in self.model_name:
sentence_embedding = self.model.get_text_features(**inputs)
return sentence_embedding
outputs = self.model(**inputs)
if any(model in self.model_name for model in POOL_MODELS):
sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
# Normalize embeddings
sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
else:
sentence_embedding = outputs.last_hidden_state[:, 0, :]
return sentence_embedding
class LMHead(nn.Module):
def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
super(LMHead, self).__init__()
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
#self.gelu = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
embd = self.fc1(x)
embd = F.normalize(embd, p=2, dim=1)
deg_pred = self.fc2(embd)
return embd, deg_pred |