Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,663 Bytes
9d3c2b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
import numpy as np
import sys
import os
from .utils import freeze
class BaseEmbedder:
def __init__(self, conf):
self.checkpoint_path = conf.text_embedder.params.checkpoint_path
self.tokenizer_path = conf.text_embedder.params.tokenizer_path
self.max_length = conf.text_embedder.tokens_lenght
self.llm = None
def to(self, device='cpu', dtype=torch.float32):
self.llm = self.llm.to(device=device, dtype=dtype)
return self
def freeze(self):
self.llm = freeze(self.llm)
return self
def compile(self):
self.llm = torch.compile(self.llm)
return self
class EmbedderWithTokenizer(BaseEmbedder):
def __init__(self, conf):
super().__init__(conf)
self.tokenizer = None
def tokenize(self, text):
model_input = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
add_special_tokens=True,
padding='max_length',
return_tensors='pt'
)
return model_input.input_ids.to(self.llm.device)
def __call__(self, text):
return self.llm(self.tokenize(text), output_hidden_states=True)[0]
class T5TextEmbedder(EmbedderWithTokenizer):
def __init__(self, conf):
from transformers import T5EncoderModel, T5Tokenizer
super().__init__(conf)
self.llm = T5EncoderModel.from_pretrained(self.checkpoint_path)
self.tokenizer = T5Tokenizer.from_pretrained(self.tokenizer_path, clean_up_tokenization_spaces=False)
def get_text_embedder(conf):
return T5TextEmbedder(conf)
|