Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import numpy as np | |
# import requests | |
# import torch | |
import tensorflow as tf | |
from PIL import Image | |
from SpecSeg import SpecSeg | |
image_size = 128 | |
def predict(inp): | |
inp = tf.keras.utils.img_to_array(inp) | |
print(np.shape(inp)) | |
h_orig, w_orig, _ = np.shape(inp) | |
inp = tf.image.rgb_to_grayscale( inp ) | |
inp = tf.image.resize(inp, (image_size, image_size)) | |
model = SpecSeg(image_size, image_size, 1) | |
model.load_weights('models/specsegv3_chkpt.h5') | |
inp =np.expand_dims(inp , 0) | |
#Predict on a few images | |
prediction = (model.predict(inp)) | |
# resize back to original for display | |
prediction = tf.image.resize(prediction, (h_orig, w_orig)) | |
# convert back to PIL for gradio | |
prediction = np.squeeze(prediction, 0) | |
prediction = Image.fromarray(np.uint8(prediction.squeeze())*255) | |
return prediction | |
with gr.Blocks() as application: | |
with gr.Tab("SpecSeg Demo",): | |
intro = """ | |
# SpecSeg Network for Specular Highlight Detection and Segmentation in Real-World Images | |
## Introduction | |
TThis repository is the implementation of our paper 'SpecSeg Network for Specular Highlight Detection and Segmentation in Real-World Images'. The developed network and pretrained weights can be used for network training and testing. Please cite the paper if you use them and find them useful.** | |
## Citation | |
``` | |
@Article{s22176552, | |
AUTHOR = {Anwer, Atif and Ainouz, Samia and Saad, Mohamad Naufal Mohamad and Ali, Syed Saad Azhar and Meriaudeau, Fabrice}, | |
TITLE = {SpecSeg Network for Specular Highlight Detection and Segmentation in Real-World Images}, | |
JOURNAL = {Sensors}, | |
VOLUME = {22}, | |
YEAR = {2022}, | |
NUMBER = {17}, | |
ARTICLE-NUMBER = {6552}, | |
URL = {https://www.mdpi.com/1424-8220/22/17/6552}, | |
ISSN = {1424-8220}, | |
DOI = {10.3390/s22176552} | |
} | |
``` | |
""" | |
gr.Markdown(intro) | |
SpecSeg_demo = gr.Interface( fn = predict, | |
inputs = gr.Image( type="pil", label="Input-image"), | |
outputs = gr.Image( type="pil", label="Specular-Mask"), | |
css = ".output-image, .input-image, .image-preview {height: 600px !important}", | |
# flagging_options=["blurry", "incorrect"], | |
examples=[ | |
os.path.join(os.path.dirname(__file__), "images/img01.png"), | |
os.path.join(os.path.dirname(__file__), "images/img02.png"), | |
os.path.join(os.path.dirname(__file__), "images/img03.png"), | |
os.path.join(os.path.dirname(__file__), "images/img04.png"), | |
os.path.join(os.path.dirname(__file__), "images/img05.png"),], | |
) | |
if __name__ == "__main__": | |
print( "------------------------------------", | |
"\nTensoorflow version:", tf.__version__, | |
"\nKeras Version", tf.keras.__version__ , | |
"\n------------------------------------",) | |
application.launch() | |