File size: 2,027 Bytes
280d87f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import onnxruntime as ort
import torch

from .config_train import onnx_path, tokenizer
from .DataProcessing import read_input
from .load_data import sorted_tags


class Key_Ner_ONNX_Predictor:
    def __init__(self, model_path, tokenizer, tag_map):
        """
        Initialize the ONNX predictor.
        Args:
            model_path (str): Path to the ONNX model.
            tokenizer (BertTokenizer): Tokenizer to process input sentences.
            tag_map (Dict[int, str]): Mapping of indices to tags.
        """
        self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
        self.tokenizer = tokenizer
        self.tag_map = tag_map

    def predict(self, sentence):
        """
        Predict tags using the ONNX model.
        Args:
            sentence (str): Input sentence.
        Returns:
            Tuple[str, List[str]]: Original sentence and predicted tags.
        """
        sentence = read_input(sentence)
        tokens = self.tokenizer(sentence, return_tensors="np", padding=True, truncation=True)

        # Convert to int64 (ONNX requirement)
        input_ids = tokens["input_ids"].astype(np.int64)
        attention_mask = tokens["attention_mask"].astype(np.int64)

        # Run inference
        outputs = self.session.run(None, {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        })

        logits = outputs[0]
        predicted_tags = np.argmax(logits, axis=2)[0]

        # Convert indices to tags
        predicted_tags = [self.tag_map[idx] for idx in predicted_tags]
        predicted_tags = set(predicted_tags)
        predicted_tags.discard('<pad>')
        predicted_tags = [tag.replace(" ", "_") for tag in predicted_tags]

        return self.tokenizer.decode(input_ids[0], skip_special_tokens=True), predicted_tags

# Initialize ONNX-based predictor
onnx_predictor = Key_Ner_ONNX_Predictor(
    model_path=onnx_path,
    tokenizer=tokenizer,
    tag_map=dict(enumerate(sorted_tags))
)