nicholasKluge's picture
Update app.py
6774d3c verified
raw
history blame contribute delete
No virus
3.2 kB
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from transformers import TextClassificationPipeline
import tensorflow as tf
import gradio as gr
# Load the model and tokenizer
model = TFAutoModelForSequenceClassification.from_pretrained("AiresPucrs/distilbert-base-cased-sentiment-classifier")
tokenizer = AutoTokenizer.from_pretrained("AiresPucrs/distilbert-base-cased-sentiment-classifier")
def get_gradients(text, model, tokenizer):
embedding_matrix = model.distilbert.embeddings.weights[0]
vocab_size = embedding_matrix.shape[0]
encoded_tokens = tokenizer(text, return_tensors="tf")
token_ids = list(encoded_tokens["input_ids"].numpy()[0])
token_ids_tensor = tf.constant([token_ids], dtype='int32')
token_ids_tensor_one_hot = tf.one_hot(token_ids_tensor, vocab_size)
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(token_ids_tensor_one_hot)
inputs_embeds = tf.matmul(token_ids_tensor_one_hot,embedding_matrix)
logits = model({"inputs_embeds": inputs_embeds, "attention_mask": encoded_tokens["attention_mask"] } ).logits
prediction_class = tf.argmax(logits, axis=1).numpy()[0]
target_logit = logits[0][prediction_class]
gradient_non_normalized = tf.norm(
tape.gradient(target_logit, token_ids_tensor_one_hot),axis=2)
gradient_tensor = (
gradient_non_normalized /
tf.reduce_max(gradient_non_normalized)
)[0].numpy().tolist()
token_words = tokenizer.convert_ids_to_tokens(token_ids)
return gradient_tensor, token_words, prediction_class
def interpet_DistilBERT(text):
gradient_tensor, token_words, prediction_class = get_gradients(text, model, tokenizer)
token_words = token_words[1:-1]
gradient_tensor = gradient_tensor[1:-1]
total = sum(gradient_tensor)
normalized_gradient_tensor = [x/total for x in gradient_tensor]
output = f"<b>Predicted Answer:</b><br> <i>{model.config.id2label[int(prediction_class)]}</i><br><br><b>Gradient Scores:</b>"
return output, {token_words[i]: normalized_gradient_tensor[i] for i in range(len(normalized_gradient_tensor))}
description = (
"<center>"
"<h1>Explaining DistilBERT with integrated gradients 🏰</h1>"
"<br>This app was built to provide insight into how a DistilBERT model, fine-tuned for text classification, operates via integrated gradient explanations. Enter a text and see the gradient scores for each word in the input.<br>"
"To learn more, visit this <a href='https://github.com/Nkluge-correa/teeny-tiny_castle/blob/master/ML%20Explainability/NLP%20Interpreter/gradient_explanations_BERT.ipynb'>tutorial</a><br>"
"</center>"
)
article = (
"<center>"
"Return to the <a href='https://github.com/Nkluge-correa/teeny-tiny_castle'>castle</a>."
"</center>"
)
# Create the Gradio interface
interface = gr.Interface(
fn=interpet_DistilBERT,
inputs=gr.Textbox(placeholder="Enter text here..."),
outputs=["html", gr.Label(num_top_classes=10)],
allow_flagging="never",
description=description,
article=article,
)
# Launch the Gradio interface
interface.launch()