| import os |
| import onnxruntime as ort |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| from mentioned.model import ModelRegistry, LitMentionDetector |
| from mentioned.data import DataRegistry |
|
|
|
|
| class InferenceMentionDetector(nn.Module): |
| def __init__(self, encoder, mention_detector): |
| super().__init__() |
| self.encoder = encoder |
| self.mention_detector = mention_detector |
|
|
| def forward(self, input_ids, attention_mask, word_ids): |
| """ |
| Inputs (Tensors): |
| input_ids: (B, Seq_Len) |
| attention_mask: (B, Seq_Len) |
| word_ids: (B, Seq_Len) -> Word index per token, -1 padding. |
| |
| Returns (Tensors): |
| start_probs: (B, Num_Words) |
| end_probs: (B, Num_Words, Num_Words) |
| """ |
| |
| word_embeddings = self.encoder( |
| input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids |
| ) |
| |
| start_logits, end_logits = self.mention_detector(word_embeddings) |
| |
| start_probs = torch.sigmoid(start_logits) |
| |
| end_probs = torch.sigmoid(end_logits) |
|
|
| return start_probs, end_probs |
|
|
|
|
| class InferenceMentionLabeler(nn.Module): |
| def __init__(self, encoder, mention_detector, mention_labeler, id2label): |
| super().__init__() |
| self.encoder = encoder |
| self.mention_detector = mention_detector |
| self.mention_labeler = mention_labeler |
| self.id2label = id2label |
|
|
| def forward(self, input_ids, attention_mask, word_ids): |
| """ |
| Pure tensor forward pass for ONNX export. |
| |
| Returns (Tensors): |
| start_probs: (B, N) |
| end_probs: (B, N, N) |
| label_probs: (B, N, N, C) or dummy empty tensor |
| """ |
| |
| word_embeddings = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| word_ids=word_ids |
| ) |
| start_logits, end_logits = self.mention_detector(word_embeddings) |
| start_probs = torch.sigmoid(start_logits) |
| end_probs = torch.sigmoid(end_logits) |
| entity_logits = self.mention_labeler(word_embeddings) |
| label_probs = torch.softmax(entity_logits, dim=-1) |
| return start_probs, end_probs, label_probs |
|
|
|
|
| class MentionProcessor: |
| def __init__(self, tokenizer, max_length: int = 512): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __call__(self, docs: list[list[str]]): |
| """ |
| Converts raw word lists into tensors. |
| Args: |
| docs: List of documents, where each doc is a list of words. |
| Example: [["Hello", "world"], ["Testing", "this"]] |
| """ |
| inputs = self.tokenizer( |
| docs, |
| is_split_into_words=True, |
| return_tensors="pt", |
| truncation=True, |
| max_length=self.max_length, |
| padding=True, |
| return_attention_mask=True, |
| ) |
|
|
| |
| |
| batch_word_ids = [] |
| for i in range(len(docs)): |
| |
| w_ids = [w if w is not None else -1 for w in inputs.word_ids(batch_index=i)] |
| batch_word_ids.append(torch.tensor(w_ids)) |
| word_ids_tensor = torch.stack(batch_word_ids) |
|
|
| return { |
| "input_ids": inputs["input_ids"], |
| "attention_mask": inputs["attention_mask"], |
| "word_ids": word_ids_tensor, |
| } |
|
|
|
|
| class ONNXMentionDetectorPipeline: |
| def __init__(self, model_path: str, tokenizer, threshold: float = 0.5): |
| |
| |
| self.session = ort.InferenceSession( |
| model_path, |
| providers=['CPUExecutionProvider'] |
| ) |
| self.tokenizer = tokenizer |
| |
| self.processor = MentionProcessor(tokenizer) |
| self.threshold = threshold |
|
|
| def predict(self, docs: list[list[str]]): |
| batch = self.processor(docs) |
| onnx_inputs = { |
| "input_ids": batch["input_ids"].numpy(), |
| "attention_mask": batch["attention_mask"].numpy(), |
| "word_ids": batch["word_ids"].numpy() |
| } |
| start_probs, end_probs = self.session.run(None, onnx_inputs) |
|
|
| |
| results = [] |
| for i in range(len(docs)): |
| doc_mentions = [] |
| doc_len = len(docs[i]) |
| is_start = start_probs[i, :doc_len] > self.threshold |
| is_span = end_probs[i, :doc_len, :doc_len] > self.threshold |
| upper_tri = np.triu(np.ones((doc_len, doc_len), dtype=bool)) |
| combined_mask = is_span & is_start[:, None] & upper_tri |
| final_indices = np.argwhere(combined_mask) |
|
|
| for s_idx, e_idx in final_indices: |
| |
| score = end_probs[i, s_idx, e_idx] |
| doc_mentions.append({ |
| "start": int(s_idx), |
| "end": int(e_idx), |
| "score": round(float(score), 4), |
| "text": " ".join(docs[i][s_idx:e_idx + 1]), |
| }) |
| results.append(doc_mentions) |
|
|
| return results |
|
|
|
|
| class ONNXMentionLabelerPipeline: |
| def __init__(self, model_path: str, tokenizer, id2label: dict = None, threshold: float = 0.5): |
| |
| self.session = ort.InferenceSession( |
| model_path, |
| providers=['CPUExecutionProvider'] |
| ) |
| self.tokenizer = tokenizer |
| self.processor = MentionProcessor(tokenizer) |
| self.threshold = threshold |
| |
| self.id2label = id2label |
|
|
| def predict(self, docs: list[list[str]]): |
| batch = self.processor(docs) |
| onnx_inputs = { |
| "input_ids": batch["input_ids"].numpy(), |
| "attention_mask": batch["attention_mask"].numpy(), |
| "word_ids": batch["word_ids"].numpy() |
| } |
| |
| |
| start_probs, end_probs, label_probs = self.session.run(None, onnx_inputs) |
|
|
| results = [] |
| for i in range(len(docs)): |
| doc_mentions = [] |
| doc_len = len(docs[i]) |
| is_start = start_probs[i, :doc_len] > self.threshold |
| is_span = end_probs[i, :doc_len, :doc_len] > self.threshold |
| upper_tri = np.triu(np.ones((doc_len, doc_len), dtype=bool)) |
| combined_mask = is_span & is_start[:, None] & upper_tri |
| final_indices = np.argwhere(combined_mask) |
|
|
| for s_idx, e_idx in final_indices: |
| |
| det_score = float(end_probs[i, s_idx, e_idx]) |
| class_probs = label_probs[i, s_idx, e_idx] |
| label_id = int(np.argmax(class_probs)) |
| label_score = float(class_probs[label_id]) |
| |
| doc_mentions.append({ |
| "start": int(s_idx), |
| "end": int(e_idx), |
| "text": " ".join(docs[i][s_idx : e_idx + 1]), |
| "score": round(det_score, 4), |
| "label": self.id2label.get(label_id, str(label_id)), |
| "label_score": round(label_score, 4), |
| }) |
| results.append(doc_mentions) |
|
|
| return results |
|
|
|
|
| def create_inference_model( |
| repo_id: str, |
| encoder_id: str, |
| model_factory: str, |
| data_factory: str, |
| device: str = "cpu", |
| ): |
| """ |
| Factory to load a trained model from HF Hub and wrap it for ONNX/Inference. |
| """ |
| data = DataRegistry.get(data_factory)() |
| fresh_bundle = ModelRegistry.get(model_factory)(data, encoder_id) |
| labeler = getattr(fresh_bundle, "mention_labeler", None) |
| l2id = getattr(fresh_bundle, "label2id", None) |
|
|
| lit_model = LitMentionDetector.from_pretrained( |
| repo_id, |
| tokenizer=fresh_bundle.tokenizer, |
| encoder=fresh_bundle.encoder, |
| mention_detector=fresh_bundle.mention_detector, |
| label2id=l2id, |
| mention_labeler=labeler, |
| |
| ) |
| lit_model.to(device) |
| lit_model.eval() |
| if l2id is not None: |
| id2l = {v: k for k, v in l2id.items()} |
| inference_model = InferenceMentionLabeler( |
| encoder=lit_model.encoder, |
| mention_detector=lit_model.mention_detector, |
| mention_labeler=lit_model.mention_labeler, |
| id2label=id2l, |
| ) |
| else: |
| inference_model = InferenceMentionDetector( |
| encoder=lit_model.encoder, mention_detector=lit_model.mention_detector |
| ) |
| inference_model.tokenizer = lit_model.tokenizer |
| inference_model.max_length = lit_model.encoder.max_length |
|
|
| return inference_model.eval() |
|
|
|
|
| def compile_detector(model, output_dir="model_v1_onnx"): |
| """ONNX export with dynamic axes for.""" |
| model.eval() |
| os.makedirs(output_dir, exist_ok=True) |
| model.tokenizer.save_pretrained(output_dir) |
| onnx_path = os.path.join(output_dir, "model.onnx") |
| dynamic_axes = { |
| "input_ids": {0: "batch", 1: "sequence"}, |
| "attention_mask": {0: "batch", 1: "sequence"}, |
| "word_ids": {0: "batch", 1: "sequence"}, |
| "start_probs": {0: "batch", 1: "num_words"}, |
| "end_probs": {0: "batch", 1: "num_words", 2: "num_words"} |
| } |
|
|
| |
| dummy_inputs = ( |
| torch.randint(0, 100, (1, 16), dtype=torch.long), |
| torch.ones((1, 16), dtype=torch.long), |
| torch.arange(16, dtype=torch.long).unsqueeze(0) |
| ) |
|
|
| print("🚀 Re-exporting with legacy engine (dynamo=False)...") |
|
|
| torch.onnx.export( |
| model, |
| dummy_inputs, |
| onnx_path, |
| export_params=True, |
| opset_version=17, |
| do_constant_folding=True, |
| input_names=["input_ids", "attention_mask", "word_ids"], |
| output_names=["start_probs", "end_probs"], |
| dynamic_axes=dynamic_axes, |
| dynamo=False |
| ) |
| print(f"✅ Exported to {output_dir}! Checking dimensions...") |
|
|
| |
| sess = ort.InferenceSession(onnx_path) |
| for input_meta in sess.get_inputs(): |
| print(f"Input '{input_meta.name}' shape: {input_meta.shape}") |
|
|
|
|
| def compile_labeler(model, output_dir="labeler_onnx"): |
| model.cpu().eval() |
| os.makedirs(output_dir, exist_ok=True) |
| model.tokenizer.save_pretrained(output_dir) |
| onnx_path = os.path.join(output_dir, "model.onnx") |
|
|
| print(f"🛠️ Exporting {model.__class__.__name__} to {onnx_path}...") |
|
|
| |
| dummy_inputs = ( |
| torch.randint(0, 50000, (2, 16), dtype=torch.long), |
| torch.ones((2, 16), dtype=torch.long), |
| torch.arange(16, dtype=torch.long).unsqueeze(0).repeat(2, 1) |
| ) |
|
|
| |
| dynamic_axes = { |
| "input_ids": {0: "batch", 1: "seq_ids"}, |
| "attention_mask": {0: "batch", 1: "seq_mask"}, |
| "word_ids": {0: "batch", 1: "seq_words"}, |
| "start_probs": {0: "batch", 1: "num_words"}, |
| "end_probs": {0: "batch", 1: "num_words", 2: "num_words"}, |
| "label_probs": {0: "batch", 1: "num_words", 2: "num_words", 3: "num_classes"} |
| } |
|
|
| torch.onnx.export( |
| model, |
| dummy_inputs, |
| onnx_path, |
| export_params=True, |
| opset_version=17, |
| do_constant_folding=True, |
| input_names=['input_ids', 'attention_mask', 'word_ids'], |
| output_names=['start_probs', 'end_probs', 'label_probs'], |
| dynamic_axes=dynamic_axes, |
| |
| training=torch.onnx.TrainingMode.EVAL, |
| dynamo=False |
| ) |
| print("✅ Export finished successfully!") |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| model_factory = "model_v2" |
| data_factory = "litbank_entities" |
| inference_model_path = "model_v2_onnx" |
| repo_id = "kadarakos/entity-labeler-poc" |
| encoder_id = "distilroberta-base" |
| inference_model = create_inference_model( |
| repo_id, |
| encoder_id, |
| model_factory, |
| data_factory, |
| ) |
| if isinstance(inference_model, InferenceMentionDetector): |
| compile_detector(inference_model, inference_model_path) |
| pipeline = ONNXMentionDetectorPipeline( |
| model_path=os.path.join(inference_model_path, "model.onnx"), |
| tokenizer=inference_model.tokenizer, |
| |
| threshold=0.3, |
| ) |
| else: |
| print(inference_model) |
| compile_labeler(inference_model, inference_model_path) |
| pipeline = ONNXMentionLabelerPipeline( |
| model_path=os.path.join(inference_model_path, "model.onnx"), |
| tokenizer=inference_model.tokenizer, |
| threshold=0.5, |
| id2label=inference_model.id2label, |
| ) |
| print("FUCK") |
| docs = [ |
| "Does this model actually work?".split(), |
| "The name of the mage is Bubba.".split(), |
| "It was quite a sunny day when the model finally started working.".split(), |
| "Albert Einstein was a theoretical physicist who developed the theory of relativity".split(), |
| "Apple Inc. and Microsoft are competing in the cloud computing market".split(), |
| "New York City is often called the Big Apple".split(), |
| "The Great Barrier Reef is the world's largest coral reef system".split(), |
| "Marie Curie was the first woman to win a Nobel Prize".split(), |
| ] |
|
|
| batch_mentions = pipeline.predict(docs) |
| for i, mentions in enumerate(batch_mentions): |
| print(" ".join(docs[i])) |
| preds = [] |
| for mention in mentions: |
| preds.append((mention["text"], mention["label"])) |
| print(preds) |
|
|