from functools import lru_cache import torch from loguru import logger from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModel DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' list_models = [ 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'sentence-transformers/all-mpnet-base-v2', 'sentence-transformers/all-MiniLM-L12-v2', 'cyclone/simcse-chinese-roberta-wwm-ext', 'bert-base-chinese', 'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese', ] class SBert: def __init__(self, path): logger.info(f'Start loading {self.__class__} from {path} ...') self.model = SentenceTransformer(path, device=DEVICE) logger.info(f'Load {self.__class__} from {path} ...') @lru_cache(maxsize=10000) def __call__(self, x) -> torch.Tensor: y = self.model.encode(x, convert_to_tensor=True) return y class ModelWithPooling: def __init__(self, path): logger.info(f'Start loading {self.__class__} from {path} ...') self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModel.from_pretrained(path) logger.info(f'Load {self.__class__} from {path} ...') @lru_cache(maxsize=100) @torch.no_grad() def __call__(self, text: str, pooling='mean'): inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") outputs = self.model(**inputs, output_hidden_states=True) if pooling == 'cls': o = outputs.last_hidden_state[:, 0] # [b, h] elif pooling == 'pooler': o = outputs.pooler_output # [b, h] elif pooling in ['mean', 'last-avg']: last = outputs.last_hidden_state.transpose(1, 2) # [b, h, s] o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] elif pooling == 'first-last-avg': first = outputs.hidden_states[1].transpose(1, 2) # [b, h, s] last = outputs.hidden_states[-1].transpose(1, 2) # [b, h, s] first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [b, 2, h] o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [b, h] else: raise Exception(f'Unknown pooling {pooling}') o = o.squeeze(0) return o def test_sbert(): m = SBert('bert-base-chinese') o = m('hello') print(o.size()) assert o.size() == (768,) def test_hf_model(): m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese') o = m('hello', pooling='cls') print(o.size()) assert o.size() == (768,)