Pratik-hf's picture
Update app.py
f76ffff verified
raw
history blame
No virus
3.02 kB
import gradio as gr
from keras.models import load_model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
from transformers import pipeline
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import json
import os
# Loading the saved tokenizer and model
with open('tokenizer.json') as f:
data = json.load(f)
tokenizer = tokenizer_from_json(data)
loaded_model = load_model("BiLSTM_INAPPRO_TEXT_CLASSIFIER.h5")
# Function to classify text
def classify_text(text):
# Tokenize and pad sequences
sequence = tokenizer.texts_to_sequences([text])
padded_sequence = pad_sequences(sequence, maxlen=128)
result = loaded_model.predict(padded_sequence)
print(result)
if result[0][0] >= 0.5:
label = "Inappropriate"
else:
label = "Appropriate"
return [round(result[0][0], 4)*100, label]
model = pipeline("image-classification", model="Pratik-hf/Inappropriate-image-classification-using-ViT")
# Function to classify image
def classify_image(image):
print(image)
# Forward pass
with torch.no_grad():
outputs = model(image)
# Get predicted class probabilities
# Get the label with the highest probabilities
prediction = max(outputs, key=lambda x: x['score'])
if prediction['label'] == "LABEL_0":
prediction_label = "Safe"
else:
prediction_label = "Unsafe"
# Print predicted probabilities for each class
print("Predicted probabilities:", prediction)
return [round(prediction['score'], 4)*100, prediction_label]
# Define Gradio interface
def classify_inputs(text=None, image=None):
if text is not None:
text_result = classify_text(text)
if image is not None:
image_result = classify_image(image)
return text_result, image_result
with gr.Blocks() as demo:
with gr.Tab("Text"):
gr.Markdown(
"""
# Inappropriate text Detction
Give input below to see the output.
""")
text_input = gr.Textbox(label="Input Text", lines=5)
btn1 = gr.Button("Classify Text")
with gr.Row():
output_text_percentage = gr.Text(label="Percentage")
output_text_label = gr.Text(label="Label")
btn1.click(fn=classify_text, inputs=text_input, outputs=[output_text_percentage, output_text_label])
with gr.Tab("Image"):
gr.Markdown(
"""
# Inappropriate Image Detction
Give input below to see the output.
""")
image_input = gr.Image(type="pil")
btn2 = gr.Button("Classify Image")
with gr.Row():
output_image_percentage = gr.Text(label="Percentage")
output_image_label = gr.Text(label="Label")
btn2.click(fn=classify_image, inputs=image_input, outputs=[output_image_percentage, output_image_label] )
demo.launch(share = True)