import gradio as gr import tensorflow as tf from transformers import TFAutoModel, AutoTokenizer import os import numpy as np import shap import scipy.special model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest' tokenizer = AutoTokenizer.from_pretrained(model_name) model = tf.keras.models.load_model( "model.h5", custom_objects={ 'TFRobertaModel': TFAutoModel.from_pretrained(model_name) } ) labels = [ 'Cardiologist', 'Dermatologist', 'ENT Specialist', 'Gastro-enterologist', 'General-Physicians', 'Neurologist/Gastro-enterologist', 'Ophthalmologist', 'Orthopedist', 'Psychiatrist', 'Respirologist', 'Rheumatologist', 'Rheumatologist/Gastro-enterologist', 'Rheumatologist/Orthopedist', 'Surgeon' ] seq_len = 152 def prep_data(text): tokens = tokenizer( text, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf' ) return { 'input_ids': tokens['input_ids'], 'attention_mask': tokens['attention_mask'] } def inference(text): encoded_text = prep_data(text) probs = model.predict_on_batch(encoded_text) probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))} return probabilities def predictor(x): input_ids = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['input_ids'] attention_mask = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['attention_mask'] outputs = model.predict([input_ids, attention_mask]) probas = tf.nn.softmax(outputs).numpy() val = scipy.special.logit(probas[:,1]) return val def f_batch(x): val = np.array([]) for i in x: val = np.append(val, predictor(i)) return val explainer_roberta = shap.Explainer(f_batch, tokenizer) shap_values = explainer_roberta(["When I remember her I feel down"]) def get_shap_data(input_text): shap_values = explainer_roberta([input_text]) html_shap_content = shap.plots.text(shap_values, display=False) return html_shap_content css = """ textarea { background-color: #00000000; border: 1px solid #6366f160; } """ with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo: with gr.Row(): textmd = gr.Markdown('''

SpecX: Find the Right Specialist For Your Symptoms!

''') with gr.Row(): with gr.Column(scale=1, min_width=600): text_box = gr.Textbox(label="Explain your problem in one sentence.") with gr.Row(): submit_btn = gr.Button("Submit", elem_id="warningk", variant='primary') submit_btn_shap = gr.Button("Get SHAP Analysis", variant='secondary') examples = gr.Examples(examples=[ "When I remember her I feel down", "The area around my heart doesn't feel good.", "I have a split on my thumb that will not heal." ], inputs=text_box) with gr.Column(): label = gr.Label(num_top_classes=4, label="Recommended Specialist") gr.Markdown("## SHAP Analysis") shap_content_box = gr.HTML() submit_btn.click(inference, inputs=text_box, outputs=label) submit_btn_shap.click(get_shap_data, inputs=text_box, outputs=shap_content_box) with gr.Row(): textmd1 = gr.Markdown('''
''') with gr.Row(): with gr.Column(): textmd1 = gr.Markdown('''

98% Success Rate

''') with gr.Column(): textmd2 = gr.Markdown('''

99.9% Model uptime

''') with gr.Column(): textmd3 = gr.Markdown('''

77,223 Monthly active users

''') with gr.Column(): textmd3 = gr.Markdown('''

990,224 total model predictions made

''') demo.launch()