File size: 904 Bytes
896437a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import PIL
import torch
import gradio as gr
from process import load_seg_model, get_palette, generate_mask


device = 'cpu'



def initialize_and_load_models():

    checkpoint_path = 'model/cloth_segm.pth'
    net = load_seg_model(checkpoint_path, device=device)    

    return net

net = initialize_and_load_models()
palette = get_palette(4)


def run(img):

    cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
    return cloth_seg

# Define input and output interfaces
input_image = gr.inputs.Image(label="Input Image", type="pil")

# Define the Gradio interface
cloth_seg_image = gr.outputs.Image(label="Cloth Segmentation", type="pil")

title = "Demo for Cloth Segmentation"
description = "An app for Cloth Segmentation"
inputs = [input_image]
outputs = [cloth_seg_image]


gr.Interface(fn=run, inputs=inputs, outputs=outputs, title=title, description=description).launch()