File size: 1,805 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import random
from pyexpat.errors import messages

import pycocotools.mask as maskUtils
import numpy as np

from projects.lisa.datasets.utils import DEFAULT_IMAGE_TOKEN


def region_llava_map_fn(example):
    k_regions = 6
    object_datas = example['objects']

    if len(object_datas) > k_regions:
        selected_indexes = np.random.choice(list(range(0, len(object_datas))), size=k_regions, replace=False)
    else:
        selected_indexes = np.random.choice(list(range(0, len(object_datas))), size=k_regions, replace=True)
    # selected_indexes = selected_indexes.astype(np.int64).tolist()
    object_datas = [object_datas[_idx] for _idx in selected_indexes]
    region_masks  = []
    region_captions = []
    for object_data in object_datas:
        i_cap = random.randint(0, len(object_data['captions'])-1)
        region_captions.append(object_data['captions'][i_cap])
        object_rle = object_data['segm']
        _mask = maskUtils.decode(object_rle).astype(np.uint8)
        region_masks.append(_mask)
    region_masks = np.stack(region_masks, axis=0)

    messages = []
    for _cap in region_captions:
        messages.append({'from': 'human', 'value': 'Please describe {}.'.format(DEFAULT_IMAGE_TOKEN)})
        messages.append({'from': 'gpt', 'value': _cap + '.'})

    input = ''
    conversation = []
    while messages and messages[0]['from'] == 'gpt':
        # Skip the first one if it is from gpt
        messages = messages[1:]
    for msg in messages:
        if msg['from'] == 'human':
            input += msg['value']

        elif msg['from'] == 'gpt':
            conversation.append({'input': input, 'output': msg['value']})
            input = ''
        else:
            raise NotImplementedError
    return {'conversation': conversation, 'region_masks': region_masks}