griffin.b commited on
Commit
66d6249
1 Parent(s): 656bf4f

working demo

Browse files
Files changed (6) hide show
  1. abbey.jpg +0 -0
  2. app.py +197 -0
  3. julia.jpg +0 -0
  4. newman.jpg +0 -0
  5. newman_mask.jpg +0 -0
  6. requirements.txt +8 -0
abbey.jpg ADDED
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageFilter
3
+ import numpy as np
4
+ import torch
5
+ from torch.autograd import Variable
6
+ from torchvision import transforms
7
+ import torch.nn.functional as F
8
+ import gdown
9
+ import os
10
+
11
+ os.system("git clone https://github.com/xuebinqin/DIS")
12
+ os.system("mv DIS/IS-Net/* .")
13
+
14
+ # project imports
15
+ from data_loader_cache import normalize, im_reader, im_preprocess
16
+ from models import *
17
+
18
+ #Helpers
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+ # Download official weights
22
+ if not os.path.exists("saved_models"):
23
+ os.mkdir("saved_models")
24
+ MODEL_PATH_URL = "https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn"
25
+ gdown.download(MODEL_PATH_URL, "saved_models/isnet.pth", use_cookies=False)
26
+
27
+ class GOSNormalize(object):
28
+ '''
29
+ Normalize the Image using torch.transforms
30
+ '''
31
+ def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
32
+ self.mean = mean
33
+ self.std = std
34
+
35
+ def __call__(self,image):
36
+ image = normalize(image,self.mean,self.std)
37
+ return image
38
+
39
+
40
+ transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
41
+
42
+ def load_image(im_path, hypar):
43
+ im = im_reader(im_path)
44
+ im, im_shp = im_preprocess(im, hypar["cache_size"])
45
+ im = torch.divide(im,255.0)
46
+ shape = torch.from_numpy(np.array(im_shp))
47
+ return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
48
+
49
+
50
+ def build_model(hypar,device):
51
+ net = hypar["model"]#GOSNETINC(3,1)
52
+
53
+ # convert to half precision
54
+ if(hypar["model_digit"]=="half"):
55
+ net.half()
56
+ for layer in net.modules():
57
+ if isinstance(layer, nn.BatchNorm2d):
58
+ layer.float()
59
+
60
+ net.to(device)
61
+
62
+ if(hypar["restore_model"]!=""):
63
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
64
+ net.to(device)
65
+ net.eval()
66
+ return net
67
+
68
+
69
+ def predict(net, inputs_val, shapes_val, hypar, device):
70
+ '''
71
+ Given an Image, predict the mask
72
+ '''
73
+ net.eval()
74
+
75
+ if(hypar["model_digit"]=="full"):
76
+ inputs_val = inputs_val.type(torch.FloatTensor)
77
+ else:
78
+ inputs_val = inputs_val.type(torch.HalfTensor)
79
+
80
+
81
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
82
+
83
+ ds_val = net(inputs_val_v)[0] # list of 6 results
84
+
85
+ pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
86
+
87
+ ## recover the prediction spatial size to the orignal image size
88
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
89
+
90
+ ma = torch.max(pred_val)
91
+ mi = torch.min(pred_val)
92
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
93
+
94
+ if device == 'cuda': torch.cuda.empty_cache()
95
+ return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
96
+
97
+ # Set Parameters
98
+ hypar = {} # paramters for inferencing
99
+
100
+
101
+ hypar["model_path"] ="./saved_models" ## load trained weights from this path
102
+ hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
103
+ hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
104
+
105
+ ## choose floating point accuracy --
106
+ hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
107
+ hypar["seed"] = 0
108
+
109
+ hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
110
+
111
+ ## data augmentation parameters ---
112
+ hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
113
+ hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
114
+
115
+ hypar["model"] = ISNetDIS()
116
+
117
+ # Build Model
118
+ net = build_model(hypar, device)
119
+
120
+
121
+ def infer_mask(image: Image):
122
+ image_path = image
123
+
124
+ image_tensor, orig_size = load_image(image_path, hypar)
125
+ mask = predict(net, image_tensor, orig_size, hypar, device)
126
+
127
+ return Image.fromarray(mask).convert("L")
128
+
129
+ def blur(image_set: list, blur_amount: int):
130
+ blurred_image = image_set[0].filter(ImageFilter.GaussianBlur(blur_amount))
131
+
132
+ return Image.composite(image_set[0], blurred_image, image_set[1])
133
+
134
+
135
+ with gr.Blocks() as interface:
136
+ default_im = Image.open("newman.jpg").convert("RGB")
137
+ default_mask = Image.open("newman_mask.jpg").convert("RGB")
138
+ examples_list = [os.path.join(os.path.dirname(__file__), "newman.jpg"),
139
+ os.path.join(os.path.dirname(__file__), "abbey.jpg"),
140
+ os.path.join(os.path.dirname(__file__), "julia.jpg")
141
+ ]
142
+
143
+ current_images = gr.State([default_im, default_mask])
144
+ mask_toggle = gr.State(False)
145
+
146
+ gr.Markdown(
147
+ """
148
+ ### Intelligent Photo Blur Using Dichotomous Image Segmentation
149
+
150
+ This app leverages the machine learning engine built by Xuebin Qin (https://github.com/xuebinqin/DIS) to mask the prominent subject within a photograph.
151
+ The mask is used to keep the subject in clear focus while an adjustable slider is available to interactively blur the background.
152
+ To use, upload a photo and press the run button. You can adjust the level of blur through the slider and view the mask using the "Show Generated Mask" button.
153
+ """
154
+ )
155
+ with gr.Row():
156
+ with gr.Column():
157
+ input_image = gr.Image(value=default_im, type='filepath')
158
+ run_button = gr.Button()
159
+ gr.Examples(inputs=input_image, examples=examples_list)
160
+ with gr.Column():
161
+ output_image = gr.Image()
162
+ blur_slider = gr.Slider(0, 16, 5, step=1, label="Blur Amount")
163
+ mask_button = gr.Button(value="Show Generated Mask")
164
+ mask_image = gr.Image(value=default_mask, visible=False)
165
+
166
+ def run(image: Image, current_images: gr.State):
167
+ im_rgb = Image.open(image).convert("RGB")
168
+ mask = infer_mask(image)
169
+
170
+ return (
171
+ blur([im_rgb, mask], 5),
172
+ mask,
173
+ [im_rgb, mask]
174
+ )
175
+
176
+ def reset_slider():
177
+ return gr.update(value=5)
178
+
179
+ def show_mask(mask_toggle: gr.State):
180
+ if mask_toggle == True:
181
+ return gr.update(visible=False)
182
+ else:
183
+ return gr.update(visible=True)
184
+
185
+ def toggle_mask(mask_toggle: gr.State):
186
+ if mask_toggle == True:
187
+ return False
188
+ else:
189
+ return True
190
+
191
+ run_button.click(run, [input_image, current_images], [output_image, mask_image, current_images])
192
+ run_button.click(reset_slider, outputs=blur_slider)
193
+ blur_slider.change(blur, [current_images, blur_slider], output_image, show_progress=False)
194
+ mask_button.click(show_mask, inputs=mask_toggle, outputs=mask_image)
195
+ mask_button.click(toggle_mask, inputs=mask_toggle, outputs=mask_toggle)
196
+
197
+ interface.launch()
julia.jpg ADDED
newman.jpg ADDED
newman_mask.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ requests
4
+ gdown
5
+ matplotlib
6
+ opencv-python
7
+ Pillow==8.0.0
8
+ scikit-image==0.15.0