| | from core import runner |
| | import torch |
| | from torch import tensor |
| | from PIL import Image |
| | import numpy as np |
| | import torch.nn.functional as F |
| | import gradio as gr |
| |
|
| | |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f'{device.type=}') |
| |
|
| | description = '<p> Choose an example below; OR <br>\ |
| | Upload by yourself: <br>\ |
| | 1. Upload any test image (query) with any target object you wish to segment <br>\ |
| | 2. Upload another image (support) with the target object or a variation of it <br>\ |
| | 3. Upload a binary mask that segments the target objet in the support image <br>\ |
| | </p>' |
| | |
| | |
| | example_episodes = [ |
| | ['./imgs/549870_35.jpg', './imgs/457070_00.jpg', './imgs/457070_00.png'], |
| | ['./imgs/ISIC_0000372.jpg', './imgs/ISIC_0013176.jpg', './imgs/ISIC_0013176_segmentation.png'], |
| | ['./imgs/d_r_450_.jpg', './imgs/d_r_465_.jpg', './imgs/d_r_465_.bmp'], |
| | ['./imgs/CHNCXR_0282_0.png', './imgs/CHNCXR_0324_0.png', './imgs/CHNCXR_0324_0_mask.png'], |
| | ['./imgs/1.jpg', './imgs/5.jpg', './imgs/5.png'], |
| | ['./imgs/cake1.png', './imgs/cake2.png', './imgs/cake2_mask.png'] |
| | ] |
| | blank_img = './imgs/blank.png' |
| |
|
| | gr_img = lambda name: gr.Image(label=name, sources=['upload', 'webcam'], type="pil") |
| | inputs = [gr_img('Query Img'), gr_img('Support Img'), gr_img('Support Mask'), gr.Checkbox(label='re-adapt')] |
| | if device.type=='cpu': |
| | inputs.append(gr.Checkbox(label='Confirm CPU run (CHOOSE ONLY WHEN REQUESTED)')) |
| |
|
| | def prepare_feat_maker(): |
| | config = runner.makeConfig() |
| | class DummyDataset: |
| | class_ids = [0] |
| | fake_feat_maker = runner.makeFeatureMaker(DummyDataset(), config, device=device) |
| | return fake_feat_maker |
| |
|
| | feat_maker = prepare_feat_maker() |
| | has_fit = False |
| |
|
| | |
| |
|
| | def reset_layers(): |
| | global feat_maker |
| | feat_maker = prepare_feat_maker() |
| | |
| | def prepare_batch(q_img_pil, s_img_pil, s_mask_pil): |
| | from data.dataset import FSSDataset |
| | FSSDataset.initialize(img_size=400,datapath='') |
| | q_img_tensor = FSSDataset.transform(q_img_pil) |
| | s_img_tensor = FSSDataset.transform(s_img_pil) |
| | s_mask_tensor = torch.tensor(np.array(s_mask_pil.convert('L'))) |
| | s_mask_tensor = F.interpolate(s_mask_tensor.unsqueeze(0).unsqueeze(0).float(), s_img_tensor.size()[-2:], mode='nearest').squeeze() |
| | add_batch_dim = lambda t: t.unsqueeze(0) |
| | add_kshot_dim = lambda t: t.unsqueeze(1) |
| | fake_batch = {'query_img':add_batch_dim(q_img_tensor), 'support_imgs':add_kshot_dim(add_batch_dim(s_img_tensor)), 'support_masks':add_kshot_dim(add_batch_dim(s_mask_tensor)), 'class_id':tensor([0])} |
| | return fake_batch |
| |
|
| | norm = lambda t: (t - t.min()) / (t.max() - t.min()) |
| | def overlay(img, mask): |
| | |
| | return norm(img)*0.5 + mask[:,:,np.newaxis]*0.5 |
| | |
| | def from_model(q_img, s_img, s_mask): |
| | batch = prepare_batch(q_img, s_img, s_mask) |
| | sseval = runner.SingleSampleEval(batch, feat_maker) |
| | pred_logits, pred_mask = sseval.forward() |
| | global has_fit |
| | has_fit = True |
| | |
| | return norm(pred_logits[0].numpy()), overlay(batch['query_img'][0].permute(1,2,0).numpy(), pred_mask[0].numpy()) |
| |
|
| | def predict(q,s,m,re_adapt,confirmed): |
| | print(f'predict with {re_adapt=}, {confirmed=}') |
| | print(f'{type(q)=}') |
| | is_cache_run = re_adapt is None and confirmed is None |
| | is_example = any([(np.array_equal(np.array(m),np.array(Image.open(e[2])))) for e in example_episodes]) |
| | print(f'{is_example=}') |
| |
|
| | if is_cache_run: |
| | reset_layers() |
| | pred = from_model(q,s,m) |
| | msg = 'Results ready.' |
| | return msg, *pred |
| | elif re_adapt: |
| | if confirmed: |
| | reset_layers() |
| | pred = from_model(q,s,m) |
| | msg = "Results ready.\nRemember to untick 're-adapt' if you wish to predict more images with the same parameters." |
| | return msg, *pred |
| | else: |
| | msg = "You chose to re-adapt but are on CPU.\nThis may take 1 minute on your local machine or 4 minutes on huggingface space.\nSelect 'Confirm CPU run' to start." |
| | return msg, blank_img, blank_img |
| | else: |
| | if is_example: |
| | msg = "Cached results for example have been shown previously already.\nTo view it again, click the example again.\nTo run adaption again from scratch, select 're-adapt'." |
| | return msg, blank_img, blank_img |
| | else: |
| | if has_fit: |
| | pred = from_model(q,s,m) |
| | msg = "Results predicted based on layers fitted from previous run.\nIf you wish to re-adapt, select 're-adapt'." |
| | return msg, *pred |
| | else: |
| | msg = "This is the first time you predict own images.\nThe attached layers need to be fitted.\nPlease select 're-adapt'." |
| | return msg, blank_img, blank_img |
| | |
| | gradio_app = gr.Interface( |
| | fn=predict, |
| | inputs=inputs, |
| | outputs=[gr.Textbox(label="Status"), gr.Image(label="Coarse Query Prediction"), gr.Image(label="Mask Prediction")], |
| | description=description, |
| | examples=example_episodes, |
| | title="abcdfss", |
| | ) |
| |
|
| | gradio_app.launch() |