aryrk commited on
Commit
0eb4294
·
1 Parent(s): c9b5e9c
Files changed (4) hide show
  1. .gitignore +0 -1
  2. app.py +4 -1
  3. sample_images/1.png +0 -0
  4. sample_images/2.png +0 -0
.gitignore CHANGED
@@ -5,7 +5,6 @@ checkpoints/
5
  results/
6
  build/
7
  dist/
8
- *.png
9
  torch.egg-info/
10
  */**/__pycache__
11
  torch/version.py
 
5
  results/
6
  build/
7
  dist/
 
8
  torch.egg-info/
9
  */**/__pycache__
10
  torch/version.py
app.py CHANGED
@@ -23,6 +23,9 @@ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHE
23
  expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
24
  if not os.path.exists(expected_model_path):
25
  copyfile(model_path, expected_model_path)
 
 
 
26
 
27
  def reflection_removal(input_image):
28
  if not input_image.lower().endswith((".jpg", ".jpeg", ".png")):
@@ -41,7 +44,7 @@ def reflection_removal(input_image):
41
  "--model", "test", "--netG", "unet_256",
42
  "--direction", "AtoB", "--dataset_mode", "single",
43
  "--norm", "batch", "--epoch", "310",
44
- "--num_test", "1",
45
  "--gpu_ids", "-1"
46
  ]
47
  subprocess.run(cmd, check=True)
 
23
  expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
24
  if not os.path.exists(expected_model_path):
25
  copyfile(model_path, expected_model_path)
26
+
27
+ def count_files(directory):
28
+ return sum([len(files) for _, _, files in os.walk(directory)])
29
 
30
  def reflection_removal(input_image):
31
  if not input_image.lower().endswith((".jpg", ".jpeg", ".png")):
 
44
  "--model", "test", "--netG", "unet_256",
45
  "--direction", "AtoB", "--dataset_mode", "single",
46
  "--norm", "batch", "--epoch", "310",
47
+ "--num_test", count_files(UPLOAD_DIR),
48
  "--gpu_ids", "-1"
49
  ]
50
  subprocess.run(cmd, check=True)
sample_images/1.png ADDED
sample_images/2.png ADDED