File size: 3,772 Bytes
4bc1e6a
 
26cd288
 
 
 
4bc1e6a
 
26cd288
4bc1e6a
26cd288
 
 
4bc1e6a
 
485cae6
4bc1e6a
0ebbe50
 
4bc1e6a
 
 
 
 
 
 
26cd288
 
 
 
 
 
 
 
 
93e2e2a
 
26cd288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bc1e6a
485cae6
4bc1e6a
 
 
485cae6
 
 
 
 
4bc1e6a
485cae6
4bc1e6a
 
485cae6
 
 
 
 
4bc1e6a
 
 
 
 
 
 
26cd288
 
 
 
 
 
4bc1e6a
26cd288
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os

import numpy as np
import torch

from pathlib import Path
from typing import Union

from huggingface_hub import hf_hub_download
from numpy.linalg import norm
from onnxruntime import InferenceSession
from tclogger import logger
from transformers import AutoTokenizer, AutoModel

from configs.envs import ENVS
from configs.constants import AVAILABLE_MODELS

if ENVS["HF_ENDPOINT"]:
    os.environ["HF_ENDPOINT"] = ENVS["HF_ENDPOINT"]
os.environ["HF_TOKEN"] = ENVS["HF_TOKEN"]


def cosine_similarity(a, b):
    return (a @ b.T) / (norm(a) * norm(b))


class JinaAIOnnxEmbedder:
    """https://huggingface.co/jinaai/jina-embeddings-v2-base-zh/discussions/6#65bc55a854ab5eb7b6300893"""

    def __init__(self):
        self.repo_name = "jinaai/jina-embeddings-v2-base-zh"
        self.download_model()
        self.load_model()

    def download_model(self):
        self.onnx_folder = Path(__file__).parents[2] / ".cache"
        self.onnx_folder.mkdir(parents=True, exist_ok=True)
        self.onnx_filename = "onnx/model_quantized.onnx"
        self.onnx_path = self.onnx_folder / self.onnx_filename
        if not self.onnx_path.exists():
            logger.note("> Downloading ONNX model")
            hf_hub_download(
                repo_id=self.repo_name,
                filename=self.onnx_filename,
                local_dir=self.onnx_folder,
                local_dir_use_symlinks=False,
            )
            logger.success(f"+ ONNX model downloaded: {self.onnx_path}")
        else:
            logger.success(f"+ ONNX model loaded: {self.onnx_path}")

    def load_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.repo_name, trust_remote_code=True
        )
        self.session = InferenceSession(self.onnx_path)

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
            input_mask_expanded.sum(1), min=1e-9
        )

    def encode(self, text: str):
        inputs = self.tokenizer(text, return_tensors="np")
        inputs = {
            name: np.array(tensor, dtype=np.int64) for name, tensor in inputs.items()
        }
        outputs = self.session.run(
            output_names=["last_hidden_state"], input_feed=dict(inputs)
        )
        embeddings = self.mean_pooling(
            torch.from_numpy(outputs[0]), torch.from_numpy(inputs["attention_mask"])
        )
        return embeddings


class JinaAIEmbedder:
    def __init__(self, model_name: str = AVAILABLE_MODELS[0]):
        self.model_name = model_name
        self.load_model()

    def check_model_name(self):
        if self.model_name not in AVAILABLE_MODELS:
            self.model_name = AVAILABLE_MODELS[0]
        return True

    def load_model(self):
        self.check_model_name()
        self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)

    def switch_model(self, model_name: str):
        if model_name != self.model_name:
            self.model_name = model_name
            self.load_model()

    def encode(self, text: Union[str, list[str]]):
        if isinstance(text, str):
            text = [text]
        return self.model.encode(text)


if __name__ == "__main__":
    # embedder = JinaAIEmbedder()
    embedder = JinaAIOnnxEmbedder()
    texts = ["How is the weather today?", "今天天气怎么样?"]
    embeddings = []
    for text in texts:
        embeddings.append(embedder.encode(text))
    logger.success(embeddings)
    print(cosine_similarity(embeddings[0], embeddings[1]))

    # python -m transforms.embed