File size: 1,746 Bytes
0e3bdb5
2d249b4
0e3bdb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8bb026
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from enum import Enum
from torchvision import transforms

# CIFAR10 Normalization values
normalize = transforms.Normalize(mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
denormalize = transforms.Normalize(mean=[-0.49139968 / 0.24703233, -0.48215827 / 0.24348505, -0.44653124 / 0.26158768],
                                   std=[1 / 0.24703233, 1 / 0.24348505, 1 / 0.26158768])

class DatasetNormalizations(Enum):
    CIFAR10_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR10_STD = [0.24703233, 0.24348505, 0.26158768]


def create_random_image(image_size, mean, std):
    """
    Creates a random image from a defined mean and std normal distribution.
    Used to create more accurate random images that are built off the models
    dataset it was trained on. Mean and std must be the same length. This will
    be used to give the images its color channels. Mean of length 3 means 3 channels.

    :param image_size: Tuple of the 2D image size
    :param mean: The mean of the distribution
    :param std:  The standard deviation of the distribution
    :return: image - The created image
    """
    channels = []
    for i in range(len(mean)):  # Create each channel with the specified custom distribution
        channels.append(torch.empty((image_size[0], image_size[1])).normal_(mean=mean[i], std=std[i]))
    return torch.stack(channels)

def expo_tuple(epochs, num_values):
    if num_values < 2:
        raise ValueError("Number of values must be greater than or equal to 2")
    
    exponential_values = [
        1,
        * [int(1 + (epochs - 1) * (i / (num_values - 2)) ** 2) for i in range(1, num_values - 1)]
    ]
    return tuple(exponential_values) # Convert to tuple