vlbthambawita commited on
Commit
df20d82
1 Parent(s): c7c9ff6

added files

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ pre_trained_checkpoint_4ch
__pycache__/models.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
generate_4ch.py CHANGED
@@ -5,6 +5,7 @@ import torch.nn.functional as F
5
  from torchvision.datasets import ImageFolder
6
  from torch.utils.data import DataLoader
7
  from torchvision import utils as vutils
 
8
 
9
  import os
10
  import random
@@ -36,12 +37,22 @@ def batch_save(images, folder_name):
36
  for i, image in enumerate(images):
37
  vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
38
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  if __name__ == "__main__":
41
  parser = argparse.ArgumentParser(
42
  description='generate images'
43
  )
44
- parser.add_argument('--ckpt', type=str, default="pre_trained_checkpoint_4ch/all_50000.pth")
45
  parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
46
  parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
47
  parser.add_argument('--start_iter', type=int, default=6)
@@ -50,7 +61,7 @@ if __name__ == "__main__":
50
  parser.add_argument('--dist', type=str, default='test_out')
51
  parser.add_argument('--size', type=int, default=256)
52
  parser.add_argument('--batch', default=1, type=int, help='batch size')
53
- parser.add_argument('--n_sample', type=int, default=1000)
54
  parser.add_argument('--big', action='store_true')
55
  parser.add_argument('--im_size', type=int, default=256)
56
  parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
@@ -59,8 +70,17 @@ if __name__ == "__main__":
59
 
60
  noise_dim = 256
61
  device = torch.device('cuda:%d'%(args.cuda))
 
 
 
 
 
62
 
63
- net_ig = Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
 
 
 
 
64
  net_ig.to(device)
65
 
66
  #for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
@@ -69,13 +89,25 @@ if __name__ == "__main__":
69
  checkpoint = torch.load(ckpt)
70
  # Remove prefix `module`.
71
  checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
72
- net_ig.load_state_dict(checkpoint['g'])
73
  #load_params(net_ig, checkpoint['g_ema'])
74
 
75
  #net_ig.eval()
76
  print("load checkpoint success")
77
 
78
  net_ig.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  del checkpoint
81
 
 
5
  from torchvision.datasets import ImageFolder
6
  from torch.utils.data import DataLoader
7
  from torchvision import utils as vutils
8
+ from huggingface_hub import PyTorchModelHubMixin
9
 
10
  import os
11
  import random
 
37
  for i, image in enumerate(images):
38
  vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
39
 
40
+ # To push the model to Huggingface model hub
41
+ class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
42
+
43
+ def __init__(self, config: dict) -> None:
44
+ super().__init__()
45
+
46
+ self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
47
+
48
+ def forward(self, x):
49
+ return self.model(x)
50
 
51
  if __name__ == "__main__":
52
  parser = argparse.ArgumentParser(
53
  description='generate images'
54
  )
55
+ parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
56
  parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
57
  parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
58
  parser.add_argument('--start_iter', type=int, default=6)
 
61
  parser.add_argument('--dist', type=str, default='test_out')
62
  parser.add_argument('--size', type=int, default=256)
63
  parser.add_argument('--batch', default=1, type=int, help='batch size')
64
+ parser.add_argument('--n_sample', type=int, default=1)
65
  parser.add_argument('--big', action='store_true')
66
  parser.add_argument('--im_size', type=int, default=256)
67
  parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
 
70
 
71
  noise_dim = 256
72
  device = torch.device('cuda:%d'%(args.cuda))
73
+
74
+ # adding the model to the model hub
75
+ config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
76
+ net_ig = MyFastGanModel(config=config)
77
+
78
 
79
+
80
+ # exit
81
+ #exit()
82
+
83
+ #net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
84
  net_ig.to(device)
85
 
86
  #for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
 
89
  checkpoint = torch.load(ckpt)
90
  # Remove prefix `module`.
91
  checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
92
+ net_ig.model.load_state_dict(checkpoint['g'])
93
  #load_params(net_ig, checkpoint['g_ema'])
94
 
95
  #net_ig.eval()
96
  print("load checkpoint success")
97
 
98
  net_ig.to(device)
99
+ # Save locally
100
+ net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
101
+ print("Model saved locally. Pushing to Huggingface model hub...")
102
+
103
+ # Push to the Huggingface model hub
104
+ # push to the hub
105
+ net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
106
+
107
+
108
+ print("pushed to the Huggingface model hub. Done.")
109
+ exit()
110
+
111
 
112
  del checkpoint
113
 
