Spaces:
Sleeping
Sleeping
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
from PIL import Image | |
from torchvision import transforms | |
#from utils.visualizer import Visualizer | |
# from detectron2.utils.colormap import random_color | |
# from detectron2.data import MetadataCatalog | |
# from detectron2.structures import BitMasks | |
from modeling.language.loss import vl_similarity | |
from utilities.constants import BIOMED_CLASSES | |
#from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES | |
# import cv2 | |
# import os | |
# import glob | |
# import subprocess | |
from PIL import Image | |
import random | |
t = [] | |
t.append(transforms.Resize((1024, 1024), interpolation=Image.BICUBIC)) | |
transform = transforms.Compose(t) | |
#metadata = MetadataCatalog.get('coco_2017_train_panoptic') | |
all_classes = ['background'] + [name.replace('-other','').replace('-merged','') | |
for name in BIOMED_CLASSES] + ["others"] | |
# colors_list = [(np.array(color['color'])/255).tolist() for color in COCO_CATEGORIES] + [[1, 1, 1]] | |
# use color list from matplotlib | |
import matplotlib.colors as mcolors | |
colors = dict(mcolors.TABLEAU_COLORS, **mcolors.BASE_COLORS) | |
colors_list = [list(colors.values())[i] for i in range(16)] | |
from .output_processing import mask_stats, combine_masks | |
def interactive_infer_image(model, image, prompts): | |
image_resize = transform(image) | |
width = image.size[0] | |
height = image.size[1] | |
image_resize = np.asarray(image_resize) | |
image = torch.from_numpy(image_resize.copy()).permute(2,0,1).cuda() | |
data = {"image": image, 'text': prompts, "height": height, "width": width} | |
# inistalize task | |
model.model.task_switch['spatial'] = False | |
model.model.task_switch['visual'] = False | |
model.model.task_switch['grounding'] = True | |
model.model.task_switch['audio'] = False | |
model.model.task_switch['grounding'] = True | |
batch_inputs = [data] | |
results,image_size,extra = model.model.evaluate_demo(batch_inputs) | |
pred_masks = results['pred_masks'][0] | |
v_emb = results['pred_captions'][0] | |
t_emb = extra['grounding_class'] | |
t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) | |
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) | |
temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale | |
out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) | |
matched_id = out_prob.max(0)[1] | |
pred_masks_pos = pred_masks[matched_id,:,:] | |
pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1] | |
# interpolate mask to ori size | |
pred_mask_prob = F.interpolate(pred_masks_pos[None,], (data['height'], data['width']), | |
mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy() | |
pred_masks_pos = (1*(pred_mask_prob > 0.5)).astype(np.uint8) | |
return pred_mask_prob | |
# def interactive_infer_panoptic_biomedseg(model, image, tasks, reftxt=None): | |
# image_ori = transform(image) | |
# #mask_ori = image['mask'] | |
# width = image_ori.size[0] | |
# height = image_ori.size[1] | |
# image_ori = np.asarray(image_ori) | |
# visual = Visualizer(image_ori, metadata=metadata) | |
# images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() | |
# data = {"image": images, "height": height, "width": width} | |
# if len(tasks) == 0: | |
# tasks = ["Panoptic"] | |
# # inistalize task | |
# model.model.task_switch['spatial'] = False | |
# model.model.task_switch['visual'] = False | |
# model.model.task_switch['grounding'] = False | |
# model.model.task_switch['audio'] = False | |
# # check if reftxt is list of strings | |
# assert isinstance(reftxt, list), f"reftxt should be a list of strings, but got {type(reftxt)}" | |
# model.model.task_switch['grounding'] = True | |
# predicts = {} | |
# for i, txt in enumerate(reftxt): | |
# data['text'] = txt | |
# batch_inputs = [data] | |
# results,image_size,extra = model.model.evaluate_demo(batch_inputs) | |
# pred_masks = results['pred_masks'][0] | |
# v_emb = results['pred_captions'][0] | |
# t_emb = extra['grounding_class'] | |
# t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) | |
# v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) | |
# temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale | |
# out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) | |
# matched_id = out_prob.max(0)[1] | |
# pred_masks_pos = pred_masks[matched_id,:,:] | |
# pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1] | |
# # interpolate mask to ori size | |
# #pred_masks_pos = (F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']] > 0.0).float().cpu().numpy() | |
# # masks.append(pred_masks_pos[0]) | |
# # mask = pred_masks_pos[0] | |
# # masks.append(mask) | |
# # interpolate mask to ori size | |
# pred_mask_prob = F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy() | |
# #pred_masks_pos = 1*(pred_mask_prob > 0.5) | |
# predicts[txt] = pred_mask_prob[0] | |
# masks = combine_masks(predicts) | |
# predict_mask_stats = {} | |
# print(masks.keys()) | |
# for i, txt in enumerate(masks): | |
# mask = masks[txt] | |
# demo = visual.draw_binary_mask(mask, color=colors_list[i], text=txt) | |
# predict_mask_stats[txt] = mask_stats((predicts[txt]*255), image_ori) | |
# res = demo.get_image() | |
# torch.cuda.empty_cache() | |
# # return Image.fromarray(res), stroke_inimg, stroke_refimg | |
# return Image.fromarray(res), None, predict_mask_stats | |