| from typing import Dict, Any |
|
|
| import spacy |
| from environs import Env |
| from huggingface_hub import hf_hub_download |
| from joblib import load |
|
|
| from src.dtos.output.basic import BasicOutput |
|
|
| from src.format import format_model_name_from_path |
| from src.models.bag_of_words.extractor import BagOfWordsExtractor |
| from src.models.bag_of_words.formatter import BagOfWordsFormatter |
| from src.models.bag_of_words.model import BagOfWordsModelContainer |
| from src.models.bag_of_words.predictor import RelevancePredictor |
|
|
| SPACY_MODEL = spacy.load('en_core_web_trf', disable=['parser']) |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str): |
| env = Env() |
| env.read_env() |
|
|
| model_path = env.str("MODEL_PATH") |
| self.model_name = format_model_name_from_path(model_path) |
| downloaded_model_path = hf_hub_download( |
| repo_id="PDAP/url-relevance-models", |
| subfolder=model_path, |
| filename="model.joblib" |
| ) |
| self.model_container: BagOfWordsModelContainer = load(downloaded_model_path) |
| self.extractor = BagOfWordsExtractor(self.model_container.permitted_terms) |
| self.formatter = BagOfWordsFormatter(self.model_container.term_label_encoder) |
| self.predictor = RelevancePredictor(self.model_container.model) |
|
|
| def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: |
| html = inputs["inputs"] |
| bag_of_words = self.extractor.extract_bag_of_words(html) |
| csr = self.formatter.format_bag_of_words(bag_of_words) |
| output = self.predictor.predict_relevance(csr) |
| return BasicOutput( |
| annotation=output.is_relevant, |
| confidence=output.probability, |
| model=self.model_name |
| ).model_dump(mode="json") |
|
|