File size: 2,912 Bytes
2bb76b3
 
 
 
 
 
a0c04c5
2bb76b3
 
 
 
9489bae
2bb76b3
 
 
 
9489bae
2bb76b3
 
0b76722
2bb76b3
 
a0c04c5
 
 
 
9489bae
26e1c44
7d9b835
9489bae
4c9b339
26e1c44
2bb76b3
4c9b339
c9b5e9c
4c9b339
2bb76b3
9489bae
2bb76b3
 
 
 
 
9489bae
2bb76b3
 
 
 
 
 
c9b5e9c
 
 
 
 
9489bae
4c9b339
 
 
 
9489bae
 
 
 
 
 
2bb76b3
9489bae
 
 
 
2bb76b3
 
 
30787ce
8b5cf1f
30787ce
c01448f
9489bae
c01448f
9489bae
30787ce
 
2bb76b3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import shutil
import os
import subprocess
from PIL import Image
from huggingface_hub import hf_hub_download
from shutil import copyfile

UPLOAD_DIR = "./uploaded_images"
RESULTS_DIR = "./results"
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
SAMPLE_DIR = "./sample_images"

os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)

REPO_ID = "hasnafk/SingleImageReflectionRemoval"
MODEL_FILE = "310_net_G.pth"
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)

expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
if not os.path.exists(expected_model_path):
    copyfile(model_path, expected_model_path)

def reflection_removal(input_image):
    if not input_image.lower().endswith((".jpg", ".jpeg", ".png")):
        return ["File is not supported (only .jpg, .jpeg, .png)."]
    
    file_path = os.path.join(UPLOAD_DIR, os.path.basename(input_image))
    shutil.copy(input_image, file_path)
    
    input_filename = os.path.splitext(os.path.basename(file_path))[0]
    print(f"Processing {input_filename}...")
    
    cmd = [
        "python", "test.py",
        "--dataroot", UPLOAD_DIR,
        "--name", "SingleImageReflectionRemoval",
        "--model", "test", "--netG", "unet_256",
        "--direction", "AtoB", "--dataset_mode", "single",
        "--norm", "batch", "--epoch", "310",
        "--num_test", "1",
        "--gpu_ids", "-1"
    ]
    subprocess.run(cmd, check=True)
    
    for root, _, files in os.walk(RESULTS_DIR):
        for file in files:
            print(os.path.join(root, file))
    
    for root, _, files in os.walk(RESULTS_DIR):
        for file in files:
            if file.startswith(input_filename) and file.endswith("_fake.png"):
                result_path = os.path.join(root, file)
                return [Image.open(result_path)]

    return ["No results found."]


def use_sample_image(sample_image_name):
    sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name)
    if not os.path.exists(sample_image_path):
        return "Sample image not found."
    return sample_image_path

sample_images = [
    file for file in os.listdir(SAMPLE_DIR)
    if file.endswith((".jpg", ".jpeg", ".png"))
]

iface = gr.Interface(
    fn=reflection_removal,
    inputs=[
        gr.Image(type="filepath", label="Upload Image (JPG/PNG)")
    ],
    outputs=gr.Gallery(label="Results after Reflection Removal"),
    examples=[
        os.path.join("sample_images", img) for img in os.listdir("sample_images") if img.endswith((".jpg", ".jpeg", ".png"))
    ],
    title="Reflection Remover with Pix2Pix",
    description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below."
)

if __name__ == "__main__":
    iface.launch()