import streamlit as st from PIL import Image import numpy as np from joblib import load from skimage.transform import resize import torch import os import sys # Ensure to run these commands in your terminal first: # pip install git+https://github.com/FacePerceiver/facer.git@main # pip install timm # git clone https://github.com/FacePerceiver/facer.git # Set the path for the 'facer' module sys.path.append('facer') import facer # Load face parsing model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') face_detector = facer.face_detector('retinaface/mobilenet', device=device) face_parser = facer.face_parser('farl/lapa/448', device=device) # Define the monk scale colors monk_scale = { 'Class2': (243, 231, 219), # f3e7db 'Class3': (247, 234, 208), # f7ead0 'Class4': (234, 218, 186), # eadaba 'Class5': (215, 189, 150), # d7bd96 'Class6': (160, 126, 86), # a07e56 'Class7': (130, 92, 67), # 825c43 'Class8': (96, 65, 52), # 604134 'Class9': (58, 49, 42), # 3a312a 'Class10': (41, 36, 32), # 292420 } # Function to convert RGB tuple to hex color code def rgb_to_hex(rgb): return '#{:02x}{:02x}{:02x}'.format(*rgb) # Mapping of Monk classes to colors using monk_scale monk_colors = { '1': [rgb_to_hex(monk_scale['Class2']), rgb_to_hex(monk_scale['Class3']), rgb_to_hex(monk_scale['Class4'])], '2': [rgb_to_hex(monk_scale['Class5']), rgb_to_hex(monk_scale['Class6'])], '3': [rgb_to_hex(monk_scale['Class7']), rgb_to_hex(monk_scale['Class8'])], '4': [rgb_to_hex(monk_scale['Class9']), rgb_to_hex(monk_scale['Class10'])], 'default': '#808080' # Default color for unexpected classes } # Mapping of model's output classes to monk classes class_mapping = { 0: '1', # Map model class 0 to monk class 1 1: '2', # Map model class 1 to monk class 2 2: '3', # Map model class 2 to monk class 3 3: '4', # Map model class 3 to monk class 4 # Add more mappings if needed } # Function to load the model def load_model(): model_path = r"C:\Users\ramam\svm_model3.joblib" # Adjust the path to your model model = load(model_path) return model # Function to parse face and extract skin region def parse_face(image): # Ensure the image has 3 channels (RGB) if image.mode != 'RGB': image = image.convert('RGB') image_data = np.array(image) # Check if the image has 3 channels if image_data.shape[2] != 3: raise ValueError("Image does not have 3 channels (RGB).") image_tensor = torch.from_numpy(image_data.astype('float32')).permute(2, 0, 1).unsqueeze(0).to(device) faces = face_detector(image_tensor) if faces: parsed_faces = face_parser(image_tensor, faces) if 'seg' in parsed_faces: seg_logits = parsed_faces['seg']['logits'] seg_probs = torch.sigmoid(seg_logits) binary_mask = seg_probs[0, 1, :, :] > 0.5 binary_mask = binary_mask.cpu().numpy() binary_mask_3d = np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2) skin_region = image_data * binary_mask_3d return skin_region.astype(np.uint8) return None # Function to make predictions def classify_image(image, model): parsed_image = parse_face(image) if parsed_image is not None: image_resized = resize(parsed_image, (128, 128), anti_aliasing=True) # Resize to 128x128 image_reshaped = image_resized.reshape(1, -1) # Reshape to match the model input if image_reshaped.shape[1] == 49152: # Check if resizing is correct image_padded = np.pad(image_reshaped, ((0, 0), (0, 65536 - 49152)), 'constant') else: raise ValueError("Unexpected number of features after reshaping.") prediction = model.predict(image_padded) return prediction[0], parsed_image else: raise ValueError("Face parsing failed.") # Load the model model = load_model() # Function to display the Monk class color def display_monk_class_color(prediction): st.write(f"Prediction: {prediction}") # Debugging monk_class = class_mapping.get(prediction, 'default') colors = monk_colors.get(monk_class, monk_colors['default']) # Default to gray if class not found st.write(f"Monk Class: {monk_class}") for color in colors: st.markdown(f"
", unsafe_allow_html=True) # Streamlit app st.title('Skin Tone Classification') uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image.', use_column_width=True) if st.button('Classify'): try: prediction, parsed_image = classify_image(image, model) display_monk_class_color(prediction) st.image(parsed_image, caption='Parsed Image.', use_column_width=True) except ValueError as e: st.error(f"Error: {e}")