Hansimov's picture
:boom: [Fix] Permission denied when saving onnx model
93e2e2a
raw
history blame contribute delete
No virus
3.77 kB
import os
import numpy as np
import torch
from pathlib import Path
from typing import Union
from huggingface_hub import hf_hub_download
from numpy.linalg import norm
from onnxruntime import InferenceSession
from tclogger import logger
from transformers import AutoTokenizer, AutoModel
from configs.envs import ENVS
from configs.constants import AVAILABLE_MODELS
if ENVS["HF_ENDPOINT"]:
os.environ["HF_ENDPOINT"] = ENVS["HF_ENDPOINT"]
os.environ["HF_TOKEN"] = ENVS["HF_TOKEN"]
def cosine_similarity(a, b):
return (a @ b.T) / (norm(a) * norm(b))
class JinaAIOnnxEmbedder:
"""https://huggingface.co/jinaai/jina-embeddings-v2-base-zh/discussions/6#65bc55a854ab5eb7b6300893"""
def __init__(self):
self.repo_name = "jinaai/jina-embeddings-v2-base-zh"
self.download_model()
self.load_model()
def download_model(self):
self.onnx_folder = Path(__file__).parents[2] / ".cache"
self.onnx_folder.mkdir(parents=True, exist_ok=True)
self.onnx_filename = "onnx/model_quantized.onnx"
self.onnx_path = self.onnx_folder / self.onnx_filename
if not self.onnx_path.exists():
logger.note("> Downloading ONNX model")
hf_hub_download(
repo_id=self.repo_name,
filename=self.onnx_filename,
local_dir=self.onnx_folder,
local_dir_use_symlinks=False,
)
logger.success(f"+ ONNX model downloaded: {self.onnx_path}")
else:
logger.success(f"+ ONNX model loaded: {self.onnx_path}")
def load_model(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.repo_name, trust_remote_code=True
)
self.session = InferenceSession(self.onnx_path)
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def encode(self, text: str):
inputs = self.tokenizer(text, return_tensors="np")
inputs = {
name: np.array(tensor, dtype=np.int64) for name, tensor in inputs.items()
}
outputs = self.session.run(
output_names=["last_hidden_state"], input_feed=dict(inputs)
)
embeddings = self.mean_pooling(
torch.from_numpy(outputs[0]), torch.from_numpy(inputs["attention_mask"])
)
return embeddings
class JinaAIEmbedder:
def __init__(self, model_name: str = AVAILABLE_MODELS[0]):
self.model_name = model_name
self.load_model()
def check_model_name(self):
if self.model_name not in AVAILABLE_MODELS:
self.model_name = AVAILABLE_MODELS[0]
return True
def load_model(self):
self.check_model_name()
self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
def switch_model(self, model_name: str):
if model_name != self.model_name:
self.model_name = model_name
self.load_model()
def encode(self, text: Union[str, list[str]]):
if isinstance(text, str):
text = [text]
return self.model.encode(text)
if __name__ == "__main__":
# embedder = JinaAIEmbedder()
embedder = JinaAIOnnxEmbedder()
texts = ["How is the weather today?", "今天天气怎么样?"]
embeddings = []
for text in texts:
embeddings.append(embedder.encode(text))
logger.success(embeddings)
print(cosine_similarity(embeddings[0], embeddings[1]))
# python -m transforms.embed