sukh28's picture
Update app.py
d4ecc33
raw
history blame contribute delete
No virus
2.05 kB
# -*- coding: utf-8 -*-
"""gradio_app.py
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1OQvi3I_q3WfavYBpjovCYfv2SPYt__pF
"""
"""gradio_app.py
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1OQvi3I_q3WfavYBpjovCYfv2SPYt__pF
"""
import json
import gradio as gr
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import tokenizer_from_json
import tensorflow_addons as tfa
# Load the pre-trained model and tokenizer
model = tf.keras.models.load_model('baseline.h5')
# Assuming you have already loaded the tokenizer configuration from the JSON file.
# Replace 'path' with the actual path to the directory where 'tokenizer.json' is saved.
with open('tokenizer.json', 'r', encoding='utf-8') as f:
tokenizer_config = json.load(f)
tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(tokenizer_config)
# Define the labels for classification
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
def classify_comment(comment):
# Tokenize the comment and convert it into sequences
comment_sequence = tokenizer.texts_to_sequences([comment])
comment_sequence = tf.keras.preprocessing.sequence.pad_sequences(comment_sequence, maxlen=200)
# Make predictions
predictions = model.predict(comment_sequence)[0]
results = dict(zip(labels, predictions))
max_value = max(results.values())
max_keys = [key for key, value in results.items() if value == max_value]
return max_keys[0].capitalize()
# Create the Gradio interface
comment_input = gr.inputs.Textbox(label="Enter your comment here")
output_text = gr.outputs.Textbox(label="Classification Results")
iface = gr.Interface(
fn=classify_comment,
inputs=comment_input,
outputs=output_text,
live=True # Set to True for live updates without needing to restart the server
)
# Launch the Gradio app
iface.launch()