Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import sys | |
import os | |
from .utils import freeze | |
class BaseEmbedder: | |
def __init__(self, conf): | |
self.checkpoint_path = conf.text_embedder.params.checkpoint_path | |
self.tokenizer_path = conf.text_embedder.params.tokenizer_path | |
self.max_length = conf.text_embedder.tokens_lenght | |
self.llm = None | |
def to(self, device='cpu', dtype=torch.float32): | |
self.llm = self.llm.to(device=device, dtype=dtype) | |
return self | |
def freeze(self): | |
self.llm = freeze(self.llm) | |
return self | |
def compile(self): | |
self.llm = torch.compile(self.llm) | |
return self | |
class EmbedderWithTokenizer(BaseEmbedder): | |
def __init__(self, conf): | |
super().__init__(conf) | |
self.tokenizer = None | |
def tokenize(self, text): | |
model_input = self.tokenizer( | |
text, | |
max_length=self.max_length, | |
truncation=True, | |
add_special_tokens=True, | |
padding='max_length', | |
return_tensors='pt' | |
) | |
return model_input.input_ids.to(self.llm.device) | |
def __call__(self, text): | |
return self.llm(self.tokenize(text), output_hidden_states=True)[0] | |
class T5TextEmbedder(EmbedderWithTokenizer): | |
def __init__(self, conf): | |
from transformers import T5EncoderModel, T5Tokenizer | |
super().__init__(conf) | |
self.llm = T5EncoderModel.from_pretrained(self.checkpoint_path) | |
self.tokenizer = T5Tokenizer.from_pretrained(self.tokenizer_path, clean_up_tokenization_spaces=False) | |
def get_text_embedder(conf): | |
return T5TextEmbedder(conf) | |