nlightcho commited on
Commit
f43b2e5
1 Parent(s): 8b1b908
Files changed (1) hide show
  1. app.py +72 -1
app.py CHANGED
@@ -1,6 +1,77 @@
1
  import os
2
  import gradio as gr
 
 
 
 
 
3
 
4
  hf_token = os.environ.get("HF_TOKEN", None)
 
5
 
6
- gr.load("models/neuratech-ai/person_segmentation_v3", hf_token=hf_token).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision.transforms.functional import to_tensor, normalize
7
+ from transformers import SegformerForSemanticSegmentation
8
 
9
  hf_token = os.environ.get("HF_TOKEN", None)
10
+ device = torch.device("cpu")
11
 
12
+ label2id = {"background": 0, "skin": 1, "hair": 2, "clothes": 3, "accessories": 4}
13
+ id2label = {v: k for k, v in label2id.items()}
14
+ colors = {
15
+ "background": (40, 40, 40),
16
+ "skin": (255, 178, 127),
17
+ "hair": (139, 69, 19),
18
+ "clothes": (100, 149, 237),
19
+ "accessories": (255, 215, 0),
20
+ }
21
+
22
+ model = SegformerForSemanticSegmentation.from_pretrained(
23
+ "neuratech-ai/person_segmentation_v3",
24
+ token=hf_token,
25
+ ignore_mismatched_sizes=True,
26
+ num_labels=len(label2id),
27
+ id2label=id2label,
28
+ label2id=label2id,
29
+ )
30
+ model.eval()
31
+ model.to(device)
32
+
33
+
34
+ def preds_to_rgb(preds):
35
+ preds_rgb = np.zeros((preds.shape[0], preds.shape[1], 3), dtype=np.uint8)
36
+
37
+ for class_name, class_id in label2id.items():
38
+ preds_rgb[preds == class_id] = colors[class_name]
39
+
40
+ return preds_rgb
41
+
42
+
43
+ def query_image(img):
44
+ img = Image.fromarray(img)
45
+ scale = 1024 / min(img.size)
46
+ img = img.resize(
47
+ (int(img.size[0] * scale), int(img.size[1] * scale)), Image.LANCZOS
48
+ )
49
+
50
+ img = normalize(
51
+ to_tensor(img),
52
+ mean=(0.485, 0.456, 0.406),
53
+ std=(0.229, 0.224, 0.225),
54
+ )
55
+
56
+ with torch.no_grad():
57
+ outputs = model(img.unsqueeze(0))
58
+
59
+ preds = outputs.logits.cpu()
60
+ w, h = preds.shape[-2:]
61
+ preds = torch.nn.functional.interpolate(
62
+ preds, size=(w * 4, h * 4), mode="bilinear", align_corners=False
63
+ )
64
+ results = torch.argmax(preds, dim=1).numpy()[0]
65
+ results = preds_to_rgb(results)
66
+ return Image.fromarray(results)
67
+
68
+
69
+ demo = gr.Interface(
70
+ query_image,
71
+ inputs=[gr.Image()],
72
+ outputs="image",
73
+ title="neuratech-ai person segmentation v3",
74
+ examples=[[]],
75
+ )
76
+
77
+ demo.launch()