Hansimov commited on
Commit
485cae6
1 Parent(s): f466164

:gem: [Feature] Embedder: check_model_name and switch_model, and separate constants

Browse files
Files changed (2) hide show
  1. configs/constants.py +3 -0
  2. transforms/embed.py +13 -1
configs/constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ AVAILABLE_MODELS = [
2
+ "jinaai/jina-embeddings-v2-base-zh",
3
+ ]
transforms/embed.py CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoModel
7
  from numpy.linalg import norm
8
 
9
  from configs.envs import ENVS
 
10
 
11
  os.environ["HF_ENDPOINT"] = ENVS["HF_ENDPOINT"]
12
  os.environ["HF_TOKEN"] = ENVS["HF_TOKEN"]
@@ -17,13 +18,24 @@ def cosine_similarity(a, b):
17
 
18
 
19
  class JinaAIEmbedder:
20
- def __init__(self, model_name: str = "jinaai/jina-embeddings-v2-base-zh"):
21
  self.model_name = model_name
22
  self.load_model()
23
 
 
 
 
 
 
24
  def load_model(self):
 
25
  self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
26
 
 
 
 
 
 
27
  def encode(self, text: Union[str, list[str]]):
28
  if isinstance(text, str):
29
  text = [text]
 
7
  from numpy.linalg import norm
8
 
9
  from configs.envs import ENVS
10
+ from configs.constants import AVAILABLE_MODELS
11
 
12
  os.environ["HF_ENDPOINT"] = ENVS["HF_ENDPOINT"]
13
  os.environ["HF_TOKEN"] = ENVS["HF_TOKEN"]
 
18
 
19
 
20
  class JinaAIEmbedder:
21
+ def __init__(self, model_name: str = AVAILABLE_MODELS[0]):
22
  self.model_name = model_name
23
  self.load_model()
24
 
25
+ def check_model_name(self):
26
+ if self.model_name not in AVAILABLE_MODELS:
27
+ self.model_name = AVAILABLE_MODELS[0]
28
+ return True
29
+
30
  def load_model(self):
31
+ self.check_model_name()
32
  self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
33
 
34
+ def switch_model(self, model_name: str):
35
+ if model_name != self.model_name:
36
+ self.model_name = model_name
37
+ self.load_model()
38
+
39
  def encode(self, text: Union[str, list[str]]):
40
  if isinstance(text, str):
41
  text = [text]