File size: 904 Bytes
3135a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90e4acb
 
 
 
 
3135a01
 
 
 
90e4acb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import random

import numpy as np
import skimage.color as sc

import torch

def set_channel(*args, n_channels=3):
    def _set_channel(img):
        if img.ndim == 2:
            img = np.expand_dims(img, axis=2)

        c = img.shape[2]
        if n_channels == 1 and c == 3:
            img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
        elif n_channels == 3 and c == 1:
            img = np.concatenate([img] * n_channels, 2)

        return img

    return [_set_channel(a) for a in args]

def np2Tensor(*args, rgb_range=255, format='NCHW'):
    def _np2Tensor(img, channel_format):
        assert channel_format in ('NCHW', 'NHWC')
        img = np.ascontiguousarray(img.transpose((2, 0, 1))) if channel_format == ('NCHW') else img
        tensor = torch.from_numpy(img).float()
        tensor.mul_(rgb_range / 255)

        return tensor

    return [_np2Tensor(a, format) for a in args]