Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
import gradio as gr | |
# Define EDSR custom model | |
class EDSRModel(tf.keras.Model): | |
def train_step(self, data): | |
# Unpack the data. Its structure depends on your model and | |
# on what you pass to `fit()`. | |
x, y = data | |
with tf.GradientTape() as tape: | |
y_pred = self(x, training=True) # Forward pass | |
# Compute the loss value | |
# (the loss function is configured in `compile()`) | |
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) | |
# Compute gradients | |
trainable_vars = self.trainable_variables | |
gradients = tape.gradient(loss, trainable_vars) | |
# Update weights | |
self.optimizer.apply_gradients(zip(gradients, trainable_vars)) | |
# Update metrics (includes the metric that tracks the loss) | |
self.compiled_metrics.update_state(y, y_pred) | |
# Return a dict mapping metric names to current value | |
return {m.name: m.result() for m in self.metrics} | |
def predict_step(self, x): | |
# Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast | |
x = tf.cast(tf.expand_dims(x, axis=0), tf.float32) | |
# Passing low resolution image to model | |
super_resolution_img = self(x, training=False) | |
# Clips the tensor from min(0) to max(255) | |
super_resolution_img = tf.clip_by_value(super_resolution_img, 0, 255) | |
# Rounds the values of a tensor to the nearest integer | |
super_resolution_img = tf.round(super_resolution_img) | |
# Removes dimensions of size 1 from the shape of a tensor and converting to uint8 | |
super_resolution_img = tf.squeeze( | |
tf.cast(super_resolution_img, tf.uint8), axis=0 | |
) | |
return super_resolution_img | |
# Residual Block | |
def ResBlock(inputs): | |
x = layers.Conv2D(64, 3, padding="same", activation="relu")(inputs) | |
x = layers.Conv2D(64, 3, padding="same")(x) | |
x = layers.Add()([inputs, x]) | |
return x | |
# Upsampling Block | |
def Upsampling(inputs, factor=2, **kwargs): | |
x = layers.Conv2D(64 * (factor ** 2), 3, padding="same", **kwargs)(inputs) | |
x = tf.nn.depth_to_space(x, block_size=factor) | |
x = layers.Conv2D(64 * (factor ** 2), 3, padding="same", **kwargs)(x) | |
x = tf.nn.depth_to_space(x, block_size=factor) | |
return x | |
def make_model(num_filters, num_of_residual_blocks): | |
# Flexible Inputs to input_layer | |
input_layer = layers.Input(shape=(None, None, 3)) | |
# Scaling Pixel Values | |
x = layers.Rescaling(scale=1.0 / 255)(input_layer) | |
x = x_new = layers.Conv2D(num_filters, 3, padding="same")(x) | |
# 16 residual blocks | |
for _ in range(num_of_residual_blocks): | |
x_new = ResBlock(x_new) | |
x_new = layers.Conv2D(num_filters, 3, padding="same")(x_new) | |
x = layers.Add()([x, x_new]) | |
x = Upsampling(x) | |
x = layers.Conv2D(3, 3, padding="same")(x) | |
output_layer = layers.Rescaling(scale=255)(x) | |
return EDSRModel(input_layer, output_layer) | |
# Define PSNR metric | |
def PSNR(super_resolution, high_resolution): | |
"""Compute the peak signal-to-noise ratio, measures quality of image.""" | |
# Max value of pixel is 255 | |
psnr_value = tf.image.psnr(high_resolution, super_resolution, max_val=255)[0] | |
return psnr_value | |
custom_objects = {"EDSRModel":EDSRModel} | |
with keras.utils.custom_object_scope(custom_objects): | |
new_model = keras.models.load_model("./trained.h5", custom_objects={'PSNR':PSNR}) | |
def process_image(img): | |
lowres = tf.convert_to_tensor(img, dtype=tf.uint8) | |
lowres = tf.image.random_crop(lowres, (150, 150, 3)) | |
preds = new_model.predict_step(lowres) | |
preds = preds.numpy() | |
lowres = lowres.numpy() | |
return (lowres, preds) | |
image = gr.inputs.Image() | |
image_out = gr.outputs.Image() | |
gr.Interface( | |
process_image, | |
title="EDSR - Enhanced Deep Residual Networks for Single Image Super-Resolution", | |
description="SuperResolution", | |
inputs = image, | |
outputs = gr.Gallery(label="Outputs, First image is low res, next one is High Res",visible=True), | |
interpretation='default', | |
allow_flagging='never' | |
).launch() |