ankush-003's picture
Update app.py
c5cc59e
raw
history blame
1.2 kB
import gradio as gr
import tensorflow as tf
# from transformers import AutoTokenizer
# from transformers import TFAutoModelForSequenceClassification
# Load model directly
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
# tokenizer = AutoTokenizer.from_pretrained("ankush-003/nosqli_identifier")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = TFAutoModelForSequenceClassification.from_pretrained("ankush-003/nosqli_identifier")
def predict(payload, malitious):
inputs = tokenizer(payload, return_tensors="tf")
# model = TFAutoModelForSequenceClassification.from_pretrained("ankush-003/nosqli_identifier")
logits = model(**inputs).logits
predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
# print(model.config.id2label[predicted_class_id])
expected = "Malitious" if malitious else "Benign"
return model.config.id2label[predicted_class_id], expected
demo = gr.Interface(
fn=predict,
inputs=["text","checkbox"],
outputs=[gr.Textbox(label="Model Prediction"),gr.Textbox(label="Expected")]
)
demo.launch(debug=True)
# gr.Interface.load("models/ankush-003/nosqli_identifier").launch()