|
import gradio as gr |
|
import onnxruntime as ort |
|
import numpy as np |
|
from PIL import Image |
|
import requests |
|
from torchvision import transforms |
|
|
|
|
|
ort_session = ort.InferenceSession("model.onnx") |
|
|
|
|
|
def preprocess(image): |
|
image = image.resize((320, 320)).convert("RGB") |
|
image_np = np.array(image).astype(np.float32) / 255.0 |
|
image_np = image_np.transpose(2, 0, 1) |
|
image_np = np.expand_dims(image_np, axis=0) |
|
return image_np |
|
|
|
|
|
def segment_dress(image): |
|
input_tensor = preprocess(image) |
|
inputs = {ort_session.get_inputs()[0].name: input_tensor} |
|
outputs = ort_session.run(None, inputs) |
|
|
|
pred = outputs[0][0][0] |
|
pred = (pred - pred.min()) / (pred.max() - pred.min()) |
|
pred_img = Image.fromarray((pred * 255).astype(np.uint8)).resize(image.size) |
|
|
|
|
|
image_np = np.array(image.convert("RGB")) |
|
mask = np.array(pred_img).astype(np.float32) / 255.0 |
|
masked = (image_np * mask[..., None]).astype(np.uint8) |
|
|
|
return Image.fromarray(masked) |
|
|
|
|
|
gr.Interface( |
|
fn=segment_dress, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=gr.Image(type="pil", label="Segmented Dress"), |
|
title="Background Removal", |
|
description="Upload an image and Remove the Background" |
|
).launch() |