gcptest / extractor.py
tom-lewis-code's picture
Upload folder using huggingface_hub
3b816e3 verified
import logging
from typing import List, Dict
import torch
from gliner import GLiNER
# Configure logging
logging.basicConfig(level=logging.INFO)
TAU = 0.3
class EntityExtractor:
def __init__(self, extractor_model: str):
"""
Initializes the EntityExtractor class with an extractor model.
Args:
extractor_model (str): The model name for the entity extractor.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.extractor = self.load_extractor(extractor_model).to(self.device)
@staticmethod
def load_extractor(model_name: str) -> GLiNER:
"""Loads the entity extractor model."""
return GLiNER.from_pretrained(model_name, load_tokenizer=True)
def extract_entities(self, text: str, entity_types: List[str] = None) -> List[Dict[str, str]]:
if entity_types is None:
entity_types = ["brand", "color_finish", "style", "collection", "dimension", "feature", "product_type", "part_number"]
output = self.extractor.predict_entities(
text, entity_types, threshold=TAU, flat_ner=True, multi_label=False
)
extracted_entities = []
for entity in output:
extracted_entities.append({
"entity": entity["text"],
"entity_type": entity["label"]
})
return extracted_entities