ZhengPeng7 commited on
Commit
85f9120
·
verified ·
1 Parent(s): 5023a18

Fix a bug in loading different types of the input in tab_batch.

Browse files
Files changed (1) hide show
  1. app.py +7 -29
app.py CHANGED
@@ -57,32 +57,6 @@ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7',
57
  birefnet.to(device)
58
  birefnet.eval()
59
 
60
- # for idx, image_path in enumerate(images):
61
- # im = load_img(image_path, output_type="pil")
62
- # if im is None:
63
- # continue
64
-
65
- # im = im.convert("RGB")
66
- # image_size = im.size
67
- # input_images = transform_image(im).unsqueeze(0).to("cpu")
68
-
69
- # with torch.no_grad():
70
- # preds = birefnet(input_images)[-1].sigmoid().cpu()
71
- # pred = preds[0].squeeze()
72
- # pred_pil = transforms.ToPILImage()(pred)
73
- # mask = pred_pil.resize(image_size)
74
-
75
- # im.putalpha(mask)
76
- # output_file_path = os.path.join(save_dir, f"output_image_batch_{idx + 1}.png")
77
- # im.save(output_file_path)
78
- # output_paths.append(output_file_path)
79
-
80
- # zip_file_path = os.path.join(save_dir, "processed_images.zip")
81
- # with zipfile.ZipFile(zip_file_path, 'w') as zipf:
82
- # for file in output_paths:
83
- # zipf.write(file, os.path.basename(file))
84
-
85
- # return output_paths, zip_file_path
86
 
87
  @spaces.GPU
88
  def predict(images, resolution, weights_file):
@@ -115,9 +89,13 @@ def predict(images, resolution, weights_file):
115
 
116
  for idx_image, image_src in enumerate(images):
117
  if isinstance(image_src, str):
118
- response = requests.get(image_src)
119
- image_data = BytesIO(response.content)
120
- image = np.array(Image.open(image_data))
 
 
 
 
121
  else:
122
  image = image_src
123
 
 
57
  birefnet.to(device)
58
  birefnet.eval()
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  @spaces.GPU
62
  def predict(images, resolution, weights_file):
 
89
 
90
  for idx_image, image_src in enumerate(images):
91
  if isinstance(image_src, str):
92
+ if os.path.isfile(image_src):
93
+ image = np.array(Image.open(image_src))
94
+ else:
95
+ image = np.array(Image.open(image_src))
96
+ response = requests.get(image_src)
97
+ image_data = BytesIO(response.content)
98
+ image = np.array(Image.open(image_data))
99
  else:
100
  image = image_src
101