--- license: bigscience-openrail-m widget: - text: >- wnt signalling orchestrates a number of developmental programs in response to this stimulus cytoplasmic beta catenin (encoded by ctnnb1) is stabilized enabling downstream transcriptional activation by members of the lef/tcf family datasets: - bigbio/drugprot - bigbio/ncbi_disease language: - en pipeline_tag: token-classification tags: - biology - medical --- # DistilBERT base model for restoring punctuation of medical/biotech speed-to-text transcripts E.g.: ``` EXAMPLE ``` will be punctuated as follows: ``` EXAMPLE ``` ## How to use it in your code: ```python import torch import numpy as np from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification checkpoint = "unikei/distilbert-base-re-punctuate" tokenizer = DistilBertTokenizerFast.from_pretrained(checkpoint) model = DistilBertForTokenClassification.from_pretrained(checkpoint) encoder_max_length = 256 # # Split text to segments of length 200, with overlap 50 # def split_to_segments(wrds, length, overlap): resp = [] i = 0 while True: wrds_split = wrds[(length * i):((length * (i + 1)) + overlap)] if not wrds_split: break resp_obj = { "text": wrds_split, "start_idx": length * i, "end_idx": (length * (i + 1)) + overlap, } resp.append(resp_obj) i += 1 return resp # # Punctuate wordpieces # def punctuate_wordpiece(wordpiece, label): if label.startswith('UPPER'): wordpiece = wordpiece.upper() elif label.startswith('Upper'): wordpiece = wordpiece[0].upper() + wordpiece[1:] if label[-1] != '_' and label[-1] != wordpiece[-1]: wordpiece += label[-1] return wordpiece # # Punctuate text segments (200 words) # def punctuate_segment(wordpieces, word_ids, labels, start_word): result = '' for idx in range(0, len(wordpieces)): if word_ids[idx] == None: continue if word_ids[idx] < start_word: continue wordpiece = punctuate_wordpiece(wordpieces[idx][2:] if wordpieces[idx].startswith('##') else wordpieces[idx], labels[idx]) if idx > 0 and len(result) > 0 and word_ids[idx] != word_ids[idx - 1] and result[-1] != '-': result += ' ' result += wordpiece return result # # Tokenize, predict, punctuate text segments (200 words) # def process_segment(words, tokenizer, model, start_word): tokens = tokenizer(words['text'], padding="max_length", # truncation=True, max_length=encoder_max_length, is_split_into_words=True, return_tensors='pt') with torch.no_grad(): logits = model(**tokens).logits logits = logits.cpu() predictions = np.argmax(logits, axis=-1) wordpieces = tokens.tokens() word_ids = tokens.word_ids() id2label = model.config.id2label labels = [[id2label[p.item()] for p in prediction] for prediction in predictions][0] return punctuate_segment(wordpieces, word_ids, labels, start_word) # # Punctuate text of any length # def punctuate(text, tokenizer, model): text = text.lower() text = text.replace('\n', ' ') words = text.split(' ') overlap = 50 slices = split_to_segments(words, 150, 50) result = "" start_word = 0 for text in slices: corrected = process_segment(text, tokenizer, model, start_word) result += corrected + ' ' start_word = overlap return result # # Example # text = "" result = punctuate(text, tokenizer, model) print(result)