|
import gradio as gr |
|
import shutil |
|
import os |
|
import subprocess |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from shutil import copyfile |
|
|
|
UPLOAD_DIR = "./uploaded_images" |
|
RESULTS_DIR = "./results" |
|
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval" |
|
SAMPLE_DIR = "./sample_images" |
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
os.makedirs(CHECKPOINTS_DIR, exist_ok=True) |
|
os.makedirs(SAMPLE_DIR, exist_ok=True) |
|
|
|
REPO_ID = "hasnafk/SingleImageReflectionRemoval" |
|
MODEL_FILE = "310_net_G.pth" |
|
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR) |
|
|
|
expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE) |
|
if not os.path.exists(expected_model_path): |
|
copyfile(model_path, expected_model_path) |
|
|
|
def reflection_removal(input_image): |
|
if not input_image.lower().endswith((".jpg", ".jpeg", ".png")): |
|
return ["File is not supported (only .jpg, .jpeg, .png)."] |
|
|
|
file_path = os.path.join(UPLOAD_DIR, os.path.basename(input_image)) |
|
shutil.copy(input_image, file_path) |
|
|
|
input_filename = os.path.splitext(os.path.basename(file_path))[0] |
|
print(f"Processing {input_filename}...") |
|
|
|
cmd = [ |
|
"python", "test.py", |
|
"--dataroot", UPLOAD_DIR, |
|
"--name", "SingleImageReflectionRemoval", |
|
"--model", "test", "--netG", "unet_256", |
|
"--direction", "AtoB", "--dataset_mode", "single", |
|
"--norm", "batch", "--epoch", "310", |
|
"--num_test", "1", |
|
"--gpu_ids", "-1" |
|
] |
|
subprocess.run(cmd, check=True) |
|
|
|
for root, _, files in os.walk(RESULTS_DIR): |
|
for file in files: |
|
print(os.path.join(root, file)) |
|
|
|
for root, _, files in os.walk(RESULTS_DIR): |
|
for file in files: |
|
if file.startswith(input_filename) and file.endswith("_fake.png"): |
|
result_path = os.path.join(root, file) |
|
return [Image.open(result_path)] |
|
|
|
return ["No results found."] |
|
|
|
|
|
def use_sample_image(sample_image_name): |
|
sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name) |
|
if not os.path.exists(sample_image_path): |
|
return "Sample image not found." |
|
return sample_image_path |
|
|
|
sample_images = [ |
|
file for file in os.listdir(SAMPLE_DIR) |
|
if file.endswith((".jpg", ".jpeg", ".png")) |
|
] |
|
|
|
iface = gr.Interface( |
|
fn=reflection_removal, |
|
inputs=[ |
|
gr.Image(type="filepath", label="Upload Image (JPG/PNG)") |
|
], |
|
outputs=gr.Gallery(label="Results after Reflection Removal"), |
|
examples=[ |
|
os.path.join("sample_images", img) for img in os.listdir("sample_images") if img.endswith((".jpg", ".jpeg", ".png")) |
|
], |
|
title="Reflection Remover with Pix2Pix", |
|
description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|