import gradio as gr from PIL import Image from toxic_detection import TextToxicDetector from toxic_detection import ImgToxicDetector text_model = TextToxicDetector() text_model.load('szzzzz/xlm-roberta-base-text-toxic') img_model = ImgToxicDetector() img_model.load('./toxic_detection_res50.gz.tar') def image_toxic_detect(im): return img_model.detect(Image.fromarray(im)) def text_toxic_detect(text): print(text) return text_model.detect(text) with gr.Blocks() as app: gr.Markdown("Toxic Detection") with gr.Tab("Toxic Text Detector"): text_input_toxic = gr.Textbox() text_output_toxic = gr.Label(num_top_classes=1) text_button_toxic = gr.Button("text_toxic") with gr.Tab("Toxic Image Detector"): image_input_toxic = gr.Image() image_output_toxic = gr.Label(num_top_classes=2) image_button_toxic = gr.Button("image_toxic") text_button_toxic.click(text_toxic_detect, inputs=text_input_toxic, outputs=text_output_toxic) image_button_toxic.click(image_toxic_detect, inputs=image_input_toxic, outputs=image_output_toxic) app.launch(server_name="0.0.0.0")