VirtualStaining / app.py
edyoshikun's picture
delete duplicated images and adding removing gpu dependency
90ad424
raw
history blame
2.7 kB
from viscy.light.engine import VSUNet
import torch
import gradio as gr
import numpy as np
from numpy.typing import ArrayLike
from skimage import exposure
from huggingface_hub import hf_hub_download
class VSGradio:
def __init__(self, model_config, model_ckpt_path):
self.model_config = model_config
self.model_ckpt_path = model_ckpt_path
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
) # Check if GPU is available
self.model = None
self.load_model()
def load_model(self):
# Load the model checkpoint
self.model = VSUNet.load_from_checkpoint(
self.model_ckpt_path,
architecture="UNeXt2_2D",
model_config=self.model_config,
)
self.model.to(self.device)
self.model.eval()
def normalize_fov(self, input: ArrayLike):
"Normalizing the fov with zero mean and unit variance"
mean = np.mean(input)
std = np.std(input)
return (input - mean) / std
def predict(self, inp):
# Setup the Trainer
# ensure inp is tensor has to be a (B,C,D,H,W) tensor
inp = self.normalize_fov(inp)
inp = torch.from_numpy(np.array(inp).astype(np.float32))
test_dict = dict(
index=None,
source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
)
with torch.inference_mode():
self.model.on_predict_start()
pred = self.model.predict_step(test_dict, 0, 0).cpu().numpy()
# Return a 2D image
nuc_pred = pred[0, 0, 0]
mem_pred = pred[0, 1, 0]
nuc_pred = exposure.rescale_intensity(nuc_pred, out_range=(0, 1))
mem_pred = exposure.rescale_intensity(mem_pred, out_range=(0, 1))
return nuc_pred, mem_pred
# %%
if __name__ == "__main__":
model_ckpt_path = hf_hub_download(
repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
)
model_config = {
"in_channels": 1,
"out_channels": 2,
"encoder_blocks": [3, 3, 9, 3],
"dims": [96, 192, 384, 768],
"decoder_conv_blocks": 2,
"stem_kernel_size": [1, 2, 2],
"in_stack_depth": 1,
"pretraining": False,
}
vsgradio = VSGradio(model_config, model_ckpt_path)
gr.Interface(
fn=vsgradio.predict,
inputs=gr.Image(type="numpy", image_mode="L", format="png"),
outputs=[
gr.Image(type="numpy", format="png"),
gr.Image(type="numpy", format="png"),
],
examples=[
"examples/a549.png",
"examples/hek.png",
],
).launch()