File size: 2,814 Bytes
06e7f8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Xueyan Zou (xueyan@cs.wisc.edu)
# --------------------------------------------------------

import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import transforms
from utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

t = []
t.append(transforms.Resize(224, interpolation=Image.BICUBIC))
transform_ret = transforms.Compose(t)
t = []
t.append(transforms.Resize(512, interpolation=Image.BICUBIC))
transform_grd = transforms.Compose(t)

metedata = MetadataCatalog.get('coco_2017_train_panoptic')

def referring_captioning(model, image, texts, inpainting_text, *args, **kwargs):
    model_last, model_cap = model
    with torch.no_grad():
        image_ori = image
        image = transform_grd(image)
        width = image.size[0]
        height = image.size[1]
        image = np.asarray(image)
        image_ori_ = image
        images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
        texts_input = [[texts.strip() if texts.endswith('.') else (texts + '.')]]

        batch_inputs = [{'image': images, 'groundings': {'texts':texts_input}, 'height': height, 'width': width}]
        outputs = model_last.model.evaluate_grounding(batch_inputs, None)

        grd_mask = (outputs[-1]['grounding_mask'] > 0).float()
        grd_mask_ = (1 - F.interpolate(grd_mask[None,], (224, 224), mode='nearest')[0]).bool()

        color = [252/255, 91/255, 129/255]
        visual = Visualizer(image_ori_, metadata=metedata)
        demo = visual.draw_binary_mask(grd_mask.cpu().numpy()[0], color=color, text=texts)
        res = demo.get_image()

        if (1 - grd_mask_.float()).sum() < 5:
            torch.cuda.empty_cache()
            return Image.fromarray(res), 'n/a', None

        grd_mask_ = grd_mask_ * 0
        image = transform_ret(image_ori)
        image_ori = np.asarray(image_ori)
        image = np.asarray(image)
        images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()
        batch_inputs = [{'image': images, 'image_id': 0, 'captioning_mask': grd_mask_}]

        token_text = texts.replace('.','') if texts.endswith('.') else texts
        token = model_cap.model.sem_seg_head.predictor.lang_encoder.tokenizer.encode(token_text)
        token = torch.tensor(token)[None,:-1]

        outputs = model_cap.model.evaluate_captioning(batch_inputs, extra={'token': token})
        # outputs = model_cap.model.evaluate_captioning(batch_inputs, extra={})
        text = outputs[-1]['captioning_text']

    torch.cuda.empty_cache()
    return Image.fromarray(res), text, None