OnlyBiggg commited on
Commit
765f020
·
1 Parent(s): e41377f

fix model dir

Browse files
Files changed (2) hide show
  1. app/ner/services/ner.py +4 -4
  2. core/conf.py +1 -1
app/ner/services/ner.py CHANGED
@@ -1,8 +1,8 @@
1
  from core.conf import settings
2
 
3
  class NER:
4
- def __init__(self, model_name: str = settings.NER_MODEL_NAME):
5
- self.model_name = model_name
6
  self.model = None
7
  self.tokenizer = None
8
  self.pipeline = None
@@ -10,8 +10,8 @@ class NER:
10
  def load_model(self):
11
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
12
 
13
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
14
- self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
15
  self.pipeline = pipeline(settings.TASK_NAME, model=self.model, tokenizer=self.tokenizer)
16
 
17
  def predict(self, text: str, entity_tag: str = None):
 
1
  from core.conf import settings
2
 
3
  class NER:
4
+ def __init__(self, model_dir: str = settings.NER_MODEL_DIR):
5
+ self.model_dir = model_dir
6
  self.model = None
7
  self.tokenizer = None
8
  self.pipeline = None
 
10
  def load_model(self):
11
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
12
 
13
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
14
+ self.model = AutoModelForTokenClassification.from_pretrained(self.model_dir)
15
  self.pipeline = pipeline(settings.TASK_NAME, model=self.model, tokenizer=self.tokenizer)
16
 
17
  def predict(self, text: str, entity_tag: str = None):
core/conf.py CHANGED
@@ -32,7 +32,7 @@ class Settings(BaseSettings):
32
  # DATABASE_PASSWORD: str
33
 
34
  # MODEl NER
35
- NER_MODEL_NAME: str = 'ner'
36
  TASK_NAME: str = 'ner'
37
 
38
  # FastAPI
 
32
  # DATABASE_PASSWORD: str
33
 
34
  # MODEl NER
35
+ NER_MODEL_DIR: str = '/app/ner/models/ner'
36
  TASK_NAME: str = 'ner'
37
 
38
  # FastAPI