test / testvae.py
junjuice0's picture
Update testvae.py
96a87c7
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image, to_tensor
import glob
from PIL import Image
import tqdm
import gc
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.start = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False)
self.conv1 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
self.conv3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False)
self.final = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(16)
self.bn2 = torch.nn.BatchNorm2d(16)
def forward(self, x):
x = self.start(x)
x = self.bn1(x)
x = self.conv1(x) + x
x = self.conv2(x) + x
x = self.conv3(x) + x
x = self.bn2(x)
x = self.final(x)
x = torch.clamp(x, -1, 1)
return x
class DS(Dataset):
def __init__(self):
super().__init__()
self.g = glob.glob("./15k/*")
self.trans = transforms.Compose([
transforms.RandomCrop((256, 256)),
transforms.ToTensor()
])
def __len__(self):
return len(self.g)
def __getitem__(self, idx):
x = self.g[idx]
x = Image.open(x)
x = x.convert("RGB")
x = self.trans(x)
x = x / 127.5 - 1
return x
def gettest(self):
x = self.g[0]
x = Image.open(x)
x = x.convert("RGB")
x = to_tensor(x)
x = x / 127.5 - 1
return x
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
bacth_size = 64
epoch = 10
model = TestModel()
dataset = DS()
datalaoder = DataLoader(dataset, batch_size=bacth_size, shuffle=True)
criterion = torch.nn.MSELoss()
kl = torch.nn.KLDivLoss(size_average=False)
optim = torch.optim.Adam(model.parameters(recurse=True), lr=1e-4)
criterion = criterion.to(device)
model = model.to(device)
model.train()
def log(l):
model.eval()
x = dataset.gettest().to(device)
x = x.unsqueeze(0)
out = model(x)
to_pil_image((out[0] + 1)/2).save("./test/" + str(l) + ".png")
model.train()
log("test")
for i in range(epoch):
for j, k in enumerate(tqdm.tqdm(datalaoder)):
k = k.to(device)
model.zero_grad()
out = model(k)
loss = criterion(out, k)# + kl(((out + 1)/2).log(), (k + 1)/2)
loss.backward()
optim.step()
if j % 100 == 0:
gc.collect()
torch.cuda.empty_cache()
print("EPOCH", i)
print("LAST LOSS", loss)
log(i)
if __name__ == "__main__":
main()