--- tags: - Chinese Medical - Punctuation Restoration language: - zh license: mit pipeline_tag: token-classification base_model: rickltt/pmp-h256 --- ## Example Usage ```python import torch import jieba import numpy as np from classifier import BertForMaskClassification from transformers import AutoTokenizer, AutoConfig, BertForTokenClassification label_list = ["O","COMMA","PERIOD","COLON"] label2punct = { "COMMA": ",", "PERIOD": "。", "COLON":":", } model_name_or_path = "pmp-h312" tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = BertForMaskClassification.from_pretrained(model_name_or_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def punct(text): tokenize_words = jieba.lcut(''.join(text)) mask_tokens = [] for word in tokenize_words: mask_tokens.extend(word) mask_tokens.append("[MASK]") tokenized_inputs = tokenizer(mask_tokens,is_split_into_words=True, return_tensors="pt") with torch.no_grad(): logits = model(**tokenized_inputs).logits predictions = logits.argmax(-1).tolist() predictions = predictions[0] tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][0]) result =[] print(tokens) print(predictions) for token, prediction in zip(tokens, predictions): if token =="[CLS]" or token =="[SEP]": continue if token == "[MASK]": label = label_list[prediction] if label != "O": punct = label2punct[label] result.append(punct) else: result.append(token) return "".join(result) text = '肝浊音界正常肝上界位于锁骨中线第五肋间移动浊音阴性肾区无叩痛' print(punct(text)) # 肝浊音界正常,肝上界位于锁骨中线第五肋间,移动浊音阴性,肾区无叩痛。 ``` # Acknowledgments This work was in part supported by Shenzhen Science and Technology Program (No:JCYJ20210324135809025). # Citations Coming Soon # License MIT