generate_4ch_from_huggingface.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch import optim
4
+ import torch.nn.functional as F
5
+ from torchvision.datasets import ImageFolder
6
+ from torch.utils.data import DataLoader
7
+ from torchvision import utils as vutils
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ import os
11
+ import random
12
+ import argparse
13
+ from tqdm import tqdm
14
+
15
+ from models import Generator
16
+
17
+
18
+ def load_params(model, new_param):
19
+ for p, new_p in zip(model.parameters(), new_param):
20
+ p.data.copy_(new_p)
21
+
22
+ def resize(img):
23
+ return F.interpolate(img, size=256)
24
+
25
+ def batch_generate(zs, netG, batch=8):
26
+ g_images = []
27
+ with torch.no_grad():
28
+ for i in range(len(zs)//batch):
29
+ g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
30
+ if len(zs)%batch>0:
31
+ g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
32
+ return torch.cat(g_images)
33
+
34
+ def batch_save(images, folder_name):
35
+ if not os.path.exists(folder_name):
36
+ os.mkdir(folder_name)
37
+ for i, image in enumerate(images):
38
+ vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
39
+
40
+ # To push the model to Huggingface model hub
41
+ class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
42
+
43
+ def __init__(self, config: dict) -> None:
44
+ super().__init__()
45
+
46
+ self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
47
+
48
+ def forward(self, x):
49
+ return self.model(x)
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(
53
+ description='generate images'
54
+ )
55
+ parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
56
+ parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
57
+ parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
58
+ parser.add_argument('--start_iter', type=int, default=6)
59
+ parser.add_argument('--end_iter', type=int, default=10)
60
+
61
+ parser.add_argument('--dist', type=str, default='test_out')
62
+ parser.add_argument('--size', type=int, default=256)
63
+ parser.add_argument('--batch', default=1, type=int, help='batch size')
64
+ parser.add_argument('--n_sample', type=int, default=1)
65
+ parser.add_argument('--big', action='store_true')
66
+ parser.add_argument('--im_size', type=int, default=256)
67
+ parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
68
+ parser.set_defaults(big=False)
69
+ args = parser.parse_args()
70
+
71
+ noise_dim = 256
72
+ device = torch.device('cuda:%d'%(args.cuda))
73
+
74
+ # adding the model to the model hub
75
+ config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
76
+ net_ig = MyFastGanModel(config=config)
77
+
78
+
79
+
80
+ # exit
81
+ #exit()
82
+
83
+ #net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
84
+ #net_ig.to(device)
85
+
86
+ #for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
87
+ #ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth"
88
+ #checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
89
+ #checkpoint = torch.load(ckpt)
90
+ # Remove prefix `module`.
91
+ #checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
92
+ #net_ig.model.load_state_dict(checkpoint['g'])
93
+ #load_params(net_ig, checkpoint['g_ema'])
94
+
95
+ net_ig = MyFastGanModel.from_pretrained("deepsynthbody/deepfake_gi_fastGAN", config=config) # Load the model from the hub
96
+
97
+ #net_ig.eval()
98
+ print("load checkpoint success")
99
+
100
+ net_ig.to(device)
101
+ # Save locally
102
+ # net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
103
+ # print("Model saved locally. Pushing to Huggingface model hub...")
104
+
105
+ # Push to the Huggingface model hub
106
+ # push to the hub
107
+ # net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
108
+
109
+
110
+ #print("pushed to the Huggingface model hub. Done.")
111
+ #exit()
112
+
113
+
114
+ #del checkpoint
115
+
116
+ #dist = 'eval_%d'%(epoch)
117
+ #dist = os.path.join(args.dist, 'img')
118
+ dist = args.dist
119
+ os.makedirs(dist, exist_ok=True)
120
+
121
+ with torch.no_grad():
122
+ for i in tqdm(range(args.n_sample//args.batch)):
123
+ noise = torch.randn(args.batch, noise_dim).to(device)
124
+ g_imgs = net_ig(noise)[0]
125
+ g_imgs = F.interpolate(g_imgs, 512)
126
+
127
+
128
+ for j, g_img in enumerate( g_imgs ):
129
+ #print("img sahpe=", g_img.shape)
130
+ g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1)
131
+ g_img = g_img.add(1).mul(0.5)[0:3, :, :]
132
+
133
+ # Clean generated data using clamping
134
+ g_mask = torch.clamp(g_mask, min=0, max=1)
135
+ g_img = torch.clamp(g_img, min=0, max=1)
136
+ #print(g_mask.type())
137
+ g_mask = (g_mask > 0.5) * 1.0
138
+ #print(g_mask.type())
139
+
140
+ #print("gmask_min:", g_mask.min())
141
+ #print("gmask_max:", g_mask.max())
142
+ #exit()
143
+
144
+ #print("img sahpe=", g_img.shape)
145
+
146
+ if args.save_option == "image_and_mask":
147
+ vutils.save_image(g_img,
148
+ os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
149
+ vutils.save_image(g_mask,
150
+ os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1))
151
+
152
+ elif args.save_option == "image_only":
153
+ vutils.save_image(g_img,
154
+ os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
155
+
156
+ elif args.save_option == "mask_only":
157
+ vutils.save_image(g_mask,
158
+ os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
159
+ else:
160
+ print("wrong choise to save option.")
test_out/0_img.png ADDED
test_out/0_mask.png ADDED