ginipick commited on
Commit
adfe191
ยท
verified ยท
1 Parent(s): ec38b03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -74
app.py CHANGED
@@ -25,30 +25,49 @@ from transformers import (
25
  AutoConfig,
26
  AutoModelForImageSegmentation,
27
  )
 
 
28
 
29
- # 1) Config๋ฅผ ๋จผ์ € ๋กœ๋“œํ•˜์—ฌ tie_weights ์ถฉ๋Œ์„ ๋ฐฉ์ง€
 
 
 
 
 
30
  config = AutoConfig.from_pretrained(
31
- "zhengpeng7/BiRefNet", # ๐Ÿ‘‰ ์›ํ•˜๋Š” Hugging Face ๋ชจ๋ธ Repo
32
  trust_remote_code=True
33
  )
34
 
35
- # 2) config.get_text_config ์— ๋”๋ฏธ ๋ฉ”์„œ๋“œ ๋ถ€์—ฌ (tie_word_embeddings=False)
36
  def dummy_get_text_config(decoder=True):
37
  return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
38
 
39
  config.get_text_config = dummy_get_text_config
40
 
41
- # 3) ๋ชจ๋ธ ๊ตฌ์กฐ๋งŒ ๋งŒ๋“ค๊ธฐ (from_config) -> tie_weights ์ž๋™ ํ˜ธ์ถœ ์•ˆ ๋จ
42
  birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
43
  birefnet.eval()
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
  birefnet.to(device)
46
  birefnet.half()
47
 
48
- # 4) state_dict ๋กœ๋“œ (๊ฐ€์ค‘์น˜) - ๋กœ์ปฌ ํŒŒ์ผ ์‚ฌ์šฉ ์˜ˆ์‹œ
49
- # ์‹ค์ œ๋กœ๋Š” hf_hub_download / snapshot_download ๋“ฑ์œผ๋กœ "model.safetensors"๋ฅผ ๋ฏธ๋ฆฌ ๋ฐ›์€ ๋’ค ์‚ฌ์šฉ
50
- print("Loading BiRefNet weights from local file: model.safetensors")
51
- state_dict = torch.load("model.safetensors", map_location="cpu") # ์˜ˆ์‹œ
 
 
 
 
 
 
 
 
 
 
 
 
52
  missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
53
  print("[Info] Missing keys:", missing)
54
  print("[Info] Unexpected keys:", unexpected)
@@ -56,7 +75,7 @@ torch.cuda.empty_cache()
56
 
57
 
58
  ##########################################################
59
- # 1. ์ด๋ฏธ์ง€ ํ›„์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋“ค
60
  ##########################################################
61
 
62
  def refine_foreground(image, mask, r=90):
@@ -85,7 +104,6 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
85
  F = np.clip(F, 0, 1)
86
  return F, blurred_B
87
 
88
-
89
  class ImagePreprocessor():
