import torch from torch.utils.data import Dataset, DataLoader import torchvision from torchvision import transforms from torchvision.transforms.functional import to_pil_image, to_tensor import glob from PIL import Image import tqdm import gc class TestModel(torch.nn.Module): def __init__(self): super().__init__() self.start = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False) self.conv1 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) self.conv3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) self.final = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=False) self.bn1 = torch.nn.BatchNorm2d(16) self.bn2 = torch.nn.BatchNorm2d(16) def forward(self, x): x = self.start(x) x = self.bn1(x) x = self.conv1(x) + x x = self.conv2(x) + x x = self.conv3(x) + x x = self.bn2(x) x = self.final(x) x = torch.clamp(x, -1, 1) return x class DS(Dataset): def __init__(self): super().__init__() self.g = glob.glob("./15k/*") self.trans = transforms.Compose([ transforms.RandomCrop((256, 256)), transforms.ToTensor() ]) def __len__(self): return len(self.g) def __getitem__(self, idx): x = self.g[idx] x = Image.open(x) x = x.convert("RGB") x = self.trans(x) x = x / 127.5 - 1 return x def gettest(self): x = self.g[0] x = Image.open(x) x = x.convert("RGB") x = to_tensor(x) x = x / 127.5 - 1 return x def main(): device = "cuda" if torch.cuda.is_available() else "cpu" bacth_size = 64 epoch = 10 model = TestModel() dataset = DS() datalaoder = DataLoader(dataset, batch_size=bacth_size, shuffle=True) criterion = torch.nn.MSELoss() kl = torch.nn.KLDivLoss(size_average=False) optim = torch.optim.Adam(model.parameters(recurse=True), lr=1e-4) criterion = criterion.to(device) model = model.to(device) model.train() def log(l): model.eval() x = dataset.gettest().to(device) x = x.unsqueeze(0) out = model(x) to_pil_image((out[0] + 1)/2).save("./test/" + str(l) + ".png") model.train() log("test") for i in range(epoch): for j, k in enumerate(tqdm.tqdm(datalaoder)): k = k.to(device) model.zero_grad() out = model(k) loss = criterion(out, k)# + kl(((out + 1)/2).log(), (k + 1)/2) loss.backward() optim.step() if j % 100 == 0: gc.collect() torch.cuda.empty_cache() print("EPOCH", i) print("LAST LOSS", loss) log(i) if __name__ == "__main__": main()