Spaces:
Running
Running
import logging | |
from typing import List, Dict | |
import torch | |
from gliner import GLiNER | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
TAU = 0.3 | |
class EntityExtractor: | |
def __init__(self, extractor_model: str): | |
""" | |
Initializes the EntityExtractor class with an extractor model. | |
Args: | |
extractor_model (str): The model name for the entity extractor. | |
""" | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.extractor = self.load_extractor(extractor_model).to(self.device) | |
def load_extractor(model_name: str) -> GLiNER: | |
"""Loads the entity extractor model.""" | |
return GLiNER.from_pretrained(model_name, load_tokenizer=True) | |
def extract_entities(self, text: str, entity_types: List[str] = None) -> List[Dict[str, str]]: | |
if entity_types is None: | |
entity_types = ["brand", "color_finish", "style", "collection", "dimension", "feature", "product_type", "part_number"] | |
output = self.extractor.predict_entities( | |
text, entity_types, threshold=TAU, flat_ner=True, multi_label=False | |
) | |
extracted_entities = [] | |
for entity in output: | |
extracted_entities.append({ | |
"entity": entity["text"], | |
"entity_type": entity["label"] | |
}) | |
return extracted_entities |