charlieoneill's picture
Initial commit of my_modernbert_person_embeddings from local files
0df2f2a
raw
history blame
994 Bytes
from transformers import AutoTokenizer, AutoModel
import torch
from typing import List
from model import PersonEmbeddings
class CustomEmbeddingPipeline:
def __init__(self, model_id="answerdotai/ModernBERT-base"):
# Load your base tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load your PersonEmbeddings
self.model = PersonEmbeddings(model_id)
ckpt_path = "pytorch_model.bin"
state_dict = torch.load(ckpt_path)
self.model.load_state_dict(state_dict)
self.model.eval()
def __call__(self, text: str) -> List[float]:
# Tokenize
inputs = self.tokenizer([text], padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
emb = self.model(inputs["input_ids"], inputs["attention_mask"])
# Return the embedding of shape (1, 1536) as a Python list
return emb[0].tolist()
def pipeline(*args, **kwargs):
return CustomEmbeddingPipeline()