nsfw_classifier / app.py
jayasuriyaK's picture
Update app.py
df33048 verified
raw
history blame contribute delete
No virus
3.84 kB
#run the app
#python -m streamlit run d:/NSFW/Project/test1.py
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import math, keras_ocr
# Initialize pipeline
pipeline = None
model_path="CustomModel"
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained(model_path)
import streamlit as st
def get_distance(predictions):
"""
Function returns dictionary with (key,value):
* text : detected text in image
* center_x : center of bounding box (x)
* center_y : center of bounding box (y)
* distance_from_origin : hypotenuse
* distance_y : distance between y and origin (0,0)
"""
# Point of origin
x0, y0 = 0, 0
# Generate dictionary
detections = []
for group in predictions:
# Get center point of bounding box
top_left_x, top_left_y = group[1][0]
bottom_right_x, bottom_right_y = group[1][1]
center_x, center_y = (top_left_x + bottom_right_x)/2, (top_left_y + bottom_right_y)/2
# Use the Pythagorean Theorem to solve for distance from origin
distance_from_origin = math.dist([x0,y0], [center_x, center_y])
# Calculate difference between y and origin to get unique rows
distance_y = center_y - y0
# Append all results
detections.append({
'text': group[0],
'center_x': center_x,
'center_y': center_y,
'distance_from_origin': distance_from_origin,
'distance_y': distance_y
})
return detections
def distinguish_rows(lst, thresh=10):
"""Function to help distinguish unique rows"""
sublists = []
for i in range(0, len(lst)-1):
if (lst[i+1]['distance_y'] - lst[i]['distance_y'] <= thresh):
if lst[i] not in sublists:
sublists.append(lst[i])
sublists.append(lst[i+1])
else:
yield sublists
sublists = [lst[i+1]]
yield sublists
# Title of the app
st.title("NSFW Content Detector")
# File uploader widget
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
def initialize():
global pipeline
if pipeline==None:
pipeline=keras_ocr.pipeline.Pipeline()
if uploaded_file is not None:
st.image(uploaded_file, caption='Uploaded Image', width=200)
#st.image(uploaded_file, caption='Uploaded Image', use_column_width=True)
initialize()
# Read in image
read_image = keras_ocr.tools.read(uploaded_file)
# prediction_groups is a list of (word, box) tuples
prediction_groups = pipeline.recognize([read_image])
predictions = prediction_groups[0] # extract text list
predictions = get_distance(predictions)
# Set thresh higher for text further apart
predictions = list(distinguish_rows(predictions, thresh=10))
# Remove all empty rows
predictions = list(filter(lambda x:x!=[], predictions))
# Order text detections in human readable format
ordered_preds = []
for row in predictions:
row = sorted(row, key=lambda x:x['distance_from_origin'])
for each in row: ordered_preds.append(each['text'])
# Join detections into sentence
sentance = ' '.join(ordered_preds)
#st.write(sentance)
input_text =sentance
print(input_text)
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model(**inputs)
predictions = outputs.logits.softmax(dim=-1)
print(predictions[0][0],predictions[0][1])
if predictions[0][0]>predictions[0][1]:
print('safe')
st.write('Safe for Work')
else:
print('Not safe')
st.write('Not Safe for Work')