Spaces:
Configuration error
Configuration error
Update segmentation.py
Browse files- segmentation.py +22 -9
segmentation.py
CHANGED
@@ -4,6 +4,8 @@ from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation,
|
|
4 |
import numpy as np
|
5 |
import torch.nn as nn
|
6 |
from scipy.ndimage import binary_dilation
|
|
|
|
|
7 |
|
8 |
model = None
|
9 |
extractor = None
|
@@ -14,13 +16,9 @@ def init_body():
|
|
14 |
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")
|
15 |
|
16 |
|
17 |
-
def get_mask(img: Image, body_part_id: int, inverse=False
|
18 |
-
|
19 |
-
|
20 |
-
outputs = model_face(**inputs)
|
21 |
-
else:
|
22 |
-
inputs = extractor(images=img, return_tensors="pt").to("cuda")
|
23 |
-
outputs = model(**inputs)
|
24 |
logits = outputs.logits.cpu()
|
25 |
|
26 |
upsampled_logits = nn.functional.interpolate(
|
@@ -40,9 +38,9 @@ def get_mask(img: Image, body_part_id: int, inverse=False, face=False):
|
|
40 |
return pil_seg
|
41 |
|
42 |
|
43 |
-
def get_cropped(img: Image, body_part_id: int, inverse:bool
|
44 |
|
45 |
-
pil_seg = get_mask(img, body_part_id, inverse
|
46 |
crop_mask_np = np.array(pil_seg.convert('L'))
|
47 |
crop_mask_binary = crop_mask_np > 128
|
48 |
|
@@ -71,4 +69,19 @@ def get_blurred_mask(img: Image, body_part_id: int):
|
|
71 |
return dilated_mask_blurred
|
72 |
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
|
|
4 |
import numpy as np
|
5 |
import torch.nn as nn
|
6 |
from scipy.ndimage import binary_dilation
|
7 |
+
import cv2
|
8 |
+
|
9 |
|
10 |
model = None
|
11 |
extractor = None
|
|
|
16 |
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")
|
17 |
|
18 |
|
19 |
+
def get_mask(img: Image, body_part_id: int, inverse=False):
|
20 |
+
inputs = extractor(images=img, return_tensors="pt").to("cuda")
|
21 |
+
outputs = model(**inputs)
|
|
|
|
|
|
|
|
|
22 |
logits = outputs.logits.cpu()
|
23 |
|
24 |
upsampled_logits = nn.functional.interpolate(
|
|
|
38 |
return pil_seg
|
39 |
|
40 |
|
41 |
+
def get_cropped(img: Image, body_part_id: int, inverse:bool):
|
42 |
|
43 |
+
pil_seg = get_mask(img, body_part_id, inverse)
|
44 |
crop_mask_np = np.array(pil_seg.convert('L'))
|
45 |
crop_mask_binary = crop_mask_np > 128
|
46 |
|
|
|
69 |
return dilated_mask_blurred
|
70 |
|
71 |
|
72 |
+
def get_face_crop(pil_image : Image):
|
73 |
+
image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_BGR2RGB)
|
74 |
+
face_casc = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
75 |
+
|
76 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
77 |
+
faces = face_casc.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
78 |
+
x, y, w, h = faces[0]
|
79 |
+
|
80 |
+
cropped_face = np.ones_like(np.array(pil_image))*255
|
81 |
+
cropped_face[y:y+h, x:x+w] = image[y:y+h, x:x+w]
|
82 |
+
|
83 |
+
img = cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB)
|
84 |
+
return Image.fromarray(img.astype("uint8")).convert("RGB")
|
85 |
+
|
86 |
+
|
87 |
|