lisaterumi commited on
Commit
57d44a7
1 Parent(s): ba4725d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import transformers
3
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
4
+ import torch
5
+
6
+ # model large
7
+ model_name = "pucpr/clinicalnerpt-chemical"
8
+ model_large = AutoModelForTokenClassification.from_pretrained(model_name)
9
+ tokenizer_large = AutoTokenizer.from_pretrained(model_name)
10
+
11
+ # model base
12
+ model_name = "pucpr/clinicalnerpt-chemical"
13
+ model_base = AutoModelForTokenClassification.from_pretrained(model_name)
14
+ tokenizer_base = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ # css
17
+ background_colors_entity_word = {
18
+ 'ChemicalDrugs': "#fae8ff",
19
+ }
20
+
21
+ background_colors_entity_tag = {
22
+ 'ChemicalDrugs': "#d946ef",
23
+ }
24
+
25
+ css = {
26
+ 'entity_word': 'color:#000000;background: #xxxxxx; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 2.5; border-radius: 0.35em;',
27
+ 'entity_tag': 'color:#fff;background: #xxxxxx; font-size: 0.8em; font-weight: bold; line-height: 2.5; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5em;'
28
+ }
29
+
30
+ list_EN = "<span style='"
31
+ list_EN += f"{css['entity_tag'].replace('#xxxxxx',background_colors_entity_tag['ChemicalDrugs'])};padding:0.5em;"
32
+ list_EN += "'>ChemicalDrugs</span>"
33
+
34
+ # infos
35
+ title = "BioBERTpt - Chemical entities"
36
+ description = "BioBERTpt - Chemical entities"
37
+ allow_screenshot = False
38
+ allow_flagging = False
39
+ examples = [
40
+ ["Dispneia venoso central em subclavia D duplolumen recebendo solução salina e glicosada em BI."],
41
+ ["Paciente com Sepse pulmonar em D8 tazocin (paciente não recebeu por 2 dias Atb)."],
42
+ ["FOI REALIZADO CURSO DE ATB COM LEVOFLOXACINA POR 7 DIAS."],
43
+ ]
44
+
45
+ def ner(input_text):
46
+
47
+ num = 0
48
+ for tokenizer,model in zip([tokenizer_large,tokenizer_base],[model_large,model_base]):
49
+
50
+ # tokenization
51
+ inputs = tokenizer(input_text, max_length=512, truncation=True, return_tensors="pt")
52
+ tokens = inputs.tokens()
53
+
54
+ # get predictions
55
+ outputs = model(**inputs).logits
56
+ predictions = torch.argmax(outputs, dim=2)
57
+ preds = [model_base.config.id2label[prediction] for prediction in predictions[0].numpy()]
58
+
59
+ # variables
60
+ groups_pred = dict()
61
+ group_indices = list()
62
+ group_label = ''
63
+ pred_prec = ''
64
+ group_start = ''
65
+ count = 0
66
+
67
+ # group the NEs
68
+ for i,en in enumerate(preds):
69
+
70
+ if en == 'O':
71
+
72
+ if len(group_indices) > 0:
73
+ groups_pred[count] = {'indices':group_indices,'en':group_label}
74
+ group_indices = list()
75
+ group_label = ''
76
+ count += 1
77
+
78
+ if en.startswith('B'):
79
+
80
+ if len(group_indices) > 0:
81
+ groups_pred[count] = {'indices':group_indices,'en':group_label}
82
+ group_indices = list()
83
+ group_label = ''
84
+ count += 1
85
+
86
+ group_indices.append(i)
87
+ group_label = en.replace('B-','')
88
+ pred_prec = en
89
+
90
+ elif en.startswith('I'):
91
+
92
+ if len(group_indices) > 0:
93
+ if en.replace('I-','') == group_label:
94
+ group_indices.append(i)
95
+ else:
96
+ groups_pred[count] = {'indices':group_indices,'en':group_label}
97
+ group_indices = [i]
98
+ group_label = en.replace('I-','')
99
+ count += 1
100
+ else:
101
+ group_indices = [i]
102
+ group_label = en.replace('I-','')
103
+
104
+ if i == len(preds) - 1 and len(group_indices) > 0:
105
+ groups_pred[count] = {'indices':group_indices,'en':group_label}
106
+ group_indices = list()
107
+ group_label = ''
108
+ count += 1
109
+
110
+ # there is at least one NE
111
+ len_groups_pred = len(groups_pred)
112
+ inputs = inputs['input_ids'][0].numpy()#[1:-1]
113
+ if len_groups_pred > 0:
114
+ for pred_num in range(len_groups_pred):
115
+ en = groups_pred[pred_num]['en']
116
+ indices = groups_pred[pred_num]['indices']
117
+ if pred_num == 0:
118
+ if indices[0] > 0:
119
+ output = tokenizer.decode(inputs[:indices[0]]) + f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
120
+ else:
121
+ output = f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
122
+ else:
123
+ output += tokenizer.decode(inputs[indices_prev[-1]+1:indices[0]]) + f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
124
+ indices_prev = indices
125
+ output += tokenizer.decode(inputs[indices_prev[-1]+1:])
126
+ else:
127
+ output = input_text
128
+
129
+ # output
130
+ output = output.replace('[CLS]','').replace(' [SEP]','').replace('##','')
131
+ output = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + output + "</div>"
132
+
133
+ if num == 0:
134
+ output_large = output
135
+ num += 1
136
+ else: output_base = output
137
+
138
+ return output_large, output_base
139
+
140
+ # interface gradio
141
+ iface = gr.Interface(
142
+ title=title,
143
+ description=description,
144
+ article=article,
145
+ allow_screenshot=allow_screenshot,
146
+ allow_flagging=allow_flagging,
147
+ fn=ner,
148
+ inputs=gr.inputs.Textbox(placeholder="Digite uma frase aqui ou clique em um exemplo:", lines=5),
149
+ outputs=[gr.outputs.HTML(label="NER1"),gr.outputs.HTML(label="NER2")],
150
+ examples=examples
151
+ )
152
+
153
+ iface.launch()