90
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
91
  self.transform_image = transforms.Compose([
@@ -99,7 +117,7 @@ class ImagePreprocessor():
99
 
100
 
101
  ##########################################################
102
- # 2. ์˜ˆ์ œ ์„ค์ • ๋ฐ ์œ ํ‹ธ
103
  ##########################################################
104
 
105
  usage_to_weights_file = {
@@ -130,30 +148,24 @@ descriptions = (
130
  "We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
131
  )
132
 
133
-
134
  ##########################################################
135
- # 3. ์ถ”๋ก  ํ•จ์ˆ˜ (์ด๋ฏธ ๋กœ๋“œ๋œ birefnet ๋ชจ๋ธ ์‚ฌ์šฉ)
136
  ##########################################################
137
 
138
  @spaces.GPU
139
  def predict(images, resolution, weights_file):
140
- """
141
- ์—ฌ๊ธฐ์„œ๋Š”, ๋‹จ์ผ birefnet ๋ชจ๋ธ๋งŒ ์œ ์ง€ํ•˜๊ณ  ์žˆ์œผ๋ฉฐ,
142
- weight_file์„ ๋ฐ”๊พธ๋”๋ผ๋„ ์‹ค์ œ๋กœ๋Š” ์ด๋ฏธ ๋กœ๋“œ๋œ 'birefnet' ๋ชจ๋ธ๋งŒ ์‚ฌ์šฉ.
143
- (๋งŒ์•ฝ ๋‹ค๋ฅธ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด, ์•„๋ž˜์ฒ˜๋Ÿผ ๋กœ์ปฌ state_dict ๊ต์ฒด ๋ฐฉ์‹ ์ถ”๊ฐ€ ๊ฐ€๋Šฅ.)
144
- """
145
  assert images is not None, 'Images cannot be None.'
146
 
147
- # Resolution parse
148
  try:
149
- w, h = resolution.strip().split('x')
150
- w, h = int(int(w)//32*32), int(int(h)//32*32)
151
- resolution_list = (w, h)
152
  except:
153
- print('[WARN] Invalid resolution input. Fallback to 1024x1024.')
154
- resolution_list = (1024, 1024)
155
 
156
- # ์ด๋ฏธ์ง€๊ฐ€ ์—ฌ๋Ÿฌ ์žฅ์ผ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ๋ฆฌ์ŠคํŠธ๋กœ ์ฒ˜๋ฆฌ
157
  if isinstance(images, list):
158
  is_batch = True
159
  outputs, save_paths = [], []
@@ -164,65 +176,57 @@ def predict(images, resolution, weights_file):
164
  is_batch = False
165
 
166
  for idx, image_src in enumerate(images):
167
- # str์ด๋ฉด ํŒŒ์ผ ๊ฒฝ๋กœ ํ˜น์€ URL
168
  if isinstance(image_src, str):
169
  if os.path.isfile(image_src):
170
  image_ori = Image.open(image_src)
171
  else:
172
  resp = requests.get(image_src)
173
  image_ori = Image.open(BytesIO(resp.content))
174
- # numpy ๋ฐฐ์—ด์ด๋ฉด Pillow ๋ณ€ํ™˜
175
  elif isinstance(image_src, np.ndarray):
176
  image_ori = Image.fromarray(image_src)
177
  else:
178
  image_ori = image_src.convert('RGB')
179
 
180
- image = image_ori.convert('RGB')
181
- preproc = ImagePreprocessor(resolution_list)
182
- image_proc = preproc.proc(image).unsqueeze(0).to(device).half()
183
 
184
- # ์‹ค์ œ ์ถ”๋ก 
185
  with torch.inference_mode():
186
- # ๊ฒฐ๊ณผ ๋งจ ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด preds
187
  preds = birefnet(image_proc)[-1].sigmoid().cpu()
188
  pred_mask = preds[0].squeeze()
189
 
190
  # ํ›„์ฒ˜๋ฆฌ
191
  pred_pil = transforms.ToPILImage()(pred_mask)
192
- image_masked = refine_foreground(image, pred_pil)
193
- image_masked.putalpha(pred_pil.resize(image.size))
194
 
195
  if is_batch:
196
- file_name = (
197
- os.path.splitext(os.path.basename(image_src))[0]
198
- if isinstance(image_src, str)
199
- else f"img_{idx}"
200
- )
201
- out_path = os.path.join(save_dir, f"{file_name}.png")
202
- image_masked.save(out_path)
203
- save_paths.append(out_path)
204
  outputs.append(image_masked)
205
  else:
206
  outputs = [image_masked, image_ori]
207
 
208
  torch.cuda.empty_cache()
209
 
210
- # ๋ฐฐ์น˜๋ผ๋ฉด ๊ฐค๋Ÿฌ๋ฆฌ + ZIP ๋ฐ˜ํ™˜
211
  if is_batch:
212
- zip_path = os.path.join(save_dir, f"{save_dir}.zip")
213
- with zipfile.ZipFile(zip_path, 'w') as zipf:
214
  for fpath in save_paths:
215
  zipf.write(fpath, os.path.basename(fpath))
216
- return (save_paths, zip_path)
217
  else:
218
  return outputs
219
 
220
-
221
  ##########################################################
222
- # 4. Gradio UI
223
  ##########################################################
224
 
225
- # ์ปค์Šคํ…€ CSS
226
  css = """
227
  body {
228
  background: linear-gradient(135deg, #667eea, #764ba2);
@@ -280,14 +284,13 @@ button:hover, .btn:hover {
280
  title_html = """
281
  <h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
282
  <p align="center" style="font-size:1.1em; color:#555;">
283
- Using <code>from_config()</code> + local <code>state_dict</code> to bypass tie_weights issues
284
  </p>
285
  """
286
 
287
  with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
288
  gr.Markdown(title_html)
289
  with gr.Tabs():
290
- # ํƒญ 1: Image
291
  with gr.Tab("Image"):
292
  with gr.Row():
293
  with gr.Column(scale=1):
@@ -297,13 +300,8 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
297
  predict_btn = gr.Button("Predict")
298
  with gr.Column(scale=2):
299
  output_slider = ImageSlider(label="Result", type="pil")
300
- gr.Examples(
301
- examples=examples_image,
302
- inputs=[image_input, resolution_input, weights_radio],
303
- label="Examples"
304
- )
305
 
306
- # ํƒญ 2: Text(URL)
307
  with gr.Tab("Text"):
308
  with gr.Row():
309
  with gr.Column(scale=1):
@@ -313,36 +311,23 @@ with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
313
  predict_btn_text = gr.Button("Predict")
314
  with gr.Column(scale=2):
315
  output_slider_text = ImageSlider(label="Result", type="pil")
316
- gr.Examples(
317
- examples=examples_text,
318
- inputs=[image_url, resolution_input_text, weights_radio_text],
319
- label="Examples"
320
- )
321
 
322
- # ํƒญ 3: Batch
323
  with gr.Tab("Batch"):
324
  with gr.Row():
325
  with gr.Column(scale=1):
326
- file_input = gr.File(
327
- label="Upload Multiple Images",
328
- type="filepath",
329
- file_count="multiple"
330
- )
331
  resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
332
  weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
333
  predict_btn_batch = gr.Button("Predict")
334
  with gr.Column(scale=2):
335
  output_gallery = gr.Gallery(label="Results", scale=1)
336
  zip_output = gr.File(label="Zip Download")
337
- gr.Examples(
338
- examples=examples_batch,
339
- inputs=[file_input, resolution_input_batch, weights_radio_batch],
340
- label="Examples"
341
- )
342
 
343
  gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
344
 
345
- # ๋ฒ„ํŠผ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
346
  predict_btn.click(
347
  fn=predict,
348
  inputs=[image_input, resolution_input, weights_radio],
 
25
  AutoConfig,
26
  AutoModelForImageSegmentation,
27
  )
28
+ # Hugging Face Hub
29
+ from huggingface_hub import hf_hub_download
30
 
31
+
32
+ ##########################################################
33
+ # 1. Config ๋ฐ from_config() ์ดˆ๊ธฐํ™”
34
+ ##########################################################
35
+
36
+ # 1) Config๋งŒ ๋จผ์ € ๋กœ๋“œ
37
  config = AutoConfig.from_pretrained(
38
+ "zhengpeng7/BiRefNet", # ์˜ˆ์‹œ
39
  trust_remote_code=True
40
  )
41
 
42
+ # 2) config.get_text_config์— ๋”๋ฏธ ๋ฉ”์„œ๋“œ ๋ถ€์—ฌ (tie_word_embeddings=False)
43
  def dummy_get_text_config(decoder=True):
44
  return type("DummyTextConfig", (), {"tie_word_embeddings": False})()
45
 
46
  config.get_text_config = dummy_get_text_config
47
 
48
+ # 3) ๋ชจ๋ธ ๊ตฌ์กฐ๋งŒ ๋งŒ๋“ค๊ธฐ
49
  birefnet = AutoModelForImageSegmentation.from_config(config, trust_remote_code=True)
50
  birefnet.eval()
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
52
  birefnet.to(device)
53
  birefnet.half()
54
 
55
+ ##########################################################
56
+ # 2. ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋‹ค์šด๋กœ๋“œ & ๋กœ๋“œ
57
+ ##########################################################
58
+
59
+ # huggingface_hub์—์„œ safetensors ๋˜๋Š” bin ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
60
+ # (repo_id, filename ๋“ฑ์€ ์‹ค์ œ ์‚ฌ์šฉ ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ๋ณ€๊ฒฝ)
61
+ weights_path = hf_hub_download(
62
+ repo_id="zhengpeng7/BiRefNet", # ์˜ˆ์‹œ
63
+ filename="model.safetensors", # ๋˜๋Š” "pytorch_model.bin"
64
+ trust_remote_code=True
65
+ )
66
+ print("Downloaded weights to:", weights_path)
67
+
68
+ # state_dict ๋กœ๋“œ
69
+ print("Loading BiRefNet weights from HF Hub file:", weights_path)
70
+ state_dict = torch.load(weights_path, map_location="cpu")
71
  missing, unexpected = birefnet.load_state_dict(state_dict, strict=False)
72
  print("[Info] Missing keys:", missing)
73
  print("[Info] Unexpected keys:", unexpected)
 
75
 
76
 
77
  ##########################################################
78
+ # 3. ์ด๋ฏธ์ง€ ํ›„์ฒ˜๋ฆฌ ํ•จ์ˆ˜๋“ค
79
  ##########################################################
80
 
81
  def refine_foreground(image, mask, r=90):
 
104
  F = np.clip(F, 0, 1)
105
  return F, blurred_B
106
 
 
107
  class ImagePreprocessor():
108
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
109
  self.transform_image = transforms.Compose([
 
117
 
118
 
119
  ##########################################################
120
+ # 4. ์˜ˆ์ œ ์„ค์ • ๋ฐ ๊ธฐํƒ€
121
  ##########################################################
122
 
123
  usage_to_weights_file = {
 
148
  "We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access."
149
  )
150
 
 
151
  ##########################################################
152
+ # 5. ์ถ”๋ก  ํ•จ์ˆ˜ (์ด๋ฏธ ๋กœ๋“œ๋œ birefnet ๋ชจ๋ธ ์‚ฌ์šฉ)
153
  ##########################################################
154
 
155
  @spaces.GPU
156
  def predict(images, resolution, weights_file):
157
+ # weights_file์€ ์—ฌ๊ธฐ์„œ๋Š” ๋ฌด์‹œํ•˜๊ณ , ์ด๋ฏธ ๋กœ๋“œ๋œ birefnet ์‚ฌ์šฉ
 
 
 
 
158
  assert images is not None, 'Images cannot be None.'
159
 
160
+ # Parse resolution
161
  try:
162
+ w, h = map(int, resolution.strip().split('x'))
163
+ w, h = int(w//32*32), int(h//32*32)
 
164
  except:
165
+ w, h = 1024, 1024
166
+ resolution_tuple = (w, h)
167
 
168
+ # ๋ฆฌ์ŠคํŠธ์ธ์ง€ ํ™•์ธ
169
  if isinstance(images, list):
170
  is_batch = True
171
  outputs, save_paths = [], []
 
176
  is_batch = False
177
 
178
  for idx, image_src in enumerate(images):
179
+ # ํŒŒ์ผ ๊ฒฝ๋กœ ํ˜น์€ URL
180
  if isinstance(image_src, str):
181
  if os.path.isfile(image_src):
182
  image_ori = Image.open(image_src)
183
  else:
184
  resp = requests.get(image_src)
185
  image_ori = Image.open(BytesIO(resp.content))
186
+ # numpy array โ†’ PIL
187
  elif isinstance(image_src, np.ndarray):
188
  image_ori = Image.fromarray(image_src)
189
  else:
190
  image_ori = image_src.convert('RGB')
191
 
192
+ # ์ „์ฒ˜๋ฆฌ
193
+ preproc = ImagePreprocessor(resolution_tuple)
194
+ image_proc = preproc.proc(image_ori.convert('RGB')).unsqueeze(0).to(device).half()
195
 
196
+ # ์ถ”๋ก 
197
  with torch.inference_mode():
 
198
  preds = birefnet(image_proc)[-1].sigmoid().cpu()
199
  pred_mask = preds[0].squeeze()
200
 
201
  # ํ›„์ฒ˜๋ฆฌ
202
  pred_pil = transforms.ToPILImage()(pred_mask)
203
+ image_masked = refine_foreground(image_ori, pred_pil)
204
+ image_masked.putalpha(pred_pil.resize(image_ori.size))
205
 
206
  if is_batch:
207
+ fbase = (os.path.splitext(os.path.basename(image_src))[0] if isinstance(image_src, str) else f"img_{idx}")
208
+ outpath = os.path.join(save_dir, f"{fbase}.png")
209
+ image_masked.save(outpath)
210
+ save_paths.append(outpath)
 
 
 
 
211
  outputs.append(image_masked)
212
  else:
213
  outputs = [image_masked, image_ori]
214
 
215
  torch.cuda.empty_cache()
216
 
 
217
  if is_batch:
218
+ zippath = os.path.join(save_dir, f"{save_dir}.zip")
219
+ with zipfile.ZipFile(zippath, 'w') as zipf:
220
  for fpath in save_paths:
221
  zipf.write(fpath, os.path.basename(fpath))
222
+ return outputs, zippath
223
  else:
224
  return outputs
225
 
 
226
  ##########################################################
227
+ # 6. Gradio UI
228
  ##########################################################
229
 
 
230
  css = """
231
  body {
232
  background: linear-gradient(135deg, #667eea, #764ba2);
 
284
  title_html = """
285
  <h1 align="center" style="margin-bottom: 0.2em;">BiRefNet Demo (No Tie-Weights Crash)</h1>
286
  <p align="center" style="font-size:1.1em; color:#555;">
287
+ Using <code>from_config()</code> + local <code>state_dict</code> or <code>hf_hub_download</code> to bypass tie_weights issues
288
  </p>
289
  """
290
 
291
  with gr.Blocks(css=css, title="BiRefNet Demo") as demo:
292
  gr.Markdown(title_html)
293
  with gr.Tabs():
 
294
  with gr.Tab("Image"):
295
  with gr.Row():
296
  with gr.Column(scale=1):
 
300
  predict_btn = gr.Button("Predict")
301
  with gr.Column(scale=2):
302
  output_slider = ImageSlider(label="Result", type="pil")
303
+ gr.Examples(examples=examples_image, inputs=[image_input, resolution_input, weights_radio], label="Examples")
 
 
 
 
304
 
 
305
  with gr.Tab("Text"):
306
  with gr.Row():
307
  with gr.Column(scale=1):
 
311
  predict_btn_text = gr.Button("Predict")
312
  with gr.Column(scale=2):
313
  output_slider_text = ImageSlider(label="Result", type="pil")
314
+ gr.Examples(examples=examples_text, inputs=[image_url, resolution_input_text, weights_radio_text], label="Examples")
 
 
 
 
315
 
 
316
  with gr.Tab("Batch"):
317
  with gr.Row():
318
  with gr.Column(scale=1):
319
+ file_input = gr.File(label="Upload Multiple Images", type="filepath", file_count="multiple")
 
 
 
 
320
  resolution_input_batch = gr.Textbox(lines=1, placeholder="e.g., 1024x1024", label="Resolution")
321
  weights_radio_batch = gr.Radio(list(usage_to_weights_file.keys()), value="General", label="Weights")
322
  predict_btn_batch = gr.Button("Predict")
323
  with gr.Column(scale=2):
324
  output_gallery = gr.Gallery(label="Results", scale=1)
325
  zip_output = gr.File(label="Zip Download")
326
+ gr.Examples(examples=examples_batch, inputs=[file_input, resolution_input_batch, weights_radio_batch], label="Examples")
 
 
 
 
327
 
328
  gr.Markdown("<p align='center'>Model by <a href='https://huggingface.co/ZhengPeng7/BiRefNet'>ZhengPeng7/BiRefNet</a></p>")
329
 
330
+ # ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
331
  predict_btn.click(
332
  fn=predict,
333
  inputs=[image_input, resolution_input, weights_radio],