aryrk
[debug]
c9b5e9c
raw
history blame
2.91 kB
import gradio as gr
import shutil
import os
import subprocess
from PIL import Image
from huggingface_hub import hf_hub_download
from shutil import copyfile
UPLOAD_DIR = "./uploaded_images"
RESULTS_DIR = "./results"
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
SAMPLE_DIR = "./sample_images"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)
REPO_ID = "hasnafk/SingleImageReflectionRemoval"
MODEL_FILE = "310_net_G.pth"
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
if not os.path.exists(expected_model_path):
copyfile(model_path, expected_model_path)
def reflection_removal(input_image):
if not input_image.lower().endswith((".jpg", ".jpeg", ".png")):
return ["File is not supported (only .jpg, .jpeg, .png)."]
file_path = os.path.join(UPLOAD_DIR, os.path.basename(input_image))
shutil.copy(input_image, file_path)
input_filename = os.path.splitext(os.path.basename(file_path))[0]
print(f"Processing {input_filename}...")
cmd = [
"python", "test.py",
"--dataroot", UPLOAD_DIR,
"--name", "SingleImageReflectionRemoval",
"--model", "test", "--netG", "unet_256",
"--direction", "AtoB", "--dataset_mode", "single",
"--norm", "batch", "--epoch", "310",
"--num_test", "1",
"--gpu_ids", "-1"
]
subprocess.run(cmd, check=True)
for root, _, files in os.walk(RESULTS_DIR):
for file in files:
print(os.path.join(root, file))
for root, _, files in os.walk(RESULTS_DIR):
for file in files:
if file.startswith(input_filename) and file.endswith("_fake.png"):
result_path = os.path.join(root, file)
return [Image.open(result_path)]
return ["No results found."]
def use_sample_image(sample_image_name):
sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name)
if not os.path.exists(sample_image_path):
return "Sample image not found."
return sample_image_path
sample_images = [
file for file in os.listdir(SAMPLE_DIR)
if file.endswith((".jpg", ".jpeg", ".png"))
]
iface = gr.Interface(
fn=reflection_removal,
inputs=[
gr.Image(type="filepath", label="Upload Image (JPG/PNG)")
],
outputs=gr.Gallery(label="Results after Reflection Removal"),
examples=[
os.path.join("sample_images", img) for img in os.listdir("sample_images") if img.endswith((".jpg", ".jpeg", ".png"))
],
title="Reflection Remover with Pix2Pix",
description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below."
)
if __name__ == "__main__":
iface.launch()