vkganesan commited on
Commit
12b5a88
1 Parent(s): 852d89f

create app

Browse files
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. __pycache__/adain.cpython-39.pyc +0 -0
  3. __pycache__/decoder.cpython-39.pyc +0 -0
  4. __pycache__/encoder.cpython-39.pyc +0 -0
  5. __pycache__/net.cpython-310.pyc +0 -0
  6. __pycache__/net.cpython-39.pyc +0 -0
  7. __pycache__/utils.cpython-39.pyc +0 -0
  8. adain.py +37 -0
  9. app.py +46 -0
  10. decoder.py +33 -0
  11. encoder.py +58 -0
  12. logs/events.out.tfevents.1673075465.Vikrams-MBP.lan +3 -0
  13. logs/events.out.tfevents.1673075531.Vikrams-MBP.lan +3 -0
  14. logs/events.out.tfevents.1673075820.Vikrams-MBP.lan +3 -0
  15. logs/events.out.tfevents.1673075821.Vikrams-MBP.lan +3 -0
  16. logs/events.out.tfevents.1673075850.Vikrams-MBP.lan +3 -0
  17. logs/events.out.tfevents.1673075852.Vikrams-MBP.lan +3 -0
  18. logs/events.out.tfevents.1673075889.Vikrams-MBP.lan +3 -0
  19. logs/events.out.tfevents.1673075890.Vikrams-MBP.lan +3 -0
  20. logs/events.out.tfevents.1673075982.Vikrams-MBP.lan +3 -0
  21. logs/events.out.tfevents.1673076026.Vikrams-MBP.lan +3 -0
  22. logs/events.out.tfevents.1673076079.Vikrams-MBP.lan +3 -0
  23. logs/events.out.tfevents.1673076142.Vikrams-MBP.lan +3 -0
  24. logs/events.out.tfevents.1673076233.Vikrams-MBP.lan +3 -0
  25. logs/events.out.tfevents.1673076507.Vikrams-MBP.lan +3 -0
  26. logs/events.out.tfevents.1673076723.Vikrams-MBP.lan +3 -0
  27. logs/events.out.tfevents.1673076832.Vikrams-MBP.lan +3 -0
  28. logs/events.out.tfevents.1673076887.Vikrams-MBP.lan +3 -0
  29. logs/events.out.tfevents.1673076993.Vikrams-MBP.lan +3 -0
  30. logs/events.out.tfevents.1673077155.Vikrams-MBP.lan +3 -0
  31. logs/events.out.tfevents.1673077187.Vikrams-MBP.lan +3 -0
  32. logs/events.out.tfevents.1673077234.Vikrams-MBP.lan +3 -0
  33. logs/events.out.tfevents.1673079573.Vikrams-MBP.lan +3 -0
  34. logs/events.out.tfevents.1673079783.Vikrams-MBP.lan +3 -0
  35. logs/events.out.tfevents.1673079809.Vikrams-MBP.lan +3 -0
  36. logs/events.out.tfevents.1673079875.Vikrams-MBP.lan +3 -0
  37. logs/events.out.tfevents.1673079932.Vikrams-MBP.lan +3 -0
  38. logs/events.out.tfevents.1673080014.Vikrams-MBP.lan +3 -0
  39. logs/events.out.tfevents.1673080084.Vikrams-MBP.lan +3 -0
  40. logs/events.out.tfevents.1673080471.Vikrams-MBP.lan +3 -0
  41. logs/events.out.tfevents.1673080709.Vikrams-MBP.lan +3 -0
  42. logs/events.out.tfevents.1673733387.Vikrams-MBP.lan +3 -0
  43. logs/events.out.tfevents.1673735400.Vikrams-MBP.lan +3 -0
  44. net.py +76 -0
  45. saved-models/.DS_Store +0 -0
  46. saved-models/decoder_iter_1000.pth.tar +3 -0
  47. saved-models/decoder_iter_500.pth.tar +3 -0
  48. train.py +144 -0
  49. utils.py +42 -0
  50. vgg_normalised.pth +3 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pth.tar filter=lfs diff=lfs merge=lfs -text
__pycache__/adain.cpython-39.pyc ADDED
Binary file (877 Bytes). View file
 
__pycache__/decoder.cpython-39.pyc ADDED
Binary file (667 Bytes). View file
 
