Bhanushray commited on
Commit
7f7b446
·
verified ·
1 Parent(s): eef6783

Update model_handler.py

Browse files
Files changed (1) hide show
  1. model_handler.py +50 -3
model_handler.py CHANGED
@@ -1,6 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from sentence_transformers.cross_encoder import CrossEncoder
2
  import torch
3
 
 
 
 
 
4
  class SimilarityModelHandler:
5
  # HOLDING THE MODEL INSTANCE TO PREVENT RELOADING
6
  SIMILARITY_MODEL_INSTANCE = None
@@ -14,9 +58,12 @@ class SimilarityModelHandler:
14
  print(f"SERVICE IS RUNNING ON DEVICE: {device}")
15
 
16
  # LOADING THE PRE-TRAINED CROSS-ENCODER MODEL
17
- model_Name = 'cross-encoder/stsb-roberta-base'
18
- #cross-encoder/stsb-roberta-large'
19
- SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE = CrossEncoder(model_Name, device=device)
 
 
 
20
  print("MODEL LOADED SUCCESSFULLY.")
21
 
22
  def calculate_Similarity(self, text_One: str, text_Two: str) -> float:
 
1
+ # from sentence_transformers.cross_encoder import CrossEncoder
2
+ # import torch
3
+
4
+ # class SimilarityModelHandler:
5
+ # # HOLDING THE MODEL INSTANCE TO PREVENT RELOADING
6
+ # SIMILARITY_MODEL_INSTANCE = None
7
+
8
+ # def __init__(self):
9
+ # # CONSTRUCTOR: LOADING THE MODEL IF IT DOESN'T EXIST
10
+ # if not SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE:
11
+ # print("INITIALIZING AND LOADING THE MODEL...")
12
+ # # CHECKING FOR GPU, FALLBACK TO CPU
13
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ # print(f"SERVICE IS RUNNING ON DEVICE: {device}")
15
+
16
+ # # LOADING THE PRE-TRAINED CROSS-ENCODER MODEL
17
+ # model_Name = 'cross-encoder/stsb-roberta-base'
18
+ # #cross-encoder/stsb-roberta-large'
19
+ # SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE = CrossEncoder(model_Name, device=device)
20
+ # print("MODEL LOADED SUCCESSFULLY.")
21
+
22
+ # def calculate_Similarity(self, text_One: str, text_Two: str) -> float:
23
+ # """
24
+ # CALCULATES THE SIMILARITY SCORE BETWEEN TWO TEXTS.
25
+ # """
26
+ # # GETTING THE SCORE FROM THE MODEL( 0-1 )
27
+ # finalScore = self.SIMILARITY_MODEL_INSTANCE.predict([(text_One, text_Two)])
28
+
29
+ # # CONVERTING FROM NUMPY ARRAY TO A SIMPLE FLOAT
30
+ # return finalScore.item()
31
+
32
+
33
+ # # CREATING A SINGLE INSTANCE TO BE USED BY THE API
34
+ # MODEL_HANDLER = SimilarityModelHandler()
35
+
36
+
37
+
38
+
39
+
40
+ import os
41
  from sentence_transformers.cross_encoder import CrossEncoder
42
  import torch
43
 
44
+ # SET CACHE DIRECTORY TO A WRITABLE LOCATION
45
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
46
+ os.environ['HF_HOME'] = '/tmp/hf_home'
47
+
48
  class SimilarityModelHandler:
49
  # HOLDING THE MODEL INSTANCE TO PREVENT RELOADING
50
  SIMILARITY_MODEL_INSTANCE = None
 
58
  print(f"SERVICE IS RUNNING ON DEVICE: {device}")
59
 
60
  # LOADING THE PRE-TRAINED CROSS-ENCODER MODEL
61
+ model_Name = 'cross-encoder/stsb-roberta-large'
62
+ SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE = CrossEncoder(
63
+ model_Name,
64
+ device=device,
65
+ cache_folder='/tmp/transformers_cache' # EXPLICIT CACHE FOLDER
66
+ )
67
  print("MODEL LOADED SUCCESSFULLY.")
68
 
69
  def calculate_Similarity(self, text_One: str, text_Two: str) -> float: