altayavci commited on
Commit
3500945
·
1 Parent(s): 9f6cca4

Update segmentation.py

Browse files
Files changed (1) hide show
  1. segmentation.py +22 -9
segmentation.py CHANGED
@@ -4,6 +4,8 @@ from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation,
4
  import numpy as np
5
  import torch.nn as nn
6
  from scipy.ndimage import binary_dilation
 
 
7
 
8
  model = None
9
  extractor = None
@@ -14,13 +16,9 @@ def init_body():
14
  model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")
15
 
16
 
17
- def get_mask(img: Image, body_part_id: int, inverse=False, face=False):
18
- if face:
19
- inputs = extractor_face(images=img, return_tensors="pt").to("cuda")
20
- outputs = model_face(**inputs)
21
- else:
22
- inputs = extractor(images=img, return_tensors="pt").to("cuda")
23
- outputs = model(**inputs)
24
  logits = outputs.logits.cpu()
25
 
26
  upsampled_logits = nn.functional.interpolate(
@@ -40,9 +38,9 @@ def get_mask(img: Image, body_part_id: int, inverse=False, face=False):
40
  return pil_seg
41
 
42
 
43
- def get_cropped(img: Image, body_part_id: int, inverse:bool, face:bool):
44
 
45
- pil_seg = get_mask(img, body_part_id, inverse, face)
46
  crop_mask_np = np.array(pil_seg.convert('L'))
47
  crop_mask_binary = crop_mask_np > 128
48
 
@@ -71,4 +69,19 @@ def get_blurred_mask(img: Image, body_part_id: int):
71
  return dilated_mask_blurred
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
4
  import numpy as np
5
  import torch.nn as nn
6
  from scipy.ndimage import binary_dilation
7
+ import cv2
8
+
9
 
10
  model = None
11
  extractor = None
 
16
  model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")
17
 
18
 
19
+ def get_mask(img: Image, body_part_id: int, inverse=False):
20
+ inputs = extractor(images=img, return_tensors="pt").to("cuda")
21
+ outputs = model(**inputs)
 
 
 
 
22
  logits = outputs.logits.cpu()
23
 
24
  upsampled_logits = nn.functional.interpolate(
 
38
  return pil_seg
39
 
40
 
41
+ def get_cropped(img: Image, body_part_id: int, inverse:bool):
42
 
43
+ pil_seg = get_mask(img, body_part_id, inverse)
44
  crop_mask_np = np.array(pil_seg.convert('L'))
45
  crop_mask_binary = crop_mask_np > 128
46
 
 
69
  return dilated_mask_blurred
70
 
71
 
72
+ def get_face_crop(pil_image : Image):
73
+ image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_BGR2RGB)
74
+ face_casc = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
75
+
76
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
77
+ faces = face_casc.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
78
+ x, y, w, h = faces[0]
79
+
80
+ cropped_face = np.ones_like(np.array(pil_image))*255
81
+ cropped_face[y:y+h, x:x+w] = image[y:y+h, x:x+w]
82
+
83
+ img = cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB)
84
+ return Image.fromarray(img.astype("uint8")).convert("RGB")
85
+
86
+
87