ghlee94 commited on
Commit
048af86
1 Parent(s): a4bb9b1

Debug app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -1115,10 +1115,10 @@ def predict(img):
1115
  vflip_tta = VerticalFlip()
1116
 
1117
  img_name = img.name
1118
- # if img_name.endswith('.tif') or img_name.endswith('.tiff'):
1119
- # img_data = tif.imread(img_name)
1120
- # else:
1121
- # img_data = io.imread(img_name)
1122
 
1123
  img_data = pred_transforms(img_name)
1124
  img_data = img_data.to(device)
@@ -1128,7 +1128,7 @@ def predict(img):
1128
  overlap = 0.5
1129
  else:
1130
  overlap = 0.6
1131
-
1132
  with torch.no_grad():
1133
  img0 = img_data
1134
  outputs0 = sliding_window_inference(
@@ -1231,14 +1231,15 @@ def predict(img):
1231
  outputs = outputs0
1232
 
1233
  pred_mask = post_process(outputs.squeeze(0).cpu().numpy(), device)
1234
-
1235
  file_path = os.path.join(
1236
  os.getcwd(), img_name.split(".")[0] + "_label.tiff"
1237
  )
1238
 
1239
  tif.imwrite(file_path, pred_mask, compression="zlib")
 
1240
  # return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff')
1241
- return img_data, pred_mask, file_path
1242
 
1243
  demo = gr.Interface(
1244
  predict,
 
1115
  vflip_tta = VerticalFlip()
1116
 
1117
  img_name = img.name
1118
+ if img_name.endswith('.tif') or img_name.endswith('.tiff'):
1119
+ origin_img = tif.imread(img_name)
1120
+ else:
1121
+ origin_img = io.imread(img_name)
1122
 
1123
  img_data = pred_transforms(img_name)
1124
  img_data = img_data.to(device)
 
1128
  overlap = 0.5
1129
  else:
1130
  overlap = 0.6
1131
+ print("start")
1132
  with torch.no_grad():
1133
  img0 = img_data
1134
  outputs0 = sliding_window_inference(
 
1231
  outputs = outputs0
1232
 
1233
  pred_mask = post_process(outputs.squeeze(0).cpu().numpy(), device)
1234
+ print("prediction end & file write")
1235
  file_path = os.path.join(
1236
  os.getcwd(), img_name.split(".")[0] + "_label.tiff"
1237
  )
1238
 
1239
  tif.imwrite(file_path, pred_mask, compression="zlib")
1240
+ print(np.max(pred_mask))
1241
  # return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff')
1242
+ return origin_img, pred_mask, file_path
1243
 
1244
  demo = gr.Interface(
1245
  predict,