File size: 2,921 Bytes
960672d
 
31cbd81
960672d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import tensorflow as tf
from tensorflow.keras.layers import TextVectorization
import pandas as pd
import re, string


model_path = "skimlit_max_trained_model"


@tf.keras.utils.register_keras_serializable(package="Custom", name=None)
def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    stripped_html = tf.strings.regex_replace(lowercase, "", " ")
    return tf.strings.regex_replace(
        stripped_html, "[%s]" % re.escape(string.punctuation), ""
    )


model = None
with tf.keras.utils.custom_object_scope(
    {"custom_standardization": custom_standardization}
):
    model = tf.keras.models.load_model(model_path)  # Folder

if model is not None:
    model.summary()
else:
    print("Model not loaded, sorry...")

# Load saved model
loaded_model = tf.keras.models.load_model(model_path)


def get_results(abs_text):
    text = ""
    labels = ["BACKGROUND", "CONCLUSIONS", "METHODS", "OBJECTIVE", "RESULTS"]
    sent_list = abs_text.split(sep=".")
    sent_list.pop()
    sent_list = [sentence.strip() for sentence in sent_list]
    i = 0
    total_lines = len(sent_list)
    final_list = []
    temp = {}
    for line in sent_list:
        temp["text"] = line
        temp["line_number"] = i
        temp["total_lines"] = total_lines
        i += 1
        final_list.append(temp)
        temp = {}
    df = pd.DataFrame(final_list)
    chars = [" ".join(list(sentence)) for sentence in sent_list]
    line_numbers_one_hot = tf.one_hot(df.line_number.to_numpy(), depth=15)
    lines_total_one_hot = tf.one_hot(df.total_lines.to_numpy(), depth=20)
    preds = tf.argmax(
        loaded_model.predict(
            x=(
                line_numbers_one_hot,
                lines_total_one_hot,
                tf.constant(sent_list),
                tf.constant(chars),
            ),
            verbose=0,
        ),
        axis=1,
    )
    i = 0
    for i in range(0, total_lines):
        if i != 0 and preds[i] != preds[i - 1]:
            text += f"\n\n{labels[preds[i]]}: "
            text += f"{sent_list[i]}."
            # print(f"\n\n{labels[preds[i]]}:", end=" ")
            # print(f"{sent_list[i]}.", end=" ")
        elif i == 0:
            text += f"{labels[preds[i]]}: "
            text += f"{sent_list[i]}."
            # print(f"{labels[preds[i]]}:", end=" ")
            # print(f"{sent_list[i]}.", end=" ")
        else:
            text += f"{sent_list[i]}."
            # print(f"{sent_list[i]}.", end=" ")
    return text


def skimlit_analysis(text):
    # print(output)
    return get_results(text)


intfc = gr.Interface(
    fn=skimlit_analysis,
    inputs=gr.Textbox(
        label="Input abstract here",
        lines=2,
        placeholder="Input your abstract here (the model works best with ones from medical research papers)",
    ),
    outputs=gr.Text(label="Skimmable Abstract", interactive=False),
)
intfc.launch()