mikecho's picture
Update app.py
b1bd905 verified
import streamlit as st
from transformers import pipeline
from PIL import Image
def main():
st.set_page_config(page_title="Unmasked the Target Customers", page_icon="🦜")
st.header("Turn the photos taken in the campaign to useful marketing insights")
uploaded_file = st.file_uploader("Select an Image...")
# define a function to extract the sub-image using
def extract_subimage(image, xmin, xmax, ymin, ymax):
# crop the sub-image using the provided coordinates
sub_image = image.crop((xmin, ymin, xmax, ymax))
# return the extracted sub-image
return sub_image
def pipeline_1_final(image_lst):
pipe = pipeline("object-detection", model="hustvl/yolos-tiny")
preds = pipe(image)
person_count = 0
sub_image_lst = []
for pred in preds:
if pred['label'] == 'person':
person_count +=1
box = pred['box']
xmin, ymin, xmax, ymax = box.values()
sub_image = extract_subimage(image,xmin, xmax, ymin, ymax)
sub_image_lst += [sub_image]
return sub_image_lst, person_count
def pipeline_2_final(image_lst):
age_lst = []
age_mapping = {"0-2": "lower than 10",
"3-9": "lower than 10",
"10-19":"10-19",
"20-29":"20-29",
"30-39":"30-39",
"40-49":"40-49",
"50-59":"50-59",
"60-69":"60-69",
"more than 70" : "70 or above"}
pipe = pipeline("image-classification", model="nateraw/vit-age-classifier")
for image in image_lst:
preds = pipe(image)
preds_age_range = preds[0]['label']
preds_age_range = age_mapping[preds_age_range]
age_lst +=[preds_age_range]
return age_lst
def pipeline_3_final(image_lst):
gender_lst = []
pipe = pipeline("image-classification", model="mikecho/NTQAI_pedestrian_gender_recognition_v1")
for image in image_lst:
preds = pipe(image)
preds_gender = preds[0]['label']
gender_lst +=[preds_gender]
return gender_lst
def gender_prediciton_model_NTQAI_pedestrian_gender_recognition(image_lst):
gender_lst = []
pipe = pipeline("image-classification", model="NTQAI/pedestrian_gender_recognition")
for image in image_lst:
preds = pipe(image)
preds_gender = preds[0]['label']
gender_lst +=[preds_gender]
return gender_lst
def pipeline_4_final(image_lst):
start_time = time.time()
pipe = pipeline("image-classification", model="dima806/facial_emotions_image_detection")
preds_lst = []
for image in image_lst:
preds = pipe(image)
preds_emotion = preds[0]['label']
preds_lst +=[preds_emotion]
return preds_lst
def generate_gender_tables(gender_list, age_list, emotion_list):
gender_count = {}
for gender, age, emotion in zip(gender_list, age_list, emotion_list):
if age not in gender_count:
gender_count[age] = {'male': 0, 'female': 0}
gender_count[age][gender] += 1
happiness_percentage = {}
for gender, age, emotion in zip(gender_list, age_list, emotion_list):
if age not in happiness_percentage:
happiness_percentage[age] = {'male': 0, 'female': 0}
if emotion == 'happiness':
happiness_percentage[age][gender] += 1
table1 = []
for age, count in gender_count.items():
male_count = count['male']
female_count = count['female']
table1.append([age, male_count, female_count])
table2 = []
for age, happiness in happiness_percentage.items():
male_count = gender_count[age]['male']
female_count = gender_count[age]['female']
male_percentage = (happiness['male'] / male_count) * 100 if male_count > 0 else 0
female_percentage = (happiness['female'] / female_count) * 100 if female_count > 0 else 0
table2.append([age, male_percentage, female_percentage])
return table1, table2
if uploaded_file is not None:
print(uploaded_file)
image = Image.open(uploaded_file)
st.image(uploaded_file, caption="Processing Image", use_column_width=True)
pipeline_1_out, person_count = pipeline_1_final(image)
pipeline_2_age = pipeline_2_final(pipeline_1_out)
pipeline_3_gender = pipeline_3_final(pipeline_1_out)
pipeline_4_emotion = pipeline_3_final(pipeline_1_out)
table1, table2 = generate_gender_tables(pipeline_3_gender, pipeline_2_age, pipeline_4_emotion)
st.text('The detected number of person:', person_count)
st.text('\nGender and Age Group Distribution')
st.text('Age, Male, Female')
for row in table1:
print(row)
st.text('\nShare of Happniess')
st.text('Age, Male, Female')
for row in table2:
print(row)
if __name__ == "__main__":
main()