ZhengPeng7 commited on
Commit
9439305
·
1 Parent(s): 4a3bbdd

Fix a bug in SliderImage.

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app_local.py +222 -0
app.py CHANGED
@@ -157,7 +157,7 @@ def predict(images, resolution, weights_file):
157
  zipf.write(file, os.path.basename(file))
158
  return save_paths, zip_file_path
159
  else:
160
- return (image_ori, image_masked)
161
 
162
 
163
  examples = [[_] for _ in glob('examples/*')][:]
 
157
  zipf.write(file, os.path.basename(file))
158
  return save_paths, zip_file_path
159
  else:
160
+ return (image_masked, image_ori)
161
 
162
 
163
  examples = [[_] for _ in glob('examples/*')][:]
app_local.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ # import spaces
7
+
8
+ from glob import glob
9
+ from typing import Tuple
10
+
11
+ from PIL import Image
12
+ # from gradio_imageslider import ImageSlider
13
+ from transformers import AutoModelForImageSegmentation
14
+ from torchvision import transforms
15
+
16
+ import requests
17
+ from io import BytesIO
18
+ import zipfile
19
+
20
+
21
+ torch.set_float32_matmul_precision('high')
22
+ # torch.jit.script = lambda f: f
23
+
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ ### image_proc.py
27
+ def refine_foreground(image, mask, r=90):
28
+ if mask.size != image.size:
29
+ mask = mask.resize(image.size)
30
+ image = np.array(image) / 255.0
31
+ mask = np.array(mask) / 255.0
32
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
33
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
34
+ return image_masked
35
+
36
+
37
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
38
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
39
+ alpha = alpha[:, :, None]
40
+ F, blur_B = FB_blur_fusion_foreground_estimator(
41
+ image, image, image, alpha, r)
42
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
43
+
44
+
45
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
46
+ if isinstance(image, Image.Image):
47
+ image = np.array(image) / 255.0
48
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
49
+
50
+ blurred_FA = cv2.blur(F * alpha, (r, r))
51
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
52
+
53
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
54
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
55
+ F = blurred_F + alpha * \
56
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
57
+ F = np.clip(F, 0, 1)
58
+ return F, blurred_B
59
+
60
+
61
+ class ImagePreprocessor():
62
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
63
+ self.transform_image = transforms.Compose([
64
+ transforms.Resize(resolution), # 1. keep consistent with the cv2.resize used in training 2. redundant with that in path_to_image()
65
+ transforms.ToTensor(),
66
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
67
+ ])
68
+
69
+ def proc(self, image: Image.Image) -> torch.Tensor:
70
+ image = self.transform_image(image)
71
+ return image
72
+
73
+
74
+ usage_to_weights_file = {
75
+ 'General': 'BiRefNet',
76
+ 'General-Lite': 'BiRefNet_lite',
77
+ 'Portrait': 'BiRefNet-portrait',
78
+ 'DIS': 'BiRefNet-DIS5K',
79
+ 'HRSOD': 'BiRefNet-HRSOD',
80
+ 'COD': 'BiRefNet-COD',
81
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
82
+ 'General-legacy': 'BiRefNet-legacy'
83
+ }
84
+
85
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
86
+ birefnet.to(device)
87
+ birefnet.eval()
88
+
89
+
90
+ # @spaces.GPU
91
+ def predict(images, resolution, weights_file):
92
+ assert (images is not None), 'AssertionError: images cannot be None.'
93
+
94
+ global birefnet
95
+ # Load BiRefNet with chosen weights
96
+ _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
97
+ print('Using weights: {}.'.format(_weights_file))
98
+ birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
99
+ birefnet.to(device)
100
+ birefnet.eval()
101
+
102
+ try:
103
+ resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
104
+ except:
105
+ resolution = [1024, 1024]
106
+ print('Invalid resolution input. Automatically changed to 1024x1024.')
107
+
108
+ if isinstance(images, list):
109
+ # For tab_batch
110
+ save_paths = []
111
+ save_dir = 'preds-BiRefNet'
112
+ if not os.path.exists(save_dir):
113
+ os.makedirs(save_dir)
114
+ tab_is_batch = True
115
+ else:
116
+ images = [images]
117
+ tab_is_batch = False
118
+
119
+ for idx_image, image_src in enumerate(images):
120
+ if isinstance(image_src, str):
121
+ if os.path.isfile(image_src):
122
+ image_ori = Image.open(image_src)
123
+ else:
124
+ response = requests.get(image_src)
125
+ image_data = BytesIO(response.content)
126
+ image_ori = Image.open(image_data)
127
+ else:
128
+ image_ori = Image.fromarray(image_src)
129
+
130
+ image = image_ori.convert('RGB')
131
+ # Preprocess the image
132
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
133
+ image_proc = image_preprocessor.proc(image)
134
+ image_proc = image_proc.unsqueeze(0)
135
+
136
+ # Prediction
137
+ with torch.no_grad():
138
+ preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
139
+ pred = preds[0].squeeze()
140
+
141
+ # Show Results
142
+ pred_pil = transforms.ToPILImage()(pred)
143
+ image_masked = refine_foreground(image, pred_pil)
144
+ image_masked.putalpha(pred_pil.resize(image.size))
145
+
146
+ torch.cuda.empty_cache()
147
+
148
+ if tab_is_batch:
149
+ save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
150
+ image_masked.save(save_file_path)
151
+ save_paths.append(save_file_path)
152
+
153
+ if tab_is_batch:
154
+ zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
155
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
156
+ for file in save_paths:
157
+ zipf.write(file, os.path.basename(file))
158
+ return save_paths, zip_file_path
159
+ else:
160
+ return (image_masked, image_ori)[0]
161
+
162
+
163
+ examples = [[_] for _ in glob('examples/*')][:]
164
+ # Add the option of resolution in a text box.
165
+ for idx_example, example in enumerate(examples):
166
+ examples[idx_example].append('1024x1024')
167
+ examples.append(examples[-1].copy())
168
+ examples[-1][1] = '512x512'
169
+
170
+ examples_url = [
171
+ ['https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg'],
172
+ ]
173
+ for idx_example_url, example_url in enumerate(examples_url):
174
+ examples_url[idx_example_url].append('1024x1024')
175
+
176
+ descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
177
+ ' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
178
+ ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
179
+ ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
180
+
181
+ tab_image = gr.Interface(
182
+ fn=predict,
183
+ inputs=[
184
+ gr.Image(label='Upload an image'),
185
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`. Higher resolutions can be much slower for inference.", label="Resolution"),
186
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
187
+ ],
188
+ outputs=gr.Image(label="BiRefNet's prediction", type="pil"),
189
+ examples=examples,
190
+ api_name="image",
191
+ description=descriptions,
192
+ )
193
+
194
+ tab_text = gr.Interface(
195
+ fn=predict,
196
+ inputs=[
197
+ gr.Textbox(label="Paste an image URL"),
198
+ gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`. Higher resolutions can be much slower for inference.", label="Resolution"),
199
+ gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
200
+ ],
201
+ outputs=gr.Image(label="BiRefNet's prediction", type="pil"),
202
+ examples=examples_url,
203
+ api_name="text",
204
+ description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
205
+ )
206
+
207
+ tab_batch = gr.Interface(
208
+ fn=predict,
209
+ inputs=gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
210
+ outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
211
+ api_name="batch",
212
+ description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
213
+ )
214
+
215
+ demo = gr.TabbedInterface(
216
+ [tab_image, tab_text, tab_batch],
217
+ ['image', 'text', 'batch'],
218
+ title="BiRefNet demo for subject extraction (general / salient / camouflaged / portrait).",
219
+ )
220
+
221
+ if __name__ == "__main__":
222
+ demo.launch(debug=True)