unikei's picture
Update README.md
637af1c
|
raw
history blame
3.71 kB
metadata
license: bigscience-openrail-m
widget:
  - text: >-
      wnt signalling orchestrates a number of developmental programs in response
      to this stimulus  cytoplasmic beta catenin (encoded by ctnnb1) is
      stabilized enabling downstream transcriptional activation by members of
      the lef/tcf family
datasets:
  - bigbio/drugprot
  - bigbio/ncbi_disease
language:
  - en
pipeline_tag: token-classification
tags:
  - biology
  - medical

DistilBERT base model for restoring punctuation of medical/biotech speed-to-text transcripts

E.g.:

EXAMPLE

will be punctuated as follows:

EXAMPLE

How to use it in your code:

import torch
import numpy as np
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification

checkpoint = "unikei/distilbert-base-re-punctuate"
tokenizer = DistilBertTokenizerFast.from_pretrained(checkpoint)
model = DistilBertForTokenClassification.from_pretrained(checkpoint)
encoder_max_length = 256

#
# Split text to segments of length 200, with overlap 50
#
def split_to_segments(wrds, length, overlap):
    resp = []
    i = 0
    while True:
        wrds_split = wrds[(length * i):((length * (i + 1)) + overlap)]
        if not wrds_split:
            break

        resp_obj = {
            "text": wrds_split,
            "start_idx": length * i,
            "end_idx": (length * (i + 1)) + overlap,
        }

        resp.append(resp_obj)
        i += 1
    return resp


#
# Punctuate wordpieces
#
def punctuate_wordpiece(wordpiece, label):
    if label.startswith('UPPER'):
        wordpiece = wordpiece.upper()
    elif label.startswith('Upper'):
        wordpiece = wordpiece[0].upper() + wordpiece[1:]
    if label[-1] != '_' and label[-1] != wordpiece[-1]:
        wordpiece += label[-1]
    return wordpiece


#
# Punctuate text segments (200 words)
#
def punctuate_segment(wordpieces, word_ids, labels, start_word):
    result = ''
    for idx in range(0, len(wordpieces)):
        if word_ids[idx] == None:
            continue
        if word_ids[idx] < start_word:
            continue
        wordpiece = punctuate_wordpiece(wordpieces[idx][2:] if wordpieces[idx].startswith('##') else wordpieces[idx],
                            labels[idx])
        if idx > 0 and len(result) > 0 and word_ids[idx] != word_ids[idx - 1] and result[-1] != '-':
            result += ' '
        result += wordpiece
    return result


#
# Tokenize, predict, punctuate text segments (200 words)
#
def process_segment(words, tokenizer, model, start_word):

    tokens = tokenizer(words['text'],
                       padding="max_length",
                       # truncation=True,
                       max_length=encoder_max_length,
                       is_split_into_words=True, return_tensors='pt')
    
    with torch.no_grad():
        logits = model(**tokens).logits
    logits = logits.cpu()
    predictions = np.argmax(logits, axis=-1)

    wordpieces = tokens.tokens()
    word_ids = tokens.word_ids()
    id2label = model.config.id2label
    labels = [[id2label[p.item()] for p in prediction] for prediction in predictions][0]

    return punctuate_segment(wordpieces, word_ids, labels, start_word)


#
# Punctuate text of any length
#
def punctuate(text, tokenizer, model):
    text = text.lower()
    text = text.replace('\n', ' ')
    words = text.split(' ')
    
    overlap = 50
    slices = split_to_segments(words, 150, 50)
    
    result = ""
    start_word = 0
    for text in slices:
        corrected = process_segment(text, tokenizer, model, start_word)
        result += corrected + ' '
        start_word = overlap
    return result

#
# Example
#
text = ""
result = punctuate(text, tokenizer, model)
print(result)