wildoctopus's picture
Upload 5 files
896437a
raw
history blame
904 Bytes
import PIL
import torch
import gradio as gr
from process import load_seg_model, get_palette, generate_mask
device = 'cpu'
def initialize_and_load_models():
checkpoint_path = 'model/cloth_segm.pth'
net = load_seg_model(checkpoint_path, device=device)
return net
net = initialize_and_load_models()
palette = get_palette(4)
def run(img):
cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
return cloth_seg
# Define input and output interfaces
input_image = gr.inputs.Image(label="Input Image", type="pil")
# Define the Gradio interface
cloth_seg_image = gr.outputs.Image(label="Cloth Segmentation", type="pil")
title = "Demo for Cloth Segmentation"
description = "An app for Cloth Segmentation"
inputs = [input_image]
outputs = [cloth_seg_image]
gr.Interface(fn=run, inputs=inputs, outputs=outputs, title=title, description=description).launch()