import gradio as gr import tensorflow as tf from tensorflow.keras.layers import TextVectorization import pandas as pd import re, string model_path = "skimlit_max_trained_model" @tf.keras.utils.register_keras_serializable(package="Custom", name=None) def custom_standardization(input_data): lowercase = tf.strings.lower(input_data) stripped_html = tf.strings.regex_replace(lowercase, "", " ") return tf.strings.regex_replace( stripped_html, "[%s]" % re.escape(string.punctuation), "" ) model = None with tf.keras.utils.custom_object_scope( {"custom_standardization": custom_standardization} ): model = tf.keras.models.load_model(model_path) # Folder if model is not None: model.summary() else: print("Model not loaded, sorry...") # Load saved model loaded_model = tf.keras.models.load_model(model_path) def get_results(abs_text): text = "" labels = ["BACKGROUND", "CONCLUSIONS", "METHODS", "OBJECTIVE", "RESULTS"] sent_list = abs_text.split(sep=".") sent_list.pop() sent_list = [sentence.strip() for sentence in sent_list] i = 0 total_lines = len(sent_list) final_list = [] temp = {} for line in sent_list: temp["text"] = line temp["line_number"] = i temp["total_lines"] = total_lines i += 1 final_list.append(temp) temp = {} df = pd.DataFrame(final_list) chars = [" ".join(list(sentence)) for sentence in sent_list] line_numbers_one_hot = tf.one_hot(df.line_number.to_numpy(), depth=15) lines_total_one_hot = tf.one_hot(df.total_lines.to_numpy(), depth=20) preds = tf.argmax( loaded_model.predict( x=( line_numbers_one_hot, lines_total_one_hot, tf.constant(sent_list), tf.constant(chars), ), verbose=0, ), axis=1, ) i = 0 for i in range(0, total_lines): if i != 0 and preds[i] != preds[i - 1]: text += f"\n\n{labels[preds[i]]}: " text += f"{sent_list[i]}." # print(f"\n\n{labels[preds[i]]}:", end=" ") # print(f"{sent_list[i]}.", end=" ") elif i == 0: text += f"{labels[preds[i]]}: " text += f"{sent_list[i]}." # print(f"{labels[preds[i]]}:", end=" ") # print(f"{sent_list[i]}.", end=" ") else: text += f"{sent_list[i]}." # print(f"{sent_list[i]}.", end=" ") return text def skimlit_analysis(text): # print(output) return get_results(text) intfc = gr.Interface( fn=skimlit_analysis, inputs=gr.Textbox( label="Input abstract here", lines=2, placeholder="Input your abstract here (the model works best with ones from medical research papers)", ), outputs=gr.Text(label="Skimmable Abstract", interactive=False), ) intfc.launch()