Spaces:
Runtime error
Runtime error
File size: 1,746 Bytes
0e3bdb5 2d249b4 0e3bdb5 e8bb026 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
from enum import Enum
from torchvision import transforms
# CIFAR10 Normalization values
normalize = transforms.Normalize(mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
denormalize = transforms.Normalize(mean=[-0.49139968 / 0.24703233, -0.48215827 / 0.24348505, -0.44653124 / 0.26158768],
std=[1 / 0.24703233, 1 / 0.24348505, 1 / 0.26158768])
class DatasetNormalizations(Enum):
CIFAR10_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR10_STD = [0.24703233, 0.24348505, 0.26158768]
def create_random_image(image_size, mean, std):
"""
Creates a random image from a defined mean and std normal distribution.
Used to create more accurate random images that are built off the models
dataset it was trained on. Mean and std must be the same length. This will
be used to give the images its color channels. Mean of length 3 means 3 channels.
:param image_size: Tuple of the 2D image size
:param mean: The mean of the distribution
:param std: The standard deviation of the distribution
:return: image - The created image
"""
channels = []
for i in range(len(mean)): # Create each channel with the specified custom distribution
channels.append(torch.empty((image_size[0], image_size[1])).normal_(mean=mean[i], std=std[i]))
return torch.stack(channels)
def expo_tuple(epochs, num_values):
if num_values < 2:
raise ValueError("Number of values must be greater than or equal to 2")
exponential_values = [
1,
* [int(1 + (epochs - 1) * (i / (num_values - 2)) ** 2) for i in range(1, num_values - 1)]
]
return tuple(exponential_values) # Convert to tuple
|