bryandts commited on
Commit
a6111ec
1 Parent(s): 35de619

Create generator.py

Browse files
Files changed (1) hide show
  1. generator.py +44 -0
generator.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # The Generator model
5
+ class Generator(nn.Module):
6
+ def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
7
+ super(Generator, self).__init__()
8
+ self.channels = channels
9
+ self.noise_dim = noise_dim
10
+ self.embed_dim = embed_dim
11
+ self.embed_out_dim = embed_out_dim
12
+
13
+ # Text embedding layers
14
+ self.text_embedding = nn.Sequential(
15
+ nn.Linear(self.embed_dim, self.embed_out_dim),
16
+ nn.BatchNorm1d(1),
17
+ nn.LeakyReLU(0.2, inplace=True)
18
+ )
19
+
20
+ # Generator architecture
21
+ model = []
22
+ model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
23
+ model += self._create_layer(512, 256, 4, stride=2, padding=1)
24
+ model += self._create_layer(256, 128, 4, stride=2, padding=1)
25
+ model += self._create_layer(128, 64, 4, stride=2, padding=1)
26
+ model += self._create_layer(64, 32, 4, stride=2, padding=1)
27
+ model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True)
28
+
29
+ self.model = nn.Sequential(*model)
30
+
31
+ def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
32
+ layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
33
+ if output:
34
+ layers.append(nn.Tanh()) # Tanh activation for the output layer
35
+ else:
36
+ layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers
37
+ return layers
38
+
39
+ def forward(self, noise, text):
40
+ # Apply text embedding to the input text
41
+ text = self.text_embedding(text)
42
+ text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size
43
+ z = torch.cat([text, noise], 1) # Concatenate text embedding with noise
44
+ return self.model(z)