Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
from tensorflow.keras.layers import TextVectorization | |
import pandas as pd | |
import re, string | |
model_path = "skimlit_max_trained_model" | |
def custom_standardization(input_data): | |
lowercase = tf.strings.lower(input_data) | |
stripped_html = tf.strings.regex_replace(lowercase, "", " ") | |
return tf.strings.regex_replace( | |
stripped_html, "[%s]" % re.escape(string.punctuation), "" | |
) | |
model = None | |
with tf.keras.utils.custom_object_scope( | |
{"custom_standardization": custom_standardization} | |
): | |
model = tf.keras.models.load_model(model_path) # Folder | |
if model is not None: | |
model.summary() | |
else: | |
print("Model not loaded, sorry...") | |
# Load saved model | |
loaded_model = tf.keras.models.load_model(model_path) | |
def get_results(abs_text): | |
text = "" | |
labels = ["BACKGROUND", "CONCLUSIONS", "METHODS", "OBJECTIVE", "RESULTS"] | |
sent_list = abs_text.split(sep=".") | |
sent_list.pop() | |
sent_list = [sentence.strip() for sentence in sent_list] | |
i = 0 | |
total_lines = len(sent_list) | |
final_list = [] | |
temp = {} | |
for line in sent_list: | |
temp["text"] = line | |
temp["line_number"] = i | |
temp["total_lines"] = total_lines | |
i += 1 | |
final_list.append(temp) | |
temp = {} | |
df = pd.DataFrame(final_list) | |
chars = [" ".join(list(sentence)) for sentence in sent_list] | |
line_numbers_one_hot = tf.one_hot(df.line_number.to_numpy(), depth=15) | |
lines_total_one_hot = tf.one_hot(df.total_lines.to_numpy(), depth=20) | |
preds = tf.argmax( | |
loaded_model.predict( | |
x=( | |
line_numbers_one_hot, | |
lines_total_one_hot, | |
tf.constant(sent_list), | |
tf.constant(chars), | |
), | |
verbose=0, | |
), | |
axis=1, | |
) | |
i = 0 | |
for i in range(0, total_lines): | |
if i != 0 and preds[i] != preds[i - 1]: | |
text += f"\n\n{labels[preds[i]]}: " | |
text += f"{sent_list[i]}." | |
# print(f"\n\n{labels[preds[i]]}:", end=" ") | |
# print(f"{sent_list[i]}.", end=" ") | |
elif i == 0: | |
text += f"{labels[preds[i]]}: " | |
text += f"{sent_list[i]}." | |
# print(f"{labels[preds[i]]}:", end=" ") | |
# print(f"{sent_list[i]}.", end=" ") | |
else: | |
text += f"{sent_list[i]}." | |
# print(f"{sent_list[i]}.", end=" ") | |
return text | |
def skimlit_analysis(text): | |
# print(output) | |
return get_results(text) | |
intfc = gr.Interface( | |
fn=skimlit_analysis, | |
inputs=gr.Textbox( | |
label="Input abstract here", | |
lines=2, | |
placeholder="Input your abstract here (the model works best with ones from medical research papers)", | |
), | |
outputs=gr.Text(label="Skimmable Abstract", interactive=False), | |
) | |
intfc.launch() | |