AhdCompnay commited on
Commit
5199291
1 Parent(s): 84b31d5

Update kpe_ranker.py

Browse files
Files changed (1) hide show
  1. kpe_ranker.py +5 -2
kpe_ranker.py CHANGED
@@ -10,6 +10,9 @@ class KpeRanker:
10
  model_name = os.environ.get("MODEL_NAME")
11
  model_repo = os.environ.get("MODEL_REPO")
12
  model_token = os.environ.get("MODEL_TOKEN")
 
 
 
13
 
14
  local_dir = "./"
15
  model_path = os.path.join(local_dir, model_name)
@@ -17,8 +20,8 @@ class KpeRanker:
17
  hf_hub_download(repo_id=model_repo, filename=model_name, local_dir=local_dir, token=model_token)
18
  TRAINED_MODEL_ADDR = model_path
19
  # TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
20
- self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
21
- self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
22
 
23
 
24
  def extract(self, text, count, using_ner, return_sorted):
 
10
  model_name = os.environ.get("MODEL_NAME")
11
  model_repo = os.environ.get("MODEL_REPO")
12
  model_token = os.environ.get("MODEL_TOKEN")
13
+ ner_model = os.environ.get("NER_MODEL")
14
+ transformer_model = os.environ.get("TRANSFORMER_MODEL")
15
+
16
 
17
  local_dir = "./"
18
  model_path = os.path.join(local_dir, model_name)
 
20
  hf_hub_download(repo_id=model_repo, filename=model_name, local_dir=local_dir, token=model_token)
21
  TRAINED_MODEL_ADDR = model_path
22
  # TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
23
+ self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model= ner_model , device='cpu')
24
+ self.ranker_transformer = SentenceTransformer(transformer_model, device='cpu')
25
 
26
 
27
  def extract(self, text, count, using_ner, return_sorted):