unikei commited on
Commit
637af1c
1 Parent(s): 86c57e0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +138 -2
README.md CHANGED
@@ -2,5 +2,141 @@
2
  license: bigscience-openrail-m
3
  widget:
4
  - text: >-
5
- wnt signalling orchestrates a number of developmental programs in response to this stimulus cytoplasmic beta catenin (encoded by ctnnb1) is stabilized enabling downstream transcriptional activation by members of the lef/tcf family
6
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: bigscience-openrail-m
3
  widget:
4
  - text: >-
5
+ wnt signalling orchestrates a number of developmental programs in response
6
+ to this stimulus cytoplasmic beta catenin (encoded by ctnnb1) is stabilized
7
+ enabling downstream transcriptional activation by members of the lef/tcf
8
+ family
9
+ datasets:
10
+ - bigbio/drugprot
11
+ - bigbio/ncbi_disease
12
+ language:
13
+ - en
14
+ pipeline_tag: token-classification
15
+ tags:
16
+ - biology
17
+ - medical
18
+ ---
19
+
20
+ # DistilBERT base model for restoring punctuation of medical/biotech speed-to-text transcripts
21
+ E.g.:
22
+ ```
23
+ EXAMPLE
24
+ ```
25
+ will be punctuated as follows:
26
+ ```
27
+ EXAMPLE
28
+ ```
29
+
30
+ ## How to use it in your code:
31
+ ```python
32
+ import torch
33
+ import numpy as np
34
+ from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification
35
+
36
+ checkpoint = "unikei/distilbert-base-re-punctuate"
37
+ tokenizer = DistilBertTokenizerFast.from_pretrained(checkpoint)
38
+ model = DistilBertForTokenClassification.from_pretrained(checkpoint)
39
+ encoder_max_length = 256
40
+
41
+ #
42
+ # Split text to segments of length 200, with overlap 50
43
+ #
44
+ def split_to_segments(wrds, length, overlap):
45
+ resp = []
46
+ i = 0
47
+ while True:
48
+ wrds_split = wrds[(length * i):((length * (i + 1)) + overlap)]
49
+ if not wrds_split:
50
+ break
51
+
52
+ resp_obj = {
53
+ "text": wrds_split,
54
+ "start_idx": length * i,
55
+ "end_idx": (length * (i + 1)) + overlap,
56
+ }
57
+
58
+ resp.append(resp_obj)
59
+ i += 1
60
+ return resp
61
+
62
+
63
+ #
64
+ # Punctuate wordpieces
65
+ #
66
+ def punctuate_wordpiece(wordpiece, label):
67
+ if label.startswith('UPPER'):
68
+ wordpiece = wordpiece.upper()
69
+ elif label.startswith('Upper'):
70
+ wordpiece = wordpiece[0].upper() + wordpiece[1:]
71
+ if label[-1] != '_' and label[-1] != wordpiece[-1]:
72
+ wordpiece += label[-1]
73
+ return wordpiece
74
+
75
+
76
+ #
77
+ # Punctuate text segments (200 words)
78
+ #
79
+ def punctuate_segment(wordpieces, word_ids, labels, start_word):
80
+ result = ''
81
+ for idx in range(0, len(wordpieces)):
82
+ if word_ids[idx] == None:
83
+ continue
84
+ if word_ids[idx] < start_word:
85
+ continue
86
+ wordpiece = punctuate_wordpiece(wordpieces[idx][2:] if wordpieces[idx].startswith('##') else wordpieces[idx],
87
+ labels[idx])
88
+ if idx > 0 and len(result) > 0 and word_ids[idx] != word_ids[idx - 1] and result[-1] != '-':
89
+ result += ' '
90
+ result += wordpiece
91
+ return result
92
+
93
+
94
+ #
95
+ # Tokenize, predict, punctuate text segments (200 words)
96
+ #
97
+ def process_segment(words, tokenizer, model, start_word):
98
+
99
+ tokens = tokenizer(words['text'],
100
+ padding="max_length",
101
+ # truncation=True,
102
+ max_length=encoder_max_length,
103
+ is_split_into_words=True, return_tensors='pt')
104
+
105
+ with torch.no_grad():
106
+ logits = model(**tokens).logits
107
+ logits = logits.cpu()
108
+ predictions = np.argmax(logits, axis=-1)
109
+
110
+ wordpieces = tokens.tokens()
111
+ word_ids = tokens.word_ids()
112
+ id2label = model.config.id2label
113
+ labels = [[id2label[p.item()] for p in prediction] for prediction in predictions][0]
114
+
115
+ return punctuate_segment(wordpieces, word_ids, labels, start_word)
116
+
117
+
118
+ #
119
+ # Punctuate text of any length
120
+ #
121
+ def punctuate(text, tokenizer, model):
122
+ text = text.lower()
123
+ text = text.replace('\n', ' ')
124
+ words = text.split(' ')
125
+
126
+ overlap = 50
127
+ slices = split_to_segments(words, 150, 50)
128
+
129
+ result = ""
130
+ start_word = 0
131
+ for text in slices:
132
+ corrected = process_segment(text, tokenizer, model, start_word)
133
+ result += corrected + ' '
134
+ start_word = overlap
135
+ return result
136
+
137
+ #
138
+ # Example
139
+ #
140
+ text = ""
141
+ result = punctuate(text, tokenizer, model)
142
+ print(result)