rimasalshehri commited on
Commit
2e6c004
1 Parent(s): 05002be

Upload AuraBloom.py

Browse files
Files changed (1) hide show
  1. AuraBloom.py +134 -0
AuraBloom.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import numpy as np
4
+ from joblib import load
5
+ from skimage.transform import resize
6
+ import torch
7
+ import os
8
+ import sys
9
+
10
+ # Ensure to run these commands in your terminal first:
11
+ # pip install git+https://github.com/FacePerceiver/facer.git@main
12
+ # pip install timm
13
+ # git clone https://github.com/FacePerceiver/facer.git
14
+
15
+ # Set the path for the 'facer' module
16
+ sys.path.append('facer')
17
+
18
+ import facer
19
+
20
+ # Load face parsing model
21
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ face_detector = facer.face_detector('retinaface/mobilenet', device=device)
23
+ face_parser = facer.face_parser('farl/lapa/448', device=device)
24
+
25
+ # Define the monk scale colors
26
+ monk_scale = {
27
+ 'Class2': (243, 231, 219), # f3e7db
28
+ 'Class3': (247, 234, 208), # f7ead0
29
+ 'Class4': (234, 218, 186), # eadaba
30
+ 'Class5': (215, 189, 150), # d7bd96
31
+ 'Class6': (160, 126, 86), # a07e56
32
+ 'Class7': (130, 92, 67), # 825c43
33
+ 'Class8': (96, 65, 52), # 604134
34
+ 'Class9': (58, 49, 42), # 3a312a
35
+ 'Class10': (41, 36, 32), # 292420
36
+ }
37
+
38
+ # Function to convert RGB tuple to hex color code
39
+ def rgb_to_hex(rgb):
40
+ return '#{:02x}{:02x}{:02x}'.format(*rgb)
41
+
42
+ # Mapping of Monk classes to colors using monk_scale
43
+ monk_colors = {
44
+ '1': [rgb_to_hex(monk_scale['Class2']), rgb_to_hex(monk_scale['Class3']), rgb_to_hex(monk_scale['Class4'])],
45
+ '2': [rgb_to_hex(monk_scale['Class5']), rgb_to_hex(monk_scale['Class6'])],
46
+ '3': [rgb_to_hex(monk_scale['Class7']), rgb_to_hex(monk_scale['Class8'])],
47
+ '4': [rgb_to_hex(monk_scale['Class9']), rgb_to_hex(monk_scale['Class10'])],
48
+ 'default': '#808080' # Default color for unexpected classes
49
+ }
50
+
51
+ # Mapping of model's output classes to monk classes
52
+ class_mapping = {
53
+ 0: '1', # Map model class 0 to monk class 1
54
+ 1: '2', # Map model class 1 to monk class 2
55
+ 2: '3', # Map model class 2 to monk class 3
56
+ 3: '4', # Map model class 3 to monk class 4
57
+ # Add more mappings if needed
58
+ }
59
+
60
+ # Function to load the model
61
+ def load_model():
62
+ model_path = r"C:\Users\ramam\svm_model3.joblib" # Adjust the path to your model
63
+ model = load(model_path)
64
+ return model
65
+
66
+ # Function to parse face and extract skin region
67
+ def parse_face(image):
68
+ # Ensure the image has 3 channels (RGB)
69
+ if image.mode != 'RGB':
70
+ image = image.convert('RGB')
71
+
72
+ image_data = np.array(image)
73
+
74
+ # Check if the image has 3 channels
75
+ if image_data.shape[2] != 3:
76
+ raise ValueError("Image does not have 3 channels (RGB).")
77
+
78
+ image_tensor = torch.from_numpy(image_data.astype('float32')).permute(2, 0, 1).unsqueeze(0).to(device)
79
+ faces = face_detector(image_tensor)
80
+
81
+ if faces:
82
+ parsed_faces = face_parser(image_tensor, faces)
83
+ if 'seg' in parsed_faces:
84
+ seg_logits = parsed_faces['seg']['logits']
85
+ seg_probs = torch.sigmoid(seg_logits)
86
+ binary_mask = seg_probs[0, 1, :, :] > 0.5
87
+ binary_mask = binary_mask.cpu().numpy()
88
+ binary_mask_3d = np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2)
89
+ skin_region = image_data * binary_mask_3d
90
+ return skin_region.astype(np.uint8)
91
+ return None
92
+
93
+ # Function to make predictions
94
+ def classify_image(image, model):
95
+ parsed_image = parse_face(image)
96
+ if parsed_image is not None:
97
+ image_resized = resize(parsed_image, (128, 128), anti_aliasing=True) # Resize to 128x128
98
+ image_reshaped = image_resized.reshape(1, -1) # Reshape to match the model input
99
+ if image_reshaped.shape[1] == 49152: # Check if resizing is correct
100
+ image_padded = np.pad(image_reshaped, ((0, 0), (0, 65536 - 49152)), 'constant')
101
+ else:
102
+ raise ValueError("Unexpected number of features after reshaping.")
103
+ prediction = model.predict(image_padded)
104
+ return prediction[0], parsed_image
105
+ else:
106
+ raise ValueError("Face parsing failed.")
107
+
108
+ # Load the model
109
+ model = load_model()
110
+
111
+ # Function to display the Monk class color
112
+ def display_monk_class_color(prediction):
113
+ st.write(f"Prediction: {prediction}") # Debugging
114
+ monk_class = class_mapping.get(prediction, 'default')
115
+ colors = monk_colors.get(monk_class, monk_colors['default']) # Default to gray if class not found
116
+ st.write(f"Monk Class: {monk_class}")
117
+ for color in colors:
118
+ st.markdown(f"<div style='width:100px; height:50px; background-color:{color};'></div>", unsafe_allow_html=True)
119
+
120
+ # Streamlit app
121
+ st.title('Skin Tone Classification')
122
+
123
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
124
+ if uploaded_file is not None:
125
+ image = Image.open(uploaded_file)
126
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
127
+
128
+ if st.button('Classify'):
129
+ try:
130
+ prediction, parsed_image = classify_image(image, model)
131
+ display_monk_class_color(prediction)
132
+ st.image(parsed_image, caption='Parsed Image.', use_column_width=True)
133
+ except ValueError as e:
134
+ st.error(f"Error: {e}")