mohammadkrb commited on
Commit
07cbdb5
1 Parent(s): 8cb9a5d

set local dir

Browse files
Files changed (1) hide show
  1. kpe_ranker.py +5 -7
kpe_ranker.py CHANGED
@@ -7,13 +7,11 @@ from huggingface_hub import hf_hub_download
7
 
8
  class KpeRanker:
9
  def __init__(self):
10
- model_path = "/root/.cache/huggingface/hub/models--ahdsoft--persian-keyphrase-extraction-model/trained_model_10000.pt"
11
- if os.path.isfile(model_path):
12
- TRAINED_MODEL_ADDR = model_path
13
- else:
14
- hf_hub_download(repo_id="ahdsoft/persian-keyphrase-extraction-model", filename="trained_model_10000.pt", token="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")
15
- TRAINED_MODEL_ADDR = model_path
16
-
17
  # TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
18
  self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
19
  self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
 
7
 
8
  class KpeRanker:
9
  def __init__(self):
10
+ local_dir = "./"
11
+ model_path = os.path.join(local_dir, 'trained_model_10000.pt')
12
+ if not os.path.isfile(model_path):
13
+ hf_hub_download(repo_id="ahdsoft/persian-keyphrase-extraction-model", filename="trained_model_10000.pt", local_dir=local_dir, token="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")
14
+ TRAINED_MODEL_ADDR = model_path
 
 
15
  # TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
16
  self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
17
  self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')