import gradio as gr import numpy as np import matplotlib.pyplot as plt import os import cv2 import tensorflow as tf from tensorflow import keras from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Input, Concatenate,Dropout,Flatten ,Reshape from keras.models import Model from keras.optimizers import Adam from keras.metrics import MeanIoU from keras.layers import Lambda from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Input, Concatenate, Dropout, Flatten, Dense from keras.models import Model from keras.layers import Lambda from keras import backend as K from keras.metrics import MeanIoU from keras.models import Sequential, model_from_json import tensorflow_addons as tfa import json from keras.models import model_from_json from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense def down_block(x, filters, kernel_size=(3, 3), padding="same", strides=1,dropout_rate=0.2): c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x) c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c) p = MaxPooling2D((2, 2), (2, 2))(c) p = Dropout(dropout_rate)(p) return c, p def up_block(x, skip, filters, kernel_size=(3, 3), padding="same", strides=1,dropout_rate=0.2): us = UpSampling2D((2, 2))(x) concat = Concatenate()([us, skip]) c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(concat) c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c) c = Dropout(dropout_rate)(c) return c def bottleneck(x, filters, kernel_size=(3, 3), padding="same", strides=1,dropout_rate=0.2): c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(x) c = Conv2D(filters, kernel_size, padding=padding, strides=strides, activation="relu")(c) c = Dropout(dropout_rate)(c) return c def spatial_transformer_network(input_layer): loc_net = Conv2D(8, (3, 3), activation='relu')(input_layer) loc_net = MaxPooling2D(pool_size=(2, 2))(loc_net) loc_net = Conv2D(10, (3, 3), activation='relu')(loc_net) loc_net = MaxPooling2D(pool_size=(2, 2))(loc_net) loc_net = Flatten()(loc_net) loc_net = Dense(50, activation='relu')(loc_net) loc_net = Dense(8, kernel_initializer='zeros', bias_initializer='zeros')(loc_net) loc_net = Reshape((2, 4))(loc_net) # Flatten the loc_net tensor to make it rank 1 loc_net_flat = Flatten()(loc_net) # Define the Lambda layer with the output shape specified x = Lambda(lambda args: tfa.image.transform(args[0], args[1]))([input_layer, loc_net_flat]) return x def UNet2_with_STN(): f = [64, 128, 256, 512, 1024] inputs = Input((512,512, 1)) stn_output = spatial_transformer_network(inputs) inputs_transformed = Concatenate()([inputs, stn_output]) p0 = inputs_transformed c1, p1 = down_block(p0, f[0]) c2, p2 = down_block(p1, f[1]) c3, p3 = down_block(p2, f[2]) c4, p4 = down_block(p3, f[3]) bn = bottleneck(p4, f[4]) u1 = up_block(bn, c4, f[3]) u2 = up_block(u1, c3, f[2]) u3 = up_block(u2, c2, f[1]) u4 = up_block(u3, c1, f[0]) outputs = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(u4) model = Model(inputs, outputs) return model model = UNet2_with_STN() mean_iou = MeanIoU(num_classes=2) model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["acc",mean_iou]) model.load_weights('./model_file/UNet+stn_model_longrun.h5') def cellsegmentor(img): # image = cv2.imread(img, 0) image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) image = cv2.resize(image, (512,512)) # Histogram equalization # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = cv2.equalizeHist(image) ## Normalizing image = image / 255.0 # plt.figure(figsize=(5, 5)) # plt.imshow(image, cmap='gray') # plt.title('Predicted Mask') # plt.axis('off') # plt.show() # print(image.shape) # print(model.summary()) image = np.expand_dims(image, axis=(0, 3)) pred_mask = model.predict(image) kernel = np.ones((1, 1), np.uint8) mask = pred_mask.squeeze() proc_mask = cv2.morphologyEx(mask,cv2.MORPH_CLOSE,kernel) proc_mask = (proc_mask> 0.5).astype(np.uint8) # plt.figure(figsize=(5, 5)) # plt.imshow(proc_mask, cmap='gray') # plt.title('Predicted Mask') # plt.axis('off') # plt.show() # Convert the processed mask to an image proc_mask_image = cv2.cvtColor(proc_mask * 255, cv2.COLOR_GRAY2BGR) return proc_mask_image # Create a Gradio interface iface = gr.Interface( title="Enhanced Cell Segmentation through Spatial Transformer Networks in U-Net ", fn=cellsegmentor, inputs=gr.Image(), # Input: microscopic cell image outputs=gr.Image(), # Output: Segmented image # live=True, ) # Launch the Gradio interface iface.launch(share=True,debug=True)