skimlit / app.py
prosekutor's picture
fixed import
31cbd81
import gradio as gr
import tensorflow as tf
from tensorflow.keras.layers import TextVectorization
import pandas as pd
import re, string
model_path = "skimlit_max_trained_model"
@tf.keras.utils.register_keras_serializable(package="Custom", name=None)
def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
stripped_html = tf.strings.regex_replace(lowercase, "", " ")
return tf.strings.regex_replace(
stripped_html, "[%s]" % re.escape(string.punctuation), ""
)
model = None
with tf.keras.utils.custom_object_scope(
{"custom_standardization": custom_standardization}
):
model = tf.keras.models.load_model(model_path) # Folder
if model is not None:
model.summary()
else:
print("Model not loaded, sorry...")
# Load saved model
loaded_model = tf.keras.models.load_model(model_path)
def get_results(abs_text):
text = ""
labels = ["BACKGROUND", "CONCLUSIONS", "METHODS", "OBJECTIVE", "RESULTS"]
sent_list = abs_text.split(sep=".")
sent_list.pop()
sent_list = [sentence.strip() for sentence in sent_list]
i = 0
total_lines = len(sent_list)
final_list = []
temp = {}
for line in sent_list:
temp["text"] = line
temp["line_number"] = i
temp["total_lines"] = total_lines
i += 1
final_list.append(temp)
temp = {}
df = pd.DataFrame(final_list)
chars = [" ".join(list(sentence)) for sentence in sent_list]
line_numbers_one_hot = tf.one_hot(df.line_number.to_numpy(), depth=15)
lines_total_one_hot = tf.one_hot(df.total_lines.to_numpy(), depth=20)
preds = tf.argmax(
loaded_model.predict(
x=(
line_numbers_one_hot,
lines_total_one_hot,
tf.constant(sent_list),
tf.constant(chars),
),
verbose=0,
),
axis=1,
)
i = 0
for i in range(0, total_lines):
if i != 0 and preds[i] != preds[i - 1]:
text += f"\n\n{labels[preds[i]]}: "
text += f"{sent_list[i]}."
# print(f"\n\n{labels[preds[i]]}:", end=" ")
# print(f"{sent_list[i]}.", end=" ")
elif i == 0:
text += f"{labels[preds[i]]}: "
text += f"{sent_list[i]}."
# print(f"{labels[preds[i]]}:", end=" ")
# print(f"{sent_list[i]}.", end=" ")
else:
text += f"{sent_list[i]}."
# print(f"{sent_list[i]}.", end=" ")
return text
def skimlit_analysis(text):
# print(output)
return get_results(text)
intfc = gr.Interface(
fn=skimlit_analysis,
inputs=gr.Textbox(
label="Input abstract here",
lines=2,
placeholder="Input your abstract here (the model works best with ones from medical research papers)",
),
outputs=gr.Text(label="Skimmable Abstract", interactive=False),
)
intfc.launch()