Spanicin commited on
Commit
4d7bc0c
1 Parent(s): 42e8c1c

Update videoretalking/models/DNet.py

Browse files
Files changed (1) hide show
  1. videoretalking/models/DNet.py +118 -118
videoretalking/models/DNet.py CHANGED
@@ -1,118 +1,118 @@
1
- # TODO
2
- import functools
3
- import numpy as np
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- from utils import flow_util
10
- from models.base_blocks import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
11
-
12
- # DNet
13
- class DNet(nn.Module):
14
- def __init__(self):
15
- super(DNet, self).__init__()
16
- self.mapping_net = MappingNet()
17
- self.warpping_net = WarpingNet()
18
- self.editing_net = EditingNet()
19
-
20
- def forward(self, input_image, driving_source, stage=None):
21
- if stage == 'warp':
22
- descriptor = self.mapping_net(driving_source)
23
- output = self.warpping_net(input_image, descriptor)
24
- else:
25
- descriptor = self.mapping_net(driving_source)
26
- output = self.warpping_net(input_image, descriptor)
27
- output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
28
- return output
29
-
30
- class MappingNet(nn.Module):
31
- def __init__(self, coeff_nc=73, descriptor_nc=256, layer=3):
32
- super( MappingNet, self).__init__()
33
-
34
- self.layer = layer
35
- nonlinearity = nn.LeakyReLU(0.1)
36
-
37
- self.first = nn.Sequential(
38
- torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
39
-
40
- for i in range(layer):
41
- net = nn.Sequential(nonlinearity,
42
- torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
43
- setattr(self, 'encoder' + str(i), net)
44
-
45
- self.pooling = nn.AdaptiveAvgPool1d(1)
46
- self.output_nc = descriptor_nc
47
-
48
- def forward(self, input_3dmm):
49
- out = self.first(input_3dmm)
50
- for i in range(self.layer):
51
- model = getattr(self, 'encoder' + str(i))
52
- out = model(out) + out[:,:,3:-3]
53
- out = self.pooling(out)
54
- return out
55
-
56
- class WarpingNet(nn.Module):
57
- def __init__(
58
- self,
59
- image_nc=3,
60
- descriptor_nc=256,
61
- base_nc=32,
62
- max_nc=256,
63
- encoder_layer=5,
64
- decoder_layer=3,
65
- use_spect=False
66
- ):
67
- super( WarpingNet, self).__init__()
68
-
69
- nonlinearity = nn.LeakyReLU(0.1)
70
- norm_layer = functools.partial(LayerNorm2d, affine=True)
71
- kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
72
-
73
- self.descriptor_nc = descriptor_nc
74
- self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
75
- max_nc, encoder_layer, decoder_layer, **kwargs)
76
-
77
- self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
78
- nonlinearity,
79
- nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
80
-
81
- self.pool = nn.AdaptiveAvgPool2d(1)
82
-
83
- def forward(self, input_image, descriptor):
84
- final_output={}
85
- output = self.hourglass(input_image, descriptor)
86
- final_output['flow_field'] = self.flow_out(output)
87
-
88
- deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
89
- final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
90
- return final_output
91
-
92
-
93
- class EditingNet(nn.Module):
94
- def __init__(
95
- self,
96
- image_nc=3,
97
- descriptor_nc=256,
98
- layer=3,
99
- base_nc=64,
100
- max_nc=256,
101
- num_res_blocks=2,
102
- use_spect=False):
103
- super(EditingNet, self).__init__()
104
-
105
- nonlinearity = nn.LeakyReLU(0.1)
106
- norm_layer = functools.partial(LayerNorm2d, affine=True)
107
- kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
108
- self.descriptor_nc = descriptor_nc
109
-
110
- # encoder part
111
- self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
112
- self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
113
-
114
- def forward(self, input_image, warp_image, descriptor):
115
- x = torch.cat([input_image, warp_image], 1)
116
- x = self.encoder(x)
117
- gen_image = self.decoder(x, descriptor)
118
- return gen_image
 
1
+ # TODO
2
+ import functools
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from videoretalking.utils import flow_util
10
+ from videoretalking.models.base_blocks import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
11
+
12
+ # DNet
13
+ class DNet(nn.Module):
14
+ def __init__(self):
15
+ super(DNet, self).__init__()
16
+ self.mapping_net = MappingNet()
17
+ self.warpping_net = WarpingNet()
18
+ self.editing_net = EditingNet()
19
+
20
+ def forward(self, input_image, driving_source, stage=None):
21
+ if stage == 'warp':
22
+ descriptor = self.mapping_net(driving_source)
23
+ output = self.warpping_net(input_image, descriptor)
24
+ else:
25
+ descriptor = self.mapping_net(driving_source)
26
+ output = self.warpping_net(input_image, descriptor)
27
+ output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
28
+ return output
29
+
30
+ class MappingNet(nn.Module):
31
+ def __init__(self, coeff_nc=73, descriptor_nc=256, layer=3):
32
+ super( MappingNet, self).__init__()
33
+
34
+ self.layer = layer
35
+ nonlinearity = nn.LeakyReLU(0.1)
36
+
37
+ self.first = nn.Sequential(
38
+ torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
39
+
40
+ for i in range(layer):
41
+ net = nn.Sequential(nonlinearity,
42
+ torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
43
+ setattr(self, 'encoder' + str(i), net)
44
+
45
+ self.pooling = nn.AdaptiveAvgPool1d(1)
46
+ self.output_nc = descriptor_nc
47
+
48
+ def forward(self, input_3dmm):
49
+ out = self.first(input_3dmm)
50
+ for i in range(self.layer):
51
+ model = getattr(self, 'encoder' + str(i))
52
+ out = model(out) + out[:,:,3:-3]
53
+ out = self.pooling(out)
54
+ return out
55
+
56
+ class WarpingNet(nn.Module):
57
+ def __init__(
58
+ self,
59
+ image_nc=3,
60
+ descriptor_nc=256,
61
+ base_nc=32,
62
+ max_nc=256,
63
+ encoder_layer=5,
64
+ decoder_layer=3,
65
+ use_spect=False
66
+ ):
67
+ super( WarpingNet, self).__init__()
68
+
69
+ nonlinearity = nn.LeakyReLU(0.1)
70
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
71
+ kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
72
+
73
+ self.descriptor_nc = descriptor_nc
74
+ self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
75
+ max_nc, encoder_layer, decoder_layer, **kwargs)
76
+
77
+ self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
78
+ nonlinearity,
79
+ nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
80
+
81
+ self.pool = nn.AdaptiveAvgPool2d(1)
82
+
83
+ def forward(self, input_image, descriptor):
84
+ final_output={}
85
+ output = self.hourglass(input_image, descriptor)
86
+ final_output['flow_field'] = self.flow_out(output)
87
+
88
+ deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
89
+ final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
90
+ return final_output
91
+
92
+
93
+ class EditingNet(nn.Module):
94
+ def __init__(
95
+ self,
96
+ image_nc=3,
97
+ descriptor_nc=256,
98
+ layer=3,
99
+ base_nc=64,
100
+ max_nc=256,
101
+ num_res_blocks=2,
102
+ use_spect=False):
103
+ super(EditingNet, self).__init__()
104
+
105
+ nonlinearity = nn.LeakyReLU(0.1)
106
+ norm_layer = functools.partial(LayerNorm2d, affine=True)
107
+ kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
108
+ self.descriptor_nc = descriptor_nc
109
+
110
+ # encoder part
111
+ self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
112
+ self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
113
+
114
+ def forward(self, input_image, warp_image, descriptor):
115
+ x = torch.cat([input_image, warp_image], 1)
116
+ x = self.encoder(x)
117
+ gen_image = self.decoder(x, descriptor)
118
+ return gen_image