manikanta's picture
v-2
846ccba
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import tensorflow as tf
from tensorflow import keras
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Input, Concatenate,Dropout,Flatten ,Reshape
from keras.models import Model
from keras.optimizers import Adam
from keras.metrics import MeanIoU
from keras.layers import Lambda
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Input, Concatenate, Dropout, Flatten, Dense
from keras.models import Model
from keras.layers import Lambda
from keras import backend as K
from keras.metrics import MeanIoU
from keras.models import Sequential, model_from_json
import tensorflow_addons as tfa
import json
from keras.models import model_from_json
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1,dropout_rate=0.2):
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
p = MaxPooling2D((2, 2), (2, 2))(c)
p = Dropout(dropout_rate)(p)
return c, p
def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1,dropout_rate=0.2):
us = UpSampling2D((2, 2))(x)
concat = Concatenate()([us, skip])
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat)
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
c = Dropout(dropout_rate)(c)
return c
def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1,dropout_rate=0.2):
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x)
c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c)
c = Dropout(dropout_rate)(c)
return c
def spatial_transformer_network(input_layer):
loc_net = Conv2D(8, (3, 3), activation='relu')(input_layer)
loc_net = MaxPooling2D(pool_size=(2, 2))(loc_net)
loc_net = Conv2D(10, (3, 3), activation='relu')(loc_net)
loc_net = MaxPooling2D(pool_size=(2, 2))(loc_net)
loc_net = Flatten()(loc_net)
loc_net = Dense(50, activation='relu')(loc_net)
loc_net = Dense(8, kernel_initializer='zeros', bias_initializer='zeros')(loc_net)
loc_net = Reshape((2, 4))(loc_net)
# Flatten the loc_net tensor to make it rank 1
loc_net_flat = Flatten()(loc_net)
# Define the Lambda layer with the output shape specified
x = Lambda(lambda args: tfa.image.transform(args[0], args[1]))([input_layer, loc_net_flat])
return x
def UNet2_with_STN():
f = [64, 128, 256, 512, 1024]
inputs = Input((512,512, 1))
stn_output = spatial_transformer_network(inputs)
inputs_transformed = Concatenate()([inputs, stn_output])
p0 = inputs_transformed
c1, p1 = down_block(p0, f[0])
c2, p2 = down_block(p1, f[1])
c3, p3 = down_block(p2, f[2])
c4, p4 = down_block(p3, f[3])
bn = bottleneck(p4, f[4])
u1 = up_block(bn, c4, f[3])
u2 = up_block(u1, c3, f[2])
u3 = up_block(u2, c2, f[1])
u4 = up_block(u3, c1, f[0])
outputs = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(u4)
model = Model(inputs, outputs)
return model
model = UNet2_with_STN()
mean_iou = MeanIoU(num_classes=2)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["acc",mean_iou])
model.load_weights('./model_file/UNet+stn_model_longrun.h5')
def cellsegmentor(img):
# image = cv2.imread(img, 0)
image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (512,512))
# Histogram equalization
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.equalizeHist(image)
## Normalizing
image = image / 255.0
# plt.figure(figsize=(5, 5))
# plt.imshow(image, cmap='gray')
# plt.title('Predicted Mask')
# plt.axis('off')
# plt.show()
# print(image.shape)
# print(model.summary())
image = np.expand_dims(image, axis=(0, 3))
pred_mask = model.predict(image)
kernel = np.ones((1, 1), np.uint8)
mask = pred_mask.squeeze()
proc_mask = cv2.morphologyEx(mask,cv2.MORPH_CLOSE,kernel)
proc_mask = (proc_mask> 0.5).astype(np.uint8)
# plt.figure(figsize=(5, 5))
# plt.imshow(proc_mask, cmap='gray')
# plt.title('Predicted Mask')
# plt.axis('off')
# plt.show()
# Convert the processed mask to an image
proc_mask_image = cv2.cvtColor(proc_mask * 255, cv2.COLOR_GRAY2BGR)
return proc_mask_image
# Create a Gradio interface
iface = gr.Interface(
title="Enhanced Cell Segmentation through Spatial Transformer Networks in U-Net ",
fn=cellsegmentor,
inputs=gr.Image(), # Input: microscopic cell image
outputs=gr.Image(), # Output: Segmented image
# live=True,
)
# Launch the Gradio interface
iface.launch(share=True,debug=True)