|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import torchvision |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import to_pil_image, to_tensor |
|
import glob |
|
from PIL import Image |
|
import tqdm |
|
import gc |
|
|
|
class TestModel(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.start = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False) |
|
self.conv1 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
|
self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
|
self.conv3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
|
self.final = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=False) |
|
self.bn1 = torch.nn.BatchNorm2d(16) |
|
self.bn2 = torch.nn.BatchNorm2d(16) |
|
|
|
def forward(self, x): |
|
x = self.start(x) |
|
x = self.bn1(x) |
|
x = self.conv1(x) + x |
|
x = self.conv2(x) + x |
|
x = self.conv3(x) + x |
|
x = self.bn2(x) |
|
x = self.final(x) |
|
x = torch.clamp(x, -1, 1) |
|
return x |
|
|
|
class DS(Dataset): |
|
def __init__(self): |
|
super().__init__() |
|
self.g = glob.glob("./15k/*") |
|
self.trans = transforms.Compose([ |
|
transforms.RandomCrop((256, 256)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.g) |
|
|
|
def __getitem__(self, idx): |
|
x = self.g[idx] |
|
x = Image.open(x) |
|
x = x.convert("RGB") |
|
x = self.trans(x) |
|
x = x / 127.5 - 1 |
|
return x |
|
|
|
def gettest(self): |
|
x = self.g[0] |
|
x = Image.open(x) |
|
x = x.convert("RGB") |
|
x = to_tensor(x) |
|
x = x / 127.5 - 1 |
|
return x |
|
|
|
def main(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
bacth_size = 64 |
|
epoch = 10 |
|
|
|
model = TestModel() |
|
dataset = DS() |
|
datalaoder = DataLoader(dataset, batch_size=bacth_size, shuffle=True) |
|
criterion = torch.nn.MSELoss() |
|
kl = torch.nn.KLDivLoss(size_average=False) |
|
optim = torch.optim.Adam(model.parameters(recurse=True), lr=1e-4) |
|
criterion = criterion.to(device) |
|
model = model.to(device) |
|
model.train() |
|
|
|
def log(l): |
|
model.eval() |
|
x = dataset.gettest().to(device) |
|
x = x.unsqueeze(0) |
|
out = model(x) |
|
to_pil_image((out[0] + 1)/2).save("./test/" + str(l) + ".png") |
|
model.train() |
|
|
|
log("test") |
|
|
|
for i in range(epoch): |
|
for j, k in enumerate(tqdm.tqdm(datalaoder)): |
|
k = k.to(device) |
|
model.zero_grad() |
|
out = model(k) |
|
loss = criterion(out, k) |
|
loss.backward() |
|
optim.step() |
|
if j % 100 == 0: |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
print("EPOCH", i) |
|
print("LAST LOSS", loss) |
|
log(i) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |