nlightcho's picture
fix image resize
0136d74 verified
raw
history blame contribute delete
No virus
2.02 kB
import os
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision.transforms.functional import to_tensor, normalize
from transformers import SegformerForSemanticSegmentation
hf_token = os.environ.get("HF_TOKEN", None)
device = torch.device("cpu")
label2id = {"background": 0, "skin": 1, "hair": 2, "clothes": 3, "accessories": 4}
id2label = {v: k for k, v in label2id.items()}
colors = {
"background": (40, 40, 40),
"skin": (255, 178, 127),
"hair": (139, 69, 19),
"clothes": (100, 149, 237),
"accessories": (255, 215, 0),
}
model = SegformerForSemanticSegmentation.from_pretrained(
"neuratech-ai/person_segmentation_v3",
token=hf_token,
ignore_mismatched_sizes=True,
num_labels=len(label2id),
id2label=id2label,
label2id=label2id,
)
model.eval()
model.to(device)
def preds_to_rgb(preds):
preds_rgb = np.zeros((preds.shape[0], preds.shape[1], 3), dtype=np.uint8)
for class_name, class_id in label2id.items():
preds_rgb[preds == class_id] = colors[class_name]
return preds_rgb
def query_image(img):
if img is None:
return None
img = Image.fromarray(img)
scale = 1024 / max(img.size)
img = img.resize(
(int(img.size[0] * scale), int(img.size[1] * scale)), Image.LANCZOS
)
img = normalize(
to_tensor(img),
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
)
with torch.no_grad():
outputs = model(img.unsqueeze(0))
preds = outputs.logits.cpu()
w, h = preds.shape[-2:]
preds = torch.nn.functional.interpolate(
preds, size=(w * 4, h * 4), mode="bilinear", align_corners=False
)
results = torch.argmax(preds, dim=1).numpy()[0]
results = preds_to_rgb(results)
return Image.fromarray(results)
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="neuratech-ai person segmentation v3",
examples=[["example1.jpg"], ["example2.jpg"]],
)
demo.launch()