File size: 2,917 Bytes
4ccf6a3
0e97d35
 
e691ea0
cf575f8
dd2409d
3f6f474
cf575f8
 
d654474
 
4ccf6a3
 
 
dd2409d
 
 
d654474
 
cf575f8
 
 
f54f0db
cf575f8
e691ea0
cf575f8
 
0e97d35
 
cf575f8
dd2409d
 
 
 
f54f0db
dd2409d
 
e691ea0
dd2409d
e691ea0
dd2409d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from functools import lru_cache

import torch
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

list_models = [
    'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
    'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
    'sentence-transformers/all-mpnet-base-v2',
    'sentence-transformers/all-MiniLM-L12-v2',
    'cyclone/simcse-chinese-roberta-wwm-ext',
    'bert-base-chinese',
    'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese',
]


class SBert:
    def __init__(self, path):
        logger.info(f'Start loading {self.__class__} from {path} ...')
        self.model = SentenceTransformer(path, device=DEVICE)
        logger.info(f'Load {self.__class__} from {path} ...')

    @lru_cache(maxsize=10000)
    def __call__(self, x) -> torch.Tensor:
        y = self.model.encode(x, convert_to_tensor=True)
        return y


class ModelWithPooling:
    def __init__(self, path):
        logger.info(f'Start loading {self.__class__} from {path} ...')
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModel.from_pretrained(path)
        logger.info(f'Load {self.__class__} from {path} ...')

    @lru_cache(maxsize=100)
    @torch.no_grad()
    def __call__(self, text: str, pooling='mean'):
        inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
        outputs = self.model(**inputs, output_hidden_states=True)

        if pooling == 'cls':
            o = outputs.last_hidden_state[:, 0]  # [b, h]

        elif pooling == 'pooler':
            o = outputs.pooler_output  # [b, h]

        elif pooling in ['mean', 'last-avg']:
            last = outputs.last_hidden_state.transpose(1, 2)  # [b, h, s]
            o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)  # [b, h]

        elif pooling == 'first-last-avg':
            first = outputs.hidden_states[1].transpose(1, 2)  # [b, h, s]
            last = outputs.hidden_states[-1].transpose(1, 2)  # [b, h, s]
            first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1)  # [b, h]
            last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)  # [b, h]
            avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1)  # [b, 2, h]
            o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)  # [b, h]

        else:
            raise Exception(f'Unknown pooling {pooling}')

        o = o.squeeze(0)
        return o


def test_sbert():
    m = SBert('bert-base-chinese')
    o = m('hello')
    print(o.size())
    assert o.size() == (768,)


def test_hf_model():
    m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese')
    o = m('hello', pooling='cls')
    print(o.size())
    assert o.size() == (768,)