carisackc commited on
Commit
df27934
1 Parent(s): 1aae30f

Upload 2 files

Browse files
Files changed (2) hide show
  1. Summarization_25Nov2022.py +357 -0
  2. requirements.txt +4 -0
Summarization_25Nov2022.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from streamlit import components
4
+ import pandas as pd
5
+ from transformers import BartTokenizer, BartForConditionalGeneration
6
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
7
+ import evaluate
8
+ from datasets import load_dataset
9
+ from transformers import AutoTokenizer, LongT5ForConditionalGeneration
10
+ import numpy as np
11
+ from math import ceil
12
+ import en_core_web_lg
13
+ from collections import Counter
14
+ from string import punctuation
15
+ # Gensim
16
+ import gensim
17
+ from gensim.summarization import summarize
18
+ import spacy
19
+
20
+ nlp = en_core_web_lg.load()
21
+
22
+ st.set_page_config(page_title ='Clinical Note Summarization',
23
+ #page_icon= "Notes",
24
+ layout='wide')
25
+ st.title('Clinical Note Summarization')
26
+ st.sidebar.markdown('Using transformer model')
27
+
28
+ ## Loading in dataset
29
+ #df = pd.read_csv('mtsamples_small.csv',index_col=0)
30
+ df = pd.read_csv("shpi_w_rouge21Nov.csv")
31
+ #df.shape
32
+ df['HADM_ID'] = df['HADM_ID'].astype(str).apply(lambda x: x.replace('.0',''))
33
+
34
+ ##Renaming column
35
+ #df.rename(columns={'patient id':'Patient_ID',
36
+ # 'hospital admission id':'Admission_ID',
37
+ # 'transcription':'Original_Text'}, inplace = True)
38
+
39
+ #Renaming column
40
+ df.rename(columns={'SUBJECT_ID':'Patient_ID',
41
+ 'HADM_ID':'Admission_ID',
42
+ 'hpi_input_text':'Original_Text',
43
+ 'hpi_reference_summary':'Reference_text'}, inplace = True)
44
+
45
+ #data.rename(columns={'gdp':'log(gdp)'}, inplace=True)
46
+
47
+ #Filter selection
48
+ st.sidebar.header("Search for Patient:")
49
+
50
+ patientid = df['Patient_ID']
51
+ patient = st.sidebar.selectbox('Select Patient ID:', patientid)
52
+ admissionid = df['Admission_ID'].loc[df['Patient_ID'] == patient]
53
+ HospitalAdmission = st.sidebar.selectbox('', admissionid)
54
+
55
+ #Another way to for filter selection
56
+ #patient = st.sidebar.multiselect(
57
+ # "Select Patient ID:",
58
+ # options=df['Patient_ID'].unique(),
59
+ # default= None
60
+ #)
61
+
62
+
63
+ #HospitalAdmission = st.sidebar.multiselect(
64
+ # "Select Hospital Admission ID:",
65
+ # options=df['Admission_ID'].unique(),
66
+ # #default=df['Admission_ID'].unique()
67
+ # default = None
68
+ #)
69
+
70
+
71
+ # List of Model available
72
+ model = st.sidebar.selectbox('Select Model', ('BART','BERT','BertGPT2','Gensim','LexRank','Long T5','Luhn','Pysummarization','SBERT Summary Tokenizer','T5','T5 Seq2Seq','T5-Base','TextRank'))
73
+
74
+
75
+ if model == 'BART':
76
+ _num_beams = 4
77
+ _no_repeat_ngram_size = 3
78
+ _length_penalty = 1
79
+ _min_length = 12
80
+ _max_length = 128
81
+ _early_stopping = True
82
+ else:
83
+ _num_beams = 4
84
+ _no_repeat_ngram_size = 3
85
+ _length_penalty = 2
86
+ _min_length = 30
87
+ _max_length = 200
88
+ _early_stopping = True
89
+
90
+
91
+
92
+
93
+ col3,col4 = st.columns(2)
94
+ patientid = col3.write(f"Patient ID: {patient} ")
95
+ admissionid =col4.write(f"Admission ID: {HospitalAdmission} ")
96
+
97
+ col1, col2 = st.columns(2)
98
+ _min_length = col1.number_input("Minimum Length", value=_min_length)
99
+ _max_length = col2.number_input("Maximun Length", value=_max_length)
100
+ ##_early_stopping = col3.number_input("early_stopping", value=_early_stopping)
101
+
102
+ #text = st.text_area('Input Clinical Note here')
103
+
104
+ # Query out relevant Clinical notes
105
+ original_text = df.query(
106
+ "Patient_ID == @patient & Admission_ID == @HospitalAdmission"
107
+ )
108
+
109
+ original_text2 = original_text['Original_Text'].values
110
+
111
+ runtext =st.text_area('Input Clinical Note here:', str(original_text2), height=300)
112
+
113
+ reference_text = original_text['Reference_text'].values
114
+
115
+
116
+
117
+ ## ===== to highlight text =====
118
+ from IPython.core.display import HTML, display
119
+ def visualize(title, sentence_list, best_sentences):
120
+ text = ''
121
+
122
+ #display(HTML(f'<h1>Summary - {title}</h1>'))
123
+ for sentence in sentence_list:
124
+ if sentence in best_sentences:
125
+ #text += ' ' + str(sentence).replace(sentence, f"<mark>{sentence}</mark>")
126
+ text += ' ' + str(sentence).replace(sentence, f"<span class='highlight yellow'>{sentence}</span>")
127
+ else:
128
+ text += ' ' + sentence
129
+ display(HTML(f""" {text} """))
130
+
131
+ output = ''
132
+ best_sentences = []
133
+ for sentence in output:
134
+ #print(sentence)
135
+ best_sentences.append(str(sentence))
136
+ return text
137
+ # try this web solution https://discuss.streamlit.io/t/colored-boxes-around-sections-of-a-sentence/3201/2
138
+
139
+ #===== Pysummarization =====
140
+ from pysummarization.nlpbase.auto_abstractor import AutoAbstractor
141
+ from pysummarization.tokenizabledoc.simple_tokenizer import SimpleTokenizer
142
+ from pysummarization.abstractabledoc.top_n_rank_abstractor import TopNRankAbstractor
143
+ import regex as re
144
+
145
+ auto_abstractor = AutoAbstractor()
146
+ auto_abstractor.tokenizable_doc = SimpleTokenizer()
147
+ auto_abstractor.delimiter_list = [".", "\n"]
148
+ abstractable_doc = TopNRankAbstractor()
149
+
150
+ def pysummarizer(input_text):
151
+ # print(type(text))
152
+ summary = auto_abstractor.summarize(input_text, abstractable_doc)
153
+ best_sentences=[]
154
+ #summary_clean = ''.join([str(sentence).capitalize() for sentence in summary['summarize_result'] for summary['summarize_result'] in auto_abstractor.summarize(text, abstractable_doc)])
155
+ for sentence in summary['summarize_result']:
156
+ best_sentences.append(re.sub(r'\s+', ' ', sentence).strip())
157
+ clean_summary=''.join(sentence for sentence in best_sentences)
158
+ return clean_summary
159
+
160
+
161
+
162
+ ##===== BERT Summary tokenizer =====
163
+
164
+ def BertSummarizer(input_text):
165
+ from transformers import BigBirdTokenizer
166
+ from summarizer import Summarizer
167
+
168
+ bertsummarizer = Summarizer()
169
+
170
+ model = Summarizer()
171
+ result = model(input_text,ratio=0.4)
172
+
173
+ return result
174
+
175
+
176
+ ##===== SBERT =====
177
+ from summarizer.sbert import SBertSummarizer
178
+
179
+
180
+ Sbertmodel = SBertSummarizer('paraphrase-MiniLM-L6-v2')
181
+
182
+ def Sbert(input_text):
183
+
184
+ # Sbertresult = Sbertmodel(text, num_sentences=3)
185
+ Sbertresult = Sbertmodel(input_text, ratio=0.4)
186
+ return Sbertresult
187
+
188
+
189
+
190
+ ##===== T5 Seq2Seq =====
191
+ def t5seq2seq(input_text):
192
+ import torch
193
+ import torch.nn.functional as F
194
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
195
+
196
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
197
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
198
+
199
+ inputs = tokenizer("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
200
+ outputs = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
201
+
202
+ summary= tokenizer.decode(outputs[0], skip_special_tokens=True)
203
+
204
+ return summary
205
+
206
+ def BertGPT2(input_text):
207
+ #import nlp
208
+
209
+ # BioClinicalBert with BERT2GPT2 model with GPT2 decoder
210
+ from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel
211
+ from transformers import AutoTokenizer, AutoModel
212
+
213
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
214
+ model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
215
+ model.to(device)
216
+
217
+ #bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
218
+ bert_tokenizer= AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
219
+
220
+ # CLS token will work as BOS token
221
+ bert_tokenizer.bos_token = bert_tokenizer.cls_token
222
+
223
+ # SEP token will work as EOS token
224
+ bert_tokenizer.eos_token = bert_tokenizer.sep_token
225
+
226
+
227
+ # make sure GPT2 appends EOS in begin and end
228
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
229
+ outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
230
+ return outputs
231
+
232
+
233
+ GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens
234
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
235
+ # set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id
236
+ gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token
237
+
238
+
239
+ # set decoding params
240
+ model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id
241
+ model.config.eos_token_id = gpt2_tokenizer.eos_token_id
242
+ model.config.max_length = 142
243
+ model.config.min_length = 56
244
+ model.config.no_repeat_ngram_size = 3
245
+ model.early_stopping = True
246
+ model.length_penalty = 2.0
247
+ model.num_beams = 4
248
+
249
+ #test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test")
250
+
251
+ batch_size = 64
252
+
253
+ def Sbertmodel(batch):
254
+ # Tokenizer will automatically set [BOS] <text> [EOS]
255
+ # cut off at BERT max length 512
256
+ inputs = bert_tokenizer(batch, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
257
+ input_ids = inputs.input_ids.to("cuda")
258
+ attention_mask = inputs.attention_mask.to("cuda")
259
+
260
+ outputs = model.generate(input_ids, attention_mask=attention_mask)
261
+
262
+ # all special tokens including will be removed
263
+ output_str = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
264
+
265
+ #batch["pred"] = output_str
266
+
267
+ return output_str
268
+
269
+ Sbert(input_text)
270
+
271
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
272
+
273
+
274
+ def run_model(input_text):
275
+
276
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
277
+
278
+ if model == "BART":
279
+ bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
280
+ bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
281
+ input_text = str(input_text)
282
+ input_text = ' '.join(input_text.split())
283
+ input_tokenized = bart_tokenizer.encode(input_text, return_tensors='pt').to(device)
284
+ summary_ids = bart_model.generate(input_tokenized,
285
+ num_beams=_num_beams,
286
+ no_repeat_ngram_size=_no_repeat_ngram_size,
287
+ length_penalty=_length_penalty,
288
+ min_length=_min_length,
289
+ max_length=_max_length,
290
+ early_stopping=_early_stopping)
291
+
292
+ output = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
293
+ st.write('Summary')
294
+ st.success(output[0])
295
+
296
+ elif model == "T5":
297
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
298
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
299
+ input_text = str(input_text).replace('\n', '')
300
+ input_text = ' '.join(input_text.split())
301
+ input_tokenized = t5_tokenizer.encode(input_text, return_tensors="pt").to(device)
302
+ summary_task = torch.tensor([[21603, 10]]).to(device)
303
+ input_tokenized = torch.cat([summary_task, input_tokenized], dim=-1).to(device)
304
+ summary_ids = t5_model.generate(input_tokenized,
305
+ num_beams=_num_beams,
306
+ no_repeat_ngram_size=_no_repeat_ngram_size,
307
+ length_penalty=_length_penalty,
308
+ min_length=_min_length,
309
+ max_length=_max_length,
310
+ early_stopping=_early_stopping)
311
+ output = [t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]
312
+ st.write('Summary')
313
+ st.success(output[0])
314
+
315
+
316
+ elif model == "Gensim":
317
+ output=summarize(str(input_text))
318
+ #visualize('of text', input_text, output)
319
+ st.write('Summary')
320
+ st.success(output)
321
+
322
+ elif model == "Pysummarization":
323
+ output = pysummarizer(input_text)
324
+ st.write('Summary')
325
+ st.success(output)
326
+
327
+ elif model == "BERT":
328
+ output = BertSummarizer(input_text)
329
+ st.write('Summary')
330
+ st.success(output)
331
+
332
+ elif model == "SBERT Summary Tokenizer":
333
+ output = Sbert(input_text)
334
+ st.write('Summary')
335
+ st.success(output)
336
+
337
+ elif model == "T5 Seq2Seq":
338
+ output = t5seq2seq(input_text)
339
+ st.write('Summary')
340
+ st.success(output)
341
+
342
+ elif model == "BertGPT2": #Not working correctly. to work on it later on
343
+ output = BertGPT2(input_text)
344
+ st.write('Summary')
345
+ st.success(output)
346
+
347
+
348
+ if st.button('Submit'):
349
+ run_model(runtext)
350
+
351
+ # runtext2=runtext.split('.')
352
+ # reference_text2=reference_text.split('.')
353
+
354
+ st.write(visualize('of text', runtext ,reference_text))
355
+
356
+ st.text_area('Reference text', str(reference_text))
357
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit==1.14.0
2
+ pandas==1.3.5
3
+ numpy==1.20.0
4
+ regex==2022.9.13