import os import gradio as gr import numpy as np from PIL import Image import torch from torchvision.transforms.functional import to_tensor, normalize from transformers import SegformerForSemanticSegmentation hf_token = os.environ.get("HF_TOKEN", None) device = torch.device("cpu") label2id = {"background": 0, "skin": 1, "hair": 2, "clothes": 3, "accessories": 4} id2label = {v: k for k, v in label2id.items()} colors = { "background": (40, 40, 40), "skin": (255, 178, 127), "hair": (139, 69, 19), "clothes": (100, 149, 237), "accessories": (255, 215, 0), } model = SegformerForSemanticSegmentation.from_pretrained( "neuratech-ai/person_segmentation_v3", token=hf_token, ignore_mismatched_sizes=True, num_labels=len(label2id), id2label=id2label, label2id=label2id, ) model.eval() model.to(device) def preds_to_rgb(preds): preds_rgb = np.zeros((preds.shape[0], preds.shape[1], 3), dtype=np.uint8) for class_name, class_id in label2id.items(): preds_rgb[preds == class_id] = colors[class_name] return preds_rgb def query_image(img): if img is None: return None img = Image.fromarray(img) scale = 1024 / max(img.size) img = img.resize( (int(img.size[0] * scale), int(img.size[1] * scale)), Image.LANCZOS ) img = normalize( to_tensor(img), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), ) with torch.no_grad(): outputs = model(img.unsqueeze(0)) preds = outputs.logits.cpu() w, h = preds.shape[-2:] preds = torch.nn.functional.interpolate( preds, size=(w * 4, h * 4), mode="bilinear", align_corners=False ) results = torch.argmax(preds, dim=1).numpy()[0] results = preds_to_rgb(results) return Image.fromarray(results) demo = gr.Interface( query_image, inputs=[gr.Image()], outputs="image", title="neuratech-ai person segmentation v3", examples=[["example1.jpg"], ["example2.jpg"]], ) demo.launch()