enhancee / app.py
ssaad5678's picture
Update app.py
ed95bf3 verified
import numpy as np
import matplotlib.pyplot as plt
import math
from PIL import Image
import gradio as gr
from utils import *
from models import *
input_size = (256, 256, 1)
def process_image(task, deg_image):
if task == 'binarize':
generator = generator_model(biggest_layer=1024)
generator.load_weights("binarization_generator_weights.h5")
elif task == 'deblur':
generator = generator_model(biggest_layer=1024)
generator.load_weights("deblur_weights.h5")
elif task == 'unwatermark':
generator = generator_model(biggest_layer=512)
generator.load_weights("watermark_rem_weights.h5")
else:
return "Wrong task, please specify a correct task!"
deg_image = deg_image.convert('L')
test_image = np.array(deg_image)
h = ((test_image.shape[0] // 256) + 1) * 256
w = ((test_image.shape[1] // 256) + 1) * 256
test_padding = np.zeros((h, w)) + 1
test_padding[:test_image.shape[0], :test_image.shape[1]] = test_image
test_image_p = split2(test_padding.reshape(1, h, w, 1), 1, h, w)
predicted_list = []
for l in range(test_image_p.shape[0]):
predicted_list.append(generator.predict(test_image_p[l].reshape(1, 256, 256, 1)))
predicted_image = np.array(predicted_list)
predicted_image = merge_image2(predicted_image, h, w)
predicted_image = predicted_image[:test_image.shape[0], :test_image.shape[1]]
predicted_image = predicted_image.reshape(predicted_image.shape[0], predicted_image.shape[1])
if task == 'binarize':
bin_thresh = 0.95
predicted_image = (predicted_image[:, :] > bin_thresh) * 1
return predicted_image
def image_processing(task, deg_image):
processed_image = process_image(task, deg_image)
plt.imshow(processed_image, cmap='gray')
plt.axis('off')
plt.show()
return processed_image
# Interface
inputs = [
gr.inputs.Radio(['binarize', 'deblur', 'unwatermark'], label="Select Task"),
gr.inputs.Image(type='pil', label="Input Image")
]
outputs = gr.outputs.Image(type='numpy', label="Processed Image")
title = "Image Processing"
description = "Choose a task and upload an image to process."
gr.Interface(image_processing, inputs, outputs, title=title, description=description).launch()