File size: 2,912 Bytes
2bb76b3 a0c04c5 2bb76b3 9489bae 2bb76b3 9489bae 2bb76b3 0b76722 2bb76b3 a0c04c5 9489bae 26e1c44 7d9b835 9489bae 4c9b339 26e1c44 2bb76b3 4c9b339 c9b5e9c 4c9b339 2bb76b3 9489bae 2bb76b3 9489bae 2bb76b3 c9b5e9c 9489bae 4c9b339 9489bae 2bb76b3 9489bae 2bb76b3 30787ce 8b5cf1f 30787ce c01448f 9489bae c01448f 9489bae 30787ce 2bb76b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import shutil
import os
import subprocess
from PIL import Image
from huggingface_hub import hf_hub_download
from shutil import copyfile
UPLOAD_DIR = "./uploaded_images"
RESULTS_DIR = "./results"
CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval"
SAMPLE_DIR = "./sample_images"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)
REPO_ID = "hasnafk/SingleImageReflectionRemoval"
MODEL_FILE = "310_net_G.pth"
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR)
expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE)
if not os.path.exists(expected_model_path):
copyfile(model_path, expected_model_path)
def reflection_removal(input_image):
if not input_image.lower().endswith((".jpg", ".jpeg", ".png")):
return ["File is not supported (only .jpg, .jpeg, .png)."]
file_path = os.path.join(UPLOAD_DIR, os.path.basename(input_image))
shutil.copy(input_image, file_path)
input_filename = os.path.splitext(os.path.basename(file_path))[0]
print(f"Processing {input_filename}...")
cmd = [
"python", "test.py",
"--dataroot", UPLOAD_DIR,
"--name", "SingleImageReflectionRemoval",
"--model", "test", "--netG", "unet_256",
"--direction", "AtoB", "--dataset_mode", "single",
"--norm", "batch", "--epoch", "310",
"--num_test", "1",
"--gpu_ids", "-1"
]
subprocess.run(cmd, check=True)
for root, _, files in os.walk(RESULTS_DIR):
for file in files:
print(os.path.join(root, file))
for root, _, files in os.walk(RESULTS_DIR):
for file in files:
if file.startswith(input_filename) and file.endswith("_fake.png"):
result_path = os.path.join(root, file)
return [Image.open(result_path)]
return ["No results found."]
def use_sample_image(sample_image_name):
sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name)
if not os.path.exists(sample_image_path):
return "Sample image not found."
return sample_image_path
sample_images = [
file for file in os.listdir(SAMPLE_DIR)
if file.endswith((".jpg", ".jpeg", ".png"))
]
iface = gr.Interface(
fn=reflection_removal,
inputs=[
gr.Image(type="filepath", label="Upload Image (JPG/PNG)")
],
outputs=gr.Gallery(label="Results after Reflection Removal"),
examples=[
os.path.join("sample_images", img) for img in os.listdir("sample_images") if img.endswith((".jpg", ".jpeg", ".png"))
],
title="Reflection Remover with Pix2Pix",
description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below."
)
if __name__ == "__main__":
iface.launch()
|