nsfw_classifier / app.py
jayasuriyaK's picture
Upload 5 files
664a6cd verified
raw
history blame
3.64 kB
#run the app
#python -m streamlit run d:/NSFW/Project/test1.py
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import math, keras_ocr
# Initialize pipeline
pipeline = keras_ocr.pipeline.Pipeline()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model_2 = BertForSequenceClassification.from_pretrained("CustomModel")
model_2.to('cpu')
import streamlit as st
def get_distance(predictions):
"""
Function returns dictionary with (key,value):
* text : detected text in image
* center_x : center of bounding box (x)
* center_y : center of bounding box (y)
* distance_from_origin : hypotenuse
* distance_y : distance between y and origin (0,0)
"""
# Point of origin
x0, y0 = 0, 0
# Generate dictionary
detections = []
for group in predictions:
# Get center point of bounding box
top_left_x, top_left_y = group[1][0]
bottom_right_x, bottom_right_y = group[1][1]
center_x, center_y = (top_left_x + bottom_right_x)/2, (top_left_y + bottom_right_y)/2
# Use the Pythagorean Theorem to solve for distance from origin
distance_from_origin = math.dist([x0,y0], [center_x, center_y])
# Calculate difference between y and origin to get unique rows
distance_y = center_y - y0
# Append all results
detections.append({
'text': group[0],
'center_x': center_x,
'center_y': center_y,
'distance_from_origin': distance_from_origin,
'distance_y': distance_y
})
return detections
def distinguish_rows(lst, thresh=15):
"""Function to help distinguish unique rows"""
sublists = []
for i in range(0, len(lst)-1):
if (lst[i+1]['distance_y'] - lst[i]['distance_y'] <= thresh):
if lst[i] not in sublists:
sublists.append(lst[i])
sublists.append(lst[i+1])
else:
yield sublists
sublists = [lst[i+1]]
yield sublists
# Title of the app
st.title("Image Input App")
# File uploader widget
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
# Read in image
read_image = keras_ocr.tools.read(uploaded_file)
# prediction_groups is a list of (word, box) tuples
prediction_groups = pipeline.recognize([read_image])
predictions = prediction_groups[0] # extract text list
predictions = get_distance(predictions)
# Set thresh higher for text further apart
predictions = list(distinguish_rows(predictions, thresh=10))
# Remove all empty rows
predictions = list(filter(lambda x:x!=[], predictions))
# Order text detections in human readable format
ordered_preds = []
for row in predictions:
row = sorted(row, key=lambda x:x['distance_from_origin'])
for each in row: ordered_preds.append(each['text'])
# Join detections into sentence
sentance = ' '.join(ordered_preds)
#st.write(sentance)
text =sentance
print(text)
inputs = tokenizer(text,padding = True, truncation = True, return_tensors='pt').to('cpu')
outputs = model_2(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predictions = predictions.cpu().detach().numpy()
print(predictions[0][0],predictions[0][1])
if predictions[0][0]>predictions[0][1]:
print('safe')
st.write('safe')
else:
print('Not safe')
st.write('n safe')