| | import numpy as np |
| | import onnxruntime as ort |
| | from huggingface_hub import hf_hub_download |
| | from omegaconf import OmegaConf |
| | from sentencepiece import SentencePieceProcessor |
| | from typing import List |
| |
|
| | def process_text(input_text: str) -> str: |
| | spe_path = "sp.model" |
| | tokenizer: SentencePieceProcessor = SentencePieceProcessor(spe_path) |
| |
|
| | |
| | onnx_path = "model.onnx" |
| | ort_session: ort.InferenceSession = ort.InferenceSession(onnx_path) |
| |
|
| | |
| | config_path = "config.yaml" |
| | config = OmegaConf.load(config_path) |
| | |
| | pre_labels: List[str] = config.pre_labels |
| | |
| | post_labels: List[str] = config.post_labels |
| | |
| | null_token = config.get("null_token", "<NULL>") |
| | |
| | acronym_token = config.get("acronym_token", "<ACRONYM>") |
| | |
| | max_len = config.max_length |
| | |
| | languages: List[str] = config.languages |
| |
|
| | |
| | input_ids = [tokenizer.bos_id()] + tokenizer.EncodeAsIds(input_text) + [tokenizer.eos_id()] |
| |
|
| | |
| | input_ids_arr: np.array = np.array([input_ids]) |
| |
|
| | |
| | pre_preds, post_preds, cap_preds, sbd_preds = ort_session.run(None, {"input_ids": input_ids_arr}) |
| | |
| | pre_preds = pre_preds[0].tolist() |
| | post_preds = post_preds[0].tolist() |
| | cap_preds = cap_preds[0].tolist() |
| | sbd_preds = sbd_preds[0].tolist() |
| |
|
| | |
| | output_texts: List[str] = [] |
| | current_chars: List[str] = [] |
| |
|
| | for token_idx in range(1, len(input_ids) - 1): |
| | token = tokenizer.IdToPiece(input_ids[token_idx]) |
| | if token.startswith("▁") and current_chars: |
| | current_chars.append(" ") |
| | |
| | pre_label = pre_labels[pre_preds[token_idx]] |
| | post_label = post_labels[post_preds[token_idx]] |
| | |
| | if pre_label != null_token: |
| | current_chars.append(pre_label) |
| | |
| | char_start = 1 if token.startswith("▁") else 0 |
| | for token_char_idx, char in enumerate(token[char_start:], start=char_start): |
| | |
| | if cap_preds[token_idx][token_char_idx]: |
| | char = char.upper() |
| | |
| | current_chars.append(char) |
| | |
| | if post_label == acronym_token: |
| | current_chars.append(".") |
| | |
| | if post_label != null_token and post_label != acronym_token: |
| | current_chars.append(post_label) |
| |
|
| | |
| | if sbd_preds[token_idx]: |
| | output_texts.append("".join(current_chars)) |
| | current_chars.clear() |
| |
|
| | |
| | output_texts.append("".join(current_chars)) |
| |
|
| | |
| | return "\n".join(output_texts) |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|