test / AuraBloom.py
rimasalshehri's picture
Upload AuraBloom.py
2e6c004 verified
raw
history blame contribute delete
No virus
5.06 kB
import streamlit as st
from PIL import Image
import numpy as np
from joblib import load
from skimage.transform import resize
import torch
import os
import sys
# Ensure to run these commands in your terminal first:
# pip install git+https://github.com/FacePerceiver/facer.git@main
# pip install timm
# git clone https://github.com/FacePerceiver/facer.git
# Set the path for the 'facer' module
sys.path.append('facer')
import facer
# Load face parsing model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
face_parser = facer.face_parser('farl/lapa/448', device=device)
# Define the monk scale colors
monk_scale = {
'Class2': (243, 231, 219), # f3e7db
'Class3': (247, 234, 208), # f7ead0
'Class4': (234, 218, 186), # eadaba
'Class5': (215, 189, 150), # d7bd96
'Class6': (160, 126, 86), # a07e56
'Class7': (130, 92, 67), # 825c43
'Class8': (96, 65, 52), # 604134
'Class9': (58, 49, 42), # 3a312a
'Class10': (41, 36, 32), # 292420
}
# Function to convert RGB tuple to hex color code
def rgb_to_hex(rgb):
return '#{:02x}{:02x}{:02x}'.format(*rgb)
# Mapping of Monk classes to colors using monk_scale
monk_colors = {
'1': [rgb_to_hex(monk_scale['Class2']), rgb_to_hex(monk_scale['Class3']), rgb_to_hex(monk_scale['Class4'])],
'2': [rgb_to_hex(monk_scale['Class5']), rgb_to_hex(monk_scale['Class6'])],
'3': [rgb_to_hex(monk_scale['Class7']), rgb_to_hex(monk_scale['Class8'])],
'4': [rgb_to_hex(monk_scale['Class9']), rgb_to_hex(monk_scale['Class10'])],
'default': '#808080' # Default color for unexpected classes
}
# Mapping of model's output classes to monk classes
class_mapping = {
0: '1', # Map model class 0 to monk class 1
1: '2', # Map model class 1 to monk class 2
2: '3', # Map model class 2 to monk class 3
3: '4', # Map model class 3 to monk class 4
# Add more mappings if needed
}
# Function to load the model
def load_model():
model_path = r"C:\Users\ramam\svm_model3.joblib" # Adjust the path to your model
model = load(model_path)
return model
# Function to parse face and extract skin region
def parse_face(image):
# Ensure the image has 3 channels (RGB)
if image.mode != 'RGB':
image = image.convert('RGB')
image_data = np.array(image)
# Check if the image has 3 channels
if image_data.shape[2] != 3:
raise ValueError("Image does not have 3 channels (RGB).")
image_tensor = torch.from_numpy(image_data.astype('float32')).permute(2, 0, 1).unsqueeze(0).to(device)
faces = face_detector(image_tensor)
if faces:
parsed_faces = face_parser(image_tensor, faces)
if 'seg' in parsed_faces:
seg_logits = parsed_faces['seg']['logits']
seg_probs = torch.sigmoid(seg_logits)
binary_mask = seg_probs[0, 1, :, :] > 0.5
binary_mask = binary_mask.cpu().numpy()
binary_mask_3d = np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2)
skin_region = image_data * binary_mask_3d
return skin_region.astype(np.uint8)
return None
# Function to make predictions
def classify_image(image, model):
parsed_image = parse_face(image)
if parsed_image is not None:
image_resized = resize(parsed_image, (128, 128), anti_aliasing=True) # Resize to 128x128
image_reshaped = image_resized.reshape(1, -1) # Reshape to match the model input
if image_reshaped.shape[1] == 49152: # Check if resizing is correct
image_padded = np.pad(image_reshaped, ((0, 0), (0, 65536 - 49152)), 'constant')
else:
raise ValueError("Unexpected number of features after reshaping.")
prediction = model.predict(image_padded)
return prediction[0], parsed_image
else:
raise ValueError("Face parsing failed.")
# Load the model
model = load_model()
# Function to display the Monk class color
def display_monk_class_color(prediction):
st.write(f"Prediction: {prediction}") # Debugging
monk_class = class_mapping.get(prediction, 'default')
colors = monk_colors.get(monk_class, monk_colors['default']) # Default to gray if class not found
st.write(f"Monk Class: {monk_class}")
for color in colors:
st.markdown(f"<div style='width:100px; height:50px; background-color:{color};'></div>", unsafe_allow_html=True)
# Streamlit app
st.title('Skin Tone Classification')
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image.', use_column_width=True)
if st.button('Classify'):
try:
prediction, parsed_image = classify_image(image, model)
display_monk_class_color(prediction)
st.image(parsed_image, caption='Parsed Image.', use_column_width=True)
except ValueError as e:
st.error(f"Error: {e}")