pmp-h312 / README.md
rickltt's picture
Update README.md
a12574c verified
metadata
tags:
  - Chinese Medical
  - Punctuation Restoration
language:
  - zh
license: mit
pipeline_tag: token-classification
base_model: rickltt/pmp-h256

Example Usage

import torch
import jieba
import numpy as np
from classifier import BertForMaskClassification
from transformers import AutoTokenizer, AutoConfig, BertForTokenClassification

label_list = ["O","COMMA","PERIOD","COLON"]

label2punct = {
    "COMMA": ",",
    "PERIOD": "。",
    "COLON":":",
}

model_name_or_path = "pmp-h312"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = BertForMaskClassification.from_pretrained(model_name_or_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def punct(text):

    tokenize_words = jieba.lcut(''.join(text))
    mask_tokens = []
    for word in tokenize_words:
        mask_tokens.extend(word)
        mask_tokens.append("[MASK]")
    tokenized_inputs = tokenizer(mask_tokens,is_split_into_words=True, return_tensors="pt")
    with torch.no_grad():   
        logits = model(**tokenized_inputs).logits
    predictions = logits.argmax(-1).tolist()
    predictions = predictions[0]
    tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][0])

    result =[]
    print(tokens)
    print(predictions)
    for token, prediction in zip(tokens, predictions):
        if token =="[CLS]" or token =="[SEP]":
            continue
        if token == "[MASK]":
            label = label_list[prediction]
            if label != "O":
                punct = label2punct[label]
                result.append(punct)
        else:
            result.append(token)

    return "".join(result)

text = '肝浊音界正常肝上界位于锁骨中线第五肋间移动浊音阴性肾区无叩痛'
print(punct(text))

# 肝浊音界正常,肝上界位于锁骨中线第五肋间,移动浊音阴性,肾区无叩痛。

Acknowledgments

This work was in part supported by Shenzhen Science and Technology Program (No:JCYJ20210324135809025).

Citations

Coming Soon

License

MIT