dhanushreddy29 commited on
Commit
2395fbc
1 Parent(s): c14578e

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -162
main.py DELETED
@@ -1,162 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from huggingface_hub import hf_hub_download
8
- from torch.autograd import Variable
9
- from PIL import Image
10
-
11
-
12
- def build_model(hypar, device):
13
- net = hypar["model"] # GOSNETINC(3,1)
14
-
15
- # convert to half precision
16
- if hypar["model_digit"] == "half":
17
- net.half()
18
- for layer in net.modules():
19
- if isinstance(layer, nn.BatchNorm2d):
20
- layer.float()
21
-
22
- net.to(device)
23
-
24
- if hypar["restore_model"] != "":
25
- net.load_state_dict(
26
- torch.load(
27
- hypar["model_path"] + "/" + hypar["restore_model"],
28
- map_location=device,
29
- )
30
- )
31
- net.to(device)
32
- net.eval()
33
- return net
34
-
35
-
36
- if not os.path.exists("saved_models"):
37
- os.mkdir("saved_models")
38
- os.mkdir("git")
39
- os.system("git clone https://github.com/xuebinqin/DIS git/xuebinqin/DIS")
40
- hf_hub_download(
41
- repo_id="NimaBoscarino/IS-Net_DIS-general-use",
42
- filename="isnet-general-use.pth",
43
- local_dir="saved_models",
44
- )
45
- os.system("rm -r git/xuebinqin/DIS/IS-Net/__pycache__")
46
- os.system("mv git/xuebinqin/DIS/IS-Net/* .")
47
-
48
- import data_loader_cache
49
- import models
50
-
51
- device = "cpu"
52
- ISNetDIS = models.ISNetDIS
53
- normalize = data_loader_cache.normalize
54
- im_preprocess = data_loader_cache.im_preprocess
55
-
56
- # Set Parameters
57
- hypar = {} # paramters for inferencing
58
-
59
- # load trained weights from this path
60
- hypar["model_path"] = "./saved_models"
61
- # name of the to-be-loaded weights
62
- hypar["restore_model"] = "isnet-general-use.pth"
63
- # indicate if activate intermediate feature supervision
64
- hypar["interm_sup"] = False
65
-
66
- # choose floating point accuracy --
67
- # indicates "half" or "full" accuracy of float number
68
- hypar["model_digit"] = "full"
69
- hypar["seed"] = 0
70
-
71
- # cached input spatial resolution, can be configured into different size
72
- hypar["cache_size"] = [1024, 1024]
73
-
74
- # data augmentation parameters ---
75
- # mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
76
- hypar["input_size"] = [1024, 1024]
77
- # random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
78
- hypar["crop_size"] = [1024, 1024]
79
-
80
- hypar["model"] = ISNetDIS()
81
-
82
- # Build Model
83
- net = build_model(hypar, device)
84
-
85
-
86
- def predict(net, inputs_val, shapes_val, hypar, device):
87
- """
88
- Given an Image, predict the mask
89
- """
90
- net.eval()
91
-
92
- if hypar["model_digit"] == "full":
93
- inputs_val = inputs_val.type(torch.FloatTensor)
94
- else:
95
- inputs_val = inputs_val.type(torch.HalfTensor)
96
-
97
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(
98
- device
99
- ) # wrap inputs in Variable
100
-
101
- ds_val = net(inputs_val_v)[0] # list of 6 results
102
-
103
- # B x 1 x H x W # we want the first one which is the most accurate prediction
104
- pred_val = ds_val[0][0, :, :, :]
105
-
106
- # recover the prediction spatial size to the orignal image size
107
- pred_val = torch.squeeze(
108
- F.upsample(
109
- torch.unsqueeze(pred_val, 0),
110
- (shapes_val[0][0], shapes_val[0][1]),
111
- mode="bilinear",
112
- )
113
- )
114
-
115
- ma = torch.max(pred_val)
116
- mi = torch.min(pred_val)
117
- pred_val = (pred_val - mi) / (ma - mi) # max = 1
118
-
119
- if device == "cpu":
120
- torch.cpu.empty_cache()
121
- # it is the mask we need
122
- return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
123
-
124
-
125
- def load_image(im_pil, hypar):
126
- im = np.array(im_pil)
127
- im, im_shp = im_preprocess(im, hypar["cache_size"])
128
- im = torch.divide(im, 255.0)
129
- shape = torch.from_numpy(np.array(im_shp))
130
- # make a batch of image, shape
131
- aa = normalize(im, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
132
- return aa.unsqueeze(0), shape.unsqueeze(0)
133
-
134
-
135
- def remove_background(image):
136
- image_tensor, orig_size = load_image(image, hypar)
137
- mask = predict(net, image_tensor, orig_size, hypar, "cpu")
138
-
139
- mask = Image.fromarray(mask).convert("L")
140
- im_rgb = image.convert("RGB")
141
-
142
- cropped = im_rgb.copy()
143
- cropped.putalpha(mask)
144
- return cropped
145
-
146
-
147
- inputs = gr.inputs.Image()
148
- outputs = gr.outputs.Image(type="pil")
149
- interface = gr.Interface(
150
- fn=remove_background,
151
- inputs=inputs,
152
- outputs=outputs,
153
- title="Remove Background",
154
- description="This App removes the background from an image",
155
- examples=[
156
- "examples/input/1.jpeg",
157
- "examples/input/2.jpeg",
158
- "examples/input/3.jpeg",
159
- ],
160
- cache_examples=True,
161
- )
162
- interface.launch(enable_queue=True)