__pycache__/encoder.cpython-39.pyc ADDED
Binary file (958 Bytes). View file
 
__pycache__/net.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
__pycache__/net.cpython-39.pyc ADDED
Binary file (2.56 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.45 kB). View file
 
adain.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils import *
3
+
4
+
5
+ class AdaIN(torch.nn.Module):
6
+ def __init__(self):
7
+ super(AdaIN, self).__init__()
8
+ # initialize instance normalization function
9
+ # this is the basis of our AdaIN layer, it follows an equation similar to a z-score
10
+ # (x - mu)/sigma
11
+ self.instance_norm = torch.nn.InstanceNorm2d(3)
12
+
13
+ # forward method for our layer
14
+ # x would be the content input and y would be the style input
15
+ # both x and y are tensors
16
+ def forward(self, x, y):
17
+ # size is shaped (N, num_channels, Height, Width)
18
+ x_size = x.size()
19
+
20
+ # we do not need these since they will be calculated by the instance normalization function
21
+ #x_mean, x_std = mean_and_std_of_image(x)
22
+ y_mean, y_std = mean_and_std_of_image(y)
23
+
24
+ x_norm = self.instance_norm(x)
25
+
26
+
27
+ print(x_norm.size())
28
+ # expand size of tensors so that there are no shape errors when performing AdaIN operation
29
+ # if not self.training:
30
+ # x_norm = x_norm.view(*x_norm.shape, 1)
31
+
32
+ x_size = x_norm.size()
33
+ print(x_size)
34
+ return y_std.expand(x_size) * x_norm + y_mean.expand(x_size)
35
+
36
+
37
+
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn as nn
5
+ from decoder import decoder as Decoder
6
+ from encoder import encoder as Encoder
7
+ from net import StyleTransfer
8
+ from PIL import Image
9
+
10
+ encoder = Encoder
11
+ decoder = Decoder
12
+
13
+ encoder.load_state_dict(torch.load("./vgg_normalised.pth"))
14
+ encoder = nn.Sequential(*list(encoder.children())[:31])
15
+ decoder.load_state_dict(torch.load("./saved-models/decoder_iter_1000.pth.tar"))
16
+
17
+
18
+ net = StyleTransfer(encoder, decoder)
19
+
20
+ net.eval()
21
+
22
+ def train_transform():
23
+ transform_list = [
24
+ transforms.Resize(size=(512, 512)),
25
+ # transforms.CenterCrop(256),
26
+ transforms.ToTensor()
27
+ ]
28
+ return transforms.Compose(transform_list)
29
+
30
+ def cleanup(input, style):
31
+ transform = train_transform()
32
+ input_img = transform(Image.fromarray(input))
33
+ style_img = transform(Image.fromarray(style))
34
+ input_img = input_img.view(1, *input_img.shape)
35
+ style_img = style_img.view(1, *style_img.shape)
36
+ final_image_tensor = net(input_img, style_img)
37
+ final_image_tensor = final_image_tensor.squeeze()
38
+ to_pil = transforms.ToPILImage()
39
+ image = to_pil(final_image_tensor)
40
+ return image
41
+
42
+ def greet(name):
43
+ return "Hello " + name + "!"
44
+
45
+ demo = gr.Interface(fn=cleanup, inputs=[gr.Image(shape=(224, 224)),gr.Image(shape=(224,224))],outputs="image")
46
+ demo.launch()
decoder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ decoder = nn.Sequential(
4
+ nn.ReflectionPad2d((1, 1, 1, 1)),
5
+ nn.Conv2d(512, 256, (3, 3)),
6
+ nn.ReLU(),
7
+ nn.Upsample(scale_factor=2, mode='nearest'),
8
+ nn.ReflectionPad2d((1, 1, 1, 1)),
9
+ nn.Conv2d(256, 256, (3, 3)),
10
+ nn.ReLU(),
11
+ nn.ReflectionPad2d((1, 1, 1, 1)),
12
+ nn.Conv2d(256, 256, (3, 3)),
13
+ nn.ReLU(),
14
+ nn.ReflectionPad2d((1, 1, 1, 1)),
15
+ nn.Conv2d(256, 256, (3, 3)),
16
+ nn.ReLU(),
17
+ nn.ReflectionPad2d((1, 1, 1, 1)),
18
+ nn.Conv2d(256, 128, (3, 3)),
19
+ nn.ReLU(),
20
+ nn.Upsample(scale_factor=2, mode='nearest'),
21
+ nn.ReflectionPad2d((1, 1, 1, 1)),
22
+ nn.Conv2d(128, 128, (3, 3)),
23
+ nn.ReLU(),
24
+ nn.ReflectionPad2d((1, 1, 1, 1)),
25
+ nn.Conv2d(128, 64, (3, 3)),
26
+ nn.ReLU(),
27
+ nn.Upsample(scale_factor=2, mode='nearest'),
28
+ nn.ReflectionPad2d((1, 1, 1, 1)),
29
+ nn.Conv2d(64, 64, (3, 3)),
30
+ nn.ReLU(),
31
+ nn.ReflectionPad2d((1, 1, 1, 1)),
32
+ nn.Conv2d(64, 3, (3, 3)),
33
+ )
encoder.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ encoder = nn.Sequential(
5
+ nn.Conv2d(3, 3, (1, 1)),
6
+ nn.ReflectionPad2d((1, 1, 1, 1)),
7
+ nn.Conv2d(3, 64, (3, 3)),
8
+ nn.ReLU(), # relu1-1
9
+ nn.ReflectionPad2d((1, 1, 1, 1)),
10
+ nn.Conv2d(64, 64, (3, 3)),
11
+ nn.ReLU(), # relu1-2
12
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
13
+ nn.ReflectionPad2d((1, 1, 1, 1)),
14
+ nn.Conv2d(64, 128, (3, 3)),
15
+ nn.ReLU(), # relu2-1
16
+ nn.ReflectionPad2d((1, 1, 1, 1)),
17
+ nn.Conv2d(128, 128, (3, 3)),
18
+ nn.ReLU(), # relu2-2
19
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
20
+ nn.ReflectionPad2d((1, 1, 1, 1)),
21
+ nn.Conv2d(128, 256, (3, 3)),
22
+ nn.ReLU(), # relu3-1
23
+ nn.ReflectionPad2d((1, 1, 1, 1)),
24
+ nn.Conv2d(256, 256, (3, 3)),
25
+ nn.ReLU(), # relu3-2
26
+ nn.ReflectionPad2d((1, 1, 1, 1)),
27
+ nn.Conv2d(256, 256, (3, 3)),
28
+ nn.ReLU(), # relu3-3
29
+ nn.ReflectionPad2d((1, 1, 1, 1)),
30
+ nn.Conv2d(256, 256, (3, 3)),
31
+ nn.ReLU(), # relu3-4
32
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
33
+ nn.ReflectionPad2d((1, 1, 1, 1)),
34
+ nn.Conv2d(256, 512, (3, 3)),
35
+ nn.ReLU(), # relu4-1, this is the last layer used
36
+ nn.ReflectionPad2d((1, 1, 1, 1)),
37
+ nn.Conv2d(512, 512, (3, 3)),
38
+ nn.ReLU(), # relu4-2
39
+ nn.ReflectionPad2d((1, 1, 1, 1)),
40
+ nn.Conv2d(512, 512, (3, 3)),
41
+ nn.ReLU(), # relu4-3
42
+ nn.ReflectionPad2d((1, 1, 1, 1)),
43
+ nn.Conv2d(512, 512, (3, 3)),
44
+ nn.ReLU(), # relu4-4
45
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
46
+ nn.ReflectionPad2d((1, 1, 1, 1)),
47
+ nn.Conv2d(512, 512, (3, 3)),
48
+ nn.ReLU(), # relu5-1
49
+ nn.ReflectionPad2d((1, 1, 1, 1)),
50
+ nn.Conv2d(512, 512, (3, 3)),
51
+ nn.ReLU(), # relu5-2
52
+ nn.ReflectionPad2d((1, 1, 1, 1)),
53
+ nn.Conv2d(512, 512, (3, 3)),
54
+ nn.ReLU(), # relu5-3
55
+ nn.ReflectionPad2d((1, 1, 1, 1)),
56
+ nn.Conv2d(512, 512, (3, 3)),
57
+ nn.ReLU() # relu5-4
58
+ )
logs/events.out.tfevents.1673075465.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dce644727440026ca34adf6e356baa976b1abdad898c4227947623aaa6c27242
3
+ size 40
logs/events.out.tfevents.1673075531.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ff3282e663a03174d06e037b601a40d8796d6ac3c2f02dc803ae854bd47f224
3
+ size 40
logs/events.out.tfevents.1673075820.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82618836bcd68966be78e10fecbf7316cca493eb6131a82dfcbbcfcd1d5ec66a
3
+ size 40
logs/events.out.tfevents.1673075821.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccbaca8977b8c54996f333c9cd844a3b4cd6f81f619f4e2cd355d96846110c2f
3
+ size 40
logs/events.out.tfevents.1673075850.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac12b47b00399439a60549dafeceba4525b0fbabf7c1046c63d2559fdd719562
3
+ size 40
logs/events.out.tfevents.1673075852.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1e5c0d77d3b9ed00d7bba6febcb5568f4ea1d01c77a13b48e5fa7c3bfe088b4
3
+ size 40
logs/events.out.tfevents.1673075889.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5ade4d30a4b1a32752bfe04ccf7499bf56fbb82d279ccc024254f465be3b253
3
+ size 40
logs/events.out.tfevents.1673075890.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:826d5080380cadac7bd1525880ef05aab8be1ed4c9ccf5c1d7349f1c6280ca03
3
+ size 40
logs/events.out.tfevents.1673075982.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f733d66ed2181abd3f250e23025231ee530cab128f976c282b502706834b11af
3
+ size 40
logs/events.out.tfevents.1673076026.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bc909a1adc9130a19913a9e53e634167b20cb15c50c8cc35d3c6d9a92b2f146
3
+ size 40
logs/events.out.tfevents.1673076079.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b03cba18d7045162a22c08091d1492af2ef99cd211a562f5bc82a1752be8c3c9
3
+ size 40
logs/events.out.tfevents.1673076142.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:156bf3ee127ca32ca84fa3405acdc4ffc1431a9dd96d5ccd5c5267cb05aaed62
3
+ size 40
logs/events.out.tfevents.1673076233.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e687762ebd774a4b8b84609559927223587f4f59128e41360d6e8b3c77d21320
3
+ size 40
logs/events.out.tfevents.1673076507.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c07751136a8cea75e817528ee97e0d018c9779150bbd69a431ca7aa35aeb49cd
3
+ size 40
logs/events.out.tfevents.1673076723.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bac6375530aa11a872827fd9006f07f37f9122c475b6398f4afbced98fa575d4
3
+ size 40
logs/events.out.tfevents.1673076832.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f82096b51b53c03476c219d7cef64374f21277e4f7d11a98877786458e215f1f
3
+ size 40
logs/events.out.tfevents.1673076887.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a33b760c94457ccf05b3b8c81c4c7a0dd2af318c72dafbb94388fb805e99cc0
3
+ size 40
logs/events.out.tfevents.1673076993.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1783ef3a87692e3d75c763073e30c9c5d9f0f577017cff2c3cee63a5e565dbf2
3
+ size 40
logs/events.out.tfevents.1673077155.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5baedc0a372f0ac969de671e31c15a60e7185972a15fc730cda35add60c51a9
3
+ size 40
logs/events.out.tfevents.1673077187.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5293293201e7252e90e863e550f2facdf95c7a84970424003f44e2e9285e0d0e
3
+ size 40
logs/events.out.tfevents.1673077234.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c998d2738fb5d91c2328fc55968fc59cb021c485c94b0e0eec33d043cd5b92f4
3
+ size 40
logs/events.out.tfevents.1673079573.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:439a6b8f06eacebeade084bc0b99cdf0fc2dd664f254f63035b85e6bea864954
3
+ size 40
logs/events.out.tfevents.1673079783.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19fa8348ce1a15cbcd2fbce80cdc4ca5935e2a776975dc914120b8e1989a77e5
3
+ size 40
logs/events.out.tfevents.1673079809.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd2327969e02ac17d9fd678ecb993897c2d48d1257cdb60043a7bfeed7188db3
3
+ size 40
logs/events.out.tfevents.1673079875.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07c018ca3aed8cc142059b3c38a35ce270db39d4e830bec58e2d432a6539ce97
3
+ size 40
logs/events.out.tfevents.1673079932.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5e140d7be63757c53ae28febf5a730b453986307698009cb0c1163cd3adabed
3
+ size 40
logs/events.out.tfevents.1673080014.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:646d967ffd7850501fb6bdde7df8aa44850d27f8d1e9011816b25531116d06f5
3
+ size 40
logs/events.out.tfevents.1673080084.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08baf672c0d6ab6eb18d6c0e4dd4b45d061171e346b51159e034497ccea12123
3
+ size 44786
logs/events.out.tfevents.1673080471.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a7d92ffc0a3f73e0a861c3b6cd709e9638e904442c6115459ce7faceced9d33
3
+ size 4940
logs/events.out.tfevents.1673080709.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b90008548a1bbd7852e46eb4b71dcf412f634bf93be55d7b84bce046cbe8fc6
3
+ size 109386
logs/events.out.tfevents.1673733387.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8fd458f41aa7271947c88f846db26cb0a58698505d219791ae5814accad1bab
3
+ size 109386
logs/events.out.tfevents.1673735400.Vikrams-MBP.lan ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb80a6f53db3f15fc30d4cf0267b4ce6b2b87288ae374949db56ce4d57fcb07b
3
+ size 109386
net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from adain import AdaIN
4
+ from utils import *
5
+
6
+ class StyleTransfer(nn.Module):
7
+ def __init__(self, encoder, decoder):
8
+ super(StyleTransfer, self).__init__()
9
+ layers = list(encoder.children())
10
+ self.enc_1 = nn.Sequential(*layers[:4]) # input -> relu1_1
11
+ self.enc_2 = nn.Sequential(*layers[4:11]) # relu1_1 -> relu2_1
12
+ self.enc_3 = nn.Sequential(*layers[11:18]) # relu2_1 -> relu3_1
13
+ self.enc_4 = nn.Sequential(*layers[18:31]) # relu3_1 -> relu4_1]
14
+ self.relus = [self.enc_1, self.enc_2, self.enc_3, self.enc_4]
15
+ self.decoder = decoder
16
+ self.mse = nn.MSELoss()
17
+ self.adain = AdaIN()
18
+
19
+ for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
20
+ for param in getattr(self, name).parameters():
21
+ param.requires_grad = False
22
+
23
+ def encode_with_save(self, input):
24
+ results = [input]
25
+ for i in range(4):
26
+ func = getattr(self, 'enc_{:d}'.format(i + 1))
27
+ results.append(func(results[-1]))
28
+ return results[1:]
29
+
30
+ def encode(self, input):
31
+ res = input
32
+ for layer in self.relus:
33
+ res = layer(res)
34
+ return res
35
+
36
+ def forward(self, content, style):
37
+ if not self.training:
38
+ self.adain.eval()
39
+ encoded_style = self.encode_with_save(style)
40
+ encoded_content = self.encode(content)
41
+ t = self.adain(encoded_content, encoded_style[-1])
42
+
43
+
44
+ g_t = self.decoder(t)
45
+
46
+ if not self.training:
47
+ return g_t
48
+ g_t_encoding = self.encode_with_save(g_t)
49
+
50
+ s_loss = self.style_loss(g_t_encoding, encoded_style)
51
+ c_loss = self.content_loss(g_t_encoding[-1], t)
52
+
53
+ return g_t, s_loss, c_loss
54
+
55
+
56
+ def style_loss(self, encoded_image, encoded_style):
57
+ MSE = torch.nn.MSELoss()
58
+ initial_mean_image, initial_std_image = mean_and_std_of_image(encoded_image[0])
59
+ initial_mean_style, initial_std_style = mean_and_std_of_image(encoded_style[0])
60
+ loss = MSE(initial_mean_image, initial_mean_style) + MSE(initial_std_image, initial_std_style)
61
+ for i in range(1, 4, 1):
62
+ mean_image, std_image = mean_and_std_of_image(encoded_image[i])
63
+ mean_style, std_style = mean_and_std_of_image(encoded_style[i])
64
+ loss += MSE(mean_image, mean_style) + MSE(std_image, std_style)
65
+ return loss
66
+
67
+
68
+ def content_loss(self, encoded_image, style_content_combined):
69
+ MSE = torch.nn.MSELoss()
70
+ return MSE(encoded_image, style_content_combined)
71
+
72
+
73
+
74
+
75
+
76
+
saved-models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
saved-models/decoder_iter_1000.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5831e836f6bb7bab64fef12a1a070d69efb6ec312dc3c4a96653e22af55b5809
3
+ size 14026951
saved-models/decoder_iter_500.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08bd329dfe5e930bccebf1045365818574eb203008c9ea3a30fe53c657ac6b32
3
+ size 14026931
train.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from net import StyleTransfer
2
+ import torch
3
+ import torch.nn as nn
4
+ from pathlib import Path
5
+ import torchvision
6
+ import torch.utils.data as data
7
+ import torchvision.transforms as transforms
8
+ import matplotlib.pyplot as plt
9
+ import torch.multiprocessing
10
+ from utils import *
11
+ import argparse
12
+ from tqdm import tqdm
13
+ from tensorboardX import SummaryWriter
14
+ from decoder import decoder as Decoder
15
+ from encoder import encoder as Encoder
16
+ from PIL import Image, ImageFile
17
+
18
+ class FlatFolderDataset(data.Dataset):
19
+ def __init__(self, root, transform):
20
+ super(FlatFolderDataset, self).__init__()
21
+ self.root = root
22
+ self.paths = list(Path(self.root).glob('*'))
23
+ self.transform = transform
24
+
25
+ def __getitem__(self, index):
26
+ path = self.paths[index]
27
+ img = Image.open(str(path)).convert('RGB')
28
+ img = self.transform(img)
29
+ return img
30
+
31
+ def __len__(self):
32
+ return len(self.paths)
33
+
34
+ def name(self):
35
+ return 'FlatFolderDataset'
36
+
37
+ def main():
38
+ torch.multiprocessing.set_sharing_strategy('file_system')
39
+
40
+ # Set the path to the dataset directory
41
+ content_dataset_dir = '../../content-dataset/images/images'
42
+ style_dataset_dir = '../../style-dataset/images'
43
+
44
+
45
+ def train_transform():
46
+ transform_list = [
47
+ transforms.Resize(size=(512, 512)),
48
+ transforms.RandomCrop(256),
49
+ transforms.ToTensor()
50
+ ]
51
+ return transforms.Compose(transform_list)
52
+
53
+
54
+
55
+
56
+ parser = argparse.ArgumentParser()
57
+ # Basic options
58
+ parser.add_argument('--content_dir', default=content_dataset_dir, type=str,
59
+ help='Directory path to a batch of content images')
60
+ parser.add_argument('--style_dir', default=style_dataset_dir, type=str,
61
+ help='Directory path to a batch of style images')
62
+ parser.add_argument('--encoder', type=str, default='./vgg_normalised.pth')
63
+
64
+ # training options
65
+ parser.add_argument('--save_dir', default='../saved-models',
66
+ help='Directory to save the model')
67
+ parser.add_argument('--log_dir', default='./logs',
68
+ help='Directory to save the log')
69
+ parser.add_argument('--lr', type=float, default=1e-4)
70
+ parser.add_argument('--lr_decay', type=float, default=5e-5)
71
+ parser.add_argument('--max_iter', type=int, default=8000)
72
+ parser.add_argument('--batch_size', type=int, default=8)
73
+ parser.add_argument('--style_weight', type=float, default=10.0)
74
+ parser.add_argument('--content_weight', type=float, default=1.0)
75
+ parser.add_argument('--n_threads', type=int, default=8)
76
+ parser.add_argument('--save_model_interval', type=int, default=500)
77
+ parser.add_argument('--save-image-interval', type=int, default=50)
78
+ args = parser.parse_args()
79
+
80
+
81
+
82
+
83
+ device = torch.device('mps')
84
+ save_dir = Path(args.save_dir)
85
+ save_dir.mkdir(exist_ok=True, parents=True)
86
+ log_dir = Path(args.log_dir)
87
+ log_dir.mkdir(exist_ok=True, parents=True)
88
+ writer = SummaryWriter(log_dir=str(log_dir))
89
+
90
+
91
+ decoder = Decoder
92
+ encoder = Encoder
93
+
94
+ encoder.load_state_dict(torch.load(args.encoder))
95
+ encoder = nn.Sequential(*list(encoder.children())[:31])
96
+ network = StyleTransfer(encoder, decoder)
97
+ network.train()
98
+ network.to(device)
99
+
100
+ content_dataset = FlatFolderDataset(args.content_dir, transform=train_transform())
101
+ style_dataset = FlatFolderDataset(args.style_dir, transform=train_transform())
102
+
103
+ print(len(content_dataset), len(style_dataset))
104
+
105
+ content_iter = iter(data.DataLoader(
106
+ content_dataset, batch_size=args.batch_size,
107
+ num_workers=args.n_threads))
108
+ style_iter = iter(data.DataLoader(
109
+ style_dataset, batch_size=args.batch_size,
110
+ num_workers=args.n_threads))
111
+ optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr)
112
+
113
+
114
+ for batch in tqdm(range(args.max_iter)):
115
+ adjust_learning_rate(optimizer, batch, args.lr_decay, args.lr)
116
+ content_images = next(content_iter).to(device)
117
+ style_images = next(style_iter).to(device)
118
+ final_image, s_loss, c_loss = network(content_images, style_images)
119
+ c_loss = args.content_weight * c_loss
120
+ s_loss = args.style_weight * s_loss
121
+ total_loss = c_loss + s_loss
122
+
123
+ optimizer.zero_grad()
124
+ total_loss.backward()
125
+ optimizer.step()
126
+
127
+ writer.add_scalar('loss_content', c_loss.item(), batch + 1)
128
+ writer.add_scalar('loss_style', s_loss.item(), batch + 1)
129
+
130
+ if (batch + 1) % args.save_model_interval == 0 or (batch + 1) == args.max_iter:
131
+ state_dict = network.decoder.state_dict()
132
+ for key in state_dict.keys():
133
+ state_dict[key] = state_dict[key].to(torch.device('cpu'))
134
+ torch.save(state_dict, save_dir /
135
+ 'decoder_iter_{:d}.pth.tar'.format(batch + 1))
136
+
137
+ if (batch + 1) % args.save_image_interval == 0:
138
+ print_img = torch.cat((content_images[:1], style_images[:1], final_image[:1]), 3).detach().cpu()
139
+ concat_img(print_img, batch)
140
+ writer.close()
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+ def adjust_learning_rate(optimiser, iters, learning_rate_decay, LR):
6
+ for param_group in optimiser.param_groups:
7
+ param_group['lr'] = LR / (1.0 + learning_rate_decay * iters)
8
+
9
+ def concat_img(imgs, batch):
10
+ plt.figure()
11
+ #imgs = (imgs + 1) / 2
12
+ imgs = imgs.movedim((0, 1, 2, 3), (0, 3, 1, 2)).detach().cpu().numpy()
13
+ axs = plt.imshow(np.concatenate(imgs.tolist(), axis=1))
14
+ plt.axis('off')
15
+ plt.savefig("../../produced-images/batch{}img.png".format(batch))
16
+ plt.close()
17
+
18
+ def concat_img(imgs, batch):
19
+ plt.figure()
20
+ #imgs = (imgs + 1) / 2
21
+ imgs = imgs.movedim((0, 1, 2, 3), (0, 3, 1, 2)).detach().cpu().numpy()
22
+ axs = plt.imshow(np.concatenate(imgs.tolist(), axis=1))
23
+ plt.axis('off')
24
+ plt.savefig("../../produced-images/batch{}img.png".format(batch))
25
+
26
+ # takes in image tensor x as input
27
+ def mean_and_std_of_image(x):
28
+ x_size = x.size()
29
+ # turn x into the shape of (batch_size, num_channels, height*width)
30
+ x = x.view(x.shape[0], x.shape[1], -1)
31
+ #calculate the mean of the second dimension, H*W
32
+ mean = x.mean(dim=2)
33
+ std = x.var(dim=2).sqrt()
34
+ #reshape mean and std to size (batch_size, num_channels, 1, 1)
35
+ #because mean and std are sort of a scalar quantity the last two dimensions are both 1
36
+ # mean = mean.view(mean.shape[0], mean.shape[1], 1, 1)
37
+ # std = std.view(std.shape[0], std.shape[1], 1, 1)
38
+
39
+ mean = mean.view(mean.shape[0], mean.shape[1], 1, 1)
40
+ std = std.view(std.shape[0], std.shape[1], 1, 1)
41
+
42
+ return (mean, std)
vgg_normalised.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:804ca2835ecf7539f0cd2a7ac3c18ce81e6f8468969ae7117ac0c148d286bb4a
3
+ size 80102481