SpecSeg / app.py
atifanwerPK's picture
uploaded app.py (no model)
2d4a607 verified
import os
import gradio as gr
import numpy as np
# import requests
# import torch
import tensorflow as tf
from PIL import Image
from SpecSeg import SpecSeg
image_size = 128
def predict(inp):
inp = tf.keras.utils.img_to_array(inp)
print(np.shape(inp))
h_orig, w_orig, _ = np.shape(inp)
inp = tf.image.rgb_to_grayscale( inp )
inp = tf.image.resize(inp, (image_size, image_size))
model = SpecSeg(image_size, image_size, 1)
model.load_weights('models/specsegv3_chkpt.h5')
inp =np.expand_dims(inp , 0)
#Predict on a few images
prediction = (model.predict(inp))
# resize back to original for display
prediction = tf.image.resize(prediction, (h_orig, w_orig))
# convert back to PIL for gradio
prediction = np.squeeze(prediction, 0)
prediction = Image.fromarray(np.uint8(prediction.squeeze())*255)
return prediction
with gr.Blocks() as application:
with gr.Tab("SpecSeg Demo",):
intro = """
# SpecSeg Network for Specular Highlight Detection and Segmentation in Real-World Images
## Introduction
TThis repository is the implementation of our paper 'SpecSeg Network for Specular Highlight Detection and Segmentation in Real-World Images'. The developed network and pretrained weights can be used for network training and testing. Please cite the paper if you use them and find them useful.**
## Citation
```
@Article{s22176552,
AUTHOR = {Anwer, Atif and Ainouz, Samia and Saad, Mohamad Naufal Mohamad and Ali, Syed Saad Azhar and Meriaudeau, Fabrice},
TITLE = {SpecSeg Network for Specular Highlight Detection and Segmentation in Real-World Images},
JOURNAL = {Sensors},
VOLUME = {22},
YEAR = {2022},
NUMBER = {17},
ARTICLE-NUMBER = {6552},
URL = {https://www.mdpi.com/1424-8220/22/17/6552},
ISSN = {1424-8220},
DOI = {10.3390/s22176552}
}
```
"""
gr.Markdown(intro)
SpecSeg_demo = gr.Interface( fn = predict,
inputs = gr.Image( type="pil", label="Input-image"),
outputs = gr.Image( type="pil", label="Specular-Mask"),
css = ".output-image, .input-image, .image-preview {height: 600px !important}",
# flagging_options=["blurry", "incorrect"],
examples=[
os.path.join(os.path.dirname(__file__), "images/img01.png"),
os.path.join(os.path.dirname(__file__), "images/img02.png"),
os.path.join(os.path.dirname(__file__), "images/img03.png"),
os.path.join(os.path.dirname(__file__), "images/img04.png"),
os.path.join(os.path.dirname(__file__), "images/img05.png"),],
)
if __name__ == "__main__":
print( "------------------------------------",
"\nTensoorflow version:", tf.__version__,
"\nKeras Version", tf.keras.__version__ ,
"\n------------------------------------",)
application.launch()