Bhanushray commited on
Commit
eef6783
·
verified ·
1 Parent(s): 3fabf66

Update model_handler.py

Browse files
Files changed (1) hide show
  1. model_handler.py +33 -32
model_handler.py CHANGED
@@ -1,33 +1,34 @@
1
- from sentence_transformers.cross_encoder import CrossEncoder
2
- import torch
3
-
4
- class SimilarityModelHandler:
5
- # HOLDING THE MODEL INSTANCE TO PREVENT RELOADING
6
- SIMILARITY_MODEL_INSTANCE = None
7
-
8
- def __init__(self):
9
- # CONSTRUCTOR: LOADING THE MODEL IF IT DOESN'T EXIST
10
- if not SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE:
11
- print("INITIALIZING AND LOADING THE MODEL...")
12
- # CHECKING FOR GPU, FALLBACK TO CPU
13
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
- print(f"SERVICE IS RUNNING ON DEVICE: {device}")
15
-
16
- # LOADING THE PRE-TRAINED CROSS-ENCODER MODEL
17
- model_Name = 'cross-encoder/stsb-roberta-large'
18
- SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE = CrossEncoder(model_Name, device=device)
19
- print("MODEL LOADED SUCCESSFULLY.")
20
-
21
- def calculate_Similarity(self, text_One: str, text_Two: str) -> float:
22
- """
23
- CALCULATES THE SIMILARITY SCORE BETWEEN TWO TEXTS.
24
- """
25
- # GETTING THE SCORE FROM THE MODEL( 0-1 )
26
- finalScore = self.SIMILARITY_MODEL_INSTANCE.predict([(text_One, text_Two)])
27
-
28
- # CONVERTING FROM NUMPY ARRAY TO A SIMPLE FLOAT
29
- return finalScore.item()
30
-
31
-
32
- # CREATING A SINGLE INSTANCE TO BE USED BY THE API
 
33
  MODEL_HANDLER = SimilarityModelHandler()
 
1
+ from sentence_transformers.cross_encoder import CrossEncoder
2
+ import torch
3
+
4
+ class SimilarityModelHandler:
5
+ # HOLDING THE MODEL INSTANCE TO PREVENT RELOADING
6
+ SIMILARITY_MODEL_INSTANCE = None
7
+
8
+ def __init__(self):
9
+ # CONSTRUCTOR: LOADING THE MODEL IF IT DOESN'T EXIST
10
+ if not SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE:
11
+ print("INITIALIZING AND LOADING THE MODEL...")
12
+ # CHECKING FOR GPU, FALLBACK TO CPU
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ print(f"SERVICE IS RUNNING ON DEVICE: {device}")
15
+
16
+ # LOADING THE PRE-TRAINED CROSS-ENCODER MODEL
17
+ model_Name = 'cross-encoder/stsb-roberta-base'
18
+ #cross-encoder/stsb-roberta-large'
19
+ SimilarityModelHandler.SIMILARITY_MODEL_INSTANCE = CrossEncoder(model_Name, device=device)
20
+ print("MODEL LOADED SUCCESSFULLY.")
21
+
22
+ def calculate_Similarity(self, text_One: str, text_Two: str) -> float:
23
+ """
24
+ CALCULATES THE SIMILARITY SCORE BETWEEN TWO TEXTS.
25
+ """
26
+ # GETTING THE SCORE FROM THE MODEL( 0-1 )
27
+ finalScore = self.SIMILARITY_MODEL_INSTANCE.predict([(text_One, text_Two)])
28
+
29
+ # CONVERTING FROM NUMPY ARRAY TO A SIMPLE FLOAT
30
+ return finalScore.item()
31
+
32
+
33
+ # CREATING A SINGLE INSTANCE TO BE USED BY THE API
34
  MODEL_HANDLER = SimilarityModelHandler()