dikarel commited on
Commit
04c2d74
0 Parent(s):

first working prototype

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. app.py +52 -0
  3. lib/cloth_seg.py +30 -0
  4. lib/find_people.py +44 -0
  5. lib/redraw_image.py +22 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv
2
+ __pycache__
3
+ .DS_Store
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from random import choice
3
+ from lib.redraw_image import redraw_image
4
+ from lib.find_people import find_people
5
+ from PIL import Image
6
+ from PIL.Image import Image as PILImage
7
+
8
+ OUTFIT_SELECTION = [
9
+ "Summer dress",
10
+ "Winter coat",
11
+ "Fall jacket",
12
+ "Formal wear",
13
+ ]
14
+
15
+
16
+ def main():
17
+ with gr.Blocks() as demo:
18
+ with gr.Row():
19
+ with gr.Column():
20
+ img_input = gr.Image(label="Image of yourself")
21
+ drp_outfit = gr.Dropdown(
22
+ label="Select a new outfit",
23
+ choices=OUTFIT_SELECTION,
24
+ value=choice(OUTFIT_SELECTION),
25
+ )
26
+
27
+ with gr.Column():
28
+ btn_change = gr.Button(value="Change outfit")
29
+ img_output = gr.Image(label="Image of you wearing a dress")
30
+
31
+ btn_change.click(
32
+ generate_output, inputs=[img_input, drp_outfit], outputs=[img_output]
33
+ )
34
+
35
+ demo.queue().launch()
36
+
37
+
38
+ def generate_output(img_input: PILImage, drp_outfit: str) -> PILImage:
39
+ img_input = Image.fromarray(img_input)
40
+
41
+ people_mask = find_people(img_input)
42
+ img_output = redraw_image(
43
+ prompt=f"person wearing {drp_outfit}",
44
+ image=img_input,
45
+ mask=people_mask,
46
+ )
47
+
48
+ return img_output
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
lib/cloth_seg.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum
2
+
3
+
4
+ class ClothSeg(IntEnum):
5
+ BACKGROUND = 0
6
+ HAT = 1
7
+ HAIR = 2
8
+ SUNGLASSES = 3
9
+ UPPER_CLOTHES = 4
10
+ SKIRT = 5
11
+ PANTS = 6
12
+ DRESS = 7
13
+ BELT = 8
14
+ LEFT_SHOE = 9
15
+ RIGHT_SHOE = 10
16
+ FACE = 11
17
+ LEFT_LEG = 12
18
+ RIGHT_LEG = 13
19
+ LEFT_ARM = 14
20
+ RIGHT_ARM = 15
21
+ BAG = 16
22
+ SCARF = 17
23
+
24
+
25
+ def everyhing_but_background_face_and_hair() -> list[ClothSeg]:
26
+ return [
27
+ t
28
+ for t in ClothSeg
29
+ if t not in [ClothSeg.BACKGROUND, ClothSeg.HAIR, ClothSeg.FACE]
30
+ ]
lib/find_people.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
2
+ from PIL.Image import Image as PILImage
3
+ from PIL import Image
4
+ from lib.cloth_seg import everyhing_but_background_face_and_hair
5
+ from torch import zeros_like
6
+ from torch.nn.functional import interpolate
7
+ from functools import cache
8
+
9
+
10
+ def find_people(image: PILImage) -> PILImage:
11
+ processor = get_processor()
12
+ model = get_model()
13
+
14
+ inputs = processor(images=image, return_tensors="pt").to("cuda")
15
+ logits = model(**inputs).logits.cpu()
16
+
17
+ upsampled_logits = interpolate(
18
+ logits,
19
+ size=image.size[::-1],
20
+ mode="bilinear",
21
+ align_corners=False,
22
+ )
23
+
24
+ predictions = upsampled_logits.argmax(dim=1)[0]
25
+ mask = zeros_like(predictions)
26
+
27
+ for type in everyhing_but_background_face_and_hair():
28
+ mask += (predictions == type.value).long()
29
+
30
+ return Image.fromarray((mask * 255).byte().numpy(), "L")
31
+
32
+
33
+ @cache
34
+ def get_processor():
35
+ return SegformerImageProcessor.from_pretrained(
36
+ "mattmdjaga/segformer_b2_clothes", device="cuda"
37
+ )
38
+
39
+
40
+ @cache
41
+ def get_model():
42
+ return AutoModelForSemanticSegmentation.from_pretrained(
43
+ "mattmdjaga/segformer_b2_clothes"
44
+ ).to("cuda")
lib/redraw_image.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL.Image import Image as PILImage
2
+ from functools import cache
3
+ from diffusers import StableDiffusionInpaintPipeline
4
+
5
+
6
+ def redraw_image(prompt: str, image: PILImage, mask: PILImage) -> PILImage:
7
+ inpaint_model = get_inpaint_model()
8
+
9
+ return inpaint_model(
10
+ prompt=prompt,
11
+ image=image,
12
+ mask_image=mask,
13
+ width=image.width,
14
+ height=image.height,
15
+ ).images[0]
16
+
17
+
18
+ @cache
19
+ def get_inpaint_model():
20
+ return StableDiffusionInpaintPipeline.from_pretrained(
21
+ "runwayml/stable-diffusion-inpainting"
22
+ ).to("cuda")