Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from util import imread, imsave, get_examples | |
import torch | |
def torch_compile(*args, **kwargs): | |
def decorator(func): | |
return func | |
return decorator | |
torch.compile = torch_compile # temporary workaround | |
default_model = 'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c' | |
def predict(filename, model=None, device=None, reduce_labels=True): | |
from cpn import CpnInterface | |
from prep import multi_norm | |
from celldetection import label_cmap | |
global default_model | |
assert isinstance(filename, str) | |
if device is None: | |
if torch.cuda.device_count(): | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
print(dict( | |
filename=filename, | |
model=model, | |
device=device, | |
reduce_labels=reduce_labels | |
), flush=True) | |
img = imread(filename) | |
print('Image:', img.dtype, img.shape, (img.min(), img.max()), flush=True) | |
if model is None or len(str(model)) <= 0: | |
model = default_model | |
img = multi_norm(img, 'cstm-mix') # TODO | |
m = CpnInterface(model.strip(), device=device) | |
y = m(img, reduce_labels=reduce_labels) | |
labels = y['labels'] | |
vis_labels = label_cmap(labels) | |
dst = '.'.join(filename.split('.')[:-1]) + '_labels.tiff' | |
imsave(dst, labels) | |
return img, vis_labels, dst | |
gr.Interface( | |
predict, | |
inputs=[gr.components.Image(label="Upload Input Image", type="filepath"), | |
gr.components.Textbox(label='Model Name', value=default_model, max_lines=1)], | |
outputs=[gr.Image(label="Processed Image"), | |
gr.Image(label="Label Image"), | |
gr.File(label="Download Label Image")], | |
title="Cell Detection with Contour Proposal Networks", | |
examples=get_examples(default_model) | |
).launch() | |