import gradio as gr import tensorflow as tf from transformers import TFAutoModel, AutoTokenizer import os import numpy as np import shap import scipy.special model_name = 'cardiffnlp/twitter-roberta-base-sentiment-latest' tokenizer = AutoTokenizer.from_pretrained(model_name) model = tf.keras.models.load_model( "model.h5", custom_objects={ 'TFRobertaModel': TFAutoModel.from_pretrained(model_name) } ) labels = [ 'Cardiologist', 'Dermatologist', 'ENT Specialist', 'Gastro-enterologist', 'General-Physicians', 'Neurologist/Gastro-enterologist', 'Ophthalmologist', 'Orthopedist', 'Psychiatrist', 'Respirologist', 'Rheumatologist', 'Rheumatologist/Gastro-enterologist', 'Rheumatologist/Orthopedist', 'Surgeon' ] seq_len = 152 def prep_data(text): tokens = tokenizer( text, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf' ) return { 'input_ids': tokens['input_ids'], 'attention_mask': tokens['attention_mask'] } def inference(text): encoded_text = prep_data(text) probs = model.predict_on_batch(encoded_text) probabilities = {i:j for i,j in zip(labels, list(probs.flatten()))} return probabilities def predictor(x): input_ids = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['input_ids'] attention_mask = tokenizer(x, max_length=seq_len, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='tf')['attention_mask'] outputs = model.predict([input_ids, attention_mask]) probas = tf.nn.softmax(outputs).numpy() val = scipy.special.logit(probas[:,1]) return val def f_batch(x): val = np.array([]) for i in x: val = np.append(val, predictor(i)) return val explainer_roberta = shap.Explainer(f_batch, tokenizer) shap_values = explainer_roberta(["When I remember her I feel down"]) def get_shap_data(input_text): shap_values = explainer_roberta([input_text]) html_shap_content = shap.plots.text(shap_values, display=False) return html_shap_content css = """ textarea { background-color: #00000000; border: 1px solid #6366f160; } """ with gr.Blocks(title="SpecX", css=css, theme=gr.themes.Soft()) as demo: with gr.Row(): textmd = gr.Markdown('''