StableDiffusionModel / diffusion_loss.py
bala1802's picture
Upload 6 files
1975737 verified
raw
history blame
765 Bytes
import torch
import torchvision.transforms as T
import torch.nn.functional as F
def blue_channel(images):
error = torch.abs(images[:,2] - 0.9).mean()
return error
def elastic_transform(images):
elastic_transformer = T.ElasticTransform(alpha=550.0,sigma=5.0)
transformed_imgs = elastic_transformer(images)
error = torch.abs(transformed_imgs - images).mean()
return error
def symmetry(images):
flipped_image = torch.flip(images, [3])
error = F.mse_loss(images, flipped_image)
print("Loss Calculated for the Symmetry : ", error)
return error
def saturation(images):
transformed_imgs = T.functional.adjust_saturation(images,saturation_factor = 10)
error = torch.abs(transformed_imgs - images).mean()
return error